# 1) Mount drive, unzip data, clone repo, install packages

## 1.1) Define paths
Main Google Drive root: `/workspace/GoogleDrive`

In [None]:
import os
from ipywidgets import FileUpload

# Define Google Drive related paths
drive_root = "/workspace/GoogleDrive"
!mkdir -p "$drive_root"
!mkdir -p "$drive_root/Models"
!mkdir -p "$drive_root/Datasets"
!mkdir -p "$drive_root/GitHub Keys"

# Upload ssh keys
is_first_run = not os.path.exists(f'{drive_root}/GitHub Keys/config') or not os.path.exists(f'{drive_root}/GitHub Keys/id_rsa') or not os.path.exists(f'{drive_root}/client_secrets.json')
if is_first_run:
    #   - config
    with open(f'{drive_root}/GitHub Keys/config', 'w') as fp:
        fp.writelines(['Host github.com\n', '    Hostname github.com\n', f'    IdentityFile "{drive_root}/GitHub Keys/id_rsa"\n', '    IdentitiesOnly yes\n'])
    #   - id_rsa.pub
    upload = FileUpload(multiple=True)
    display(upload)

In [None]:
if is_first_run:
    with open(f'{drive_root}/GitHub Keys/id_rsa', 'wb') as i:
        i.write(upload.value['id_rsa']['content'])
    !chmod 600 "$drive_root/GitHub Keys/id_rsa"
    with open(f'{drive_root}/GitHub Keys/id_rsa.pub', 'wb') as i:
        i.write(upload.value['id_rsa.pub']['content'])
    !chmod 600 "$drive_root/GitHub Keys/id_rsa.pub"

    # client_secrets.json
    with open(f'{drive_root}/client_secrets.json', 'wb') as i:
        i.write(upload.value['client_secrets.json']['content'])
    
    # kaggle.json
    with open(f'/workspace/kaggle.json', 'wb') as i:
        i.write(upload.value['kaggle.json']['content'])
    !mkdir -p /root/.kaggle
    !cp -rf /workspace/kaggle.json /root/.kaggle/
    !chmod 600 /root/.kaggle/kaggle.json

## 1.2) Clone GitHub repo
Clone achariso/gans-thesis repo into /content/code using git clone.
For more info see: https://medium.com/@purba0101/how-to-clone-private-github-repo-in-google-colab-using-ssh-77384cfef18f

In [None]:
if is_first_run:
    !conda install git -y

import os

repo_root = '/workspace/code/gans-thesis'
ssh_root = '/root/.ssh'
!rm -rf "$repo_root"
!rm -rf "$ssh_root"
if not os.path.exists(repo_root) and not os.path.exists(f'{repo_root}/requirements.txt'):
    # Check that ssh keys exist
    assert os.path.exists(f'{drive_root}/GitHub Keys')
    id_rsa_abs_drive = f'{drive_root}/GitHub Keys/id_rsa'
    id_rsa_pub_abs_drive = f'{id_rsa_abs_drive}.pub'
    assert os.path.exists(id_rsa_abs_drive)
    assert os.path.exists(id_rsa_pub_abs_drive)
    # On first run: Add ssh key in repo
    if not os.path.exists(ssh_root) or True:
        # Transfer config file
        ssh_config_abs_drive = f'{drive_root}/GitHub Keys/config'
        assert os.path.exists(ssh_config_abs_drive)
        !mkdir -p "$ssh_root"
        !cp -f "$ssh_config_abs_drive" "$ssh_root/"
        # # Add github.com to known hosts
        !ssh-keyscan -t rsa github.com >> "$ssh_root/known_hosts"
        # Test: !ssh -T git@github.com

    # Remove any previous attempts
    !rm -rf "$repo_root"
    !mkdir -p "$repo_root"
    # Clone repo
    !ssh -o StrictHostKeyChecking=no github.com
    !git clone git@github.com:achariso/gans-thesis.git "$repo_root" -o StrictHostKeyChecking=no
    src_root = f'{repo_root}/src'
    !rm -rf "$repo_root"/report

## 1.3) Install pip packages
All required files are stored in a requirements.txt files at the repository's root.
Use `pip install -r requirements.txt` from inside the dir to install required packages.

In [None]:
%cd "$repo_root"
!pip install -r requirements.txt
!pip install kaggle --upgrade

# import os
# os.kill(os.getpid(), 9)

## 1.4) Add code/, */src/ to path
This is necessary in order to be able to run the modules.

In [None]:
content_root_abs = f'{repo_root}'
src_root_abs = f'{repo_root}/src'
%env PYTHONPATH="/env/python:$content_root_abs:$src_root_abs"

# 2) Train StyleGAN model on DeepFashion's Fashion Image Synthesis Benchmark dataset
In this section we run the actual training loop for StyleGAN network. StyleGAN consists of a stylized generator and a
fairly naive discriminator architecture. Both however are progressively grown, starting from a resolution of 4x4 up to
the final resolution fo 128x128.

In [None]:
chkpt_step = 'latest'  # supported: 'latest', <int>, None
log_level = 'debug'    # supported: 'debug', 'info', 'warning', 'error', 'critical', 'fatal'
device = 'cuda'        # supported: 'cpu', 'cuda', 'cuda:<GPU_INDEX>'

# Running with -i enables us to get variables defined inside the script (the script runs inline)
%run -i src/train_setup.py --log_level $log_level --chkpt_step $chkpt_step --seed 42 --device $device -use_refresh_token

## Download dataset from Kaggle instead of Google Drive
This results in much - much faster download times.

In [None]:
if not os.path.exists(f'{drive_root}/Datasets/DeepFashion/Fashion Synthesis Benchmark/Img.h5'):
    !kaggle datasets download achariso/deepfashion-fisb -p /workspace
    !rm -rf "$drive_root"/Datasets/DeepFashion/Fashion\ Synthesis\ Benchmark
    !mkdir -p "$drive_root"/Datasets/DeepFashion/Fashion\ Synthesis\ Benchmark
    !conda install unzip -y
    !unzip /workspace/deepfashion-fisb.zip -d "$drive_root"/Datasets/DeepFashion/Fashion\ Synthesis\ Benchmark/
    if os.path.exists(f'{drive_root}/Datasets/DeepFashion/Fashion Synthesis Benchmark/Img.h5'):
        !rm -rf /workspace/deepfashion-fisb.zip
    else:
        print(f'Error: Dataset not found at "{drive_root}/Datasets/DeepFashion/Fashion Synthesis Benchmark"')

### StyleGAN Training

Eventually, run the code!


In [None]:
%cd 'src/'

import torch
from IPython.core.display import display
from PIL import Image
from torch import Tensor
from torch.nn import DataParallel
# noinspection PyProtectedMember
from torch.utils.data import DataLoader
from torchvision.transforms import transforms

from datasets.deep_fashion import FISBDataset, FISBDataloader
from modules.stylegan import StyleGan
from tqdm import tqdm
from utils.dep_free import get_tqdm
from utils.ifaces import FilesystemDataset
from utils.metrics import GanEvaluator

###################################
###  Hyper-parameters settings  ###
###################################
#   - training
n_epochs = 300
batch_size = 128 if not run_locally else 4
train_test_splits = [90, 10]  # for a 90% training - 10% evaluation set split
#   - evaluation
metrics_n_samples = 1000 if not run_locally else 2
metrics_batch_size = 32 if not run_locally else 1
f1_k = 3 if not run_locally else 1
#   - visualizations / checkpoints steps
display_steps = {4: 40, 8: 80, 16: 120, 32: 160, 64: 200, 128: 200}
checkpoint_steps = {k:3*v for k,v in display_steps.items()}
metrics_steps = {k:3*v for k,v in checkpoint_steps.items()}  # evaluate model every 3 checkpoints
#   - dataset
target_shape = 128
target_channels = 3
#   - StyleGAN configuration
z_dim = 512
use_half_precision = False
stgan_config_id = f'default_z{z_dim}' if not use_half_precision \
    else f'default_z{z_dim}_half'

###################################
###   Dataset Initialization    ###
###################################
#   - image transforms:
#     If target_shape is different from load one, resize & crop. If target_shape is different from load shape,
#     convert to grayscale.
#     Update: Now done automatically if you set target_channels, target_shape when instantiating the dataloader.
gen_transforms = FISBDataset.get_image_transforms(target_shape=target_shape, target_channels=target_channels)
#   - the dataloader used to access the training dataset of cross-scale/pose image pairs at every epoch
#     > len(dataloader) = <number of batches>
#     > len(dataloader.dataset) = <number of total dataset items>
dataloader = FISBDataloader(dataset_fs_folder_or_root=datasets_groot, batch_size=batch_size, log_level=log_level,
                            image_transforms=gen_transforms, splits=train_test_splits, pin_memory=not run_locally,
                            load_in_memory=False)
dataset = dataloader.dataset  # save training dataset as `dataset`
#   - ensure dataset is fetched locally and unzipped
if isinstance(dataset, FilesystemDataset):
    dataset.fetch_and_unzip(in_parallel=False, show_progress=True)
elif hasattr(dataset, 'dataset') and isinstance(dataset.dataset, FilesystemDataset):
    dataset.dataset.fetch_and_unzip(in_parallel=False, show_progress=True)
else:
    raise TypeError('dataset must implement utils.ifaces.FilesystemDataset in order to be auto-downloaded and unzipped')
#   - apply rudimentary tests
assert issubclass(dataloader.__class__, DataLoader)
assert len(dataloader) == len(dataset) // batch_size + (1 if len(dataset) % batch_size else 0)
_real = next(iter(dataloader))
assert tuple(_real.shape) == (batch_size, target_channels, target_shape, target_shape)

###################################
###    Models Initialization    ###
###################################
#   - initialize evaluator instance (used to run GAN evaluation metrics: FID, IS, PRECISION, RECALL, F1 and SSIM)
evaluator = GanEvaluator(model_fs_folder_or_root=models_groot, gen_dataset=dataset, z_dim=z_dim, device=exec_device,
                         n_samples=metrics_n_samples, batch_size=metrics_batch_size, f1_k=f1_k)
#   - initialize model
chkpt_step = args.chkpt_step
try:
    if chkpt_step == 'latest':
        stgan_chkpt_step = chkpt_step
    elif isinstance(chkpt_step, str) and chkpt_step.isdigit():
        stgan_chkpt_step = int(chkpt_step)
    else:
        stgan_chkpt_step = None
except NameError:
    stgan_chkpt_step = None
stgan = StyleGan(model_fs_folder_or_root=models_groot, config_id=stgan_config_id, dataset_len=len(dataset),
                 chkpt_epoch=stgan_chkpt_step, evaluator=evaluator, device=exec_device, log_level=log_level)
stgan.logger.debug(f'Using device: {str(exec_device)}')
stgan.logger.debug(f'Model initialized. Number of params = {stgan.nparams_hr}')
# FIX: Warmup counters before first batch
if stgan.step is None:
    stgan.gforward(batch_size=batch_size)
    stgan.logger.debug(f'Model warmed-up (internal counters).')
# FIX: Dataloader batch_size need update
if stgan.current_batch_size is not None and stgan.current_batch_size != batch_size:
    stgan.logger.debug(f'Updating Dataloader batch_size (from {batch_size} --> {stgan.current_batch_size}).')
    batch_size = stgan.current_batch_size
    dataloader = dataloader.update_batch_size(batch_size=batch_size)
#   - setup multi-GPU training
if torch.cuda.device_count() > 1:
    stgan.gen = DataParallel(stgan.gen)
    stgan.info(f'Using {torch.cuda.device_count()} GPUs for StyleGAN Generator (via torch.nn.DataParallel)')
#   - load dataloader state (from model checkpoint)
if 'dataloader' in stgan.other_state_dicts.keys():
    dataloader.set_state(stgan.other_state_dicts['dataloader'])
    stgan.logger.debug(f'Loaded dataloader state! Current pem_index={dataloader.get_state()["perm_index"]}')

# FIX: Change batch size (if needed)
stgan.update_batch_size(batch_size, sampler_instance=dataloader.sampler)

### Actual Training Loop

In [None]:
# Update lr
stgan.update_lr(gen_new_lr=1e-5, disc_new_lr=1e-5)

###################################
###       Training Loop         ###
###################################
#   - get the correct tqdm instance
exec_tqdm = get_tqdm()
#   - start training loop from last checkpoint's epoch and step
torch.cuda.empty_cache()
gcapture_ready = True
async_results = None
stgan.logger.info(f'[training loop] STARTING (epoch={stgan.epoch}, step={stgan.initial_step})')
for epoch in range(stgan.epoch, n_epochs):
    # Check if the networks should grow
    if stgan.growing() or batch_size != stgan.current_batch_size:
        batch_size = stgan.current_batch_size
        stgan.logger.critical(f'Reinitializing Dataloader... (new batch_size={batch_size})')
        dataloader = dataloader.update_batch_size(batch_size=batch_size)
        stgan.update_batch_size(batch_size, sampler_instance=dataloader.sampler)
    
    # Set steps
    display_step = display_steps[stgan.gen.resolution]
    checkpoint_step = checkpoint_steps[stgan.gen.resolution]
    metrics_step = metrics_steps[stgan.gen.resolution]

    # noinspection PyProtectedMember
    d = {
        'step': stgan.step,
        'initial_step': stgan.initial_step,
        'epoch': stgan.epoch,
        '_counter': stgan._counter,
        'epoch_inc': stgan.epoch_inc,
    }
    # initial_step = stgan.initial_step % len(dataloader)
    stgan.logger.debug('[START OF EPOCH] ' + str(d))

    # Instantiate progress bart
    progress_bar: tqdm = exec_tqdm(dataloader, initial=stgan.initial_step)
    progress_bar.set_description(f'[e {str(epoch).zfill(3)}/{str(n_epochs).zfill(3)}]' +
                                 f'[g --.-- | d --.--]')

    real: Tensor
    for real in progress_bar:
        # Downsample images
        if real.shape[-1] != stgan.gen.resolution:
            real = transforms.Resize(size=stgan.gen.resolution, interpolation=Image.BILINEAR)(real)

        # Transfer image batches to GPU
        real = real.to(exec_device)

        # Perform a forward + backward pass + weight update on the Generator & Discriminator models
        disc_loss, gen_loss = stgan(real)

        # Update loss in tqdm description
        if gen_loss is not None and disc_loss is not None:
            progress_bar.set_description(f'[e {str(epoch).zfill(3)}/{str(n_epochs).zfill(3)}]' +
                                         f'[g {round(gen_loss.item(), 2)} | d {round(disc_loss.item(), 2)}]')

        # Metrics & Checkpoint Code
        if stgan.step % checkpoint_step == 0:
            # Check if another upload is pending
            if not gcapture_ready and async_results:
                # Wait for previous upload to finish
                stgan.logger.warning('Waiting for previous gcapture() to finish...')
                [r.wait() for r in async_results]
                stgan.logger.warning('DONE! Starting new capture now.')
            # Capture current model state, including metrics and visualizations
            async_results = stgan.gcapture(checkpoint=True, metrics=stgan.step % metrics_step == 0, visualizations=True,
                                           dataloader=dataloader, in_parallel=True, show_progress=True,
                                           delete_after=True)
        # Visualization code
        elif stgan.step % display_step == 0:
            visualization_img = stgan.visualize()
            visualization_img.show() if not in_notebook() else display(visualization_img)

        # Check if a pending checkpoint upload has finished
        if async_results:
            gcapture_ready = all([r.ready() for r in async_results])
            if gcapture_ready:
                stgan.logger.info(f'gcapture() finished')
                if stgan.latest_checkpoint_had_metrics:
                    stgan.logger.info(str(stgan.latest_metrics))
                async_results = None

        # If run locally one pass is enough
        if run_locally and gcapture_ready:
            break

    # If run locally one pass is enough
    if run_locally:
        break

    # noinspection PyProtectedMember
    d = {
        'step': stgan.step,
        'initial_step': stgan.initial_step,
        'epoch': stgan.epoch,
        '_counter': stgan._counter,
        'epoch_inc': stgan.epoch_inc,
    }
    stgan.logger.debug('[END OF EPOCH] ' + str(d))

# Check if a pending checkpoint exists
if async_results:
    ([r.wait() for r in async_results])
    stgan.logger.info(f'last gcapture() finished')
    if stgan.latest_checkpoint_had_metrics:
        stgan.logger.info(str(stgan.latest_metrics))
    async_results = None

# Training finished!
stgan.logger.info('[training loop] DONE')

In [None]:
stgan.gen_opt.param_groups[0]['lr'], stgan.disc_opt.param_groups[0]['lr'], stgan.disc_iters