# Results

Notebook for comparing our generative models

Contains

- Frechet LeNet5 distance
- Label distribution

In [None]:
import os
import csv

import torch
from torchvision import transforms
import numpy as np
import matplotlib.pyplot as plt
import matplotlib.axes, matplotlib.figure
from matplotlib.ticker import PercentFormatter

from lenet import LeNet5
import distances

In [None]:
enable_file_save = False # Set to True to save figures etc.

In [None]:
# Setup device
cuda = torch.cuda.is_available()

if cuda:
    device = torch.device("cuda")
else:
    device = torch.device("cpu")

print("Using device", device)

# Disable GPU always
# device = torch.device("cpu")

In [None]:
mnist_test = np.load('data/mnist_test.npy')

In [None]:
leNet_classifier = LeNet5(n_classes=10).to(device)
leNet_classifier.load_state_dict(torch.load('models/mnist_lenet5.pth'))
leNet_classifier.eval();

def _np_to_net(images):
    images = torch.from_numpy(images).to(device).float()
    images = torch.unsqueeze(images, dim=1) # add channel dimension
    resize = transforms.Resize((32, 32)) # LeNet expects 32x32 size images
    images = resize.forward(images)
    return images

def _net_to_np(array):
    return array.cpu().detach().numpy()

def classify_images(images):
    images = _np_to_net(images)
    _, probs = leNet_classifier.forward(images)
    probs = _net_to_np(probs)
    predicted_class = np.argmax(probs, axis=1)
    return predicted_class

def embed_images(images):
    images = _np_to_net(images)
    embedding = leNet_classifier.forward_headless(images)
    embedding = _net_to_np(embedding)
    return embedding

In [None]:
def plot_classification(images, dataset_name):
    
    classes = classify_images(images)
    fig, axs = plt.subplots(1, len(images), figsize=(12, 1.6))
    fig.suptitle("Image classification with LeNet5 for %s" % dataset_name)
    
    for i in range(len(images)):
        ax = axs[i]
        img = images[i]
        img = np.clip(img, 0, 1)
        ax.imshow(img, cmap='gray_r')
        ax.set_title("%s" % classes[i]),
        ax.set_xticks([]), ax.set_yticks([])
    
    if enable_file_save:
        plt.savefig('plots/LeNet_classification_examples_%s.png' % dataset_name.replace(' ', '_'))
    
    plt.show()

plot_classification(mnist_test[:15], 'MNIST')

In [None]:
def fit_gaussian(data):
    mean = np.mean(data, axis=0)
    cov = np.cov(data, rowvar=0)
    assert mean.shape == data.shape[1:]
    assert cov.shape == (len(mean), len(mean))
    
    return mean, cov

In [None]:
# For now, work with two datasets which are disjoint subsets of MNIST
samples1 = mnist_test[:20]
samples2 = mnist_test[20:40]

In [None]:
classified1 = embed_images(samples1)
classified2 = embed_images(samples2)
gauss1 = fit_gaussian(classified1)
gauss2 = fit_gaussian(classified2)
frechet_dist = distances.frechet_distance(gauss1, gauss2)

# small numerical error can give complex distance
assert np.abs(np.imag(frechet_dist)) < 1e-6, np.max(np.abs(np.imag(frechet_dist)))
frechet_dist = np.real(frechet_dist)

print(frechet_dist)

In [None]:
datasets_files = [
    'data/mnist_test.npy',
    'data/FC_VAE_samples.npy',
    'data/convolutional_VAE_samples.npy',
    'data/diffusion_samples.npy',
]

dataset_titles = [
    'MNIST test set',
    'Simple Variational Autoencoder',
    'Convolutional Variational Autoencoder',
    'Diffusion model',
]

for file in datasets_files:
    assert os.path.exists(file), file

In [None]:
# Reference, the MNIST training set
mnist_train = np.load('data/mnist_train.npy')

# https://en.wikipedia.org/wiki/Fr%C3%A9chet_inception_distance
reference_gaussian = fit_gaussian(embed_images(mnist_train))

In [None]:
# Load generated images
datasets = [np.load(file) for file in datasets_files]

In [None]:
print("FLD scores")
print("==========")
scores = dict()
for d in range(len(dataset_titles)):
    images = datasets[d]
    name = dataset_titles[d]
    model_gaussian = fit_gaussian(embed_images(images))
    frechet_dist = distances.frechet_distance(reference_gaussian, model_gaussian)
    print('{0:40}  {1}'.format(name, frechet_dist))
    scores[name] = frechet_dist

print("==========")

In [None]:
# Write scores to CSV file
if enable_file_save:
    with open('data/fld_scores.csv', 'w', newline='') as f:
        writer = csv.writer(f)
        writer.writerow(('Model', 'Score'))
        for row in scores.items():
            writer.writerow(row)

In [None]:
histogram_datasets = [mnist_train] + datasets[1:]
histogram_series_labels = ['MNIST training set'] + dataset_titles[1:] 
class_distributions = [classify_images(images) for images in histogram_datasets]

In [None]:
# Plot histogram of distribution of class predictions

fig = plt.figure(figsize=(8, 4))

# workaround to get x-axis labels correct
# https://stackoverflow.com/a/27084005
bins=np.arange(11)-0.5

plt.hist(class_distributions, bins, density=True, histtype='bar', label=histogram_series_labels)
fontsize = 12
plt.xlabel('Image label', fontsize=fontsize)
plt.ylabel('Frequency', fontsize=fontsize)
plt.gca().yaxis.set_major_formatter(PercentFormatter(1))
plt.xticks(range(10), fontsize=fontsize)
plt.yticks(fontsize=fontsize)
plt.legend(fontsize=fontsize)

if enable_file_save:
    fig.dpi = 500
    plt.savefig('plots/class_distribution_histogram.png')

#plt.title('Distribution of generated images (and MNIST training set)')
plt.show()

In [None]:
# Plot classification of some of our generated images
for d in range(len(datasets)):
    images = datasets[d]
    name = dataset_titles[d]
    plot_classification(images[:15], name)