## M5 - pix2pix : Image-to-Image Translation

For this notebook we will translate one image to another, as described in the [pix2pix article](https://arxiv.org/pdf/1611.07004.pdf). Using the `cityscapes` dataset, we will map from the original photo to its semantic map.
Even though we are mapping from the original photo to its semantic map, we could do the other way around. That is the beauty of image2image translation.

We will train a discriminator and a UNet generator to map from one image to another. We assume that this map exists, and through adversarial training we will try to obtain this mapping f:

$$f: X \to Y, \,\,\, X,Y \in \mathbb{R}^{\text{HxWxC}}$$


In [None]:
import numpy as np
import matplotlib.pyplot as plt
import os
import glob
import time

import torch
from torch import nn
from torch.nn import functional as F

import torchvision.transforms as transforms
from torch.utils.data import DataLoader, Dataset
from torchvision import datasets
import datetime

from PIL import Image
import sys
sys.path.append('..')

In [None]:
args = {
    'epochs': 200,
    'img_size': 256,
    'batch_size': 32,
    'cuda': True if torch.cuda.is_available() else False,
    'sample_interval': 20,
    'checkpoint_interval': -1,
    'dataset_name': 'cityscapes',

    'lr': 0.0002,
    'b1': 0.5,
    'b2': 0.999,
}

if torch.cuda.is_available():
    args['device'] = torch.device('cuda')
else:
    args['device'] = torch.device('cpu')

print(args['device'])

### Dataset

Our image dataset is composed of multiple images. For the cityscapes dataset, the image and the "label" are present in the same file. To split these two images, we will crop the .jpg in half. For this we will use the `PIL` library to both load the image and crop it.

It is worh mentioning that the loaded image is in format *(WxHxC)*, so the channels are in the last dimension of my array. As we know, `pytorch` expects the channel to be in the dimension. We will ensure that using the `transpose` function, to make the image *(CxWxH)*.

Finally, we will scale the image from the range [0,255] to [-1,1], as it is more stable to train GANs.

In [None]:
class ImageDataset(Dataset):
    def __init__(self, root = '/pgeoprj/ciag2023/datasets/cityscapes/', transforms_= None, mode="train"):
        
        
        self.transform = transforms_

        self.files = sorted(glob.glob(os.path.join(root, mode) + "/*.*"))
        
    def __getitem__(self, index):
        
        img = Image.open(self.files[index % len(self.files)])
        w, h = img.size
        img_A = img.crop((0, 0, w / 2, h))
        img_B = img.crop((w / 2, 0, w, h))

        if np.random.random() < 0.5:
            img_A = Image.fromarray(np.array(img_A)[:, ::-1, :], "RGB")
            img_B = Image.fromarray(np.array(img_B)[:, ::-1, :], "RGB")
        
        if self.transform:
            img_A = self.transform(img_A)
            img_B = self.transform(img_B)
        
        img_A = np.array(img_A)
        img_B = np.array(img_B)
        
        img_A = (img_A.transpose(2, 0, 1) / 127.5) - 1.0
        img_B = (img_B.transpose(2, 0, 1) / 127.5) - 1.0
        
        return img_B.astype(np.float32), img_A.astype(np.float32)

    def __len__(self):
        return len(self.files)

transforms_ = transforms.Compose([
    transforms.Resize( size=(args['img_size'], args['img_size']) ),
])

In [None]:
cityscapes = ImageDataset(transforms_ = transforms_)
len(cityscapes)

In [None]:
cityscapes_dataloader = DataLoader(dataset = cityscapes,
                                batch_size = args['batch_size'],
                                shuffle = True,
                                num_workers = 1)

### Generator and Discriminator

The discriminator is a pretty straightforward network: it receives two images (imgA, and imgB, which we want to obtain from imgA), which are concatenated along the channels and performs a downsample convolution using `nn.Conv2d(in_filters, out_filters, 4, stride=2, padding=1)` to sequentially decrease the image in size and outputs a 1-d tensor, 0 or 1, stating if the passed image is fake or real. As we learn in [DCGANs](https://arxiv.org/pdf/1511.06434.pdf), we shall avoid Feed Forward Networks and ReLUs, so we will perform only convolutions and use the `LeakyReLU`.

For the generator, we will use an [UNet](https://arxiv.org/pdf/1505.04597v1.pdf). We extensively studied the UNet, and it should be pretty clear how it was constructed and its advantages, due to its "skip" connection.

Defining the `weight_init_normal` we will instantiate how the weights of our networks will be initialized.

We will import it from the `models.py`, located in the previous folder. 

In [None]:
def weights_init_normal(m):
    classname = m.__class__.__name__
    if classname.find("Conv") != -1:
        torch.nn.init.normal_(m.weight.data, 0.0, 0.02)
    elif classname.find("BatchNorm2d") != -1:
        torch.nn.init.normal_(m.weight.data, 1.0, 0.02)
        torch.nn.init.constant_(m.bias.data, 0.0)


In [None]:
from models import GeneratorUNet, Discriminator
generator = GeneratorUNet().to(args['device'])
discriminator = Discriminator().to(args['device'])

# Initialize weight
generator.apply(weights_init_normal)
discriminator.apply(weights_init_normal);

### Visualizing the generated image and its corresponding inputs.

In [None]:
def sample_images(data_loader):
    
    imgs_real, label = next( iter(data_loader) )
    imgs_real = imgs_real.to(args['device'])
    label = label.to(args['device'])

    imgs_fake = generator(label)

    fig, ax = plt.subplots(nrows=min(imgs_real.size(0), 2), ncols=3, figsize=(12, 8))

    for i in range( min(imgs_real.size(0), 2) ):

        ax[i, 0].imshow(imgs_real.data[i].cpu().numpy().transpose(1, 2, 0) * 0.5 + 0.5)
        ax[i, 0].set_yticks([])
        ax[i, 0].set_xticks([])
        ax[i, 0].set_title('Real')

        ax[i, 1].imshow(imgs_fake.data[i].cpu().numpy().transpose(1, 2, 0) * 0.5 + 0.5)
        ax[i, 1].set_yticks([])
        ax[i, 1].set_xticks([])
        ax[i, 1].set_title('Generated')

        ax[i, 2].imshow(label.data[i].cpu().numpy().transpose(1, 2, 0) * 0.5 + 0.5)
        ax[i, 2].set_yticks([])
        ax[i, 2].set_xticks([])
        ax[i, 2].set_title('Label')

    plt.show()

In [None]:
sample_images(cityscapes_dataloader)

### Defining the loss functions and the optimizers

We will use the `nn.MSELoss`, due to the [Least Square GANs](https://arxiv.org/pdf/1611.04076.pdf). The original pix2pix defined a pixelwise loss, which we will calculate using the `nn.L1Loss`.

For the optimizers, we will use `optim.Adam`, with the learning rate and weight decay defined on the `args` dictionary.

In [None]:
# Loss functions
criterion_GAN = torch.nn.MSELoss()
criterion_pixelwise = torch.nn.L1Loss()

In [None]:
# Loss weight of L1 pixel-wise loss between translated image and real image
lambda_pixel = 100

# Calculate output of image discriminator (PatchGAN)
patch = (1, args['img_size'] // 2 ** 4, args['img_size'] // 2 ** 4)

In [None]:
# Optimizers
optimizer_G = torch.optim.Adam( generator.parameters(),
                               lr = args['lr'],
                               betas = (args['b1'], args['b2']) )

optimizer_D = torch.optim.Adam( discriminator.parameters(),
                               lr = args['lr'],
                               betas = (args['b1'], args['b2']) )

## Training procedure


The overall workflow goes as follows:

    1. Generate fake image using the generator
        2. Get the discriminator prediction for the fake image
        3. Calculate the generator loss using the MSE and the pixelwise loss.
    
    4. Get the discriminator prediction for the real image 
        5. Calculate the discriminator loss using MSE for the real image.
    
    6. Get the discriminator prediction for the fake image (with torch.no_grad()).
        7. Calculate the discriminator loss using MSE for the fake image.

### Losses

> Generator 

The `loss_GAN` is responsible for applying the MSE between the output of the discriminator and the `y_true` which is always 1. Therefore, this `loss_GAN`, which will be used into the `loss_G` (Generator loss), is basically tricking the discriminator to predict 1, even though the passed image to the discriminator is fake.

The pixelwise loss, `loss_pixel` calculates the difference between the predicted image and the real image, that is between what my generator generated, and what it would ideally generate.

Finally, we sum these two losses balancing using the `lambda_pixel`, due to PatchGAN.

> Discriminator

`pred_real` calculates the discriminator output for the real image. `loss_real` calculates the loss between this ouput and `y_true`, which is always 1. So we are tricking the discriminator to predict 1 to the real image.

`pred_fake` calculates the discriminator output for the fake image. `loss_fake` calculates the loss between this ouput and `y_fake`, which is always 0. So we are tricking the discriminator to predict 0 to the fake image.

Finally, we take the mean of the `loss_fake` and `loss_real`.



In [None]:
################################################################################
#  Training ####################################################################
################################################################################

prev_time = time.time()

for epoch in range(1, args['epochs'] + 1):

    for i, (real_img, label) in enumerate(cityscapes_dataloader):

        # Transfer images and labels to GPU
        real_img = real_img.to(args['device'])
        label = label.to(args['device'])

        # Adversarial ground truths
        y_true = torch.ones(size=(real_img.size(0), *patch), requires_grad=False).to(args['device'])
        y_fake = torch.zeros(size=(real_img.size(0), *patch), requires_grad=False).to(args['device'])

        # ------------------
        #  Train Generators
        # ------------------

        # Clearing gradients for G optimizer.
        optimizer_G.zero_grad()

        # GAN loss.
        fake_img = generator(label)
        pred_fake = discriminator(fake_img, label)
        loss_GAN = criterion_GAN(pred_fake, y_true)

        # Pixel-wise loss
        loss_pixel = criterion_pixelwise(fake_img, real_img)

        # Total loss
        loss_G = (loss_GAN + lambda_pixel * loss_pixel) / (1 + lambda_pixel)

        # G backward and optimizer step.
        loss_G.backward()
        optimizer_G.step()

        # ---------------------
        #  Train Discriminator
        # ---------------------

        # Clearing gradients for D optimizer.
        optimizer_D.zero_grad()

        # Real loss
        pred_real = discriminator(real_img, label)
        loss_real = criterion_GAN(pred_real, y_true)

        # Fake loss
        pred_fake = discriminator(fake_img.detach(), label)
        loss_fake = criterion_GAN(pred_fake, y_fake)

        # Total loss
        loss_D = 0.5 * (loss_real + loss_fake)

        # D backward and optimizer step.
        loss_D.backward()
        optimizer_D.step()

        # --------------
        #  Log Progress
        # --------------

        # Determine approximate time left
        batches_done = epoch * len(cityscapes_dataloader) + i
        batches_left = args['epochs'] * len(cityscapes_dataloader) - batches_done
        time_left = datetime.timedelta(seconds=batches_left * (time.time() - prev_time))
        prev_time = time.time()

        # If at sample interval save image
        if batches_done % args['sample_interval'] == 0:

            # Print log
            print(f'[Epoch {epoch}/{ args["epochs"] }] [Batch {i}/{len(cityscapes_dataloader)}] [D loss: {loss_D.item():.4f}] [G loss: {loss_G.item():.4f}, pixel: {loss_pixel.item():.4f}, adv: {loss_GAN.item():.4f}] ETA: {time_left}')
            sample_images(cityscapes_dataloader)


### Validating the generator network

In [None]:
cityscapes_validation = ImageDataset(transforms_ = transforms_, mode = 'val')
cityscapes_validation_dataloader = DataLoader(dataset = cityscapes,
                                batch_size = args['batch_size'],
                                shuffle = True,
                                num_workers = 1)

def validate(dataloader):
    
    generator.eval()
    with torch.no_grad()
        sample_images(dataloader)

In [None]:
validate(cityscapes_validation_dataloader)