In [1]:
!pip install torchsummary

[0m^C


In [3]:
import os
from pathlib import Path

# Import glob to get the files directories recursively
import glob

# Import Garbage collector interface
import gc 

# Import OpenCV to transforme pictures
import cv2

# Import Time
import time

# import numpy for math calculations
import numpy as np

# Import pandas for data (csv) manipulation
import pandas as pd

# Import matplotlib for plotting
import matplotlib.pyplot as plt
import matplotlib
matplotlib.style.use('fivethirtyeight') 
%matplotlib inline

import PIL
from PIL import Image
from skimage.color import rgb2lab, lab2rgb

import pytorch_lightning as pl

# Import pytorch to build Deel Learling Models 
import torch
from torch import nn, optim
from torchvision import transforms
from torch.utils.data import Dataset, DataLoader
from torch.autograd import Variable
from torchvision import models
from torch.nn import functional as F
import torch.utils.data
from torchvision.models.inception import inception_v3
from scipy.stats import entropy

from torchsummary import summary

# Import tqdm to show a smart progress meter
from tqdm import tqdm

# Import warnings to hide the unnessairy warniings
import warnings
warnings.filterwarnings('ignore')

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

<hr>
<h1><center>I. Project Understanding</center></h1>
<hr>

## A. Introduction


<div class="alert alert-block alert-info" style="font-size:14px; color:black; background-color:#e6f2ff; border-color: #3399ff">
    <aside>
    📘 Pix2pix is a powerful model for image-to-image translation tasks, but it can be further improved for specific applications such as colorization. One way to improve the performance of pix2pix for colorization is to use a Wasserstein GAN (WGAN) instead of the traditional GAN architecture. WGANs use the Wasserstein distance metric to train the generator and discriminator, which can help stabilize the training process and produce more realistic results.

Another way to improve the performance of pix2pix for colorization is to use a U-Net architecture based on residual blocks. U-Net is a type of convolutional neural network (CNN) that is well-suited for image segmentation tasks. It consists of a series of convolutional layers and max pooling layers, with skip connections between layers of the same resolution. This allows the network to learn fine details of the input image, which can be particularly useful for colorization tasks.

Residual blocks are a type of building block for neural networks, which consist of multiple layers with skip connections. The skip connections allow the gradient to pass through the layers more easily, which can help the network converge faster and produce better results.

Using WGAN with a U-Net architecture based on residual blocks can help improve the performance of pix2pix for colorization by providing better stability and improved ability to learn fine details of the input image.
    </aside>
</div>

<aside>
📌 The goal of this paper is :

- The goal of the "Image-to-Image Translation with Conditional Adversarial Networks" (pix2pix) paper is to propose a method for image-to-image translation using a conditional GAN architecture.
- The method uses a generator network that is trained to convert images from one domain (e.g. sketches) to another domain (e.g. photographs).
- The generator is trained using a combination of adversarial loss and L1 loss.
- The generator takes an image from the input domain and a random noise as input, and generates an image in the target domain.
- The discriminator network is trained to classify the generated image as real or fake, based on whether it is similar to a target image from the target domain.
- The paper uses the patch GAN architecture to discriminator with 70x70 patches.
- The authors showed the effectiveness of the proposed method on a variety of image-to-image translation tasks, such as converting edges to photographs, day to night, and labels to street scenes.
</aside>

<aside>
❓ What is image colorization ? 
  <br> <br/>
Image colorization is the process of adding color to a grayscale image or a black and white image. It involves mapping the intensity values of the grayscale image to a color space, such as RGB, and then filling in the missing color channels to produce a full-color image. There are different ways to approach image colorization, but most methods involve some form of image processing, such as image segmentation, texture synthesis, or machine learning.

One popular approach is to use a deep learning-based method, such as a convolutional neural network (CNN) to colorize images. This approach typically involves training a CNN on a large dataset of color images, and then using this network to predict the missing color channels of a grayscale image.

Another approach is to use a Generative Adversarial Network (GAN) model, where a generator network generates the color version of the grayscale image and a discriminator network is trained to distinguish between the generated color version and the real color image.

In recent years, there have been some impressive results in image colorization using deep learning-based methods, which can produce high-quality colorization results on a wide range of images.
</aside>

## B. Theory 

### 1. Generative adversarial networks (GAN)

A generative adversarial network (GAN) is a type of deep learning network that can generate data with similar characteristics as the input training data.

A GAN consists of two networks that train together:

* Generator — Given a vector of random values as input, this network generates data with the same structure as the training data.

* Discriminator — Given batches of data containing observations from both the training data, and generated data from the generator, this network attempts to classify the observations as "real" or "generated".

<center><img src="https://it.mathworks.com/help/examples/nnet/win64/TrainConditionalGenerativeAdversarialNetworkCGANExample_01.png"/> </center>

### 2. Conditional Generative adversarial networks (cGAN)


A conditional generative adversarial network (CGAN) is a type of GAN that also takes advantage of labels during the training process.

* Generator — Given a label and random array as input, this network generates data with the same structure as the training data observations corresponding to the same label.

* Discriminator — Given batches of labeled data containing observations from both the training data and generated data from the generator, this network attempts to classify the observations as "real" or "generated".

<center><img src="https://it.mathworks.com/help/examples/nnet/win64/TrainConditionalGenerativeAdversarialNetworkCGANExample_02.png"/> </center>

To train a conditional GAN, train both networks simultaneously to maximize the performance of both:

* Train the generator to generate data that "fools" the discriminator.

* Train the discriminator to distinguish between real and generated data.

To maximize the performance of the generator, maximize the loss of the discriminator when given generated labeled data. That is, the objective of the generator is to generate labeled data that the discriminator classifies as "real".

To maximize the performance of the discriminator, minimize the loss of the discriminator when given batches of both real and generated labeled data. That is, the objective of the discriminator is to not be "fooled" by the generator.

Ideally, these strategies result in a generator that generates convincingly realistic data that corresponds to the input labels and a discriminator that has learned strong feature representations that are characteristic of the training data for each label.

source : https://it.mathworks.com/help/deeplearning/ug/train-conditional-generative-adversarial-network.html

### 3. Why choosing cGAN over GAN

Conditional Generative Adversarial Networks (CGANs) are an extension of standard Generative Adversarial Networks (GANs) that are designed to handle conditional data. A CGAN consists of a generator network and a discriminator network, just like a standard GAN. However, in a CGAN, the generator and discriminator are both conditioned on some additional input data. This additional input data can be used to control the output of the generator, allowing it to produce more specific or customized results.

There are several reasons why a CGAN can be better than a standard GAN:

1. Control over the generated data: In a CGAN, the generator's output is conditioned on the input data, which allows the model to be more specific and controlled in its output. For example, if the input is a grayscale image, the model can colorize it to a specific color scheme.

2. Improved stability and training: Because the generator is conditioned on additional input data, it can be easier to train and more stable than a standard GAN. This is because the generator is able to focus on a specific subset of the data, rather than trying to generate all possible outputs.

3. Handling missing data: CGANs are well suited for handling missing data or data with missing modalities. The additional input data can be used to condition the generator to produce plausible outputs for the missing data.

4. Handling multiple classes: CGANs can be used to generate data for multiple classes in a one-to-many mapping, where the generator is conditioned on the class label and produces an image from that class.

5. Handling conditional data: In some tasks, the data is conditional, such as in image-to-image translation, where the output is conditioned on the input. CGANs can handle this kind of conditional data very well.

It's important to note that in some tasks a GAN might be enough or even better than a CGAN, it depends on the task and the data.

<hr>
<h1><center>II. Data Preparation</center></h1>
<hr>

In [4]:
ab_path = r"C:\Users\Raoul\Desktop\Studium\Semester_10\ki_seminar\coloring_images\ab\ab/ab1.npy"
l_path = r"C:\Users\Raoul\Desktop\Studium\Semester_10\ki_seminar\coloring_images\l/gray_scale.npy"

In [5]:
ab_df = np.load(ab_path)[0:5000]
L_df = np.load(l_path)[0:5000]
dataset = (L_df,ab_df )
gc.collect()

22

<hr>
<h1><center>III. Data Exploration and Visualiaztion</center></h1>
<hr>

## A. L*a*b* Colors

Like geographic coordinates – longitude, latitude, and altitude – L*a*b* color values give us a way to locate and communicate colors.

In the 1940’s, Richard Hunter introduced a tri-stimulus model, Lab, which is scaled to achieve near uniform spacing of perceived color differences. While Hunter’s Lab was adopted as the de facto model for plotting absolute color coordinates and differences between colors, it was never formally accepted as an international standard. 

### What does L*a*b* stand for? 
 it’s important to know what L*, a*, and b*stand for. 

* L*: Lightness
* a*: Red/Green Value
* b*: Blue/Yellow Value

As an example,  showing the color-plotting diagrams for L*a*b*.
<center>
<img src="https://www.xrite.com/-/media/modules/weblog/blog/lab-color-space/lab-color-space.png?h=622&w=600&la=en&hash=53A76941BAB3015346FAB3689739E967843CF8EA"></center>

* The a* axis runs from left to right. A color measurement movement in the +a direction depicts a shift toward red.
* Along the b* axis, +b movement represents a shift toward yellow.
* The center L* axis shows L = 0 (black or total absorption) at the bottom.
* At the center of this plane is neutral or gray.

Source : https://www.xrite.com/blog/lab-color-space

### Why chosing L*a*b* for our problem ?

sing the LAB color space for image colorization can be beneficial for several reasons:

1. LAB color space is perceptually uniform: The LAB color space separates color information (A and B channels) from lightness information (L channel), which allows for more accurate color representation. This can be particularly useful for image colorization, as it allows for more precise control over the colorization process.

2. LAB color space is more suitable for image processing: LAB color space is designed to be perceptually uniform and it separates the lightness information from the color information. This makes it more suitable for image processing tasks such as colorization, where it is important to maintain the relationship between lightness and color.

3. LAB color space is more robust to lighting changes: The L channel of the LAB color space represents the lightness of the image, which is relatively robust to changes in lighting conditions. This can be useful when colorizing images taken under different lighting conditions, as it allows for more consistent results.

4. LAB color space is more similar to human vision: The LAB color space is based on the way human eyes perceive color, which means that the results of colorization in LAB space are more similar to what a human would perceive.

5. LAB color space can be converted to other color spaces: LAB color space can be easily converted to other color spaces such as RGB which is the most common color space used in computer vision and image processing.

6. To train a model for colorization, we should give it a grayscale image and hope that it will make it colorful. When using L*a*b, we can give the L channel to the model (which is the grayscale image) and want it to predict the other two channels (*a, *b) and after its prediction, we concatenate all the channels and we get our colorful image. But if you use RGB, you have to first convert your image to grayscale, feed the grayscale image to the model and hope it will predict 3 numbers for you which is a way more difficult and unstable task due to the many more possible combinations of 3 numbers compared to two numbers. 

It's important to note that using LAB color space is not the only option for image colorization, and other color spaces such as RGB can also be used. But LAB color space has been proved to be a better option for image colorization due to the above-mentioned reasons.

In [6]:
def lab_to_rgb(L, ab):
    """
    Takes an image or a batch of images and converts from LAB space to RGB
    """
    L = L  * 100
    ab = (ab - 0.5) * 128 * 2
    Lab = torch.cat([L, ab], dim=2).numpy()
    rgb_imgs = []
    for img in Lab:
        img_rgb = lab2rgb(img)
        rgb_imgs.append(img_rgb)
    return np.stack(rgb_imgs, axis=0)

In [None]:
plt.figure(figsize=(30,30))
for i in range(1,16,2):
    plt.subplot(4,4,i)
    img = np.zeros((224,224,3))
    img[:,:,0] = L_df[i]
    plt.title('B&W')
    plt.imshow(lab2rgb(img))
    
    plt.subplot(4,4,i+1)
    img[:,:,1:] = ab_df[i]
    img = img.astype('uint8')
    img = cv2.cvtColor(img, cv2.COLOR_LAB2RGB)
    plt.title('Colored')
    plt.imshow(img)

In [7]:
gc.collect()

0

<hr>
<h1><center>III. Data Loader</center></h1>
<hr>

In [8]:
class ImageColorizationDataset(Dataset):
    ''' Black and White (L) Images and corresponding A&B Colors'''
    def __init__(self, dataset, transform=None):
        '''
        :param dataset: Dataset name.
        :param data_dir: Directory with all the images.
        :param transform: Optional transform to be applied on sample
        '''
        self.dataset = dataset
        self.transform = transform
    
    def __len__(self):
        return len(self.dataset[0])
    
    def __getitem__(self, idx):
        L = np.array(dataset[0][idx]).reshape((224,224,1))
        L = transforms.ToTensor()(L)
        
        ab = np.array(dataset[1][idx])
        ab = transforms.ToTensor()(ab)

        return ab, L

In [9]:
batch_size = 1

# Prepare the Datasets
train_dataset = ImageColorizationDataset(dataset = (L_df, ab_df))
test_dataset = ImageColorizationDataset(dataset = (L_df, ab_df))

# Build DataLoaders
train_loader = DataLoader(dataset=train_dataset, batch_size=batch_size, shuffle = True, pin_memory = True)
test_loader = DataLoader(dataset=test_dataset, batch_size=batch_size, shuffle = False, pin_memory = True)

<hr>
<h1><center>IV. Data Modelling</center></h1>
<hr>

# A. Generator ( ResU-NET )
>UNet with ResBlock for Semantic Segmentation

## 1. Theory

The U-Net is a convolutional neural network architecture that is designed for fast and precise segmentation of images. It has performed extremely well in several challenges and to this day, it is one of the most popular end-to-end architectures in the field of semantic segmentation.

We can split the network into two parts: 
* The encoder path (backbone) and the decoder path. The encoder captures features at different scales of the images by using a traditional stack of convolutional and max pooling layers.Concretely speaking, a block in the encoder consists of the repeated use of two convolutional layers (k=3, s=1), each followed by a non-linearity layer, and a max-pooling layer (k=2, s=2). For every convolution block and its associated max pooling operation, the number of feature maps is doubled to ensure that the network can learn the complex structures effectively.

* The decoder path is a symmetric expanding counterpart that uses transposed convolutions. This type of convolutional layer is an up-sampling method with trainable parameters and performs the reverse of (down)pooling layers such as the max pool. Similar to the encoder, each convolution block is followed by such an up-convolutional layer. The number of feature maps is halved in every block. Because recreating a segmentation mask from a small feature map is a rather difficult task for the network, the output after every up-convolutional layer is appended by the feature maps of the corresponding encoder block. The feature maps of the encoder layer are cropped if the dimensions exceed the one of the corresponding decoder layers.

Source : https://towardsdatascience.com/creating-and-training-a-u-net-model-with-pytorch-for-2d-3d-semantic-segmentation-model-building-6ab09d6a0862

## 2. UNet with ResBlock for Semantic Segmentation
UNet architecture was a great step forward in computer vision that revolutionized segmentation not just in medical imaging but in other fields as well. The long skip connection between each level of contracting path and expanding path is the key feature of the UNet. It’s like FCN is pulled upwards from both ends.

Another revolutionary advancement in computer vision was ResNet. The residual blocks in ResNet with skip connections helped in making a deeper and deeper convolution neural network and achieved record-breaking results for classification on the ImageNet dataset.
<center> <img src="https://miro.medium.com/max/720/0*Q6Dq_Ztsno3zV8TF"> </center>

Now by replacing convolutions in U-Net on each level with ResBlock, we can get better performance than the original UNet almost every time. Below is the detailed model architecture diagram.

source  : https://medium.com/@nishanksingla/unet-with-resblock-for-semantic-segmentation-dd1766b4ff66

## 3. Achitecture 

<center> <img src="https://i.imgur.com/k6ErEni.png"></center

## 4. Implementation

In [10]:
class ResBlock(nn.Module):
    def __init__(self, in_channels, out_channels, stride=1):
        super().__init__()
        self.layer = nn.Sequential(
            nn.Conv2d(in_channels,out_channels,kernel_size=3, padding=1, stride=stride, bias=False),
            nn.BatchNorm2d(out_channels),
            nn.ReLU(inplace=True),
            nn.Conv2d(out_channels, out_channels,kernel_size=3,padding=1, stride=1, bias=False),
            nn.BatchNorm2d(out_channels),
            nn.ReLU(inplace=True)
        )

        self.identity_map = nn.Conv2d(in_channels, out_channels,kernel_size=1,stride=stride)
        self.relu = nn.ReLU(inplace=True)
    def forward(self, inputs):
        x = inputs.clone().detach()
        out = self.layer(x)
        residual  = self.identity_map(inputs)
        skip = out + residual
        return self.relu(skip)

In [11]:
class DownSampleConv(nn.Module):
    def __init__(self, in_channels, out_channels, stride=1):
        super().__init__()
        self.layer = nn.Sequential(
            nn.MaxPool2d(2),
            ResBlock(in_channels, out_channels)
        )

    def forward(self, inputs):
        return self.layer(inputs)

In [12]:
class UpSampleConv(nn.Module):
    def __init__(self, in_channels, out_channels):
        super().__init__()
        
        self.upsample = nn.Upsample(scale_factor=2, mode="bilinear", align_corners=True)
        self.res_block = ResBlock(in_channels + out_channels, out_channels)
        
    def forward(self, inputs, skip):
        x = self.upsample(inputs)
        x = torch.cat([x, skip], dim=1)
        x = self.res_block(x)
        return x

In [13]:
class Generator(nn.Module):
    def __init__(self, input_channel, output_channel, dropout_rate = 0.2):
        super().__init__()
        self.encoding_layer1_ = ResBlock(input_channel,64)
        self.encoding_layer2_ = DownSampleConv(64, 128)
        self.encoding_layer3_ = DownSampleConv(128, 256)
        self.bridge = DownSampleConv(256, 512)
        self.decoding_layer3_ = UpSampleConv(512, 256)
        self.decoding_layer2_ = UpSampleConv(256, 128)
        self.decoding_layer1_ = UpSampleConv(128, 64)
        self.output = nn.Conv2d(64, output_channel, kernel_size=1)
        self.dropout = nn.Dropout2d(dropout_rate)
        
    def forward(self, inputs):
        ###################### Enocoder #########################
        e1 = self.encoding_layer1_(inputs)
        e1 = self.dropout(e1)
        e2 = self.encoding_layer2_(e1)
        e2 = self.dropout(e2)
        e3 = self.encoding_layer3_(e2)
        e3 = self.dropout(e3)
        
        ###################### Bridge #########################
        bridge = self.bridge(e3)
        bridge = self.dropout(bridge)
        
        ###################### Decoder #########################
        d3 = self.decoding_layer3_(bridge, e3)
        d2 = self.decoding_layer2_(d3, e2)
        d1 = self.decoding_layer1_(d2, e1)
        
        ###################### Output #########################
        output = self.output(d1)
        return output

In [14]:
model = Generator(1,2).to(device)
summary(model, (1, 224, 224), batch_size = 1)

----------------------------------------------------------------
        Layer (type)               Output Shape         Param #
            Conv2d-1          [1, 64, 224, 224]             576
       BatchNorm2d-2          [1, 64, 224, 224]             128
              ReLU-3          [1, 64, 224, 224]               0
            Conv2d-4          [1, 64, 224, 224]          36,864
       BatchNorm2d-5          [1, 64, 224, 224]             128
              ReLU-6          [1, 64, 224, 224]               0
            Conv2d-7          [1, 64, 224, 224]             128
              ReLU-8          [1, 64, 224, 224]               0
          ResBlock-9          [1, 64, 224, 224]               0
        Dropout2d-10          [1, 64, 224, 224]               0
        MaxPool2d-11          [1, 64, 112, 112]               0
           Conv2d-12         [1, 128, 112, 112]          73,728
      BatchNorm2d-13         [1, 128, 112, 112]             256
             ReLU-14         [1, 128, 1

# B. Discriminator ( Critic )

## 1. Theory

The class Critic in this research is a crucial component of the proposed architecture for image recoloring using a conditional WGAN. This class defines the architecture of the critic network, which is responsible for evaluating the quality of the generated images. The critic network is trained to differentiate between real and fake images, where the real images are the ground truth images in the LAB color space, and the fake images are the generated images by the generator network.

## 2. Architecture

The architecture of the critic network follows a standard convolutional neural network (CNN) design, where the input image is processed through a series of convolutional layers followed by batch normalization, LeakyReLU activation, and downsampling layers. The convolutional layers are designed to extract features from the input image, and the downsampling layers are responsible for reducing the spatial dimensions of the feature maps while increasing the number of filters.

The architecture of the critic network is designed to handle the input images in the LAB color space, where the input image is the concatenation of the ab channels and the L channel. The output of the critic network is a scalar value, representing the probability of the input image being real or fake.

The proposed architecture of the critic network is important for image recoloring task, as it allows the generator network to learn the underlying distribution of the real images in the LAB color space. By providing a reliable evaluation of the generated images, the critic network helps the generator network to produce more realistic and high-quality images. Additionally, the use of the LeakyReLU activation function and the Instance Normalization layers improve the performance of the critic network, as they help to stabilize the training process and reduce the mode collapse problem.

<div class="alert alert-block alert-info" style="font-size:14px; color:black; background-color:#ffeee6; border-color: #ff661a">
    <aside>
📙  In our project, The discriminator network is composed of 5 convolutional layers, each followed by a Intestance normalization layer and a LeakyReLU activation function. The first 4 layers have a stride of 2, which reduces the spatial resolution of the feature map by a factor of 4. The last layer is a fully connected layer that outputs a single output representing the probability that a patch is real. 

<center> <img src="https://i.imgur.com/rG6DjQA.png"></center>

## 2. Implementation

In [15]:
class Critic(nn.Module):
    def __init__(self, in_channels=3):
        super(Critic, self).__init__()

        def critic_block(in_filters, out_filters, normalization=True):
            """Returns layers of each critic block"""
            layers = [nn.Conv2d(in_filters, out_filters, 4, stride=2, padding=1)]
            if normalization:
                layers.append(nn.InstanceNorm2d(out_filters))
            layers.append(nn.LeakyReLU(0.2, inplace=True))
            return layers

        self.model = nn.Sequential(
            *critic_block(in_channels, 64, normalization=False),
            *critic_block(64, 128),
            *critic_block(128, 256),
            *critic_block(256, 512),
            nn.AdaptiveAvgPool2d(1),
            nn.Flatten(),
            nn.Linear(512, 1)
        )

    def forward(self, ab, l):
        # Concatenate image and condition image by channels to produce input
        img_input = torch.cat((ab, l), 1)
        output = self.model(img_input)
        return output

<div class="alert alert-block alert-info" style="font-size:14px; color:black; background-color:#e6ffe6; border-color: #1aff1a">
    <aside>
📗 The forward method takes in two inputs, "ab" and "l", concatenates them along the channel dimension and pass them through the model. The model is a sequential of convolutional layers, Instance normalization, Leaky ReLU activation and AdaptiveAvgPool2d, Flatten and Linear layers. The output of the forward method is a single scalar value representing the Wasserstein distance between the true and fake data distributions.

In [16]:
model = Critic(3).to(device)
summary(model, [(2, 224, 224), (1, 224, 224)], batch_size = 1)

----------------------------------------------------------------
        Layer (type)               Output Shape         Param #
            Conv2d-1          [1, 64, 112, 112]           3,136
         LeakyReLU-2          [1, 64, 112, 112]               0
            Conv2d-3           [1, 128, 56, 56]         131,200
    InstanceNorm2d-4           [1, 128, 56, 56]               0
         LeakyReLU-5           [1, 128, 56, 56]               0
            Conv2d-6           [1, 256, 28, 28]         524,544
    InstanceNorm2d-7           [1, 256, 28, 28]               0
         LeakyReLU-8           [1, 256, 28, 28]               0
            Conv2d-9           [1, 512, 14, 14]       2,097,664
   InstanceNorm2d-10           [1, 512, 14, 14]               0
        LeakyReLU-11           [1, 512, 14, 14]               0
AdaptiveAvgPool2d-12             [1, 512, 1, 1]               0
          Flatten-13                   [1, 512]               0
           Linear-14                   

# C. Generative Adversarial Network

## 1. Theory

WGAN (Wasserstein Generative Adversarial Network) is a type of GAN (Generative Adversarial Network) that uses the Wasserstein distance as the loss function for the generator and the critic. The Wasserstein distance, also known as the Earth Mover's distance, is a distance metric that measures the amount of "work" required to transform one probability distribution into another.

In a traditional GAN, the generator and the critic are trained to minimize the Jensen-Shannon divergence, which is a measure of the difference between two probability distributions. However, the Jensen-Shannon divergence can be difficult to optimize and can lead to instability in the training process.

The WGAN addresses this issue by using the Wasserstein distance as the loss function, which allows for more stable training of the generator and the critic. Additionally, the Wasserstein distance has a nice property that the gradient of the loss function is always well defined, which allows for more stable optimization.

In WGAN, the critic is trained to approximate the Wasserstein distance between the real and fake data distributions. The generator is trained to generate samples that will cause the critic to output low values. In WGAN, the critic network is trained to be a 1-Lipschitz function, which means that the critic's output should change by at most 1 for any small change in the input.

Overall, WGAN is a more stable and well-defined version of GAN, that allows for more control over the training process and can help to improve the quality and diversity of the generated samples.

In [17]:
# https://stackoverflow.com/questions/49433936/how-to-initialize-weights-in-pytorch
def _weights_init(m):
    if isinstance(m, (nn.Conv2d, nn.ConvTranspose2d)):
        torch.nn.init.normal_(m.weight, 0.0, 0.02)
    if isinstance(m, nn.BatchNorm2d):
        torch.nn.init.normal_(m.weight, 0.0, 0.02)
        torch.nn.init.constant_(m.bias, 0)

In [18]:
def display_progress(cond, real, fake, current_epoch = 0, figsize=(20,15)):
    """
    Save cond, real (original) and generated (fake)
    images in one panel 
    """
    cond = cond.detach().cpu().permute(1, 2, 0)   
    real = real.detach().cpu().permute(1, 2, 0)
    fake = fake.detach().cpu().permute(1, 2, 0)
    
    images = [cond, real, fake]
    titles = ['input','real','generated']
    print(f'Epoch: {current_epoch}')
    fig, ax = plt.subplots(1, 3, figsize=figsize)
    for idx,img in enumerate(images):
        if idx == 0:
            ab = torch.zeros((224,224,2))
            img = torch.cat([images[0]* 100, ab], dim=2).numpy()
            imgan = lab2rgb(img)
        else:
            imgan = lab_to_rgb(images[0],img)
        ax[idx].imshow(imgan)
        ax[idx].axis("off")
    for idx, title in enumerate(titles):    
        ax[idx].set_title('{}'.format(title))
    plt.show()

In [19]:
class CWGAN(pl.LightningModule):

    def __init__(self, in_channels, out_channels, learning_rate=0.0002, lambda_recon=100, display_step=10, lambda_gp=10, lambda_r1=10,):

        super().__init__()
        self.save_hyperparameters()
        
        self.display_step = display_step
        
        self.generator = Generator(in_channels, out_channels)
        self.critic = Critic(in_channels + out_channels)
        self.optimizer_G = optim.Adam(self.generator.parameters(), lr=learning_rate, betas=(0.5, 0.9))
        self.optimizer_C = optim.Adam(self.critic.parameters(), lr=learning_rate, betas=(0.5, 0.9))
        self.lambda_recon = lambda_recon
        self.lambda_gp = lambda_gp
        self.lambda_r1 = lambda_r1
        self.recon_criterion = nn.L1Loss()
        self.generator_losses, self.critic_losses  =[],[]
    
    def configure_optimizers(self):
        return [self.optimizer_C, self.optimizer_G]
        
    def generator_step(self, real_images, conditioned_images):
        # WGAN has only a reconstruction loss
        self.optimizer_G.zero_grad()
        fake_images = self.generator(conditioned_images)
        recon_loss = self.recon_criterion(fake_images, real_images)
        recon_loss.backward()
        self.optimizer_G.step()
        
        # Keep track of the average generator loss
        self.generator_losses += [recon_loss.item()]
        
        
    def critic_step(self, real_images, conditioned_images):
        self.optimizer_C.zero_grad()
        fake_images = self.generator(conditioned_images)
        fake_logits = self.critic(fake_images, conditioned_images)
        real_logits = self.critic(real_images, conditioned_images)
        
        # Compute the loss for the critic
        loss_C = real_logits.mean() - fake_logits.mean()

        # Compute the gradient penalty
        alpha = torch.rand(real_images.size(0), 1, 1, 1, requires_grad=True)
        alpha = alpha.to(device)
        interpolated = (alpha * real_images + (1 - alpha) * fake_images.detach()).requires_grad_(True)
        
        interpolated_logits = self.critic(interpolated, conditioned_images)
        
        grad_outputs = torch.ones_like(interpolated_logits, dtype=torch.float32, requires_grad=True)
        gradients = torch.autograd.grad(outputs=interpolated_logits, inputs=interpolated, grad_outputs=grad_outputs,create_graph=True, retain_graph=True)[0]

        
        gradients = gradients.view(len(gradients), -1)
        gradients_penalty = ((gradients.norm(2, dim=1) - 1) ** 2).mean()
        loss_C += self.lambda_gp * gradients_penalty
        
        # Compute the R1 regularization loss
        r1_reg = gradients.pow(2).sum(1).mean()
        loss_C += self.lambda_r1 * r1_reg

        # Backpropagation
        loss_C.backward()
        self.optimizer_C.step()
        self.critic_losses += [loss_C.item()]
        
    def training_step(self, batch, batch_idx, optimizer_idx):
        real, condition = batch
        if optimizer_idx == 0:
            self.critic_step(real, condition)
        elif optimizer_idx == 1:
            self.generator_step(real, condition)
        gen_mean = sum(self.generator_losses[-self.display_step:]) / self.display_step
        crit_mean = sum(self.critic_losses[-self.display_step:]) / self.display_step
        if self.current_epoch%self.display_step==0 and batch_idx==0 and optimizer_idx==1:
            fake = self.generator(condition).detach()
            torch.save(cwgan.generator.state_dict(), "ResUnet_"+ str(self.current_epoch) +".pt")
            torch.save(cwgan.critic.state_dict(), "PatchGAN_"+ str(self.current_epoch) +".pt")
            print(f"Epoch {self.current_epoch} : Generator loss: {gen_mean}, Critic loss: {crit_mean}")
            display_progress(condition[0], real[0], fake[0], self.current_epoch)


In [20]:
gc.collect()
cwgan = CWGAN(in_channels = 1, out_channels = 2 ,learning_rate=2e-4, lambda_recon=100, display_step=10)

In [25]:
trainer = pl.Trainer(max_epochs=150)
trainer.fit(cwgan, train_loader)

GPU available: False, used: False
TPU available: False, using: 0 TPU cores
HPU available: False, using: 0 HPUs


RuntimeError: Training with multiple optimizers is only supported with manual optimization. Remove the `optimizer_idx` argument from `training_step`, set `self.automatic_optimization = False` and access your optimizers in `training_step` with `opt1, opt2, ... = self.optimizers()`.

<hr>
<h1><center>VI. Model Inferencing</center></h1>
<hr>

In [None]:
plt.figure(figsize=(30,60))
idx =1
for batch_idx, batch in enumerate(test_loader):
    real, condition = batch
    pred = cwgan.generator(condition).detach().squeeze().permute(1, 2, 0)
    condition  = condition.detach().squeeze(0).permute(1, 2, 0)
    real  = real.detach().squeeze(0).permute(1, 2, 0)
    plt.subplots_adjust(wspace=0, hspace=0)
    plt.subplot(6,3,idx)
    plt.grid(False)
    
    ab = torch.zeros((224,224,2))
    img = torch.cat([condition * 100, ab], dim=2).numpy()
    imgan = lab2rgb(img)
    plt.imshow(imgan)
    plt.title('Input')
    
    plt.subplot(6,3,idx + 1)
    
    ab = torch.zeros((224,224,2))
    imgan = lab_to_rgb(condition,real)
    plt.imshow(imgan)
    plt.title('Real')
    
    plt.subplot(6,3,idx + 2)
    imgan = lab_to_rgb(condition,pred)
    plt.title('Generated')
    plt.imshow(imgan)
    idx += 3
    if idx >= 18:
        break

<hr>
<h1><center>VI. Evaluation</center></h1>
<hr>

## A. Measuring GAN Performance

### 1. Human scoring
One way to evaluate the performance of a GAN is to use human scoring, where real images and images created by the GAN are randomly stacked together, and human scorers label each image as real or fake. This can be done using platforms such as Amazon Mechanical Turk.

### 2. Inception Score (IS)
Another way to evaluate GAN performance is to use the Inception Score, which uses a pre-trained Inception model to classify the generated images. The IS considers both the quality and diversity of the generated images by measuring the entropy of the classification labels. The higher the entropy, the more diverse and unpredictable the generated images are, indicating a better performance.

<center><img width = "700px" src="https://cdn-images-1.medium.com/max/1600/1*RdIYRsqXxRAKwcjtxg6_kw.jpeg"></center>


### 3. Frechet Inception Distance (FID)
In Frechet Inception Distance (FID), we utilize the Inception network to extract features from an intermediate layer. Then, we model the data distribution for these features using a multivariate Gaussian distribution with mean $\mu$ and covariance $\Sigma$. The FID between the real images $x$ and generated images $g$ is computed as: \\


$FID(x,g) = \left\|\mu_x - \mu_g\right\|^2_2 + Tr(\Sigma_x + \Sigma_g - 2(\Sigma_x\Sigma_g)^{\frac{1}{2}})$ 


where $Tr$ sums up all the diagonal elements. Lower FID values indicate better image quality and diversity.

FID is sensitive to mode collapse, meaning that as the distance between simulated missing modes increases, the FID value will also increase. Additionally, FID is more robust to noise than Inception Score (IS) as it is less affected by models that only generate one image per class. Therefore, FID is a better measurement for image diversity. However, it should be noted that FID has some rather high bias but low variance. By computing the FID between a training dataset and a testing dataset, we should expect the FID to be zero since both are real images. However, running the test with different batches of training sample shows non-zero FID.


## B. Implemenation of Inception Score

In [None]:
# disable grads + batchnorm + dropout
torch.set_grad_enabled(False)
cwgan.generator.eval()
all_preds = []
all_real = []

for batch_idx, batch in enumerate(test_loader):
    real, condition = batch
    pred = cwgan.generator(condition).detach()
    Lab = torch.cat([condition, pred], dim=1).numpy()
    Lab_real = torch.cat([condition, real], dim=1).numpy()
    all_preds.append(Lab.squeeze())
    all_real.append(Lab_real.squeeze())
    if batch_idx == 500: break

In [None]:
class InceptionScore:
    def __init__(self, device):
        self.device = device
        self.inception = inception_v3(pretrained=True, transform_input=False).to(self.device)
        self.inception.eval()

    def calculate_is(self, generated_images):
        generated_images = generated_images.to(self.device)

        with torch.no_grad():
            generated_features = self.inception(generated_images.view(-1,3,224,224))

        generated_features = generated_features.view(generated_features.size(0), -1)
        p = F.softmax(generated_features, dim=1)

        kl = p * (torch.log(p) - torch.log(torch.tensor(1.0/generated_features.size(1)).to(self.device)))
        kl = kl.sum(dim=1)

        return kl.mean().item(), kl.std().item()

In [None]:
# Initialize the InceptionScore class
device = "cuda" # or "cpu" if you don't have a GPU
is_calculator = InceptionScore(device)

all_preds = np.concatenate(all_preds, axis=0)
all_preds = torch.tensor(all_preds).float()

all_real = np.concatenate(all_real, axis=0)
all_real = torch.tensor(all_real).float()

is_model = InceptionScore(device)

# Calculate the Inception Score
mean_real, std_real = is_model.calculate_is(all_real)
print("Inception Score of real images: mean: {:.4f}, std: {:.4f}".format(mean_real, std_real))
mean_is, std_is = is_model.calculate_is(all_preds)
print("Inception Score of fake images: mean: {:.4f}, std: {:.4f}".format(mean_is, std_is))

## B. Implemenation of Fréchet Inception Distance (FID)

In [None]:
class FID:
    def __init__(self, device):
        self.device = device
        self.inception = inception_v3(pretrained=True, transform_input=False).to(self.device)
        self.inception.eval()
        self.mu = None
        self.sigma = None

    def calculate_fid(self, real_images, generated_images):
        real_images = real_images.to(self.device)
        generated_images = generated_images.to(self.device)

        with torch.no_grad():
            real_features = self.inception(real_images.view(-1,3,224,224))
            generated_features = self.inception(generated_images.view(-1,3,224,224))

        real_features = real_features.view(real_features.size(0), -1)
        generated_features = generated_features.view(generated_features.size(0), -1)

        if self.mu is None:
            self.mu = real_features.mean(dim=0)

        if self.sigma is None:
            self.sigma = real_features.std(dim=0)

        real_mu = real_features.mean(dim=0)
        real_sigma = real_features.std(dim=0)

        generated_mu = generated_features.mean(dim=0)
        generated_sigma = generated_features.std(dim=0)

        mu_diff = real_mu - generated_mu
        sigma_diff = real_sigma - generated_sigma

        fid = mu_diff.pow(2).sum() + (self.sigma - generated_sigma).pow(2).sum() + (self.mu - generated_mu).pow(2).sum()
        return fid.item()

In [None]:
# Initialize the FID class
device = "cuda" # or "cpu" if you don't have a GPU
fid_calculator = FID(device)

# Calculate the FID
fid_value = fid_calculator.calculate_fid(all_real, all_preds)
print("FID: {:.4f}".format(fid_value))

<hr>
<h1><center>VII. Conclusion</center></h1>
<hr>


<div class="alert alert-block alert-info" style="font-size:14px; color:black; background-color:#e6f2ff; border-color: #3399ff">
    <aside>
    📘 In conclusion, Pix2pix is a powerful model for image-to-image translation tasks, but it can be further improved for specific applications such as colorization. Using Wasserstein GAN (WGAN) and a U-Net architecture based on residual blocks are two ways to improve the performance of pix2pix for colorization. WGANs use the Wasserstein distance metric to train the generator and discriminator which can help stabilize the training process and produce more realistic results. A U-Net architecture which is well-suited for image segmentation tasks combined with Residual blocks allows the network to learn fine details of the input image, which can be particularly useful for colorization tasks. This can help improve the stability and the ability to learn fine details of the input image, producing more realistic results.
    </aside>
</div>