In [None]:
%matplotlib inline
import matplotlib.pyplot as plt
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F

torch.manual_seed(1231); #you may want to make use of this in various cells for reproducability

# Generative Adversarial Networks (GANs)

## Introduction and Motivation

Recall from the lecture:

GANs are a popular class of generative models introduced in
[Goodfellow *et al.* 2014](https://arxiv.org/pdf/1406.2661.pdf). Such models
can capture the data distribution of a given dataset and be used to generate
new data samples that resemble those in the dataset.

The two major components of a GAN are a generator network $G$ and a
discriminator network $D$.

$G$ takes noise samples $z$ from a prior distribution and use them to generate
new data samples. $D$ takes data samples as input and decides whether they are
real (coming from the
actual training set) or fake (generated by $G$).

In practice, these networks "battle": $G$ continually attmepts to generate
more realistic samples and $D$ continually tries to get better at
distinguishing real samples from fake smaples. This can be formulated as a
min-max zero sum game:

$$
\min_G \max_D V(D, G)
= \mathbb{E}_{x \sim p_{data}(x)} [\log D(x)]
+ \mathbb{E}_{z \sim p_z(z)} [\log (1 - D(G(z)))]
$$
where $p_z(z)$ denotes a prior over the noise vector.

The hope is that, over training, $D$ becomes a very good discriminator and $G$
becomes a very good generator. After training, we can then use $G$ to generate
new data samples as if they were drawn from the distribution of the original
training dataset.

## GAN for Gaussian Distribution

As an introduction, we will define and train a GAN that can replicate a simple
2D Gaussian distribution.

Let's start by generating random samples from a 2D Gaussian distribution as our
training dataset.

In [None]:
A = torch.tensor([[1., 2.], [2., 1.]])
u = torch.tensor([2., 1.])
xs = torch.randn((1000, 2)) @ A + u

#Plotting:
plt.scatter(xs[:,0].numpy(), xs[:,1].numpy(), s=1, c='b')
plt.axis('equal')
plt.show()

Now, let's define our $D$ and $G$ networks as described in the introduction.

For the generator $G$, we want a network that can map a noise vector to a 2D
vector which appears to be drawn from our training distribution. Since our
target distribution is simple, we could sample from a 2D distrbution and
transform it using a very simple neural network. (In fact, a single linear
layer should suffice.)

(5 points) In the cell below, define the generator and see what the generated
data look like.

In [None]:
class generator(nn.Module):
    # TODO

G = generator()

zs = torch.randn(1000, 2)
with torch.no_grad():
    fake_xs = G(zs)

plt.scatter(xs[:,0].numpy(), xs[:,1].numpy(), s=1, c='b')
plt.scatter(fake_xs[:,0].numpy(), fake_xs[:,1].numpy(), s=1, c='r')
plt.axis('equal')
plt.show()

Next, let's define the discriminator network $D$. Basically, we need a
classifier that takes a 2D data sample and outputs the probability of
it being real or fake.

(5 points) In the cell below, define a simple multi-layer perceptron that
maps 2D samples to 1D scores. Since we will use `BCEWithLogitsLoss`, there
is no need to normalize the output using the sigmoid function.

In [None]:
class discriminator(nn.Module):
    # TODO
    
D = discriminator()

## Training the GAN

Now, we will set up the training regime for $G$ and $D$.

For each batch:

1. While fixing $G$, train the discriminator $D$ to make it better at telling real samples from genearted ones.

2. While fixing $D$, train the generator $G$ to make it better at generating realistic samples, such that $D$ predicts these samples as real samples.

In [None]:
epochs = 20
batch_size = 32
lr_G = 0.1
lr_D = 0.1

# Data loaders:
train_dataset = torch.utils.data.TensorDataset(xs)
train_loader = torch.utils.data.DataLoader(
    train_dataset, batch_size=batch_size, shuffle=True)

# Otimizers:
optimizer_G = torch.optim.SGD(G.parameters(), lr=lr_G)
optimizer_D = torch.optim.SGD(D.parameters(), lr=lr_D)

# Loss function:
loss_fn = nn.BCEWithLogitsLoss()

# Training loop:
losses_D = []
losses_G = []
for epoch in range(epochs):
    for x, in train_loader:
        z = torch.randn(batch_size, 2)

        # Train discriminator D while fixing G.
        optimizer_D.zero_grad()
        
        x_fake = G(z)
        y_fake = D(x_fake.detach())
        y_real = D(x)
        
        # We want D to predict high scores for real samples,
        loss_D = loss_fn(y_real,
            torch.ones(y_real.shape, device=x.device))
        # ... and low scores for synthesized samples,
        loss_D += loss_fn(y_fake,
            torch.zeros(y_fake.shape, device=x.device))
        loss_D /= 2
        
        loss_D.backward()
        optimizer_D.step()
        losses_D.append(loss_D.item())
        
        # Train generator G while fixing D.
        optimizer_G.zero_grad()
        
        x_fake = G(z)
        y_fake = D(x_fake)
        loss_G = loss_fn(y_fake,
            torch.ones(y_fake.shape, device=x.device))
        
        loss_G.backward()
        optimizer_G.step()
        losses_G.append(loss_G.data.item())
    
    print(f"[{epoch+1:>3d}/{epochs:>3d}] loss_D: {loss_D:>.6f}, loss_G: {loss_G:>.6f}")

# Generate new samples using trained G and compare them to the training distribution.
zs = torch.randn(1000, 2)
with torch.no_grad():
    fake_xs = G(zs)

plt.scatter(xs[:,0].numpy(), xs[:,1].numpy(), s=1, c='b')
plt.scatter(fake_xs[:,0].numpy(), fake_xs[:,1].numpy(), s=1, c='r')
plt.axis('equal')
plt.show()