In [1]:
import os
import re
import pathlib
from skimage.color import rgb2lab, lab2rgb
from tqdm import tqdm
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import cv2
from PIL import Image
import random

import torch
import torch.nn as nn
import torch.optim as optim
import torchvision
import torch.nn.functional as F
from torchvision import datasets,transforms
from torch.utils.data import DataLoader, Dataset

In [2]:
# Configurations

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(device)
np.random.seed(25)
torch.manual_seed(25)

cuda


<torch._C.Generator at 0x7a9314ea2430>

In [3]:
img_dir = "/kaggle/input/coco25k/images"
working_dir = "/kaggle/working/"
IMG_DIM = 256
batch_size = 32 #adjust to 16 if required
learning_rate = 2e-4 #optimal rate for training GANs
NUM_EPOCHS = 20
beta1 = 0.5
beta2 = 0.999
lambda_L1 = 100.

In [4]:
class COCODataset(Dataset):
    def __init__(self, img_dir, transforms=None):
        self.img_dir = img_dir
        self.transforms = transforms
        self.image_paths = []
        all_images = [os.path.join(img_dir, file) for file in os.listdir(img_dir) 
                      if file.endswith('.jpg')]
        self.image_paths = random.sample(all_images, min(len(all_images), 10000))

    def __len__(self):
        return len(self.image_paths)

    def __getitem__(self, idx):
        img = Image.open(self.image_paths[idx]).convert("RGB")
        img = self.transforms(img)
        img = np.array(img)
        img_lab = rgb2lab(img).astype('float32')  #change to float16 for faster training
        img_lab = transforms.ToTensor()(img_lab)
        L = img_lab[[0], ...]/ 50.0 - 1.0
        ab = img_lab[[1, 2], ...]/ 128.0 
        return {'L': L, 'ab': ab}

In [5]:
# Make sure to not add jitter/ noise as it affects the image colors

train_transforms = transforms.Compose([
    transforms.Resize((IMG_DIM, IMG_DIM), Image.BICUBIC),
    transforms.RandomHorizontalFlip(),
    # transforms.RandomVerticalFlip(),
])

In [6]:
train_dataset = COCODataset(img_dir, train_transforms)
train_dl = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, num_workers=4)

In [7]:
# Sanity check

data = next(iter(train_dl))
Ls, abs_ = data['L'], data['ab']
print(Ls.shape, abs_.shape)
print(len(train_dl))

torch.Size([32, 1, 256, 256]) torch.Size([32, 2, 256, 256])
313


## Introducing ResNet backed UNet

In [8]:
from fastai.vision.learner import create_body
# from torchvision.models.resnet import resnet34
from fastai.vision.models.unet import DynamicUnet

resnet34 = torch.hub.load('pytorch/vision:v0.10.0', 'resnet34', pretrained=True)

Downloading: "https://github.com/pytorch/vision/zipball/v0.10.0" to /root/.cache/torch/hub/v0.10.0.zip
Downloading: "https://download.pytorch.org/models/resnet34-b627a593.pth" to /root/.cache/torch/hub/checkpoints/resnet34-b627a593.pth
100%|██████████| 83.3M/83.3M [00:00<00:00, 200MB/s]


Building Generator

In [9]:
def build_res_unet(n_input=1, n_output=2, size=IMG_DIM):
    body = create_body(resnet34, pretrained=True, n_in=n_input, cut=-2)
    net_G = DynamicUnet(body, n_output, (size, size))
    return net_G

Building Patch Discriminator

In [10]:
class PatchDisc(nn.Module):
    def __init__(self, input_c, num_filters=64, n_down=3):
        super().__init__()
        # We define the first and last layers outside the loop since they are required to not have activation 
        # and normalization according to the paper
        model = [self.get_layers(input_c, num_filters, norm=False)]
        model += [self.get_layers(num_filters * 2 ** i, num_filters * 2 ** (i + 1), s=1 if i == (n_down-1) else 2) 
                          for i in range(n_down)] 
        model += [self.get_layers(num_filters * 2 ** n_down, 1, s=1, norm=False, act=False)]
        self.model = nn.Sequential(*model)                                                   
        
    def get_layers(self, ni, nf, k=4, s=2, p=1, norm=True, act=True):
        layers = [nn.Conv2d(ni, nf, k, s, p, bias=not norm)]         
        if norm: layers += [nn.BatchNorm2d(nf)]
        if act: layers += [nn.LeakyReLU(0.2, True)]
        return nn.Sequential(*layers)
    
    def forward(self, x):
        return self.model(x)

In [11]:
# Defining loss functions
'''PatchGAN uses a combination of adversarial and L1 loss.'''

adversarial_loss = nn.BCEWithLogitsLoss()  
l1_loss = nn.L1Loss()  

In [12]:
# Initialize models
net_G = build_res_unet(n_input=1, n_output=2, size=IMG_DIM).to(device)
net_D = PatchDisc(input_c=3).to(device)

# Define optimizers
optimizer_G = optim.Adam(net_G.parameters(), lr=learning_rate, betas=(beta1, beta2))
optimizer_D = optim.Adam(net_D.parameters(), lr=learning_rate, betas=(beta1, beta2))

In [13]:
# Learning Rate Scheduler
scheduler_G = torch.optim.lr_scheduler.StepLR(optimizer_G, step_size=15, gamma=0.5)
scheduler_D = torch.optim.lr_scheduler.StepLR(optimizer_D, step_size=15, gamma=0.5)

def pretrain_G(net_G, train_dl, opt, criterion, epochs=20):
    net_G.train()  # Set generator to training mode
    for e in range(epochs):
        total_loss = 0.0  # Accumulate loss for each epoch
        for data in tqdm(train_dl, desc=f"Pretraining Epoch {e + 1}/{epochs}"):
            L, ab = data['L'].to(device), data['ab'].to(device)
            preds = net_G(L)  # Predicted color channels
            loss = criterion(preds, ab)  # L1 loss between predicted and ground truth
            opt.zero_grad()
            loss.backward()
            opt.step()
            total_loss += loss.item()
        
        avg_loss = total_loss / len(train_dl)
        print(f"Pretraining Epoch {e + 1}/{epochs}, Average L1 Loss: {avg_loss:.5f}")
        if (e + 1) % 5 == 0:  # Save every 5 epochs
            torch.save(net_G.state_dict(), f"generator_pretrain_epoch_{e+1}.pth")
        
        # Step scheduler
        scheduler_G.step()

        

pretrain_epochs = 20
pretrain_criterion = nn.L1Loss()
pretrain_optimizer = optim.Adam(net_G.parameters(), lr=learning_rate, betas=(beta1, beta2))

In [14]:
from torchvision.utils import save_image

# Pretrain the generator
pretrain_G(net_G, train_dl, pretrain_optimizer, pretrain_criterion, epochs=pretrain_epochs)

for epoch in range(NUM_EPOCHS):
    lambda_L1 = max(100.0 / (epoch + 1), 1.0)  # Dynamically adjust lambda_L1

    for i, data in enumerate(train_dl):
        # Load data
        real_L = data['L'].to(device)  # Grayscale input
        real_ab = data['ab'].to(device)  # Ground truth color channels
        
        # ==========================================
        # Train Discriminator
        # ==========================================
        optimizer_D.zero_grad()
        
        # Real images (input + ground truth)
        real_input = torch.cat([real_L, real_ab], dim=1)  # Concatenate grayscale and color
        real_validity = net_D(real_input)  # Discriminator output for real images
        real_loss = adversarial_loss(real_validity, torch.ones_like(real_validity, device=device))  # Target: 1 (real)

        # Fake images (input + generated output)
        fake_ab = net_G(real_L)  # Generator's output
        fake_input = torch.cat([real_L, fake_ab], dim=1)  # Concatenate grayscale and fake color
        fake_validity = net_D(fake_input.detach())  # Discriminator output for fake images
        fake_loss = adversarial_loss(fake_validity, torch.zeros_like(fake_validity, device=device))  # Target: 0 (fake)

        # Total discriminator loss
        d_loss = (real_loss + fake_loss) / 2
        d_loss.backward()
        optimizer_D.step()

        # ==========================================
        # Train Generator
        # ==========================================
        optimizer_G.zero_grad()

        # Adversarial loss for generator
        fake_validity = net_D(fake_input)  # Discriminator's response to fake images
        g_adv_loss = adversarial_loss(fake_validity, torch.ones_like(fake_validity, device=device))  # Target: 1 (fool D)

        # L1 loss for generator
        g_l1_loss = lambda_L1 * l1_loss(fake_ab, real_ab)  # Pixel-wise similarity

        # Total generator loss
        g_loss = g_adv_loss + g_l1_loss
        g_loss.backward()

        # Gradient clipping
        nn.utils.clip_grad_norm_(net_G.parameters(), max_norm=1.0)
        
        optimizer_G.step()
        
        # ==========================================
        # Logging and visualization (optional)
        # ==========================================
        if i % 100 == 0:
            print(f"Epoch [{epoch}/{NUM_EPOCHS}], Step [{i}/{len(train_dl)}], "
                  f"D Loss: {d_loss.item():.4f}, G Loss: {g_loss.item():.4f}")
    
    # ==========================================
    # Save sample outputs every 5 epochs
    # ==========================================
    if (epoch + 1) % 5 == 0:
        with torch.no_grad():
            # Take a batch of 8 grayscale images
            sample_L = real_L[:8]
            sample_fake_ab = net_G(sample_L)  # Generate fake color channels
            
            # Reverse normalization for saved images
            real_L = (sample_L + 1.0) * 50.0  # Reverse normalization for L channel
            sample_fake_ab = sample_fake_ab * 128.0  # Reverse normalization for ab channels

            # Combine L and ab channels to form the LAB image
            sample_fake_lab = torch.cat([real_L, sample_fake_ab], dim=1)  # Concatenate L and ab

            # Convert from LAB to RGB
            sample_fake_rgb = []
            for i in range(len(sample_fake_lab)):
                lab_image = sample_fake_lab[i].cpu().numpy().transpose(1, 2, 0)  # Convert to HWC format
                rgb_image = lab2rgb(lab_image)  # Convert LAB to RGB
                sample_fake_rgb.append(torch.from_numpy(rgb_image).permute(2, 0, 1))  # Convert back to tensor

            # Save as grid
            save_image(torch.stack(sample_fake_rgb), f"sample_epoch_{epoch+1}.png", nrow=4)
            print(f"Saved sample outputs for epoch {epoch+1}")

         # Save model weights
        torch.save(net_G.state_dict(), f"/kaggle/working/net_G_epoch_{epoch+1}.pth")
        torch.save(net_D.state_dict(), f"/kaggle/working/net_D_epoch_{epoch+1}.pth")
        print(f"Saved model weights for epoch {epoch+1}")
    
    # Step scheduler for both generators and discriminators
    scheduler_G.step()
    scheduler_D.step()


Pretraining Epoch 1/20: 100%|██████████| 313/313 [07:00<00:00,  1.34s/it]


Pretraining Epoch 1/20, Average L1 Loss: 0.08041


Pretraining Epoch 2/20: 100%|██████████| 313/313 [07:12<00:00,  1.38s/it]


Pretraining Epoch 2/20, Average L1 Loss: 0.06756


Pretraining Epoch 3/20: 100%|██████████| 313/313 [07:10<00:00,  1.38s/it]


Pretraining Epoch 3/20, Average L1 Loss: 0.06629


Pretraining Epoch 4/20: 100%|██████████| 313/313 [07:10<00:00,  1.38s/it]


Pretraining Epoch 4/20, Average L1 Loss: 0.06520


Pretraining Epoch 5/20: 100%|██████████| 313/313 [07:09<00:00,  1.37s/it]


Pretraining Epoch 5/20, Average L1 Loss: 0.06436


Pretraining Epoch 6/20: 100%|██████████| 313/313 [07:09<00:00,  1.37s/it]


Pretraining Epoch 6/20, Average L1 Loss: 0.06339


Pretraining Epoch 7/20: 100%|██████████| 313/313 [07:08<00:00,  1.37s/it]


Pretraining Epoch 7/20, Average L1 Loss: 0.06252


Pretraining Epoch 8/20: 100%|██████████| 313/313 [07:07<00:00,  1.37s/it]


Pretraining Epoch 8/20, Average L1 Loss: 0.06156


Pretraining Epoch 9/20: 100%|██████████| 313/313 [07:08<00:00,  1.37s/it]


Pretraining Epoch 9/20, Average L1 Loss: 0.06051


Pretraining Epoch 10/20: 100%|██████████| 313/313 [07:08<00:00,  1.37s/it]


Pretraining Epoch 10/20, Average L1 Loss: 0.05967


Pretraining Epoch 11/20: 100%|██████████| 313/313 [07:08<00:00,  1.37s/it]


Pretraining Epoch 11/20, Average L1 Loss: 0.05855


Pretraining Epoch 12/20: 100%|██████████| 313/313 [07:07<00:00,  1.37s/it]


Pretraining Epoch 12/20, Average L1 Loss: 0.05773


Pretraining Epoch 13/20: 100%|██████████| 313/313 [07:06<00:00,  1.36s/it]


Pretraining Epoch 13/20, Average L1 Loss: 0.05676


Pretraining Epoch 14/20: 100%|██████████| 313/313 [07:05<00:00,  1.36s/it]


Pretraining Epoch 14/20, Average L1 Loss: 0.05602


Pretraining Epoch 15/20: 100%|██████████| 313/313 [07:04<00:00,  1.36s/it]


Pretraining Epoch 15/20, Average L1 Loss: 0.05534


Pretraining Epoch 16/20: 100%|██████████| 313/313 [07:05<00:00,  1.36s/it]


Pretraining Epoch 16/20, Average L1 Loss: 0.05446


Pretraining Epoch 17/20: 100%|██████████| 313/313 [07:04<00:00,  1.36s/it]


Pretraining Epoch 17/20, Average L1 Loss: 0.05341


Pretraining Epoch 18/20: 100%|██████████| 313/313 [07:03<00:00,  1.35s/it]


Pretraining Epoch 18/20, Average L1 Loss: 0.05264


Pretraining Epoch 19/20: 100%|██████████| 313/313 [07:02<00:00,  1.35s/it]


Pretraining Epoch 19/20, Average L1 Loss: 0.05169


Pretraining Epoch 20/20: 100%|██████████| 313/313 [07:01<00:00,  1.35s/it]


Pretraining Epoch 20/20, Average L1 Loss: 0.05112
Epoch [0/20], Step [0/313], D Loss: 0.7157, G Loss: 6.3963
Epoch [0/20], Step [100/313], D Loss: 0.6496, G Loss: 5.7424
Epoch [0/20], Step [200/313], D Loss: 0.5984, G Loss: 5.6708
Epoch [0/20], Step [300/313], D Loss: 0.6543, G Loss: 5.6039
Epoch [1/20], Step [0/313], D Loss: 0.6748, G Loss: 3.5439
Epoch [1/20], Step [100/313], D Loss: 0.6992, G Loss: 3.3087
Epoch [1/20], Step [200/313], D Loss: 0.7048, G Loss: 3.4984
Epoch [1/20], Step [300/313], D Loss: 0.6198, G Loss: 3.6826
Epoch [2/20], Step [0/313], D Loss: 0.5790, G Loss: 2.3902
Epoch [2/20], Step [100/313], D Loss: 0.7070, G Loss: 2.5440
Epoch [2/20], Step [200/313], D Loss: 0.6166, G Loss: 2.5743
Epoch [2/20], Step [300/313], D Loss: 0.6643, G Loss: 2.7944
Epoch [3/20], Step [0/313], D Loss: 0.7362, G Loss: 1.9359
Epoch [3/20], Step [100/313], D Loss: 0.6846, G Loss: 1.9452
Epoch [3/20], Step [200/313], D Loss: 0.7040, G Loss: 2.1007
Epoch [3/20], Step [300/313], D Loss: 0.646

  rgb_image = lab2rgb(lab_image)  # Convert LAB to RGB
  rgb_image = lab2rgb(lab_image)  # Convert LAB to RGB
  rgb_image = lab2rgb(lab_image)  # Convert LAB to RGB
  rgb_image = lab2rgb(lab_image)  # Convert LAB to RGB


Saved sample outputs for epoch 5
Saved model weights for epoch 5
Epoch [5/20], Step [0/313], D Loss: 0.7298, G Loss: 1.5612
Epoch [5/20], Step [100/313], D Loss: 0.6795, G Loss: 1.5933
Epoch [5/20], Step [200/313], D Loss: 0.7094, G Loss: 1.6117
Epoch [5/20], Step [300/313], D Loss: 0.7245, G Loss: 1.6176
Epoch [6/20], Step [0/313], D Loss: 0.7222, G Loss: 1.4334
Epoch [6/20], Step [100/313], D Loss: 0.7176, G Loss: 1.3290
Epoch [6/20], Step [200/313], D Loss: 0.7080, G Loss: 1.4440
Epoch [6/20], Step [300/313], D Loss: 0.7188, G Loss: 1.5024
Epoch [7/20], Step [0/313], D Loss: 0.6844, G Loss: 1.3412
Epoch [7/20], Step [100/313], D Loss: 0.7155, G Loss: 1.4053
Epoch [7/20], Step [200/313], D Loss: 0.7044, G Loss: 1.3899
Epoch [7/20], Step [300/313], D Loss: 0.6697, G Loss: 1.4601
Epoch [8/20], Step [0/313], D Loss: 0.6855, G Loss: 1.2846
Epoch [8/20], Step [100/313], D Loss: 0.6939, G Loss: 1.2085
Epoch [8/20], Step [200/313], D Loss: 0.6804, G Loss: 1.2903
Epoch [8/20], Step [300/313]

  rgb_image = lab2rgb(lab_image)  # Convert LAB to RGB


Saved sample outputs for epoch 15
Saved model weights for epoch 15
Epoch [15/20], Step [0/313], D Loss: 0.6778, G Loss: 1.0594
Epoch [15/20], Step [100/313], D Loss: 0.6861, G Loss: 1.0063
Epoch [15/20], Step [200/313], D Loss: 0.6868, G Loss: 1.0194
Epoch [15/20], Step [300/313], D Loss: 0.6812, G Loss: 1.0568
Epoch [16/20], Step [0/313], D Loss: 0.6757, G Loss: 1.0089
Epoch [16/20], Step [100/313], D Loss: 0.6878, G Loss: 0.9614
Epoch [16/20], Step [200/313], D Loss: 0.7012, G Loss: 1.0158
Epoch [16/20], Step [300/313], D Loss: 0.6798, G Loss: 0.9752
Epoch [17/20], Step [0/313], D Loss: 0.6823, G Loss: 1.0145
Epoch [17/20], Step [100/313], D Loss: 0.7042, G Loss: 0.9316
Epoch [17/20], Step [200/313], D Loss: 0.6768, G Loss: 0.9887
Epoch [17/20], Step [300/313], D Loss: 0.6860, G Loss: 0.9680
Epoch [18/20], Step [0/313], D Loss: 0.6911, G Loss: 1.0049
Epoch [18/20], Step [100/313], D Loss: 0.6906, G Loss: 0.9695
Epoch [18/20], Step [200/313], D Loss: 0.7012, G Loss: 0.9098
Epoch [18/2