# Autoencoders

This week, you learned about autoencoders, which learn to compress data into a low-dimensional _code_ (also called _latent variable_ or _encoding_). You also got familiar with variational autoencoders (VAEs) that enable sample generation. In this session, we are going to implement an autoencoder, and you will implement a variational autoencoder in your assignment. 

In [None]:
import os
from typing import Sequence, Union
from tqdm import tqdm
import numpy as np

import matplotlib.pyplot as plt
import ipywidgets as widgets
try:
    from google.colab import output
    output.enable_custom_widget_manager()
except ImportError:
    pass
try:
    %matplotlib widget
except:
    os.system('pip install ipympl -qq')
    %matplotlib widget


import torch
from torch import nn
from torch import optim
from torch.optim import lr_scheduler
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
from torchvision.datasets import MNIST
from torchvision.transforms import v2

Device = 'cuda' if torch.cuda.is_available() else 'cpu'
print(f'Using {Device} device')

In [None]:
"""
Load and look at the MNIST dataset
"""

train_dataset = MNIST(
    root = 'MNIST',
    train = True,
    download = True,
    transform = v2.Compose([
        v2.ToImage(),
        v2.ToDtype(torch.float32, scale=True),
    ]),
)

test_dataset = MNIST(
    root = 'MNIST',
    train = False,
    download = True,
    transform = v2.Compose([
        v2.ToImage(),
        v2.ToDtype(torch.float32, scale=True),
    ]),
)

In [None]:
class ImageDataViz:
    """
    An interactive image data visualzation tool inside Juptyer Notebook.
    Make sure to use the magic command: %matplotlib widget
    """
    def __init__(self, dataset: Dataset):
        self.dataset = dataset
        self.n_samples = len(dataset)
        self.index = widgets.IntSlider(
            value = 0, 
            min = 0, 
            max = self.n_samples-1, 
            step = 1, 
            description = 'Index', 
            continuous_update = True,
            layout = widgets.Layout(width='50%'),
        )

    def update(self, index: int):
        x, y = self.dataset[index]
        image = x.moveaxis(0, -1).squeeze().numpy()
        self.img.set_data(image)
        self.ax.set_title(f'Label: {y}')

    def show(self):
        self.fig, self.ax = plt.subplots()
        x, y = self.dataset[0]
        image = x.moveaxis(0, -1).squeeze().numpy()
        self.img = self.ax.imshow(image, cmap='gray')
        self.ax.axis('off')
        self.ax.set_title(f'Label: {y}')
        widgets.interact(self.update, index=self.index)
        

In [None]:
viz = ImageDataViz(train_dataset)
viz.show()

## Defining an autoencoder

Typically, an autoencoder reduces the dimension of the data gradually to reach the compact encoding (the encoder's job), and then tries to reconstruct the input from the compact encoding (the decoder's job). Let's take a look at an example:

In [None]:
class Encoder(nn.Module):
    
    def __init__(
            self,
            input_size: int,
            hidden_sizes: Sequence[int],
            latent_size: int,
            activation: str = 'ReLU',
            ):
        super().__init__()

        act = nn.__getattribute__(activation)

        n_layers = len(hidden_sizes) + 1
        sizes = hidden_sizes + [latent_size]
        self.layers = nn.Sequential(nn.Flatten())

        for i in range(n_layers):
            in_size = input_size if i == 0 else sizes[i-1]
            out_size = sizes[i]
            self.layers.append(nn.Linear(in_size, out_size))
            if i < n_layers - 1:
                self.layers.append(act())

    def forward(self, x):
        return self.layers(x)
    

class Decoder(nn.Module):
    
    def __init__(
            self,
            latent_size: int,
            hidden_sizes: Sequence[int],
            output_size: int,
            activation: str = 'ReLU',
        ):
        super().__init__()

        act = nn.__getattribute__(activation)

        n_layers = len(hidden_sizes) + 1
        sizes = hidden_sizes + [output_size]
        self.layers = nn.Sequential()

        for i in range(n_layers):
            in_size = latent_size if i == 0 else sizes[i-1]
            out_size = sizes[i]
            self.layers.append(nn.Linear(in_size, out_size))
            if i < n_layers - 1:
                self.layers.append(act())

        self.layers.append(nn.Unflatten(1, (1, 28, 28)))
        
    def forward(self, x):
        return self.layers(x)
    

class Autoencoder(nn.Module):

    def __init__(
            self,
            input_size: int,
            hidden_sizes: Sequence[int],
            latent_size: int,
            activation: str = 'ReLU',
        ):
        super().__init__()

        self.encoder = Encoder(
            input_size = input_size, 
            hidden_sizes = hidden_sizes, 
            latent_size = latent_size, 
            activation = activation,
            )
        
        self.decoder = Decoder(
            latent_size = latent_size, 
            hidden_sizes = hidden_sizes[::-1], # reverse order. Gradually increasing the size
            output_size = input_size, 
            activation = activation,
            )
        
    def forward(
            self, 
            x: torch.FloatTensor, # Shape: (batch_size, 1, 28, 28)
            ) -> torch.FloatTensor: # Shape: (batch_size, 1, 28, 28)
        z = self.encoder(x)
        x_hat = self.decoder(z)
        return x_hat

In [None]:
@torch.enable_grad()
def train_epoch(
    model: nn.Module,
    train_loader: DataLoader,
    loss_fn: nn.Module,
    optimizer: torch.optim.Optimizer,
    device: str = Device,
    ):

    model.train().to(device)

    for x, y in train_loader:
        x, y = x.to(device), y.to(device)
        optimizer.zero_grad()
        # Forward pass is not just reconstruction
        x_rec = model(x)
        # The loss is calculated between the input and the output (input's reconstruction)
        loss = loss_fn(x_rec, x)
        loss.backward()
        optimizer.step()


@torch.inference_mode()
def eval_epoch(
    model: nn.Module,
    data_loader: DataLoader, # can be train_loader or val_loader or test_loader
    loss_fn: nn.Module,
    device: str = Device,
    ):
    assert loss_fn.reduction in ['mean', 'sum'], 'Invalid reduction method!'

    model.eval().to(device)
    
    n = len(data_loader.dataset)
    Loss = 0.

    for x, y in data_loader:
        b = len(x)
        x, y = x.to(device), y.to(device)
        # Forward pass is just reconstruction
        x_rec = model(x)
        # The loss is calculated between the input and the output (input's reconstruction)
        loss = loss_fn(x_rec, x)
        if loss_fn.reduction == 'mean':
            Loss += loss.item()*b
        elif loss_fn.reduction == 'sum':
            Loss += loss.item()

    return Loss/n


def train(
    # Model and data
    model: nn.Module,
    train_dataset: Dataset,
    test_dataset: Dataset,
    loss_fn: nn.Module = nn.MSELoss(reduction='mean'),
    device: str = Device,

    # train config
    optim_name: str = 'Adam', # from optim
    optim_config: dict = dict(),
    lr_scheduler_name: Union[str, None] = None, # from optim.lr_scheduler
    lr_scheduler_config: dict = dict(),
    n_epochs: int = 10,
    batch_size: int = 32,
    ):
    
    train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
    test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False)

    optimizer = optim.__getattribute__(optim_name)(model.parameters(), **optim_config)

    if lr_scheduler_name is not None:
        scheduler = lr_scheduler.__getattribute__(lr_scheduler_name)(optimizer, **lr_scheduler_config)

    epoch_pbar = tqdm(
        range(n_epochs),
        desc = 'epochs',
        unit = 'epoch',
        dynamic_ncols = True,
        leave = True,
        )

    for epoch in epoch_pbar:

        train_epoch(
            model = model,
            train_loader = train_loader,
            loss_fn = loss_fn,
            optimizer = optimizer,
            device = device,
            )

        train_loss = eval_epoch(
            model = model,
            data_loader = train_loader,
            loss_fn = loss_fn,
            device = device,
            )

        test_loss = eval_epoch(
            model = model,
            data_loader = test_loader,
            loss_fn = loss_fn,
            device = device,
            )

        if lr_scheduler_name == 'ReduceLROnPlateau':
            scheduler.step(train_loss)
        elif lr_scheduler_name is not None:
            scheduler.step()

        epoch_pbar.set_postfix_str(f'train loss: {train_loss:.4f}, test loss: {test_loss:.4f}')

In [None]:
model_config = dict(
    input_size = 28*28,
    hidden_sizes = [256, 128, 64, 32],
    latent_size = 2,
    activation = 'LeakyReLU',
)

train_config = dict(
    optim_name = 'Adam',
    optim_config = {},
    lr_scheduler_name = 'ReduceLROnPlateau',
    lr_scheduler_config = dict(factor=0.5, patience=5),
    n_epochs = 5,
    batch_size = 64,
)

In [None]:
if __name__ == '__main__':
    model = Autoencoder(**model_config)
    train(
        model = model, 
        train_dataset = train_dataset, 
        test_dataset = test_dataset, 
        loss_fn = nn.MSELoss(), 
        device = Device,
        **train_config,
        )

# Using a trained autoencoder

## Data compression and reconstruction
- We can use the encoder to compress the data. We will lose some quality though. We will compare the data with the reconstruction of the autoencoder.


In [None]:
class AutoEncoderViz:
    """
    Interactive plotting of data and the reconstruction of the autoencoder
    """
    def __init__(
            self, 
            dataset: Dataset, 
            model: Autoencoder,
            ):
        self.dataset = dataset
        self.n_samples = len(dataset)
        self.index = widgets.IntSlider(
            value = 0, 
            min = 0, 
            max = self.n_samples-1, 
            step = 1, 
            description = 'Index', 
            continuous_update = True,
            layout = widgets.Layout(width='50%'),
        )

        self.model = model.eval().cpu()


    def show(self):
        self.fig, (self.ax, self.ax2) = plt.subplots(1, 2)
        x, y = self.dataset[0]
        image = x.moveaxis(0, -1).squeeze().numpy()
        self.img = self.ax.imshow(image, cmap='gray', vmin=0, vmax=1)
        self.ax.axis('off')
        self.ax.set_title(f'Label: {y}')
        self.img_rec = self.ax2.imshow(np.zeros((28, 28), dtype=np.float32), cmap='gray', vmin=0, vmax=1)
        self.ax2.axis('off')
        self.ax2.set_title('Reconstruction')
        widgets.interact(self.update, index=self.index)


    @torch.inference_mode()
    def update(self, index: int):

        x, y = self.dataset[index]

        image = x.moveaxis(0, -1).squeeze().numpy()
        self.img.set_data(image)
        self.ax.set_title(f'Label: {y}')

        x_rec = self.model(x.unsqueeze(0))[0]
        x_rec = x_rec.moveaxis(0, -1).squeeze().numpy()
        self.img_rec.set_data(x_rec)

        z0, z1 = self.model.encoder(x.unsqueeze(0))[0].numpy()
        self.ax2.set_title(f'Latent variable: ({z0:.2f}, {z1:.2f})')

In [None]:
AEviz = AutoEncoderViz(train_dataset, model)
AEviz.show()

## Analyze distribution in latent space

We can plot the distribution of the data in the latent space, which is easier to visualize and analyze qualitatively.

In [None]:
@torch.inference_mode()
def plot_latent_space():
    """
    Encodes the whole dataset into the latent space
    and plots the 2D latent space with different colors for different classes
    """

    model.eval().cpu()

    train_loader = DataLoader(train_dataset, batch_size=1024, shuffle=False)
    test_loader = DataLoader(test_dataset, batch_size=1024, shuffle=False)

    train_encoded = []
    train_labels = []

    for x, y in train_loader:
        z = model.encoder(x)
        train_encoded.append(z)
        train_labels.append(y)

    test_encoded = []
    test_labels = []

    for x, y in test_loader:
        z = model.encoder(x)
        test_encoded.append(z)
        test_labels.append(y)

    train_encoded = torch.cat(train_encoded)
    train_labels = torch.cat(train_labels)
    test_encoded = torch.cat(test_encoded)
    test_labels = torch.cat(test_labels)

    # 2 subplots for train and test
    # different classes are colored differently

    fig, axs = plt.subplots(1, 2, figsize=(10, 4), sharex=True, sharey=True)

    for i, (encoded, labels) in enumerate([(train_encoded, train_labels), (test_encoded, test_labels)]):
        ax = axs[i]
        for c in range(10):
            idx = labels == c
            ax.scatter(encoded[idx, 0].numpy(), encoded[idx, 1].numpy(), label=str(c), alpha=0.5, s=5)
        ax.set_title(['Train', 'Test'][i])
        ax.grid(linestyle='--')

    axs[1].legend(loc=(1.05, 0))

    plt.show()

In [None]:
plot_latent_space()

## Generation

We can use the decoder to generate new samples from arbitrary latent variables. __However__, Not every arbitrary latent variable will lead to a good generated sample.

In [None]:
class Generator:
    """
    Generating new samples from arbitrary latent variable using the decoder
    """
    def __init__(
            self, 
            model: Autoencoder,
            ):
        self.model = model.eval().cpu()


    def show(self):

        # create widgets fot latent space variables
        self.z0 = widgets.FloatSlider(
            value = 0.0,
            min = -20.0,
            max = 20.0,
            step = 0.1,
            description = 'z0',
            continuous_update = True,
            layout = widgets.Layout(width='50%'),
        )

        self.z1 = widgets.FloatSlider(
            value = 0.0,
            min = -20.0,
            max = 20.0,
            step = 0.1,
            description = 'z1',
            continuous_update = True,
            layout = widgets.Layout(width='50%'),
        )

        self.fig, self.ax = plt.subplots()
        self.ax.axis('off')
        self.img = self.ax.imshow(np.zeros((28, 28), dtype=np.float32), cmap='gray', vmin=0, vmax=1)

        widgets.interact(self.update, z0=self.z0, z1=self.z1)


    @torch.inference_mode()
    def update(self, z0: float, z1: float):

        z = torch.tensor([z0, z1], dtype=torch.float32)
        x = self.model.decoder(z[None, ...])[0]
        image = x.moveaxis(0, -1).squeeze().numpy()
        self.img.set_data(image)


In [None]:
generator = Generator(model)
generator.show()

## So what is the point of Variational Autoencoders (VAEs)?

In the lectures, you might have heard that VAEs are used for sample generation and have a probabilistic nature. But we just saw that we can also generate samples with a vanilla autoencoder. So what is the point?

We cannot always use vanilla (simple) autoencoders because we do not know what is the range and distribution of the data in the learned latent space. Here it was simple because the size of the latent space is small (2) and it is possible to visualize and take a look at what range of latent values correspond to what data samples.

In a general situation, the latent space might be relatively high dimensional. For example, high-fidelity image and video data have dimensionality in the order of millions of pixels. In such cases, an encoder can be trained to compress the data into a latent space of size 512. Although the latent space is still much more compressed compared to the original data, the latent space is too high dimensional to be visualized or be assumed to follow within a certain range or distribution that we want it to.

This is where variational autoencoders come into play. They have a mechanism to encourage the latent space to follow a certain distribution. You will learn more about VAEs in this week's assignment.