# 1) Mount drive, unzip data, clone repo, install packages

## 1.1) Mount Drive and define paths
Run provided colab code to mount Google Drive. Then define dataset paths relative to mount point.

In [None]:
!rm -rf '/content/sample_data'
!rm -rf '/content/*.jpg'
!rm -rf '/content/*.png'
!rm -rf '/content/*.json'

In [None]:
# noinspection PyUnresolvedReferences,PyPackageRequirements
from google.colab import drive

mount_root_abs = '/content/drive'
drive.mount(mount_root_abs)
drive_root = f'{mount_root_abs}/MyDrive'

## 1.2) Clone GitHub repo
Clone achariso/gans-thesis repo into /content/code using git clone.
For more info see: https://medium.com/@purba0101/how-to-clone-private-github-repo-in-google-colab-using-ssh-77384cfef18f

In [None]:
import os

repo_root = '/content/code/gans-thesis'
!rm -rf "$repo_root"
if not os.path.exists(repo_root) and not os.path.exists(f'{repo_root}/requirements.txt'):
    # Check that ssh keys exist
    assert os.path.exists(f'{drive_root}/GitHub Keys')
    id_rsa_abs_drive = f'{drive_root}/GitHub Keys/id_rsa'
    id_rsa_pub_abs_drive = f'{id_rsa_abs_drive}.pub'
    assert os.path.exists(id_rsa_abs_drive)
    assert os.path.exists(id_rsa_pub_abs_drive)
    # On first run: Add ssh key in repo
    if not os.path.exists('/root/.ssh'):
        # Transfer config file
        ssh_config_abs_drive = f'{drive_root}/GitHub Keys/config'
        assert os.path.exists(ssh_config_abs_drive)
        !mkdir -p ~/.ssh
        !cp -f "$ssh_config_abs_drive" ~/.ssh/
        # # Add github.com to known hosts
        !ssh-keyscan -t rsa github.com >> ~/.ssh/known_hosts
        # Test: !ssh -T git@github.com

    # Remove any previous attempts
    !rm -rf "$repo_root"
    !mkdir -p "$repo_root"
    # Clone repo
    !git clone git@github.com:achariso/gans-thesis.git "$repo_root"
    src_root = f'{repo_root}/src'
    !rm -rf "$repo_root"/report

## 1.3) Install pip packages
All required files are stored in a requirements.txt files at the repository's root.
Use `pip install -r requirements.txt` from inside the dir to install required packages.

In [None]:
%cd "$repo_root"
!pip install -r requirements.txt

In [None]:
# import os
# os.kill(os.getpid(), 9)

In [None]:
import torch

## 1.4) Add code/, */src/ to path
This is necessary in order to be able to run the modules.

In [None]:
content_root_abs = f'{repo_root}'
src_root_abs = f'{repo_root}/src'
%env PYTHONPATH="/env/python:$content_root_abs:$src_root_abs"

# 2) Update model checkpoints
Initially, define the model name. Then define a function to handle each checkpoint. Then we'll see...

### Run setup for basic setup.
This sets-up access to GoogleDrive via custom accessors, the tqdm instance etc...

In [None]:
chkpt_step = None      # supported: 'latest', <int>, None
log_level = 'debug'    # supported: 'debug', 'info', 'warning', 'error', 'critical', 'fatal'
device = 'cpu'         # supported: 'cpu', 'cuda', 'cuda:<GPU_INDEX>'

# Running with -i enables us to get variables defined inside the script (the script runs inline)
%run -i src/train_setup.py --log_level $log_level --chkpt_step $chkpt_step --seed 42 --device $device

In [None]:
%cd 'src/'

import torch
from IPython.core.display import display
from torch import Tensor
# noinspection PyProtectedMember
from torch.utils.data import DataLoader
from utils.dep_free import get_tqdm
from utils.ifaces import FilesystemDataset


In [None]:
from modules.pgpg import PGPG
from modules.pixel_dt_gan import PixelDTGan
from modules.cycle_gan import CycleGAN
from modules.stylegan import StyleGan

# PGPG
pgpg_config_id = '128_MSE_256_6_4_5_none_none_1e4_true_false_false'
pgpg = PGPG(model_fs_folder_or_root=models_groot, config_id=pgpg_config_id,
            device=exec_device, log_level=log_level)
pgpg.logger.debug(f'[PGPG] Starting visualization in device: {str(exec_device)}')
#   - visualize
pgpg.visualize_losses(dict_keys=('gen', 'disc', ('gen', 'disc')), upload=True, preview=True, extract_dicts=True)
pgpg.visualize_metrics(upload=True, preview=True)

# PixelDTGAN
pxldtgan_config_id = 'default'  # as proposed in DiscoGAN paper
pxldtgan = PixelDTGan(model_fs_folder_or_root=models_groot, config_id=pxldtgan_config_id,
                      device=exec_device, log_level=log_level)
pxldtgan.logger.debug(f'[PixelDTGan] Starting visualization in device: {str(exec_device)}')
#   - visualize
pxldtgan.visualize_losses(dict_keys=('gen', 'disc_r', 'disc_a', ('gen', 'disc_r'), ('gen', 'disc_a'), ('disc_r', 'disc_a')), upload=True, preview=True, extract_dicts=True)
pxldtgan.visualize_metrics(upload=True, preview=True)

# CycleGAN
ccgan_config_id = 'discogan'
ccgan = CycleGAN(model_fs_folder_or_root=models_groot, config_id=ccgan_config_id,
                 device=exec_device, log_level=log_level)
ccgan.logger.debug(f'[CycleGAN] Starting visualization in device: {str(exec_device)}')
#   - visualize
ccgan.visualize_losses(dict_keys=('gen', 'disc', ('gen', 'disc')), upload=True, preview=True, extract_dicts=True)
ccgan.visualize_metrics(upload=True, preview=True)

# StyleGAN
stgan_config_id = 'default_z512'
stgan = StyleGan(model_fs_folder_or_root=models_groot, config_id=stgan_config_id,
                 device=exec_device, log_level=log_level)
stgan.logger.debug(f'[StyleGan] Starting visualization in device: {str(exec_device)}')
#   - visualize
stgan.visualize_losses(dict_keys=('gen', 'disc', ('gen', 'disc')), upload=True, preview=True, extract_dicts=True)
stgan.visualize_metrics(upload=True, preview=True)