# Deep Learning & Applied AI

We recommend going through the notebook using Google Colaboratory.

# Tutorial 8b: VQ-VAEs

In this tutorial, we will cover a modern variant of VAEs that triggered a great shift in generative AI!

Author:

- Prof. Emanuele Rodolà

Course:

- Website and notebooks will be available at https://github.com/erodola/DLAI-s2-2025/

# Exercise 1: From VAEs to VQ-VAEs

Welcome! In this lab, you will implement your first **Vector-Quantized Variational Autoencoder (VQ-VAE)**, an important extension of VAEs that combines ideas from **discrete latent variables** and **vector quantization**.

You already know that in a standard VAE, the encoder maps an image to a single continuous latent vector, typically sampled from a Gaussian distribution.

VQ-VAEs change this in two key ways:

1. Instead of a single continuous vector, **each image is mapped to a small 2D grid of discrete codes** (one code per local region).

2. Instead of Gaussian sampling, each code is selected from a fixed learned **codebook of discrete embeddings**.

This method is simple yet very powerful: it allows models to learn more structured, symbolic, and compressed representations.  

VQ-VAEs are especially useful for tasks like generation (e.g., speech, images, and video), and later inspired models like DALL·E and VQVAE-2!

In this lab, you will **read the original VQ-VAE paper**:  
> **[VQ-VAE: Neural Discrete Representation Learning](https://arxiv.org/abs/1711.00937)** by Oord et al. (2017)

As you read and implement, focus on these **key differences** compared to classic VAEs:
- **No KL divergence**: no need for continuous latent regularization
- **Discrete codebook lookup**: latents snap to nearest embedding vectors.
- **Latent space is spatial**: encoder outputs a grid (e.g., 7×7), not just a vector.
- **Commitment loss**: a special loss term encourages the encoder to commit to a single embedding without fluctuating too much.
- **Straight-through estimator**: allows backpropagation through discrete code assignments.

During this lab, you will:
- Implement a simple VQ layer from scratch.
- Train a basic VQ-VAE on the MNIST dataset.
- Visualize how well your VQ-VAE can reconstruct digits using only discrete latents!

Good luck — and have fun exploring this new generation of representation learning! 🚀

In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torchvision import datasets, transforms
import matplotlib.pyplot as plt

In [None]:
# Load MNIST Dataset

transform = transforms.Compose([
    transforms.ToTensor()
])

train_loader = torch.utils.data.DataLoader(
    datasets.MNIST('./data', train=True, download=True, transform=transform),
    batch_size=128, shuffle=True
)

Below you'll find a barebones implementation of convolutional `Encoder` and `Decoder` to start experimenting (should be enough for MNIST digits). Feel free to change to your own architecture!

In [None]:
class Encoder(nn.Module):
    def __init__(self, hidden_dim=128, embedding_dim=64):
        super().__init__()
        self.conv1 = nn.Conv2d(1, hidden_dim, 4, stride=2, padding=1)
        self.conv2 = nn.Conv2d(hidden_dim, embedding_dim, 4, stride=2, padding=1)

    def forward(self, x):
        x = F.relu(self.conv1(x))
        x = self.conv2(x)
        return x

class Decoder(nn.Module):
    def __init__(self, embedding_dim=64, hidden_dim=128):
        super().__init__()
        self.conv1 = nn.ConvTranspose2d(embedding_dim, hidden_dim, 4, stride=2, padding=1)
        self.conv2 = nn.ConvTranspose2d(hidden_dim, 1, 4, stride=2, padding=1)

    def forward(self, x):
        x = F.relu(self.conv1(x))
        x = torch.sigmoid(self.conv2(x))
        return x

Now, focus on completing the `VectorQuantizer`'s `forward` method below.

You need to compute distances, find nearest embeddings, and apply the straight-through estimator. Don't worry about getting perfect reconstructions — if the digits look recognizable, you're good.

In [None]:
# Vector Quantizer (✏️ your solution here)

class VectorQuantizer(nn.Module):
    def __init__(self, num_embeddings, embedding_dim, commitment_cost):
        super().__init__()
        self.embedding_dim = embedding_dim
        self.num_embeddings = num_embeddings
        self.commitment_cost = commitment_cost

        self.embeddings = nn.Embedding(num_embeddings, embedding_dim)
        self.embeddings.weight.data.uniform_(-1/self.num_embeddings, 1/self.num_embeddings)

    def forward(self, z):
        # z: (batch, channel, height, width)
        z_flattened = z.permute(0, 2, 3, 1).contiguous()
        z_flattened = z_flattened.view(-1, self.embedding_dim)

        # TODO: Compute distances (L2) between z_flattened and embeddings
        # distances = ✏️

        # TODO: Find closest embeddings
        # encoding_indices = ✏️
        # encodings = ✏️

        # TODO: Quantize and unflatten
        # quantized = ✏️

        # TODO: Compute loss
        # e_latent_loss = ✏️
        # loss = ✏️

        # Straight-through estimator trick
        # quantized = z + (quantized - z).detach()

        # return quantized, loss
        raise NotImplementedError("Fill in the VectorQuantizer forward pass!")

In [None]:
# @title Solution 👀

class VectorQuantizer(nn.Module):
    def __init__(self, num_embeddings, embedding_dim, commitment_cost):
        super().__init__()
        self.embedding_dim = embedding_dim
        self.num_embeddings = num_embeddings
        self.commitment_cost = commitment_cost

        self.embeddings = nn.Embedding(num_embeddings, embedding_dim)
        self.embeddings.weight.data.uniform_(-1/self.num_embeddings, 1/self.num_embeddings)

    def forward(self, z):
        # z: (batch, channel, height, width)
        z_flattened = z.permute(0, 2, 3, 1).contiguous()
        z_flattened = z_flattened.view(-1, self.embedding_dim)

        # Compute distances
        distances = (torch.sum(z_flattened**2, dim=1, keepdim=True)
                     + torch.sum(self.embeddings.weight**2, dim=1)
                     - 2 * torch.matmul(z_flattened, self.embeddings.weight.t()))

        # Find nearest embeddings
        encoding_indices = torch.argmin(distances, dim=1)
        encodings = F.one_hot(encoding_indices, self.num_embeddings).float()

        # Quantize
        quantized = torch.matmul(encodings, self.embeddings.weight)
        quantized = quantized.view(z.shape[0], z.shape[2], z.shape[3], self.embedding_dim)
        quantized = quantized.permute(0, 3, 1, 2).contiguous()

        # Loss
        e_latent_loss = F.mse_loss(quantized.detach(), z)  # push encoder outputs closer to embeddings
        q_latent_loss = F.mse_loss(quantized, z.detach())  # improve embeddings to match encoder outputs
        loss = q_latent_loss + self.commitment_cost * e_latent_loss

        # Straight-through estimator
        quantized = z + (quantized - z).detach()

        return quantized, loss

You can test your `VectorQuantizer` layer in the full model below:

In [None]:
# Full VQ-VAE Model

class VQVAE(nn.Module):
    def __init__(self, hidden_dim=128, embedding_dim=64, num_embeddings=512, commitment_cost=0.25):
        super().__init__()
        self.encoder = Encoder(hidden_dim, embedding_dim)
        self.vq = VectorQuantizer(num_embeddings, embedding_dim, commitment_cost)
        self.decoder = Decoder(embedding_dim, hidden_dim)

    def forward(self, x):
        z = self.encoder(x)
        quantized, vq_loss = self.vq(z)
        x_recon = self.decoder(quantized)
        return x_recon, vq_loss

And here's a training loop for you to use:

In [None]:
# Training Loop

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = VQVAE().to(device)
optimizer = optim.Adam(model.parameters(), lr=1e-3)

for epoch in range(1, 6):
    model.train()
    total_loss = 0
    for batch_idx, (data, _) in enumerate(train_loader):
        data = data.to(device)
        optimizer.zero_grad()
        recon_batch, vq_loss = model(data)
        recon_loss = F.mse_loss(recon_batch, data)
        loss = recon_loss + vq_loss
        loss.backward()
        optimizer.step()

        total_loss += loss.item()

        if batch_idx % 100 == 0:
            print(f'Epoch {epoch} [{batch_idx * len(data)}/{len(train_loader.dataset)}] Loss: {loss.item():.6f}')

    print(f'====> Epoch: {epoch} Average loss: {total_loss / len(train_loader):.4f}')

In [None]:
# Visualize Reconstructions

model.eval()
data_iter = iter(train_loader)
images, _ = next(data_iter)
images = images.to(device)

reconstructions, _ = model(images)

n = 8
plt.figure(figsize=(20, 4))
for i in range(n):
    # Original
    ax = plt.subplot(2, n, i + 1)
    plt.imshow(images[i][0].cpu().detach(), cmap='gray')
    plt.axis('off')

    # Reconstruction
    ax = plt.subplot(2, n, i + 1 + n)
    plt.imshow(reconstructions[i][0].cpu().detach(), cmap='gray')
    plt.axis('off')
plt.show()

# Exercise 2: Make it more efficient

Try to run the model again with the following tweaks:

- Smaller codebook: for example, 128 embeddings instead of 512

- Smaller embedding dimension: e.g. 32 instead of 64

- Fewer channels in the encoder/decoder: e.g. 64 instead of 128

- Train for less epochs (3 instead of 5)

You should be able to train a working VQ-VAE for MNIST in **under 20 seconds** on GPU!

# Exercise 3: Sample new images

We have seen that we can sample the latent spaces of AEs and VAEs to **generate new data points**. Can this be done for VQ-VAEs as well? Let's explore this possibility here.

First, keep in mind that the encoder downsamples the input image, e.g. to a **latent 7×7 grid**. Each position in this grid corresponds to one code from the discrete codebook.

Write the code to do the following:

- Randomly pick integers in [0, num_embeddings) to create a new latent grid.

- Map those integers to their embedding vectors.

- Reshape this back into (batch_size, embedding_dim, 7, 7).

- Finally, pass this through the decoder to generate images!

In [None]:
model.eval()

# make sure these values correspond to the ones you chose earlier
latent_height = 7
latent_width = 7
embedding_dim = 64
num_embeddings = 512

# Randomly sample latent codes, map to embeddings, decode
# ✏️

# Visualize generated images
# ✏️

In [None]:
# @title Solution 👀

model.eval()

# make sure these values correspond to the ones you chose earlier
latent_height = 7
latent_width = 7
embedding_dim = 64
num_embeddings = 512

# Randomly sample latent codes
num_samples = 8  # how many new images to generate
random_codes = torch.randint(0, num_embeddings, (num_samples, latent_height, latent_width))

# Map random codes to embeddings
# (num_samples x latent_height x latent_width x embedding_dim)
embeddings = model.vq.embeddings.weight

# Prepare quantized latent grid
quantized_latents = embeddings[random_codes]  # (num_samples, 7, 7, embedding_dim)
quantized_latents = quantized_latents.permute(0, 3, 1, 2).contiguous()  # (num_samples, embedding_dim, 7, 7)

# Decode sampled latents into images
with torch.no_grad():
    generated_images = model.decoder(quantized_latents.to(device))

# Visualize generated images
n = num_samples
plt.figure(figsize=(20, 4))
for i in range(n):
    ax = plt.subplot(1, n, i + 1)
    plt.imshow(generated_images[i][0].cpu().detach(), cmap='gray')
    plt.axis('off')
plt.suptitle('Generated Samples from Random Discrete Latents', fontsize=16)
plt.show()

If you are getting bullet holes on a wall like in old gangster movies, **don't panic!** You are safe here.

Simply put: the "holes" are valid codes and the "wall" represents untrained codes.

In VQ-VAEs, not all codes are used equally. If you sample randomly, you might pick codes that were never trained, leading to garbage outputs. Instead, sampling from **used codes** focuses generation on the parts of the latent space where the model actually learned something meaningful!

> **TL;DR** You want to avoid sampling from "dead" areas of the codebook that were never trained properly.

Therefore, you could:

1. First collect the set of indices that were used by the encoder during training.

2. Then sample from that set only.

We're doing this for you in the code block below:

In [None]:
# 1. Build a set of used code indices
# (Encode a big batch of real images to find used codes)

model.eval()
used_codes_set = set()

for i, (data, _) in enumerate(train_loader):
    data = data.to(device)
    with torch.no_grad():
        z = model.encoder(data)

    z_flattened = z.permute(0, 2, 3, 1).contiguous().view(-1, model.vq.embedding_dim)
    distances = (torch.sum(z_flattened**2, dim=1, keepdim=True)
                 + torch.sum(model.vq.embeddings.weight**2, dim=1)
                 - 2 * torch.matmul(z_flattened, model.vq.embeddings.weight.t()))
    encoding_indices = torch.argmin(distances, dim=1)

    used_codes_set.update(encoding_indices.cpu().numpy())

    if i > 10:
        break  # Only scan a few batches to save time

used_codes = torch.tensor(list(used_codes_set))

print(f"Found {len(used_codes)} used codes out of {model.vq.num_embeddings} total codes.")

# 2. Now sample from used_codes only
num_samples = 8
latent_height = 7
latent_width = 7

random_codes = used_codes[torch.randint(0, len(used_codes), (num_samples, latent_height, latent_width))]

# 3. Map sampled codes back to embeddings
embeddings = model.vq.embeddings.weight

quantized_latents = embeddings[random_codes]  # (num_samples, 7, 7, embedding_dim)
quantized_latents = quantized_latents.permute(0, 3, 1, 2).contiguous()

# 4. Decode
with torch.no_grad():
    generated_images = model.decoder(quantized_latents.to(device))

# 5. Visualize
n = num_samples
plt.figure(figsize=(20, 4))
for i in range(n):
    ax = plt.subplot(1, n, i + 1)
    plt.imshow(generated_images[i][0].cpu().detach(), cmap='gray')
    plt.axis('off')
plt.suptitle('Generated Samples (from used latent codes)', fontsize=16)
plt.show()

No more holes, but looks like garbage! **What's happening now?**

We're sampling codes that were actually used by the encoder, but still **assembling them randomly without any spatial coherence**.

It's exactly why full generative models like VQ-VAE-2 or DALL·E use autoregressive models over the codes to sample coherent patterns. We'll study autoregressive models later down the course!

In practice, in the original VQ-VAE paper, they train a PixelCNN model on the discrete latent codes to sample spatially coherent code grids.

We come to the following, very important take-home message:

> ***Separate learning representations (e.g. VQ-VAE) and modeling their distribution (e.g. PixelCNN).***

# Exercise 4: Interpolate between two images in the latent space!

In this final task, you'll do the prototypical test from AE literature, with the VQ-VAE twist:

- Pick two real images from the dataset.

- Encode both images into their discrete latent codes.

- Create a sequence of blended latent codes between them.

- Decode each blend and visualize the transformation!

Note that since the latent space is discrete, you can't simply take an average — but you can crossfade between the two codes patch-by-patch!

In [None]:
# Pick two images from the dataset

model.eval()
data_iter = iter(train_loader)
images, _ = next(data_iter)
images = images.to(device)

img1 = images[0:1]
img2 = images[10:11]

# Visualize originals

plt.figure(figsize=(4,2))
plt.subplot(1,2,1)
plt.title('Image 1')
plt.imshow(img1[0][0].cpu().detach(), cmap='gray')
plt.axis('off')

plt.subplot(1,2,2)
plt.title('Image 2')
plt.imshow(img2[0][0].cpu().detach(), cmap='gray')
plt.axis('off')
plt.show()

# Encode both images
# ✏️

# Flatten and find nearest embeddings
# ✏️

# Interpolate between codes
num_steps = 8
interpolated_images = []
# ✏️

# Plot interpolation results
plt.figure(figsize=(20, 4))
for i, img in enumerate(interpolated_images):
    ax = plt.subplot(1, num_steps, i + 1)
    plt.imshow(img[0][0].cpu().detach(), cmap='gray')
    plt.axis('off')
plt.suptitle('Latent Space Interpolation', fontsize=16)
plt.show()

In [None]:
# @title Solution 👀

# Pick two images from the dataset

model.eval()
data_iter = iter(train_loader)
images, _ = next(data_iter)
images = images.to(device)

img1 = images[0:1]
img2 = images[10:11]

# Visualize originals

plt.figure(figsize=(4,2))
plt.subplot(1,2,1)
plt.title('Image 1')
plt.imshow(img1[0][0].cpu().detach(), cmap='gray')
plt.axis('off')

plt.subplot(1,2,2)
plt.title('Image 2')
plt.imshow(img2[0][0].cpu().detach(), cmap='gray')
plt.axis('off')
plt.show()

# Encode both images
with torch.no_grad():
    z1 = model.encoder(img1)
    z2 = model.encoder(img2)

# Flatten and find nearest embeddings
def encode_to_indices(z, model):
    z_flat = z.permute(0,2,3,1).contiguous().view(-1, model.vq.embedding_dim)
    distances = (torch.sum(z_flat**2, dim=1, keepdim=True)
                 + torch.sum(model.vq.embeddings.weight**2, dim=1)
                 - 2 * torch.matmul(z_flat, model.vq.embeddings.weight.t()))
    indices = torch.argmin(distances, dim=1)
    return indices.view(7, 7)

codes1 = encode_to_indices(z1, model)
codes2 = encode_to_indices(z2, model)

# Interpolate between codes
num_steps = 8
interpolated_images = []

for alpha in torch.linspace(0, 1, steps=num_steps):
    # For each location, randomly pick from img1 or img2 codes based on alpha
    mask = (torch.rand_like(codes1.float()) < alpha).long()
    blended_codes = codes1 * (1 - mask) + codes2 * mask

    # Map codes to embeddings
    blended_latents = model.vq.embeddings(blended_codes.flatten())
    blended_latents = blended_latents.view(1, 7, 7, model.vq.embedding_dim)
    blended_latents = blended_latents.permute(0, 3, 1, 2).contiguous()

    # Decode to image
    with torch.no_grad():
        img = model.decoder(blended_latents.to(device))
    interpolated_images.append(img)

# 4. Plot interpolation results
plt.figure(figsize=(20, 4))
for i, img in enumerate(interpolated_images):
    ax = plt.subplot(1, num_steps, i + 1)
    plt.imshow(img[0][0].cpu().detach(), cmap='gray')
    plt.axis('off')
plt.suptitle('Latent Space Interpolation', fontsize=16)
plt.show()

**You've just built the core of a modern generative model.** 🚀

With a few more steps, like learning a prior over the discrete codes, you could generate entirely new images, just like models behind DALL·E!

You're closer to state-of-the-art generative AI than you might think. Keep experimenting — the next breakthrough could come from you!