# Image Colorizer using Deep Learning
## Introduction
Image colorization is the process of adding color to a grayscale image. This can be done through a variety of methods, including manual methods such as painting or digital methods such as neural networks.

One common approach to digital colorization is to use a convolutional neural network (CNN) trained on a dataset of color images. The network is trained to learn the relationships between the grayscale values of an image and the corresponding color values. Once trained, the network can then be used to colorize new grayscale images by using the learned relationships to predict the appropriate color values for each pixel.

Another approach is using Generative Adversarial Networks (GANs) where a generator network learns to generate a colored image from a grayscale image and a discriminator network is trained to distinguish between real color images and the generated colored images. As the generator network improves, the discriminator network is no longer able to distinguish between the real and generated images, resulting in a high-quality colorization.

## Goal
The goal is to develop a deep learning-based image colorization system using a cGAN architecture that can accurately and realistically colorize grayscale images while preserving fine details and textures in the final output. The aim is to achieve this by training the generator network to learn the relationships between the grayscale input image and the corresponding color output, and the discriminator network to effectively distinguish between the generated colored images and real color images.

# Generative Adversarial Network (GAN)
## What is GAN?
A Generative Adversarial Network (GAN) is a type of deep learning model that is used for generative tasks, such as image generation and image colorization. It is composed of two main parts: a generator network and a discriminator network.

The generator network is trained to generate new data samples that are similar to the training data. For example, it can generate new images that resemble the training images. The generator network takes a random input, called a noise vector, and maps it to a sample of the target data distribution.

The discriminator network is trained to distinguish between real data samples and the generated samples produced by the generator network. It takes an input (real or generated) and output a probability of whether the input is real or fake.

The generator and discriminator networks are trained together in an adversarial manner, where the generator is trying to produce samples that can fool the discriminator and the discriminator is trying to correctly identify whether the input is real or fake. Through this competition, the generator learns to produce more realistic samples, while the discriminator becomes better at identifying fake samples. Eventually, the generator produces samples that are indistinguishable from real data and the GAN has successfully learned the target data distribution.

> This project implements a `Conditional Wasserstein GAN`.

# Code

## Importing necessary modules/packages

In [None]:
import gc
import random
import warnings

import cv2
import matplotlib
import matplotlib.pyplot as plt
import numpy as np
import pytorch_lightning as pl
import torch
from skimage.color import lab2rgb
from torch import nn, optim
from torch.nn import functional as F  # noqa
from torch.utils.data import Dataset, DataLoader
from torchsummary import summary
from torchvision import transforms
from torchvision.models.inception import inception_v3

## Setting module/packange environments

In [None]:
warnings.filterwarnings("ignore")
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
matplotlib.style.use("seaborn-pastel")

## Dataset
Used [Image Colorization Dataset by Shravankumar Shetty](https://www.kaggle.com/datasets/shravankumar9892/image-colorization) from (Kaggle)[https://www.kaggle.com/]. The dataset consists 25000 images of [LAB](https://www.xrite.com/blog/lab-color-space) color space. 

The dataset is very large and only first 10000 images has been used for this project.

In [None]:
AB_image_path = "/kaggle/input/image-colorization/ab/ab/ab1.npy"
L_image_path = "/kaggle/input/image-colorization/l/gray_scale.npy"

AB_image_df = np.load(AB_image_path)
L_image_df = np.load(L_image_path)[: AB_image_df.shape[0]]

print(f"Total {AB_image_df.shape[0]} Color images of shape {AB_image_df.shape[1:]}")
print(f"Total {L_image_df.shape[0]} Color images of shape {L_image_df.shape[1:]}")

gc.collect()

**Utility function to convert images to RGB from individual L and A&B component**

In [None]:
def LAB_to_RGB(L_img, AB_img):
    L_img = L_img * 100
    AB_img = (AB_img - 0.5) * 128 * 2
    LAB_img = torch.cat([L_img, AB_img], dim=2).numpy()
    RGB_images = []
    for img in LAB_img:
        img_RGB = lab2rgb(img)
        RGB_images.append(img_RGB)
    return np.stack(RGB_images, axis=0)

**Plot random images to show the images in the dataset**

In [None]:
n = random.randint(20, 50)
plt.figure(figsize=(30, 30))

for i in range(n + 1, n + 17, 2):
    plt.subplot(4, 4, (i - n))
    img = np.zeros((224, 224, 3))
    img[:, :, 0] = L_image_df[i]
    plt.title("B&W")
    plt.imshow(lab2rgb(img))

    plt.subplot(4, 4, (i + 1 - n))
    img[:, :, 1:] = AB_image_df[i]
    img = img.astype("uint8")
    img = cv2.cvtColor(img, cv2.COLOR_LAB2RGB)
    plt.title("Colored")
    plt.imshow(img)

gc.collect()

**Convert the dataset as pytorch tensor**

In [None]:
class ImageColorizationDataset(Dataset):
    def __init__(self, dataset, transform=None):
        self.dataset = dataset
        self.transform = transform

    def __len__(self):
        return len(self.dataset[0])

    def __getitem__(self, idx):
        L = np.array(self.dataset[0][idx]).reshape((224, 224, 1))
        L = transforms.ToTensor()(L)

        AB = np.array(self.dataset[1][idx])
        AB = transforms.ToTensor()(AB)

        return AB, L

**Build the data generator**

In [None]:
batch_size = 1
split = 0.3
train_size = int(AB_image_df.shape[0] * (1 - split))
test_size = int(AB_image_df.shape[0] * split)

train_dataset = ImageColorizationDataset(dataset=(L_image_df[:train_size], AB_image_df[:train_size]))
test_dataset = ImageColorizationDataset(dataset=(L_image_df[-test_size:], AB_image_df[-test_size:]))

print(f"Train dataset has {len(train_dataset)} images")
print(f"Test dataset has {len(test_dataset)} images")

train_loader = DataLoader(dataset=train_dataset, batch_size=batch_size, shuffle=True, pin_memory=True)
test_loader = DataLoader(dataset=test_dataset, batch_size=batch_size, shuffle=False, pin_memory=True)

## Generator
The generator is a UNet with ResBlock for Semantic Segmentation.

In [None]:
class ResBlock(nn.Module):
    def __init__(self, in_channels, out_channels, stride=1):
        super().__init__()
        self.layer = nn.Sequential(
            nn.Conv2d(
                in_channels,
                out_channels,
                kernel_size=3,
                padding=1,
                stride=stride,
                bias=False,
            ),
            nn.BatchNorm2d(out_channels),
            nn.ReLU(inplace=True),
            nn.Conv2d(
                out_channels,
                out_channels,
                kernel_size=3,
                padding=1,
                stride=1,
                bias=False,
            ),
            nn.BatchNorm2d(out_channels),
            nn.ReLU(inplace=True),
        )

        self.identity_map = nn.Conv2d(in_channels, out_channels, kernel_size=1, stride=stride)
        self.relu = nn.ReLU(inplace=True)

    def forward(self, inputs):
        x = inputs.clone().detach()
        out = self.layer(x)
        residual = self.identity_map(inputs)
        skip = out + residual
        return self.relu(skip)


class DownSampleConv(nn.Module):
    def __init__(self, in_channels, out_channels, stride=1):
        super().__init__()
        self.layer = nn.Sequential(nn.MaxPool2d(2), ResBlock(in_channels, out_channels))

    def forward(self, inputs):
        return self.layer(inputs)


class UpSampleConv(nn.Module):
    def __init__(self, in_channels, out_channels):
        super().__init__()

        self.upsample = nn.Upsample(scale_factor=2, mode="bilinear", align_corners=True)
        self.res_block = ResBlock(in_channels + out_channels, out_channels)

    def forward(self, inputs, skip):
        x = self.upsample(inputs)
        x = torch.cat([x, skip], dim=1)
        x = self.res_block(x)
        return x


class Generator(nn.Module):
    def __init__(self, input_channel, output_channel, dropout_rate=0.2):
        super().__init__()
        self.encoding_layer1_ = ResBlock(input_channel, 64)
        self.encoding_layer2_ = DownSampleConv(64, 128)
        self.encoding_layer3_ = DownSampleConv(128, 256)
        self.bridge = DownSampleConv(256, 512)
        self.decoding_layer3_ = UpSampleConv(512, 256)
        self.decoding_layer2_ = UpSampleConv(256, 128)
        self.decoding_layer1_ = UpSampleConv(128, 64)
        self.output = nn.Conv2d(64, output_channel, kernel_size=1)
        self.dropout = nn.Dropout2d(dropout_rate)

    def forward(self, inputs):
        e1 = self.encoding_layer1_(inputs)
        e1 = self.dropout(e1)
        e2 = self.encoding_layer2_(e1)
        e2 = self.dropout(e2)
        e3 = self.encoding_layer3_(e2)
        e3 = self.dropout(e3)

        bridge = self.bridge(e3)
        bridge = self.dropout(bridge)

        d3 = self.decoding_layer3_(bridge, e3)
        d2 = self.decoding_layer2_(d3, e2)
        d1 = self.decoding_layer1_(d2, e1)

        output = self.output(d1)
        return output

In [None]:
model = Generator(1, 2).to(device)
summary(model, (1, 224, 224), batch_size=1)

## Discriminator
The discriminator is a standard Convolutional Neural Network (CNN).

In [None]:
class Discriminator(nn.Module):
    def __init__(self, in_channels=3):
        super(Discriminator, self).__init__()

        def discriminator_block(in_filters, out_filters, normalization=True):
            layers = [nn.Conv2d(in_filters, out_filters, 4, stride=2, padding=1)]
            if normalization:
                layers.append(nn.InstanceNorm2d(out_filters))
            layers.append(nn.LeakyReLU(0.2, inplace=True))
            return layers

        self.model = nn.Sequential(
            *discriminator_block(in_channels, 64, normalization=False),
            *discriminator_block(64, 128),
            *discriminator_block(128, 256),
            *discriminator_block(256, 512),
            nn.AdaptiveAvgPool2d(1),
            nn.Flatten(),
            nn.Linear(512, 1),
        )

    def forward(self, ab, l):
        img_input = torch.cat((ab, l), 1)
        output = self.model(img_input)
        return output

In [None]:
model = Discriminator(3).to(device)
summary(model, [(2, 224, 224), (1, 224, 224)], batch_size=1)

## Generative Adversarial Network

In [None]:
def _weights_init(m):
    if isinstance(m, (nn.Conv2d, nn.ConvTranspose2d)):
        torch.nn.init.normal_(m.weight, 0.0, 0.02)
    if isinstance(m, nn.BatchNorm2d):
        torch.nn.init.normal_(m.weight, 0.0, 0.02)
        torch.nn.init.constant_(m.bias, 0)


def display_progress(cond, real, fake, current_epoch=0, figsize=(20, 15)):
    cond = cond.detach().cpu().permute(1, 2, 0)
    real = real.detach().cpu().permute(1, 2, 0)
    fake = fake.detach().cpu().permute(1, 2, 0)

    images = [cond, real, fake]
    titles = ["input", "real", "generated"]
    print(f"Epoch: {current_epoch}")
    fig, ax = plt.subplots(1, 3, figsize=figsize)
    for idx, img in enumerate(images):
        if idx == 0:
            ab = torch.zeros((224, 224, 2))
            img = torch.cat([images[0] * 100, ab], dim=2).numpy()
            imgan = lab2rgb(img)
        else:
            imgan = LAB_to_RGB(images[0], img)
        ax[idx].imshow(imgan)
        ax[idx].axis("off")
    for idx, title in enumerate(titles):
        ax[idx].set_title("{}".format(title))
    plt.show()


class ConditionalWGAN(pl.LightningModule):
    def __init__(
        self,
        in_channels,
        out_channels,
        learning_rate=0.0002,
        lambda_recon=100,
        display_step=10,
        lambda_gp=10,
        lambda_r1=10,
    ):

        super().__init__()
        self.save_hyperparameters()

        self.display_step = display_step

        self.generator = Generator(in_channels, out_channels)
        self.discriminator = Discriminator(in_channels + out_channels)
        self.optimizer_G = optim.Adam(self.generator.parameters(), lr=learning_rate, betas=(0.5, 0.9))
        self.optimizer_C = optim.Adam(self.discriminator.parameters(), lr=learning_rate, betas=(0.5, 0.9))
        self.lambda_recon = lambda_recon
        self.lambda_gp = lambda_gp
        self.lambda_r1 = lambda_r1
        self.recon_criterion = nn.L1Loss()
        self.generator_losses, self.discriminator_losses = [], []

    def configure_optimizers(self):
        return [self.optimizer_C, self.optimizer_G]

    def generator_step(self, real_images, conditioned_images):
        self.optimizer_G.zero_grad()
        fake_images = self.generator(conditioned_images)
        recon_loss = self.recon_criterion(fake_images, real_images)
        recon_loss.backward()
        self.optimizer_G.step()

        self.generator_losses += [recon_loss.item()]

    def discriminator_step(self, real_images, conditioned_images):
        self.optimizer_C.zero_grad()
        fake_images = self.generator(conditioned_images)
        fake_logits = self.discriminator(fake_images, conditioned_images)
        real_logits = self.discriminator(real_images, conditioned_images)

        loss_C = real_logits.mean() - fake_logits.mean()

        alpha = torch.rand(real_images.size(0), 1, 1, 1, requires_grad=True)
        alpha = alpha.to(device)
        interpolated = (alpha * real_images + (1 - alpha) * fake_images.detach()).requires_grad_(True)

        interpolated_logits = self.discriminator(interpolated, conditioned_images)

        grad_outputs = torch.ones_like(interpolated_logits, dtype=torch.float32, requires_grad=True)
        gradients = torch.autograd.grad(
            outputs=interpolated_logits,
            inputs=interpolated,
            grad_outputs=grad_outputs,
            create_graph=True,
            retain_graph=True,
        )[0]

        gradients = gradients.view(len(gradients), -1)
        gradients_penalty = ((gradients.norm(2, dim=1) - 1) ** 2).mean()
        loss_C += self.lambda_gp * gradients_penalty

        r1_reg = gradients.pow(2).sum(1).mean()
        loss_C += self.lambda_r1 * r1_reg

        loss_C.backward()
        self.optimizer_C.step()
        self.discriminator_losses += [loss_C.item()]

    def training_step(self, batch, batch_idx, optimizer_idx):
        real, condition = batch
        if optimizer_idx == 0:
            self.discriminator_step(real, condition)
        elif optimizer_idx == 1:
            self.generator_step(real, condition)
        gen_mean = sum(self.generator_losses[-self.display_step :]) / self.display_step
        crit_mean = sum(self.discriminator_losses[-self.display_step :]) / self.display_step
        if self.current_epoch % self.display_step == 0 and batch_idx == 0 and optimizer_idx == 1:
            fake = self.generator(condition).detach()
            print(f"Epoch {self.current_epoch} : Generator loss: {gen_mean}, discriminator loss: {crit_mean}")
            display_progress(condition[0], real[0], fake[0], self.current_epoch)

## Build and Train the model

In [None]:
gc.collect()
cwgan = ConditionalWGAN(in_channels=1, out_channels=2, learning_rate=2e-4, lambda_recon=100, display_step=10)

In [None]:
trainer = pl.Trainer(max_epochs=150, gpus=-1)
trainer.fit(cwgan, train_loader)

## Model inferencing

In [None]:
plt.figure(figsize=(30, 60))
idx = 40
for batch_idx, batch in enumerate(test_loader):
    real, condition = batch
    pred = cwgan.generator(condition).detach().squeeze().permute(1, 2, 0)
    condition = condition.detach().squeeze(0).permute(1, 2, 0)
    real = real.detach().squeeze(0).permute(1, 2, 0)
    plt.subplots_adjust(wspace=0, hspace=0)
    plt.subplot(6, 3, idx)
    plt.grid(False)

    ab = torch.zeros((224, 224, 2))
    img = torch.cat([condition * 100, ab], dim=2).numpy()
    imgan = lab2rgb(img)
    plt.imshow(imgan)
    plt.title("Input")

    plt.subplot(6, 3, idx + 1)

    ab = torch.zeros((224, 224, 2))
    imgan = LAB_to_RGB(condition, real)
    plt.imshow(imgan)
    plt.title("Real")

    plt.subplot(6, 3, idx + 2)
    imgan = LAB_to_RGB(condition, pred)
    plt.title("Generated")
    plt.imshow(imgan)
    idx += 3
    if idx >= 18:
        break

## Model evaluation

**Inception Score**

In [None]:
torch.set_grad_enabled(False)
cwgan.generator.eval()
all_preds = []
all_real = []

for batch_idx, batch in enumerate(test_loader):
    real, condition = batch
    pred = cwgan.generator(condition).detach()
    Lab = torch.cat([condition, pred], dim=1).numpy()
    Lab_real = torch.cat([condition, real], dim=1).numpy()
    all_preds.append(Lab.squeeze())
    all_real.append(Lab_real.squeeze())
    if batch_idx == 500:
        break

In [None]:
class InceptionScore:
    def __init__(self, device):
        self.device = device
        self.inception = inception_v3(pretrained=True, transform_input=False).to(self.device)
        self.inception.eval()

    def calculate_is(self, generated_images):
        generated_images = generated_images.to(self.device)

        with torch.no_grad():
            generated_features = self.inception(generated_images.view(-1, 3, 224, 224))

        generated_features = generated_features.view(generated_features.size(0), -1)
        p = F.softmax(generated_features, dim=1)

        kl = p * (torch.log(p) - torch.log(torch.tensor(1.0 / generated_features.size(1)).to(self.device)))
        kl = kl.sum(dim=1)

        return kl.mean().item(), kl.std().item()

In [None]:
device = ["cuda", "cpu"][0]
is_calculator = InceptionScore(device)

all_preds = np.concatenate(all_preds, axis=0)
all_preds = torch.tensor(all_preds).float()

all_real = np.concatenate(all_real, axis=0)
all_real = torch.tensor(all_real).float()

is_model = InceptionScore(device)

mean_real, std_real = is_model.calculate_is(all_real)
mean_is, std_is = is_model.calculate_is(all_preds)

print("Inception Score of real images: mean: {:.4f}, std: {:.4f}".format(mean_real, std_real))
print("Inception Score of fake images: mean: {:.4f}, std: {:.4f}".format(mean_is, std_is))

**Frechet Inception Distance (FID)**

In [None]:
class FID:
    def __init__(self, device):
        self.device = device
        self.inception = inception_v3(pretrained=True, transform_input=False).to(self.device)
        self.inception.eval()
        self.mu = None
        self.sigma = None

    def calculate_fid(self, real_images, generated_images):
        real_images = real_images.to(self.device)
        generated_images = generated_images.to(self.device)

        with torch.no_grad():
            real_features = self.inception(real_images.view(-1, 3, 224, 224))
            generated_features = self.inception(generated_images.view(-1, 3, 224, 224))

        real_features = real_features.view(real_features.size(0), -1)
        generated_features = generated_features.view(generated_features.size(0), -1)

        if self.mu is None:
            self.mu = real_features.mean(dim=0)

        if self.sigma is None:
            self.sigma = real_features.std(dim=0)

        real_mu = real_features.mean(dim=0)
        real_sigma = real_features.std(dim=0)

        generated_mu = generated_features.mean(dim=0)
        generated_sigma = generated_features.std(dim=0)

        mu_diff = real_mu - generated_mu
        sigma_diff = real_sigma - generated_sigma

        fid = mu_diff.pow(2).sum() + (self.sigma - generated_sigma).pow(2).sum() + (self.mu - generated_mu).pow(2).sum()
        return fid.item()

In [None]:
device = ["cuda", "cpu"][0]
fid_calculator = FID(device)

fid_value = fid_calculator.calculate_fid(all_real, all_preds)
print("FID: {:.4f}".format(fid_value))

# Conclusion
Two methods to enhance the performance of the model for image colorization are by utilizing the Wasserstein GAN (WGAN) and a U-Net architecture that incorporates residual blocks. The WGAN technique utilizes the Wasserstein distance metric during the training of the generator and discriminator, which can lead to a more stable training process and generate more realistic output. A U-Net structure, which is particularly efficient for image segmentation tasks, combined with the use of residual blocks allows the network to capture fine details of the input image, which is essential for colorization tasks. This can enhance the stability and the capability to learn fine details of the input image, resulting in more realistic output.