In [1]:
import os
import re
import pathlib
from skimage.color import rgb2lab, lab2rgb
from tqdm import tqdm
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import cv2
from PIL import Image
import random

import torch
import torch.nn as nn
import torch.optim as optim
import torchvision
import torch.nn.functional as F
from torchvision import datasets,transforms
from torch.utils.data import DataLoader, Dataset

In [2]:
!pip install torchviz

Collecting torchviz
  Downloading torchviz-0.0.3-py3-none-any.whl.metadata (2.1 kB)
Downloading torchviz-0.0.3-py3-none-any.whl (5.7 kB)
Installing collected packages: torchviz
Successfully installed torchviz-0.0.3


In [3]:
from torchviz import make_dot

In [4]:
# Configurations

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(device)
np.random.seed(25)
torch.manual_seed(25)

cuda


<torch._C.Generator at 0x7d25a6b42430>

In [5]:
img_dir = "/kaggle/input/coco25k/images"
working_dir = "/kaggle/working/"
IMG_DIM = 256
batch_size = 32 #adjust to 16 if required
learning_rate = 2e-4 #optimal rate for training GANs
NUM_EPOCHS = 20
beta1 = 0.5
beta2 = 0.999
lambda_L1 = 100.

In [6]:
class COCODataset(Dataset):
    def __init__(self, img_dir, transforms=None):
        self.img_dir = img_dir
        self.transforms = transforms
        self.image_paths = []
        all_images = [os.path.join(img_dir, file) for file in os.listdir(img_dir) 
                      if file.endswith('.jpg')]
        self.image_paths = random.sample(all_images, min(len(all_images), 10000))

    def __len__(self):
        return len(self.image_paths)

    def __getitem__(self, idx):
        img = Image.open(self.image_paths[idx]).convert("RGB")
        img = self.transforms(img)
        img = np.array(img)
        img_lab = rgb2lab(img).astype('float32')  #change to float16 for faster training
        img_lab = transforms.ToTensor()(img_lab)
        L = img_lab[[0], ...]/ 50.0 - 1.0
        ab = img_lab[[1, 2], ...]/ 128.0 
        return {'L': L, 'ab': ab}

In [7]:
# Make sure to not add jitter/ noise as it affects the image colors

train_transforms = transforms.Compose([
    transforms.Resize((IMG_DIM, IMG_DIM), Image.BICUBIC),
    transforms.RandomHorizontalFlip(),
    # transforms.RandomVerticalFlip(),
])

In [8]:
train_dataset = COCODataset(img_dir, train_transforms)
train_dl = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, num_workers=4)

In [9]:
# Sanity check

data = next(iter(train_dl))
Ls, abs_ = data['L'], data['ab']
print(Ls.shape, abs_.shape)
print(len(train_dl))

torch.Size([32, 1, 256, 256]) torch.Size([32, 2, 256, 256])
313


## Introducing ResNet backed UNet

In [10]:
from fastai.vision.learner import create_body
# from torchvision.models.resnet import resnet34
from fastai.vision.models.unet import DynamicUnet

resnet34 = torch.hub.load('pytorch/vision:v0.10.0', 'resnet34', pretrained=True)

Downloading: "https://github.com/pytorch/vision/zipball/v0.10.0" to /root/.cache/torch/hub/v0.10.0.zip
Downloading: "https://download.pytorch.org/models/resnet34-b627a593.pth" to /root/.cache/torch/hub/checkpoints/resnet34-b627a593.pth
100%|██████████| 83.3M/83.3M [00:00<00:00, 186MB/s]


Building Generator

In [11]:
def build_res_unet(n_input=1, n_output=2, size=IMG_DIM):
    body = create_body(resnet34, pretrained=True, n_in=n_input, cut=-2)
    net_G = DynamicUnet(body, n_output, (size, size))
    return net_G

Building Patch Discriminator

In [12]:
class PatchDisc(nn.Module):
    def __init__(self, input_c, num_filters=64, n_down=3):
        super().__init__()
        # We define the first and last layers outside the loop since they are required to not have activation 
        # and normalization according to the paper
        model = [self.get_layers(input_c, num_filters, norm=False)]
        model += [self.get_layers(num_filters * 2 ** i, num_filters * 2 ** (i + 1), s=1 if i == (n_down-1) else 2) 
                          for i in range(n_down)] 
        model += [self.get_layers(num_filters * 2 ** n_down, 1, s=1, norm=False, act=False)]
        self.model = nn.Sequential(*model)                                                   
        
    def get_layers(self, ni, nf, k=4, s=2, p=1, norm=True, act=True):
        layers = [nn.Conv2d(ni, nf, k, s, p, bias=not norm)]         
        if norm: layers += [nn.BatchNorm2d(nf)]
        if act: layers += [nn.LeakyReLU(0.2, True)]
        return nn.Sequential(*layers)
    
    def forward(self, x):
        return self.model(x)

In [13]:
# Defining loss functions
'''PatchGAN uses a combination of adversarial and L1 loss.'''

adversarial_loss = nn.BCEWithLogitsLoss()  
l1_loss = nn.L1Loss()  

In [14]:
# Initialize models
net_G = build_res_unet(n_input=1, n_output=2, size=IMG_DIM).to(device)
net_D = PatchDisc(input_c=3).to(device)

# Define optimizers
optimizer_G = optim.Adam(net_G.parameters(), lr=learning_rate, betas=(beta1, beta2))
optimizer_D = optim.Adam(net_D.parameters(), lr=learning_rate, betas=(beta1, beta2))

In [15]:
# Learning Rate Scheduler
scheduler_G = torch.optim.lr_scheduler.StepLR(optimizer_G, step_size=15, gamma=0.5)
scheduler_D = torch.optim.lr_scheduler.StepLR(optimizer_D, step_size=15, gamma=0.5)

def pretrain_G(net_G, train_dl, opt, criterion, epochs=20):
    net_G.train()  # Set generator to training mode
    for e in range(epochs):
        total_loss = 0.0  # Accumulate loss for each epoch
        for data in tqdm(train_dl, desc=f"Pretraining Epoch {e + 1}/{epochs}"):
            L, ab = data['L'].to(device), data['ab'].to(device)
            preds = net_G(L)  # Predicted color channels
            loss = criterion(preds, ab)  # L1 loss between predicted and ground truth
            opt.zero_grad()
            loss.backward()
            opt.step()
            total_loss += loss.item()
        
        avg_loss = total_loss / len(train_dl)
        print(f"Pretraining Epoch {e + 1}/{epochs}, Average L1 Loss: {avg_loss:.5f}")
        if (e + 1) % 5 == 0:  # Save every 5 epochs
            torch.save(net_G.state_dict(), f"generator_pretrain_epoch_{e+1}.pth")
        
        # Step scheduler
        scheduler_G.step()

        

pretrain_epochs = 20
pretrain_criterion = nn.L1Loss()
pretrain_optimizer = optim.Adam(net_G.parameters(), lr=learning_rate, betas=(beta1, beta2))

In [16]:
from torchvision.utils import save_image

# Pretrain the generator
# pretrain_G(net_G, train_dl, pretrain_optimizer, pretrain_criterion, epochs=pretrain_epochs)

for epoch in range(NUM_EPOCHS):
    lambda_L1 = max(100.0 / (epoch + 1), 1.0)  # Dynamically adjust lambda_L1

    for i, data in enumerate(train_dl):
        # Load data
        real_L = data['L'].to(device)  # Grayscale input
        real_ab = data['ab'].to(device)  # Ground truth color channels
        
        # ==========================================
        # Train Discriminator
        # ==========================================
        optimizer_D.zero_grad()
        
        # Real images (input + ground truth)
        real_input = torch.cat([real_L, real_ab], dim=1)  # Concatenate grayscale and color
        real_validity = net_D(real_input)  # Discriminator output for real images
        real_loss = adversarial_loss(real_validity, torch.ones_like(real_validity, device=device))  # Target: 1 (real)

        # Fake images (input + generated output)
        fake_ab = net_G(real_L)  # Generator's output
        fake_input = torch.cat([real_L, fake_ab], dim=1)  # Concatenate grayscale and fake color
        fake_validity = net_D(fake_input.detach())  # Discriminator output for fake images
        fake_loss = adversarial_loss(fake_validity, torch.zeros_like(fake_validity, device=device))  # Target: 0 (fake)

        # Total discriminator loss
        d_loss = (real_loss + fake_loss) / 2
        d_loss.backward()
        optimizer_D.step()

        # ==========================================
        # Train Generator
        # ==========================================
        optimizer_G.zero_grad()

        # Adversarial loss for generator
        fake_validity = net_D(fake_input)  # Discriminator's response to fake images
        g_adv_loss = adversarial_loss(fake_validity, torch.ones_like(fake_validity, device=device))  # Target: 1 (fool D)

        # L1 loss for generator
        g_l1_loss = lambda_L1 * l1_loss(fake_ab, real_ab)  # Pixel-wise similarity

        # Total generator loss
        g_loss = g_adv_loss + g_l1_loss
        g_loss.backward()

        # Gradient clipping
        nn.utils.clip_grad_norm_(net_G.parameters(), max_norm=1.0)
        
        optimizer_G.step()
        
        # ==========================================
        # Logging and visualization (optional)
        # ==========================================
        if i % 100 == 0:
            print(f"Epoch [{epoch}/{NUM_EPOCHS}], Step [{i}/{len(train_dl)}], "
                  f"D Loss: {d_loss.item():.4f}, G Loss: {g_loss.item():.4f}")
    
    # ==========================================
    # Save sample outputs every 5 epochs
    # ==========================================
    if (epoch + 1) % 5 == 0:
        with torch.no_grad():
            # Take a batch of 8 grayscale images
            sample_L = real_L[:8]
            sample_fake_ab = net_G(sample_L)  # Generate fake color channels
            
            # Reverse normalization for saved images
            real_L = (sample_L + 1.0) * 50.0  # Reverse normalization for L channel
            sample_fake_ab = sample_fake_ab * 128.0  # Reverse normalization for ab channels

            # Combine L and ab channels to form the LAB image
            sample_fake_lab = torch.cat([real_L, sample_fake_ab], dim=1)  # Concatenate L and ab

            # Convert from LAB to RGB
            sample_fake_rgb = []
            for i in range(len(sample_fake_lab)):
                lab_image = sample_fake_lab[i].cpu().numpy().transpose(1, 2, 0)  # Convert to HWC format
                rgb_image = lab2rgb(lab_image)  # Convert LAB to RGB
                sample_fake_rgb.append(torch.from_numpy(rgb_image).permute(2, 0, 1))  # Convert back to tensor

            # Save as grid
            save_image(torch.stack(sample_fake_rgb), f"sample_epoch_{epoch+1}.png", nrow=4)
            print(f"Saved sample outputs for epoch {epoch+1}")

         # Save model weights
        torch.save(net_G.state_dict(), f"/kaggle/working/net_G_epoch_{epoch+1}.pth")
        torch.save(net_D.state_dict(), f"/kaggle/working/net_D_epoch_{epoch+1}.pth")
        print(f"Saved model weights for epoch {epoch+1}")
    
    # Step scheduler for both generators and discriminators
    scheduler_G.step()
    scheduler_D.step()


Epoch [0/20], Step [0/313], D Loss: 0.7153, G Loss: 46.3890
Epoch [0/20], Step [100/313], D Loss: 0.6336, G Loss: 11.7141
Epoch [0/20], Step [200/313], D Loss: 0.2201, G Loss: 10.9877
Epoch [0/20], Step [300/313], D Loss: 0.8490, G Loss: 10.5690
Epoch [1/20], Step [0/313], D Loss: 0.3192, G Loss: 5.8533
Epoch [1/20], Step [100/313], D Loss: 0.6642, G Loss: 5.6802
Epoch [1/20], Step [200/313], D Loss: 0.7808, G Loss: 6.3628
Epoch [1/20], Step [300/313], D Loss: 0.8916, G Loss: 6.9342
Epoch [2/20], Step [0/313], D Loss: 0.5045, G Loss: 4.4878
Epoch [2/20], Step [100/313], D Loss: 0.5039, G Loss: 5.0193
Epoch [2/20], Step [200/313], D Loss: 0.7473, G Loss: 3.8400
Epoch [2/20], Step [300/313], D Loss: 0.8031, G Loss: 3.6776
Epoch [3/20], Step [0/313], D Loss: 0.6294, G Loss: 3.2148
Epoch [3/20], Step [100/313], D Loss: 0.6153, G Loss: 3.2055
Epoch [3/20], Step [200/313], D Loss: 0.5961, G Loss: 3.1336
Epoch [3/20], Step [300/313], D Loss: 0.6072, G Loss: 3.0541
Epoch [4/20], Step [0/313], 

  rgb_image = lab2rgb(lab_image)  # Convert LAB to RGB


Saved sample outputs for epoch 5
Saved model weights for epoch 5
Epoch [5/20], Step [0/313], D Loss: 0.7455, G Loss: 2.3611
Epoch [5/20], Step [100/313], D Loss: 0.6963, G Loss: 2.5493
Epoch [5/20], Step [200/313], D Loss: 0.6088, G Loss: 2.2470
Epoch [5/20], Step [300/313], D Loss: 0.5387, G Loss: 2.5871
Epoch [6/20], Step [0/313], D Loss: 0.6542, G Loss: 2.2268
Epoch [6/20], Step [100/313], D Loss: 0.5890, G Loss: 2.1973
Epoch [6/20], Step [200/313], D Loss: 0.6525, G Loss: 1.9788
Epoch [6/20], Step [300/313], D Loss: 0.6149, G Loss: 2.7638
Epoch [7/20], Step [0/313], D Loss: 0.6659, G Loss: 1.8671
Epoch [7/20], Step [100/313], D Loss: 0.6057, G Loss: 2.1471
Epoch [7/20], Step [200/313], D Loss: 0.7098, G Loss: 2.0329
Epoch [7/20], Step [300/313], D Loss: 0.7440, G Loss: 2.0459
Epoch [8/20], Step [0/313], D Loss: 0.7343, G Loss: 1.9354
Epoch [8/20], Step [100/313], D Loss: 0.6200, G Loss: 1.9883
Epoch [8/20], Step [200/313], D Loss: 0.7066, G Loss: 1.7344
Epoch [8/20], Step [300/313]

  rgb_image = lab2rgb(lab_image)  # Convert LAB to RGB


Saved sample outputs for epoch 10
Saved model weights for epoch 10
Epoch [10/20], Step [0/313], D Loss: 0.6867, G Loss: 1.7393
Epoch [10/20], Step [100/313], D Loss: 0.6275, G Loss: 1.6031
Epoch [10/20], Step [200/313], D Loss: 0.6058, G Loss: 1.7833
Epoch [10/20], Step [300/313], D Loss: 0.6665, G Loss: 1.8435
Epoch [11/20], Step [0/313], D Loss: 0.6901, G Loss: 1.5403
Epoch [11/20], Step [100/313], D Loss: 0.7267, G Loss: 1.5588
Epoch [11/20], Step [200/313], D Loss: 0.6003, G Loss: 1.6307
Epoch [11/20], Step [300/313], D Loss: 0.6436, G Loss: 1.6069
Epoch [12/20], Step [0/313], D Loss: 0.6795, G Loss: 1.5402
Epoch [12/20], Step [100/313], D Loss: 0.7358, G Loss: 1.4176
Epoch [12/20], Step [200/313], D Loss: 0.6253, G Loss: 1.7190
Epoch [12/20], Step [300/313], D Loss: 0.6962, G Loss: 1.6128
Epoch [13/20], Step [0/313], D Loss: 0.7487, G Loss: 1.4080
Epoch [13/20], Step [100/313], D Loss: 0.7469, G Loss: 1.5275
Epoch [13/20], Step [200/313], D Loss: 0.7177, G Loss: 1.4780
Epoch [13/2

  rgb_image = lab2rgb(lab_image)  # Convert LAB to RGB
  rgb_image = lab2rgb(lab_image)  # Convert LAB to RGB


Saved sample outputs for epoch 20
Saved model weights for epoch 20


## Evaluation

In [17]:
from PIL import Image
import numpy as np
from skimage.color import rgb2lab, lab2rgb
from skimage.metrics import peak_signal_noise_ratio as psnr
from skimage.metrics import structural_similarity as ssim
from fastai.vision.learner import create_body
from fastai.vision.models.unet import DynamicUnet
from huggingface_hub import hf_hub_download
import matplotlib.pyplot as plt
import os
import random
from tqdm import tqdm

import torch
import torch.nn as nn
import torch.optim as optim
import torchvision
import torch.nn.functional as F
from torchvision import datasets,transforms
from torch.utils.data import DataLoader, Dataset

In [18]:
net_G.eval().to(device)

DynamicUnet(
  (layers): ModuleList(
    (0): Sequential(
      (0): Conv2d(1, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)
      (1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (2): ReLU(inplace=True)
      (3): MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)
      (4): Sequential(
        (0): BasicBlock(
          (conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
          (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
          (relu): ReLU(inplace=True)
          (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
          (bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        )
        (1): BasicBlock(
          (conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
          (bn1): BatchNorm2d(64, eps=1e-05

In [19]:
# Preprocessing
transform = transforms.Compose([
    transforms.Resize((IMG_DIM, IMG_DIM)),
])

In [20]:
# Dataset
class CocoSubset(Dataset):
    def __init__(self, img_dir, transform=None, limit=1000):
        self.paths = [os.path.join(img_dir, f) for f in os.listdir(img_dir) if f.endswith('.jpg')]
        self.paths = random.sample(self.paths, min(limit, len(self.paths)))
        self.transform = transform

    def __len__(self): return len(self.paths)

    def __getitem__(self, idx):
        img = Image.open(self.paths[idx]).convert("RGB")
        img = self.transform(img)
        lab = rgb2lab(np.array(img).astype(np.float32) / 255.0)
        L = (lab[:, :, 0] / 50.0 - 1.0)[np.newaxis, ...]
        ab = (lab[:, :, 1:] / 128.0).transpose(2, 0, 1)
        return torch.tensor(L).float(), torch.tensor(ab).float(), lab

# Dataloader
dataset = CocoSubset('/kaggle/input/coco25k/images', transform=transform, limit=5000)
dataloader = DataLoader(dataset, batch_size=1, shuffle=False)

In [21]:
# Evaluation
psnr_list, ssim_list, delta_e_list = [], [], []

for L, ab_gt, lab_gt in tqdm(dataloader):
    L = L.to(device)
    with torch.no_grad():
        ab_pred = net_G(L).cpu()

    # Denormalize
    L_denorm = (L.cpu().numpy()[0, 0] + 1.0) * 50.0
    ab_gt_denorm = ab_gt.numpy()[0].transpose(1, 2, 0) * 128.0
    ab_pred_denorm = ab_pred.numpy()[0].transpose(1, 2, 0) * 128.0

    lab_pred = np.concatenate([L_denorm[..., None], ab_pred_denorm], axis=2)
    lab_true = lab_gt[0].numpy()

    rgb_pred = lab2rgb(lab_pred.clip(0, 100))
    rgb_true = lab2rgb(lab_true.clip(0, 100))

    psnr_list.append(psnr(rgb_true, rgb_pred, data_range=1.0))
    ssim_list.append(ssim(rgb_true, rgb_pred, channel_axis=2, data_range=1.0))
    delta_e_list.append(np.mean(np.linalg.norm(lab_true - lab_pred, axis=2)))

  rgb_pred = lab2rgb(lab_pred.clip(0, 100))
  rgb_pred = lab2rgb(lab_pred.clip(0, 100))
  rgb_pred = lab2rgb(lab_pred.clip(0, 100))
  rgb_pred = lab2rgb(lab_pred.clip(0, 100))
  rgb_pred = lab2rgb(lab_pred.clip(0, 100))
  rgb_pred = lab2rgb(lab_pred.clip(0, 100))
  rgb_pred = lab2rgb(lab_pred.clip(0, 100))
  rgb_pred = lab2rgb(lab_pred.clip(0, 100))
  rgb_pred = lab2rgb(lab_pred.clip(0, 100))
  rgb_pred = lab2rgb(lab_pred.clip(0, 100))
  rgb_pred = lab2rgb(lab_pred.clip(0, 100))
  rgb_pred = lab2rgb(lab_pred.clip(0, 100))
  rgb_pred = lab2rgb(lab_pred.clip(0, 100))
  rgb_pred = lab2rgb(lab_pred.clip(0, 100))
  rgb_pred = lab2rgb(lab_pred.clip(0, 100))
  rgb_pred = lab2rgb(lab_pred.clip(0, 100))
  rgb_pred = lab2rgb(lab_pred.clip(0, 100))
  rgb_pred = lab2rgb(lab_pred.clip(0, 100))
  rgb_pred = lab2rgb(lab_pred.clip(0, 100))
  rgb_pred = lab2rgb(lab_pred.clip(0, 100))
  rgb_pred = lab2rgb(lab_pred.clip(0, 100))
  rgb_pred = lab2rgb(lab_pred.clip(0, 100))
  rgb_pred = lab2rgb(lab_pred.cl

In [22]:
# Results
print({
    "PSNR": np.mean(psnr_list),
    "SSIM": np.mean(ssim_list),
    "DeltaE": np.mean(delta_e_list)
})

{'PSNR': 22.185953826026445, 'SSIM': 0.8788718, 'DeltaE': 21.862093}
