# Display generated images

Notebook for displaying samples from our datasets consisting of generated images and MNIST

In [None]:
import numpy as np
import matplotlib.pyplot as plt
import matplotlib.axes, matplotlib.figure
import os.path
np.random.seed(0)

In [None]:
enable_file_save = False # Set to True to save figures etc.

In [None]:
datasets_files = [
    'data/mnist_test.npy',
    'data/FC_VAE_samples.npy',
    'data/convolutional_VAE_samples.npy',
    'data/diffusion_samples.npy',
]

dataset_titles = [
    'MNIST test set',
    'Simple Variational Autoencoder',
    'Convolutional Variational Autoencoder',
    'Diffusion model',
]

reference_dataset = 'data/mnist_train.npy'
reference_dataset_title = 'MNIST training set'


for file in datasets_files:
    assert os.path.exists(file), file

In [None]:
num_samples = 20
num_images = 10000
selected_indices = np.random.randint(num_images, size=num_samples)

def plot_images(images, title):
    fig, axs = plt.subplots(1, len(images), figsize=(12, 1.5))
    fig.suptitle(title)
    
    for i in range(len(images)):
        ax = axs[i]
        img = images[i]
        img = np.clip(img, 0, 1)
        ax.imshow(img, cmap='gray_r')
        ax.set_xticks([]), ax.set_yticks([])

    plt.show()

def plot_dataset(dataset, title):
    loaded_images = np.load(dataset)
    assert len(loaded_images) == num_images

    selected_images = np.take(loaded_images, selected_indices, axis=0)
    plot_images(selected_images, title)

plot_dataset(datasets_files[0], dataset_titles[0])

In [None]:
def plot_datasets():

    fig = plt.figure(constrained_layout=True, figsize=(20, len(dataset_titles) * 1.2 + 2))
    #fig.suptitle('Sample images from our generated image datasets (and MNIST)')

    # create Dx1 subfigs
    subfigs = fig.subfigures(nrows=len(dataset_titles), ncols=1)
    for row, subfig in enumerate(subfigs):

        subfig.suptitle(f'{dataset_titles[row]}', fontsize=20)
        loaded_images = np.load(datasets_files[row])
        selected_images = np.take(loaded_images, selected_indices, axis=0)

        # create 1xN subplots per subfig
        axs = subfig.subplots(nrows=1, ncols=num_samples)
        for col, ax in enumerate(axs):
            img = selected_images[col]
            img = np.clip(img, 0, 1)
            ax.imshow(img, cmap='gray_r')
            ax.set_xticks([]), ax.set_yticks([])
    if enable_file_save:
        plt.savefig('plots/dataset_samples.png')
    
    plt.show()


plot_datasets()