In [None]:
import gc
import glob
import os
import random
import time
import warnings
from pathlib import Path

import PIL
import cv2
import matplotlib
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import pytorch_lightning as pl
import torch
import torch.utils.data
from PIL import Image
from scipy.stats import entropy
from skimage.color import rgb2lab, lab2rgb
from torch import nn, optim
from torch.autograd import Variable
from torch.nn import functional as F  # noqa
from torch.utils.data import Dataset, DataLoader
from torchsummary import summary
from torchvision import models
from torchvision import transforms
from torchvision.models.inception import inception_v3
from tqdm import tqdm

In [None]:
warnings.filterwarnings("ignore")
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
matplotlib.style.use("seaborn-pastel")

In [None]:
AB_image_path = "/kaggle/input/image-colorization/ab/ab/ab1.npy"
L_image_path = "/kaggle/input/image-colorization/l/gray_scale.npy"

AB_image_df = np.load(AB_image_path)
L_image_df = np.load(L_image_path)[: AB_image_df.shape[0]]

print(f"Total {AB_image_df.shape[0]} Color images of shape {AB_image_df.shape[1:]}")
print(f"Total {L_image_df.shape[0]} Color images of shape {L_image_df.shape[1:]}")

gc.collect()

In [None]:
def LAB_to_RGB(L_img, AB_img):
    L_img = L_img * 100
    AB_img = (AB_img - 0.5) * 128 * 2
    LAB_img = torch.cat([L_img, AB_img], dim=2).numpy()
    RGB_images = []
    for img in LAB_img:
        img_RGB = lab2rgb(img)
        RGB_images.append(img_RGB)
    return np.stack(RGB_images, axis=0)

In [None]:
n = random.randint(20, 50)
plt.figure(figsize=(30, 30))

for i in range(n + 1, n + 17, 2):
    plt.subplot(4, 4, (i - n))
    img = np.zeros((224, 224, 3))
    img[:, :, 0] = L_image_df[i]
    plt.title("B&W")
    plt.imshow(lab2rgb(img))

    plt.subplot(4, 4, (i + 1 - n))
    img[:, :, 1:] = AB_image_df[i]
    img = img.astype("uint8")
    img = cv2.cvtColor(img, cv2.COLOR_LAB2RGB)
    plt.title("Colored")
    plt.imshow(img)

gc.collect()

In [None]:
class ImageColorizationDataset(Dataset):
    def __init__(self, dataset, transform=None):
        self.dataset = dataset
        self.transform = transform

    def __len__(self):
        return len(self.dataset[0])

    def __getitem__(self, idx):
        L = np.array(self.dataset[0][idx]).reshape((224, 224, 1))
        L = transforms.ToTensor()(L)

        AB = np.array(self.dataset[1][idx])
        AB = transforms.ToTensor()(AB)

        return AB, L

In [None]:
batch_size = 1
split = 0.3
train_size = int(AB_image_df.shape[0] * (1 - split))
test_size = int(AB_image_df.shape[0] * split)

train_dataset = ImageColorizationDataset(dataset=(L_image_df[:train_size], AB_image_df[:train_size]))
test_dataset = ImageColorizationDataset(dataset=(L_image_df[-test_size:], AB_image_df[-test_size:]))

print(f"Train dataset has {len(train_dataset)} images")
print(f"Test dataset has {len(test_dataset)} images")

train_loader = DataLoader(dataset=train_dataset, batch_size=batch_size, shuffle=True, pin_memory=True)
test_loader = DataLoader(dataset=test_dataset, batch_size=batch_size, shuffle=False, pin_memory=True)

In [None]:
class ResBlock(nn.Module):
    def __init__(self, in_channels, out_channels, stride=1):
        super().__init__()
        self.layer = nn.Sequential(
            nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1, stride=stride, bias=False),
            nn.BatchNorm2d(out_channels),
            nn.ReLU(inplace=True),
            nn.Conv2d(out_channels, out_channels, kernel_size=3, padding=1, stride=1, bias=False),
            nn.BatchNorm2d(out_channels),
            nn.ReLU(inplace=True),
        )

        self.identity_map = nn.Conv2d(in_channels, out_channels, kernel_size=1, stride=stride)
        self.relu = nn.ReLU(inplace=True)

    def forward(self, inputs):
        x = inputs.clone().detach()
        out = self.layer(x)
        residual = self.identity_map(inputs)
        skip = out + residual
        return self.relu(skip)


class DownSampleConv(nn.Module):
    def __init__(self, in_channels, out_channels, stride=1):
        super().__init__()
        self.layer = nn.Sequential(nn.MaxPool2d(2), ResBlock(in_channels, out_channels))

    def forward(self, inputs):
        return self.layer(inputs)


class UpSampleConv(nn.Module):
    def __init__(self, in_channels, out_channels):
        super().__init__()

        self.upsample = nn.Upsample(scale_factor=2, mode="bilinear", align_corners=True)
        self.res_block = ResBlock(in_channels + out_channels, out_channels)

    def forward(self, inputs, skip):
        x = self.upsample(inputs)
        x = torch.cat([x, skip], dim=1)
        x = self.res_block(x)
        return x


class Generator(nn.Module):
    def __init__(self, input_channel, output_channel, dropout_rate=0.2):
        super().__init__()
        self.encoding_layer1_ = ResBlock(input_channel, 64)
        self.encoding_layer2_ = DownSampleConv(64, 128)
        self.encoding_layer3_ = DownSampleConv(128, 256)
        self.bridge = DownSampleConv(256, 512)
        self.decoding_layer3_ = UpSampleConv(512, 256)
        self.decoding_layer2_ = UpSampleConv(256, 128)
        self.decoding_layer1_ = UpSampleConv(128, 64)
        self.output = nn.Conv2d(64, output_channel, kernel_size=1)
        self.dropout = nn.Dropout2d(dropout_rate)

    def forward(self, inputs):
        e1 = self.encoding_layer1_(inputs)
        e1 = self.dropout(e1)
        e2 = self.encoding_layer2_(e1)
        e2 = self.dropout(e2)
        e3 = self.encoding_layer3_(e2)
        e3 = self.dropout(e3)

        bridge = self.bridge(e3)
        bridge = self.dropout(bridge)

        d3 = self.decoding_layer3_(bridge, e3)
        d2 = self.decoding_layer2_(d3, e2)
        d1 = self.decoding_layer1_(d2, e1)

        output = self.output(d1)
        return output


In [None]:
model = Generator(1, 2).to(device)
summary(model, (1, 224, 224), batch_size=1)

In [None]:
class Discriminator(nn.Module):
    def __init__(self, in_channels=3):
        super(Discriminator, self).__init__()

        def discriminator_block(in_filters, out_filters, normalization=True):
            layers = [nn.Conv2d(in_filters, out_filters, 4, stride=2, padding=1)]
            if normalization:
                layers.append(nn.InstanceNorm2d(out_filters))
            layers.append(nn.LeakyReLU(0.2, inplace=True))
            return layers

        self.model = nn.Sequential(
            *discriminator_block(in_channels, 64, normalization=False),
            *discriminator_block(64, 128),
            *discriminator_block(128, 256),
            *discriminator_block(256, 512),
            nn.AdaptiveAvgPool2d(1),
            nn.Flatten(),
            nn.Linear(512, 1),
        )

    def forward(self, ab, l):
        img_input = torch.cat((ab, l), 1)
        output = self.model(img_input)
        return output

In [None]:
model = Discriminator(3).to(device)
summary(model, [(2, 224, 224), (1, 224, 224)], batch_size=1)

In [None]:
def _weights_init(m):
    if isinstance(m, (nn.Conv2d, nn.ConvTranspose2d)):
        torch.nn.init.normal_(m.weight, 0.0, 0.02)
    if isinstance(m, nn.BatchNorm2d):
        torch.nn.init.normal_(m.weight, 0.0, 0.02)
        torch.nn.init.constant_(m.bias, 0)


def display_progress(cond, real, fake, current_epoch=0, figsize=(20, 15)):
    cond = cond.detach().cpu().permute(1, 2, 0)
    real = real.detach().cpu().permute(1, 2, 0)
    fake = fake.detach().cpu().permute(1, 2, 0)

    images = [cond, real, fake]
    titles = ["input", "real", "generated"]
    print(f"Epoch: {current_epoch}")
    fig, ax = plt.subplots(1, 3, figsize=figsize)
    for idx, img in enumerate(images):
        if idx == 0:
            ab = torch.zeros((224, 224, 2))
            img = torch.cat([images[0] * 100, ab], dim=2).numpy()
            imgan = lab2rgb(img)
        else:
            imgan = LAB_to_RGB(images[0], img)
        ax[idx].imshow(imgan)
        ax[idx].axis("off")
    for idx, title in enumerate(titles):
        ax[idx].set_title("{}".format(title))
    plt.show()


class ConditionalWGAN(pl.LightningModule):
    def __init__(
            self,
            in_channels,
            out_channels,
            learning_rate=0.0002,
            lambda_recon=100,
            display_step=10,
            lambda_gp=10,
            lambda_r1=10,
    ):

        super().__init__()
        self.save_hyperparameters()

        self.display_step = display_step

        self.generator = Generator(in_channels, out_channels)
        self.discriminator = Discriminator(in_channels + out_channels)
        self.optimizer_G = optim.Adam(self.generator.parameters(), lr=learning_rate, betas=(0.5, 0.9))
        self.optimizer_C = optim.Adam(self.discriminator.parameters(), lr=learning_rate, betas=(0.5, 0.9))
        self.lambda_recon = lambda_recon
        self.lambda_gp = lambda_gp
        self.lambda_r1 = lambda_r1
        self.recon_criterion = nn.L1Loss()
        self.generator_losses, self.discriminator_losses = [], []

    def configure_optimizers(self):
        return [self.optimizer_C, self.optimizer_G]

    def generator_step(self, real_images, conditioned_images):
        self.optimizer_G.zero_grad()
        fake_images = self.generator(conditioned_images)
        recon_loss = self.recon_criterion(fake_images, real_images)
        recon_loss.backward()
        self.optimizer_G.step()

        self.generator_losses += [recon_loss.item()]

    def discriminator_step(self, real_images, conditioned_images):
        self.optimizer_C.zero_grad()
        fake_images = self.generator(conditioned_images)
        fake_logits = self.discriminator(fake_images, conditioned_images)
        real_logits = self.discriminator(real_images, conditioned_images)

        loss_C = real_logits.mean() - fake_logits.mean()

        alpha = torch.rand(real_images.size(0), 1, 1, 1, requires_grad=True)
        alpha = alpha.to(device)
        interpolated = (alpha * real_images + (1 - alpha) * fake_images.detach()).requires_grad_(True)

        interpolated_logits = self.discriminator(interpolated, conditioned_images)

        grad_outputs = torch.ones_like(interpolated_logits, dtype=torch.float32, requires_grad=True)
        gradients = torch.autograd.grad(
            outputs=interpolated_logits,
            inputs=interpolated,
            grad_outputs=grad_outputs,
            create_graph=True,
            retain_graph=True,
        )[0]

        gradients = gradients.view(len(gradients), -1)
        gradients_penalty = ((gradients.norm(2, dim=1) - 1) ** 2).mean()
        loss_C += self.lambda_gp * gradients_penalty

        r1_reg = gradients.pow(2).sum(1).mean()
        loss_C += self.lambda_r1 * r1_reg

        loss_C.backward()
        self.optimizer_C.step()
        self.discriminator_losses += [loss_C.item()]

    def training_step(self, batch, batch_idx, optimizer_idx):
        real, condition = batch
        if optimizer_idx == 0:
            self.discriminator_step(real, condition)
        elif optimizer_idx == 1:
            self.generator_step(real, condition)
        gen_mean = sum(self.generator_losses[-self.display_step:]) / self.display_step
        crit_mean = sum(self.discriminator_losses[-self.display_step:]) / self.display_step
        if self.current_epoch % self.display_step == 0 and batch_idx == 0 and optimizer_idx == 1:
            fake = self.generator(condition).detach()
            print(f"Epoch {self.current_epoch} : Generator loss: {gen_mean}, discriminator loss: {crit_mean}")
            display_progress(condition[0], real[0], fake[0], self.current_epoch)

In [None]:
gc.collect()
cwgan = ConditionalWGAN(in_channels=1, out_channels=2, learning_rate=2e-4, lambda_recon=100, display_step=10)

In [None]:
trainer = pl.Trainer(max_epochs=150, gpus=-1)
trainer.fit(cwgan, train_loader)

In [None]:
plt.figure(figsize=(30, 60))
idx = 40
for batch_idx, batch in enumerate(test_loader):
    real, condition = batch
    pred = cwgan.generator(condition).detach().squeeze().permute(1, 2, 0)
    condition = condition.detach().squeeze(0).permute(1, 2, 0)
    real = real.detach().squeeze(0).permute(1, 2, 0)
    plt.subplots_adjust(wspace=0, hspace=0)
    plt.subplot(6, 3, idx)
    plt.grid(False)

    ab = torch.zeros((224, 224, 2))
    img = torch.cat([condition * 100, ab], dim=2).numpy()
    imgan = lab2rgb(img)
    plt.imshow(imgan)
    plt.title('Input')

    plt.subplot(6, 3, idx + 1)

    ab = torch.zeros((224, 224, 2))
    imgan = LAB_to_RGB(condition, real)
    plt.imshow(imgan)
    plt.title('Real')

    plt.subplot(6, 3, idx + 2)
    imgan = LAB_to_RGB(condition, pred)
    plt.title('Generated')
    plt.imshow(imgan)
    idx += 3
    if idx >= 18:
        break