### DRM experiment simulation - extended model

The Deese-Roediger-McDermott task is a classic way to measure memory distortion. This notebook tries to recreate the human results in VAE and AE models.

#### Installation:

In [None]:
!pip install tensorflow==2.11.0
!pip install tensorflow-datasets
!pip install tfds-nightly
!pip install scikit-learn --upgrade

#### Imports:

In [None]:
import pandas as pd
from sklearn.feature_extraction.text import CountVectorizer, TfidfVectorizer
import tensorflow_datasets as tfds
import matplotlib.pyplot as plt
import numpy as np
import tensorflow as tf
from tensorflow import keras
from tensorflow.keras import layers
from config import DRM_lists, lures
from data_preparation import *
from generative_model import *
# from hpc import HPC
from hopfield_models import *
# from drm_utils import *

tf.random.set_seed(1234)

#### VAE pre-training

In [None]:
x_train, vectorizer = prepare_data(max_df=0.1, min_df=0.0005, ids=False)

# vae = train_vae(x_train, vectorizer, eps=100, ld=300, beta=0.001, batch=128, l1_value=0.01,
#                 model_save_path='300ld_100eps_0.001beta_128batch_0.01l1.h5')

#### Load model

In [None]:
model_weights_path = '300ld_100eps_0.001beta_128batch_0.01l1.h5'
vae = load_vae(vectorizer, ld=300, beta=0.001, 
               model_weights_path=model_weights_path)

#### Extended model - store latent codes plus IDs:

In [None]:
# create lists with additional unique spatiotemporal features:
experiences = [' '.join([f'id_{n}'] + l) for n, l in enumerate(list(DRM_data.values()))]
print("Example list with unique spatiotemporal feature:")
print(experiences[0])

_, id_vectorizer = prepare_data(max_df=0.1, min_df=0.0005, ids=True)

id_vectorizer.transform(['id_0'])

In [None]:
# set latent dimension
ld = 300

# create MHN with dimension ld + id_vectorizer vocabulary size
net = ContinuousHopfield(ld + len(id_vectorizer.vocabulary_.keys()),
                         beta=100,
                         do_normalization=False)

In [None]:
def get_latent_code(sent):
    encoded = vae.encoder.predict(vectorizer.transform([sent]))[0]
    return encoded

patterns = []
for test_text in experiences:
    # get latent code for list
    latent = get_latent_code(test_text)
    # flatten latent to 1D array
    latent = latent.flatten()  
    # get vector representing unique spatiotemporal context
    id_counts = id_vectorizer.transform([test_text.split()[0]]).toarray()
    # check unique spatiotemporal context
    print(id_vectorizer.inverse_transform(id_counts))
    id_counts = id_counts.flatten() 
    pattern = list(latent) + list(id_counts)
    patterns.append(pattern)

patterns = [np.array(p).reshape(-1, 1) for p in patterns]
net.learn(np.array(patterns))

#### Test recall:

In [None]:
def hybrid_recall(test_text, net):
    latent = np.full((1, ld), 0)
    latent = latent.flatten()  # flatten latent to 1D array
    id_counts = id_vectorizer.transform([test_text]).toarray()
    id_counts = id_counts.flatten()  # flatten id_counts to 1D array
    pattern = list(latent) + list(id_counts)  # concatenating two lists

    memory = net.retrieve(np.array(pattern).reshape(-1, 1))
    
    decoded = vae.decoder.predict(memory[0:ld].reshape((1,ld)))
    top_words = [(word_lookup[index], decoded[0][index]) for index in np.argsort(-decoded)[0]][0:15]

    unpredictable_component = id_vectorizer.inverse_transform(np.array(memory[ld:]).reshape((1, 4226)))
    recalled_words = [tuple((unpredictable_component[0][0], 1))] + list(top_words)
    return recalled_words

def hybrid_plot(ax, terms, scores, clrs, lure_word):
    ax.bar(terms, scores, color=clrs, alpha=0.5)
    ax.axhline(y=0.5, color='grey', linestyle='--') # Add a dashed line at y=0.5
    ax.set_ylabel('Recall score', fontsize=16)
    ax.set_title(f"Lure word '{lure_word}'", fontsize=18)
    plt.sca(ax)
    plt.xticks(rotation=90, fontsize=16)

word_lookup = {v:k for k,v in vectorizer.vocabulary_.items()}

fig, axs = plt.subplots(len(lures), 1, figsize=(10, 4*len(lures)))
fig.tight_layout(h_pad=12)

for i, ax in enumerate(axs):
    lure = lures[i]
    list_words = DRM_data[lures[i]] + [f'id_{i}']
    recalled = hybrid_recall(f'id_{i}', net)
    terms = [i[0] for i in recalled]
    scores = [i[1] for i in recalled]
    clrs = ['red' if x == lures[i] else 'blue' if x in list_words else 'grey' for x in terms]
    hybrid_plot(ax, terms, scores, clrs, lure)

plt.savefig('mhn_drm.png', bbox_inches='tight')
plt.show()

#### Show longer lists increase false recall:

In [None]:
def test_list_length(vectorizer, encoder, decoder, list_name, num_words):
    # Get the list
    full_list = DRM_data[list_name]

    # Check if the number of words is greater than the list length
    if num_words > len(full_list):
        print("The number of words is greater than the list length.")
        return None

    # Get the first num_words of the list
    subset_list = full_list[:num_words]
    print(f"Subset list: {subset_list}")

    # Transform the list into the VAE's feature space and encode and decode it
    encoded = encoder.predict(vectorizer.transform([' '.join(subset_list)]))[0]
    decoded = decoder.predict(encoded)

    # Get the recalled words
    word_lookup = {v:k for k,v in vectorizer.vocabulary_.items()}
    recalled_words = [word_lookup[index] for index in np.argsort(-decoded)[0] if decoded[0][index] > 0.5]

    print(f"Recalled words: {recalled_words}")

    if list_name in recalled_words:
        return 1
    else:
        return 0

word_counts = range(1, 15)
list_names = list(DRM_data.keys())

In [None]:
# Create a dictionary to store results
results = {}

# For each list
for list_name in list_names:
    results[list_name] = []
    
    # For each word count
    for num_words in word_counts:
        # Calculate semantic intrusions
        semantic_intrusions = test_list_length(vectorizer, vae.encoder, vae.decoder, list_name, num_words)
        
        # Store the result
        results[list_name].append(semantic_intrusions)


In [None]:
# Create an array to hold the percentages and standard errors
percentages = []
errors = []

# For each index and number of words
for i, num_words in enumerate(word_counts):
    lure_retrieved_counts = []
    # For each list
    for list_name in list_names:
        # Append whether the lure word was retrieved or not
        lure_retrieved_counts.append(results[list_name][i])
    # Calculate the percentage and standard error, and append them to the arrays
    percentage = np.mean(lure_retrieved_counts) * 100
    error = np.std(lure_retrieved_counts) / np.sqrt(len(lure_retrieved_counts)) * 100
    percentages.append(percentage)
    errors.append(error)

# Create a plot
plt.figure(figsize=(8, 6))

# Create the line plot with error bars
plt.errorbar(word_counts, percentages, yerr=errors, fmt='-o', capsize=5)

# Add labels and title
plt.xlabel("Number of words", fontsize=16)
plt.ylabel("Lure recall percentage", fontsize=16)
plt.xticks(fontsize=16)
plt.yticks(fontsize=16)

# Show the plot
plt.savefig('lure_words_fraction.png', dpi=300)
plt.show()
