*Notebook* based on Erik Linder-Nor√©n's code available at https://github.com/eriklindernoren/PyTorch-GAN

**Please feel free to check all his implementations of various GANs from the literature, it might help you for the second course work!**
#Imports

In [None]:
import argparse
import os
import numpy as np
import math
import sys

import torchvision.transforms as transforms
from torchvision.utils import save_image

from torch.utils.data import DataLoader
from torchvision import datasets
from torch.autograd import Variable

import torch.nn as nn
import torch.nn.functional as F
import torch.autograd as autograd
import torch

# Script constants

In [None]:
n_epochs=5
batch_size=64
lr=0.0002
b1=0.5
b2=0.999
n_cpu=8
latent_dim=100
img_size=28
channels=1
n_critic=5
clip_value=0.01
sample_interval=400

data_folder = "./data/mnist/"
inference_folder = "./images/"
gan_inference_folder = "./images/gan/"
wgan_inference_folder = "./images/wgan/"
wgangp_inference_folder = "./images/wgan-gp/"


img_shape = (channels, img_size, img_size)
cuda = True if torch.cuda.is_available() else False

# Dataset preparation

In [None]:
os.makedirs(data_folder, exist_ok=True)
os.makedirs(inference_folder, exist_ok=True)
os.makedirs(gan_inference_folder, exist_ok=True)
os.makedirs(wgan_inference_folder, exist_ok=True)
os.makedirs(wgangp_inference_folder, exist_ok=True)

dataloader = torch.utils.data.DataLoader(
    datasets.MNIST(
        data_folder,
        train=True,
        download=True,
        transform=transforms.Compose(
            [transforms.Resize(img_size), transforms.ToTensor(), transforms.Normalize([0.5], [0.5])]
        ),
    ),
    batch_size=batch_size,
    shuffle=True,
)

# Generative Adversarial Network (Goodfellow *et al.*)
## Network

Defining the Generator and Discriminator that are going to be used for the first experiments.

**Please read them carefully since you will be asked later to modify them!**

In [None]:
class Generator(nn.Module):
    def __init__(self):
        super(Generator, self).__init__()

        def block(in_feat, out_feat, normalize=True):
            layers = [nn.Linear(in_feat, out_feat)]
            if normalize:
                layers.append(nn.BatchNorm1d(out_feat, 0.8))
            layers.append(nn.LeakyReLU(0.2, inplace=True))
            return layers

        self.model = nn.Sequential(
            *block(latent_dim, 128, normalize=False),
            *block(128, 256),
            *block(256, 512),
            *block(512, 1024),
            nn.Linear(1024, int(np.prod(img_shape))),
            nn.Tanh()
        )

    def forward(self, z):
        img = self.model(z)
        img = img.view(img.shape[0], *img_shape)
        return img


class Discriminator(nn.Module):
    def __init__(self):
        super(Discriminator, self).__init__()

        self.model = nn.Sequential(
            nn.Linear(int(np.prod(img_shape)), 512),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Linear(512, 256),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Linear(256, 1),
            nn.Sigmoid(),
        )

    def forward(self, img):
        img_flat = img.view(img.shape[0], -1)
        validity = self.model(img_flat)
        return validity

## Initializations


In [None]:
#TODO: Initialize the required loss for the Goodfellow et al. GAN
adversarial_loss = torch.nn.BCELoss()

generator = Generator()
discriminator = Discriminator()

if cuda:
    generator.cuda()
    discriminator.cuda()
    adversarial_loss.cuda()

optimizer_G = torch.optim.Adam(generator.parameters(), lr=lr, betas=(b1, b2))
optimizer_D = torch.optim.Adam(discriminator.parameters(), lr=lr, betas=(b1, b2))

Tensor = torch.cuda.FloatTensor if cuda else torch.FloatTensor

## Training Loop
Where Generator's loss is $\frac{1}{m} \sum_{i=1}^{m} \log \left(1-D\left(G\left(\boldsymbol{z}^{(i)}\right)\right)\right)$

Where Discriminator's loss is $\frac{1}{m} \sum_{i=1}^{m}\left[\log D\left(\boldsymbol{x}^{(i)}\right)+\log \left(1-D\left(G\left(\boldsymbol{z}^{(i)}\right)\right)\right)\right]$

In [None]:
for epoch in range(n_epochs):
    for i, (imgs, _) in enumerate(dataloader):

        valid = Variable(Tensor(imgs.size(0), 1).fill_(1.0), requires_grad=False)
        fake = Variable(Tensor(imgs.size(0), 1).fill_(0.0), requires_grad=False)

        #|------------------------|
        #|   Generator training   |
        #|------------------------|

        real_imgs = Variable(imgs.type(Tensor))

        optimizer_G.zero_grad()

        #TODO: Sample noise
        z = Variable(Tensor(np.random.normal(0, 1, (imgs.shape[0], latent_dim))))

        gen_imgs = generator(z)

        #TODO: Calculate the loss for the generator
        g_loss = adversarial_loss(discriminator(gen_imgs), valid)

        g_loss.backward()
        optimizer_G.step()

        #|------------------------|
        #| Discriminator training |
        #|------------------------|

        optimizer_D.zero_grad()

        #TODO: Calculate the losses for the discriminator
        real_loss = adversarial_loss(discriminator(real_imgs), valid)
        fake_loss = adversarial_loss(discriminator(gen_imgs.detach()), fake)
        d_loss = (real_loss + fake_loss) / 2

        d_loss.backward()
        optimizer_D.step()

        print(
            "[Epoch %d/%d] [Batch %d/%d] [D loss: %f] [G loss: %f]"
            % (epoch, n_epochs, i, len(dataloader), d_loss.item(), g_loss.item())
        )

        batches_done = epoch * len(dataloader) + i
        if batches_done % sample_interval == 0:
            save_image(gen_imgs.data[:25], gan_inference_folder + "%d.png" % batches_done, nrow=5, normalize=True)

# Wasserstein GAN (Arjovsky *et al.*)
## Generator

In [None]:
#TODO: Make the required change in the discriminator
#      to work in a Wasserstein GAN
class Discriminator(nn.Module):
    def __init__(self):
        super(Discriminator, self).__init__()

        self.model = nn.Sequential(
            nn.Linear(int(np.prod(img_shape)), 512),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Linear(512, 256),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Linear(256, 1),
        )

    def forward(self, img):
        img_flat = img.view(img.shape[0], -1)
        validity = self.model(img_flat)
        return validity

## Initializations

In [None]:
generator = Generator()
discriminator = Discriminator()

if cuda:
    generator.cuda()
    discriminator.cuda()

optimizer_G = torch.optim.RMSprop(generator.parameters(), lr=lr)
optimizer_D = torch.optim.RMSprop(discriminator.parameters(), lr=lr)

Tensor = torch.cuda.FloatTensor if cuda else torch.FloatTensor

## Training loop
Where Discriminator's loss is $\frac{1}{m} \sum_{i=1}^{m} D\left(x^{(i)}\right)-\frac{1}{m} \sum_{i=1}^{m} D\left(G\left(z^{(i)}\right)\right)$

Where Generator's loss is $\frac{1}{m} \sum_{i=1}^{m} D\left(G\left(z^{(i)}\right)\right)$

In [None]:
batches_done = 0
for epoch in range(n_epochs):
    for i, (imgs, _) in enumerate(dataloader):
        #|------------------------|
        #| Discriminator training |
        #|------------------------|
        real_imgs = Variable(imgs.type(Tensor))

        optimizer_D.zero_grad()

        z = Variable(Tensor(np.random.normal(0, 1, (imgs.shape[0], latent_dim))))

        fake_imgs = generator(z).detach()

        #TODO: Calculate Discriminator's loss
        loss_D = -torch.mean(discriminator(real_imgs)) + torch.mean(discriminator(fake_imgs))

        loss_D.backward()
        optimizer_D.step()

        #TODO: Clamp discriminator's weights
        for p in discriminator.parameters():
            p.data.clamp_(-clip_value, clip_value)

        #|------------------------|
        #|   Generator training   |
        #|------------------------|

        if i % n_critic == 0:
            optimizer_G.zero_grad()

            gen_imgs = generator(z)

            #TODO: Calculate generator's loss
            loss_G = -torch.mean(discriminator(gen_imgs))

            loss_G.backward()
            optimizer_G.step()

            print(
                "[Epoch %d/%d] [Batch %d/%d] [D loss: %f] [G loss: %f]"
                % (epoch, n_epochs, batches_done % len(dataloader), len(dataloader), loss_D.item(), loss_G.item())
            )

        if batches_done % sample_interval == 0:
            save_image(gen_imgs.data[:25], wgan_inference_folder + "/%d.png" % batches_done, nrow=5, normalize=True)
        batches_done += 1

# WGAN with Gradient Penalty
## Gradient Penalty
The samples are calculated as : _"We implicitly define ${\mathbb{P}}_{\hat{x}}$ sampling uniformly along straight lines between pairs of points sampled from the data distribution ${\mathbb{P}}_{r}$ and the generator distribution ${\mathbb{P}}_{g}$."_

The gradient penalty is calculated as: $\lambda \underset{\hat{\boldsymbol{x}} \sim \mathbb{P}_{\hat{\boldsymbol{x}}}}{\mathbb{E}}\left[\left(\left\|\nabla_{\hat{\boldsymbol{x}}} D(\hat{\boldsymbol{x}})\right\|_{2}-1\right)^{2}\right]$

In [None]:
lambda_gp = 10

def compute_gradient_penalty(D, real_samples, fake_samples):
    #TODO: Calcualte interpolated samples
    alpha = Tensor(np.random.random((real_samples.size(0), 1, 1, 1)))
    interpolates = (alpha * real_samples + ((1 - alpha) * fake_samples)).requires_grad_(True)

    d_interpolates = D(interpolates)
    fake = Variable(Tensor(real_samples.shape[0], 1).fill_(1.0), requires_grad=False)

    gradients = autograd.grad(
        outputs=d_interpolates,
        inputs=interpolates,
        grad_outputs=fake,
        create_graph=True,
        retain_graph=True,
        only_inputs=True,
    )[0]

    gradients = gradients.view(gradients.size(0), -1)
    
    #TODO: Calculate the gradient penalty
    gradient_penalty = ((gradients.norm(2, dim=1) - 1) ** 2).mean()

    return gradient_penalty 

## Initializations

In [None]:
generator = Generator()
discriminator = Discriminator()

if cuda:
    generator.cuda()
    discriminator.cuda()

optimizer_G = torch.optim.Adam(generator.parameters(), lr=lr, betas=(b1, b2))
optimizer_D = torch.optim.Adam(discriminator.parameters(), lr=lr, betas=(b1, b2))

Tensor = torch.cuda.FloatTensor if cuda else torch.FloatTensor

## Training loop
Where discriminator's loss is $\underset{\tilde{\boldsymbol{x}} \sim \mathbb{P}_{g}}{\mathbb{E}}[D(\tilde{\boldsymbol{x}})]-\underset{\boldsymbol{x} \sim \mathbb{P}_{r}}{\mathbb{E}}[D(\boldsymbol{x})]+\lambda \underset{\hat{\boldsymbol{x}} \sim \mathbb{P}_{\hat{\boldsymbol{x}}}}{\mathbb{E}}\left[\left(\left\|\nabla_{\hat{\boldsymbol{x}}} D(\hat{\boldsymbol{x}})\right\|_{2}-1\right)^{2}\right]$

Where Generator's loss is

In [None]:
batches_done = 0
for epoch in range(n_epochs):
    for i, (imgs, _) in enumerate(dataloader):

        real_imgs = Variable(imgs.type(Tensor))

        optimizer_D.zero_grad()

        z = Variable(Tensor(np.random.normal(0, 1, (imgs.shape[0], latent_dim))))

        fake_imgs = generator(z)

        #TODO: Compute the gradient penalty
        gradient_penalty = compute_gradient_penalty(discriminator, real_imgs.data, fake_imgs.data)

        #TODO: Calculate discriminator's loss
        d_loss = -torch.mean(discriminator(real_imgs)) + torch.mean(discriminator(fake_imgs)) + lambda_gp * gradient_penalty

        d_loss.backward()
        optimizer_D.step()

        optimizer_G.zero_grad()

        if i % n_critic == 0:

            fake_imgs = generator(z)
            
            #TODO: Calculate generator's loss
            g_loss = -torch.mean(discriminator(fake_imgs))

            g_loss.backward()
            optimizer_G.step()

            print(
                "[Epoch %d/%d] [Batch %d/%d] [D loss: %f] [G loss: %f]"
                % (epoch, n_epochs, i, len(dataloader), d_loss.item(), g_loss.item())
            )

            if batches_done % sample_interval == 0:
                save_image(fake_imgs.data[:25], wgangp_inference_folder + "%d.png" % batches_done, nrow=5, normalize=True)

            batches_done += n_critic