In [None]:
# Package imports

import torch
import torch.nn as nn
import torch.optim as optim
import torchvision
from torchvision import transforms
from torch.utils.data import DataLoader, TensorDataset
import matplotlib.pyplot as plt # imshow might be faster?
import numpy as np
import math

In [None]:
# Define hyperparameters

batch_size = 128
latent_dim = 20 # the smallest layer in the VAE. May be too high, as indicated by work analysing sparsity
vae_hidden_dim = 128 # size of the hidden dimension in the vae
classifier_hidden_dim = 256 # hidden dimension in the classifier
learning_rate = 1e-3
vae_num_epochs = 200 # higher slows collapse
classifier_num_epochs = 100 # 100 is adequate to achieve 100% accuracy each time
num_iterations = 50  # Number of times to repeat the process
num_samples = 60000  # Number of samples to generate in each iteration # Note that original MNIST (all 60000) are still used regardless of this value for the first iteration
num_images = 10 # This is for printing samples. Not actually relevant to the code logic

In [None]:
# Load original MNIST dataset

transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.1307,), (0.3081,)) # mean and variance. See later cells to see how these were calculated
    # normalisation accelerates convergence of the VAE, and possibly of the classifier. However, it means MSE as opposed to BCE must be used
])

original_dataset = torchvision.datasets.MNIST(root='./data', train=True, download=True, transform=transform) # separate download each time

In [None]:
# Define the VAE model. Standard ReLU with one hidden layer

# logvar emerges more naturally. Also useful for the explicit KLD
class VAE(nn.Module):
    def __init__(self, input_dim, hidden_dim, latent_dim):
        super(VAE, self).__init__()
        self.encoder = nn.Sequential(
            nn.Linear(input_dim, hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, hidden_dim),
            nn.ReLU()
        )
        self.fc_mu = nn.Linear(hidden_dim, latent_dim)
        self.fc_logvar = nn.Linear(hidden_dim, latent_dim)
        self.decoder = nn.Sequential(
            nn.Linear(latent_dim, hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, input_dim)
        )

    def encode(self, x):
        h = self.encoder(x)
        return self.fc_mu(h), self.fc_logvar(h)

    def reparameterize(self, mu, logvar):
        std = torch.exp(0.5 * logvar) # from N(0,1) assumption
        eps = torch.randn_like(std)
        return mu + eps * std

    def decode(self, z):
        return self.decoder(z)

    def forward(self, x):
        mu, logvar = self.encode(x.view(-1, 784))
        z = self.reparameterize(mu, logvar)
        return self.decode(z), mu, logvar # returns LOGVAR

In [None]:
# Define the Classifier

class Classifier(nn.Module):
    def __init__(self, input_dim, hidden_dim, num_classes):
        super(Classifier, self).__init__()
        self.model = nn.Sequential(
            nn.Linear(input_dim, hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, num_classes)
        )

    def forward(self, x):
        return self.model(x)

In [None]:
# Helper functions

# VAE loss function
def vae_loss_function(recon_x, x, mu, logvar): # recon_x is short for reconstructed
    MSE = nn.functional.mse_loss(recon_x, x.view(-1, 784), reduction='sum') # MSE required. BCE can be used if normalisation is not introduced
    # BCE makes more theoretical sense, but seems irrelevant in practice
    KLD = -0.5 * torch.sum(1 + logvar - mu.pow(2) - logvar.exp()) # closed form for KLD between gaussians
    return MSE + KLD

# Generates new digits for later training
def generate_digits(vae, num_samples):
    vae.eval()
    with torch.no_grad():
        z = torch.randn(num_samples, latent_dim)
        samples = vae.decode(z)
    return samples

# Classify generated digits. Uses the previous iteration's classifier, bootstrapping future classifiers.
def classify_digits(classifier, digits):
    classifier.eval()
    with torch.no_grad():
        output = classifier(digits)
        labels = output.argmax(dim=1)
    return labels


# Evaluates a classifier on a dataset
# Used to test performance of historic classifiers on newly generated sets of images to test statistical distance
def evaluate_classifier(classifier, dataloader):
    classifier.eval()
    criterion = nn.CrossEntropyLoss()
    total_loss = 0
    correct = 0
    total = 0
    with torch.no_grad(): 
        for data, target in dataloader:
            output = classifier(data.view(data.size(0), -1))
            loss = criterion(output, target)
            total_loss += loss.item() * data.size(0)
            pred = output.argmax(dim=1, keepdim=True)
            correct += pred.eq(target.view_as(pred)).sum().item()
            total += data.size(0)
    avg_loss = total_loss / total
    accuracy = correct / total
    return avg_loss, accuracy

In [None]:
# Training functions

# Train VAE
def train_vae(vae, optimizer, dataloader):
    vae.train()
    train_loss = 0
    for batch_idx, (data, _) in enumerate(dataloader):
        optimizer.zero_grad()
        recon_batch, mu, logvar = vae(data) # reconstructed batch straight from vae.forward
        loss = vae_loss_function(recon_batch, data, mu, logvar)
        loss.backward() # fails when reparameterisation isn't implemented appropriately
        train_loss += loss.item()
        optimizer.step()
    return train_loss / len(dataloader.dataset) # necessary for stability if num_samples is not equal to 60,000

# Train Classifier
def train_classifier(classifier, optimizer, dataloader): # could pass the raw dataset to avoid dependencies?
    classifier.train()
    criterion = nn.CrossEntropyLoss()
    train_loss = 0
    correct = 0 # for later accuracy calculation. Entirely for evaluation
    for data, target in dataloader:
        optimizer.zero_grad()
        output = classifier(data.view(data.size(0), -1)) # turn to an array
        loss = criterion(output, target)
        loss.backward()
        optimizer.step()
        train_loss += loss.item()
        pred = output.argmax(dim=1, keepdim=True) # unclear how slow argmax is as a step, but it is certainly easier than implementing a for loop here
        correct += pred.eq(target.view_as(pred)).sum().item()
    return train_loss / len(dataloader.dataset), correct / len(dataloader.dataset)

In [None]:
# Plotting functions

def visualize_digits(digits, labels, iteration):
    num_digits = len(digits)
    num_rows = math.ceil(num_digits / 10) # allows arbitrarily many images to be shown. Displays these images in rows of 10
    num_cols = min(num_digits, 10)

    fig, axes = plt.subplots(num_rows, num_cols, figsize=(2*num_cols, 2*num_rows))
    fig.suptitle(f'Generated Digits - Iteration {iteration}')

    if num_rows == 1:
        axes = axes.reshape(1, -1)

    for i in range(num_digits):
        row = i // 10
        col = i % 10
        ax = axes[row, col]
        ax.imshow(digits[i].view(28, 28).numpy(), cmap='gray') # recreates original MNIST images as faithfully as possible
        ax.set_title(f'Label: {labels[i].item()}')
        ax.axis('off')

    # Remove any unused subplots
    for i in range(num_digits, num_rows * num_cols):
        row = i // 10
        col = i % 10
        fig.delaxes(axes[row, col])

    plt.tight_layout()
    plt.show()

# Visualize the loss and accuracy matrices
def plot_matrix(matrix, title, cmap='viridis'): # I have decided I have viridis
    plt.figure(figsize=(10, 8))
    plt.imshow(matrix, cmap=cmap)
    plt.colorbar() # considerably better visualisation
    plt.title(title)
    plt.xlabel('Dataset Iteration')
    plt.ylabel('Classifier Iteration')
    for i in range(matrix.shape[0]): # NOTE: Currently counts from 0. I should probably shift it to counting from 1, like all humans do
        for j in range(matrix.shape[1]):
            plt.text(j, i, f'{matrix[i, j]:.2f}', ha='center', va='center', color='white')
    #plt.tight_layout() # produces graphical bugs with the title, but is clearer with the colourbar and labels
    plt.show()

# Creates a colour plot without including values. Far more appropriate for large runs
def colour_plot_matrix(matrix, title, cmap='viridis', num_ticks=11):
    plt.figure(figsize=(10, 8))
    plt.imshow(matrix, cmap=cmap)
    plt.colorbar()
    plt.title(title)
    plt.xlabel('Dataset Iteration')
    plt.ylabel('Classifier Iteration')
    
    # Create evenly spaced tick locations
    x_ticks = np.linspace(0, matrix.shape[1] - 1, num_ticks, dtype=int)
    y_ticks = np.linspace(0, matrix.shape[0] - 1, num_ticks, dtype=int)
    
    # Set tick locations and labels
    plt.xticks(x_ticks, x_ticks + 1)  # +1 to start from 1 instead of 0
    plt.yticks(y_ticks, y_ticks + 1)
    
    plt.show()

In [None]:
# Main training loop

# Initialisation
current_dataset = original_dataset
classifiers = []
loss_matrix = np.zeros((num_iterations, num_iterations))
accuracy_matrix = np.zeros((num_iterations, num_iterations))
vae = VAE(784, vae_hidden_dim, latent_dim) # VAE is not reset after each iteration

for iteration in range(num_iterations):
    print(f"Iteration {iteration + 1}/{num_iterations}")

    # Create DataLoader
    dataloader = DataLoader(current_dataset, batch_size=batch_size, shuffle=True)

    # Train the VAE
    vae_optimizer = optim.Adam(vae.parameters(), lr=learning_rate)
    for epoch in range(vae_num_epochs):
        loss = train_vae(vae, vae_optimizer, dataloader)
        if (epoch + 1) % 50 == 0:
            print(f'VAE Epoch {epoch+1}/{vae_num_epochs}, Loss: {loss:.4f}')

    # Initialize and train classifier. Note that the classifier is reinitialised every training loop
    classifier = Classifier(784, classifier_hidden_dim, 10)
    classifier_optimizer = optim.Adam(classifier.parameters(), lr=learning_rate)
    for epoch in range(classifier_num_epochs):
        loss, accuracy = train_classifier(classifier, classifier_optimizer, dataloader)
        if (epoch + 1) % 50 == 0:
            print(f'Classifier Epoch {epoch+1}/{classifier_num_epochs}, Loss: {loss:.4f}, Accuracy: {accuracy:.4f}')

    # Store the trained classifier in classifiers list. Used later for overall evaluation
    classifiers.append(classifier)

    # Generate new dataset
    generated_digits = generate_digits(vae, num_samples)
    generated_labels = classify_digits(classifier, generated_digits) # from prev classifier notably

    # Visualize some generated digits. num_images represents the number of images to be displayed
    visualize_digits(generated_digits[:num_images], generated_labels[:num_images], iteration + 1)

    # Create new dataset for next iteration
    current_dataset = TensorDataset(generated_digits, generated_labels) # torch.utils.data.TensorDataset alias

    # Evaluate all previous classifiers on the new dataset, and writing to the loss and accuracy arrays
    new_dataloader = DataLoader(current_dataset, batch_size=batch_size, shuffle=False) # shuffle should be unnecessary as the sampling from the latent space is random
    for prev_iteration, prev_classifier in enumerate(classifiers):
        loss, accuracy = evaluate_classifier(prev_classifier, new_dataloader)
        loss_matrix[prev_iteration, iteration] = loss
        accuracy_matrix[prev_iteration, iteration] = accuracy

    # This line can be used to save the current dataset for later analysis
    #torch.save(current_dataset, f'generated_dataset_iteration_{iteration + 1}.pt')

# For smaller runs, use plot_matrix to display all relevant values
colour_plot_matrix(loss_matrix, 'Loss Matrix')
colour_plot_matrix(accuracy_matrix, 'Accuracy Matrix')

# Save matrices for further analysis
#np.save('loss_matrix.npy', loss_matrix)
#np.save('accuracy_matrix.npy', accuracy_matrix)