$$
\newcommand{\mat}[1]{\boldsymbol {#1}}
\newcommand{\mattr}[1]{\boldsymbol {#1}^\top}
\newcommand{\matinv}[1]{\boldsymbol {#1}^{-1}}
\newcommand{\vec}[1]{\boldsymbol {#1}}
\newcommand{\vectr}[1]{\boldsymbol {#1}^\top}
\newcommand{\rvar}[1]{\mathrm {#1}}
\newcommand{\rvec}[1]{\boldsymbol{\mathrm{#1}}}
\newcommand{\diag}{\mathop{\mathrm {diag}}}
\newcommand{\set}[1]{\mathbb {#1}}
\newcommand{\norm}[1]{\left\lVert#1\right\rVert}
\newcommand{\pderiv}[2]{\frac{\partial #1}{\partial #2}}
\newcommand{\bm}[1]{{\bf #1}}
\newcommand{\bb}[1]{\bm{\mathrm{#1}}}
$$

#GAN project models
<a id=part3></a>

In this part we will implement and train a generative adversarial network and apply it to the task of image generation.

In [None]:
import unittest
import os
import sys
import pathlib
import urllib
import shutil
import re
import zipfile
import pickle
import numpy as np
import torch
import matplotlib.pyplot as plt

%load_ext autoreload
%autoreload 2
test = unittest.TestCase()
plt.rcParams.update({'font.size': 12})
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print('Using device:', device)

In [None]:
%reload_ext autoreload

In [None]:
from google.colab import drive
drive.mount('/content/drive/')

Joel's path in drive:

In [None]:
cd /content/drive/MyDrive/IDC/Courses/hw4 - pro

Gil's path in drive:

In [None]:
cd /content/drive/MyDrive/finalproject_mlds_MSc/hw4 - project/hw4 - pro

In [None]:
from project.inception import inception_score

### Obtaining the dataset
<a id=part3_1></a>

We'll use the same data as in Part 2.

But again, you can use a custom dataset, by editing the `PART3_CUSTOM_DATA_URL` variable in `hw4/answers.py`.

In [None]:
from project.gan_models import *

In [None]:
ds_gwb = load_bush_dataset()

In [None]:
import torch.optim as optim
from torch.utils.data import DataLoader
import project.gan_models as gan

torch.manual_seed(42)


def prepare_trainer(dl_train, im_size, z_dim, hp, sn=False, wgan=False):
  # Model
  dsc = gan.Discriminator(im_size, sn).to(device)
  gen = gan.Generator(z_dim, featuremap_size=4).to(device)

  # Optimizer
  def create_optimizer(model_params, opt_params):
      opt_params = opt_params.copy()
      optimizer_type = opt_params['type']
      opt_params.pop('type')
      return optim.__dict__[optimizer_type](model_params, **opt_params)
  dsc_optimizer = create_optimizer(dsc.parameters(), hp['discriminator_optimizer'])
  gen_optimizer = create_optimizer(gen.parameters(), hp['generator_optimizer'])

  # Training
  name = ''
  name += 'wgan' if wgan==True else 'gan'
  name += '_sn' if sn==True else ''
  name += f"_ncritic_{str(hp['n_critic'])}" if wgan==True else ''
  checkpoint_file = f'checkpoints/{name}'
  checkpoint_file_final = f'{checkpoint_file}_final'
  if os.path.isfile(f'{checkpoint_file}.pt'):
      os.remove(f'{checkpoint_file}.pt')

  return dsc, gen, dsc_optimizer, gen_optimizer, checkpoint_file, checkpoint_file_final, name

# Loss
def dsc_loss_fn(y_data, y_generated, wgan, hp):
    if wgan == False:
      return gan.discriminator_loss_fn(y_data, y_generated, hp['data_label'], hp['label_noise'])
    else:
      return gan.wgan_discriminator_loss_fn(y_data, y_generated)

def gen_loss_fn(y_generated, wgan, hp):
    if wgan == False:
      return gan.generator_loss_fn(y_generated, hp['data_label'])
    else:
      return gan.wgan_generator_loss_fn(y_generated)


In [None]:
import IPython.display
import tqdm
from project.gan_models import train_batch, save_checkpoint

def train_model(name, checkpoint_file, num_epochs, dl_train, dsc, gen, dsc_loss_fn, gen_loss_fn, dsc_optimizer, gen_optimizer, wgan, hp):
  print(f'*********** TRAINING MODEL {name} WITH hyperparams> **************')
  print(hp)
  try:
      dsc_avg_losses, gen_avg_losses = [], []
      IS = {'score':[],'std':[]}
      for epoch_idx in range(num_epochs):
          # We'll accumulate batch losses and show an average once per epoch.
          dsc_losses, gen_losses = [], []
          
          print(f'--- EPOCH {epoch_idx+1}/{num_epochs} ---')

          with tqdm.tqdm(total=len(dl_train.batch_sampler), file=sys.stdout) as pbar:
              for batch_idx, (x_data, _) in enumerate(dl_train):
                  x_data = x_data.to(device)
                  dsc_loss, gen_loss = train_batch(
                      dsc, gen,
                      dsc_loss_fn, gen_loss_fn,
                      dsc_optimizer, gen_optimizer,
                      x_data, wgan, hp)
                  dsc_losses.append(dsc_loss)
                  gen_losses.append(gen_loss)
                  pbar.update()
          mu,sigma = inception_score(gen.sample(1000, with_grad=False), cuda=True, batch_size=32, resize=True, splits=1)
          IS['score'].append(mu)
          IS['std'].append(sigma)
          dsc_avg_losses.append(np.mean(dsc_losses))
          gen_avg_losses.append(np.mean(gen_losses))
          print(f'Discriminator loss: {dsc_avg_losses[-1]}')
          print(f'Generator loss:     {gen_avg_losses[-1]}')
          print(f'Inception Score , std are: {IS["score"][-1]},{IS["std"][-1]}')
          if save_checkpoint(gen, dsc_avg_losses, gen_avg_losses, checkpoint_file):
              print(f'Saved checkpoint.')
              
          if (epoch_idx+1) % 50 == 0:
            samples = gen.sample(5, with_grad=False)
            fig, _ = plot.tensors_as_images(samples.cpu(), figsize=(6,2))
            IPython.display.display(fig)
            plt.close(fig)

      print('\n\n\n*** Images Generated from best model:')
      samples = gen.sample(n=15, with_grad=False).cpu()
      fig, _ = plot.tensors_as_images(samples, nrows=3, figsize=(6,6))
      IPython.display.display(fig)
      plt.close(fig)
      if wgan:
          with open(f"{name}_{hp['batch_size']}_{hp['z_dim']}_{hp['n_critic']}.pickle", 'wb') as handle:
            pickle.dump(IS, handle, protocol=pickle.HIGHEST_PROTOCOL)
      else:
          with open(f"{name}_{hp['batch_size']}_{hp['z_dim']}_{hp['label_noise']}.pickle", 'wb') as handle:
            pickle.dump(IS, handle, protocol=pickle.HIGHEST_PROTOCOL)
  except KeyboardInterrupt as e:
      print('\n *** Training interrupted by user')

In [None]:
batch_size = 32
dataset = ds_gwb
dl_train = DataLoader(dataset, batch_size, shuffle=True)
im_size = dataset[0][0].shape

num_epochs = 100

In [None]:
from project.hyperparams import vanilla_hyperparams
hp = vanilla_hyperparams()
z_dim = hp['z_dim']
sn = False
wgan = False

dsc, gen, dsc_optimizer, gen_optimizer, checkpoint_file, checkpoint_file_final, name = prepare_trainer(dl_train, im_size, z_dim, 
                                                                                                       hp, sn, wgan)
vanilla_checkpoint_file = train_model(name, checkpoint_file, num_epochs, dl_train, dsc, gen, dsc_loss_fn, gen_loss_fn, dsc_optimizer, gen_optimizer, wgan, hp)


In [None]:
from project.hyperparams import sn_gan_hyperparams
hp = sn_gan_hyperparams()
z_dim = hp['z_dim']
sn = True
wgan = False
dsc, gen, dsc_optimizer, gen_optimizer, checkpoint_file, checkpoint_file_final, name = prepare_trainer(dl_train, im_size, z_dim, 
                                                                                                      hp, sn, wgan)
sn_gan_checkpoint_file = train_model(name, checkpoint_file, num_epochs, dl_train, dsc, gen, dsc_loss_fn, gen_loss_fn, dsc_optimizer, gen_optimizer, wgan, hp)


In [None]:
from project.hyperparams import wgan_hyperparams
hp = wgan_hyperparams()
z_dim = hp['z_dim']
for nc in [1,2,5,10,20]:
  hp['n_critic'] = nc
  sn = False
  wgan = True
  dsc, gen, dsc_optimizer, gen_optimizer, checkpoint_file, checkpoint_file_final, name = prepare_trainer(dl_train, im_size, z_dim, 
                                                                                                        hp, sn, wgan)
  sn_gan_checkpoint_file = train_model(name, checkpoint_file, num_epochs, dl_train, dsc, gen, dsc_loss_fn, gen_loss_fn, dsc_optimizer, gen_optimizer, wgan, hp)


In [None]:
from project.hyperparams import wgan_hyperparams
hp = wgan_hyperparams()
z_dim = hp['z_dim']
for nc in [1,2,5,10,20]:
  hp['n_critic'] = nc
  sn = True
  wgan = True
  dsc, gen, dsc_optimizer, gen_optimizer, checkpoint_file, checkpoint_file_final, name = prepare_trainer(dl_train, im_size, z_dim, 
                                                                                                        hp, sn, wgan)
  sn_gan_checkpoint_file = train_model(name, checkpoint_file, num_epochs, dl_train, dsc, gen, dsc_loss_fn, gen_loss_fn, dsc_optimizer, gen_optimizer, wgan, hp)

In [None]:
from project.hyperparams import wgan_hyperparams
hp = wgan_hyperparams()
z_dim = hp['z_dim']
hp['n_critic'] = 5
sn = True
wgan = True
dsc, gen, dsc_optimizer, gen_optimizer, checkpoint_file, checkpoint_file_final, name = prepare_trainer(dl_train, im_size, z_dim, 
                                                                                                      hp, sn, wgan)
sn_gan_checkpoint_file = train_model(name, checkpoint_file, num_epochs, dl_train, dsc, gen, dsc_loss_fn, gen_loss_fn, dsc_optimizer, gen_optimizer, wgan, hp)


In [None]:
from project.hyperparams import wgan_hyperparams
hp = wgan_hyperparams()
z_dim = hp['z_dim']
hp['n_critic'] = 20
sn = True
wgan = True
dsc, gen, dsc_optimizer, gen_optimizer, checkpoint_file, checkpoint_file_final, name = prepare_trainer(dl_train, im_size, z_dim, 
                                                                                                      hp, sn, wgan)
sn_gan_checkpoint_file = train_model(name, checkpoint_file, num_epochs, dl_train, dsc, gen, dsc_loss_fn, gen_loss_fn, dsc_optimizer, gen_optimizer, wgan, hp)