### Exploring prototypical distortions

This notebook contains code to demonstrate that generative networks such as VAEs make their outputs more prototypical.

Tested with tensorflow 2.11.0 and Python 3.10.9.

#### Installation

In [None]:
!pip install -r requirements.txt
!pip install umap-learn

#### Imports

In [None]:
import tensorflow as tf
import tensorflow_datasets as tfds
from distortions_utils import *
from tensorflow import keras
import numpy as np
from random import randrange
from PIL import Image
import matplotlib.pyplot as plt
import matplotlib.backends.backend_pdf
from config import dims_dict
from generative_model import models_dict
import matplotlib
from sklearn.decomposition import PCA
from sklearn.manifold import TSNE
from umap import UMAP
from tensorflow.keras import layers
import tensorflow.keras.backend as K
from utils import display
from sklearn.decomposition import PCA
import matplotlib.pyplot as plt
import numpy as np

import warnings
warnings.filterwarnings("ignore")

# set tensorflow random seed to make outputs reproducible
tf.random.set_seed(123)
np.random.seed(123)

#### Measuring intra-class variation

Load MNIST VAE trained previously and some test data:

In [None]:
(x_train, y_train), (x_test, y_test) = keras.datasets.mnist.load_data()
mnist_digits = np.concatenate([x_test], axis=0)
mnist_labels = np.concatenate([y_test], axis=0)
mnist_digits = np.expand_dims(mnist_digits, -1).astype("float32") / 255

encoder, decoder = build_encoder_decoder_small(latent_dim=20)
encoder.load_weights("model_weights/mnist_encoder.h5")
decoder.load_weights("model_weights/mnist_decoder.h5")
vae = VAE(encoder, decoder, kl_weighting=1)

Get latents and outputs before and after recall and optionally plot latent spaces.

Note that no noise is applied in check_generative_recall() function from distortions_utils as this would invalidate the comparison of the variances.

In [None]:
test_data = mnist_digits
all_pixels = []
all_latents = []

latents = vae.encoder.predict(test_data)
latent_umap = UMAP(n_components=2, min_dist=1, n_neighbors=20)
latent_umap.fit(latents[0])
pixel_umap = UMAP(n_components=2, min_dist=1, n_neighbors=20)
pixel_umap.fit(test_data.reshape(test_data.shape[0], 784))

for i in range(2):    
    all_pixels.append(test_data)
    test_data, latents = check_generative_recall(vae, test_data, mnist_labels, latent_umap, pixel_umap, 
                                                 displaybool=False, n=3000)
    all_latents.append(latents)

Calculate the total variance across each MNIST class before and after recall:

In [None]:
variances = {}

for time_step in range(0,2):
    variances[time_step] = []
    for i in range(0,10):
        px = all_pixels[time_step][0:10000]
        inds = np.where((mnist_labels[0:10000]==i))
        px_for_digit = px[inds][0:500]

        # Reshape the images into 1D vectors (n, 784)
        reshaped_images = px_for_digit.reshape((px_for_digit.shape[0], -1))
        # Calculate variance per pixel of images, giving array of shape (784,)
        variance_vec = np.var(reshaped_images, axis=0)
        variances[time_step].append(variance_vec)

In [None]:
reshaped_images.shape

Plot intra-class variation before and after recall for each MNIST class:

In [None]:
[np.std(v) / np.sqrt(len(v)) for v in variances[0]]

In [None]:
matplotlib.rcParams.update({'font.size': 12})

labels = range(10)
before_means = [np.mean(v) for v in variances[0]]
after_means = [np.mean(v) for v in variances[1]]
before_sem = [np.std(v) / np.sqrt(len(v)) for v in variances[0]]
after_sem = [np.std(v) / np.sqrt(len(v)) for v in variances[1]]

x = np.arange(len(labels))
width = 0.4 

fig, ax = plt.subplots(figsize=(7,4))

rects1 = ax.bar(x - width/2, before_means, width, yerr=before_sem, capsize=5, label='Inputs', color='red', alpha=0.5)
rects2 = ax.bar(x + width/2, after_means, width, yerr=after_sem, capsize=5, label='Outputs', color='blue', alpha=0.5)

ax.set_ylabel('Mean variance per pixel')
ax.set_title('Intra-class image variation')
plt.xticks(x)
plt.ylim(0, 0.073)
ax.legend()

fig.tight_layout()
plt.savefig('misc_plots/mnist_variance.pdf')
plt.show()

In [None]:
# Assuming variances[0] and variances[1] are your data sets for 'before' and 'after'
before_data = variances[0]
after_data = variances[1]

# Creating a figure and axis
fig, ax = plt.subplots(figsize=(7, 4))

# Positioning of the boxes
positions = np.arange(len(before_data)) * 2
width = 0.65  # Width of the boxes

# Creating the box plot
box1 = ax.boxplot(before_data, positions=positions - width/2, widths=width, patch_artist=True, 
                  boxprops=dict(facecolor='red', alpha=0.5), medianprops=dict(color='black'),
                 showfliers=False, whis=[10,90])
box2 = ax.boxplot(after_data, positions=positions + width/2, widths=width, patch_artist=True, 
                  boxprops=dict(facecolor='blue', alpha=0.5), medianprops=dict(color='black'),
                 showfliers=False, whis=[10,90])

#ax.set_ylabel('Variance per pixel', fontsize=14)
ax.tick_params(axis='y', labelsize=16)
plt.ylim(0, 0.235)
#ax.set_title('Intra-class image variation')
ax.set_xticks(positions)
ax.set_xticklabels(range(1, len(before_data) + 1), fontsize=16)
legend = ax.legend([box1["boxes"][0], box2["boxes"][0]], ['Inputs', 'Outputs'], fontsize=16)

fig.tight_layout()
plt.savefig('misc_plots/mnist_variance_boxplot.pdf')
plt.show()


#### Statistical analysis

In [None]:
from scipy.stats import ttest_rel

flattened_before_data = [item for sublist in before_data for item in sublist]
flattened_after_data = [item for sublist in after_data for item in sublist]

res = ttest_rel(flattened_before_data, flattened_after_data)
print(res)
print(res.confidence_interval())

mean_difference = np.mean(np.array(flattened_after_data) - np.array(flattened_before_data))
std_dev_difference = np.std(np.array(flattened_after_data) - np.array(flattened_before_data))
cohens_d = mean_difference / std_dev_difference
print(f"Cohen's d: {cohens_d}")