<a href="https://colab.research.google.com/github/azfarkhoja305/GANs/blob/training-loop-v2/notebooks/TransGAN_CIFAR.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# Initialization

In [1]:
# Check GPU
!nvidia-smi

Wed Apr 21 17:46:39 2021       
+-----------------------------------------------------------------------------+
| NVIDIA-SMI 460.67       Driver Version: 460.32.03    CUDA Version: 11.2     |
|-------------------------------+----------------------+----------------------+
| GPU  Name        Persistence-M| Bus-Id        Disp.A | Volatile Uncorr. ECC |
| Fan  Temp  Perf  Pwr:Usage/Cap|         Memory-Usage | GPU-Util  Compute M. |
|                               |                      |               MIG M. |
|   0  Tesla T4            Off  | 00000000:00:04.0 Off |                    0 |
| N/A   53C    P8     9W /  70W |      0MiB / 15109MiB |      0%      Default |
|                               |                      |                  N/A |
+-------------------------------+----------------------+----------------------+
                                                                               
+-----------------------------------------------------------------------------+
| Proces

In [2]:
import os
import sys
import pdb
import numpy as np
import matplotlib.pyplot as plt
from pathlib import Path
from matplotlib import animation, rc
from tqdm.auto import tqdm

import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import torchvision.datasets as datasets
import torchvision.transforms as transforms
import torchvision.utils as vutils

# from torchsummary import summary
from types import SimpleNamespace

In [3]:
from google.colab import drive
drive.mount('/content/drive')
gdrive = Path('drive/MyDrive')
# gdrive = Path('/home/AP')

# Create a required checkpoint instance. 
# If does not exists, Checkpoint class will create one.
ckp_folder = gdrive/'full_v2_moving_avg'

Mounted at /content/drive


In [4]:
!git clone --single-branch --branch training-loop-v2 https://github.com/azfarkhoja305/GANs.git
!git -C GANs/ pull

Cloning into 'GANs'...
remote: Enumerating objects: 514, done.[K
remote: Counting objects: 100% (514/514), done.[K
remote: Compressing objects: 100% (386/386), done.[K
remote: Total 514 (delta 298), reused 265 (delta 127), pack-reused 0[K
Receiving objects: 100% (514/514), 135.76 MiB | 14.63 MiB/s, done.
Resolving deltas: 100% (298/298), done.
Already up to date.


In [5]:
Path.ls = lambda x: list(x.iterdir())
if Path('./GANs').exists():
    sys.path.insert(0,'./GANs')

In [6]:
from models.transformer_generator import TGenerator
from models.ViT_discriminator import Discriminator
from utils.utils import check_gpu, display_images, set_seed, reduce_resolution, weights_init, LinearLrDecay
from utils.checkpoint import Checkpoint
from utils.loss import wgangp_eps_loss
from utils.datasets import ImageDataset
%load_ext autoreload
%autoreload 2

In [7]:
# seed notebook
set_seed(seed=123)
device = check_gpu()
print(f'Using device: {device}')

Using device: cuda


In [8]:
!PYTHONPATH=./GANs python ./GANs/scripts/create_fid_stats.py -d cifar_10 -t False

Namespace(batch_size=256, dataset='cifar_10', save='fid_stats', train='False')
Downloading https://www.cs.toronto.edu/~kriz/cifar-10-python.tar.gz to data/cifar_10/cifar-10-python.tar.gz
170499072it [00:03, 46807101.82it/s]                   
Extracting data/cifar_10/cifar-10-python.tar.gz to data/cifar_10
Downloading: "https://github.com/mseitzer/pytorch-fid/releases/download/fid_weights/pt_inception-2015-12-05-6726825d.pth" to /root/.cache/torch/hub/checkpoints/pt_inception-2015-12-05-6726825d.pth
100% 91.2M/91.2M [00:01<00:00, 61.8MB/s]


# Training Setup

In [9]:
gen_batch_sz = 64
dis_batch_sz = 32
latent_dims = 1024
lr, beta1, beta2 = 1e-4, 0, 0.999
max_iter = 500000
num_epochs = 320
num_ckps = 64
start_after = 0.015625
disc_args = SimpleNamespace(**{"d_depth": 7, "df_dim": 384, "img_size": 32, "patch_size": 8})
disc_augments = "translation,cutout,color"
lamb = 50.0
n_critic = 5
moving_avg_alpha = 0.001
mse_loss = nn.MSELoss(reduction='mean')
fixed_z = torch.randn(gen_batch_sz, latent_dims, device=device)

In [10]:
# Add Horizontal Flip by default
tfms = [transforms.RandomHorizontalFlip()]
dataset =  ImageDataset('cifar_10', batch_sz=dis_batch_sz, tfms=tfms, 
                        num_workers=2, drop_last=True)
# display_images(dataset.train_loader)

Files already downloaded and verified


In [11]:
Gen = TGenerator(latent_dims=latent_dims, use_att_mask=True).to(device)
# summary(Gen,(latent_dims,))

In [12]:
Dis = Discriminator(disc_args,augments=disc_augments).to(device)
# summary(Dis,(3,32,32,))

In [13]:
optG = optim.AdamW(Gen.parameters(), lr=lr, betas=(beta1, beta2))
optD = optim.AdamW(Dis.parameters(), lr=lr, betas=(beta1, beta2))

In [14]:
loss_logs = {'gen_loss': [], 'dis_loss': []}
img_list = []

In [15]:
ckp_class = Checkpoint(ckp_folder, max_epochs=num_epochs, num_ckps=num_ckps, start_after=start_after)

# check if any existing checkpoint exists, none found hence start_epoch is 0.
# Optimizer states also get saved
Gen, Gen_avg_params, Dis, optG, optD, start_epoch, step, old_logs = \
                        ckp_class.check_if_exists(Gen, Dis, optG, optD)

loss_logs = old_logs or loss_logs
start_epoch #, loss_logs
print(step)

Checkpoint folder with checkpoints already exists. Searching for the latest.
=> Loading checkpoint: drive/MyDrive/full_v2_moving_avg/GanModel_010.pth
17182


In [16]:
# number of prints per epoch 
num = 4
print_at = np.linspace(100, len(dataset.train_loader)-1, num=num, dtype=np.int).tolist()

# Training Loop v2 Fixes

In [17]:
gen_scheduler = LinearLrDecay(optG, lr, 0.0, 0, max_iter * n_critic)
dis_scheduler = LinearLrDecay(optD, lr, 0.0, 0, max_iter * n_critic)
g_lr = gen_scheduler.step(step)
d_lr = dis_scheduler.step(step)
assert g_lr == d_lr
print(g_lr)

9.931272000000001e-05


In [18]:
if start_epoch == 0:
    print('Initializing parameters...')
    Gen.apply(weights_init)
    Dis.apply(weights_init)

# Loop

In [19]:
ckp = torch.load("drive/MyDrive/full_v2_moving_avg/GanModel_005.pth")

In [20]:
# from copy import deepcopy
# generator_avg = deepcopy(Gen)
# generator_avg.load_state_dict(ckp['generator_avg_state_dict'])
# generator_avg_params = deepcopy(list(p.data for p in generator_avg.parameters()))
# print(generator_avg_params)

In [21]:
for epoch in tqdm(range(start_epoch, num_epochs)):
    for i, data in enumerate(tqdm(dataset.train_loader, leave=False)):
        
        ###########################
        # (1) Update Dis network
        ###########################
        
        ## Train with all-real batch
        Dis.zero_grad()
        real = data[0].to(device)
        output_real = Dis(real).view(-1)

        ## Train with all-fake batch
        dis_z = torch.randn(dis_batch_sz, latent_dims, device=device)
        fake_1 = Gen(dis_z, epoch).detach()
        output_fake_1 = Dis(fake_1).view(-1)

        ## Compute loss and backpropagate
        errD = wgangp_eps_loss(Dis, real, fake_1, 1.0, output_real, output_fake_1)
        errD.backward()
        torch.nn.utils.clip_grad_norm_(Dis.parameters(), 5.0)
        optD.step()

        ###########################
        # (2) Update Gen network
        ###########################
        if step % n_critic == 0:
            Gen.zero_grad()
            gen_z = torch.randn(gen_batch_sz, latent_dims, device=device)
            fake_2 = Gen(gen_z, epoch)
            output_fake_2 = Dis(fake_2).view(-1)
            errG = -torch.mean(output_fake_2)

            # Co-training task
            low_res_img = reduce_resolution(real)
            generated_img = Gen.super_resolution(low_res_img,epoch=epoch)
            co_train_loss = mse_loss(generated_img, real)

            errG += lamb * co_train_loss
            errG.backward()
            torch.nn.utils.clip_grad_norm_(Gen.parameters(), 5.0)
            optG.step()

            g_lr = gen_scheduler.step(step)
            d_lr = dis_scheduler.step(step)
            assert g_lr == d_lr

            # moving average weight
            # if use_moving_avg:
            for params, avg_params in zip(Gen.parameters(), Gen_avg_params):
                avg_params.mul_(1.0 - moving_avg_alpha).add_(params.data, alpha=moving_avg_alpha)
            
            # Save Losses for plotting later
            loss_logs['gen_loss'].append(errG.item())

        ###########################
        # (3) Output
        ###########################

        # Save Losses for plotting later
        loss_logs['dis_loss'].append(errD.item())

        step += 1

        if i in print_at:
            # Output training stats averged over last 100 iterations
            gen_mean = np.mean(loss_logs['gen_loss'][-100:])
            dis_mean = np.mean(loss_logs['dis_loss'][-100:])
            print(f'[{epoch}/{num_epochs}][{i}/{len(dataset.train_loader)}]\t'
                  f'Loss_D: {round(gen_mean, 4)}\t',
                  f'Loss_G: {round(dis_mean, 4)}')
        
            # Check how the generator is doing by saving G's output on fixed_noise
            with torch.no_grad():
                fixed_fake = Gen(fixed_z, epoch).detach().cpu()
            img_list.append(vutils.make_grid(fixed_fake, padding=2, normalize=True))
    
    # Checkpoint
    ckp_class.at_epoch_end(Gen, Gen_avg_params, Dis, optG, optD, epoch=epoch, step=step, loss_logs=loss_logs)

HBox(children=(FloatProgress(value=0.0, max=309.0), HTML(value='')))

HBox(children=(FloatProgress(value=0.0, max=1562.0), HTML(value='')))

[11/320][100/1562]	Loss_D: 1.5859	 Loss_G: 0.0046


KeyboardInterrupt: ignored

# Analysis

In [None]:
_, axs = plt.subplots(1,2,figsize=(15,15))
display_images(dataset.train_loader,ax = axs[0])
display_images(img_list, ax = axs[1])
plt.tight_layout()

In [None]:
plt.figure(figsize=(10,5))
plt.title("Generator and Discriminator Loss During Training")
plt.plot(loss_logs['gen_loss'],label="G")
plt.plot(loss_logs['dis_loss'],label="D")
plt.xlabel("iterations")
plt.ylabel("Loss")
plt.ylim([-10,10])
plt.legend()
plt.show()

Calculating FID Score

In [None]:
from metrics.torch_is_fid_score import is_fid_from_generator
stat_path = Path('fid_stats/cifar_10_valid_fid_stats.npz')
inception_score, fid = is_fid_from_generator(generator=Gen,
                                        latent_dims=latent_dims,
                                        num_imgs=10000,
                                        batch_sz=256,
                                        fid_stat_path = stat_path)

In [None]:
print(f"\nFID score: {fid}")
print(f"\nIS: {inception_score}")

In [None]:
# rc('animation', html='jshtml')
# fig = plt.figure(figsize=(8,8))
# plt.axis("off")
# ims = [[plt.imshow(np.transpose(i,(1,2,0)), animated=True)] for i in img_list]
# ani = animation.ArtistAnimation(fig, ims, interval=1000, repeat_delay=1000, blit=True)
# ani

In [None]:
!kill $(ps aux | awk '{print $2}')