### Modelling memory distortions with a variational autoencoder: the Carmichael experiment

In an experiment by Carmichael et al. (1932), subjects were asked to reproduce ambiguous sketches. A context was established by telling the subjects that they would see images from a certain category. It was found that when the subjects tried to reproduce the image after a delay, their drawings were distorted to look more like members of the context class.

Nagy et al. (2020) showed that a variational autoencoder trained on a class biases recall towards that class, but they used a separate model for each class (rather than a single model with context as an input). The data below extends this by using a single model with the context represented by a cue in the image.

In this notebook, I train variational autoencoders on pairs of visually similar classes from various datasets. The results show that the model 'recalls' the same ambiguous image differently depending on the context; as in the Carmichael experiment, the 'recalled' image looks more like the context class.

#### Imports:

In [None]:
import numpy as np
from matplotlib.pyplot import imshow
import matplotlib.pyplot as plt
import numpy as np
import tensorflow as tf
from tensorflow import keras
from tensorflow.keras import layers
import random
from utils import prepare_data, noise, display
from initial_model import create_autoencoder
from initial_tests import check_initial_recall, iterative_recall
from generative_model import build_encoder_decoder_v1, VAE,build_encoder_decoder_v2, build_encoder_decoder_v3, build_encoder_decoder_v4, build_encoder_decoder_v5
from generative_tests import interpolate_ims, plot_latent_space, check_generative_recall, plot_history, vector_arithmetic
from tensorflow import keras
from tensorflow.keras.datasets import mnist, fashion_mnist, cifar10
import numpy as np
from random import randrange
from PIL import Image
import matplotlib.pyplot as plt
import hopfield_utils
from hopfield_models import ContinuousHopfield
import matplotlib.backends.backend_pdf
from config import models_dict, dims_dict

tf.random.set_seed(123)

#### Load the data

Based on Carmichael et al. (1932), let's pick a similar looking pairs of items.

In [None]:
(train_data, train_label), (test_data, _) = mnist.load_data()

In [None]:
# class bar will be used to set context as data1
class_1 = 2
class_2 = 3

data1 = [i for ind, i in enumerate(train_data) if train_label[ind] == class_2]
data2 = [i for ind, i in enumerate(train_data) if train_label[ind] == class_1]

In [None]:
len(data2)

#### Add context cues

Equivalent to the use of a word to set the context in Carmichael et al, I add added a feature to each image that indicates the class. This will allow us to manipulate the context and see how that effects the reconstruction. A horizontal line at the top of the image indicates a certain class.

In [None]:
def add_class_bar(d):
    d_bar = [255]*28 + list(d.flatten())[28:]
    return np.array(d_bar).reshape((28,28,1))

def preprocess(d):
    d_bar = list(d.flatten())
    return np.array(d_bar).reshape((28,28,1))

def remove_class_bar(d):
    d_bar = [0]*28 + list(d.flatten())[28:]
    return np.array(d_bar).reshape((28,28,1))

In [None]:
num = 5000
resized_train = [add_class_bar(d) for d in data1][0:num] + [preprocess(d) for d in data2][0:num] #+ [np.resize(add_class_bar(d), (28,28,1)) for d in data3][0:num] + [np.resize(d, (28,28,1)) for d in data4][0:num]
train = np.stack(resized_train, axis=0).astype("float32") / 255

inverse_resized_train = [preprocess(d) for d in data1][0:num] + [add_class_bar(d) for d in data2][0:num] #+ [np.resize(add_class_bar(d), (28,28,1)) for d in data3][0:num] + [np.resize(d, (28,28,1)) for d in data4][0:num]
inverse_train = np.stack(resized_train, axis=0).astype("float32") / 255

Plot an item from the first class as an example to see how the context appears:

In [None]:
fig, ax = plt.subplots()
im = ax.imshow(resized_train[0].reshape((28,28)))

In [None]:
fig, ax = plt.subplots()
im = ax.imshow(resized_train[num].reshape((28,28)))

#### Train the VAE

Using the functions defined above, build a VAE with two latent variables:

In [None]:
encoder, decoder = build_encoder_decoder_v5(latent_dim = 6)
vae = VAE(encoder, decoder, kl_weighting=5)
opt = keras.optimizers.Adam(lr=0.001)
vae.compile(optimizer=opt)
history = vae.fit(train, epochs=100, verbose=0, batch_size=32, shuffle=True)

#### Explore how context affects recall

Let's see how the model affects the same drawing with and without the context cue.

In [None]:
class_a_with_cue = [np.resize(add_class_bar(d), (28,28,1)) for d in data1[0:10]]
class_a_with_cue = np.array(class_a_with_cue).astype("float32") / 255
class_a_without_cue = [np.resize(d, (28,28,1)) for d in data1[0:10]]
class_a_without_cue = np.array(class_a_without_cue).astype("float32") / 255

class_b_with_cue = [np.resize(add_class_bar(d), (28,28,1)) for d in data2[0:10]]
class_b_with_cue = np.array(class_b_with_cue).astype("float32") / 255
class_b_without_cue = [np.resize(d, (28,28,1)) for d in data2[0:10]]
class_b_without_cue = np.array(class_b_without_cue).astype("float32") / 255

#### How does context affect recall of ambiguous items from the class_b category?

In [None]:
fig, axs = plt.subplots(5, 3, figsize=(6,8), sharex=True)

for i in range(5):
    item = class_b_without_cue[i+5]
    axs[i,0].imshow(np.resize(item, (28,28)), cmap='Greys')
    axs[i,0].axis('off')

    encoding = encoder.predict(item.reshape(1,28,28,1))
    x_decoded = decoder.predict(encoding)
    axs[i,1].imshow(remove_class_bar(np.resize(x_decoded, (28,28))), cmap='Greys')
    axs[i,1].axis('off')
    
    item = class_b_with_cue[i]
    encoding = encoder.predict(item.reshape(1,28,28,1))
    x_decoded = decoder.predict(encoding)
    axs[i,2].imshow(remove_class_bar(np.resize(x_decoded, (28,28))), cmap='Greys')
    axs[i,2].axis('off')
    
    
for ax, col in zip(axs[0,:], ['Original', 
                              'Recalled (context 1)',  
                              'Recalled (context 2)']):
    ax.set_title(col, size=12)
    
fig.tight_layout() 
plt.show()

#### How does context affect recall of ambiguous items from the class_a category?

In [None]:
fig, axs = plt.subplots(5, 3, figsize=(6,8), sharex=True)

for i in range(5):
    item = class_a_without_cue[i+5]
    axs[i,0].imshow(np.resize(item, (28,28)), cmap='Greys')
    axs[i,0].axis('off')

    encoding = encoder.predict(item.reshape(1,28,28,1))
    x_decoded = decoder.predict(encoding)
    axs[i,1].imshow(remove_class_bar(np.resize(x_decoded, (28,28))), cmap='Greys')
    axs[i,1].axis('off')
    
    item = class_a_with_cue[i]
    encoding = encoder.predict(item.reshape(1,28,28,1))
    x_decoded = decoder.predict(encoding)
    axs[i,2].imshow(remove_class_bar(np.resize(x_decoded, (28,28))), cmap='Greys')
    axs[i,2].axis('off')
    
    
for ax, col in zip(axs[0,:], ['Original', 
                              'Recalled (context 1)',  
                              'Recalled (context 2)']):
    ax.set_title(col, size=12)
    
fig.tight_layout() 
plt.show()

### Loop through different KL weight / LD combinations

In [None]:
def distortions(kl_weighting=1, l_d=4):
    encoder, decoder = build_encoder_decoder_v5(latent_dim = l_d)
    vae = VAE(encoder, decoder, kl_weighting=kl_weighting)
    opt = keras.optimizers.Adam(lr=0.001)
    vae.compile(optimizer=opt)
    history = vae.fit(train, epochs=100, verbose=0, batch_size=32, shuffle=True)
    
    class_a_with_cue = [np.resize(add_class_bar(d), (28,28,1)) for d in data1[0:10]]
    class_a_with_cue = np.array(class_a_with_cue).astype("float32") / 255
    class_a_without_cue = [np.resize(d, (28,28,1)) for d in data1[0:10]]
    class_a_without_cue = np.array(class_a_without_cue).astype("float32") / 255

    class_b_with_cue = [np.resize(add_class_bar(d), (28,28,1)) for d in data2[0:10]]
    class_b_with_cue = np.array(class_b_with_cue).astype("float32") / 255
    class_b_without_cue = [np.resize(d, (28,28,1)) for d in data2[0:10]]
    class_b_without_cue = np.array(class_b_without_cue).astype("float32") / 255
    
    fig, axs = plt.subplots(5, 3, figsize=(8,8), sharex=True)

    for i in range(5):
        item = class_b_without_cue[i]
        axs[i,0].imshow(np.resize(item, (28,28)), cmap='Greys')
        axs[i,0].axis('off')

        encoding = encoder.predict(item.reshape(1,28,28,1))
        x_decoded = decoder.predict(encoding)
        axs[i,1].imshow(remove_class_bar(np.resize(x_decoded, (28,28))), cmap='Greys')
        axs[i,1].axis('off')

        item = class_b_with_cue[i]
        encoding = encoder.predict(item.reshape(1,28,28,1))
        x_decoded = decoder.predict(encoding)
        axs[i,2].imshow(remove_class_bar(np.resize(x_decoded, (28,28))), cmap='Greys')
        axs[i,2].axis('off')


    for ax, col in zip(axs[0,:], ['Original:', 
                                  'Recalled (context={}):'.format(class_1),  
                                  'Recalled (context={}):'.format(class_2)]):
        ax.set_title(col, size=12)

    fig.tight_layout() 
    #plt.show()
    fig.savefig('./distortions/mnist_{}lv_{}kl_{}vs{}_1.png'.format(l_d, kl_weighting, class_1, class_2))
    
    fig, axs = plt.subplots(5, 3, figsize=(8,8), sharex=True)

    for i in range(5):
        item = class_a_without_cue[i]
        axs[i,0].imshow(np.resize(item, (28,28)), cmap='Greys')
        axs[i,0].axis('off')

        encoding = encoder.predict(item.reshape(1,28,28,1))
        x_decoded = decoder.predict(encoding)
        axs[i,1].imshow(remove_class_bar(np.resize(x_decoded, (28,28))), cmap='Greys')
        axs[i,1].axis('off')

        item = class_a_with_cue[i]
        encoding = encoder.predict(item.reshape(1,28,28,1))
        x_decoded = decoder.predict(encoding)
        axs[i,2].imshow(remove_class_bar(np.resize(x_decoded, (28,28))), cmap='Greys')
        axs[i,2].axis('off')


    for ax, col in zip(axs[0,:], ['Original:', 
                                  'Recalled (context={}):'.format(class_1),  
                                  'Recalled (context={}):'.format(class_2)]):
        ax.set_title(col, size=12)

    fig.tight_layout() 
    #plt.show()
    fig.savefig('./distortions/mnist_{}lv_{}kl_{}vs{}_2.png'.format(l_d, kl_weighting, class_1, class_2))
    
klws = [1]
lds = [4]

for ld in lds:
    for klw in klws:
        print(ld, klw)
        distortions(kl_weighting=klw, l_d=ld)

### Other things of interest

#### Effect of the number of latent variables

Let's try building models with a different dimension latent space, and see how that effects these results. The following cells trains 4 models with 2, 4, 6, and 8 latent variables respectively.

In [None]:
import pickle

models = []
latent_dims = [2,4,6,8]

for l_d in latent_dims:
    print("Training model with {} latent variables.".format(l_d))
    encoder, decoder = build_encoder_decoder_v5(latent_dim=l_d)
    vae = VAE(encoder, decoder)
    vae.compile(optimizer=keras.optimizers.Adam(lr=0.001))
    vae.fit(train, epochs=100, batch_size=128, verbose=False)
    
    models.append((encoder,decoder,vae))

Let's see how the different models recall the same item.

It appears that smaller the dimension of the latent space, the greater the distortion (this makes sense as the memory gets compressed more). In other words, a variational autoencoder model of memory suggests you have more gist-based distortion when the storage capacity of the 'semantic memory' is smaller.

In [None]:
fig, axs = plt.subplots(10, 5, figsize=(8,12), sharex=True)

for ind, (encoder, decoder, vae) in enumerate(models):
    for i in range(10):
        item = inverse_train[random.randrange(0,len(train))]
        
        axs[i,0].imshow(remove_class_bar(np.resize(item, (28,28))), cmap='Greys')
        axs[i,0].axis('off')

        encoding = encoder.predict(item.reshape(1,28,28,1))
        x_decoded = decoder.predict(encoding)
        axs[i,ind+1].imshow(remove_class_bar(np.resize(x_decoded, (28,28))), cmap='Greys')
        axs[i,ind+1].axis('off')
        
for ax, col in zip(axs[0,:], ['Original', 
                              '2 L.V.', 
                              '4 L.V.', 
                              '6 L.V.',
                              '8 L.V.']):
    ax.set_title(col, size=9)
        
plt.show()

#### References

Carmichael, L., Hogan, H. P., & Walter, A. A. (1932). An experimental study of the effect of language on the reproduction of visually perceived form. Journal of experimental Psychology, 15(1), 73.

Nagy, D. G., Török, B., & Orbán, G. (2020). Optimal forgetting: Semantic compression of episodic memories. PLOS Computational Biology, 16(10), e1008367.