Demonstrating model collapse using the balanced EMNIST dataset.

Has a very similar structure to the corresponding MNIST scripts. Variational Auto-Encoders (VAEs) are trained on the dataset alongside a classifier. The VAE then has a full 112800 samples drawn (equal to the number in the original balanced EMNIST dataset), which the previously trained classifier then classifies. 

Model collapse is visible by viewing random samples of the outputs of each iteration of VAE. Note that it is not necessarily possible to draw a balanced sample, as after a certain amount of time variance within the outputs shrinks so much that many of the original classes are no longer represented in the generated datasets. 

The primary metric of model collapse used is the performance on old trained classifiers on the newly generated data. This is intended to be a data-driven description of how many original features remain, although it also incorporates (unwanted?) information about whether later classifiers identified relevant features, or are using an entirely different scheme. 

In [1]:
# Package Imports

import torch
import torch.nn as nn
import torch.optim as optim
import torchvision
from torchvision import transforms
from torch.utils.data import DataLoader, TensorDataset
import matplotlib.pyplot as plt
import numpy as np
import math

In [2]:
# Defining parameters

# Define hyperparameters # As of right now, these are broadly speaking the same parameters as used in the MNIST trial
batch_size = 128
latent_dim = 50 # the smallest layer in the VAE
vae_hidden_layers = [256, 128, 64] # note that the VAE is symmetric in the encoder and decoder
classifier_hidden_dim = 512 # hidden dimension in the classifier
learning_rate = 5e-4 
vae_num_epochs = 70 
classifier_num_epochs = 200 
num_iterations = 25  # Number of times to repeat the process
num_samples = 112800  # Number of samples to generate in each iteration # Note that original EMNIST (all 112800) are still used regardless of this value for the first iteration
num_images = 20 # For printing samples

In [3]:
# Load the balanced EMINST dataset and normalise for faster model convergence

# Calculated using 'EMNIST Calculate Statistics.ipynb' 
dataset_mean = 0.1751
dataset_std = 0.3332
transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((dataset_mean,), (dataset_std,)) # mean and variance. See later cells to see how these were calculated
    # normalisation accelerates convergence of the VAE, and possibly of the classifier. However, it means MSE as opposed to BCE must be used
])

original_dataset = torchvision.datasets.EMNIST(root='./data', split='balanced', train=True, download=True, transform=transform) # separate download each time


Downloading https://biometrics.nist.gov/cs_links/EMNIST/gzip.zip to ./data\EMNIST\raw\gzip.zip


100%|██████████| 562M/562M [04:09<00:00, 2.25MB/s] 


Extracting ./data\EMNIST\raw\gzip.zip to ./data\EMNIST\raw


In [4]:
# Define the VAE used. Uses a standard architecture

class VAE(nn.Module):
    def __init__(self, input_dim, hidden_layers, latent_dim):
        super(VAE, self).__init__()
        
        # Encoder
        encoder_layers = []
        prev_dim = input_dim
        for hidden_dim in hidden_layers:
            encoder_layers.extend([
                nn.Linear(prev_dim, hidden_dim),
                nn.ReLU()
            ])
            prev_dim = hidden_dim
        self.encoder = nn.Sequential(*encoder_layers)
        
        self.fc_mu = nn.Linear(hidden_layers[-1], latent_dim)
        self.fc_logvar = nn.Linear(hidden_layers[-1], latent_dim)
        
        # Decoder
        decoder_layers = []
        prev_dim = latent_dim
        for hidden_dim in reversed(hidden_layers):
            decoder_layers.extend([
                nn.Linear(prev_dim, hidden_dim),
                nn.ReLU()
            ])
            prev_dim = hidden_dim
        decoder_layers.append(nn.Linear(hidden_layers[0], input_dim))
        self.decoder = nn.Sequential(*decoder_layers)

    def encode(self, x):
        h = self.encoder(x)
        return self.fc_mu(h), self.fc_logvar(h)

    def reparameterize(self, mu, logvar):
        std = torch.exp(0.5 * logvar)
        eps = torch.randn_like(std)
        return mu + eps * std

    def decode(self, z):
        return self.decoder(z)

    def forward(self, x):
        mu, logvar = self.encode(x.view(-1, 784))
        z = self.reparameterize(mu, logvar)
        return self.decode(z), mu, logvar

In [5]:
# Define the classifier
# Note that the classifier loss is given as cross entropy. Defined in the classifier train loop train_classifier

# Uses a fully connected linear architecutre. Performance is adequate in earlier iterations
class Classifier(nn.Module):
    def __init__(self, input_dim, hidden_dim, num_classes):
        super(Classifier, self).__init__()
        self.model = nn.Sequential(
            nn.Linear(input_dim, hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, num_classes)
        )

    def forward(self, x):
        return self.model(x)

In [6]:
# Custom loss function for the VAE

def vae_loss_function(recon_x, x, mu, logvar): # recon_x is short for reconstructed
    MSE = nn.functional.mse_loss(recon_x, x.view(-1, 784), reduction='sum') # MSE required. BCE can be used if normalisation is not introduced
    # BCE makes more theoretical sense, but seems irrelevant in practice
    KLD = -0.5 * torch.sum(1 + logvar - mu.pow(2) - logvar.exp()) # closed form for KLD between gaussians
    return MSE + KLD

In [7]:
# VAE main training loop

def train_vae(vae, optimizer, dataloader):
    vae.train()
    train_loss = 0
    for batch_idx, (data, _) in enumerate(dataloader):
        optimizer.zero_grad()
        recon_batch, mu, logvar = vae(data) # reconstructed batch straight from vae.forward
        loss = vae_loss_function(recon_batch, data, mu, logvar)
        loss.backward() # fails when reparameterisation isn't implemented appropriately
        train_loss += loss.item()
        optimizer.step()
    return train_loss / len(dataloader.dataset) # necessary for stability if num_samples is not equal to 60,000

In [8]:
# Classifier main training loop

def train_classifier(classifier, optimizer, dataloader):
    classifier.train()
    criterion = nn.CrossEntropyLoss()
    train_loss = 0
    correct = 0
    for data, target in dataloader:
        optimizer.zero_grad()
        output = classifier(data.view(data.size(0), -1))
        loss = criterion(output, target)
        loss.backward()
        optimizer.step()
        train_loss += loss.item()
        pred = output.argmax(dim=1, keepdim=True)
        correct += pred.eq(target.view_as(pred)).sum().item()
    return train_loss / len(dataloader.dataset), correct / len(dataloader.dataset)

In [9]:
# Visualise loss and accuracy matrices of various classifiers

# Visualize the loss and accuracy matrices
def plot_matrix(matrix, title, cmap='viridis'): # I have decided I have viridis
    plt.figure(figsize=(15, 12)) # made larger than previous versions. To help with readability of text
    plt.imshow(matrix, cmap=cmap)
    plt.colorbar() # considerably better visualisation
    plt.title(title)
    plt.xlabel('Dataset Iteration')
    plt.ylabel('Classifier Iteration')
    for i in range(matrix.shape[0]): # NOTE: Currently counts from 0. I should probably shift it to counting from 1, like all humans do
        for j in range(matrix.shape[1]):
            plt.text(j, i, f'{matrix[i, j]:.2f}', ha='center', va='center', color='white')
    #plt.tight_layout() # produces graphical bugs with the title, but is clearer with the colourbar and labels
    plt.show()

# Visualise the loss and accuracy matrices without text. Just colour plots
def colour_plot_matrix(matrix, title, cmap='viridis', num_ticks=11):
    plt.figure(figsize=(10, 8))
    plt.imshow(matrix, cmap=cmap)
    plt.colorbar()
    plt.title(title)
    plt.xlabel('Dataset Iteration')
    plt.ylabel('Classifier Iteration')
    
    # Create evenly spaced tick locations
    x_ticks = np.linspace(0, matrix.shape[1] - 1, num_ticks, dtype=int)
    y_ticks = np.linspace(0, matrix.shape[0] - 1, num_ticks, dtype=int)
    
    # Set tick locations and labels
    plt.xticks(x_ticks, x_ticks + 1)  # +1 to start from 1 instead of 0
    plt.yticks(y_ticks, y_ticks + 1)
    
    plt.show()

In [10]:
# Create, classify, and display digits

def generate_digits(vae, num_samples):
    vae.eval()
    with torch.no_grad():
        z = torch.randn(num_samples, latent_dim)
        samples = vae.decode(z)
    return samples

# Classify generated digits. Uses the previous iteration's classifier, bootstrapping future classifiers.
# Note that this could potentially be done by directly partitioning the latent space while samples are drawn
def classify_digits(classifier, digits):
    classifier.eval()
    with torch.no_grad():
        output = classifier(digits)
        labels = output.argmax(dim=1)
    return labels

def visualize_digits(digits, labels, iteration): # removed the original EMINST code and replaced it with MNIST work. Should work fine?
    num_digits = len(digits)
    num_rows = math.ceil(num_digits / 10) # required  to allow arbitrary numbers to be shown
    # Currently creates a single image out of all of them with 10 columns, and the digits stacked on top of one another
    num_cols = min(num_digits, 10)

    fig, axes = plt.subplots(num_rows, num_cols, figsize=(2*num_cols, 2*num_rows))
    fig.suptitle(f'Generated Digits - Iteration {iteration}')

    balanced_dict = {0: '0', 1: '1', 2: '2', 3: '3', 4: '4', 5: '5', 6: '6', 7: '7', 8: '8', 9: '9', 10: 'A', 11: 'B', 12: 'C', 13: 'D', 14: 'E', 15: 'F', 16: 
                 'G', 17: 'H', 18: 'I', 19: 'J', 20: 'K', 21: 'L', 22: 'M', 23: 'N', 24: 'O', 25: 'P', 26: 'Q', 27: 'R', 28: 'S', 29: 'T', 30: 'U', 31: 'V', 
                 32: 'W', 33: 'X', 34: 'Y', 35: 'Z', 36: 'a', 37: 'b', 38: 'd', 39: 'e', 40: 'f', 41: 'g', 42: 'h', 43: 'n', 44: 'q', 45: 'r', 46: 't'}

    if num_rows == 1:
        axes = axes.reshape(1, -1) # avoid index errors

    for i in range(num_digits):
        row = i // 10
        col = i % 10
        ax = axes[row, col]
        image = digits[i].reshape(28, 28)
        rotated_image = torch.rot90(image, k=-1)
        flipped_image = torch.flip(rotated_image, dims=[1])
        ax.imshow(flipped_image.squeeze(), cmap='gray') # recreates original MNIST images as faithfully as possible. Perhaps a custom cmap would be better for historic accuracy?
        ax.set_title(f'Label: {balanced_dict[labels[i].item()]}') # NOTE: Printed images are not sorted by label in any way. This could be resolved in future?
        ax.axis('off')

    # Remove any unused subplots. Without this, there may be graphical glitches and poorly placed images
    for i in range(num_digits, num_rows * num_cols):
        row = i // 10
        col = i % 10
        fig.delaxes(axes[row, col])

    plt.tight_layout() # I might want this. It seems to place the title too low, such that it overlaps with the images themselves. Removing this fixes the title issue
    plt.show()

In [11]:
# Evaluate a given classifier on a given dataset
# Returns the cross entropy loss and the accuracy over the whole dataset

def evaluate_classifier(classifier, dataloader):
    classifier.eval()
    criterion = nn.CrossEntropyLoss()
    total_loss = 0
    correct = 0
    total = 0
    with torch.no_grad():
        for data, target in dataloader:
            data = data.view(-1, 1, 28, 28)  # Reshape to (batch_size, channels, height, width)
            output = classifier(data)
            loss = criterion(output, target)
            total_loss += loss.item() * data.size(0)
            pred = output.argmax(dim=1, keepdim=True)
            correct += pred.eq(target.view_as(pred)).sum().item()
            total += target.size(0)
    avg_loss = total_loss / total
    accuracy = correct / total
    return avg_loss, accuracy

In [None]:
# Main loop. Iterates over training of models and generating samples for new models
current_dataset = original_dataset
classifiers = []
loss_matrix = np.zeros((num_iterations, num_iterations)) # numpy shortcut for square zeros?
accuracy_matrix = np.zeros((num_iterations, num_iterations))

vae = VAE(784, vae_hidden_layers, latent_dim) # VAE is not reset after each iteration to save on computational time, requiring fewer generations to keep sensible results
for iteration in range(num_iterations):
    print(f"Iteration {iteration + 1}/{num_iterations}")

    # Create DataLoader
    dataloader = DataLoader(current_dataset, batch_size=batch_size, shuffle=True)

    # Initialize and train VAE
    #vae = VAE(784, vae_hidden_dim, latent_dim) # currently not being reinitialised
    vae_optimizer = optim.Adam(vae.parameters(), lr=learning_rate)
    for epoch in range(vae_num_epochs):
        loss = train_vae(vae, vae_optimizer, dataloader)
        if (epoch + 1) % 10 == 0:
            print(f'VAE Epoch {epoch+1}/{vae_num_epochs}, Loss: {loss:.4f}')

    # Initialize and train classifier
    classifier = Classifier(784, classifier_hidden_dim, 47) # EMNIST has 47 classes
    classifier_optimizer = optim.Adam(classifier.parameters(), lr=learning_rate) # does this need to be respecified?
    for epoch in range(classifier_num_epochs):
        loss, accuracy = train_classifier(classifier, classifier_optimizer, dataloader)
        if (epoch + 1) % 10 == 0:
            print(f'Classifier Epoch {epoch+1}/{classifier_num_epochs}, Loss: {loss:.4f}, Accuracy: {accuracy:.4f}')

    # Store the trained classifier in classifiers list. Used later for overall evaluation
    classifiers.append(classifier)

    # Generate new dataset
    generated_digits = generate_digits(vae, num_samples)
    generated_labels = classify_digits(classifier, generated_digits) # from prev classifier notably

    # Visualize some generated digits
    #visualize_digits(generated_digits[:num_images*num_cols], generated_labels[:num_images*num_cols], iteration + 1, num_rows=num_images, num_cols=num_cols) # num_images used to be set to 25. Produces no issues
    visualize_digits(generated_digits[:num_images], generated_labels[:num_images], iteration + 1) # num_images used to be set to 25. Produces no issues


    # Create new dataset for next iteration
    current_dataset = TensorDataset(generated_digits, generated_labels) # torch.utils.data.TensorDataset

    # Evaluate all previous classifiers on the new dataset, and writing to the loss and accuracy arrays
    new_dataloader = DataLoader(current_dataset, batch_size=batch_size, shuffle=False) # shuffle should be unnecessary as the sampling from the latent space is random
    for prev_iteration, prev_classifier in enumerate(classifiers):
        loss, accuracy = evaluate_classifier(prev_classifier, new_dataloader)
        loss_matrix[prev_iteration, iteration] = loss
        accuracy_matrix[prev_iteration, iteration] = accuracy

    # This line can be used to save the current dataset for later analysis. Requires drive loading
    #torch.save(current_dataset, f'generated_dataset_iteration_{iteration + 1}.pt')

# Save matrices for further analysis
#np.save('loss_matrix_EMNIST.npy', loss_matrix)
#np.save('accuracy_matrix_EMNIST.npy', accuracy_matrix)

# Display the relevative accuracies of trained classifiers

plot_matrix(loss_matrix, 'Loss Matrix')
plot_matrix(accuracy_matrix, 'Accuracy Matrix')

colour_plot_matrix(loss_matrix, 'Loss Matrix')
colour_plot_matrix(accuracy_matrix, 'Accuracy Matrix')

Iteration 1/25
VAE Epoch 10/70, Loss: 169.1161
