# 1) Mount drive, unzip data, clone repo, install packages

## 1.1) Define paths
The dataset's main `Img.zip` must be present as the input dataset. Then, run provided code to define dataset paths
relative to mount point (`/kaggle/input` & `/kaggle/working`).

In [None]:
import os

# Define dataset paths
# LookBook + (partially) DeepFashion In-shop Clothes Retrieval Benchmark (ICRB)
lb_root = '/kaggle/input/lookbook-deepfashion-model2clothes'
lb_img_root = f'{lb_root}/Img'
assert os.path.exists(lb_img_root), f'lb_img_root={lb_img_root}: NOT FOUND'

# Define asset paths
git_keys_root = '/kaggle/input/git-keys/github-keys'
assert os.path.exists(git_keys_root), f'git_keys_root={git_keys_root}: NOT FOUND'
client_secrets_path = '/kaggle/input/git-keys/client_secrets.json'
assert os.path.exists(client_secrets_path), f'client_secrets_path={client_secrets_path}: NOT FOUND'

Create the root Google Drive directory. This is where all model checkpoints/metrics exists as well as Datasets, Fonts etc.
Symlink to dataset Img folder to avoid code changes and enable interoperability with Google Colab.

In [None]:
# Create root directory if not exists
gdrive_root = '/kaggle/working/GoogleDrive'
!mkdir -p "$gdrive_root"

# Create the Dataset link inside Google Drive
gdrive_icrb_root = f'{gdrive_root}/Datasets/LookBook'
!mkdir -p "$gdrive_icrb_root"
!ln -s "/kaggle/input/lookbook-deepfashion-model2clothes/Img" "$gdrive_icrb_root"

# Copy the Fonts dir inside local Google Drive root
!cp -r /kaggle/input/mplfonts/Fonts "$gdrive_root"

# Link the Inceptionv3 model Checkpoint inside local Google Drive root
!mkdir -p "$gdrive_root"/Models
!cp -r "/kaggle/input/inception-model/model_name=inceptionv3" "$gdrive_root"/Models
!mv "$gdrive_root"/Models/model_name=inceptionv3/Checkpoints/1a9a5a14.pth.bak "$gdrive_root"/Models/model_name=inceptionv3/Checkpoints/1a9a5a14.pth

# Create also an empty Img.zip file to fool GDriveDataset instance into thinking the dataset was downloaded
# and unzipped
!touch "$gdrive_icrb_root"/Img.zip

# FIX: We need client_secrets.json to be writable, so copy to /kaggle/working
!cp "$client_secrets_path" "$gdrive_root"
client_secrets_path = f'{gdrive_root}/client_secrets.json'

## 1.2) Clone github repo
Clone achariso/gans-thesis repo into /kaggle/working/code using git clone. For a similar procedure in Colab,
see: https://medium.com/@purba0101/how-to-clone-private-github-repo-in-google-colab-using-ssh-77384cfef18f

In [None]:
# Clean failed attempts
!rm -rf /root/.ssh
!rm -rf /kaggle/working/code
!mkdir -p /kaggle/working/code

repo_root = '/kaggle/working/code/gans-thesis'
if not os.path.exists(repo_root):
    # Check that ssh keys exist
    id_rsa_abs_drive = f'{git_keys_root}/id_rsa'
    id_rsa_pub_abs_drive = f'{id_rsa_abs_drive}.pub'
    assert os.path.exists(id_rsa_abs_drive)
    assert os.path.exists(id_rsa_pub_abs_drive)
    # On first run: Add ssh key in repo
    if not os.path.exists('/root/.ssh'):
        # Transfer config file
        ssh_config_abs_drive = f'{git_keys_root}/config'
        assert os.path.exists(ssh_config_abs_drive)
        !mkdir -p ~/.ssh
        !cp -f "$ssh_config_abs_drive" ~/.ssh/
        # Add github.com to known hosts
        !ssh-keyscan -t rsa github.com >> ~/.ssh/known_hosts
        # Test ssh connection
        # !ssh -T git@github.com

    # Remove any previous attempts
    !rm -rf "$repo_root"
    !mkdir -p "$repo_root"
    # Clone repo
    !git clone git@github.com:achariso/gans-thesis.git "$repo_root"
    src_root = f'{repo_root}/src'
    !rm -rf "$repo_root"/report

## 1.3) Install pip packages
All required files are stored in a requirements.txt files at the repository's root.
Use `pip install -r requirements.txt` from inside the dir to install required packages.

In [None]:
%cd "$repo_root"
!pip install -r requirements.txt

In [None]:
import torch
assert torch.cuda.is_available()

## 1.5) Add code/, */src/ to path
This is necessary in order to be able to run the modules.

In [None]:
content_root_abs = f'{repo_root}'
src_root_abs = f'{repo_root}/src'
%env PYTHONPATH="/env/python:$content_root_abs:$src_root_abs

# 2) Train PixelDTGAN model on LookBook + DeepFashion (part of the ICRB dataset)
In this section we run the actual training loop for PixelDTGan network. PixelDTGAN consists of a AE-like generator, and, in our version, two PatchGAN discriminators.


### Colab Bug Workaround
Bug: matplotlib cache not rebuilding.
Solution: Run the following code and then restart the kernel.

In [None]:
# os.kill(os.getpid(), 9)
# now inside src/train_setup.py

### Actual Run
Eventually, run the code!

In [None]:
chkpt_step = 'latest'   # supported: 'latest', <int>, None
log_level = 'debug'     # supported: 'debug', 'info', 'warning', 'error', 'critical', 'fatal'
devide = 'cuda'         # supported: 'cpu', 'cuda'

# From epoch=37, lambda_recon in G2's loss went from 1 --> 5
# From epoch=66, lambda_recon in G2's loss went from 5 --> 10

# Running with -i enables us to get variables defined inside the script (the script runs inline)
%run -i src/train_setup.py --log_level $log_level --chkpt_step $chkpt_step --seed 42 --device $devide -use_refresh_token

### PixelDTGAN Training
Setup / preparation before starting PixelDTGAN training loop.

In [None]:
%cd src/

import torch
from IPython.core.display import display
from torch import Tensor
from torch.nn import DataParallel
# noinspection PyProtectedMember
from torch.utils.data import DataLoader

from datasets.look_book import PixelDTDataset, PixelDTDataloader
from modules.pixel_dt_gan import PixelDTGan
from utils.dep_free import get_tqdm
from utils.ifaces import FilesystemDataset
from utils.metrics import GanEvaluator

###################################
###  Hyper-parameters settings  ###
###################################
# TODO: finish this notebook and train in Colab/Kaggle
#   - training
n_epochs = 300
batch_size = 256 if not run_locally else 2
train_test_splits = [90, 10]  # for a 90% training - 10% evaluation set split
#   - evaluation
metrics_n_samples = 1000 if not run_locally else 2
metrics_batch_size = 64 if not run_locally else 1
f1_k = 3 if not run_locally else 1
#   - visualizations / checkpoints steps
display_step = 200
checkpoint_step = 600
metrics_step = 1800  # evaluate model every 3 checkpoints
#   - dataset
target_shape = 64
target_channels = 3
#   - PixelDTGAN configuration
pxldtg_config_id = f'default'  # as proposed in the original paper

###################################
###   Dataset Initialization    ###
###################################
#   - image transforms:
#     If target_shape is different from load one, resize & crop. If target_shape is different from load shape,
#     convert to grayscale.
#     Update: Now done automatically if you set target_channels, target_shape when instantiating the dataloader.
gen_transforms = PixelDTDataset.get_image_transforms(target_shape=target_shape, target_channels=target_channels)
#   - the dataloader used to access the training dataset of cross-scale/pose image pairs at every epoch
#     > len(dataloader) = <number of batches>
#     > len(dataloader.dataset) = <number of total dataset items>
dataloader = PixelDTDataloader(dataset_fs_folder_or_root=datasets_groot, batch_size=batch_size,
                               image_transforms=gen_transforms, splits=train_test_splits,
                               pin_memory=not run_locally, log_level=log_level)
dataset = dataloader.dataset  # save training dataset as `dataset`
#   - ensure dataset is fetched locally and unzipped
if isinstance(dataset, FilesystemDataset):
    dataset.fetch_and_unzip(in_parallel=False, show_progress=True)
elif hasattr(dataset, 'dataset') and isinstance(dataset.dataset, FilesystemDataset):
    dataset.dataset.fetch_and_unzip(in_parallel=False, show_progress=True)
else:
    raise TypeError('dataset must implement utils.ifaces.FilesystemDataset in order to be auto-downloaded and unzipped')
#   - apply rudimentary tests
assert issubclass(dataloader.__class__, DataLoader)
assert len(dataloader) == len(dataset) // batch_size + (1 if len(dataset) % batch_size else 0)
_img_s, _img_t = next(iter(dataloader))
assert tuple(_img_s.shape) == (batch_size, target_channels, target_shape, target_shape)
assert tuple(_img_t.shape) == (batch_size, target_channels, target_shape, target_shape)

###################################
###    Models Initialization    ###
###################################
#   - initialize evaluator instance (used to run GAN evaluation metrics: FID, IS, PRECISION, RECALL, F1 and SSIM)
evaluator = GanEvaluator(model_fs_folder_or_root=models_groot, gen_dataset=dataset, target_index=1, device=exec_device,
                         condition_indices=(0, ), n_samples=metrics_n_samples, batch_size=metrics_batch_size,
                         f1_k=f1_k)
#   - initialize model
chkpt_step = args.chkpt_step
try:
    if chkpt_step == 'latest':
        pxldtg_chkpt_step = chkpt_step
    elif isinstance(chkpt_step, str) and chkpt_step.isdigit():
        pxldtg_chkpt_step = int(chkpt_step)
    else:
        pxldtg_chkpt_step = None
except NameError:
    pxldtg_chkpt_step = None
pxldtg = PixelDTGan(model_fs_folder_or_root=models_groot, config_id=pxldtg_config_id, dataset_len=len(dataset),
            chkpt_epoch=pxldtg_chkpt_step, evaluator=evaluator, device=exec_device, log_level=log_level)
pxldtg.logger.debug(f'Using device: {str(exec_device)}')
pxldtg.logger.debug(f'Model initialized. Number of params = {pxldtg.nparams_hr}')
# FIX: Warmup counters before first batch
if pxldtg.step is None:
    pxldtg.gforward(batch_size=batch_size)
    pxldtg.logger.debug(f'Model warmed-up (internal counters).')
#   - setup multi-GPU training
if torch.cuda.device_count() > 1:
    pxldtg.gen = DataParallel(pxldtg.gen)
    pxldtg.info(f'Using {torch.cuda.device_count()} GPUs for PixelDTGAN Generator (via torch.nn.DataParallel)')
#   - load dataloader state (from model checkpoint)
if 'dataloader' in pxldtg.other_state_dicts.keys():
    dataloader.set_state(pxldtg.other_state_dicts['dataloader'])
    pxldtg.logger.debug(f'Loaded dataloader state! Current pem_index={dataloader.get_state()["perm_index"]}')


In [None]:
# FIX: client_secrets.json invalid error
import json
with open(client_secrets_path) as json_fp:
    client_secrets = json.load(json_fp)
if 'web' not in client_secrets.keys():
    client_secrets = {'web': client_secrets}
    with open(client_secrets_path, 'w') as json_fp:
        json.dump(client_secrets, json_fp, indent=4)
    with open(client_secrets_path) as json_fp:
        print(json.dumps(json.load(json_fp), indent=4))

### PixelDTGAN Main training loop
Start / continue training PixelDTGAN until reaching the desired number of epochs.

In [None]:

###################################
###       Training Loop         ###
###################################
#   - get the correct tqdm instance
exec_tqdm = get_tqdm()
#   - start training loop from last checkpoint's epoch and step
gcapture_ready = True
async_results = None
pxldtg.logger.info(f'[training loop] STARTING (epoch={pxldtg.epoch}, step={pxldtg.initial_step})')
for epoch in range(pxldtg.epoch, n_epochs):
    # noinspection PyProtectedMember
    d = {
        'step': pxldtg.step,
        'initial_step': pxldtg.initial_step,
        'epoch': pxldtg.epoch,
        '_counter': pxldtg._counter,
        'epoch_inc': pxldtg.epoch_inc,
    }
    # initial_step = pxldtg.initial_step % len(dataloader)
    pxldtg.logger.debug('[START OF EPOCH] ' + str(d))

    img_s: Tensor
    img_t: Tensor
    for img_s, img_t in exec_tqdm(dataloader, initial=pxldtg.initial_step):
        # Transfer image batches to GPU
        img_s = img_s.to(exec_device)
        img_t = img_t.to(exec_device)

        # Perform a forward + backward pass + weight update on the Generator & Discriminator models
        disc_r_loss, disc_a_loss, gen_loss = pxldtg(img_s, img_t)

        # Metrics & Checkpoint Code
        if pxldtg.step % checkpoint_step == 0:
            # Check if another upload is pending
            if not gcapture_ready and async_results:
                # Wait for previous upload to finish
                pxldtg.logger.warning('Waiting for previous gcapture() to finish...')
                [r.wait() for r in async_results]
                pxldtg.logger.warning('DONE! Starting new capture now.')
            # Capture current model state, including metrics and visualizations
            async_results = pxldtg.gcapture(checkpoint=True, metrics=pxldtg.step % metrics_step == 0, visualizations=True,
                                          dataloader=dataloader, in_parallel=True, show_progress=True,
                                          delete_after=False)
        # Visualization code
        elif pxldtg.step % display_step == 0:
            visualization_img = pxldtg.visualize()
            visualization_img.show() if not in_notebook() else display(visualization_img)

        # Check if a pending checkpoint upload has finished
        if async_results:
            gcapture_ready = all([r.ready() for r in async_results])
            if gcapture_ready:
                pxldtg.logger.info(f'gcapture() finished')
                if pxldtg.latest_checkpoint_had_metrics:
                    pxldtg.logger.info(str(pxldtg.latest_metrics))
                async_results = None

        # If run locally one pass is enough
        if run_locally and gcapture_ready:
            break

    # If run locally one pass is enough
    if run_locally:
        break

    # noinspection PyProtectedMember
    d = {
        'step': pxldtg.step,
        'initial_step': pxldtg.initial_step,
        'epoch': pxldtg.epoch,
        '_counter': pxldtg._counter,
        'epoch_inc': pxldtg.epoch_inc,
    }
    pxldtg.logger.debug('[END OF EPOCH] ' + str(d))

# Check if a pending checkpoint exists
if async_results:
    ([r.wait() for r in async_results])
    pxldtg.logger.info(f'last gcapture() finished')
    if pxldtg.latest_checkpoint_had_metrics:
        pxldtg.logger.info(str(pxldtg.latest_metrics))
    async_results = None

# Training finished!
pxldtg.logger.info('[training loop] DONE')


# 3) Evaluate PixelDTGAN
In this section we evaluate the generation performance of our trained network using the SOTA GAN evaluation metrics.

## 3.1) Get the metrics evolution plots
We plot how the metrics evolved during training. The GAN is **not** trained to minimize those metrics (they are
calculated using `torch.no_grad()`) and thus this evolution merely depends on the network and showcases the correlation
between the GAN evaluation metrics, and the losses (e.g. adversarial & reconstruction) used to optimize the network.

In [None]:
# Since the PixelDTGAN implements utils.ifaces.Visualizable, we can
# directly call visualize_metrics() on the model instance.
_ = pxldtg.visualize_metrics(upload=True, preview=True)

## 3.2) Evaluate Generated Samples
In order to evaluate generated samples and compare model with other GAN architectures trained on the same dataset.
For this purpose we will re-calculate the evaluation metrics as stated above, but with a much bigger number of samples.
In this way, the metrics will be more trustworthy and comparable with the corresponding metrics in the original paper.

In [None]:
# Initialize a new evaluator instance
# (used to run GAN evaluation metrics: FID, IS, PRECISION, RECALL, F1 and SSIM)
evaluator = GanEvaluator(model_fs_folder_or_root=models_groot, gen_dataset=dataset, target_index=1, device=exec_device,
                         condition_indices=(0, 2), n_samples=10000, batch_size=metrics_batch_size,
                         f1_k=f1_k, ssim_c_img=target_channels)
# Run the evaluator
metrics_dict = evaluator.evaluate(gen=pxldtg.gen, metric_name='all', show_progress=True)

# Print results
import json
print(json.dumps(metrics_dict, indent=4))