<a href="https://colab.research.google.com/github/faisalalz171/COMP8230/blob/main/AML_A1.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

## 1. Importing Libraries

In [1]:
# Imports
import os
import numpy as np
import matplotlib.pyplot as plt
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
from sklearn.datasets import load_digits


## 2. Defining Constants

In [2]:
# Defining constants for mathematics.
PI = torch.from_numpy(np.asarray(np.pi))
EPS = 1.e-5


## 3. Uploading the Dataset
Digits dataset class is created to load and organize the Digits dataset. The data is split into training, validation, and testing sets based on index ranges.

The __len__ method gives the total number of samples, and the __getitem__ method allows the model to access one sample at a time during training.

In [3]:
# Uploading digits Dataset
class Digits(Dataset):
    def __init__(self, mode='train', transforms=None):
        digits = load_digits()
        if mode == 'train':
            self.data = digits.data[:1000].astype(np.float32)
        elif mode == 'val':
            self.data = digits.data[1000:1350].astype(np.float32)
        else:
            self.data = digits.data[1350:].astype(np.float32)
        self.transforms = transforms

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        sample = self.data[idx]
        if self.transforms:
            sample = self.transforms(sample)
        return sample


## 4: Probability Functions

 Helper functions are created to calculate different types of probabilities needed for the VAE model. These functions measure how good the model's predictions are by comparing the true data and the predicted data. They handle different probability types, and are used inside the encoder and decoder during training.

In [4]:
# Probability Distribution Functions
def log_categorical(x, p, num_classes=256, reduction=None, dim=None):
    x_one_hot = F.one_hot(x.long(), num_classes=num_classes)
    log_p = x_one_hot * torch.log(torch.clamp(p, EPS, 1. - EPS))
    if reduction == 'avg':
        return torch.mean(log_p, dim)
    elif reduction == 'sum':
        return torch.sum(log_p, dim)
    else:
        return log_p

def log_bernoulli(x, p, reduction=None, dim=None):
    pp = torch.clamp(p, EPS, 1. - EPS)
    log_p = x * torch.log(pp) + (1. - x) * torch.log(1. - pp)
    if reduction == 'avg':
        return torch.mean(log_p, dim)
    elif reduction == 'sum':
        return torch.sum(log_p, dim)
    else:
        return log_p

def log_normal_diag(x, mu, log_var, reduction=None, dim=None):
    D = x.shape[1]
    log_p = -0.5 * D * torch.log(2. * PI) - 0.5 * log_var - 0.5 * torch.exp(-log_var) * (x - mu)**2.
    if reduction == 'avg':
        return torch.mean(log_p, dim)
    elif reduction == 'sum':
        return torch.sum(log_p, dim)
    else:
        return log_p

def log_standard_normal(x, reduction=None, dim=None):
    D = x.shape[1]
    log_p = -0.5 * D * torch.log(2. * PI) - 0.5 * x**2.
    if reduction == 'avg':
        return torch.mean(log_p, dim)
    elif reduction == 'sum':
        return torch.sum(log_p, dim)
    else:
        return log_p


## 5: Encoder

 The encoder compresses the input into a small hidden code and adds random noise to help learning

In [5]:
# Encoder
class Encoder(nn.Module):
    def __init__(self, encoder_net):
        super(Encoder, self).__init__()
        self.encoder = encoder_net

    @staticmethod
    def reparameterization(mu, log_var):
        std = torch.exp(0.5 * log_var)
        eps = torch.randn_like(std)
        return mu + std * eps

    def encode(self, x):
        h_e = self.encoder(x)
        mu_e, log_var_e = torch.chunk(h_e, 2, dim=1)
        return mu_e, log_var_e

    def sample(self, x=None, mu_e=None, log_var_e=None):
        if (mu_e is None) and (log_var_e is None):
            mu_e, log_var_e = self.encode(x)
        else:
            if (mu_e is None) or (log_var_e is None):
                raise ValueError('mu and log-var cannot be None!')
        z = self.reparameterization(mu_e, log_var_e)
        return z

    def log_prob(self, x=None, mu_e=None, log_var_e=None, z=None):
        if x is not None:
            mu_e, log_var_e = self.encode(x)
            z = self.sample(mu_e=mu_e, log_var_e=log_var_e)
        else:
            if (mu_e is None) or (log_var_e is None) or (z is None):
                raise ValueError('mu, log-var and z cannot be None!')
        return log_normal_diag(z, mu_e, log_var_e)


## 6. Decoder

The decoder takes the hidden code and tries to rebuild the original image. It predicts the pixel values based on the hidden information.

In [6]:
# Decoder
class Decoder(nn.Module):
    def __init__(self, decoder_net, distribution='categorical', num_vals=None):
        super(Decoder, self).__init__()
        self.decoder = decoder_net
        self.distribution = distribution
        self.num_vals = num_vals

    def decode(self, z):
        h_d = self.decoder(z)
        if self.distribution == 'categorical':
            b = h_d.shape[0]
            d = h_d.shape[1] // self.num_vals
            h_d = h_d.view(b, d, self.num_vals)
            mu_d = torch.softmax(h_d, 2)
            return [mu_d]
        elif self.distribution == 'bernoulli':
            mu_d = torch.sigmoid(h_d)
            return [mu_d]
        else:
            raise ValueError('Distribution must be categorical or bernoulli')

    def sample(self, z):
        outs = self.decode(z)
        if self.distribution == 'categorical':
            mu_d = outs[0]
            b = mu_d.shape[0]
            m = mu_d.shape[1]
            p = mu_d.view(-1, self.num_vals)
            x_new = torch.multinomial(p, num_samples=1).view(b, m)
        elif self.distribution == 'bernoulli':
            mu_d = outs[0]
            x_new = torch.bernoulli(mu_d)
        else:
            raise ValueError('Distribution must be categorical or bernoulli')
        return x_new

    def log_prob(self, x, z):
        outs = self.decode(z)
        if self.distribution == 'categorical':
            mu_d = outs[0]
            log_p = log_categorical(x, mu_d, num_classes=self.num_vals, reduction='sum', dim=-1).sum(-1)
        elif self.distribution == 'bernoulli':
            mu_d = outs[0]
            log_p = log_bernoulli(x, mu_d, reduction='sum', dim=-1)
        else:
            raise ValueError('Distribution must be categorical or bernoulli')
        return log_p


## 7. Prior

The prior is defined as a normal distribution. It controls the shape of the hidden space where the encoder sends the data.

In [7]:
# Prior
class Prior(nn.Module):
    def __init__(self, L):
        super(Prior, self).__init__()
        self.L = L

    def sample(self, batch_size):
        z = torch.randn((batch_size, self.L))
        return z

    def log_prob(self, z):
        return log_standard_normal(z)


## 8: β-VAE Model

The full VAE model is built by combining the encoder, decoder, and prior using β-VAE.

In the loss calculation, the KL divergence term is multiplied by beta, which controls how strongly we force the hidden space to be organized. A bigger beta means the model will care more about making the hidden space neat, but it might not rebuild the images as perfectly.

In this assignment, beta is set to 4.0 to balance learning a clean hidden space and still making good quality images.

In [8]:
# β-VAE Model
class VAE(nn.Module):
    def __init__(self, encoder_net, decoder_net, num_vals=256, L=16, likelihood_type='categorical', beta=4.0):
        super(VAE, self).__init__()
        self.encoder = Encoder(encoder_net=encoder_net)
        self.decoder = Decoder(distribution=likelihood_type, decoder_net=decoder_net, num_vals=num_vals)
        self.prior = Prior(L=L)
        self.num_vals = num_vals
        self.likelihood_type = likelihood_type
        self.beta = beta  # NEW

    def forward(self, x, reduction='avg'):
        mu_e, log_var_e = self.encoder.encode(x)
        z = self.encoder.sample(mu_e=mu_e, log_var_e=log_var_e)
        RE = self.decoder.log_prob(x, z)
        KL = (self.prior.log_prob(z) - self.encoder.log_prob(mu_e=mu_e, log_var_e=log_var_e, z=z)).sum(-1)
        if reduction == 'sum':
            return -(RE + self.beta * KL).sum()
        else:
            return -(RE + self.beta * KL).mean()

    def sample(self, batch_size=64):
        z = self.prior.sample(batch_size=batch_size)
        return self.decoder.sample(z)


## 9. Model Setup

The input dimension D is set to 64 because each image from the Digits dataset is 8×8 pixels, resulting in 64 values when flattened.

The latent dimension L is chosen as 16 to balance compression and reconstruction quality, allowing the model to capture important features without losing too much detail.

The hidden layer size M is set to 256 to provide enough capacity for the encoder and decoder to learn meaningful patterns, while keeping training efficient for a small dataset.

As explained above, the beta value is set to 4.0 to encourage better organization in the latent space without significantly harming the quality of the reconstructed images, achieving a good balance between disentanglement and reconstruction.

The likelihood type is set to 'categorical' because each pixel in the Digits dataset can take on 17 discrete integer values.

In [9]:
# Hyperparameters
D = 64    # Input dimension
L = 16    # Latent dimension
M = 256   # Hidden layer size
beta = 4.0 # Beta for β-VAE
likelihood_type = 'categorical'

# Model architecture
encoder = nn.Sequential(
    nn.Linear(D, M), nn.LeakyReLU(),
    nn.Linear(M, M), nn.LeakyReLU(),
    nn.Linear(M, 2 * L)
)

decoder = nn.Sequential(
    nn.Linear(L, M), nn.LeakyReLU(),
    nn.Linear(M, M), nn.LeakyReLU(),
    nn.Linear(M, 17 * D)  # 17 because digits data has values 0–16
)

# Instantiate model
model = VAE(encoder_net=encoder, decoder_net=decoder, num_vals=17, L=L, likelihood_type=likelihood_type, beta=beta)
