# 1) Set paths, check data, clone repo, install packages

## 1.1) Define paths
These should be defined after successful upload of dataset and assets (e.g. git keys)

In [None]:
import os

# Define dataset paths
df_icrb_root = '/kaggle/input/deepfashion-icrb'
df_icrb_img_root = f'{df_icrb_root}/Img'
assert os.path.exists(df_icrb_img_root), f'df_icrb_img_root={df_icrb_img_root}: NOT FOUND'

# Define asset paths
git_keys_root = '/kaggle/input/git-keys/github-keys'
assert os.path.exists(git_keys_root), f'git_keys_root={git_keys_root}: NOT FOUND'
client_secrets_path = '/kaggle/input/git-keys/client_secrets.json'
assert os.path.exists(client_secrets_path), f'client_secrets_path={client_secrets_path}: NOT FOUND'

Create the root Google Drive directory. This is where all model checkpoints/metrics exists as well as Datasets, Fonts etc. Symlink to dataset Img folder to avoid code changes and enable interoperability with Google Colab

In [None]:
# Create root directory if not exists
gdrive_root = '/kaggle/working/GoogleDrive'
!mkdir -p "$gdrive_root"

# Create the Dataset link inside Google Drive
gdrive_icrb_root = f'{gdrive_root}/Datasets/DeepFashion/In-shop Clothes Retrieval Benchmark'
!mkdir -p "$gdrive_root"/Datasets/DeepFashion/In-shop\ Clothes\ Retrieval\ Benchmark
!ln -s /kaggle/input/deepfashion-icrb/Img "$gdrive_icrb_root"

# Copy the Fonts dir inside local Google Drive root
!cp -r /kaggle/input/mplfonts/Fonts "$gdrive_root"

# Link the Inceptionv3 model Checkpoint inside local Google Drive root
!mkdir -p "$gdrive_root"/Models
!cp -r "/kaggle/input/inception-model/model_name=inceptionv3" "$gdrive_root"/Models
!mv "$gdrive_root"/Models/model_name=inceptionv3/Checkpoints/1a9a5a14.pth.bak "$gdrive_root"/Models/model_name=inceptionv3/Checkpoints/1a9a5a14.pth

# Create also an empty Img.zip file to fool GDriveDataset instance into thinking the dataset was downloaded
# and unzipped
!touch "$gdrive_icrb_root"/Img.zip

# FIX: We need client_secrets.json to be writable, so copy to /kaggle/working
!cp "$client_secrets_path" "$gdrive_root"
client_secrets_path = f'{gdrive_root}/client_secrets.json'

## 1.2) Clone github repo
Clone achariso/gans-thesis repo into /kaggle/working/code using git clone. For a similar procedure in Colab,
see: https://medium.com/@purba0101/how-to-clone-private-github-repo-in-google-colab-using-ssh-77384cfef18f

In [None]:
# Clean failed attempts
!rm -rf /root/.ssh
!rm -rf /kaggle/working/code
!mkdir -p /kaggle/working/code

repo_root = '/kaggle/working/code/gans-thesis'
if not os.path.exists(repo_root):
    # Check that ssh keys exist
    id_rsa_abs_drive = f'{git_keys_root}/id_rsa'
    id_rsa_pub_abs_drive = f'{id_rsa_abs_drive}.pub'
    assert os.path.exists(id_rsa_abs_drive)
    assert os.path.exists(id_rsa_pub_abs_drive)
    # On first run: Add ssh key in repo
    if not os.path.exists('/root/.ssh'):
        # Transfer config file
        ssh_config_abs_drive = f'{git_keys_root}/config'
        assert os.path.exists(ssh_config_abs_drive)
        !mkdir -p ~/.ssh
        !cp -f "$ssh_config_abs_drive" ~/.ssh/
        # Add github.com to known hosts
        !ssh-keyscan -t rsa github.com >> ~/.ssh/known_hosts
        # Test ssh connection
        # !ssh -T git@github.com

    # Remove any previous attempts
    !rm -rf "$repo_root"
    !mkdir -p "$repo_root"
    # Clone repo
    !git clone git@github.com:achariso/gans-thesis.git "$repo_root"
    src_root = f'{repo_root}/src'
    !rm -rf "$repo_root"/report

## 1.3) Install pip packages
All required files are stored in a requirements.txt files at the repository's root.
Use `pip install -r requirements.txt` from inside the dir to install required packages.

In [None]:
%cd $repo_root
!pip install -r requirements.txt

## 1.4) Update path to include src dir
This is necessary for the modules to function correctly

In [None]:
content_root_abs = f'{repo_root}'
src_root_abs = f'{repo_root}/src'
%env PYTHONPATH="/kaggle/lib/kagglegym:/kaggle/lib:$content_root_abs:$src_root_abs"

# 2) Train PGPG model on DeepFashion
In this section we run the actual training loop for PGPG network. PGPG consists of a 2-stage generator, where each stage is a UNET-like model, and, in our version, a PatchGAN discriminator.


### Colab Bug Workaround
Bug: matplotlib cache not rebuilding.
Solution: Run the following code and then restart the kernel (now included inside `src/train_pgpg.py`)


### Actual Run
Eventually, run the code!

In [None]:
chkpt_step = None       # supported: 'latest', <int>, None
log_level = 'debug'     # supported: 'debug', 'info', 'warning', 'error', 'critical', 'fatal'

# Running with -i enables us to get variables defined inside the script (the script runs inline)
%run -i src/train_pgpg.py

# 3) Evaluate PGPG
In this section we evaluate the generation performance of our trained network using the SOTA GAN evaluation metrics.

## 3.1) Get the metrics evolution plots
We plot how the metrics evolved during training. The GAN is **not** trained to minimize those metrics (they are
calculated using `torch.no_grad()`) and thus this evolution merely depends on the network and showcases the correlation
between the GAN evaluation metrics, and the losses (e.g. adversarial & reconstruction) used to optimize the network.

In [None]:
# Since the PGPG implements utils.ifaces.Visualizable, we can
# directly call visualize_metrics() on the model instance.
_ = pgpg.visualize_metrics(upload=True, preview=True)

## 3.2) Evaluate Generated Samples
In order to evaluate generated samples and compare model with other GAN architectures trained on the same dataset. For this purpose we will re-calculate the evaluation metrics as stated above, but with a much bigger number of samples. In this way, the metrics will be more trustworthy and comparable with the corresponding metrics in the original paper.


In [None]:
# Initialize a new evaluator instance
# (used to run GAN evaluation metrics: FID, IS, PRECISION, RECALL, F1 and SSIM)
evaluator = GanEvaluator(model_fs_folder_or_root=models_groot, gen_dataset=dataset, target_index=1, device=exec_device,
                         condition_indices=(0, 2), n_samples=10000, batch_size=metrics_batch_size,
                         f1_k=f1_k)
# Run the evaluator
metrics_dict = evaluator.evaluate(gen=pgpg.gen, metric_name='all', show_progress=True)

# Print results
import json
print(json.dumps(metrics_dict, indent=4))