In [173]:
from IPython.display import display, clear_output
import plotly.graph_objs as go
import matplotlib.pyplot as plt

import numpy as np
import pandas as pd 
import time
import os

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader, random_split
from torchvision import datasets, transforms

from sklearn.decomposition import PCA

# Testing GPU acceleration

In [174]:
# Check if CUDA is available
is_cuda_available = torch.cuda.is_available()
print("Is CUDA available:", is_cuda_available)

# Determine the device to use: GPU (CUDA), Apple Silicon (MPS), or CPU
DEVICE = (
    "cuda"
    if torch.cuda.is_available()
    else "mps"
    if torch.backends.mps.is_available()
    else "cpu"
)
print("Using device:", DEVICE)

Is CUDA available: True
Using device: cuda


In [175]:
if torch.cuda.is_available():
    # Test tensor operation on GPU
    test_tensor = torch.tensor([1.0, 2.0, 3.0], device="cuda")
    print("Test tensor on CUDA:", test_tensor)

Test tensor on CUDA: tensor([1., 2., 3.], device='cuda:0')


In [176]:
if is_cuda_available:
    try:
        test_tensor = torch.tensor([1, 2, 3], device=DEVICE)
        print("Successfully moved a tensor to the device:", test_tensor)
    except RuntimeError as e:
        print("Error moving a tensor to the device:", e)

Successfully moved a tensor to the device: tensor([1, 2, 3], device='cuda:0')


# Defining VAE classes

In [177]:
class StandardVAE(nn.Module):
    def __init__(self, latent_dim=20, w_dim=10):
        super(StandardVAE, self).__init__()
        self.latent_dim = latent_dim

        # Encoder layers
        # Input: [bs, 1, 28, 28]
        self.enc_conv1 = nn.Conv2d(1, 16, kernel_size=3, stride=2, padding=1)  # Output: [bs, 16, 14, 14]
        self.enc_conv2 = nn.Conv2d(16, 32, kernel_size=3, stride=2, padding=1) # Output: [bs, 32, 7, 7]
        self.enc_fc1 = nn.Linear(32 * 7 * 7, w_dim)  # Output: [bs, 128]
        # Two output layers for the latent space
        self.enc_fc2 = nn.Linear(w_dim, latent_dim)  # For mu, Output: [bs, latent_dim]
        self.enc_fc3 = nn.Linear(w_dim, latent_dim)  # For logvar, Output: [bs, latent_dim]

        # Decoder layers
        self.dec_fc1 = nn.Linear(latent_dim, w_dim)  # Output: [bs, 128]
        self.dec_fc2 = nn.Linear(w_dim, 32 * 7 * 7)  # Output: [bs, 1568]
        self.dec_conv1 = nn.ConvTranspose2d(32, 16, kernel_size=3, stride=2, padding=1, output_padding=1)  # Output: [bs, 16, 14, 14]
        self.dec_conv2 = nn.ConvTranspose2d(16, 1, kernel_size=3, stride=2, padding=1, output_padding=1)  # Output: [bs, 1, 28, 28]

    def encode(self, x):
        h = F.relu(self.enc_conv1(x))
        h = F.relu(self.enc_conv2(h))
        h = torch.flatten(h, start_dim=1)
        h = F.relu(self.enc_fc1(h))
        return self.enc_fc2(h), self.enc_fc3(h)

    def reparameterize(self, mu, logvar):
        std = torch.exp(0.5 * logvar)
        # Returns a tensor with the same size as input that is filled with random numbers from a normal distribution with mean 0 and variance 1.
        eps = torch.randn_like(std)
        return mu + eps * std

    def decode(self, z):
        h = F.relu(self.dec_fc1(z)) 
        h = F.relu(self.dec_fc2(h)).view(-1, 32, 7, 7) # .view reshapes [bs, 1568] to [bs, 32, 7, 7]
        h = F.relu(self.dec_conv1(h))
        return torch.sigmoid(self.dec_conv2(h))

    def forward(self, x):
        mu, logvar = self.encode(x)
        z = self.reparameterize(mu, logvar)
        return self.decode(z), z, mu, logvar

def loss_function_standard(recon_x, z, mu, logvar, x):
    BCE = F.binary_cross_entropy(recon_x, x, reduction='sum')
    KLD = -0.5 * torch.sum(1 + logvar - mu.pow(2) - logvar.exp())
    # Can reweight BCE + KLD as desired
    return BCE + KLD

LakeVAE inheirits from StandardVAE, but with a modified forward pass.
Its encoding and decoding layers are identicle. 
The only difference is that the forward pass stores and returns all the intermediate values required to calculate the modified reconstruction loss.

Since pixel values are either 0 or 1, we can use BCE between the input image and output image.
However, for the intermediate layers, that are continuous, we use can MSE instead. This is also what the paper's code does in practice.

In [178]:
class LakeVAE(StandardVAE):
    def forward(self, x):
        # Encoder layers
        # Input: [bs, 1, 28, 28]
        enc_conv1_out = F.relu(self.enc_conv1(x))  # Output: [bs, 16, 14, 14]
        enc_conv2_out = F.relu(self.enc_conv2(enc_conv1_out))  # Output: [bs, 32, 7, 7]
        flattened = torch.flatten(enc_conv2_out, start_dim=1)  # Output: [bs, 1568]
        
        # w = F.relu(self.enc_fc1(flattened))  # Output: [bs, 128] # Eqn (5) in LAKE paper
        w = self.enc_fc1(flattened)  # Not sure if relu is used in the paper or not. Output: [bs, 128] # Eqn (5) in LAKE paper
        
        mu, logvar = self.enc_fc2(w), self.enc_fc3(w)  # Output: [bs, latent_dim], [bs, latent_dim] # Eqn (6) in LAKE paper

        # Reparameterization and Decoding layers
        z = self.reparameterize(mu, logvar)  # Output: [bs, latent_dim] # Eqn (7) in LAKE paper
        dec_fc1_out = F.relu(self.dec_fc1(z))  # Output: [bs, 128]
        dec_fc2_out = F.relu(self.dec_fc2(dec_fc1_out)).view(-1, 32, 7, 7)  # Output: [bs, 1568], then reshaped to [bs, 32, 7, 7]
        dec_conv1_out = F.relu(self.dec_conv1(dec_fc2_out))  # Output: [bs, 16, 14, 14]
        recon_x = torch.sigmoid(self.dec_conv2(dec_conv1_out))  # Output: [bs, 1, 28, 28]

        return recon_x, z, mu, logvar, enc_conv1_out, enc_conv2_out, w, dec_fc1_out, dec_fc2_out, dec_conv1_out

def loss_function_lake(recon_x, z, mu, logvar, enc_conv1_out, enc_conv2_out, w, dec_fc1_out, dec_fc2_out, dec_conv1_out, x):
    BCE = F.binary_cross_entropy(recon_x, x, reduction='sum')
    # Layer-wise reconstruction loss
    layer_loss = F.mse_loss(enc_conv1_out, dec_conv1_out) + F.mse_loss(enc_conv2_out, dec_fc2_out) + F.mse_loss(w, dec_fc1_out)
    # KL Divergence
    KLD = -0.5 * torch.sum(1 + logvar - mu.pow(2) - logvar.exp())
    return BCE + KLD + layer_loss

# Definining Utility Classes for training and visualisation

In [179]:
def save_model(model, path):
    torch.save(model.state_dict(), path)

In [180]:
def load_model(model, path, device):
    model.load_state_dict(torch.load(path, map_location=device))

In [181]:
def plot_losses(train_losses, val_losses):
    plt.figure(figsize=(10, 6))
    plt.plot(train_losses, label='Training Loss')
    plt.plot(val_losses, label='Validation Loss')
    plt.xlabel('Epochs')
    plt.ylabel('Loss')
    plt.title('Training and Validation Loss')
    plt.legend()
    plt.show()

In [182]:
def generate_and_save_images(model, epoch, latent_vectors, folder="generated_images"):
    model.eval()
    with torch.no_grad():
        generated = model.decode(latent_vectors).cpu()
    for i, img in enumerate(generated):
        plt.imshow(img.squeeze(), cmap='gray')
        plt.savefig(f"{folder}/img_{epoch}_{i}.png")
    model.train()

In [183]:
def plot_latent_space(model, data_loader, device, num_samples=1000):
    model.eval()
    latents = []
    labels = []
    with torch.no_grad():
        for data, label in data_loader:
            data = data.to(device)
            mu, _ = model.encode(data)
            latents.append(mu.cpu().numpy())
            labels.append(label.numpy())
            if len(latents) * data_loader.batch_size > num_samples:
                break
    
    latents = np.concatenate(latents, axis=0)[:num_samples]
    labels = np.concatenate(labels, axis=0)[:num_samples]
    
    pca = PCA(n_components=2)
    latents_reduced = pca.fit_transform(latents)
    
    plt.figure(figsize=(10, 6))
    plt.scatter(latents_reduced[:, 0], latents_reduced[:, 1], c=labels, cmap='viridis', s=2, alpha=0.6)
    plt.colorbar()
    plt.title("Latent Space (PCA-reduced)")
    plt.xlabel("PCA Component 1")
    plt.ylabel("PCA Component 2")
    plt.show()
    model.train()

In [184]:
is_processing = False
def plot_losses_interactive(model, model_name, model_states, train_losses, val_losses, train_loader, val_loader, device, best_train_loss, best_val_loss):
    global is_processing
    epochs = len(train_losses)
    fig = go.FigureWidget()

    # Add traces for training and validation losses
    fig.add_trace(go.Scatter(x=list(range(1, epochs + 1)), y=train_losses, mode='lines+markers', name='Training Loss'))
    fig.add_trace(go.Scatter(x=list(range(1, epochs + 1)), y=val_losses, mode='lines+markers', name='Validation Loss'))

    # Set layout for the plot
    fig.update_layout(
        title=f'Interactive {model_name} Training and Validation Loss',
        xaxis_title='Epoch',
        yaxis_title='Loss',
        width=800, height=600
    )

    # Function to update the image on clicking a point on the plot
    def update_image(trace, points, selector):
        global is_processing
        if is_processing:
            return
        is_processing = True
    
        if points.point_inds:
            epoch = points.point_inds[0]
            model.load_state_dict(model_states[epoch])
            model.eval()
    
            # Determine which dataloader and best loss to use based on which trace was clicked
            if trace.name == 'Training Loss':
                data_loader = train_loader
                current_loss = train_losses[epoch]
                best_loss = best_train_loss
                loss_type = 'Training'
            elif trace.name == 'Validation Loss':
                data_loader = val_loader
                current_loss = val_losses[epoch]
                best_loss = best_val_loss
                loss_type = 'Validation'

            # Calculate loss as a percentage of the best loss
            loss_percentage = (best_loss / current_loss) * 100

            # Generate and display an image
            data, _ = next(iter(data_loader))
            data = data.to(device)
            reconstructed_img = model(data)[0].cpu().squeeze()

            if reconstructed_img.ndim == 3:  # If image has 3 dimensions, take the first one
                reconstructed_img = reconstructed_img[0]

            # Display information and the image
            plt.figure(figsize=(5, 5))
            plt.imshow(reconstructed_img.detach().numpy(), cmap='gray')
            plt.title(f'{model_name} {loss_type} Loss\nEpoch: {epoch + 1}\nLoss: {current_loss:.4f} ({loss_percentage:.2f}% rel to best)')
            plt.axis('off')
            clear_output(wait=True)
            display(plt.gcf())
            model.train()
    
        is_processing = False

    # Attach the click handler to the plot
    fig.data[0].on_click(update_image)  # For training loss
    fig.data[1].on_click(update_image)  # For validation loss

    # Display the plot
    display(fig)

In [185]:
def validate_vae(model, val_loader, loss_function, device):
    total_loss = 0
    with torch.no_grad():
        for data, _ in val_loader:
            data = data.to(device)
            outputs = model(data)
            loss = loss_function(*outputs, data)
            total_loss += loss.item()
    return total_loss / len(val_loader.dataset)

modify train_vae to instead to the probability density estimation training.
I'm not sure how its training, or what this is doing, considering there's no updates to anything?
write the functions for rec_euclidean and rec_cosine.
The first thing the VAE classes return with their forward pass is their reconstructed images x'.

In [186]:
def train_vae(model, train_loader, val_loader, loss_function, optimiser, epochs, device, model_name, plot_interval=1):
    train_losses = []
    val_losses = []
    model_states = []
    best_val_loss = float('inf')
    best_model_state = None
    best_epoch = 0

    for epoch in range(epochs):
        model.train()
        total_loss = 0
        for batch_idx, (data, _) in enumerate(train_loader):
            data = data.to(device)
            optimiser.zero_grad()
            outputs = model(data)
            loss = loss_function(*outputs, data)
            loss.backward()
            optimiser.step()
            total_loss += loss.item()

        train_loss = total_loss / len(train_loader.dataset)
        train_losses.append(train_loss)

        model.eval()
        val_loss = validate_vae(model, val_loader, loss_function, device)
        val_losses.append(val_loss)

        # Save the current model state
        model_states.append(model.state_dict().copy())

        if val_loss < best_val_loss:
            best_val_loss = val_loss
            best_model_state = model.state_dict()
            best_epoch = epoch

        print(f'Epoch {epoch+1}, Training Loss: {train_loss}, Validation Loss: {val_loss}')

        print(f'Saving {model_name} model at epoch {epoch+1}')
        save_model(model, f'{model_name}_{epoch+1}.pth')
        

    if best_model_state is not None:
        model_path = f"best_{model_name}_epoch_{best_epoch}.pth"
        save_model(model, model_path)
        print(f"Best {model_name} model saved as {model_path}")
        model.load_state_dict(best_model_state)

    
    best_train_loss = min(train_losses)
    best_val_loss = min(val_losses)
    plot_losses_interactive(model, model_name, model_states, train_losses, val_losses, train_loader, val_loader, device, best_train_loss, best_val_loss)
        
    return train_losses, val_losses

# Loading Dataset and Training

This cell loads in the mnist dataset into train, validation, and test dataloaders.
While doing so, they are normalised to be [0,1] and turned into tensors.

In [187]:
transform = transforms.Compose([transforms.ToTensor()])

# Download the MNIST dataset
mnist_trainset = datasets.MNIST(root='~/.pytorch/MNIST_data/', train=True, download=True, transform=transform)

# Splitting the dataset into train and validation sets
train_size = int(0.8 * len(mnist_trainset))
validation_size = len(mnist_trainset) - train_size
train_dataset, validation_dataset = random_split(mnist_trainset, [train_size, validation_size])

# Download and load the test data
test_dataset = datasets.MNIST(root='~/.pytorch/MNIST_data/', train=False, download=True, transform=transform)

In [188]:
# Training parameters
batch_size = 64
learning_rate = 1e-3
epochs = 100

trainloader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, drop_last=False)
validationloader = DataLoader(validation_dataset, batch_size=batch_size, shuffle=True, drop_last=False)
testloader = DataLoader(test_dataset, batch_size=batch_size, shuffle=True, drop_last=False)

Standard_VAE = StandardVAE().to(DEVICE)
Lake_VAE = LakeVAE().to(DEVICE)

In [189]:
# Run cell to train Standard VAE
print(f'Training parameters: batch_size={batch_size}, learning_rate={learning_rate}, epochs={epochs}')
print("Training Standard_VAE...")
Optimiser_Standard = torch.optim.Adam(Standard_VAE.parameters(), lr=learning_rate)
Trained_Standard_VAE = train_vae(Standard_VAE, trainloader, validationloader, loss_function_standard, Optimiser_Standard, epochs, DEVICE, "Standard_VAE")

Training parameters: batch_size=64, learning_rate=0.001, epochs=10
Training Standard_VAE...
Epoch 1, Training Loss: 226.89376039632162, Validation Loss: 193.1483650716146
Saving Standard_VAE model at epoch 1
Epoch 2, Training Loss: 184.66886783854167, Validation Loss: 174.04674861653646
Saving Standard_VAE model at epoch 2
Epoch 3, Training Loss: 164.75998073323566, Validation Loss: 159.01668273925782
Saving Standard_VAE model at epoch 3
Epoch 4, Training Loss: 156.47857543945312, Validation Loss: 154.85309574381512
Saving Standard_VAE model at epoch 4
Epoch 5, Training Loss: 153.54923905436198, Validation Loss: 152.9229608561198
Saving Standard_VAE model at epoch 5
Epoch 6, Training Loss: 151.72211771647136, Validation Loss: 151.54623213704426
Saving Standard_VAE model at epoch 6
Epoch 7, Training Loss: 150.495251180013, Validation Loss: 150.29708162434895
Saving Standard_VAE model at epoch 7
Epoch 8, Training Loss: 149.39842336018881, Validation Loss: 149.78097778320313
Saving Standa

FigureWidget({
    'data': [{'mode': 'lines+markers',
              'name': 'Training Loss',
              'type': 'scatter',
              'uid': 'afdb25bf-818f-48d3-8a0e-e54da4e43351',
              'x': [1, 2, 3, 4, 5, 6, 7, 8, 9, 10],
              'y': [226.89376039632162, 184.66886783854167, 164.75998073323566,
                    156.47857543945312, 153.54923905436198, 151.72211771647136,
                    150.495251180013, 149.39842336018881, 147.88727329508464,
                    146.97939158121744]},
             {'mode': 'lines+markers',
              'name': 'Validation Loss',
              'type': 'scatter',
              'uid': 'da655403-db7c-4118-83f3-503c0dc824c9',
              'x': [1, 2, 3, 4, 5, 6, 7, 8, 9, 10],
              'y': [193.1483650716146, 174.04674861653646, 159.01668273925782,
                    154.85309574381512, 152.9229608561198, 151.54623213704426,
                    150.29708162434895, 149.78097778320313, 147.8560028889974,
                  

In [190]:
# Run cell to train Lake VAE
print(f'Training parameters: batch_size={batch_size}, learning_rate={learning_rate}, epochs={epochs}')
print("Training Lake_VAE...")
Optimiser_Lake = torch.optim.Adam(Lake_VAE.parameters(), lr=learning_rate)
Trained_Lake_VAE = train_vae(Lake_VAE, trainloader, validationloader, loss_function_lake, Optimiser_Lake, epochs, DEVICE, "Lake_VAE")

Training parameters: batch_size=64, learning_rate=0.001, epochs=10
Training Lake_VAE...
Epoch 1, Training Loss: 196.97447145589192, Validation Loss: 157.30794738769532
Saving Lake_VAE model at epoch 1
Epoch 2, Training Loss: 146.9710359802246, Validation Loss: 139.83250321451823
Saving Lake_VAE model at epoch 2
Epoch 3, Training Loss: 135.70760551961263, Validation Loss: 133.41635420735676
Saving Lake_VAE model at epoch 3
Epoch 4, Training Loss: 131.69676136271158, Validation Loss: 130.97847662353516
Saving Lake_VAE model at epoch 4
Epoch 5, Training Loss: 129.85162173461913, Validation Loss: 129.6448965250651
Saving Lake_VAE model at epoch 5
Epoch 6, Training Loss: 128.53322006225585, Validation Loss: 128.36548286946615
Saving Lake_VAE model at epoch 6
Epoch 7, Training Loss: 127.50781346638998, Validation Loss: 128.03276088460285
Saving Lake_VAE model at epoch 7
Epoch 8, Training Loss: 126.6456785176595, Validation Loss: 126.97166536458333
Saving Lake_VAE model at epoch 8
Epoch 9, Tr

FigureWidget({
    'data': [{'mode': 'lines+markers',
              'name': 'Training Loss',
              'type': 'scatter',
              'uid': '8651cf43-7c5a-4bf3-8f33-aa9be1d3ca0c',
              'x': [1, 2, 3, 4, 5, 6, 7, 8, 9, 10],
              'y': [196.97447145589192, 146.9710359802246, 135.70760551961263,
                    131.69676136271158, 129.85162173461913, 128.53322006225585,
                    127.50781346638998, 126.6456785176595, 125.97908441162109,
                    125.418255859375]},
             {'mode': 'lines+markers',
              'name': 'Validation Loss',
              'type': 'scatter',
              'uid': '98df721b-b72b-4c13-9c03-17e1744d826e',
              'x': [1, 2, 3, 4, 5, 6, 7, 8, 9, 10],
              'y': [157.30794738769532, 139.83250321451823, 133.41635420735676,
                    130.97847662353516, 129.6448965250651, 128.36548286946615,
                    128.03276088460285, 126.97166536458333, 126.47102274576822,
                  

Training only uses 25MiB of GPU memory. Perhaps could try Jupyterlab and jupyterlab-nvdashboard to monitor GPU usage.

# Loading Models

In [192]:
# best_standard_vae_path = 'best_Standard_VAE_epoch_X.pth'
# best_lake_vae_path = 'best_Lake_VAE_epoch_Y.pth'

standard_vae_path = 'Standard_VAE_9.pth'
lake_vae_path = 'Lake_VAE_9.pth'

# Load the best models
Standard_VAE = StandardVAE().to(DEVICE)
load_model(Standard_VAE, standard_vae_path, DEVICE)

Lake_VAE = LakeVAE().to(DEVICE)
load_model(Lake_VAE, lake_vae_path, DEVICE)

# Visualising Trained Latent Space

In [193]:
is_processing = False

def reconstruct_from_latent_space(model, latent_point, pca, device):
    original_latent = pca.inverse_transform([latent_point])
    original_latent_tensor = torch.from_numpy(original_latent).float().to(device)
    reconstructed_img = model.decode(original_latent_tensor).cpu()
    return reconstructed_img[0].squeeze()

def plot_latent_space_interactive(model, data_loader, device, vaename, dataname, num_samples=1000):
    model.eval()
    latents = []
    labels = []
    with torch.no_grad():
        for data, label in data_loader:
            data = data.to(device)
            mu, _ = model.encode(data)
            latents.append(mu.cpu().numpy())
            labels.append(label.numpy())
            if len(latents) * data_loader.batch_size > num_samples:
                break

    latents = np.concatenate(latents, axis=0)[:num_samples]
    labels = np.concatenate(labels, axis=0)[:num_samples]
    
    pca = PCA(n_components=2)
    latents_reduced = pca.fit_transform(latents)

    # Define a custom color scale (10 different colors for digits 0-9)
    custom_color_scale = ['#636EFA', '#EF553B', '#00CC96', '#AB63FA', '#FFA15A',
                          '#19D3F3', '#FF6692', '#B6E880', '#FF97FF', '#FECB52']

    # Create traces for each digit with hover text
    traces = []
    for digit in range(10):
        digit_indices = np.where(labels == digit)[0]
        trace = go.Scatter(
            x=latents_reduced[digit_indices, 0], y=latents_reduced[digit_indices, 1],
            mode='markers', marker=dict(color=custom_color_scale[digit], size=10),
            name=str(digit),
            hoverinfo='text',
            text=[f'Label: {digit}, Pos: ({x:.2f}, {y:.2f})' for x, y in latents_reduced[digit_indices]]
        )
        traces.append(trace)

    # Plotly figure with separate traces
    fig = go.FigureWidget(traces)
    fig.update_layout(
        title=f'{vaename} Latent Space Visualisation of MNIST {dataname} Data',
        xaxis_title='Principal Component 1',
        yaxis_title='Principal Component 2',
        width=800, height=600,
        legend_title_text='Digit Label'
    )

    def update_image(trace, points, selector):
        global is_processing
        if is_processing:
            return
    
        is_processing = True
    
        if points.point_inds:
            idx = points.point_inds[0]
            latent_point = latents_reduced[idx]
            img = reconstruct_from_latent_space(model, latent_point, pca, device)
            plt.imshow(img.detach().numpy(), cmap='gray')
            plt.axis('off')
            clear_output(wait=True)
            display(plt.gcf())
    
        is_processing = False

    for trace in fig.data:
        trace.on_click(update_image)

    display(fig)
    model.train()

In [194]:
#plot_latent_space_interactive(Standard_VAE, trainloader, DEVICE, 'Standard VAE Epochs' ,'Training', num_samples=500)
plot_latent_space_interactive(Standard_VAE, validationloader, DEVICE, 'Standard VAE Epochs' ,'Validation', num_samples=5000)

FigureWidget({
    'data': [{'hoverinfo': 'text',
              'marker': {'color': '#636EFA', 'size': 10},
              'mode': 'markers',
              'name': '0',
              'text': [Label: 0, Pos: (2.49, -0.01), Label: 0, Pos: (2.81, 1.56),
                       Label: 0, Pos: (2.69, 0.50), ..., Label: 0, Pos: (2.68,
                       0.98), Label: 0, Pos: (2.73, -0.38), Label: 0, Pos: (2.58,
                       -0.35)],
              'type': 'scatter',
              'uid': '581d373a-bd57-4a76-9c8f-6ecd65a4f751',
              'x': array([2.4901571, 2.8095226, 2.686181 , ..., 2.6832032, 2.7348704, 2.582164 ],
                         dtype=float32),
              'y': array([-0.01425579,  1.5630798 ,  0.49716645, ...,  0.9761483 , -0.38007715,
                          -0.34660178], dtype=float32)},
             {'hoverinfo': 'text',
              'marker': {'color': '#EF553B', 'size': 10},
              'mode': 'markers',
              'name': '1',
              'tex

In [195]:
#plot_latent_space_interactive(Lake_VAE, trainloader, DEVICE, 'layer-constrained VAE Epochs', 'Training', num_samples=500)
plot_latent_space_interactive(Lake_VAE, validationloader, DEVICE, 'layer-constrained VAE Epochs', 'Validiation', num_samples=5000)

FigureWidget({
    'data': [{'hoverinfo': 'text',
              'marker': {'color': '#636EFA', 'size': 10},
              'mode': 'markers',
              'name': '0',
              'text': [Label: 0, Pos: (2.38, -0.38), Label: 0, Pos: (1.01, -0.05),
                       Label: 0, Pos: (0.65, -0.70), ..., Label: 0, Pos: (0.53,
                       -0.18), Label: 0, Pos: (0.11, -1.20), Label: 0, Pos: (1.58,
                       -1.55)],
              'type': 'scatter',
              'uid': '6218fe83-b214-4540-9664-aada615ec135',
              'x': array([2.378218  , 1.0071332 , 0.6456206 , ..., 0.5291387 , 0.11081575,
                          1.5753206 ], dtype=float32),
              'y': array([-0.38112515, -0.05147202, -0.7019567 , ..., -0.17912616, -1.2015626 ,
                          -1.553151  ], dtype=float32)},
             {'hoverinfo': 'text',
              'marker': {'color': '#EF553B', 'size': 10},
              'mode': 'markers',
              'name': '1',
        

# Probability Density Estimation Training and Testing

Using 900 MNIST validation images (labelled 0) and 100 GAN generated MNIST images (labelled 1) to create a combined testing dataset. 
Then attempting to see if the LAKE anomaly detection can find the synthetic images.

In [196]:
from torchvision.datasets import MNIST
from torch.utils.data import DataLoader, TensorDataset, ConcatDataset
import torch
import matplotlib.pyplot as plt
import numpy as np
from sklearn.neighbors import KernelDensity
from sklearn.model_selection import GridSearchCV

# Load the validation dataset
mnist_validation_data = validation_dataset 

# Extract MNIST images and labels
mnist_images = mnist_validation_data.dataset.data[mnist_validation_data.indices][:900]
mnist_labels = torch.zeros(900)  # Label '0' for real MNIST data

# Load GAN generated images
gan_images_path = 'gan_generated_mnist_images.pt'
gan_images = torch.load(gan_images_path)

# Normalize the MNIST and GAN images if they are not already
mnist_images = mnist_images.float() / 255.0
mnist_images = mnist_images[:,None]
gan_images = gan_images.float() / 255.0 if gan_images.max() > 1.0 else gan_images

# Add a pure white image as an anomaly
white_image = torch.ones(1, 1, 28, 28)
white_label = torch.tensor([1])  # Anomaly label

# Combine MNIST, GAN, and white image into one dataset
combined_images = torch.cat((mnist_images, gan_images, white_image), dim=0)
combined_labels = torch.cat((mnist_labels, torch.ones(100), white_label), dim=0)

# Create a TensorDataset and DataLoader
combined_dataset = TensorDataset(combined_images, combined_labels)
combined_loader = DataLoader(combined_dataset, batch_size=64, shuffle=True)

# Check the normalization
max_pixel_value = combined_images.max()
min_pixel_value = combined_images.min()
print(f"Range of pixel values: {min_pixel_value} to {max_pixel_value}")

Range of pixel values: 0.0 to 1.0


In [212]:
# Function to encode dataset and compute reconstruction errors
def encode_and_reconstruct(model, dataloader, device):
    model.eval()
    ws = []
    rec_errors = []
    with torch.no_grad():
        for x, _ in dataloader:
            x = x.to(device)
            # Get w and x' in forward pass
            recon_x, _, _, _, _, _, w, _, _, _ = model(x)

            # Store w
            ws.append(w.cpu()) 

            # Calc and store rec_eu and rec_co
            rec_euclidean = torch.norm(x - recon_x, p=2, dim=(1, 2, 3))
            rec_cosine = F.cosine_similarity(x.view(x.size(0), -1), recon_x.view(recon_x.size(0), -1), dim=1)
            r = torch.stack((rec_euclidean, rec_cosine), dim=1)
            rec_errors.append(r.cpu()) 
            
    return torch.cat(ws, dim=0), torch.cat(rec_errors, dim=0)
    
Lake_VAE.to(DEVICE)

# For Training Data
encoded_ws, reconstruction_rs = encode_and_reconstruct(Lake_VAE, trainloader, DEVICE)
assert not torch.isnan(encoded_ws).any(), "NaNs in encoded_ws"
assert not torch.isnan(reconstruction_rs).any(), "NaNs in reconstruction_rs"
C_Train = np.hstack((encoded_ws, reconstruction_rs))
print(C_Train.shape)

# Testing Data
encoded_ws, reconstruction_rs = encode_and_reconstruct(Lake_VAE, combined_loader, DEVICE)
assert not torch.isnan(encoded_ws).any(), "NaNs in encoded_ws"
assert not torch.isnan(reconstruction_rs).any(), "NaNs in reconstruction_rs"
C = np.hstack((encoded_ws, reconstruction_rs))
print(C.shape)

(48000, 12)
(1001, 12)


There are three main things that affect the anomaly detection results. <br>
A. The model used (how well the data is compressed). The ability of your VAE to compress and reconstruct data is critical. It is contained within w and r. <br>
B. The KDE's bandwidth setting, corresponding to the smoothness of the density estimate. <br>
If it's too narrow, you might have a very bumpy estimate that's sensitive to noise. <br>
If it's too wide, the estimate might be too smooth and anomalies could be missed because they blend in with the normal data. <br>
C. Anomaly detection threshold. What proportion of your dataset you set expected to be an anomaly. <br>

In [214]:
# KernelDensity from sklearn represents fh(s) = (1/n) ∑[i=1 to n] Kh(s - ci)
from joblib import dump

def perform_kde(C_Train):
    # Doing k-fold validation on to find good bandwidth value
    params = {'bandwidth': np.linspace(0.01, 0.2, 20)}
    grid = GridSearchCV(KernelDensity(kernel='gaussian'), params, cv=5)
    grid.fit(C_Train)
    print(f"Optimal bandwidth: {grid.best_estimator_.bandwidth}")
    kde = KernelDensity(kernel='gaussian', bandwidth=grid.best_estimator_.bandwidth)
    
    #kde = KernelDensity(kernel='gaussian', bandwidth=0.05) # Manual bandwidth setting
    kde.fit(C_Train)
    return kde

# Function to estimate density from C values
def estimate_density(kde, C):
    log_density = kde.score_samples(C)
    return np.exp(log_density)

# Calibrate and Save KDE
kde_model = perform_kde(C_Train)
dump(kde_model, 'kde_model.joblib')

density_estimates = estimate_density(kde_model, C)

# Determine anomaly threshold and detect anomalies
threshold = np.percentile(density_estimates, 10)
anomalies = density_estimates < threshold

# Calculate and print anomaly detection results
detected_anomalies = np.sum(anomalies[-101:]) 
print(f"Detected {detected_anomalies} anomalies out of 101 synthetic images.")

# Plotting
plt.figure(figsize=(12, 6))
plt.hist(density_estimates, bins=100, alpha=0.5, color='blue', label='Density Scores')
plt.axvline(threshold, color='red', linestyle='--', label='Threshold')
plt.legend()
plt.title('Density Estimates and Anomaly Threshold')
plt.xlabel('Density Score')
plt.ylabel('Frequency')
plt.tight_layout()
plt.show()

KeyboardInterrupt: 

In [None]:
from joblib import load
kde_model = load('kde_model.joblib')