<a href="https://colab.research.google.com/github/maha-alarifi/AIArtathon/blob/master/AligatuoAI.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [0]:
from pathlib import Path
from collections import defaultdict

from matplotlib import pyplot as plt

import torch
from torch import nn
from torch.optim import Adam
from torch.nn import functional as F
from torch.utils.data import DataLoader

from torchvision import transforms
from torchvision.utils import make_grid
from torchvision.datasets import ImageFolder

# Discriminative Models

In [0]:
class Discriminator(nn.Module):
  def __init__(self, channels=1, memory=64):
    super().__init__()
    self.features = nn.Sequential(  # fully convolutional model
      nn.Conv2d(channels, memory, 4, 2, 1, bias=False),
      nn.LeakyReLU(0.2, inplace=True),
      nn.Conv2d(memory, memory * 2, 4, 2, 1, bias=False),
      nn.BatchNorm2d(memory * 2),
      nn.LeakyReLU(0.2, inplace=True),
      nn.Conv2d(memory * 2, memory * 4, 4, 2, 1, bias=False),
      nn.BatchNorm2d(memory * 4),
      nn.LeakyReLU(0.2, inplace=True),
      nn.Conv2d(memory * 4, memory * 8, 4, 2, 1, bias=False),
      nn.BatchNorm2d(memory * 8),
      nn.LeakyReLU(0.2, inplace=True),
      nn.Conv2d(memory * 8, 1, 4, 1, 0, bias=False),
    )
    # self.classifier = nn.Sequential(
    #   nn.AdaptiveAvgPool2d((1, 1)),
    #   nn.Flatten(1),
    #   nn.Sigmoid(),
    # )

  def forward(self, images):
    # return self.classifier(self.features(images))
    # or equivalently
    return self.features(images).flatten(1).mean(1, keepdim=True).sigmoid()


# the most common loss function for classification is F.cross_entropy

# Generative Models

In [0]:
class Generator(nn.Module):
  def __init__(self, input_dim=100, channels=1, memory=64):
    super().__init__()
    self.input_dim = input_dim
    self.decoder = nn.Sequential(  # fully convolutional model
      nn.ConvTranspose2d(input_dim, memory * 8, 4, 1, 0, bias=False),
      nn.BatchNorm2d(memory * 8),
      nn.ReLU(True),
      nn.ConvTranspose2d(memory * 8, memory * 4, 4, 2, 1, bias=False),
      nn.BatchNorm2d(memory * 4),
      nn.ReLU(True),
      nn.ConvTranspose2d(memory * 4, memory * 2, 4, 2, 1, bias=False),
      nn.BatchNorm2d(memory * 2),
      nn.ReLU(True),
      nn.ConvTranspose2d(memory * 2, memory, 4, 2, 1, bias=False),
      nn.BatchNorm2d(memory),
      nn.ReLU(True),
      nn.ConvTranspose2d(memory, channels, 4, 2, 1, bias=False),
      nn.Tanh()
    )

  def forward(self, latent_code):
    return self.decoder(latent_code.view(-1, self.input_dim, 1, 1))

## Generative Adversarial Networks (GANs)

#### Configuration

In [0]:
# define commonly-changed training options
dataset = '/content/shapes'  # you can put a path to image folder
channels = 1         # RGB vs gray-scale
batch_size = 64      # input batch size

z_dim = 100          # size of the latent z vector
image_size = 64      # the height / width of the input image
capacity_d = 64      # size factor (memory) for the discriminator
capacity_g = 64      # size factor (memory) for the generator

epochs = 25          # number of epochs
lr = 0.0002          # learning rate for the optimizer (Adam)
beta1 = 0.5          # Adam beta1: exponential decay rate for the 1st moment

device = 'cuda:0' if torch.cuda.is_available() else 'cpu'

#### Data

In [0]:
# get the dataset and initialize the data loader
def get_dataset(dataset, image_size=image_size, channels=channels):
  Dataset = ImageFolder
  directory = Path(dataset)
  return Dataset(directory, transform=transforms.Compose([
           transforms.Resize(image_size),
           transforms.CenterCrop(image_size),
           transforms.Grayscale(),
           transforms.ToTensor(),
           transforms.Normalize([0.5] * channels, [0.5] * channels),
         ]))

data_loader = DataLoader(get_dataset(dataset), batch_size,
                         shuffle=True, drop_last=True, num_workers=4)

In [0]:
def plot_image_grid(images, columns=8, ax=None, show=True):
  if ax is None:
    _, ax = plt.subplots(figsize=(10, 10))
  image_grid = make_grid(images.detach().cpu(), columns, normalize=True)
  ax.imshow(image_grid.permute(1, 2, 0), interpolation='nearest')
  ax.axis('off')
  if show:
    plt.show(ax.figure)
  return ax

_ = plot_image_grid(next(iter(data_loader))[0])

#### Initialization

In [0]:
# create the discriminator (net_d) and the generator (net_g)
net_d = Discriminator(channels, memory=capacity_d).to(device)
net_g = Generator(z_dim, channels, memory=capacity_g).to(device)

# setup an optimizer for each model
optimizer_d = Adam(net_d.parameters(), lr=lr, betas=(beta1, 0.999))
optimizer_g = Adam(net_g.parameters(), lr=lr, betas=(beta1, 0.999))

# create fixed noise, to track generator progress, and labels
fixed_noise = torch.randn(batch_size, z_dim, device=device)
real_label = torch.ones(batch_size, 1, device=device)
fake_label = torch.zeros(batch_size, 1, device=device)

# define how you would like the progress to be printed
states = []   # save the state of our GAN after each epoch
metrics = []  # metrics to monitor
def progress(log):
  total = len(log['loss/gen'])
  err_g = sum(log['loss/gen']) / total
  err_d_real = sum(log['loss/d_real']) / total
  err_d_fake = sum(log['loss/d_fake']) / total
  d_x = sum(log['score/d_real']) / total
  d_g_z1 = sum(log['score/d_fake']) / total
  d_g_z2 = sum(log['score/gen']) / total
  return (f'\rLoss_D: {err_d_real + err_d_fake:.4f} Loss_G: {err_g:.4f} '
          f'D(x): {d_x:.4f} D(G(z)): {d_g_z1:.4f} / {d_g_z2:.4f}')

#### Training

In [0]:
def analyze(images, labels, binary_classifier):
  scores = binary_classifier(images)  # values in [0, 1]
  loss = F.binary_cross_entropy(scores, labels)
  if loss.requires_grad:
    loss.backward()  # compute gradients
  return float(loss), float(scores.detach().mean())

# start training
for epoch in range(len(metrics), epochs):
  print(f'Epoch [{epoch}/{epochs - 1}]')
  net_g.train(True)
  log = defaultdict(list)
  metrics.append(log)
  for i, data in enumerate(data_loader):
    # sample real (from dataset) and fake (using net_g) images
    real = data[0].to(device)  # real images
    fake = net_g(torch.randn(real.size(0), z_dim, device=device))
    fake_detached = fake.detach()

    # train the dicriminator: maximize log(D(x)) + log(1 - D(G(z)))
    optimizer_d.zero_grad()
    err_d_real, d_x = analyze(real, real_label, net_d)
    err_d_fake, d_g_z1 = analyze(fake_detached, fake_label, net_d)
    optimizer_d.step()  # learn
    
    # train the generator: maximize log(D(G(z)))
    optimizer_g.zero_grad()
    err_g, d_g_z2 = analyze(fake, real_label, net_d)
    optimizer_g.step()  # learn

    ##############################################################

    # record the metrics
    log['loss/d_real'].append(err_d_real)
    log['loss/d_fake'].append(err_d_fake)
    log['loss/gen'].append(err_g)
    log['score/d_real'].append(d_x)        # D(x)
    log['score/d_fake'].append(d_g_z1)     # D(G(z_1))
    log['score/gen'].append(d_g_z2)        # D(G(z_2))

    if i % 10 == 0:
      print(progress(log), end='')

  # display progress with few examples after every epoch
  print(progress(log))
  with torch.no_grad():
    net_g.train(False)
    epoch_samples = net_g(fixed_noise)
  ax = plot_image_grid(epoch_samples)
  # ax.figure.savefig(f'./samples_{epoch:03d}.png')

  # save your progress in a checkpoint (optional)
  state = {
      'log': log,
      'epoch': epoch,
      'net_d': net_d.state_dict(),
      'net_g': net_g.state_dict(),
      'opt_d': optimizer_d.state_dict(),
      'opt_g': optimizer_g.state_dict(),
  }
  # torch.save(state, f'./gan_{epoch:03d}.pth')
  states.append(state)
  if len(states) > 3:  # keep only the last 3 states
    states.pop(0)

*See also: [PyTorch-Lightning](https://colab.research.google.com/drive/1F_RNcHzTfFuQf-LeKvSlud6x7jXYkG31).*

#### Play with It

In [0]:
# every time you run this, you will plot new randomly generated samples
with torch.no_grad():
  net_g.train(False)
  z = torch.randn(batch_size, z_dim, device=device)
  plot_image_grid(net_g(z))

In [0]:
#@title interpolate between images { run: "auto" }
seed = 42 #@param {type:"slider", min:0, max:100, step:1}
num_pairs = 8
num_interpolations = 7
def interpolate(x, y, count=num_interpolations):
  alphas = torch.linspace(0, 1, count, device=x.device)
  out = torch.stack([(1 - a) * x + a * y for a in alphas])
  return out.permute(1, 0, 2).contiguous()

torch.manual_seed(seed)
z_1 = torch.randn(num_pairs, z_dim, device=device)
z_2 = torch.randn(num_pairs, z_dim, device=device)
with torch.no_grad():
  net_g.train(False)
  plot_image_grid(net_g(interpolate(z_1, z_2)), columns=num_interpolations)

### GAN Loss visualization

<button disabled><img src="https://3qeqpr26caki16dnhd19sv6by6v-wpengine.netdna-ssl.com/wp-content/uploads/2019/07/Line-Plots-of-Loss-and-Accuracy-for-a-Generative-Adversarial-Network-with-Mode-Collapse.png" width="600"></button>
<h5><i><b>Source:</b> <a href="https://machinelearningmastery.com/practical-guide-to-gan-failure-modes/">Jason Brownlee</a></i></h5>

GANs suffer from [**problems**](https://developers.google.com/machine-learning/gan/problems) like vanishing gradients, mode collapse, and failure to converge. [Patience and deep understanding](https://machinelearningmastery.com/practical-guide-to-gan-failure-modes/) are required.  
[**Here**](https://towardsdatascience.com/10-lessons-i-learned-training-generative-adversarial-networks-gans-for-a-year-c9071159628), is a quick ~10 minutes read that nicely summarizes ten things to keep in mind when training GANs by Marco Pasini.  
You might also be interested in reading "[Stabilizing GAN Training: A Survey](https://arxiv.org/abs/1910.00927)" by Maciej Wiatrak and Stefano Albrecht.



In [0]:
chunk_size = 100

def smooth(values, chunk_size=chunk_size):  # 1d non-overlapping conv
  if chunk_size <= 1:
    return values
  out = []
  for i in range(len(values) // chunk_size):
    chunk = values[i * chunk_size: (i + 1) * chunk_size]
    out.append(sum(chunk) / len(chunk))
  return out

plot = {}
for k in metrics[0]:
  plot[k] = smooth([x for m in metrics for x in m[k]])

x = torch.linspace(0, epochs - 1, len(plot['loss/gen'])).tolist()
plt.figure(figsize=(10, 7))
plt.plot(x, plot['loss/d_real'], label='loss/d_real')
plt.plot(x, plot['loss/d_fake'], label='loss/d_fake')
plt.plot(x, plot['loss/gen'], label='loss/gen')
plt.legend()
plt.grid()
plt.show()

plt.figure(figsize=(10, 7))
plt.plot(x, plot['score/d_real'], label='score/d_real')
plt.plot(x, plot['score/d_fake'], label='score/d_fake')
plt.plot(x, plot['score/gen'], label='score/gen')
plt.legend()
plt.grid()
plt.show()