### Novelty and consolidation

Code for exploring how novelty might affect memory consolidation.

#### Installation:

In [None]:
!pip install -r requirements.txt

#### Imports:

In [None]:
from end_to_end import run_end_to_end
import tensorflow as tf
import tensorflow_datasets as tfds
from utils import prepare_data, noise, display
from initial_model import create_autoencoder
from initial_tests import check_initial_recall, iterative_recall
from generative_model import VAE, build_encoder_decoder_v3, build_encoder_decoder_v5
from generative_tests import interpolate_ims, plot_latent_space, check_generative_recall, plot_history, vector_arithmetic
from tensorflow import keras
import numpy as np
from random import randrange
from PIL import Image
import matplotlib.pyplot as plt
import hopfield_utils
from hopfield_models import ContinuousHopfield
import matplotlib.backends.backend_pdf
from config import models_dict, dims_dict
import matplotlib.pyplot as plt
import matplotlib

# set tensorflow random seed to make outputs reproducible
tf.random.set_seed(123)

#### Train VAE

As a starting point, train VAE on just the MNIST dataset.

In [None]:
generative_epochs=10
latent_dim=5
kl_weighting=1
dataset = 'mnist'
lr=0.002

(x_train, _), (x_test, _) = keras.datasets.mnist.load_data()
mnist_digits = np.concatenate([x_train, x_test], axis=0)
mnist_digits = np.expand_dims(mnist_digits, -1).astype("float32") / 255
    
encoder, decoder = models_dict[dataset](latent_dim = latent_dim)
vae = VAE(encoder, decoder, kl_weighting)
opt = keras.optimizers.Adam(lr=lr)
vae.compile(optimizer=opt)
history = vae.fit(mnist_digits[0:5000], epochs=generative_epochs, verbose=1, batch_size=32, shuffle=True)

#### Measure reconstruction error for new data

We'll use 2000 unseen items from each of three datasets: MNIST, KMNIST, and Fashion MNIST.

MNIST:

In [None]:
encs = vae.encoder.predict(mnist_digits[5000:7000])
decs = vae.decoder.predict(encs[0])
mnist_recons = tf.reduce_sum(keras.losses.mean_absolute_error(mnist_digits[5000:7000], decs), axis=(1,2)).numpy().tolist()

KMNIST:

In [None]:
dim = dims_dict[dataset]
ds = tfds.load('kmnist', split='test', shuffle_files=True)
ds_info = tfds.builder(dataset).info
df = tfds.as_dataframe(ds.take(2000), ds_info)
new_train_data = np.empty([2000, dim[0], dim[0]])
train_data = df['image']
for ind, t in enumerate(train_data):
    im = Image.fromarray(t.reshape((28,28))).resize((dim[0],dim[0]))
    new_train_data[ind] = np.asarray(im)
train_data = new_train_data

kmnist_digits = np.expand_dims(train_data, -1).astype("float32") / 255

encs = vae.encoder.predict(kmnist_digits[0:2000])
decs = vae.decoder.predict(encs[0])
kmnist_recons = tf.reduce_sum(keras.losses.mean_absolute_error(kmnist_digits[0:2000], decs), axis=(1,2)).numpy().tolist()

Fashion MNIST:

In [None]:
(x_train, _), (x_test, _) = keras.datasets.fashion_mnist.load_data()
fmnist_digits = np.concatenate([x_train, x_test], axis=0)
fmnist_digits = np.expand_dims(fmnist_digits, -1).astype("float32") / 255

encs = vae.encoder.predict(fmnist_digits[0:2000])
decs = vae.decoder.predict(encs[0])
fmnist_recons = tf.reduce_sum(keras.losses.mean_absolute_error(fmnist_digits[0:2000], decs), axis=(1,2)).numpy().tolist()

#### Generate plots

In [None]:
matplotlib.rcParams.update({'font.size': 12})
plt.rcParams.update({"figure.figsize": (8,5)})

fig = plt.figure()

n, bins, patches = plt.hist(fmnist_recons, 25, density=True, facecolor='blue', alpha=0.5, label='Fashion-MNIST')
n, bins, patches = plt.hist(mnist_recons, 25, density=True, facecolor='black', alpha=0.5, label='MNIST')
n, bins, patches = plt.hist(kmnist_recons, 25, density=True, facecolor='red', alpha=0.5, label='KMNIST')
plt.title('Reconstruction error by dataset')
plt.xlabel('Reconstruction error')
plt.ylabel('Probability')
plt.legend()
plt.savefig('recon_error_by_dataset.png')
plt.show()


In [None]:
labels = ['MNIST', 'KMNIST', 'Fashion-MNIST']
threshold = 100

hopfield_means = [len([i for i in mnist_recons if i>threshold]), 
                  len([i for i in kmnist_recons if i>threshold]),
                 len([i for i in fmnist_recons if i>threshold])]

no_hopfield_means = [len([i for i in mnist_recons if i<threshold]), 
                  len([i for i in kmnist_recons if i<threshold]),
                 len([i for i in fmnist_recons if i<threshold])]

x = np.arange(len(labels))  # the label locations
width = 0.35  # the width of the bars

fig, ax = plt.subplots()
rects1 = ax.bar(x - width/2, hopfield_means, width, label='Hopfield encoding required', facecolor='b', alpha=0.5)
rects2 = ax.bar(x + width/2, no_hopfield_means, width, label='Hopfield encoding not required', facecolor='r', alpha=0.5)


# Add some text for labels, title and custom x-axis tick labels, etc.
ax.set_ylabel('Number of memories')
ax.set_title('Number of memories stored in Hopfield network by dataset')
ax.set_xticks(x)
ax.set_xticklabels(labels)
ax.legend()
ax.set_ylim([0, 2500])

plt.savefig('Hopfield fraction.png')
plt.show()