In [34]:
from IPython.display import display, clear_output
import plotly.graph_objs as go
import matplotlib.pyplot as plt

import numpy as np
import pandas as pd 
import time
import os

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader, random_split
from torchvision import datasets, transforms

from sklearn.decomposition import PCA

# Testing GPU acceleration

In [2]:
# Check if CUDA is available
is_cuda_available = torch.cuda.is_available()
print("Is CUDA available:", is_cuda_available)

# Determine the device to use: GPU (CUDA), Apple Silicon (MPS), or CPU
DEVICE = (
    "cuda"
    if torch.cuda.is_available()
    else "mps"
    if torch.backends.mps.is_available()
    else "cpu"
)
print("Using device:", DEVICE)

Is CUDA available: True
Using device: cuda


In [3]:
if torch.cuda.is_available():
    # Test tensor operation on GPU
    test_tensor = torch.tensor([1.0, 2.0, 3.0], device="cuda")
    print("Test tensor on CUDA:", test_tensor)

Test tensor on CUDA: tensor([1., 2., 3.], device='cuda:0')


In [4]:
if is_cuda_available:
    try:
        test_tensor = torch.tensor([1, 2, 3], device=DEVICE)
        print("Successfully moved a tensor to the device:", test_tensor)
    except RuntimeError as e:
        print("Error moving a tensor to the device:", e)

Successfully moved a tensor to the device: tensor([1, 2, 3], device='cuda:0')


# Defining VAE classes

In [5]:
class StandardVAE(nn.Module):
    def __init__(self, latent_dim=20):
        super(StandardVAE, self).__init__()
        self.latent_dim = latent_dim

        # Encoder layers
        # Input: [bs, 1, 28, 28]
        self.enc_conv1 = nn.Conv2d(1, 16, kernel_size=3, stride=2, padding=1)  # Output: [bs, 16, 14, 14]
        self.enc_conv2 = nn.Conv2d(16, 32, kernel_size=3, stride=2, padding=1) # Output: [bs, 32, 7, 7]
        self.enc_fc1 = nn.Linear(32 * 7 * 7, 128)  # Output: [bs, 128]
        # Two output layers for the latent space
        self.enc_fc2 = nn.Linear(128, latent_dim)  # For mu, Output: [bs, latent_dim]
        self.enc_fc3 = nn.Linear(128, latent_dim)  # For logvar, Output: [bs, latent_dim]

        # Decoder layers
        self.dec_fc1 = nn.Linear(latent_dim, 128)  # Output: [bs, 128]
        self.dec_fc2 = nn.Linear(128, 32 * 7 * 7)  # Output: [bs, 1568]
        self.dec_conv1 = nn.ConvTranspose2d(32, 16, kernel_size=3, stride=2, padding=1, output_padding=1)  # Output: [bs, 16, 14, 14]
        self.dec_conv2 = nn.ConvTranspose2d(16, 1, kernel_size=3, stride=2, padding=1, output_padding=1)  # Output: [bs, 1, 28, 28]

    def encode(self, x):
        h = F.relu(self.enc_conv1(x))
        h = F.relu(self.enc_conv2(h))
        h = torch.flatten(h, start_dim=1)
        h = F.relu(self.enc_fc1(h))
        return self.enc_fc2(h), self.enc_fc3(h)

    def reparameterize(self, mu, logvar):
        std = torch.exp(0.5 * logvar)
        # Returns a tensor with the same size as input that is filled with random numbers from a normal distribution with mean 0 and variance 1.
        eps = torch.randn_like(std)
        return mu + eps * std

    def decode(self, z):
        h = F.relu(self.dec_fc1(z)) 
        h = F.relu(self.dec_fc2(h)).view(-1, 32, 7, 7) # .view reshapes [bs, 1568] to [bs, 32, 7, 7]
        h = F.relu(self.dec_conv1(h))
        return torch.sigmoid(self.dec_conv2(h))

    def forward(self, x):
        mu, logvar = self.encode(x)
        z = self.reparameterize(mu, logvar)
        return self.decode(z), z, mu, logvar

def loss_function_standard(recon_x, z, mu, logvar, x):
    BCE = F.binary_cross_entropy(recon_x, x, reduction='sum')
    KLD = -0.5 * torch.sum(1 + logvar - mu.pow(2) - logvar.exp())
    # Can reweight BCE + KLD as desired
    return BCE + KLD

LakeVAE inheirits from StandardVAE, but with a modified forward pass.
Its encoding and decoding layers are identicle. 
The only difference is that the forward pass stores and returns all the intermediate values required to calculate the modified reconstruction loss.

Since pixel values are either 0 or 1, we can use BCE between the input image and output image.
However, for the intermediate layers, that are continuous, we use can MSE instead. This is also what the paper's code does in practice.

In [6]:
class LakeVAE(StandardVAE):
    def forward(self, x):
        # Encoder layers
        # Input: [bs, 1, 28, 28]
        enc_conv1_out = F.relu(self.enc_conv1(x))  # Output: [bs, 16, 14, 14]
        enc_conv2_out = F.relu(self.enc_conv2(enc_conv1_out))  # Output: [bs, 32, 7, 7]
        flattened = torch.flatten(enc_conv2_out, start_dim=1)  # Output: [bs, 1568]
        h = F.relu(self.enc_fc1(flattened))  # Output: [bs, 128]
        mu, logvar = self.enc_fc2(h), self.enc_fc3(h)  # Output: [bs, latent_dim], [bs, latent_dim]

        # Reparameterization and Decoding layers
        z = self.reparameterize(mu, logvar)  # Output: [bs, latent_dim]
        dec_fc1_out = F.relu(self.dec_fc1(z))  # Output: [bs, 128]
        dec_fc2_out = F.relu(self.dec_fc2(dec_fc1_out)).view(-1, 32, 7, 7)  # Output: [bs, 1568], then reshaped to [bs, 32, 7, 7]
        dec_conv1_out = F.relu(self.dec_conv1(dec_fc2_out))  # Output: [bs, 16, 14, 14]
        recon_x = torch.sigmoid(self.dec_conv2(dec_conv1_out))  # Output: [bs, 1, 28, 28]

        return recon_x, z, mu, logvar, enc_conv1_out, enc_conv2_out, h, dec_fc1_out, dec_fc2_out, dec_conv1_out

def loss_function_lake(recon_x, z, mu, logvar, enc_conv1_out, enc_conv2_out, h, dec_fc1_out, dec_fc2_out, dec_conv1_out, x):
    BCE = F.binary_cross_entropy(recon_x, x, reduction='sum')
    # Layer-wise reconstruction loss
    layer_loss = F.mse_loss(enc_conv1_out, dec_conv1_out) + F.mse_loss(enc_conv2_out, dec_fc2_out) + F.mse_loss(h, dec_fc1_out)
    # KL Divergence
    KLD = -0.5 * torch.sum(1 + logvar - mu.pow(2) - logvar.exp())
    return BCE + KLD + layer_loss

# Definining Utility Classes for training and visualisation

In [7]:
def save_model(model, path):
    torch.save(model.state_dict(), path)

In [8]:
def load_model(model, path, device):
    model.load_state_dict(torch.load(path, map_location=device))

In [9]:
def plot_losses(train_losses, val_losses):
    plt.figure(figsize=(10, 6))
    plt.plot(train_losses, label='Training Loss')
    plt.plot(val_losses, label='Validation Loss')
    plt.xlabel('Epochs')
    plt.ylabel('Loss')
    plt.title('Training and Validation Loss')
    plt.legend()
    plt.show()

In [10]:
def generate_and_save_images(model, epoch, latent_vectors, folder="generated_images"):
    model.eval()
    with torch.no_grad():
        generated = model.decode(latent_vectors).cpu()
    for i, img in enumerate(generated):
        plt.imshow(img.squeeze(), cmap='gray')
        plt.savefig(f"{folder}/img_{epoch}_{i}.png")
    model.train()

In [11]:
def plot_latent_space(model, data_loader, device, num_samples=1000):
    model.eval()
    latents = []
    labels = []
    with torch.no_grad():
        for data, label in data_loader:
            data = data.to(device)
            mu, _ = model.encode(data)
            latents.append(mu.cpu().numpy())
            labels.append(label.numpy())
            if len(latents) * data_loader.batch_size > num_samples:
                break
    
    latents = np.concatenate(latents, axis=0)[:num_samples]
    labels = np.concatenate(labels, axis=0)[:num_samples]
    
    pca = PCA(n_components=2)
    latents_reduced = pca.fit_transform(latents)
    
    plt.figure(figsize=(10, 6))
    plt.scatter(latents_reduced[:, 0], latents_reduced[:, 1], c=labels, cmap='viridis', s=2, alpha=0.6)
    plt.colorbar()
    plt.title("Latent Space (PCA-reduced)")
    plt.xlabel("PCA Component 1")
    plt.ylabel("PCA Component 2")
    plt.show()
    model.train()

In [24]:
is_processing = False
def plot_losses_interactive(model, model_name, model_states, train_losses, val_losses, train_loader, val_loader, device, best_train_loss, best_val_loss):
    global is_processing
    epochs = len(train_losses)
    fig = go.FigureWidget()

    # Add traces for training and validation losses
    fig.add_trace(go.Scatter(x=list(range(1, epochs + 1)), y=train_losses, mode='lines+markers', name='Training Loss'))
    fig.add_trace(go.Scatter(x=list(range(1, epochs + 1)), y=val_losses, mode='lines+markers', name='Validation Loss'))

    # Set layout for the plot
    fig.update_layout(
        title=f'Interactive {model_name} Training and Validation Loss',
        xaxis_title='Epoch',
        yaxis_title='Loss',
        width=800, height=600
    )

    # Function to update the image on clicking a point on the plot
    def update_image(trace, points, selector):
        global is_processing
        if is_processing:
            return
        is_processing = True
    
        if points.point_inds:
            epoch = points.point_inds[0]
            model.load_state_dict(model_states[epoch])
            model.eval()
    
            # Determine which dataloader and best loss to use based on which trace was clicked
            if trace.name == 'Training Loss':
                data_loader = train_loader
                current_loss = train_losses[epoch]
                best_loss = best_train_loss
                loss_type = 'Training'
            elif trace.name == 'Validation Loss':
                data_loader = val_loader
                current_loss = val_losses[epoch]
                best_loss = best_val_loss
                loss_type = 'Validation'

            # Calculate loss as a percentage of the best loss
            loss_percentage = (best_loss / current_loss) * 100

            # Generate and display an image
            data, _ = next(iter(data_loader))
            data = data.to(device)
            reconstructed_img = model(data)[0].cpu().squeeze()

            if reconstructed_img.ndim == 3:  # If image has 3 dimensions, take the first one
                reconstructed_img = reconstructed_img[0]

            # Display information and the image
            plt.figure(figsize=(5, 5))
            plt.imshow(reconstructed_img.detach().numpy(), cmap='gray')
            plt.title(f'{model_name} {loss_type} Loss\nEpoch: {epoch + 1}\nLoss: {current_loss:.4f} ({loss_percentage:.2f}% rel to best)')
            plt.axis('off')
            clear_output(wait=True)
            display(plt.gcf())
            model.train()
    
        is_processing = False

    # Attach the click handler to the plot
    fig.data[0].on_click(update_image)  # For training loss
    fig.data[1].on_click(update_image)  # For validation loss

    # Display the plot
    display(fig)

In [25]:
def validate_vae(model, val_loader, loss_function, device):
    total_loss = 0
    with torch.no_grad():
        for data, _ in val_loader:
            data = data.to(device)
            outputs = model(data)
            loss = loss_function(*outputs, data)
            total_loss += loss.item()
    return total_loss / len(val_loader.dataset)

In [26]:
def train_vae(model, train_loader, val_loader, loss_function, optimiser, epochs, device, model_name, plot_interval=1):
    train_losses = []
    val_losses = []
    model_states = []
    best_val_loss = float('inf')
    best_model_state = None
    best_epoch = 0

    for epoch in range(epochs):
        model.train()
        total_loss = 0
        for batch_idx, (data, _) in enumerate(train_loader):
            data = data.to(device)
            optimiser.zero_grad()
            outputs = model(data)
            loss = loss_function(*outputs, data)
            loss.backward()
            optimiser.step()
            total_loss += loss.item()

        train_loss = total_loss / len(train_loader.dataset)
        train_losses.append(train_loss)

        model.eval()
        val_loss = validate_vae(model, val_loader, loss_function, device)
        val_losses.append(val_loss)

        # Save the current model state
        model_states.append(model.state_dict().copy())

        if val_loss < best_val_loss:
            best_val_loss = val_loss
            best_model_state = model.state_dict()
            best_epoch = epoch

        print(f'Epoch {epoch+1}, Training Loss: {train_loss}, Validation Loss: {val_loss}')    

    if best_model_state is not None:
        model_path = f"best_{model_name}_epoch_{best_epoch}.pth"
        save_model(model, model_path)
        print(f"Best {model_name} model saved as {model_path}")
        model.load_state_dict(best_model_state)

    
    best_train_loss = min(train_losses)
    best_val_loss = min(val_losses)
    plot_losses_interactive(model, model_name, model_states, train_losses, val_losses, train_loader, val_loader, device, best_train_loss, best_val_loss)
        
    return train_losses, val_losses

# Loading Dataset and Training

This cell loads in the mnist dataset into train, validation, and test dataloaders.
While doing so, they are normalised to be [0,1] and turned into tensors.

In [27]:
transform = transforms.Compose([transforms.ToTensor()])

# Download the MNIST dataset
mnist_trainset = datasets.MNIST(root='~/.pytorch/MNIST_data/', train=True, download=True, transform=transform)

# Splitting the dataset into train and validation sets
train_size = int(0.8 * len(mnist_trainset))
validation_size = len(mnist_trainset) - train_size
train_dataset, validation_dataset = random_split(mnist_trainset, [train_size, validation_size])

# Download and load the test data
test_dataset = datasets.MNIST(root='~/.pytorch/MNIST_data/', train=False, download=True, transform=transform)

In [28]:
# Training parameters
batch_size = 64
learning_rate = 1e-3
epochs = 2 

trainloader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, drop_last=False)
validationloader = DataLoader(validation_dataset, batch_size=batch_size, shuffle=True, drop_last=False)
testloader = DataLoader(test_dataset, batch_size=batch_size, shuffle=True, drop_last=False)

Standard_VAE = StandardVAE().to(DEVICE)
Lake_VAE = LakeVAE().to(DEVICE)

In [29]:
print(f'Training parameters: batch_size={batch_size}, learning_rate={learning_rate}, epochs={epochs}')

# For Standard VAE
print("Training Standard_VAE...")
Optimiser_Standard = torch.optim.Adam(Standard_VAE.parameters(), lr=learning_rate)
Trained_Standard_VAE = train_vae(Standard_VAE, trainloader, validationloader, loss_function_standard, Optimiser_Standard, epochs, DEVICE, "Standard_VAE")

Training parameters: batch_size=64, learning_rate=0.001, epochs=2
Training Standard_VAE...
Epoch 1, Training Loss: 177.80979630533855, Validation Loss: 130.8518377278646
Epoch 2, Training Loss: 120.63934587605794, Validation Loss: 115.23266548665364
Best Standard_VAE model saved as best_Standard_VAE_epoch_1.pth


FigureWidget({
    'data': [{'mode': 'lines+markers',
              'name': 'Training Loss',
              'type': 'scatter',
              'uid': 'd54b8d21-15f0-41ea-a129-2a438eec7acc',
              'x': [1, 2],
              'y': [177.80979630533855, 120.63934587605794]},
             {'mode': 'lines+markers',
              'name': 'Validation Loss',
              'type': 'scatter',
              'uid': '8e1ea81f-d7b3-45f4-82be-11efb047568b',
              'x': [1, 2],
              'y': [130.8518377278646, 115.23266548665364]}],
    'layout': {'height': 600,
               'template': '...',
               'title': {'text': 'Interactive Standard_VAE Training and Validation Loss'},
               'width': 800,
               'xaxis': {'title': {'text': 'Epoch'}},
               'yaxis': {'title': {'text': 'Loss'}}}
})

Training only uses 25MiB of GPU memory. Perhaps could try Jupyterlab and jupyterlab-nvdashboard to monitor GPU usage.

In [30]:
print(f'Training parameters: batch_size={batch_size}, learning_rate={learning_rate}, epochs={epochs}')

# For Lake VAE
print("Training Lake_VAE...")
Optimiser_Lake = torch.optim.Adam(Lake_VAE.parameters(), lr=learning_rate)
Trained_Lake_VAE = train_vae(Lake_VAE, trainloader, validationloader, loss_function_lake, Optimiser_Lake, epochs, DEVICE, "Lake_VAE")

Training parameters: batch_size=64, learning_rate=0.001, epochs=2
Training Lake_VAE...
Epoch 1, Training Loss: 176.32783767700195, Validation Loss: 132.32506477864584
Epoch 2, Training Loss: 123.31711346435547, Validation Loss: 116.42845627848307
Best Lake_VAE model saved as best_Lake_VAE_epoch_1.pth


FigureWidget({
    'data': [{'mode': 'lines+markers',
              'name': 'Training Loss',
              'type': 'scatter',
              'uid': 'dee4ed24-9166-42d5-97a0-a58ed3e927f0',
              'x': [1, 2],
              'y': [176.32783767700195, 123.31711346435547]},
             {'mode': 'lines+markers',
              'name': 'Validation Loss',
              'type': 'scatter',
              'uid': '2cd65548-8b87-46fa-8449-20f285ae86ce',
              'x': [1, 2],
              'y': [132.32506477864584, 116.42845627848307]}],
    'layout': {'height': 600,
               'template': '...',
               'title': {'text': 'Interactive Lake_VAE Training and Validation Loss'},
               'width': 800,
               'xaxis': {'title': {'text': 'Epoch'}},
               'yaxis': {'title': {'text': 'Loss'}}}
})

# Loading Models and Visualising Latent Space:

In [19]:
# Replace 'X' and 'Y' with the actual epoch numbers for the best Standard_VAE and Lake_VAE models, respectively
best_standard_vae_path = 'best_Standard_VAE_epoch_4.pth'
best_lake_vae_path = 'best_Lake_VAE_epoch_4.pth'

# Load the best models
best_Standard_VAE = StandardVAE().to(DEVICE)
load_model(best_Standard_VAE, best_standard_vae_path, DEVICE)

best_Lake_VAE = LakeVAE().to(DEVICE)
load_model(best_Lake_VAE, best_lake_vae_path, DEVICE)

In [35]:
is_processing = False

def reconstruct_from_latent_space(model, latent_point, pca, device):
    original_latent = pca.inverse_transform([latent_point])
    original_latent_tensor = torch.from_numpy(original_latent).float().to(device)
    reconstructed_img = model.decode(original_latent_tensor).cpu()
    return reconstructed_img[0].squeeze()

def plot_latent_space_interactive(model, data_loader, device, vaename, dataname, num_samples=1000):
    model.eval()
    latents = []
    labels = []
    with torch.no_grad():
        for data, label in data_loader:
            data = data.to(device)
            mu, _ = model.encode(data)
            latents.append(mu.cpu().numpy())
            labels.append(label.numpy())
            if len(latents) * data_loader.batch_size > num_samples:
                break

    latents = np.concatenate(latents, axis=0)[:num_samples]
    labels = np.concatenate(labels, axis=0)[:num_samples]
    
    pca = PCA(n_components=2)
    latents_reduced = pca.fit_transform(latents)

    # Define a custom color scale (10 different colors for digits 0-9)
    custom_color_scale = ['#636EFA', '#EF553B', '#00CC96', '#AB63FA', '#FFA15A',
                          '#19D3F3', '#FF6692', '#B6E880', '#FF97FF', '#FECB52']

    # Create traces for each digit with hover text
    traces = []
    for digit in range(10):
        digit_indices = np.where(labels == digit)[0]
        trace = go.Scatter(
            x=latents_reduced[digit_indices, 0], y=latents_reduced[digit_indices, 1],
            mode='markers', marker=dict(color=custom_color_scale[digit], size=10),
            name=str(digit),
            hoverinfo='text',
            text=[f'Label: {digit}, Pos: ({x:.2f}, {y:.2f})' for x, y in latents_reduced[digit_indices]]
        )
        traces.append(trace)

    # Plotly figure with separate traces
    fig = go.FigureWidget(traces)
    fig.update_layout(
        title=f'{vaename} Latent Space Visualization of MNIST {dataname} Dataset',
        xaxis_title='Principal Component 1',
        yaxis_title='Principal Component 2',
        width=800, height=600,
        legend_title_text='Digit Label'
    )

    def update_image(trace, points, selector):
        global is_processing
        if is_processing:
            return
    
        is_processing = True
    
        if points.point_inds:
            idx = points.point_inds[0]
            latent_point = latents_reduced[idx]
            img = reconstruct_from_latent_space(model, latent_point, pca, device)
            plt.imshow(img.detach().numpy(), cmap='gray')
            plt.axis('off')
            clear_output(wait=True)
            display(plt.gcf())
    
        is_processing = False

    for trace in fig.data:
        trace.on_click(update_image)

    display(fig)
    model.train()

In [36]:
plot_latent_space_interactive(best_Standard_VAE, trainloader, DEVICE, 'Best Standard VAE' ,'Training', num_samples=500)
# plot_latent_space_interactive(best_Standard_VAE, validiationloader, DEVICE, 'Best Standard VAE' ,'Validation')

FigureWidget({
    'data': [{'hoverinfo': 'text',
              'marker': {'color': '#636EFA', 'size': 10},
              'mode': 'markers',
              'name': '0',
              'text': [Label: 0, Pos: (0.02, -1.18), Label: 0, Pos: (1.53, 1.71),
                       Label: 0, Pos: (1.72, 1.25), Label: 0, Pos: (1.39, -1.87),
                       Label: 0, Pos: (0.34, 0.87), Label: 0, Pos: (1.20, -1.24),
                       Label: 0, Pos: (-0.18, 0.24), Label: 0, Pos: (0.46, 0.18),
                       Label: 0, Pos: (1.06, 0.67), Label: 0, Pos: (1.53, -0.36),
                       Label: 0, Pos: (1.90, -1.79), Label: 0, Pos: (0.98, -0.28),
                       Label: 0, Pos: (0.84, 0.10), Label: 0, Pos: (1.31, 0.15),
                       Label: 0, Pos: (-1.20, 0.05), Label: 0, Pos: (0.93, 0.03),
                       Label: 0, Pos: (0.14, -1.70), Label: 0, Pos: (1.17, -1.68),
                       Label: 0, Pos: (0.33, -1.33), Label: 0, Pos: (0.49, -0.81),
          

In [22]:
plot_latent_space_interactive(best_Lake_VAE, trainloader, DEVICE, 'Best layer-constrained VAE', 'Training', num_samples=500)
# plot_latent_space_interactive(best_Lake_VAE, validationloader, DEVICE, 'Best layer-constrained VAE', 'Validiation')

FigureWidget({
    'data': [{'hoverinfo': 'text',
              'marker': {'color': '#636EFA', 'size': 10},
              'mode': 'markers',
              'name': '0',
              'text': [Label: 0, Pos: (2.38, -1.25), Label: 0, Pos: (1.54, -1.22),
                       Label: 0, Pos: (-0.02, -0.21), Label: 0, Pos: (1.33, -0.22),
                       Label: 0, Pos: (0.88, -0.69), Label: 0, Pos: (3.49, -0.18),
                       Label: 0, Pos: (0.40, 0.05), Label: 0, Pos: (2.56, 0.73),
                       Label: 0, Pos: (1.79, 1.16), Label: 0, Pos: (1.18, 0.62),
                       Label: 0, Pos: (2.67, 0.44), Label: 0, Pos: (1.11, -0.34),
                       Label: 0, Pos: (1.24, -1.81), Label: 0, Pos: (2.06, 1.83),
                       Label: 0, Pos: (-0.16, 0.73), Label: 0, Pos: (1.77, 0.45),
                       Label: 0, Pos: (2.21, -0.35), Label: 0, Pos: (1.15, -0.61),
                       Label: 0, Pos: (0.48, 0.69), Label: 0, Pos: (1.87, -0.94),
         