In [None]:
## Importing required Libraries

import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader
from torchvision import datasets, transforms, models
import numpy as np
from skimage.color import rgb2lab

In [None]:
# To GPU
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [None]:
# Function to convert RGB images to L, a, and b channels in Lab color space
def rgb_to_lab(images):
    l_channels = []
    ab_channels = []
    for img in images:
        img = img.permute(1, 2, 0).cpu().numpy()  # Convert to HWC format and move to CPU
        lab_image = rgb2lab(img)  # Convert to CIE-Lab

        # Normalize L, a, and b channels
        L_channel = lab_image[:, :, 0] / 100.0  # Normalize L channel to [0, 1]
        a_channel = (lab_image[:, :, 1] + 128) / 255.0  # Normalize a channel to [0, 1]
        b_channel = (lab_image[:, :, 2] + 128) / 255.0  # Normalize b channel to [0, 1]

        l_channels.append(L_channel)
        ab_channels.append(np.stack((a_channel, b_channel), axis=-1))  # Stack a and b

    # Convert to PyTorch tensors
    L = torch.tensor(np.stack(l_channels), dtype=torch.float32).unsqueeze(1).to(device)  # (N, 1, H, W)
    ab = torch.tensor(np.stack(ab_channels), dtype=torch.float32).permute(0, 3, 1, 2).to(device)  # (N, 2, H, W)
    return L, ab

In [None]:
# CIFAR10 Dataset Loader
def load_cifar10_dataset(batch_size=8, num_workers=2):
    transform = transforms.Compose([
        transforms.Resize((224, 224)),  # Resize to match ResNet/DenseNet input size
        transforms.ToTensor(),  # Convert to tensor and scale to [0, 1]
    ])

    # Load the CIFAR-10 dataset
    train_set = datasets.CIFAR10(root='./data/cifar10', train=True, download=True, transform=transform)
    test_set = datasets.CIFAR10(root='./data/cifar10', train=False, download=True, transform=transform)

    # Create DataLoaders for train and test sets
    train_loader = DataLoader(train_set, batch_size=batch_size, shuffle=True, num_workers=num_workers)
    test_loader = DataLoader(test_set, batch_size=batch_size, shuffle=False, num_workers=num_workers)

    return train_loader, test_loader

In [None]:
# Model Architecture

class ColorizationModel(nn.Module):
    def __init__(self):
        super(ColorizationModel, self).__init__()

        # Pre-trained ResNet50 encoder Model (modified to accept 1 channel input i.e., L-Channal)
        self.encoder_resnet = models.resnet50(weights='IMAGENET1K_V1')
        self.encoder_resnet.conv1 = nn.Conv2d(1, 64, kernel_size=7, stride=2, padding=3, bias=False)
        self.encoder_resnet = nn.Sequential(*list(self.encoder_resnet.children())[:-2])  # Removing not required layers

        # Pre-trained DenseNet121 encoder (modified to accept 1 channel input i.e., L-Channal)
        self.encoder_densenet = models.densenet121(weights='IMAGENET1K_V1')
        self.encoder_densenet.features.conv0 = nn.Conv2d(1, 64, kernel_size=7, stride=2, padding=3, bias=False)
        self.encoder_densenet = nn.Sequential(*list(self.encoder_densenet.children())[:-1])  # Removing not required layers

        # Pooling layer to downsample DenseNet output to 7x7
        self.downsample_densenet = nn.AdaptiveAvgPool2d((7, 7))

        # Fusion Blocks
        self.fusion_block1 = nn.Sequential(
            nn.Conv2d(2048 + 1024, 256, kernel_size=1),   # concatnated inchannals to outchannals
            nn.BatchNorm2d(256),
            nn.ReLU()
        )

        self.fusion_block2 = nn.Sequential(
            nn.Conv2d(256 + 256, 256, kernel_size=1),
            nn.BatchNorm2d(256),
            nn.ReLU()
        )

        self.fusion_block3 = nn.Sequential(
            nn.Conv2d(256 + 256, 256, kernel_size=1),
            nn.BatchNorm2d(256),
            nn.ReLU()
        )

        self.fusion_block4 = nn.Sequential(
            nn.Conv2d(256 + 256, 256, kernel_size=1),
            nn.BatchNorm2d(256),
            nn.ReLU()
        )

        # Decoder Blocks
        self.decoder_block1 = nn.Sequential(
            nn.Conv2d(256, 256, kernel_size=3, padding=1),
            nn.BatchNorm2d(256),
            nn.ReLU(),
            nn.Upsample(scale_factor=2)
        )

        self.decoder_block2 = nn.Sequential(
            nn.Conv2d(512, 256, kernel_size=3, padding=1),  # Input Channals -> 512 (Concatnated db1 output and fb3 output)
            nn.BatchNorm2d(256),
            nn.ReLU(),
            nn.Upsample(scale_factor=2)
        )

        self.decoder_block3 = nn.Sequential(
            nn.Conv2d(512, 256, kernel_size=3, padding=1),
            nn.BatchNorm2d(256),
            nn.ReLU(),
            nn.Upsample(scale_factor=2)
        )

        self.decoder_block4 = nn.Sequential(
            nn.Conv2d(512, 256, kernel_size=3, padding=1),
            nn.BatchNorm2d(256),
            nn.ReLU(),
            nn.Upsample(scale_factor=2)
        )

        self.decoder_block5 = nn.Sequential(
            nn.Conv2d(256, 2, kernel_size=3, padding=1),
            nn.Tanh(),  # Output Range [-1, 1]
            nn.Upsample(scale_factor=2)  # Upsample to 224 x 224 (Original Spacial resolution)
        )

        # Upsampling layers for skip connections -> to match input concatenation to decoder blocks
        self.upsample_fb3 = nn.Upsample(scale_factor=2)  # Upsample from 7x7 to 14x14
        self.upsample_fb2 = nn.Upsample(scale_factor=4)  # Upsample from 7x7 to 28x28
        self.upsample_fb1 = nn.Upsample(scale_factor=8)  # Upsample from 7x7 to 56x56

    def forward(self, x):
        # Encoder
        x_resnet = self.encoder_resnet(x)  # ResNet output
        x_densenet = self.encoder_densenet(x)  # DenseNet output
        x_densenet = self.downsample_densenet(x_densenet)  # Downsample DenseNet output

        # Fusion Blocks
        fb1_input = torch.cat([x_resnet, x_densenet], dim=1)  # 2048 + 1024
        fb1_output = self.fusion_block1(fb1_input)

        fb2_input = torch.cat([fb1_output, fb1_output], dim=1)  # Use previous output only
        fb2_output = self.fusion_block2(fb2_input)

        fb3_input = torch.cat([fb2_output, fb2_output], dim=1)  # Use previous output only
        fb3_output = self.fusion_block3(fb3_input)

        fb4_input = torch.cat([fb3_output, fb3_output], dim=1)  # Use previous output only
        fb4_output = self.fusion_block4(fb4_input)

        # Decoder with Skip Connections
        db1_output = self.decoder_block1(fb4_output)

        # Upsample fb3 to match db1_output's size
        fb3_output_upsampled = self.upsample_fb3(fb3_output)
        db2_input = torch.cat([db1_output, fb3_output_upsampled], dim=1)  # Concatenate db1_output with fb3 (upsampled)
        db2_output = self.decoder_block2(db2_input)

        # Upsample fb2 to match db2_output's size
        fb2_output_upsampled = self.upsample_fb2(fb2_output)
        db3_input = torch.cat([db2_output, fb2_output_upsampled], dim=1)  # Concatenate db2_output with fb2 (upsampled)
        db3_output = self.decoder_block3(db3_input)

        # Upsample fb1 to match db3_output's size
        fb1_output_upsampled = self.upsample_fb1(fb1_output)
        db4_input = torch.cat([db3_output, fb1_output_upsampled], dim=1)  # Concatenate db3_output with fb1 (upsampled)
        db4_output = self.decoder_block4(db4_input)

        output = self.decoder_block5(db4_output)

        return output


In [None]:
import torch.nn.functional as F

# Training loop
def train(model, train_loader, criterion, optimizer, num_epochs):
    model.train()  # Set the model to training mode
    for epoch in range(num_epochs):
        for batch_idx, batch in enumerate(train_loader):
            images = batch[0].to(device)  # Move images to device

            # Convert RGB to L and ab channels
            L, ab_target = rgb_to_lab(images)

            # Zero gradients
            optimizer.zero_grad()

            # Forward pass: model predicts 'a' and 'b' channels from 'L' channel
            ab_pred = model(L)  # The input is now just the L channel

            # Compute loss between predicted ab channels and ground truth ab channels
            loss = criterion(ab_pred, ab_target)
            loss.backward()  # Backward pass
            optimizer.step()  # Update weights

            # Print loss every 500 batches
            if batch_idx % 500 == 0:
                print(f'Epoch [{epoch+1}/{num_epochs}], Batch [{batch_idx}/{len(train_loader)}], Loss: {loss.item():.4f}')

In [None]:
# Example usage
if __name__ == "__main__":
    # Load the CIFAR10 dataset
    batch_size = 8
    train_loader, test_loader = load_cifar10_dataset(batch_size=batch_size)

    # Initialize the model
    model = ColorizationModel().to(device)  # Move model to GPU if available

    # Define loss function and optimizer
    criterion = nn.MSELoss()  # Loss for comparing predicted ab with ground truth ab
    optimizer = optim.Adam(model.parameters(), lr=0.001)

    # Train the model
    num_epochs = 2
    train(model, train_loader, criterion, optimizer, num_epochs)

Downloading https://www.cs.toronto.edu/~kriz/cifar-10-python.tar.gz to ./data/cifar10/cifar-10-python.tar.gz


100%|██████████| 170498071/170498071 [00:05<00:00, 29593531.79it/s]


Extracting ./data/cifar10/cifar-10-python.tar.gz to ./data/cifar10
Files already downloaded and verified


Downloading: "https://download.pytorch.org/models/resnet50-0676ba61.pth" to /root/.cache/torch/hub/checkpoints/resnet50-0676ba61.pth
100%|██████████| 97.8M/97.8M [00:00<00:00, 167MB/s]
Downloading: "https://download.pytorch.org/models/densenet121-a639ec97.pth" to /root/.cache/torch/hub/checkpoints/densenet121-a639ec97.pth
100%|██████████| 30.8M/30.8M [00:00<00:00, 135MB/s]


Epoch [1/2], Batch [0/6250], Loss: 0.3705


KeyboardInterrupt: 

In [None]:
import torch

# After training is complete
model_save_path = "colorization_model_cat.pth"

# Save the model's state_dict (recommended way to save models in PyTorch)
torch.save(model.state_dict(), model_save_path)

print(f"Model saved to {model_save_path}")