# 1) Mount drive, unzip data, clone repo, install packages

## 1.1) Mount Drive and define paths
Run provided colab code to mount Google Drive. Then define dataset paths relative to mount point.

In [None]:
# noinspection PyUnresolvedReferences,PyPackageRequirements
from google.colab import drive

mount_root_abs = '/content/drive'
drive.mount(mount_root_abs, force_remount=True)
drive_root = f'{mount_root_abs}/MyDrive/ProjectGStorage'

## 1.2) Unzip Img directory in Colab
By unzipping the `lin-48x80.zip` in Colab before running our model we gain significant disk reading speedups.
So, the first step is to unzip images directory, and then save the image directory before proceeding.

In [None]:
import os

# Check if LIN Dataset is present / Download Dataset
df_root_drive = f'{drive_root}/Datasets/LIN_48x80'
if not os.path.exists(f'{df_root_drive}/lin-48x80.zip'):
    !pip install kaggle --upgrade
    os.environ['KAGGLE_CONFIG_DIR'] = drive_root
    !mkdir -p $df_root_drive
    !kaggle datasets download "achariso/lin-48x80" -p "$df_root_drive"

# Unzip
if not os.path.exists(f"/content/data/LIN_48x80/LIN_Normalized_WT_size-48-80_train"):
    !pip install unzip
    !mkdir -p "/content/data/LIN_48x80"
    !cp -f "$df_root_drive/lin-48x80.zip" "/content/data/LIN_48x80"
    !unzip -qq "/content/data/LIN_48x80/lin-48x80.zip" -d "/content/data/LIN_48x80"
    if os.path.exists(f'/content/data/LIN_48x80/LIN_Normalized_WT_size-48-80_train'):
        # Create symbolic links
        !ln -s "/content/data/LIN_48x80/LIN_Normalized_WT_size-48-80_train" "$df_root_drive/LIN_Normalized_WT_size-48-80_train"
        !ln -s "/content/data/LIN_48x80/LIN_Normalized_WT_size-48-80_test" "$df_root_drive/LIN_Normalized_WT_size-48-80_test"
    else:
        print(f'Error: Dataset not found at "/content/data/LIN_48x80"')

## 1.3) Clone GitHub repo
Clone achariso/gans-thesis repo into /content/code
 using git clone.
 For more info see: https://medium.com/@purba0101/how-to-clone-private-github-repo-in-google-colab-using-ssh-77384cfef18f

In [None]:
repo_root = '/content/code/biogans'
!rm -rf "$repo_root"
if not os.path.exists(repo_root) and not os.path.exists(f'{repo_root}/requirements.txt'):
    # Check that ssh keys exist
    assert os.path.exists(f'{drive_root}/GitHub Keys')
    id_rsa_abs_drive = f'{drive_root}/GitHub Keys/id_rsa'
    id_rsa_pub_abs_drive = f'{id_rsa_abs_drive}.pub'
    assert os.path.exists(id_rsa_abs_drive)
    assert os.path.exists(id_rsa_pub_abs_drive)
    # On first run: Add ssh key in repo
    if not os.path.exists('/root/.ssh'):
        # Transfer config file
        ssh_config_abs_drive = f'{drive_root}/GitHub Keys/config'
        assert os.path.exists(ssh_config_abs_drive)
        !mkdir -p ~ /.ssh
        !cp -f "$ssh_config_abs_drive" ~ /.ssh /
        # # Add github.com to known hosts
        !ssh-keyscan -t rsa github.com >> ~ /.ssh / known_hosts
        # Test: !ssh -T git@github.com

    # Remove any previous attempts
    !rm -rf "$repo_root"
    !mkdir -p "$repo_root"
    # Clone repo
    !git clone git @ github.com:kth-ml-course-projects / biogans.git "$repo_root"
    src_root = f'{repo_root}/src'

## 1.4) Install pip packages
All required files are stored in a requirements.txt files at the repository's root.
Use `pip install -r requirements.txt` from inside the dir to install required packages.

In [None]:
% cd "$repo_root"
!pip install -r requirements.txt

In [None]:
# import os
# os.kill(os.getpid(), 9)

import torch

assert torch.cuda.is_available()

## 1.5) Add code/, */src/ to path
This is necessary in order to be able to run the modules.

In [None]:
content_root_abs = f'{repo_root}'
src_root_abs = f'{repo_root}/src'
% env PYTHONPATH="/env/python:$content_root_abs:$src_root_abs

# 2) Add C2ST Metric to Existing Ones
The method has already implemented in `grdive/__init__`. Just a call to it is needed.

### Actual Run
Eventually, run the code!

In [None]:
chkpt_step = None  # supported: 'latest', <int>, None
log_level = 'debug'  # supported: 'debug', 'info', 'warning', 'error', 'critical', 'fatal'
device = 'cuda'  # supported: 'cpu', 'cuda', 'cuda:<GPU_INDEX>'
gdrive_which = 'personal'  # supported: 'personal', 'auth'

# Running with -i enables us to get variables defined inside the script (the script runs inline)
% run -i src / train_setup.py --log_level $log_level --chkpt_step $chkpt_step --seed 42 --device $device --gdrive_which $gdrive_which -use_refresh_token
% cd src /

### BioGAN Training
The code that follows defines the dataloaders/evaluators/models and the main training loop.


In [None]:
import torch
from IPython.core.display import display
from torch import Tensor
from torch.nn import DataParallel
# noinspection PyProtectedMember
from torch.utils.data import DataLoader

from datasets.lin import LINDataloader6Class
from modules.biogan_ind import BioGanInd6Class
from utils.metrics import GanEvaluator6Class

###################################
###  Hyper-parameters settings  ###
###################################
#   - training
n_epochs = 540

batch_size = 32 if not run_locally else 2
train_test_splits = [90, 10]  # for a 90% training - 10% evaluation set split
#   - evaluation
metrics_n_samples = 1000 if not run_locally else 2
metrics_batch_size = 32 if not run_locally else 1
f1_k = 3 if not run_locally else 1
#   - visualizations / checkpoints steps
display_step = 150
checkpoint_step = 6 * display_step
metrics_step = 3 * checkpoint_step  # evaluate model every 3 checkpoints

###################################
###   Dataset Initialization    ###
###################################
#   - the dataloader used to access the training dataset of cross-scale/pose image pairs at every epoch
#     > len(dataloader) = <number of batches>
#     > len(dataloader.dataset) = <number of total dataset items>
# FIX: Add subfolders to GoogleDrive
datasets_groot.subfolder_by_name('LIN_48x80').subfolder_by_name_or_create('LIN_Normalized_WT_size-48-80_train')
datasets_groot.subfolder_by_name('LIN_48x80').subfolder_by_name_or_create('LIN_Normalized_WT_size-48-80_test')
dataloader = LINDataloader6Class(dataset_fs_folder_or_root=datasets_groot, train_not_test=True,
                                 batch_size=batch_size, pin_memory=not run_locally, shuffle=True)
dataset = dataloader.dataset
dataset.logger.debug('Transforms: ' + repr(dataset.transforms))
#   - apply rudimentary tests
assert issubclass(dataloader.__class__, DataLoader)
assert len(dataloader) == len(dataset) // batch_size + (1 if len(dataset) % batch_size else 0)
_x = next(iter(dataloader))
assert tuple(_x.shape) == (6, batch_size, 2, 48, 80)

###################################
###    Models Initialization    ###
###################################
#   - initialize evaluator instance (used to run GAN evaluation metrics: FID, IS, PRECISION, RECALL, F1 and SSIM)
evaluator = GanEvaluator6Class(
    model_fs_folder_or_root=models_groot, gen_dataset=dataset, device=exec_device, z_dim=-1,
    n_samples=metrics_n_samples, batch_size=metrics_batch_size, f1_k=f1_k, ssim_c_img=2)
#   - initialize model
biogan_config = 'wgan-gp-independent-sep'  # or 'gan-independent-sep', 'wgan-gp-independent-sep', 'default'
chkpt_step = args.chkpt_step
try:
    if chkpt_step == 'latest':
        _chkpt_step = chkpt_step
    elif isinstance(chkpt_step, str) and chkpt_step.isdigit():
        _chkpt_step = int(chkpt_step)
    else:
        _chkpt_step = None
except NameError:
    _chkpt_step = None
biogan = BioGanInd6Class(model_fs_folder_or_root=models_groot, config_id=biogan_config, dataset_len=len(dataset),
                         chkpt_epoch=_chkpt_step, evaluator=evaluator, device=exec_device, log_level=log_level,
                         gen_transforms=dataloader.transforms)
biogan.logger.debug(f'Using device: {str(exec_device)}')
biogan.logger.debug(f'Model initialized. Number of params = {biogan.nparams_hr}')
# # FIX: Warmup counters before first batch
# if biogan.step is None:
#     biogan.gforward(batch_size=batch_size)
#     biogan.logger.debug(f'Model warmed-up (internal counters).')
# #   - setup multi-GPU training
# if torch.cuda.device_count() > 1:
#     biogan.gen = DataParallel(biogan.gen, list(range(torch.cuda.device_count())))
#     biogan.info(f'Using {torch.cuda.device_count()} GPUs for PGPG Generator (via torch.nn.DataParallel)')
# #   - load dataloader state (from model checkpoint)
# if 'dataloader' in biogan.other_state_dicts.keys():
#     dataloader.set_state(biogan.other_state_dicts['dataloader'])
#     biogan.logger.debug(f'Loaded dataloader state! Current pem_index={dataloader.get_state()["perm_index"]}')

### BioGAN Main Metrics Update call


In [None]:
import json

md = biogan.update_all_metrics(which='c2st')
print(json.dumps(md, indent=4))
biogan.visualize_metrics(upload=True, preview=True)