## Generative Adversarial Nets

### About GANs
- Proposed by Ian Goodfellow et al., the original paper **Generative Adversarial Nets** can be found at <https://arxiv.org/abs/1406.2661>
- This work introduces a new framework for estimating generative models via an **adversarial** process, which consists of simultaneously training a generator G and a discriminator D
- Ian told that he got this idea while hanging out with his frends in a local bar at Montréal (where he was a PhD student at Université de Montréal)
 

### Importing necessary libraries...

In [1]:
import torch
import torch.nn as nn
import torch.optim as optim
import torchvision
import torchvision.datasets as datasets
from torch.utils.data import DataLoader
import torchvision.transforms as transforms
from torch.utils.tensorboard import SummaryWriter

### The GAN framework
- Key idea: The generative model is pitted against an adversary: a discriminative model that learns to detect whether a sample comes from the model (generator's) distribution or the (original) data distribution
- Analogy:
    - Generative model: Team of counterfeiters trying to produce fake currency and use it without detection
    - Discriminative model: Police trying to detect the counterfeit currency
- Both the generator G and the discriminator D are composed of Multilayer perceptrons
- The generator $G$ takes as input a fixed-dimensional noise vector $z$, usually sampled from a Gaussian, and tries to transform it into images $G(z)$ that are indistinguishable from the original data ($x \sim p_{data}$) while the discriminator's output $D(x)$ represents the probability that the sample $x$ comes from $p_{data}$
- i.e. $D$ and $G$ play the **two-player minimax game** with the value function $V(G, D)$
$$ \min_G \max_D V(D,G) = \mathbb{E}_{x \sim p_{data}(x)} [ \log D(x)] +  \mathbb{E}_{z \sim p_{z}(z)} [\log(1-D(G(z)))]$$
- The authors also derive closed form expressions (in the function space) for the optimal discriminator $D^*$ (given any generator $G$) and show that there is a unique solution that ensures convergence 

### Implementation of Generator & Discriminator

In [2]:
class Discriminator(nn.Module):
    def __init__(self, in_features):
        super().__init__()
        self.disc = nn.Sequential(
            nn.Linear(in_features, 128),
            nn.LeakyReLU(0.01),
            nn.Linear(128, 1),
            nn.Sigmoid(),
        )

    def forward(self, x):
        return self.disc(x)


class Generator(nn.Module):
    def __init__(self, z_dim, img_dim):
        super().__init__()
        self.gen = nn.Sequential(
            nn.Linear(z_dim, 256),
            nn.LeakyReLU(0.01),
            nn.Linear(256, img_dim),
            nn.Tanh(),  # we have inputs that are normalized in the range [-1, 1] --> so, make the G's outputs in [-1, 1]
        )

    def forward(self, x):
        return self.gen(x)

In [3]:
# Hyperparameters etc.
device = "cuda" if torch.cuda.is_available() else "cpu"
lr = 3e-4
z_dim = 64
image_dim = 28 * 28 * 1  # 784
batch_size = 32
num_epochs = 50

disc = Discriminator(image_dim).to(device)
gen = Generator(z_dim, image_dim).to(device)
fixed_noise = torch.randn((batch_size, z_dim)).to(device)
transforms = transforms.Compose(
    [
        transforms.ToTensor(),
        transforms.Normalize((0.5,), (0.5,)),
    ]
)

dataset = datasets.MNIST(root="dataset/", transform=transforms, download=True)
loader = DataLoader(dataset, batch_size=batch_size, shuffle=True)
opt_disc = optim.Adam(disc.parameters(), lr=lr)
opt_gen = optim.Adam(gen.parameters(), lr=lr)
criterion = nn.BCELoss()
writer_fake = SummaryWriter(f"logs/fake")
writer_real = SummaryWriter(f"logs/real")
step = 0

for epoch in range(num_epochs):
    for batch_idx, (real, _) in enumerate(loader):
        real = real.view(-1, 784).to(device)
        batch_size = real.shape[0]

        ### Train Discriminator: max log(D(x)) + log(1 - D(G(z)))
        noise = torch.randn(batch_size, z_dim).to(device)
        fake = gen(noise)
        disc_real = disc(real).view(-1)
        lossD_real = criterion(disc_real, torch.ones_like(disc_real))
        disc_fake = disc(fake).view(-1)
        lossD_fake = criterion(disc_fake, torch.zeros_like(disc_fake))
        lossD = (lossD_real + lossD_fake) / 2
        disc.zero_grad()
        lossD.backward(retain_graph=True)
        opt_disc.step()

        ### Train Generator: min log(1 - D(G(z))) <-> max log(D(G(z))
        # where the second option of maximizing doesn't suffer from
        # saturating gradients
        output = disc(fake).view(-1)
        lossG = criterion(output, torch.ones_like(output))
        gen.zero_grad()
        lossG.backward()
        opt_gen.step()

        if batch_idx == 0:
            print(
                f"Epoch [{epoch}/{num_epochs}] Batch {batch_idx}/{len(loader)} \
                      Loss D: {lossD:.4f}, loss G: {lossG:.4f}"
            )

            with torch.no_grad():
                fake = gen(fixed_noise).reshape(-1, 1, 28, 28)
                data = real.reshape(-1, 1, 28, 28)
                img_grid_fake = torchvision.utils.make_grid(fake, normalize=True)
                img_grid_real = torchvision.utils.make_grid(data, normalize=True)

                writer_fake.add_image(
                    "Images - Generated", img_grid_fake, global_step=step
                )
                writer_real.add_image(
                    "Images - MNIST", img_grid_real, global_step=step
                )
            step += 1

Epoch [0/50] Batch 0/1875                       Loss D: 0.7869, loss G: 0.6809
Epoch [1/50] Batch 0/1875                       Loss D: 0.3136, loss G: 1.6344
Epoch [2/50] Batch 0/1875                       Loss D: 0.4887, loss G: 1.1313
Epoch [3/50] Batch 0/1875                       Loss D: 0.3926, loss G: 1.4554
Epoch [4/50] Batch 0/1875                       Loss D: 1.0535, loss G: 0.5698
Epoch [5/50] Batch 0/1875                       Loss D: 0.4222, loss G: 1.4418
Epoch [6/50] Batch 0/1875                       Loss D: 1.3877, loss G: 0.6035
Epoch [7/50] Batch 0/1875                       Loss D: 0.9995, loss G: 0.7359
Epoch [8/50] Batch 0/1875                       Loss D: 0.8987, loss G: 0.5971
Epoch [9/50] Batch 0/1875                       Loss D: 0.3613, loss G: 1.5525
Epoch [10/50] Batch 0/1875                       Loss D: 0.9427, loss G: 0.8585
Epoch [11/50] Batch 0/1875                       Loss D: 0.5983, loss G: 1.1564
Epoch [12/50] Batch 0/1875                       L