# Getting Started

## General Tips
In each homework problem, you will implement various autoencoder models and run them on two datasets (dataset 1 and dataset 2). The expected outputs for dataset 1 are already provided to help as a sanity check.

Feel free to print whatever output (e.g. debugging code, training code, etc) you want, as the graded submission will be the submitted pdf with images.

After you complete the assignment, download all of the image outputted in the results/ folder and upload them to the figure folder in the given latex template.

Run the cells below to download and load up the starter code. It may take longer to run since we are using larger datasets.

In [1]:
# get to the parent dir of mai_dul repo
import os
os.chdir('../../')
os.getcwd()

'/Users/gokul/Study/Unsupervised-DL/Git/MAI_DUL_WS24'

In [None]:
# install latest version deepul package
!pip install -e .

In [2]:
from deepul.hw2_helper import *

# Question 3: VQ-VAE [40pts]
In this question, you with train a [VQ-VAE](https://arxiv.org/abs/1711.00937) on the SVHN and CIFAR10. If you are confused on how the VQ-VAE works, you may find [Lilian Weng's blogpost](https://lilianweng.github.io/lil-log/2018/08/12/from-autoencoder-to-beta-vae.html#vq-vae-and-vq-vae-2) to be useful.

You may experiment with different hyperparameters and architecture designs, but the following designs for the VQ-VAE architecture may be useful.

```
conv2d(in_channels, out_channels, kernel_size, stride, padding)
transpose_conv2d(in_channels, out_channels, kernel_size, stride, padding)
linear(in_dim, out_dim)
batch_norm2d(dim)

residual_block(dim)
    batch_norm2d(dim)
    relu()
    conv2d(dim, dim, 3, 1, 1)
    batch_norm2d(dim)
    relu()
    conv2d(dim, dim, 1, 1, 0)

Encoder
    conv2d(3, 256, 4, 2, 1) 16 x 16
    batch_norm2d(256)
    relu()
    conv2d(256, 256, 4, 2, 1) 8 x 8
    residual_block(256)
    residual_block(256)

Decoder
    residual_block(256)
    residual_block(256)
    batch_norm2d(256)
    relu()
    transpose_conv2d(256, 256, 4, 2, 1) 16 x 16
    batch_norm2d(256)
    relu()
    transpose_conv2d(256, 3, 4, 2, 1) 32 x 32
```

A few other tips:
*   Use a codebook with $K = 128$ latents each with a $D = 256$ dimensional embedding vector
*   You should initialize each element in your $K\times D$ codebook to be uniformly random in $[-1/K, 1/K]$
*   Use batch size 128 with a learning rate of $10^{-3}$ and an Adam optimizer
*   Center and scale your images to $[-1, 1]$
*   Supposing that $z_e(x)$ is the encoder output, and $z_q(x)$ is the quantized output using the codebook, you can implement the straight-through estimator as follows (where below is fed into the decoder):
  * `(z_q(x) - z_e(x)).detach() + z_e(x)` in Pytorch
  * `tf.stop_gradient(z_q(x) - z_e(x)) + z_e(x)` in Tensorflow.

In addition to training the VQ-VAE, you will also need to train a Transformer prior on the categorical latents in order to sample. Feel free to use your implementation for HW1! You should flatten the VQ-VAE tokens into a [H x W] sequence, and use a start token.

**You will provide the following deliverables**


1.   Over the course of training, record the average loss of the training data (per minibatch) and test data (for your entire test set) **for both your VQ-VAE and Transformer prior**. Code is provided that automatically plots the training curves.
2. Report the final test set performances of your final models
3. 100 samples from your trained VQ-VAE and Transformer prior
4. 50 real-image / reconstruction pairs (for some $x$, encode and then decode)

## Solution
Fill out the function below and return the neccessary arguments. Feel free to create more cells if need be

In [3]:
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim

from torch.utils.data import DataLoader
from tqdm.notebook import tqdm
from homeworks.hw2.vae import VQVAE, TransformerPrior

In [4]:
def evaluate_vqvae(model, loader, device):
    model.eval()
    total_loss = 0

    with torch.no_grad():
        for x, _ in loader:
            x = x * 2 - 1
            x = x.to(device)
            x_recon, commitment_loss, codebook_loss, _ = model(x)
            loss = F.mse_loss(x_recon, x) + 0.25 * commitment_loss + 0.25 * codebook_loss
            total_loss += loss.item()
    
    return total_loss / len(loader)

In [5]:
def train_vqvae(model, train_loader, test_loader, device, num_epochs=50, learning_rate=1e-3):
    optimizer = optim.Adam(model.parameters(), lr=learning_rate)
    train_losses = []
    test_losses = [evaluate_vqvae(model, test_loader, device)]  # Initial test loss
    
    for epoch in range(num_epochs):
        model.train()
        pbar = tqdm(train_loader, unit='batch')

        for batch_idx, (x, _) in enumerate(pbar):
            x = x * 2 - 1  # Scale to [-1, 1]
            x = x.to(device)
            
            optimizer.zero_grad()
            x_recon, commitment_loss, codebook_loss, _ = model(x)
            
            recon_loss = F.mse_loss(x_recon, x)
            loss = recon_loss + 0.25 * commitment_loss + 0.25 * codebook_loss
            
            loss.backward()
            optimizer.step()
            
            train_losses.append(loss.item())
            pbar.set_description(desc=f"batch_loss={loss.item():.4f}")
        
        # Evaluate on test set
        test_loss = evaluate_vqvae(model, test_loader, device)
        test_losses.append(test_loss)
        
    return np.array(train_losses), np.array(test_losses)

In [6]:
def evaluate_transformer(vqvae, transformer, loader, device):
    transformer.eval()
    criterion = nn.CrossEntropyLoss()
    total_loss = 0

    with torch.no_grad():
        for x, _ in loader:
            x = x.to(device)
            indices = vqvae.encode(x)
            logits = transformer(indices)
            
            targets = indices.view(indices.shape[0], -1)
            targets = torch.cat([targets[:, 1:], torch.full((targets.shape[0], 1), -100, device=device)], dim=1)
            
            loss = criterion(logits[:, :-1].reshape(-1, logits.shape[-1]), targets.reshape(-1))
            total_loss += loss.item()
    
    return total_loss / len(loader)

In [7]:
def train_transformer(vqvae, transformer, train_loader, test_loader, device, num_epochs=50, learning_rate=1e-3):
    optimizer = optim.Adam(transformer.parameters(), lr=learning_rate)
    criterion = nn.CrossEntropyLoss()
    
    train_losses = []
    test_losses = [evaluate_transformer(vqvae, transformer, test_loader, device)]
    
    for epoch in range(num_epochs):
        transformer.train()
        pbar = tqdm(train_loader, unit='batch')
        
        for batch_idx, (x, _) in enumerate(pbar):
            x = x.to(device)
            
            # Get VQ-VAE encodings
            with torch.no_grad():
                indices = vqvae.encode(x)
            
            # Train transformer
            optimizer.zero_grad()
            logits = transformer(indices)
            
            # Shift predictions one step
            targets = indices.view(indices.shape[0], -1)
            targets = torch.cat([targets[:, 1:], torch.full((targets.shape[0], 1), -100, device=device)], dim=1)
            
            loss = criterion(logits[:, :-1].reshape(-1, logits.shape[-1]), targets.reshape(-1))
            
            loss.backward()
            optimizer.step()
            
            train_losses.append(loss.item())
            pbar.set_description(desc=f"batch_loss={loss.item():.4f}")
        
        # Evaluate
        test_loss = evaluate_transformer(vqvae, transformer, test_loader, device)
        test_losses.append(test_loss)
    
    return np.array(train_losses), np.array(test_losses)

In [8]:
@torch.no_grad()
def sample_images(vqvae, transformer, device, n_samples=100):
    vqvae.eval()
    transformer.eval()
    
    samples = []
    for _ in range(n_samples):
        # Start with start token
        curr_sequence = torch.full((1, 1), 128, device=device)
        
        # Generate sequence
        for _ in range(64):  # 8x8 latents
            logits = transformer(curr_sequence)
            probs = F.softmax(logits[:, -1], dim=-1)
            next_token = torch.multinomial(probs, 1)
            curr_sequence = torch.cat([curr_sequence, next_token], dim=1)
        
        # Remove start token and reshape
        indices = curr_sequence[:, 1:].view(1, 8, 8)
        
        # Decode
        sample = vqvae.decode(indices)
        sample = (sample + 1) / 2  # Scale to [0, 1]
        sample = sample.clamp(0, 1)
        sample = (sample * 255).cpu().numpy().transpose(0, 2, 3, 1).astype(np.uint8)
        samples.append(sample[0])
    
    return np.stack(samples)

In [9]:
@torch.no_grad()
def get_reconstructions(vqvae, test_loader, device, n_samples=50):
    vqvae.eval()
    pairs = []
    
    for x, _ in test_loader:
        if len(pairs) >= n_samples:
            break
            
        x = x.to(device)
        x_recon = vqvae(x * 2 - 1)[0]
        x_recon = (x_recon + 1) / 2
        x_recon = x_recon.clamp(0, 1)
        
        for i in range(x.shape[0]):
            if len(pairs) >= n_samples:
                break
                
            orig = (x[i] * 255).cpu().numpy().transpose(1, 2, 0).astype(np.uint8)
            recon = (x_recon[i] * 255).cpu().numpy().transpose(1, 2, 0).astype(np.uint8)
            pairs.extend([orig, recon])
    
    return np.stack(pairs)

In [10]:
def q3(train_data, test_data, dset_id):
    """
    train_data: torch dataset with (n_train, 3, 32, 32) color images as tensors with 256 values rescaled to [0, 1]
    test_data: torch dataset with (n_test, 3, 32, 32) color images as tensors with 256 values rescaled to [0, 1]
    dset_id: An identifying number of which dataset is given (1 or 2). Most likely
               used to set different hyperparameters for different datasets

    Returns
    - a (# of training iterations,) numpy array of VQ-VAE train losess evaluated every minibatch
    - a (# of epochs + 1,) numpy array of VQ-VAE test losses evaluated once at initialization and after each epoch
    - a (# of training iterations,) numpy array of Transformer prior train losess evaluated every minibatch
    - a (# of epochs + 1,) numpy array of Transformer prior test losses evaluated once at initialization and after each epoch
    - a (100, 32, 32, 3) numpy array of 100 samples with values in {0, ... 255}
    - a (100, 32, 32, 3) numpy array of 50 real image / reconstruction pairs
      FROM THE TEST SET with values in [0, 255]
    """
    torch.manual_seed(1)

    # Set device
    device = torch.device("cuda" if torch.cuda.is_available() 
                         else "mps" if torch.backends.mps.is_available() 
                         else "cpu")

    print(f"Using device: {device}")

    # DataLoader settings based on device
    kwargs = {'num_workers': 8, 'pin_memory': False} if torch.cuda.is_available() else \
             {'num_workers': 0} if torch.backends.mps.is_available() else \
             {}

    # Hyperparameters
    batch_size = 128
    learning_rate = 1e-3
    num_epochs = 1
    
    # Create data loaders
    train_loader = DataLoader(train_data, batch_size=batch_size, shuffle=True, **kwargs)
    test_loader = DataLoader(test_data, batch_size=batch_size, shuffle=False, **kwargs)
    
    # Initialize models
    vqvae = VQVAE().to(device)
    transformer = TransformerPrior().to(device)
    
    # Train VQ-VAE
    vqvae_train_losses, vqvae_test_losses = train_vqvae(
        vqvae, train_loader, test_loader, device, num_epochs, learning_rate
    )
    
    # Train Transformer
    transformer_train_losses, transformer_test_losses = train_transformer(
        vqvae, transformer, train_loader, test_loader, device, num_epochs, learning_rate
    )
    
    # Generate samples
    samples = sample_images(vqvae, transformer, device, n_samples=100)
    
    # Get reconstructions
    reconstructions = get_reconstructions(vqvae, test_loader, device, n_samples=50)
    
    return (
        vqvae_train_losses,
        vqvae_test_losses,
        transformer_train_losses,
        transformer_test_losses,
        samples,
        reconstructions
    )

## Results
Once you've finished `q3`, execute the cells below to visualize and save your results.

In [11]:
q3_save_results(1, q3)

Using downloaded and verified file: homeworks/hw2/data/train_32x32.mat
Using downloaded and verified file: homeworks/hw2/data/test_32x32.mat
Using device: mps




  0%|          | 0/573 [00:00<?, ?batch/s]

RuntimeError: The size of tensor a (2) must match the size of tensor b (65) at non-singleton dimension 1

In [None]:
q3_save_results(2, q3)