# Generative models

- We will implement VAE and GAN.

In [1]:
import matplotlib.pyplot as plt
import numpy as np
import torch as th
import torchvision as thvis
import torchvision.transforms as T

from torch import nn
from torch.nn import functional as F
from torch.utils.data import DataLoader
from torchvision.utils import save_image

## Variational Auto-Encoder

In [2]:
batch_size = 100
transform = T.ToTensor()

train_dataset = thvis.datasets.MNIST("./data", transform=transform, train=True, download=True)
train_dataloader = DataLoader(dataset=train_dataset, batch_size=batch_size, shuffle=True)
test_dataset  = thvis.datasets.MNIST("./data", transform=transform, train=False, download=True)
test_dataloader  = DataLoader(dataset=test_dataset,  batch_size=batch_size, shuffle=False)

<center><img src="./figs/vae.png" width = 600></center>

- The encoder $f(x)$ produces mean $\mu(x)$ and log of variance $\log\sigma^2(x)$.

In [3]:
class Encoder(nn.Module):

  def __init__(self, input_dim, hidden_dim, latent_dim):
    super().__init__()

    self.fc1 = nn.Linear(input_dim, hidden_dim)
    self.fc2 = nn.Linear(hidden_dim, hidden_dim)
    self.fc_mean = nn.Linear(hidden_dim, latent_dim)
    self.fc_log_variance = nn.Linear (hidden_dim, latent_dim)
    self.activation = nn.LeakyReLU(0.2)
    self.training = True

  def forward(self, x):
    h = self.activation(self.fc1(x))
    h = self.activation(self.fc2(h))
    mean = self.fc_mean(h)
    log_variance = self.fc_log_variance(h)
    return mean, log_variance

encoder = Encoder(input_dim=784, hidden_dim=400, latent_dim=200)

<center><img src="./figs/vae.png" width = 600></center>

- Given a latent encoding $z$, the decoder $g$ produces reconstructed image $\hat{x}$.

In [4]:
class Decoder(nn.Module):

  def __init__(self, latent_dim, hidden_dim, output_dim):
    super().__init__()

    self.fc1 = nn.Linear(latent_dim, hidden_dim)
    self.fc2 = nn.Linear(hidden_dim, hidden_dim)
    self.fc3 = nn.Linear(hidden_dim, output_dim)
    self.activation = nn.LeakyReLU(0.2)

  def forward(self, x):
    h = self.activation(self.fc1(x))
    h = self.activation(self.fc2(h))
    return th.sigmoid(self.fc3(h))

decoder = Decoder(latent_dim=200, hidden_dim=400, output_dim=784)

Given two images $x_1,x_2$ (`images`), let's sample latent encoding $z_1,z_2\in\mathbb{R}^{200}$ (`encodings`).
$$
z_i\sim\mathcal{N}(\mu(x_i),\sigma^2(x_i))
$$

In [5]:
x1, _ = train_dataset[0]
x2, _ = train_dataset[1]
images = th.cat([x1, x2], axis=0)
print(1, images.shape)
images = images.view(-1, 784)
print(2, images.shape)
mean, log_variance = """ Change here """
print(3, mean.shape, log_variance.shape)
stddev = """ Change here """
epsilon = """ Change here """
encodings = mean + stddev * epsilon
print(4, encodings.shape)

1 torch.Size([2, 28, 28])
2 torch.Size([2, 784])
3 torch.Size([2, 200]) torch.Size([2, 200])
4 torch.Size([2, 200])


In [6]:
class VAE(nn.Module):

  def __init__(self, encoder, decoder):
    super().__init__()
    self.encoder = encoder
    self.decoder = decoder

  def sample(self, mean, stddev):
    epsilon = th.randn_like(stddev)
    z = mean + stddev * epsilon
    return z

  def forward(self, x):
    mean, log_variance = self.encoder(x)
    z = self.sample(mean, th.exp(0.5 * log_variance))
    return self.decoder(z), mean, log_variance

vae = VAE(encoder=encoder, decoder=decoder)

The VAE loss function is simply a sum of reconstruction loss and KL divergence term:
$$
\sum_{i=1}^B \lVert x_i - g(z_i) \rVert_2^2 + \mathrm{KL}( \mathcal{N}(\mu(x_i),\sigma^2(x_i)) \Vert \mathcal{N}(0, 1) ) \ ,
$$
where $B$ denotes the batch size.

We provide a function to compute the KL divergence term.

In [7]:
def KL(mean, log_variance):
  return - 0.5 * th.sum(1 + log_variance - mean.pow(2) - log_variance.exp())

In [8]:
n_epochs = 30
optimizer = th.optim.Adam(vae.parameters(), lr=1e-3)

vae.train()
for epoch in range(1, n_epochs + 1):
  # Monitor training loss
  train_loss = 0.0

  for batch_idx, (batch, _) in enumerate(train_dataloader):
    optimizer.zero_grad()

    # The shape of `images` is (B, 1, 28, 28)
    # We expect 784=28*28 points.
    batch = batch.view(-1, 784)
    # Compute reconstructed images, and Gaussian parameters.
    batch_hat, mean, log_variance = vae(batch)

    reconstruction_loss = """ Change here """
    kldiv = """ Change here """
    loss = reconstruction_loss + kldiv

    # Optimize.
    loss.backward()
    optimizer.step()

    # Update stats
    train_loss += loss.item() * images.size(0)

  # Log training stats
  train_loss = train_loss / len(train_dataloader)
  print(f"Epoch [{epoch}/{n_epochs}] loss={train_loss:.6f}")

Epoch [1/30] loss=34657.864723
Epoch [2/30] loss=25592.713070
Epoch [3/30] loss=23373.347513
Epoch [4/30] loss=22469.289118
Epoch [5/30] loss=21948.079928
Epoch [6/30] loss=21643.518008


- Original images vs. Reconstructed images

In [None]:
# Obtain one batch of test images.
images, labels = next(iter(test_dataloader))
images_flatten = images.view(images.size(0), -1)
images = images.numpy()

# Get sample outputs
outputs, _, _ = vae(images_flatten)
outputs = outputs.view(batch_size, 1, 28, 28)
outputs = outputs.detach().numpy()

fig, axes = plt.subplots(nrows=2, 
                         ncols=10, 
                         figsize=(20, 4))
for images, row in zip([images, outputs], axes):
  for img, ax in zip(images, row):
    ax.imshow(np.squeeze(img), cmap="gray")
fig.tight_layout()

- Generating new images

In [None]:
epsilon = th.randn(batch_size, 200)
generated_images = """ Change here """
generated_images.shape

In [None]:
fig, ax = plt.subplots()
def show(i):
  ax.imshow(generated_images[i].reshape(28, 28).detach().numpy(), cmap="gray")
show(6)

The fun part is interpolating between classes, that reveals transition from a class A to another class B.

In [None]:
class_a = 5
class_b = 6
n_steps = 10

step_size = 1.0 / n_steps

for (images, labels) in train_dataloader:
  images_a = images[labels == class_a][0]
  images_b = images[labels == class_b][0]
  break

def compute_encoding(x):
  mean, log_variance = vae.encoder(x)
  return vae.sample(mean, th.exp(0.5 * log_variance))

with th.no_grad():
  encodings_a = compute_encoding(images_a.reshape(-1, 784))
  encodings_b = compute_encoding(images_b.reshape(-1, 784))

diff = encodings_b - encodings_a
steps = th.arange(0.0, 1.0 + step_size, step_size, dtype=th.float).reshape(-1, 1)
interpolated_encodings = encodings_a + (steps * diff)
interpolated_images = vae.decoder(interpolated_encodings).reshape(-1, 28, 28)

ncols = len(interpolated_images)
fig, axes = plt.subplots(ncols=ncols, figsize=(2*ncols, 2))
for i in range(ncols):
  axes[i].imshow(interpolated_images[i].detach().numpy(), cmap="gray")

fig.tight_layout()

## Generative Adversarial Networks

<center><img src="./figs/gan.png" width = 600></center>

In [None]:
batch_size = 100

transform = T.Compose([T.ToTensor(),
                       T.Normalize(mean=[0.5], std=[0.5])])
train_dataset = thvis.datasets.MNIST(root="./data", train=True, transform=transform, download=True)
train_dataloader = DataLoader(dataset=train_dataset, batch_size=batch_size, shuffle=True)
test_dataset = thvis.datasets.MNIST(root="./data", train=False, transform=transform, download=False)
test_dataloader = DataLoader(dataset=test_dataset, batch_size=batch_size, shuffle=False)

In [None]:
class Generator(nn.Module):

  def __init__(self, input_dim, output_dim):
    super().__init__()       
    self.fc1 = nn.Linear(input_dim, 256)
    self.fc2 = nn.Linear(256, 512)
    self.fc3 = nn.Linear(512, 1024)
    self.fc4 = nn.Linear(self.fc3.out_features, output_dim)

  def forward(self, x): 
    x = F.leaky_relu(self.fc1(x), 0.2)
    x = F.leaky_relu(self.fc2(x), 0.2)
    x = F.leaky_relu(self.fc3(x), 0.2)
    return th.tanh(self.fc4(x))

latent_dim = 100
G = Generator(input_dim=latent_dim, output_dim=784)

In [None]:
class Discriminator(nn.Module):

  def __init__(self, input_dim):
    super().__init__()
    self.fc1 = nn.Linear(input_dim, 1024)
    self.fc2 = nn.Linear(1024, 512)
    self.fc3 = nn.Linear(512, 256)
    self.fc4 = nn.Linear(256, 1)

  def forward(self, x):
    x = F.leaky_relu(self.fc1(x), 0.2)
    x = F.dropout(x, 0.3)
    x = F.leaky_relu(self.fc2(x), 0.2)
    x = F.dropout(x, 0.3)
    x = F.leaky_relu(self.fc3(x), 0.2)
    x = F.dropout(x, 0.3)
    return th.sigmoid(self.fc4(x))

D = Discriminator(784)

### Training discriminator

<center><img src="./figs/gan_discriminator.png" width = 600></center>

We **maximize** the following objective to update the discriminator $D$:
$$
\frac{1}{B}\sum_{i=1}^B \Big( \log D(x_i) + \log (1 - D(G(z_i))) \Big) \ ,
$$
where $B$ denotes the batch size, and $z_i\sim\mathcal{N}(0, 1)$ for each $i=1,2,\dots,B$.

- $D(x)$ is the probability of $x$ being **real**.
- Note that we do not update $G$.

It is equivalent to minimize the following loss function:
$$
\frac{1}{B}\sum_{i=1}^B \Big( \ell_\texttt{BCE}(D(x_i), 1) + \ell_\texttt{BCE}(D(G(z_i)), 0) \Big) \ ,
$$
where $\ell_\texttt{BCE}(p,q)\coloneqq -q\log p - (1 - q)\log (1-p)$ is the binary cross entropy loss.

In [None]:
D_optimizer = th.optim.Adam(D.parameters(), lr=0.0002)
criterion = nn.BCELoss()

def train_discriminator(batch):
  # The shape of `batch` is `(B, 1, 28, 28)`

  D_optimizer.zero_grad()

  real_loss = """ Change here """
  fake_loss = """ Change here """

  loss = real_loss + fake_loss
  loss.backward()
  D_optimizer.step()
      
  return loss.item()

### Training generator

<center><img src="./figs/gan_generator.png" width = 600></center>

We **minimize** the following objective to update the generator $G$:
$$
\frac{1}{B}\sum_{i=1}^B  \log (1 - D(G(z_i))) \ ,
$$
where $B$ denotes the batch size, and $z_i\sim\mathcal{N}(0, 1)$ for each $i=1,2,\dots,B$.

- $D(x)$ is the probability of $x$ being **real**.
- Note that we do not update $D$.

It is equivalent to minimize the following loss function:
$$
\frac{1}{B}\sum_{i=1}^B \ell_\texttt{BCE}(D(G(z_i)), 1) \ ,
$$
where $\ell_\texttt{BCE}(p,q)\coloneqq -q\log p - (1 - q)\log (1-p)$ is the binary cross entropy loss.

In [None]:
G_optimizer = th.optim.Adam(G.parameters(), lr=0.0002)
criterion = nn.BCELoss()

def train_generator():

  G_optimizer.zero_grad()
  loss = """ Change here """
  loss.backward()
  G_optimizer.step()
      
  return loss.item()

We alternate training of the discriminator $D$ and the generator $G$.

In [None]:
n_epochs = 20
for epoch in range(1, n_epochs + 1):
  D_losses, G_losses = [], []
  for (batch, _) in train_dataloader:
    D_losses.append(train_discriminator(batch))
    G_losses.append(train_generator())
  print(f"Epoch [{epoch}/{n_epochs}] discriminator_loss={np.mean(D_losses):.3f}, generator_loss={np.mean(G_losses)}")

In [None]:
def evaluate(G, D):

  p_real = 0.0
  p_fake = 0.0
  for images, _ in test_dataloader:
    with th.no_grad():
      p_real += (D(images.view(-1, 28 * 28))).sum().item()
      p_fake += (D(G(th.randn(batch_size, latent_dim)))).sum().item()
  return (p_real / len(test_dataset), 
          p_fake / len(test_dataset))

In [None]:
n_epochs = 20

p_real_trace = []
p_fake_trace = []
for epoch in range(1, n_epochs + 1):
  D_losses, G_losses = [], []
  for (batch, _) in train_dataloader:
    D_losses.append(train_discriminator(batch))
    G_losses.append(train_generator())
  print(f"Epoch [{epoch}/{n_epochs}] discriminator_loss={np.mean(D_losses):.3f}, generator_loss={np.mean(G_losses)}")

  p_real, p_fake = evaluate(G, D)
  p_real_trace.append(p_real)
  p_fake_trace.append(p_fake)

- How well dose the discriminator $D$ discriminate between generated images and real images?

In [None]:
fig, ax = plt.subplots()
ax.plot(p_fake_trace, label="D(generated images)")
ax.plot(p_real_trace, label="D(real images)")
ax.set_ylabel("Probability")
ax.set_xlabel("Epoch")
ax.legend()

- Visualizing generated images

In [None]:
with th.no_grad():
  z = th.randn(batch_size, latent_dim)
  generated_images = G(z).view(-1, 1, 28, 28)
save_image(generated_images, "gan_generated_images.png")