# Seminar 10: Autoregression, VAE, and GAN

**Deep Learning Course 2025**

**Author:** Nikita Kiselev

In [None]:
import torch
from tqdm import tqdm
import torch.nn as nn
import torch.optim as optim
import matplotlib.pyplot as plt
import torch.nn.functional as F
from torch.utils.data import DataLoader
from torchvision import datasets, transforms

plt.rcParams["axes.linewidth"] = 1.5
plt.rcParams["font.size"] = 16

In [None]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

## 1. Autoregression

Let's use binarized MNIST dataset.

In [None]:
transform = transforms.Compose([
    transforms.ToTensor(),
    lambda x: (x > 0.5).float()
])
train_dataset = datasets.MNIST(root='./data', train=True, transform=transform, download=True)
train_loader = DataLoader(dataset=train_dataset, batch_size=128, shuffle=True)

**Building the model**

Image pixels are treated as input tokens. Their values can be either 0 or 1.

To leverage autoregressive modeling, we add `<START>` token into beginning, let it equals 2.

As we follow classification objective and use tokens from predefined vocabulary (0, 1, 2), we need embeddings for them.

Recurrent neural network is implemented with LSTM module.

The last layer projects hidden states into probabilities of ones per each pixel.

In [None]:
class PixelRNN(nn.Module):
    def __init__(self, hidden_dim=512):
        super().__init__()
        self.embed = nn.Embedding(3, hidden_dim)
        self.rnn = nn.LSTM(hidden_dim, hidden_dim, batch_first=True)
        self.fc = nn.Linear(hidden_dim, 1)
        
    def forward(self, x):
        start_token = torch.full((x.shape[0], 1), 2, dtype=torch.long, device=x.device)
        x = torch.cat([start_token, x[:, :-1]], dim=1)
        x = self.embed(x)
        x, _ = self.rnn(x)
        x = self.fc(x)
        return x

**Training and evaluating the model**

We've alredy build the model class, so let's train it!

Firstly, we define a loss function – Binary Cross Entropy.

In [None]:
criterion = nn.BCEWithLogitsLoss()

Then, we make a training loop.

In [None]:
epochs = 10
learning_rate = 3e-4

model = PixelRNN().to(device)
optimizer = optim.AdamW(model.parameters(), lr=learning_rate)

model.train()

for epoch in tqdm(range(epochs), desc="Epoch"):
    for x, _ in train_loader:
        x = x.view(-1, 784).to(device)
        prob = model(x.long())
        loss = criterion(prob.view(-1), x.view(-1))
        loss.backward()
        optimizer.step()
        optimizer.zero_grad()

**Sampling from the latent space**

We’ll sample random points from a standard normal distribution and feed them through the decoder to generate synthetic images.

In [None]:
@torch.no_grad()
def sample(model, num_samples):
    model.eval()
    x = torch.full((num_samples, 1), 2, dtype=torch.long, device=device)
    h = c = None
    generated = []

    for _ in range(784):
        emb = model.embed(x)
        if h is None or c is None:
            out, (h, c) = model.rnn(emb)
        else:
            out, (h, c) = model.rnn(emb, (h, c))
        logits = model.fc(out)
        p = torch.sigmoid(logits).squeeze(-1)
        x = torch.bernoulli(p).long()
        generated.append(x)

    images = torch.stack(generated, dim=1)
    images = images.view(num_samples, 28, 28).float().cpu()
    return images

In [None]:
num_samples = 16

images = sample(model, num_samples)
    
fig, axes = plt.subplots(2, num_samples // 2, figsize=(12, 4)) 
for i, ax in enumerate(axes.flat): 
  ax.imshow(images[i], cmap='gray') 
  ax.axis('off') 

plt.suptitle("Generated Samples")
plt.show() 

## 2. VAE

In contrast to PixelRNN example, here we use **non-binarized** MNIST.

In [None]:
transform = transforms.ToTensor()
train_dataset = datasets.MNIST(root='./data', train=True, transform=transform, download=True)
train_loader = DataLoader(dataset=train_dataset, batch_size=128, shuffle=True)

**Building the encoder class**

The encoder takes an input image and compresses it into a compact latent representation.

But unlike a regular autoencoder, it doesn’t output a single point - instead, it outputs two vectors: the mean $\boldsymbol{\mu}$ and log-variance $\log \boldsymbol{\sigma}^2$.

These define a probability distribution from which we’ll later sample a latent vector.

In [None]:
class Encoder(nn.Module):
    def __init__(self, input_dim=784, hidden_dim=400, latent_dim=20):
        super().__init__()
        self.fc1 = nn.Linear(input_dim, hidden_dim)
        self.act = nn.GELU()
        
        self.fc_mu = nn.Linear(hidden_dim, latent_dim)
        self.fc_logvar = nn.Linear(hidden_dim, latent_dim)
        
    def forward(self, x):
        x = self.act(self.fc1(x))
        mu = self.fc_mu(x)
        logvar = self.fc_logvar(x)
        return mu, logvar

**Implementing the decoder class**

The decoder does the reverse, it takes a sampled point from the latent space and tries to reconstruct the original image.

This helps the VAE learn to generate new data similar to the training inputs.

In [None]:
class Decoder(nn.Module):
    def __init__(self, latent_dim=20, hidden_dim=400, output_dim=784):
        super().__init__()
        self.fc1 = nn.Linear(latent_dim, hidden_dim)
        self.act1 = nn.GELU()
        
        self.fc2 = nn.Linear(hidden_dim, output_dim)
        self.act2 = nn.Sigmoid()
    
    def forward(self, x):
        x = self.act1(self.fc1(x))
        x = self.act2(self.fc2(x))
        return x

**Creating the main VAE class**

The VAE class combines the encoder and decoder and implements the reparameterization trick to keep training differentiable.

Since sampling directly from a distribution $\mathbf{z} \sim \mathcal{N}(\boldsymbol{\mu}, \boldsymbol{\sigma}^2)$ breaks backpropagation, we sample $\boldsymbol{\epsilon} \sim \mathcal{N}(\mathbf{0}, \mathbf{I})$ and compute $\mathbf{z} = \boldsymbol{\mu} + \boldsymbol{\sigma} \cdot  \boldsymbol{\epsilon}$ instead.

This trick allows gradients to flow through the sampling step during training.

In [None]:
class VAE(nn.Module):
    def __init__(self, input_dim=784, hidden_dim=400, latent_dim=20):
        super().__init__()
        self.encoder = Encoder(input_dim, hidden_dim, latent_dim)
        self.decoder = Decoder(latent_dim, hidden_dim, input_dim)
    
    def reparameterize(self, mu, logvar):
        std = torch.exp(0.5 * logvar)
        eps = torch.randn_like(mu)
        z = mu + std * eps
        return z
    
    def forward(self, x):
        mu, logvar = self.encoder(x)
        z = self.reparameterize(mu, logvar)
        x_recon = self.decoder(z)
        return x_recon, mu, logvar

**Loss function**

Training a VAE involves optimizing a composite loss function that balances two goals:

1. **Reconstruction Loss:** Ensures the output is close to the original input.

2. **Regularization Loss (KL divergence):** Encourages the learned latent distribution to be close to a standard normal distribution.

The total loss is often referred to as the ELBO (Evidence Lower Bound), and we aim to maximize it (or equivalently, minimize its negative).

In our case (binary images from MNIST) we have the following reconstruction loss:

$$ \mathcal{L}_2(\boldsymbol{\theta}, \boldsymbol{\phi}) = \| \mathbf{x} - \mathbf{x}' \|_2^2 $$

Regularization loss is easily derived for gaussian latent distribution as follows:

$$ \mathcal{L}_{\text{KL}}(\boldsymbol{\theta}, \boldsymbol{\phi}) = - \frac{1}{2} \left( 1 + \log \boldsymbol{\sigma}^2 - \boldsymbol{\mu}^2 - \boldsymbol{\sigma}^2 \right) $$

In [None]:
def loss_function(x, x_recon, mu, logvar):
    recon_loss = F.mse_loss(x_recon, x, reduction="sum")
    kl_loss = -0.5 * torch.sum(1 + logvar - mu.pow(2) - logvar.exp())
    total_loss = recon_loss + kl_loss
    return total_loss

**Training and evaluating the VAE**

Now that we’ve defined the VAE model and its loss function, it’s time to train it.

We’ll run the training loop for several epochs, calculate the loss, and visualize how well the VAE learns to reconstruct and generate data.

In [None]:
epochs = 10
learning_rate = 1e-3

model = VAE().to(device)
optimizer = optim.AdamW(model.parameters(), lr=learning_rate)

model.train()

for epoch in tqdm(range(epochs), desc="Epoch"):
    for x, _ in train_loader:
        x = x.view(-1, 784).to(device)
        x_recon, mu, logvar = model(x)
        loss = loss_function(x, x_recon, mu, logvar)
        loss.backward()
        optimizer.step()
        optimizer.zero_grad()

**Sampling from the latent space**

We’ll sample random points from a standard normal distribution and feed them through the decoder to generate synthetic images.

In [None]:
@torch.no_grad()
def sample(model, num_samples):
    model.eval()
    z = torch.randn(num_samples, 20).to(device)
    x = model.decoder(z).cpu()
    x = x.view(-1, 1, 28, 28)
    return x

In [None]:
num_samples = 16

x = sample(model, num_samples)
    
fig, axes = plt.subplots(2, num_samples // 2, figsize=(12, 4)) 
for i, ax in enumerate(axes.flat): 
  ax.imshow(x[i][0], cmap='gray') 
  ax.axis('off') 

plt.suptitle("Generated Samples from Latent Space") 
plt.show() 

**Latent space interpolation**

We can also interpolate between two points in the latent space to see smooth transitions in generated images, which is a hallmark of a well-trained VAE.

In [None]:
@torch.no_grad()
def interpolate(model, z1, z2, steps=10):
    model.eval()
    z = torch.stack([
        z1 * (1 - t) + z2 * t for t in torch.linspace(0, 1, steps)
    ]).to(device)
    x = model.decoder(z.to(device)).cpu()
    x = x.view(-1, 1, 28, 28)
    return x

In [None]:
z1 = torch.randn(1, 20)
z2 = torch.randn(1, 20)
steps = 10

x = interpolate(model, z1, z2, steps)

fig, axes = plt.subplots(1, steps, figsize=(15, 2))
for i, ax in enumerate(axes.flat):
  ax.imshow(x[i][0], cmap="gray")
  ax.axis("off")

plt.suptitle("Latent Space Interpolation")  
plt.show() 

## 3. GAN

Similar to VAE example, here we use **non-binarized** MNIST.

In [None]:
transform = transforms.ToTensor()
train_dataset = datasets.MNIST(root='./data', train=True, transform=transform, download=True)
train_loader = DataLoader(dataset=train_dataset, batch_size=128, shuffle=True)

**Building the Generator class**

Generator $G$ maps random noise $\mathbf{z} \sim \mathcal{N}(\mathbf{0}, \mathbf{I})$ to images and tries to fool the discriminator.

In [None]:
class Generator(nn.Module):
    def __init__(self, latent_dim=20, hidden_dim=400, output_dim=784):
        super().__init__()
        self.fc1 = nn.Linear(latent_dim, hidden_dim)
        self.act1 = nn.GELU()
        
        self.fc2 = nn.Linear(hidden_dim, output_dim)
        self.act2 = nn.Sigmoid()
    
    def forward(self, x):
        x = self.act1(self.fc1(x))
        x = self.act2(self.fc2(x))
        return x

**Implementing the Discriminator class**

Discriminator $D$ receives an image and outputs a probability that it is real (from data) rather than fake (from $G$).

In [None]:
class Discriminator(nn.Module):
    def __init__(self, input_dim=784, hidden_dim=400):
        super().__init__()
        self.fc1 = nn.Linear(input_dim, hidden_dim)
        self.act1 = nn.GELU()
        
        self.fc2 = nn.Linear(hidden_dim, 1)
        self.act2 = nn.Sigmoid()
    
    def forward(self, x):
        x = self.act1(self.fc1(x))
        x = self.act2(self.fc2(x))
        return x

**Loss function**

We formalize GAN's training with a minimax objective:
$$
\min_{G} \max_{D} \left[ \mathbb{E}_{\pi(\mathbf{x})} \log D(\mathbf{x}) + \mathbb{E}_{p(\mathbf{z})} \log (1 - D(G(\mathbf{z}))) \right]
$$

Consider a binary cross entropy (BCE):
$$
\text{BCE}(\hat{y}, y) = - \left[ y \cdot \log \hat{y} + (1 - y) \cdot \log (1 - \hat{y}) \right]
$$

Using BCE, the discriminator loss (calculated on minibatch) becomes:

$$
\mathcal{L}_D = \text{BCE}(D(\mathbf{x}), 1) + \text{BCE}(D(G(\mathbf{z})), 0)
$$

If we followed the original minimax formula, the generator would minimize:

$$
\mathbb{E}_{p(\mathbf{z})} \log (1 - D(G(\mathbf{z})))
$$

But this saturates when $D$ is good at rejecting fakes, leading to weak gradients.

Instead, we use the **non-saturating trick**: maximize $\log D(G(\mathbf{z}))$, or in BCE minibatch notation:

$$
\mathcal{L}_G = \text{BCE}(D(G(\mathbf{z})), 1)
$$

**Training and evaluating the GAN**

Now that we’ve defined the GAN model, so it’s time to train it.

We’ll run the training loop for several epochs, calculate the loss, and visualize how well the GAN learns to generate data.

In [None]:
epochs = 10
learning_rate = 1e-3

generator = Generator().to(device)
discriminator = Discriminator().to(device)

criterion = nn.BCELoss()

optimizer_generator = optim.AdamW(generator.parameters(), lr=learning_rate)
optimizer_discriminator = optim.AdamW(discriminator.parameters(), lr=learning_rate)

generator.train()
discriminator.train()

for epoch in tqdm(range(epochs), desc="Epoch"):
    for x_real, _ in train_loader:
        x_real = x_real.view(-1, 784).to(device)
        bs = x_real.shape[0]
        
        labels_real = torch.ones(bs, 1).to(device)
        labels_fake = torch.zeros(bs, 1).to(device)
        
        # Discriminator step
        y_real = discriminator(x_real)
        loss_discriminator_real = criterion(y_real, labels_real)
        
        z = torch.randn(bs, 20).to(device)
        x_fake = generator(z).detach()  # NOTE: stop gradient for generator here
        y_fake = discriminator(x_fake)
        loss_discriminator_fake = criterion(y_fake, labels_fake)
        
        loss_discriminator = loss_discriminator_real + loss_discriminator_fake
        loss_discriminator.backward()
        optimizer_discriminator.step()
        optimizer_discriminator.zero_grad()
        
        # Generator step
        z = torch.randn(bs, 20).to(device)
        x_fake = generator(z)
        
        y_fake = discriminator(x_fake)
        loss_generator = criterion(y_fake, labels_real)
        loss_generator.backward()
        optimizer_generator.step()
        optimizer_generator.zero_grad()

**Sampling from the latent space**

We’ll sample random points from a standard normal distribution and feed them through the generator to generate synthetic images.

In [None]:
@torch.no_grad()
def sample(generator, num_samples):
    generator.eval()
    z = torch.randn(num_samples, 20).to(device)
    x = generator(z).cpu()
    x = x.view(-1, 1, 28, 28)
    return x

In [None]:
num_samples = 16

x = sample(generator, num_samples)
    
fig, axes = plt.subplots(2, num_samples // 2, figsize=(12, 4)) 
for i, ax in enumerate(axes.flat): 
  ax.imshow(x[i][0], cmap='gray') 
  ax.axis('off') 

plt.suptitle("Generated Samples from Latent Space") 
plt.show() 

**Latent space interpolation**

We can also interpolate between two points in the latent space to see smooth transitions in generated images, which is a hallmark of a well-trained VAE.

In [None]:
@torch.no_grad()
def interpolate(generator, z1, z2, steps=10):
    generator.eval()
    z = torch.stack([
        z1 * (1 - t) + z2 * t for t in torch.linspace(0, 1, steps)
    ]).to(device)
    x = generator(z.to(device)).cpu()
    x = x.view(-1, 1, 28, 28)
    return x

In [None]:
z1 = torch.randn(1, 20)
z2 = torch.randn(1, 20)
steps = 10

x = interpolate(generator, z1, z2, steps)

fig, axes = plt.subplots(1, steps, figsize=(15, 2))
for i, ax in enumerate(axes.flat):
  ax.imshow(x[i][0], cmap="gray")
  ax.axis("off")

plt.suptitle("Latent Space Interpolation")  
plt.show() 