# Purpose and Objectives:

You have probably just arrived from getting a thorough understanding of the architectures and capabilities of Autoencoders and Variational Autoencoders! Now, we will look at something more state-of-the-art, the **VQ-VAE** model.

After completing this notebook, you will:
1. Have a thorough understanding of the architecture and capabilites of the VQ-VAE
2. Be able to implement the VQ-VAE to reconstruct CIFAR-10 images
3. Evaluate the models' performance and explore a plethora of visualizations and ablations

## Motivation for VQ-VAE

Recall, that we motivated the VAE architecture by saying that we would like to regularize the latent space of the AE. We modeled the latent space as being sampled from a Gaussian distribution. This is a convenient and useful interpretation, but can we do better? Yes!

![coco_instance_segmentation.jpeg](attachment:4bad17a6-b078-4a21-977d-4f53fd91eacc.jpeg)

<p style="text-align: center;">Fig 1. An image taken from the COCO dataset </p>

[Source](https://manipulation.csail.mit.edu/segmentation.html)

The above image, taken from the COCO dataset, illustrates a very important facet of much of the data that is of interest to the deep learning community. Whether it is language, scene images, or audio files, many semantically meaningful data have *discrete elements within them.*

Thus, a logical new improvement to our VAE bottleneck would be to **sample from a discrete distribution instead of a continuous one**. That is precisely the main motivation for the VQ-VAE artchitecture.

### Concept Check 3-3.1

In [45]:
import random
import math

import numpy as np

import torch
import torchvision
import matplotlib.pyplot as plt
import torch.nn as nn
import torch.nn.functional as F
from torchvision import transforms
from torch.distributions import MultivariateNormal, Normal, Independent
from torchvision.utils import make_grid

from tqdm import tqdm

In [3]:
device = 'cuda' if torch.cuda.is_available() else 'cpu'
device

'cuda'

In [4]:
#Reproducability Checks:
random.seed(0) #Python
torch.manual_seed(0) #Torch
np.random.seed(0) #NumPy

## Implementing VQ-VAE

### Residual Layers For Image Processing

In [5]:
class ResidualLayerBlock(nn.Module):
    def __init__(self, in_dim, h_dim, res_h_dim):
        super().__init__()
        self.block = nn.Sequential(
            nn.ReLU(),
            nn.Conv2d(in_channels=in_dim, out_channels = res_h_dim, kernel_size=3, stride=1, padding=1),
            nn.BatchNorm2d(res_h_dim),
            nn.ReLU(),
            nn.Conv2d(in_channels=res_h_dim, out_channels=h_dim, kernel_size=1, stride=1, padding=0),
            nn.BatchNorm2d(h_dim)
        )

    def forward(self, x):
        out = x + self.block(x)
        return out

In [6]:
class ResidualLayers(nn.Module):
    def __init__(self, in_dim, h_dim, res_h_dim, n_res_layers):
        super().__init__()
        self.n_res_layers = n_res_layers
        self.layers = nn.ModuleList(
            [ResidualLayerBlock(in_dim, h_dim, res_h_dim)] * n_res_layers
        )

    def forward(self, x):
        for layer in self.layers:
            x = layer(x)
        return F.relu(x)

## Encoder

In [7]:
class Encoder(nn.Module):
    def __init__(self, in_dim, h_dim, res_h_dim, n_res_layers):
        super().__init__()
        kernel = 4
        stride = 2
        #Maybe remove batch norms?
        self.conv_block = nn.Sequential(
            nn.Conv2d(in_dim, h_dim // 2, kernel_size=4, stride=2, padding=1),
            nn.BatchNorm2d(h_dim // 2),
            nn.ReLU(),
            nn.Conv2d(h_dim // 2, h_dim, kernel_size=4, stride=2, padding=1),
            nn.BatchNorm2d(h_dim),
            nn.ReLU(),
            nn.Conv2d(h_dim, h_dim, kernel_size = 3, stride = 1, padding = 1),
            nn.BatchNorm2d(h_dim),
            ResidualLayers(h_dim, h_dim, res_h_dim, n_res_layers)
        ) 

    def forward(self, x):
        return self.conv_block(x)

## [Important] Vector Quantizer:

Recall that loss is defined as:

$$Recon(x, \hat{x}) + ||sg(z_e(x)) - e||_2^2 + \beta ||z_e(x) - e||_2^2$$

Where the first term is our codebook loss to keep our codebook close to our encoder outputs, and the third term is keep our codebook committed to a discrete distribution. Note that we are going to be quantizing over each individual channel pixel, so we we sample a channel from the discrete distribution. If the original image size is say for example $[64, 3, 32, 32]$, we first rearrange it to have each element be a channel pixel, i.e $[64, 32, 32, 3]$, then flatten it to get each individuala channel yielding $[64 * 32 * 32, 3]$. Now we have $64 * 32 * 32$ seperate channel pixels to query from the latent distribution, and map them to its closest latent distribution value. It should be noted that the outputs from the gradient wont have any actual gradients, because of the nearest neighbor sampling. To deal with this, we copy the gradient from the outputs to the inputs. In other words, we do:

$$e := z_e(x) + sg(e - z_e(x))$$

This allows the gradients of $e$ to be copied over to $z_e(x)$ and backprop to our encoder.

### Concept Check 3.2.1-3.2.3

In [8]:
class VectorQuantizer(nn.Module):
    def __init__(self, num_embeddings, embedding_dim, commitment_cost):
        super(VectorQuantizer, self).__init__()
        
        self._embedding_dim = embedding_dim
        self._num_embeddings = num_embeddings
        
        self._embedding = nn.Embedding(self._num_embeddings, self._embedding_dim)
        self._embedding.weight.data.uniform_(-1/self._num_embeddings, 1/self._num_embeddings)
        self._commitment_cost = commitment_cost

    def forward(self, inputs):
        # convert inputs from BCHW -> BHWC
        inputs = inputs.permute(0, 2, 3, 1).contiguous()
        input_shape = inputs.shape
        
        # Flatten input
        flat_input = inputs.view(-1, self._embedding_dim)
        
        # Calculate distances
        distances = (torch.sum(flat_input**2, dim=1, keepdim=True) 
                    + torch.sum(self._embedding.weight**2, dim=1)
                    - 2 * torch.matmul(flat_input, self._embedding.weight.t()))
            
        # Encoding
        encoding_indices = torch.argmin(distances, dim=1).unsqueeze(1) 
        encodings = torch.zeros(encoding_indices.shape[0], self._num_embeddings, device=inputs.device)
        encodings.scatter_(1, encoding_indices, 1)
        
        # Quantize and unflatten
        quantized = torch.matmul(encodings, self._embedding.weight).view(input_shape)
        
        # Loss
        e_latent_loss = F.mse_loss(quantized.detach(), inputs)
        q_latent_loss = F.mse_loss(quantized, inputs.detach())
        loss = q_latent_loss + self._commitment_cost * e_latent_loss
        
        quantized = inputs + (quantized - inputs).detach() #allows for copied gradients
        avg_probs = torch.mean(encodings, dim=0)
        perplexity = torch.exp(-torch.sum(avg_probs * torch.log(avg_probs + 1e-10)))
        
        # convert quantized from BHWC -> BCHW
        return loss, quantized.permute(0, 3, 1, 2).contiguous(), perplexity, encodings

## The Decoder Class

Uses the sampled discrete distribution and reconstruct the image.

In [9]:
class Decoder(nn.Module):
    def __init__(self, in_dim, h_dim, n_res_layers, res_h_dim):
        super(Decoder, self).__init__()
        kernel = 4
        stride = 2

        self.inverse_conv_stack = nn.Sequential(
            nn.ConvTranspose2d(
                in_dim, h_dim, kernel_size=kernel-1, stride=stride-1, padding=1),
            ResidualLayers(h_dim, h_dim, res_h_dim, n_res_layers),
            nn.ConvTranspose2d(h_dim, h_dim // 2,
                               kernel_size=kernel, stride=stride, padding=1),
            nn.ReLU(),
            nn.ConvTranspose2d(h_dim//2, 3, kernel_size=kernel,
                               stride=stride, padding=1)
        )

    def forward(self, x):
        return self.inverse_conv_stack(x)

## The VQAE Module

This is the main module for the VQ-VAE. The model consists of 

$$Encoder \to VectorQuantitization \to Decoder \to \hat{x}$$

Each image will have its own discrete mapping. Let $f_{vq}(x)$ be the discrete mapping from an image $x$ to its discrete mapping. Once the model is trained, we will have good mappings for $x \to f_{vq}(x)$. However, to actually sample from the VQ-VAE, we need to learn patterns from the learned $f_{vq}(x)$. Once we can sample $f_{vq}$ perhaps using PixelCNN, we can send it into the decoder.

In [10]:
class VQVAE(nn.Module):
    def __init__(self, h_dim, res_h_dim, n_res_layers,
                 n_embeddings, embedding_dim, beta, save_img_embedding_map=False):
        super(VQVAE, self).__init__()
        # encode image into continuous latent space
        self.encoder = Encoder(3, h_dim, n_res_layers, res_h_dim)
        self.pre_quantization_conv = nn.Conv2d(
            h_dim, embedding_dim, kernel_size=1, stride=1)
        # pass continuous latent vector through discretization bottleneck
        self.vector_quantization = VectorQuantizer(
            n_embeddings, embedding_dim, beta)
        # decode the discrete latent representation
        self.decoder = Decoder(embedding_dim, h_dim, n_res_layers, res_h_dim)

        if save_img_embedding_map:
            self.img_to_embedding_map = {i: [] for i in range(n_embeddings)}
        else:
            self.img_to_embedding_map = None

    def forward(self, x, verbose=False):

        z_e = self.encoder(x)

        z_e = self.pre_quantization_conv(z_e)
        embedding_loss, z_q, perplexity, _ = self.vector_quantization(z_e)
        x_hat = self.decoder(z_q)

        return embedding_loss, x_hat, perplexity

In [11]:
training_data = torchvision.datasets.CIFAR10(root="data", train=True, download=True,
                                  transform=transforms.Compose([
                                      transforms.ToTensor(),
                                      transforms.Normalize((0.5,0.5,0.5), (1.0,1.0,1.0))
                                  ]))

validation_data = torchvision.datasets.CIFAR10(root="data", train=False, download=True,
                                  transform=transforms.Compose([
                                      transforms.ToTensor(),
                                      transforms.Normalize((0.5,0.5,0.5), (1.0,1.0,1.0))
                                  ]))

Files already downloaded and verified
Files already downloaded and verified


In [12]:
data_variance = np.var(training_data.data / 255.0)
data_variance

0.06328692405746414

## Define our Hyperparameters

In [22]:
batch_size = 64
n_hiddens = 64
n_residual_hiddens = 32
n_residual_layers = 2
embedding_dim = 64
n_embeddings = 512
beta = .25
lr = 3e-4
epochs = 25

In [23]:
train_loader = torch.utils.data.DataLoader(training_data, batch_size = batch_size, shuffle = True)
validation_loader = torch.utils.data.DataLoader(validation_data,batch_size=32,shuffle=True)

In [58]:
vqvae = VQVAE(n_hiddens, n_residual_hiddens, n_residual_layers,
              n_embeddings, embedding_dim, 
              beta).to(device)

In [119]:
vqvae

VQVAE(
  (encoder): Encoder(
    (conv_block): Sequential(
      (0): Conv2d(3, 32, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1))
      (1): BatchNorm2d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (2): ReLU()
      (3): Conv2d(32, 64, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1))
      (4): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (5): ReLU()
      (6): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
      (7): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (8): ResidualLayers(
        (layers): ModuleList(
          (0-31): 32 x ResidualLayerBlock(
            (block): Sequential(
              (0): ReLU()
              (1): Conv2d(64, 2, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
              (2): BatchNorm2d(2, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
              (3): ReLU()
              (4): Conv2d(2, 64, ke

In [47]:
optimizer = torch.optim.Adam(vqvae.parameters(), lr=lr, amsgrad=False)

In [48]:
train_res_recon_error = []
test_res_recon_error = []
train_res_perplexity = []
for epoch in range(epochs):
    with tqdm(train_loader, unit="batch") as tepoch:
        vqvae.train()
        for data, target in tepoch:
            data = data.to(device)
            optimizer.zero_grad()

            vq_loss, data_recon, perplexity = vqvae(data)
            recon_error = F.mse_loss(data_recon, data) / data_variance
            loss = recon_error + vq_loss
            loss.backward()

            optimizer.step()
            tepoch.set_postfix(loss=float(loss.detach().cpu()))
            train_res_recon_error.append(recon_error.item())
            train_res_perplexity.append(perplexity.item())

    avg_loss = 0
    vqvae.eval()
    for data, target in validation_loader:
        data = data.to(device)

        vq_loss, data_recon, perplexity = vqvae(data)
        recon_error = F.mse_loss(data_recon, data) / data_variance
        loss = recon_error.item() * 32

        avg_loss += loss / len(validation_data)
        test_res_recon_error.append(recon_error)
    
    print(f'Validation Loss: {avg_loss}')


 59%|█████▉    | 460/782 [00:28<00:20, 16.02batch/s, loss=3.78]


KeyboardInterrupt: ignored

In [None]:
# Michael, I don't think we need this anymore since we are able to clearly train the VQVAE on Colab
#torch.save(vqvae.state_dict(), 'VQ-VAE_Model')

In [None]:
plt.plot(train_res_recon_error)
plt.xlabel('Batches')
plt.ylabel('Training Reconstruction Error')
plt.show()
#plt.xlabel('Batches')
#plt.ylabel('Training Perplexity')
#plt.plot(perplexity)
#plt.show()

In [None]:
vqvae.eval()
validation_loader = torch.utils.data.DataLoader(validation_data,batch_size=16,shuffle=True)
(valid_originals, _) = next(iter(validation_loader))
valid_originals = valid_originals.to(device)

_, valid_recon, _ = vqvae(valid_originals)

In [None]:
def show(img):
    npimg = img.numpy()
    fig = plt.imshow(np.transpose(npimg, (1,2,0)), interpolation='nearest')
    fig.axes.get_xaxis().set_visible(False)
    fig.axes.get_yaxis().set_visible(False)

In [None]:
show(torchvision.utils.make_grid(valid_recon.cpu().data) + 0.5, )
plt.show()
show(torchvision.utils.make_grid(valid_originals.cpu())+0.5, )
plt.show()

In [None]:
def sample_model(model):
    #sample 8 x 8 embedding vectors
    encoding_indices = torch.argmin(torch.rand(size = [8 * 8, n_embeddings]), dim=1).to(device).unsqueeze(1)
    encodings = torch.zeros(encoding_indices.shape[0], n_embeddings, device=device)
    encodings.scatter_(1, encoding_indices, 1)
    quantized = torch.matmul(encodings, model.vector_quantization._embedding.weight).view(1, 8, 8, 64)
    quantized = quantized.permute(0, 3, 1, 2).contiguous()
    z_e = model.decoder(quantized)
    return z_e

In [None]:
plt.imshow(sample_model(vqvae).squeeze(0).permute(1, 2, 0).cpu().detach() + 0.5)

NameError: name 'plt' is not defined

In [None]:
dl = torch.utils.data.DataLoader(training_data,batch_size=1024,shuffle=True)

In [None]:
out = None
for data, _ in dl:
    data = data.to(device)
    loss, x_hat, perplexity, encodings = vqvae.vector_quantization(vqvae.pre_quantization_conv(vqvae.encoder(data)))
    if out == None:
        out = x_hat
    else:
        out = torch.cat([out, x_hat], dim = 0) 

In [None]:
x_hat.shape

### Concept Check 3.2.4

## Ablations & Visualizations:

In [36]:
#Michael, when you make the student copy, replace the below with a comment that says copy over add_noise function from ae_vae notebook

In [13]:
def add_noise(tensor, mean=0., std=1., noise_weight=0.5):
    noise = torch.randn(tensor.size()) * std + mean
    return torch.clip(tensor + noise_weight * noise, 0., 1.)

### Quick AE Implementation Using Same Encoder-Decoder Architecture as VQ-VAE:

In [37]:
# AE/VAE hyperparams
batch_size = 64
n_hiddens = 64
n_residual_hiddens = 32
n_residual_layers = 2
embedding_dim = 64
code_size=8
#n_embeddings = 512
#beta = .25
#lr = 3e-4

epochs = 25
lr = 1e-3
noise=False
lin_dim=256
regularization_weight = .0001

In [38]:
class AE(nn.Module):
    def __init__(self, h_dim, res_h_dim, n_res_layers, embedding_dim, lin_dim):
        super().__init__()

        self.h_dim = h_dim
        # encode image into continuous latent space
        self.encoder = Encoder(3, h_dim, n_res_layers, res_h_dim)

        #FC Projections
        self.fc1 = nn.Linear(h_dim*code_size*code_size, lin_dim)
        self.fc2 = nn.Linear(lin_dim, h_dim*code_size*code_size)

        # decode the discrete latent representation
        self.decoder = Decoder(embedding_dim, h_dim, n_res_layers, res_h_dim)

    def encode(self, x):
        x = self.encoder(x)
        x = x.view(batch_size, self.h_dim*code_size*code_size)
        return self.fc1(x)

    def decode(self, z):
        z = self.fc2(z)
        z = z.view(batch_size, self.h_dim, code_size, code_size)
        return self.decoder(z)
    
    def train_step(self, optimizer, x_in, x_star):
        z = self.encode(x_in)
        x_hat = self.decode(z)

        loss = self.loss(x_star, x_hat)
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        return loss
    
    def test_step(self, x):
        z = self.encode(x)
        x_hat = self.decode(z)
        
        loss = self.loss(x, x_hat)
        return loss
    @staticmethod
    def loss(x, x_hat):
        #Mean-Squared Error Reconstruction Loss
        criterion = nn.MSELoss()
        return criterion(x, x_hat)

In [39]:
ae = AE(n_hiddens, n_residual_hiddens, n_residual_layers, embedding_dim, lin_dim).to(device)
optimizer = torch.optim.Adam(ae.parameters(), lr=lr)

# train
ae_train_losses = []
ae_test_losses = []
step = 0
report_every = 500
for epoch in range(epochs):
    print(f"Epoch {epoch + 1}")
    #train loss:
    for x, y in tqdm(train_loader):
        x_no_noise = x.to(device)
        if noise:
            x_in = add_noise(x_no_noise, noise_weight=0.5)
        else:
            x_in = x_no_noise
        loss = ae.train_step(optimizer, x_in=x_in, x_star=x_no_noise)
        ae_train_losses.append(loss.cpu().detach().numpy()) #loss every iteration
        step += 1
        if step % report_every == 0:
            print(f"Training loss: {loss}")
    #ae_train_losses.append(loss.detach().numpy()) #loss after every epoch
    #test loss
    for b, (X_test, y_test) in enumerate(validation_loader):
        X_test = X_test.to(device)
        loss = ae.test_step(X_test)
        ae_test_losses.append(loss.cpu().detach().numpy()) #loss every iteration
    #ae_test_losses.append(loss.detach().numpy()) #loss after every epoch

Epoch 1


  5%|▍         | 37/782 [00:02<00:54, 13.56it/s]


KeyboardInterrupt: ignored

### Quick VAE Implementation Using Same Encoder-Decoder Architecture as VQ-VAE:

In [40]:
class VAE(nn.Module):
    def __init__(self, h_dim, res_h_dim, n_res_layers, embedding_dim):
        super().__init__()
        self.z_mean = Encoder(3, h_dim, n_res_layers, res_h_dim)
        self.z_log_std = Encoder(3, h_dim, n_res_layers, res_h_dim)
        self.decoder = Decoder(embedding_dim, h_dim, n_res_layers, res_h_dim)

        self.h_dim = h_dim
        #FC Projections
        self.fc1 = nn.Linear(h_dim*code_size*code_size, lin_dim)
        self.fc2 = nn.Linear(lin_dim, h_dim*code_size*code_size)
    
    def _encode(self, x):
        print(x.shape)
        x = self.z_mean(x)
        x = x.view(batch_size, self.h_dim*code_size*code_size)
        z_mean = self.fc1(x)
        

        x = self.z_mean(x)
        x = x.view(batch_size, self.h_dim*code_size*code_size)
        z_log_std = self.fc1(x)

        # reparameterization trick
        z_std = torch.exp(z_log_std)
        eps = torch.randn_like(z_std)
        z = z_mean + eps * z_std

        # log prob
        # 'd' not sampled on purpose
        # to show reparameterization trick
        d = Independent(Normal(z_mean, z_std), 1)
        log_prob = d.log_prob(z)
        
        return z_mean + eps * z_std, log_prob
    
    def encode(self, x):
        z, _ = self._encode(x)
        return z

    def decode(self, z):
        z = self.fc2(z)
        z = z.view(batch_size, self.h_dim, code_size, code_size)
        return self.decoder(z)
    
    def train_step(self, optimizer, x_in, x_star):
        z, log_prob = self._encode(x_in)
        x_hat = self.decode(z)
        loss = self.loss(x_star, x_hat, z, log_prob)
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        return loss
    
    def test_step(self, x):
        z, log_prob = self._encode(x)
        x_hat = self.decode(z)
        
        loss = self.loss(x, x_hat, z, log_prob)
        return loss

    @staticmethod
    def loss(x, x_hat, z, log_prob, kl_weight=regularization_weight):
        criterion = nn.MSELoss()
        reconst_loss = criterion(x, x_hat)

        z_dim = z.shape[-1]
        standard_normal = MultivariateNormal(torch.zeros(z_dim).to(device), 
                                             torch.eye(z_dim).to(device))
        #print(MultivariateNormal.device)
        kld_loss = (log_prob - standard_normal.log_prob(z)).mean()
        
        return reconst_loss + kl_weight * kld_loss

In [44]:
#vae

In [43]:
vae = VAE(n_hiddens, n_residual_hiddens, n_residual_layers, embedding_dim).to(device)
optimizer = torch.optim.Adam(vae.parameters(), lr=lr)

# train
vae_train_losses = []
vae_test_losses = []
step = 0
report_every = 500
for epoch in range(epochs):
    print(f"Epoch {epoch + 1}")
    #train loss:
    for x, y in tqdm(train_loader):
        x_no_noise = x.to(device)
        if noise:
            x_in = add_noise(x_no_noise, noise_weight=0.5)
        else:
            x_in = x_no_noise
        loss = vae.train_step(optimizer, x_in=x_in, x_star=x_no_noise)
        vae_train_losses.append(loss.cpu().detach().numpy()) #loss every iteration
        step += 1
        if step % report_every == 0:
            print(f"Training loss: {loss}")
    #vae_train_losses.append(loss.detach().numpy()) #loss after every epoch
    #test loss
    for b, (X_test, y_test) in enumerate(validation_loader):
        X_test = X_test.to(device)
        loss = vae.test_step(X_test)
        vae_test_losses.append(loss.cpu().detach().numpy()) #loss every iteration
    #vae_test_losses.append(loss.detach().numpy()) #loss after every epoch

Epoch 1


  0%|          | 0/782 [00:00<?, ?it/s]

torch.Size([64, 3, 32, 32])





RuntimeError: ignored

### 1. Loss Visualization (AE vs. VAE vs. VQVAE):

In [None]:
with torch.no_grad():
    fig, ax = plt.subplots(2, figsize=(8,10))
    ax[0].plot(ae_train_losses, label="AE")
    ax[0].plot(vae_train_losses, label="VAE")
    ax[0].plot(train_res_recon_error, label="VQ-VAE")
    ax[0].set_title("Train Losses")
    ax[0].set_ylabel("Train Loss")
    ax[0].set_xlabel("Iteration")
    ax[0].legend()
    
    ax[1].plot(ae_test_losses, label="AE")
    ax[1].plot(vae_test_losses, label="VAE")
    ax[1].plot(test_res_recon_error, label="VQ-VAE")
    ax[1].set_title("Test Losses")
    ax[1].set_ylabel("Test Loss")
    ax[1].set_xlabel("Iteration")
    ax[1].legend()

#Note: if you want per epoch (or avg. epoch) losses, you will have to change the above code somewhat

### 2. Reconstruction Visualization (AE vs. VAE vs. VQVAE):

In [48]:
#Visualize the samples in bulk:
ae.eval()
vae.eval()
vqvae.eval()
with torch.no_grad():
    #Grab first batch of images:
    for images, labels in validation_loader:
        recon_ae = ae.decode(ae.encode(images))
        recon_vae = vae.decode(vae.encode(images))
        _, recon_vqvae, _ = vqvae(images)
        break


    #Print and show the first 10 samples:

    print(f"Labels: {labels[0:10]}")
    im = make_grid(images[:10], nrow=10)
    ae_im = make_grid(recon_ae[:10], nrow=10)
    vae_im = make_grid(recon_vae[:10], nrow=10)
    vqvae_im = make_grid(recon_vqvae[:10], nrow=10)

    fig, ax = plt.subplots(4, figsize=(45,4.5))
    fig.tight_layout(pad=1.5)
    ax[0].imshow(np.transpose(im.numpy(), (1, 2, 0))) #Remember that default MNIST data is CWH, but matplotlib uses WHC
    ax[0].set_title("Original:")

    ax[1].imshow(np.transpose(ae_im.numpy(), (1, 2, 0)))
    ax[1].set_title("AE Reconstruction:")
    
    ax[2].imshow(np.transpose(vae_im.numpy(), (1, 2, 0)))
    ax[2].set_title("VAE Reconstruction:")

    ax[2].imshow(np.transpose(vqvae_im.numpy(), (1, 2, 0)))
    ax[2].set_title("VQ-VAE Reconstruction:")

NameError: ignored

### 3. Denoising Ablation:

Now, let us see the denoising capabilities of these architectures. Go back to the function add_noise in order to add some Gaussian noise to the training data, $X$. You can just copy-paste this from the AE-VAE notebook and set the noise hyperparameter to True. Then, apply this noise to the sample when training.

In [None]:
with torch.no_grad():
    fig, ax = plt.subplots(2, figsize=(8,10))
    ax[0].plot(ae_train_losses, label="AE")
    ax[0].plot(vae_train_losses, label="VAE")
    ax[0].plot(train_res_recon_error, label="VQ-VAE")
    ax[0].set_title("Train Losses")
    ax[0].set_ylabel("Train Loss")
    ax[0].set_xlabel("Iteration")
    ax[0].legend()
    
    ax[1].plot(ae_test_losses, label="AE")
    ax[1].plot(vae_test_losses, label="VAE")
    ax[1].plot(test_res_recon_error, label="VQ-VAE")
    ax[1].set_title("Test Losses")
    ax[1].set_ylabel("Test Loss")
    ax[1].set_xlabel("Iteration")
    ax[1].legend()

#Note: if you want per epoch (or avg. epoch) losses, you will have to change the above code somewhat

In [None]:
#Visualize the samples in bulk:
ae.eval()
vae.eval()
vqvae.eval()
with torch.no_grad():
    #Grab first batch of images:
    for images, labels in validation_loader:
        recon_ae = ae.decode(ae.encode(images))
        recon_vae = vae.decode(vae.encode(images))
        _, recon_vqvae, _ = vqvae(images)
        break


    #Print and show the first 10 samples:

    print(f"Labels: {labels[0:10]}")
    im = make_grid(images[:10], nrow=10)
    ae_im = make_grid(recon_ae[:10], nrow=10)
    vae_im = make_grid(recon_vae[:10], nrow=10)
    vqvae_im = make_grid(recon_vqvae[:10], nrow=10)

    fig, ax = plt.subplots(4, figsize=(45,4.5))
    fig.tight_layout(pad=1.5)
    ax[0].imshow(np.transpose(im.numpy(), (1, 2, 0))) #Remember that default MNIST data is CWH, but matplotlib uses WHC
    ax[0].set_title("Original:")

    ax[1].imshow(np.transpose(ae_im.numpy(), (1, 2, 0)))
    ax[1].set_title("AE Reconstruction:")
    
    ax[2].imshow(np.transpose(vae_im.numpy(), (1, 2, 0)))
    ax[2].set_title("VAE Reconstruction:")

    ax[2].imshow(np.transpose(vqvae_im.numpy(), (1, 2, 0)))
    ax[2].set_title("VQ-VAE Reconstruction:")

### 4. Latent Space Dimension & Size Ablation (Jason, if time; this might be more non-trivial than we initially thought):

### 5. Sample Generation Ablation (Jason):

## References

Any references utilized in this project can be found in the README of our repo