In [None]:
# This Python 3 environment comes with many helpful analytics libraries installed
# It is defined by the kaggle/python Docker image: https://github.com/kaggle/docker-python
# For example, here's several helpful packages to load

import cv2
import numpy as np # linear algebra
import pandas as pd # data processing, CSV file I/O (e.g. pd.read_csv
import matplotlib.pyplot as plt

# Input data files are available in the read-only "../input/" directory
# For example, running this (by clicking run or pressing Shift+Enter) will list all files under the input directory

import os
# for dirname, _, filenames in os.walk('/kaggle/input'):
#     for filename in filenames:
#         print(os.path.join(dirname, filename))

# You can write up to 20GB to the current directory (/kaggle/working/) that gets preserved as output when you create a version using "Save & Run All" 
# You can also write temporary files to /kaggle/temp/, but they won't be saved outside of the current session

## A PyTorch implementation of the following: https://www.kaggle.com/code/ttymonkey/cyclegan-starter

# 1. Data Check

In [None]:
# Check the root
root_path = "/kaggle/input/gan-getting-started"
os.listdir(root_path)

In [None]:
# Reading img in rgb
read_img = lambda path: cv2.cvtColor(cv2.imread(path), cv2.COLOR_BGR2RGB)

In [None]:
# Load a sample image of monet and photo
data_path = f"{root_path}/photo_jpg"
sample_photo = read_img(os.path.join(data_path, os.listdir(data_path)[0]))

data_path = f"{root_path}/monet_jpg"
sample_monet = read_img(os.path.join(data_path, os.listdir(data_path)[0]))

In [None]:
sample_photo.shape

In [None]:
sample_photo.min(), sample_photo.max(), sample_photo.dtype

In [None]:
# Photo
plt.subplot(121)
plt.title("Photo")
plt.imshow(sample_photo)  

# Monet
plt.subplot(122)
plt.title("Photo")
plt.imshow(sample_monet)  

# 2. Build The Model

In [None]:
import torch
import torch.nn as nn

In [None]:
# img params for testing
IMG_H, IMG_W, IMG_C = 256, 256, 3

## 2.1 Generator

In [None]:
def downsample(in_channels, out_channels, kernel_size, norm=True):
    """ A simple convolutional block that downsamples the feature map by 2 (stride)
    e.g. x [256, 256] -> out [128, 128]
    
    Same padding is applied
    
    The convolutional layer has no bias
    Group Normalisation is applied is norm=True
    
    Flow:
        x -> Conv2d -> out (x2 smaller) -> GroupNorm -> LeakyRelu
    
    
    """
    downsample_block = [nn.Conv2d(in_channels=in_channels,
                                  out_channels=out_channels,
                                  kernel_size=kernel_size,
                                  stride=2,
                                  padding=(kernel_size - 1) // 2,
                                  bias=False)
                       ]
    
    if norm:
        downsample_block.append(nn.GroupNorm(num_groups=out_channels,
                                             num_channels=out_channels))
    
    
    downsample_block.append(nn.LeakyReLU())
    
    return nn.Sequential(*downsample_block)

In [None]:
#  Check block
x = torch.randn(1,IMG_C, IMG_H, IMG_W)

# Downsample by 2 and Keep the same number of channels
out = downsample(3, 3, 3)(x)

assert IMG_C == out.shape[1] 
assert IMG_H == out.shape[2] * 2
assert IMG_W == out.shape[3] * 2

In [None]:
def upsample(in_channels, out_channels, kernel_size, dropout=True):
    """ A simple transpose convolutional block that upsample the feature maps
    e.g. x [256, 256] -> out [512, 512]
    
    Same padding is applied
    
    The transpose convolutional layer has no bias
    Group Normalisation is applied 
    Dropout is applied if drouput=True
    
    Flow:
        x -> Conv2d -> out (x2 bigger) -> GroupNorm -> Dropout -> Relu
    
    
    """
    downsample_block = [nn.ConvTranspose2d(in_channels=in_channels,
                                           out_channels=out_channels,
                                           kernel_size=kernel_size,
                                           stride=2,
                                           padding=(kernel_size - 1) // 2,
                                           bias=False),
                        nn.GroupNorm(num_groups=out_channels,
                                     num_channels=out_channels)
                       ]
    
    if dropout:
        downsample_block.append(nn.Dropout(0.5))
    
    
    downsample_block.append(nn.ReLU())
    
    return nn.Sequential(*downsample_block)

In [None]:
def weight_init(m):
    """ Initialise Conv2D and ConvTranspose2D with N(0, 0.02)
    """
    if any(isinstance(m, _m) for _m in [nn.Conv2d, nn.ConvTranspose2d]):
        nn.init.normal_(m.weight, mean=0.0, std=0.02)
        if m.bias is not None:
            nn.init.constant_(m.bias, 0)


In [None]:
#  Check block
x = torch.randn(1,IMG_C, IMG_H, IMG_W)

# Upsample by 2 and Keep the same number of channels
out = upsample(3, 3, 4)(x)

assert IMG_C == out.shape[1] 
assert IMG_H == out.shape[2] // 2
assert IMG_W == out.shape[3] // 2

In [None]:
class Generator(nn.Module):
    def __init__(self):
        super().__init__()
        
        # Encoder and Decoder
        self.encoder = self._init_encoder()
        self.decoder = self._init_decoder()
        
        # Final Upsample
        self.out = nn.ConvTranspose2d(in_channels=128,
                                      out_channels=3,
                                      kernel_size=4,
                                      stride=2, padding=1)
        self.act = nn.Tanh()
        
        self.apply(weight_init)
        
    def forward(self, x):
        """ Downsample and Upsample image
        
        """
        
        # Encode to latent space //128 #512
        skips = []  # Skip connections U-Net Style
        for layer in self.encoder:
            x = layer(x)
            skips.append(x)
        
        # Skip last one, bottom bit, latent space
        skips = reversed(skips[:-1])
            
        
        # Decode form latent space *64 #128
        for layer, skip in zip(self.decoder, skips):
            x = layer(x)
            x = torch.cat((x, skip), dim=1)
                
        # Upsample so that out.shape == x.shape
        out = self.act(self.out(x))
        
        return out
        
        
    def _init_encoder(self):
        """ A Sequential Encoder with 7 Downsample Blocks each downsampling by 2
        """
        base_channels = 64
        kernel_size = 4
        
        encoder = nn.Sequential(
            # [64, //2, //2]
            downsample(in_channels=3, out_channels=base_channels, 
                       kernel_size=kernel_size, norm=False), 
            # [128, //4, //4]
            downsample(in_channels=base_channels, out_channels=base_channels * 2, 
                       kernel_size=kernel_size),
            
            # [256, //8, //8]
            downsample(in_channels=base_channels * 2, out_channels=base_channels * 4, 
                       kernel_size=kernel_size),
            
            # [512, //16, //16]
            downsample(in_channels=base_channels * 4, out_channels=base_channels * 8, 
                       kernel_size=kernel_size),
            
            # [512, //32, //32]
            downsample(in_channels=base_channels * 8, out_channels=base_channels * 8, 
                       kernel_size=kernel_size),
            
            # [512, //64, //64]
            downsample(in_channels=base_channels * 8, out_channels=base_channels * 8, 
                       kernel_size=kernel_size),
            
            # [512, //128, //128]
            downsample(in_channels=base_channels * 8, out_channels=base_channels * 8, 
                       kernel_size=kernel_size),
            
        )
        
        
        return encoder
    
    def _init_decoder(self):
        """ A Sequential Decoder with 6 Upsample Blocks each upling by 2
        """
        base_channels = 512
        kernel_size = 4
        
        decoder = nn.Sequential(
            # [512, *2, *2]
            upsample(in_channels=base_channels, out_channels=base_channels, 
                     kernel_size=kernel_size, dropout=True),
            
            # [512, *4, *4]
            upsample(in_channels=base_channels * 2, out_channels=base_channels, 
                     kernel_size=kernel_size, dropout=True),
            
            # [512, *8, *8]
            upsample(in_channels=base_channels * 2, out_channels=base_channels, 
                     kernel_size=kernel_size, dropout=True),
            
            # [256, *16, *16]
            upsample(in_channels=base_channels * 2, out_channels=base_channels // 2, 
                     kernel_size=kernel_size),
            
            # [128, *32, *32]
            upsample(in_channels=base_channels, out_channels=base_channels // 4, 
                     kernel_size=kernel_size),
            
            # [64, *64, *64]
            upsample(in_channels=base_channels // 2, out_channels=base_channels // 8, 
                     kernel_size=kernel_size),
            
        )
        
        return decoder

In [None]:
# Test Generator
g = Generator()
x = torch.randn(1,3,256,256)

enc_out = g.encoder(x)
gen_out = g(x)

# Downsample by 128
assert enc_out.shape[1] == 512
assert enc_out.shape[2] == x.shape[2] // 128
assert enc_out.shape[3] == x.shape[3] // 128

# Upsample to the same shape
assert gen_out.shape[1] == x.shape[1]
assert gen_out.shape[2] == x.shape[2]
assert gen_out.shape[3] == x.shape[3]

In [None]:
# Check weights if initialised correctly 
convolutions = []
for _, m in g.named_modules():
    if isinstance(m, nn.Conv2d):
        convolutions.append(m.weight.flatten().clone().detach().numpy())
        
convolutions = np.concatenate(convolutions)
convolutions.mean(), convolutions.std()

## 2.2 Discriminator

In [None]:
class Discriminator(nn.Module):
    def __init__(self):
        super().__init__()
        
        base_channels = 64
        kernel_size = 4
        self.discriminator = nn.Sequential(
            downsample(in_channels=3, out_channels=base_channels, 
                       kernel_size=kernel_size, norm=False),
            downsample(in_channels=base_channels, out_channels=base_channels * 2, 
                       kernel_size=kernel_size),
            downsample(in_channels=base_channels * 2, out_channels=base_channels * 4, 
                       kernel_size=kernel_size),
            nn.ZeroPad2d(padding=(0, 2, 0, 2)),
            nn.Conv2d(in_channels=base_channels * 4, out_channels=base_channels * 8,
                      kernel_size=4, stride=1, bias=False),
            nn.GroupNorm(num_groups=base_channels * 8,
                         num_channels=base_channels * 8),
            nn.ZeroPad2d(padding=(0, 2, 0, 2)),      
            nn.Conv2d(in_channels=base_channels * 8, out_channels=1,
                      kernel_size=4, stride=1)
            
        )
        
        self.apply(weight_init)
        
    def forward(self, x):
        """ Takes an image and produces a reduced feature map 
        """
        
        return self.discriminator(x)
    

In [None]:
# Test Discriminator
d = Discriminator()
x = torch.randn(1,3,256,256)

assert d(x).shape[1] == 1

## 2.3 Test with images

In [None]:
# To Torch Tensor
torch_sample = torch.tensor(sample_photo, dtype=torch.float32)

# -1 to 1
torch_sample = (torch_sample / 255.0 - 0.5) / 0.5

# H,W,C to C,H,W
torch_sample = torch_sample.permute(2,0,1)

# Add Batch Dim 1,C,H,W
torch_sample = torch_sample.unsqueeze(dim=0)

In [None]:
monet_generator = Generator()

In [None]:
with torch.inference_mode():
    monet_generator.eval()
    to_monet = monet_generator(torch_sample).detach()
    to_monet.shape

In [None]:
plt.subplot(1,2,1)
plt.title("Original Photo")
plt.imshow(torch_sample[0].permute(1,2,0) * 0.5 + 0.5)

plt.subplot(1,2,2)
plt.title("Monet Photo")
plt.imshow(to_monet[0].permute(1,2,0) * 0.5 + 0.5)