Very similar to the EMNIST Model Collapse code, using CNN clasifiers rather than fully connected ones.

Originally intended to increase the amount of compute dedicated to VAE training. Experiments show that this approach may increase the performance of classifier accuracy evlaution and thus give a better metric of model collapse. However, these improvements are quite minimal, and likely not worth the additional computational time and code complexity. 

In [6]:
# Package imports

import torch
import torch.nn as nn
import torch.optim as optim
import torchvision
from torchvision import transforms
from torch.utils.data import DataLoader, TensorDataset
import matplotlib.pyplot as plt
import numpy as np
import math
import torch.nn.functional as F 

In [7]:
# Define parameters. Including detail of CNN design

# CNN hyperparameters
cnn_channels = [32, 64, 128]  # Number of channels in each convolutional layer
cnn_kernel_sizes = [3, 3, 3]  # Kernel sizes for each convolutional layer
cnn_strides = [1, 1, 1]  # Strides for each convolutional layer
cnn_paddings = [1, 1, 1]  # Paddings for each convolutional layer
use_pooling = True  # Whether to use max pooling after each conv layer
pool_size = 2  # Size of the max pooling window
fc_layers = [256, 128]  # Fully connected layers after convolutions

# Training hyperparameters
batch_size = 128
latent_dim = 50 # VAAE latent dimension
vae_hidden_layers = [256, 128, 64] # note the VAE has a symmetric encode / decoder structure. However, weights are not tied
classifier_hidden_dim = 512 # hidden dimension in the classifier
learning_rate = 5e-4 
vae_num_epochs = 200 
classifier_num_epochs = 100 # Larger numbers required for EMNIST
num_iterations = 25 
num_samples = 112800  
num_images = 20 

In [8]:
# Load balanced EMNIST dataset

# Load original EMNIST dataset
dataset_mean = 0.1751
dataset_std = 0.3332
transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((dataset_mean,), (dataset_std,)) # mean and variance. See later cells to see how these were calculated
    # normalisation accelerates convergence of the VAE, and possibly of the classifier. However, it means MSE as opposed to BCE must be used
])

original_dataset = torchvision.datasets.EMNIST(root='./data', split='balanced', train=True, download=True, transform=transform) # separate download each time

In [9]:
# VAE architecture specified

class VAE(nn.Module):
    def __init__(self, input_dim, hidden_layers, latent_dim):
        super(VAE, self).__init__()
        
        # Encoder
        encoder_layers = []
        prev_dim = input_dim
        for hidden_dim in hidden_layers:
            encoder_layers.extend([
                nn.Linear(prev_dim, hidden_dim),
                nn.ReLU()
            ])
            prev_dim = hidden_dim
        self.encoder = nn.Sequential(*encoder_layers)
        
        self.fc_mu = nn.Linear(hidden_layers[-1], latent_dim)
        self.fc_logvar = nn.Linear(hidden_layers[-1], latent_dim)
        
        # Decoder
        decoder_layers = []
        prev_dim = latent_dim
        for hidden_dim in reversed(hidden_layers):
            decoder_layers.extend([
                nn.Linear(prev_dim, hidden_dim),
                nn.ReLU()
            ])
            prev_dim = hidden_dim
        decoder_layers.append(nn.Linear(hidden_layers[0], input_dim))
        self.decoder = nn.Sequential(*decoder_layers)

    def encode(self, x):
        h = self.encoder(x)
        return self.fc_mu(h), self.fc_logvar(h)

    def reparameterize(self, mu, logvar):
        std = torch.exp(0.5 * logvar)
        eps = torch.randn_like(std)
        return mu + eps * std

    def decode(self, z):
        return self.decoder(z)

    def forward(self, x):
        mu, logvar = self.encode(x.view(-1, 784))
        z = self.reparameterize(mu, logvar)
        return self.decode(z), mu, logvar

In [10]:
# CNN classifier design

# Note that the classifier is trained using cross entropy
class Classifier(nn.Module):
    def __init__(self, input_channels, input_height, input_width, num_classes, 
                 cnn_channels, cnn_kernel_sizes, cnn_strides, cnn_paddings, 
                 use_pooling, pool_size, fc_layers):
        super(Classifier, self).__init__()
        
        self.conv_layers = nn.ModuleList()
        self.use_pooling = use_pooling
        
        # Convolutional layers
        for i in range(len(cnn_channels)):
            in_channels = input_channels if i == 0 else cnn_channels[i-1]
            self.conv_layers.append(nn.Conv2d(in_channels, cnn_channels[i], 
                                              kernel_size=cnn_kernel_sizes[i], 
                                              stride=cnn_strides[i], 
                                              padding=cnn_paddings[i]))
            if use_pooling:
                self.conv_layers.append(nn.MaxPool2d(pool_size))
        
        # Calculate the size of the feature maps after convolutions and pooling
        def conv_output_size(size, kernel_size, stride, padding):
            return (size - kernel_size + 2 * padding) // stride + 1
        
        for i in range(len(cnn_channels)):
            input_height = conv_output_size(input_height, cnn_kernel_sizes[i], cnn_strides[i], cnn_paddings[i])
            input_width = conv_output_size(input_width, cnn_kernel_sizes[i], cnn_strides[i], cnn_paddings[i])
            if use_pooling:
                input_height //= pool_size
                input_width //= pool_size
        
        # Fully connected layers
        self.fc_layers = nn.ModuleList()
        prev_dim = cnn_channels[-1] * input_height * input_width
        for fc_size in fc_layers:
            self.fc_layers.append(nn.Linear(prev_dim, fc_size))
            prev_dim = fc_size
        self.fc_layers.append(nn.Linear(prev_dim, num_classes))
        
    def forward(self, x):
        # Convolutional layers
        for layer in self.conv_layers:
            x = F.relu(layer(x))
        
        # Flatten the output
        x = x.view(x.size(0), -1)
        
        # Fully connected layers
        for layer in self.fc_layers[:-1]:
            x = F.relu(layer(x))
        
        # Final layer (no activation)
        x = self.fc_layers[-1](x)
        
        return x

In [11]:
# Train VAE function

# VAE loss function
def vae_loss_function(recon_x, x, mu, logvar): 
    MSE = nn.functional.mse_loss(recon_x, x.view(-1, 784), reduction='sum') # MSE required. BCE can be used if normalisation is not introduced
    # BCE makes more theoretical sense, but seems irrelevant in practice
    KLD = -0.5 * torch.sum(1 + logvar - mu.pow(2) - logvar.exp()) # KLD between gaussians. Assuming roughly gaussian structure
    return MSE + KLD

# Train VAE
def train_vae(vae, optimizer, dataloader):
    vae.train()
    train_loss = 0
    for batch_idx, (data, _) in enumerate(dataloader):
        optimizer.zero_grad()
        recon_batch, mu, logvar = vae(data) # reconstructed batch straight from vae.forward
        loss = vae_loss_function(recon_batch, data, mu, logvar)
        loss.backward() # fails when reparameterisation isn't implemented appropriately
        train_loss += loss.item()
        optimizer.step()
    return train_loss / len(dataloader.dataset) # necessary for stability if num_samples is not equal to 60,000

In [12]:
# Train classifier

def train_classifier(classifier, optimizer, dataloader):
    classifier.train()
    criterion = nn.CrossEntropyLoss()
    train_loss = 0
    correct = 0
    for data, target in dataloader:
        data = data.view(-1, 1, 28, 28)  # Reshape to (batch_size, channels, height, width)
        optimizer.zero_grad()
        output = classifier(data)
        loss = criterion(output, target)
        loss.backward()
        optimizer.step()
        train_loss += loss.item()
        pred = output.argmax(dim=1, keepdim=True)
        correct += pred.eq(target.view_as(pred)).sum().item()
    return train_loss / len(dataloader.dataset), correct / len(dataloader.dataset)

In [13]:
# Generate, classify, and visualise digits

# Function to generate new digits. Note that this is for training the next iteration of the models. The visualisation occurs elsewhere, and displays the first num_images of these generated samples
def generate_digits(vae, num_samples):
    vae.eval()
    with torch.no_grad():
        z = torch.randn(num_samples, latent_dim)
        samples = vae.decode(z)
    return samples

# This could probably be done by specifying regions in latent space instead. How could I do this?
def classify_digits(classifier, digits, batch_size=batch_size):
    classifier.eval()
    dataset = TensorDataset(digits)
    dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=False)
    all_labels = []
    
    with torch.no_grad():
        for batch in dataloader:
            batch = batch[0].view(-1, 1, 28, 28)  # Reshape to (batch_size, channels, height, width)
            output = classifier(batch)
            labels = output.argmax(dim=1)
            all_labels.append(labels)
    
    return torch.cat(all_labels)


def visualize_digits(digits, labels, iteration): # removed the original EMINST code and replaced it with MNIST work. Should work fine?
    num_digits = len(digits)
    num_rows = math.ceil(num_digits / 10) # required  to allow arbitrary numbers to be shown
    # Currently creates a single image out of all of them with 10 columns, and the digits stacked on top of one another
    num_cols = min(num_digits, 10)

    fig, axes = plt.subplots(num_rows, num_cols, figsize=(2*num_cols, 2*num_rows))
    fig.suptitle(f'Generated Digits - Iteration {iteration}')

    balanced_dict = {0: '0', 1: '1', 2: '2', 3: '3', 4: '4', 5: '5', 6: '6', 7: '7', 8: '8', 9: '9', 10: 'A', 11: 'B', 12: 'C', 13: 'D', 14: 'E', 15: 'F', 16: 
                 'G', 17: 'H', 18: 'I', 19: 'J', 20: 'K', 21: 'L', 22: 'M', 23: 'N', 24: 'O', 25: 'P', 26: 'Q', 27: 'R', 28: 'S', 29: 'T', 30: 'U', 31: 'V', 
                 32: 'W', 33: 'X', 34: 'Y', 35: 'Z', 36: 'a', 37: 'b', 38: 'd', 39: 'e', 40: 'f', 41: 'g', 42: 'h', 43: 'n', 44: 'q', 45: 'r', 46: 't'}

    if num_rows == 1:
        axes = axes.reshape(1, -1) # avoid index errors

    for i in range(num_digits):
        row = i // 10
        col = i % 10
        ax = axes[row, col]
        image = digits[i].reshape(28, 28)
        rotated_image = torch.rot90(image, k=-1)
        flipped_image = torch.flip(rotated_image, dims=[1])
        ax.imshow(flipped_image.squeeze(), cmap='gray') # recreates original MNIST images as faithfully as possible. Perhaps a custom cmap would be better for historic accuracy?
        ax.set_title(f'Label: {balanced_dict[labels[i].item()]}') # NOTE: Printed images are not sorted by label in any way. This could be resolved in future?
        ax.axis('off')

    # Remove any unused subplots. Without this, there may be graphical glitches and poorly placed images
    for i in range(num_digits, num_rows * num_cols):
        row = i // 10
        col = i % 10
        fig.delaxes(axes[row, col])

    plt.tight_layout() # I might want this. It seems to place the title too low, such that it overlaps with the images themselves. Removing this fixes the title issue
    plt.show()

In [14]:
# Evaluate a given classifier and dataloader

# Update the evaluate_classifier function
def evaluate_classifier(classifier, dataloader):
    classifier.eval()
    criterion = nn.CrossEntropyLoss()
    total_loss = 0
    correct = 0
    total = 0
    with torch.no_grad():
        for data, target in dataloader:
            data = data.view(-1, 1, 28, 28)  # Reshape to (batch_size, channels, height, width)
            output = classifier(data)
            loss = criterion(output, target)
            total_loss += loss.item() * data.size(0)
            pred = output.argmax(dim=1, keepdim=True)
            correct += pred.eq(target.view_as(pred)).sum().item()
            total += data.size(0)
    avg_loss = total_loss / total
    accuracy = correct / total
    return avg_loss, accuracy


In [15]:
# Visualise loss and accuracy matrices of various classifiers

# Visualize the loss and accuracy matrices
def plot_matrix(matrix, title, cmap='viridis'): # I have decided I have viridis
    plt.figure(figsize=(15, 12)) # made larger than previous versions. To help with readability of text
    plt.imshow(matrix, cmap=cmap)
    plt.colorbar() # considerably better visualisation
    plt.title(title)
    plt.xlabel('Dataset Iteration')
    plt.ylabel('Classifier Iteration')
    for i in range(matrix.shape[0]): # NOTE: Currently counts from 0. I should probably shift it to counting from 1, like all humans do
        for j in range(matrix.shape[1]):
            plt.text(j, i, f'{matrix[i, j]:.2f}', ha='center', va='center', color='white')
    #plt.tight_layout() # produces graphical bugs with the title, but is clearer with the colourbar and labels
    plt.show()

# Visualise the loss and accuracy matrices without text. Just colour plots
def colour_plot_matrix(matrix, title, cmap='viridis', num_ticks=11):
    plt.figure(figsize=(10, 8))
    plt.imshow(matrix, cmap=cmap)
    plt.colorbar()
    plt.title(title)
    plt.xlabel('Dataset Iteration')
    plt.ylabel('Classifier Iteration')
    
    # Create evenly spaced tick locations
    x_ticks = np.linspace(0, matrix.shape[1] - 1, num_ticks, dtype=int)
    y_ticks = np.linspace(0, matrix.shape[0] - 1, num_ticks, dtype=int)
    
    # Set tick locations and labels
    plt.xticks(x_ticks, x_ticks + 1)  # +1 to start from 1 instead of 0
    plt.yticks(y_ticks, y_ticks + 1)
    
    plt.show()

In [None]:
# Main loop. Iterates over training of models and generating samples for new models

current_dataset = original_dataset
classifiers = []
loss_matrix = np.zeros((num_iterations, num_iterations)) 
accuracy_matrix = np.zeros((num_iterations, num_iterations))

vae = VAE(784, vae_hidden_layers, latent_dim) # VAE is not reset after each iteration to save on computational time
for iteration in range(num_iterations):
    print(f"Iteration {iteration + 1}/{num_iterations}")

    dataloader = DataLoader(current_dataset, batch_size=batch_size, shuffle=True)

    vae_optimizer = optim.Adam(vae.parameters(), lr=learning_rate)
    for epoch in range(vae_num_epochs):
        loss = train_vae(vae, vae_optimizer, dataloader)
        if (epoch + 1) % 10 == 0:
            print(f'VAE Epoch {epoch+1}/{vae_num_epochs}, Loss: {loss:.4f}')

    # Update the classifier initialization in the main loop
    classifier = Classifier(input_channels=1, input_height=28, input_width=28, num_classes=47,
                        cnn_channels=cnn_channels, cnn_kernel_sizes=cnn_kernel_sizes,
                        cnn_strides=cnn_strides, cnn_paddings=cnn_paddings,
                        use_pooling=use_pooling, pool_size=pool_size, fc_layers=fc_layers)
    
    classifier_optimizer = optim.Adam(classifier.parameters(), lr=learning_rate) # does this need to be respecified?
    for epoch in range(classifier_num_epochs):
        loss, accuracy = train_classifier(classifier, classifier_optimizer, dataloader)
        if (epoch + 1) % 10 == 0:
            print(f'Classifier Epoch {epoch+1}/{classifier_num_epochs}, Loss: {loss:.4f}, Accuracy: {accuracy:.4f}')

    # Store the trained classifier in classifiers list. Used later for overall evaluation
    classifiers.append(classifier)

    # Generate new dataset
    generated_digits = generate_digits(vae, num_samples)
    generated_labels = classify_digits(classifier, generated_digits) # from prev classifier notably

    # Visualize some generated digits
    visualize_digits(generated_digits[:num_images], generated_labels[:num_images], iteration + 1) 


    # Create new dataset for next iteration
    current_dataset = TensorDataset(generated_digits, generated_labels) # torch.utils.data.TensorDataset

    # Evaluate all previous classifiers on the new dataset, and writing to the loss and accuracy arrays
    new_dataloader = DataLoader(current_dataset, batch_size=batch_size, shuffle=False) # shuffle should be unnecessary as the sampling from the latent space is random
    for prev_iteration, prev_classifier in enumerate(classifiers):
        loss, accuracy = evaluate_classifier(prev_classifier, new_dataloader)
        loss_matrix[prev_iteration, iteration] = loss
        accuracy_matrix[prev_iteration, iteration] = accuracy

    # This line can be used to save the current dataset for later analysis. Requires drive loading
    #torch.save(current_dataset, f'generated_dataset_iteration_{iteration + 1}.pt')

# Save matrices for further analysis
#np.save('loss_matrix_EMNIST.npy', loss_matrix)
#np.save('accuracy_matrix_EMNIST.npy', accuracy_matrix)

# Show relative loss and accuracies of classifiers on generated datasets
# The plot_matrix function explicitly includes precise values. More appropriate for small numbers of iterations
plot_matrix(loss_matrix, 'Loss Matrix')
plot_matrix(accuracy_matrix, 'Accuracy Matrix')

colour_plot_matrix(loss_matrix, 'Loss Matrix')
colour_plot_matrix(accuracy_matrix, 'Accuracy Matrix')

Iteration 1/25
