# Testing the ReGene framework

## Setup


Import libraries

In [1]:
import importlib
import regene_models
importlib.reload(regene_models)
import torch
import torch.nn as nn
import torch.optim as optim
import torchvision
import torchvision.transforms as transforms
import matplotlib.pyplot as plt
import numpy as np
import os

Set the device

In [None]:
if torch.cuda.is_available():
    device = torch.device("cuda")
elif torch.backends.mps.is_available():
    device = torch.device("mps")
else:
    device = torch.device("cpu")

print(f"Using device: {device}")

Load the Datasets

In [3]:
# Load the MNIST dataset
transform = transforms.Compose([transforms.ToTensor()])
trainset = torchvision.datasets.MNIST(root='../data', train=True, download=True, transform=transform)
trainloader = torch.utils.data.DataLoader(trainset, batch_size=64, shuffle=True, num_workers=2)

Set the latent dimension

In [4]:
latent_dim = 256

Create a models directory if it doesn't exist

In [5]:
# Create models directory if it doesn't exist
os.makedirs('models', exist_ok=True)

## Classifier

### Training

First we define the classifier

In [6]:
classifier = regene_models.Classifier(latent_dim=latent_dim, num_classes=10, device=device)

Then we train

In [None]:
# Train classifier and save
classifier.train_classifier(trainloader, num_epochs=10, lr=0.001)
torch.save(classifier.state_dict(), 'models/classifier.pth')

### Testing


First let's test the classifier on a few images

In [None]:
# Get random indices for test images
random_indices = torch.randint(0, len(trainset), (5,))
images = torch.stack([trainset[i][0] for i in random_indices])
labels = torch.tensor([trainset[i][1] for i in random_indices])

# Get predictions
classifier.eval()  # Set to evaluation mode
with torch.no_grad():
    images = images.to(device)
    _, predictions = classifier(images)
    predicted_classes = torch.argmax(predictions, dim=1)

# Plot images with true and predicted labels
plt.figure(figsize=(15, 3))
for i in range(5):
    plt.subplot(1, 5, i + 1)
    plt.imshow(images[i].cpu().squeeze().numpy(), cmap='gray')
    plt.title(f'True: {labels[i].item()}\nPred: {predicted_classes[i].cpu().item()}')
    plt.axis('off')

plt.tight_layout()
plt.show()


We'll also visualise the latent space. This is done by taking the latent representations of 50 training images and plotting them in 2D using t-SNE.

In [None]:
# Get latent representations for 50 random training images
random_indices = torch.randint(0, len(trainset), (1000,))
images = torch.stack([trainset[i][0] for i in random_indices])
labels = torch.tensor([trainset[i][1] for i in random_indices])

# Get latent representations
classifier.eval()
with torch.no_grad():
    images = images.to(device)
    latent_reps, _ = classifier(images)
    latent_reps = latent_reps.cpu().numpy()

# Perform t-SNE dimensionality reduction
from sklearn.manifold import TSNE
tsne = TSNE(n_components=2, random_state=42)
latent_2d = tsne.fit_transform(latent_reps)

# Plot the 2D latent space
plt.figure(figsize=(10, 8))
scatter = plt.scatter(latent_2d[:, 0], latent_2d[:, 1], c=labels, cmap='tab10')
plt.colorbar(scatter, label='Digit Class')
plt.title('t-SNE Visualization of Latent Space')
plt.xlabel('t-SNE Dimension 1')
plt.ylabel('t-SNE Dimension 2')
plt.legend(*scatter.legend_elements(), title="Classes")
plt.show()


## Decoder

### Training

We define the decoder, and then train it using the classifier's latent space.

In [99]:
decoder = regene_models.Decoder(latent_dim=latent_dim, device=device)  

In [None]:
decoder.train_decoder(trainloader, classifier, num_epochs=12, lr=0.001)
torch.save(decoder.state_dict(), 'models/decoder.pth')

### Testing


First less visualise some reconstructions

In [None]:
# Get 10 random images from training set
dataiter = iter(trainloader)
images, _ = next(dataiter)
images = images[:10].to(device)

# Get reconstructions
classifier.eval()
decoder.eval()
with torch.no_grad():
    z, _ = classifier(images)
    reconstructed = decoder(z)

# Plot original vs reconstructed images
fig, axes = plt.subplots(2, 10, figsize=(15, 3))
for i in range(10):
    # Original images
    axes[0,i].imshow(images[i].cpu().squeeze(), cmap='gray')
    axes[0,i].axis('off')
    if i == 0:
        axes[0,i].set_title('Original', pad=10)
    
    # Reconstructed images  
    axes[1,i].imshow(reconstructed[i].cpu().squeeze(), cmap='gray')
    axes[1,i].axis('off')
    if i == 0:
        axes[1,i].set_title('Reconstructed', pad=10)

plt.tight_layout()
plt.show()


## Joint training


Let's try training the models with a joint objective

In [110]:
from importlib import reload
import regene_models
importlib.reload(regene_models)
from regene_models import ClassifierGenerator

Alpha determines how much weight is given to the reconstruction loss.

In [None]:
joint_decoder = regene_models.Decoder(latent_dim=256, device=device)
joint_classifier = regene_models.Classifier(latent_dim=latent_dim, num_classes=10, device=device)
regene_models.train_joint(joint_classifier, joint_decoder, trainloader, num_epochs=12, lr=0.001, lambda_recon=0.8)

# Save models
torch.save(joint_decoder.state_dict(), 'models/joint_decoder.pth')
torch.save(joint_classifier.state_dict(), 'models/joint_classifier.pth')

In [None]:
# Get some test images
dataiter = iter(trainloader)
images, labels = next(dataiter)
images = images.to(device)

# Get reconstructions
with torch.no_grad():
    z, _ = joint_classifier(images)
    reconstructed = joint_decoder(z)

# Plot original vs reconstructed images
fig, axes = plt.subplots(2, 5, figsize=(15, 6))
for i in range(5):
    # Original images
    axes[0,i].imshow(images[i].cpu().squeeze(), cmap='gray')
    axes[0,i].axis('off')
    if i == 0:
        axes[0,i].set_title('Original', pad=10)
    
    # Reconstructed images
    axes[1,i].imshow(reconstructed[i].cpu().squeeze(), cmap='gray')
    axes[1,i].axis('off')
    if i == 0:
        axes[1,i].set_title('Reconstructed', pad=10)

plt.tight_layout()
plt.show()


## Training encoder and classifier separately


In this final section, we will train the encoder and classifier separately. The encoder is trained to minimise the reconstruction loss, and the classifier is trained to minimise the cross-entropy loss on the enocders latent space.

In [115]:
from regene_models import train_autoencoder, train_classifier_only

separate_classifier = regene_models.Classifier(latent_dim=latent_dim, num_classes=10, device=device)
separate_decoder = regene_models.Decoder(latent_dim=latent_dim, device=device)

In [None]:
# Train autoencoder and save models
train_autoencoder(classifier=separate_classifier, decoder=separate_decoder, train_loader=trainloader, num_epochs=12, lr=0.001)

# Save models
torch.save(separate_classifier.state_dict(), 'models/separate_classifier.pth')
torch.save(separate_decoder.state_dict(), 'models/separate_decoder.pth')

In [118]:
train_classifier_only(separate_classifier, trainloader, num_epochs=10, lr=0.001)

## Comparison

We will now compare the performance of the different models.



In [None]:
# Load models
classifier_loaded = regene_models.Classifier(latent_dim=latent_dim, num_classes=10, device=device)
decoder_loaded = regene_models.Decoder(latent_dim=latent_dim, device=device)
classifier_loaded.load_state_dict(torch.load('models/classifier.pth'))
decoder_loaded.load_state_dict(torch.load('models/decoder.pth'))

joint_classifier_loaded = regene_models.Classifier(latent_dim=latent_dim, num_classes=10, device=device)
joint_decoder_loaded = regene_models.Decoder(latent_dim=latent_dim, device=device)
joint_classifier_loaded.load_state_dict(torch.load('models/joint_classifier.pth'))
joint_decoder_loaded.load_state_dict(torch.load('models/joint_decoder.pth'))

separate_classifier_loaded = regene_models.Classifier(latent_dim=latent_dim, num_classes=10, device=device)
separate_decoder_loaded = regene_models.Decoder(latent_dim=latent_dim, device=device)
separate_classifier_loaded.load_state_dict(torch.load('models/separate_classifier.pth'))
separate_decoder_loaded.load_state_dict(torch.load('models/separate_decoder.pth'))

models = [(classifier_loaded, decoder_loaded), (joint_classifier_loaded, joint_decoder_loaded), (separate_classifier_loaded, separate_decoder_loaded)]


In [None]:
import torch.nn.functional as F
from torchmetrics import Accuracy
import pandas as pd
from IPython.display import display

# Function to calculate metrics
def calculate_metrics(classifier, decoder, test_loader):
    classifier.eval()
    decoder.eval()
    
    total = 0
    correct = 0
    mse_total = 0.0
    
    with torch.no_grad():
        for images, labels in test_loader:
            images = images.to(device)
            labels = labels.to(device)
            
            # Get predictions and reconstructions
            z, outputs = classifier(images)
            reconstructed = decoder(z)
            
            # Calculate accuracy
            _, predicted = torch.max(outputs.data, 1)
            total += labels.size(0)
            correct += (predicted == labels).sum().item()
            
            # Calculate MSE
            mse = F.mse_loss(reconstructed, images)
            mse_total += mse.item()
    
    accuracy = 100 * correct / total
    avg_mse = mse_total / len(test_loader)
    
    return accuracy, avg_mse

# Calculate metrics for each model
model_names = ['Standard', 'Joint Training', 'Separate Training']
results = []

for (clf, dec), name in zip(models, model_names):
    accuracy, mse = calculate_metrics(clf, dec, trainloader)
    results.append({
        'Model': name,
        'Accuracy (%)': f'{accuracy:.2f}',
        'MSE': f'{mse:.4f}'
    })

# Create and display DataFrame
df = pd.DataFrame(results)
display(df)


And compare their latent spaces

In [None]:
# Get latent representations for 1000 random training images
random_indices = torch.randint(0, len(trainset), (1000,))
images = torch.stack([trainset[i][0] for i in random_indices])
labels = torch.tensor([trainset[i][1] for i in random_indices])

# Get latent representations for each model
model_names = ['Standard', 'Joint Training', 'Separate Training']
latent_spaces = []

for clf, _ in models:
    clf.eval()
    with torch.no_grad():
        images = images.to(device)
        latent_reps, _ = clf(images)
        latent_spaces.append(latent_reps.cpu().numpy())

# Create subplot for each model's latent space
fig, axes = plt.subplots(1, 3, figsize=(20, 6))
fig.suptitle('t-SNE Visualization of Latent Spaces')

# Perform t-SNE and plot for each model
from sklearn.manifold import TSNE
tsne = TSNE(n_components=2, random_state=42)

for i, (latent_reps, name) in enumerate(zip(latent_spaces, model_names)):
    latent_2d = tsne.fit_transform(latent_reps)
    
    scatter = axes[i].scatter(latent_2d[:, 0], latent_2d[:, 1], c=labels, cmap='tab10')
    axes[i].set_title(name)
    axes[i].set_xlabel('t-SNE Dimension 1')
    axes[i].set_ylabel('t-SNE Dimension 2')

# Add colorbar
plt.colorbar(scatter, ax=axes.ravel().tolist(), label='Digit Class')

# Add legend to last subplot
legend_elements = scatter.legend_elements()[0]
axes[-1].legend(legend_elements, range(10), title="Classes", bbox_to_anchor=(1.3, 1))

plt.tight_layout()
plt.show()
