# 1) Mount drive, unzip data, clone repo, install packages

## 1.1) Mount Drive and define paths
Run provided colab code to mount Google Drive. Then define dataset paths relative to mount point.

In [None]:
!rm -rf /content/sample_data
!rm -rf /content/*.jpg
!rm -rf /content/*.png
!rm -rf /content/*.json

In [None]:
# noinspection PyUnresolvedReferences
from google.colab import drive
mount_root_abs = '/content/drive'
drive.mount(mount_root_abs)
drive_root = f'{mount_root_abs}/MyDrive'

In [None]:
import os
# DeepFashion In-shop Clothes Retrieval Benchmark (ICRB)
df_root_drive = f'{drive_root}/Datasets/DeepFashion'
assert os.path.exists(df_root_drive)
df_icrb_root_drive = f'{df_root_drive}/In-shop Clothes Retrieval Benchmark'
assert os.path.exists(df_icrb_root_drive)
df_icrb_img_zip_abs_drive = f'{df_icrb_root_drive}/Img.zip'

# If Img.zip is not present, we need to unzip .../Img/img_iuv.zip directory
# from drive root and then run ICRBScraper.run() from /src/dataset/deep_fashion.
# For this nb, we skip this since it'll take an eternity to complete with
# mounted Google Drive.
assert os.path.exists(df_icrb_img_zip_abs_drive), \
  'Please upload a processed zip (processing img.zip in colab will take' + \
  f' for AGES). \nTried: {df_icrb_img_zip_abs_drive}'

## 1.2) Unzip Img directory in Colab
By unzipping the Img_processed.zip in Colab before running our model we gain significant disk reading speedups.
So, the first step is to unzip images directory, and then save the image directory before proceeding.

In [None]:
df_icrb_root = df_icrb_root_drive.replace(drive_root, '/content/data')
df_icrb_img_root = f'{df_icrb_root}/Img'
if not os.path.exists(df_icrb_img_root):
  # Clear any previous attempts
  # ATTENTION: This will remove /contents/data/*. So, before running, please make
  # sure no usable files will be deleted.
  !mkdir -p /content/data
  !rm -rf /content/data

  # Create output directory
  !mkdir -p "$df_icrb_root"

  # Transfer Img.zip from Google Drive to Colab
  df_icrb_img_zip_abs = f'{df_icrb_root}/{os.path.basename(df_icrb_img_zip_abs_drive)}'
  if not os.path.exists(df_icrb_img_zip_abs):
    !cp "$df_icrb_img_zip_abs_drive" "$df_icrb_root"
  # Unzip it in Colab
  !unzip -q "$df_icrb_img_zip_abs" -d "$df_icrb_root"
  # Handle newly-created image directory
  assert os.path.exists(df_icrb_img_root), f'df_icrb_img_root: {df_icrb_img_root}'
  assert not os.path.exists(f'{df_icrb_img_root}/Img')
  assert not os.path.exists(f'{df_icrb_img_root}/img')
  !rm -f "$df_icrb_img_zip_abs"
  assert not os.path.exists(df_icrb_img_zip_abs)

## 1.3) Clone github repo
Clone achariso/gans-thesis repo into /content/code
 using git clone.
 For more info see: https://medium.com/@purba0101/how-to-clone-private-github-repo-in-google-colab-using-ssh-77384cfef18f

In [None]:
repo_root = '/content/code/gans-thesis'
if not os.path.exists(repo_root) and not os.path.exists(f'{repo_root}/requirements.txt'):
  # Check that ssh keys exist
  assert os.path.exists(f'{drive_root}/GitHub Keys')
  id_rsa_abs_drive = f'{drive_root}/GitHub Keys/id_rsa'
  id_rsa_pub_abs_drive = f'{id_rsa_abs_drive}.pub'
  assert os.path.exists(id_rsa_abs_drive)
  assert os.path.exists(id_rsa_pub_abs_drive)
  # On first run: Add ssh key in repo
  # !cat "$id_rsa_pub_abs_drive", copy & paste in repo's Deploy Keys
  # Transfer config file
  ssh_config_abs_drive = f'{drive_root}/GitHub Keys/config'
  assert os.path.exists(ssh_config_abs_drive)
  !mkdir -p ~/.ssh
  !cp -f "$ssh_config_abs_drive" ~/.ssh/
  # # Add github.com to known hosts
  !ssh-keyscan -t rsa github.com >> ~/.ssh/known_hosts
  # Test: !ssh -T git@github.com

  # Remove any previous attempts
  !mkdir -p "$repo_root"
  !rm -rf "$repo_root"
  !mkdir -p "$repo_root"
  # Clone repo
  !git clone git@github.com:achariso/gans-thesis.git "$repo_root"
  src_root = f'{repo_root}/src'
  !rm -rf "$repo_root"/report

## 1.4) Install pip packages
All required files are stored in a requirements.txt files at the repository's root.
Use `pip install -r requirements.txt` from inside the dir to install required packages.

In [None]:
%cd $repo_root
!pip install -r requirements.txt

In [None]:
import torch
assert torch.cuda.is_available()

## 1.6) Add code/, */src/ to path
This is necessary in order to be able to run the modules.

In [None]:
content_root_abs = f'{repo_root}'
src_root_abs = f'{repo_root}/src'
%env PYTHONPATH="/env/python:$content_root_abs:$src_root_abs
%env PATH="$PATH:$content_root_abs:$src_root_abs

# 2) Test code
Test that pulled code is running by running its tests.

In [None]:
import matplotlib.pyplot as plt

# Fix: Adjust figure size for better plotting in Colab
plt.rcParams["figure.figsize"] = (20, 20)

In [None]:
%run src/main.py

In [None]:
test_root = f'{repo_root}/tests'
!python -m unittest discover -s "$test_root" -t "$test_root"


# 3) Train PGPG model on DeepFashion
In this section we give the actual training code for PGPG network. PGPG consists of a 2-stage generator, where each stage is a UNET-like model, and, in our version, a PatchGAN discriminator.

In [None]:
import torch
import torch.nn as nn
from torch.utils.data import DataLoader
from tqdm import tqdm

## 3.1) Hyper-parameters settings
In this sub-section we define some important training hyper-parameters such as the batch size, and the number of epochs.

In [None]:
# Training hyperparams
n_epochs = 100
batch_size = 48
lr = 1e-3
adv_criterion = nn.MSELoss()
recon_criterion = nn.L1Loss()

load_shape = 256
load_channels = 3
target_shape = 128
target_channels = 3

use_cuda = torch.cuda.is_available()
device = torch.device('cuda' if use_cuda else 'cpu')

# Set save_model to True to have training loop crate model checkpoints at a
# certain interval
save_model = True

# Set the desired LR schedulers for Generator and Discriminator optimizers
# Supported types: 'on_plateau', 'cyclic'
gen_lr_scheduler_type = 'on_plateau'
disc_lr_scheduler_type = None

# Step listeners
display_step = 200
checkpoint_step = 500

model_chkpts_root_drive = '/content/drive/MyDrive/Model Checkpoints'

## 3.2) Data
In this sub-section we define the dataset & the dataloader to use. In particular, we use DeepFashion's In-shop Clothes Retrieval Benchmark (ICRB) in cross-pose mode. As so, at each iteration dataloader yields 3 images:


*   `image_1`: the input (condition) image to the generator (from which the styles, clothes, etc. will be "fetched")
*   `image_2`: `image_1` at a different pose and/or scale (from which the output "body structure" and "position" will be "fetched")
*   `dense_pose_2`: the dense pose annotation (as an RGB image) of `image_2`



In [None]:
from dataset.deep_fashion import ICRBCrossPoseDataloader

# Image transforms:
# If target_shape is different from load one, resize & crop. If target_shape is
# different from load shape, convert to grayscale.
# Update: Now done automatically if you set target_channels, target_shape when
# instantiating the dataloader
#   - len(dataloader) = <number of batches>
#   - len(dataloader.dataset) = <number of total dataset items>
dl = ICRBCrossPoseDataloader(batch_size=batch_size,
                             target_channels=target_channels,
                             target_shape=target_shape,
                             norm_mean=0.5, norm_std=0.5,
                             skip_pose_norm=True)

print(dl.dataset.transforms)

# and apply basic tests
assert issubclass(dl.__class__, DataLoader)
assert len(dl) == len(dl.dataset) // batch_size + \
                              (1 if len(dl.dataset) % batch_size else 0)
_image_1, _image_2, _dense_pose_2 = next(iter(dl))
assert tuple(_image_1.shape) == (batch_size, target_channels, target_shape, target_shape)
assert tuple(_image_2.shape) == (batch_size, target_channels, target_shape, target_shape)
assert tuple(_dense_pose_2.shape) == (batch_size, target_channels, target_shape, target_shape)

## 3.3) Define PGPG model
In this sub-section we define the PGPG's children network models and in particular:


*   the stage-I generator, `G1`: UNET-like with skip connections, and an FC layer at the bottleneck
*   the stage-II generator, `G2`: UNET-like with skip connections and Dropout at the first half of the "encoder" part of the UNET-like model
*   the whole generator, `G`: composed of the two stages
*   the discriminator, `D`: PatchGAN discriminator



In [None]:
from modules.discriminators.patch_gan import PatchGANDiscriminator
from modules.generators.pgpg import PGPGGenerator
from utils.train import get_adam_optimizer, weights_init_naive, load_model_chkpt, get_optimizer_lr_scheduler

torch.cuda.empty_cache()

# Define models
gen = PGPGGenerator(c_in=2*target_channels, c_out=target_channels,
                    w_in=target_shape, h_in=target_shape).to(device)
disc = PatchGANDiscriminator(c_in=2*target_channels, n_contracting_blocks=6,
                             use_spectral_norm=True).to(device)
# Define optimizers
# Note: Both generators, G1 & G2, are trained using a joint optimizer
gen_opt = get_adam_optimizer(gen, lr=lr, betas=(0.9, 0.999))
disc_opt = get_adam_optimizer(disc, lr=lr, betas=(0.9, 0.999))

# Optimizer LR Schedulers
if gen_lr_scheduler_type:
    gen_opt_lr_scheduler = get_optimizer_lr_scheduler(gen_opt, schedule_type=gen_lr_scheduler_type)
else:
    gen_opt_lr_scheduler = None
if disc_lr_scheduler_type:
    disc_opt_lr_scheduler = get_optimizer_lr_scheduler(disc_opt, schedule_type=disc_lr_scheduler_type)
else:
    disc_opt_lr_scheduler = None

# Initialize weights
gen = gen.apply(weights_init_naive)
disc = disc.apply(weights_init_naive)

# Load checkpoints
if save_model:
    try:
        load_model_chkpt(model=disc, model_name='pgpg', dict_key='disc', model_opt=disc_opt)
        print('Loaded Discriminator checkpoint!')
    except AssertionError as e:
        print(f'Error while loading Discriminator checkpoint: {str(e)}')

    try:
        chkpt_images = load_model_chkpt(model=gen, model_name='pgpg', dict_key='gen', model_opt=gen_opt)
        chkpt_epoch, chkpt_images_in_current_epoch = divmod(chkpt_images, len(dl))
        chkpt_step = chkpt_images_in_current_epoch // batch_size
        print('Loaded Generator checkpoint!')
        print(f'\t--> chkpt_images={chkpt_images}, chkpt_epoch={chkpt_epoch}, chkpt_step={chkpt_step}')
    except AssertionError as e:
        print(f'Error while loading Generator checkpoint: {str(e)}')
        chkpt_epoch = chkpt_step = 0
else:
    chkpt_epoch = chkpt_step = 0

## 3.4) Evaluation Metrics
The evaluation metrics that we will use to evaluate model performance during an after training are:

* Frechet Inception Distance, FID
* Inception Score, IS
* Structural Similarity Index, (SSIM)
* Precision/Recall/F1, (P, R, F1)

These metrics are saved on each model checkpoint. Below, we initialize the calculators for these metrics.


In [None]:
from dataset.deep_fashion import ICRBCrossPoseDataset, ICRBDataset

from utils.metrics.fid import FID
from utils.metrics.is_ import IS
from utils.metrics.f1 import F1
from utils.metrics.ssim import SSIM

metrics_n_sample = 2048
metrics_batch_sample = 16
metrics_gen_transforms = ICRBDataset.get_image_transforms(target_shape=target_shape, target_channels=target_channels)
metrics_dataset = ICRBCrossPoseDataset(image_transforms=None, pose=True)
metrics_dataset_with_transforms = ICRBCrossPoseDataset(image_transforms=metrics_gen_transforms, pose=True)

fid_calculator = FID(n_samples=metrics_n_sample, batch_size=metrics_batch_sample, device='cuda')
is_calculator = IS(n_samples=metrics_n_sample, batch_size=metrics_batch_sample, device='cuda')
f1_calculator = F1(n_samples=metrics_n_sample, batch_size=metrics_batch_sample, device='cuda')
ssim_calculator = SSIM(n_samples=metrics_n_sample, batch_size=metrics_batch_sample, 
                       c_img=target_channels, device='cuda')


## 3.5) PGPG Training
Finally, let's start training the model! In the following cells we also define methods to print images to be able to preview inference (i.e. genration) quality evolution during training.

In [None]:
import matplotlib.pyplot as plt

mean_generator_loss = 0
mean_discriminator_loss = 0
cur_step = chkpt_step
torch.cuda.empty_cache()
for epoch in range(chkpt_epoch, n_epochs):
  for image_1, image_2, pose_2 in tqdm(dl, initial=chkpt_step):
    cur_batch_size = len(image_1)

    image_1 = image_1.to(device)
    image_2 = image_2.to(device)
    pose_2 = pose_2.to(device)


    ##########################################
    ########   Update Discriminator   ########
    ##########################################
    disc_opt.zero_grad()                    # Zero out gradient before backprop
    with torch.no_grad():
      _, g_out = gen(image_1, pose_2)
    disc_loss = disc.get_loss(real=image_2, fake=g_out, condition=image_1,
                              criterion=adv_criterion)
    disc_loss.backward(retain_graph=True)   # Update discriminator gradients
    disc_opt.step()                         # Update discriminator weights
    # Update LR (if needed)
    if disc_lr_scheduler_type and disc_opt_lr_scheduler:
        disc_opt_lr_scheduler.step(metrics=disc_loss) if disc_lr_scheduler_type == 'on_plateau' \
            else disc_opt_lr_scheduler.step()


    ##########################################
    ########     Update Generator     ########
    ##########################################
    gen_opt.zero_grad()
    g1_loss, g2_loss, g1_out, g_out = gen.get_loss(x=image_1, y=image_2,
                y_pose=pose_2.clone(), disc=disc, adv_criterion=adv_criterion,
                recon_criterion=recon_criterion)
    gen_loss = g1_loss + g2_loss
    gen_loss.backward()                     # Update generator gradients

    gen_opt.step()                          # Update generator optimizer
    # Update LR (if needed)
    if gen_lr_scheduler_type and gen_opt_lr_scheduler:
        gen_opt_lr_scheduler.step(metrics=gen_loss) if gen_lr_scheduler_type == 'on_plateau' \
            else gen_opt_lr_scheduler.step()


    ##########################################
    ########      Visualizations      ########
    ##########################################
    # Keep track of the average losses
    mean_discriminator_loss += disc_loss.item() / display_step
    mean_generator_loss += gen_loss.item() / display_step

    # Visualization code
    if cur_step % display_step == 0:
        print(f"\n\t--> Epoch {epoch}: Step {cur_step}: Generator loss: {mean_generator_loss}, Discriminator loss: {mean_discriminator_loss}")

        g1_out_first = g1_out[0].detach().cpu().float()
        g1_out_last = g1_out[-1].detach().cpu().float()
        g_out_first = g_out[0].detach().cpu().float()
        g_out_last = g_out[-1].detach().cpu().float()
        g2_out_fist = g_out_first - g1_out_first
        g2_out_last = g_out_last - g1_out_last

        cat_images = torch.cat((image_1[0].cpu(), pose_2[0].cpu(), image_2[0].cpu(), g1_out_first, g2_out_fist, g_out_first), dim=2)
        cat_images = cat_images.detach().cpu()
        if target_channels == 1:
          cat_images = cat_images.view([cat_images.shape[1], cat_images.shape[2]])
        else:
          cat_images = cat_images.permute(1, 2, 0)
        plt.imshow(cat_images, cmap='gray' if target_channels == 1 else None)
        plt.show()

        cat_images = torch.cat((image_1[-1].cpu(), pose_2[-1].cpu(), image_2[-1].cpu(), g1_out_last, g2_out_last, g_out_last), dim=2)
        cat_images = cat_images.detach().cpu()
        if target_channels == 1:
          cat_images = cat_images.view([cat_images.shape[1], cat_images.shape[2]])
        else:
          cat_images = cat_images.permute(1, 2, 0)
        plt.imshow(cat_images, cmap='gray' if target_channels == 1 else None)
        plt.show()

        mean_generator_loss = 0
        mean_discriminator_loss = 0


    ##########################################
    ########     Save checkpoints     ########
    ##########################################
    if save_model and cur_step % checkpoint_step == 0:
        # Calculate metrics
        gen = gen.eval()
        with torch.no_grad():
            metrics_fid = fid_calculator(metrics_dataset, gen=gen, gen_transforms=metrics_gen_transforms,
                                         target_index=1, condition_indices=(0, 2))
            metrics_is = is_calculator(metrics_dataset, gen=gen, gen_transforms=metrics_gen_transforms,
                                         target_index=1, condition_indices=(0, 2))
            metrics_f1, metrics_p, metrics_r = f1_calculator(metrics_dataset, gen=gen,
                                                             gen_transforms=metrics_gen_transforms,
                                                             target_index=1, condition_indices=(0, 2))
            metrics_ssim = ssim_calculator(metrics_dataset_with_transforms, gen=gen,
                                           target_index=1, condition_indices=(0, 2))
        gen = gen.train()
        # Save state dicts alongside metrics in a single .pth file
        torch.save({
            'gen': gen.state_dict(),
            'gen_opt': gen_opt.state_dict(),
            'disc': disc.state_dict(),
            'disc_opt': disc_opt.state_dict(),
            'metrics': {
                'fid': metrics_fid,
                'is': metrics_is,
                'f1': metrics_f1, 'precision': metrics_p, 'recall': metrics_r,
                'ssim': metrics_ssim
            }
        }, f"{model_chkpts_root_drive}/pgpg_{cur_step}_{batch_size}.pth")

    cur_step += 1
  chkpt_step = 0

## 3.5) Evaluate Generated Samples
In order to evaluate generated samples and compare model with other GAN architectures trained on the same dataset. For this purpose we will calculate the following three metrics:

* Frechet Inception Distance (FID) - lower is better
* Inception Score (IS) - higher is better
* Structural Similarity (SSIM) - higher is better


In [None]:
# Calculate final metrics
gen = gen.eval()
with torch.no_grad():
    metrics_fid = fid_calculator(metrics_dataset, gen=gen, gen_transforms=metrics_gen_transforms,
                                 target_index=1, condition_indices=(0, 2))
    metrics_is = is_calculator(metrics_dataset, gen=gen, gen_transforms=metrics_gen_transforms,
                                 target_index=1, condition_indices=(0, 2))
    metrics_f1, metrics_p, metrics_r = f1_calculator(metrics_dataset, gen=gen,
                                                     gen_transforms=metrics_gen_transforms,
                                                     target_index=1, condition_indices=(0, 2))
    metrics_ssim = ssim_calculator(metrics_dataset_with_transforms, gen=gen,
                                   target_index=1, condition_indices=(0, 2))
print(str({
    'fid': metrics_fid,
    'is': metrics_is,
    'f1': metrics_f1, 'precision': metrics_p, 'recall': metrics_r,
    'ssim': metrics_ssim
}))