In [None]:
# Import Python packages
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
from scipy import stats
from scipy.stats import zscore
import warnings
warnings.filterwarnings('ignore')

# Set random seed for reproducibility
np.random.seed(42)

# parameters to change

In [None]:
# strength of drift
# scale_drift = 0.4  # scaling parameter (not used if inputting manual drift)

# strength of semantic + episodic route
sem = 0.5
episodic = 0.5

# normalize
sem_weight = sem / (episodic + sem)
episodic_weight = episodic / (episodic + sem)

# reward sequence

In [None]:
# random sample (1-indexed to match R behavior, then convert to 0-indexed)
pres_indices = np.random.permutation(10) + 1

# high RPE within sequence (first outcome is "reward expectation" at 0)
# sequence = np.array([0, 54, 57, 56, 53, 55, 7, 5, 4, 6, 3])  # initial expectation is 0

# primacy (first outcome is "reward expectation" at 50)
sequence = np.array([50, 54, 53, 52, 51, 50, 49, 48, 47, 46, 45])

# dynamic drift

In [None]:
# dynamic beta
# absPE = np.zeros((len(pres_indices), 1))

# take absolute difference of previous reward with current reward
# (first "prediction error" is reward - initial expectation)
# for seq in range(len(sequence) - 1):
#     absPE[seq] = abs(sequence[seq] - sequence[seq + 1])

# take absolute value of z-scored absPE and multiply by scaling parameter
# B_encD = np.abs(zscore(absPE)) * scale_drift

# B_encD = np.where(B_encD > 1, 1, B_encD)  # if it's over 1, make it 1

# manual drift

In [None]:
# manual drift (comment out if using above or stable drift)
B_encD = np.array([1, 0.65, 0.65, 0.65, 0.65, 0.65, 0.65, 0.65, 0.65, 0.65])

# semantic matrix

In [None]:
# create "semantic" matrix
# each item has an activation of 1 (and to no other units)
sem_mat = np.eye(len(pres_indices))

# original parameters (polyn, norman & kahana, 2009)

In [None]:
# SET PARAMETERS

# for creating the network
gamma_fc = 0.581  # relative strength of pre-existing associations on connections feature-context
eye_fc = 1 - gamma_fc  # if items are rep as orthonormal vectors, identity matrix ("eye")
eye_cf = 0

# during encoding
B_enc = 0.745  # vector of context integration rate at encoding, dynamic alternative above
lrate_fc_enc = gamma_fc  # feature-to-context during encoding
lrate_cf_enc = 1  # context-to-feature during encoding

# during recall
B_rec = 0.36  # vector of context integration at recall
lrate_fc_rec = 0  # feature-to-context during recall
lrate_cf_rec = 0  # context-to-feature during recall
thresh = 1  # threshold for an accumulating element to win the decision competition (fixed at 1)
rec_time = 90000  # max recall process (interpreted as 90 seconds)
dt = 100  # time constant on decision process
L = 0.375  # lateral inhibition between units in decision competition
K = 0.091  # decay rate for the accumulating elements in decision competition
eta = 0.3699  # standard deviation of gaussian noise term in decision competition
tau = 413  # time constant in decision competition

n_sims = 1000  # number of simulations
recall_sims = np.zeros((len(pres_indices), n_sims))
times_sims = np.zeros((len(pres_indices), n_sims))

# model run

In [None]:
for sims in range(n_sims):
    
    # NETWORK
    
    # initialize the features and context layers
    net_f = np.zeros((len(pres_indices), 1))
    net_c = np.zeros((len(pres_indices), 1))
    
    # learning rate matrices
    net_lrate_fc = np.zeros((len(pres_indices), len(pres_indices)))
    net_lrate_cf = np.zeros((len(pres_indices), len(pres_indices)))
    
    # the lrate matrices
    net_lrate_fc_enc = np.full((len(pres_indices), len(pres_indices)), lrate_fc_enc)
    net_lrate_cf_enc = np.full((len(pres_indices), len(pres_indices)), lrate_cf_enc)
    net_lrate_fc_rec = np.full((len(pres_indices), len(pres_indices)), lrate_fc_rec)
    net_lrate_cf_rec = np.full((len(pres_indices), len(pres_indices)), lrate_cf_rec)
    
    net_w_fc = np.eye(len(net_c)) * eye_fc  # m_fc eye() creates identity matrices
    net_w_cf = np.eye(len(net_f)) * eye_cf  # m_cf zero
    net_weights = np.zeros((len(pres_indices), len(pres_indices)))
    
    # ENCODING
    
    net_idx = np.arange(len(pres_indices))
    
    for item in range(len(pres_indices)):
        
        # present item
        feature_idx = pres_indices[item] - 1  # Convert to 0-indexed
        
        # activates the indexed feature (each item activates one element)
        net_f = np.zeros((len(pres_indices), 1))
        net_f[feature_idx] = 1
        
        # update context representations
        net_c_in = net_w_fc @ net_f
        
        # normalize vector
        vec = net_c_in
        denom_vec = np.sqrt(vec.T @ vec)[0, 0]
        norm_vec = vec / denom_vec
        net_c_in = norm_vec
        
        # advance context
        c_in = net_c_in
        c = net_c
        
        # set dynamic or stable drift
        B = B_encD[item]  # beta at encoding for dynamic
        # B = B_enc  # beta at encoding if stable
        
        dot_product = (c.T @ c_in)[0, 0]
        rho = np.sqrt(1 + (B**2) * ((dot_product**2) - 1)) - B * dot_product
        updated_c = rho * c + B * c_in
        net_c = updated_c
        
        # determine current learning rate
        lrate_fc = net_lrate_fc_enc
        lrate_cf = net_lrate_cf_enc
        
        # update weights
        
        # w_fc
        delta = (net_c @ net_f.T) * lrate_fc
        net_w_fc = net_w_fc + delta
        
        # w_cf
        delta = (net_f @ net_c.T) * lrate_cf
        net_w_cf = net_w_cf + delta
    
    # RECALL
    
    # set up
    recalls = np.zeros((len(pres_indices), 1))
    times = np.zeros((len(pres_indices), 1))
    
    rec_time_local = 90000
    time_passed = 0
    recall_count = 0
    
    retrieved = np.zeros((len(pres_indices), 1), dtype=bool)
    thresholds = np.ones((len(pres_indices), 1))
    
    # semantic + episodic routes
    net_weights = episodic_weight * net_w_cf + sem_weight * sem_mat
    
    # GO!!!
    
    while time_passed < rec_time_local:
        
        # input to the feature layer, from last context cue
        f_in = net_weights @ net_c
        
        # set max number of cycles
        max_cycles = int((rec_time_local - time_passed) / dt)
        
        # for noise error standard deviation
        dt_tau = dt / tau
        sq_dt_tau = np.sqrt(dt_tau)
        
        # noise matrix
        noise = np.random.normal(0, eta * sq_dt_tau, (len(pres_indices), max_cycles))
        eyeI = ~np.eye(len(pres_indices), dtype=bool)
        lmat = eyeI.astype(float) * L
        
        ncycles = noise.shape[1]
        inds = np.arange(len(pres_indices))
        
        crossed = 0
        
        x = np.zeros((len(pres_indices), 1))
        
        K_array = np.ones((len(pres_indices), 1)) * K
        
        i = 0
        
        # ACCUMULATORS CYCLING
        while i < ncycles and crossed == 0:
            
            # the lateral inhibition felt by each unit
            lx = lmat @ x
            
            # the activity leaking from each unit
            kx = K_array * x
            
            # change in each accumulator
            x = x + ((f_in - kx - lx) * dt_tau + noise[:, i:i+1])
            x[x < 0] = 0
            
            # reset retrieved values, allow them to compete but prevent from accumulating
            reset_these = retrieved & (x >= thresholds)
            x[reset_these] = 0.95 * thresholds[reset_these]
            
            # retrieved items cannot be repeated
            retrievable = ~retrieved
            
            # determine whether any items have crossed thresholds
            crossed = 0
            if np.any(x[retrievable] >= thresholds[retrievable]):
                crossed = 1
                temp_win = x[retrievable] >= thresholds[retrievable]
                temp_ind = inds[retrievable.flatten()]
                winners = temp_ind[temp_win.flatten()]
                
                # if there is a tie, random tiebreak
                if len(winners) > 1:
                    winners = np.array([np.random.choice(winners)])
                
                winner_position = np.where(pres_indices - 1 == winners[0])[0][0]
            
            i = i + 1
        
        # calculate the amount of elapsed time
        time = i * dt
        
        time_passed = time_passed + time
        
        # reactivate item if there has been a retrieval
        if crossed == 1:
            
            # activate the retrieved feature
            net_f = np.zeros((len(pres_indices), 1))
            net_f[winners[0]] = 1
            
            # update context representations
            net_c_in = net_w_fc @ net_f
            
            # normalize vector
            vec = net_c_in
            denom_vec = np.sqrt(vec.T @ vec)[0, 0]
            norm_vec = vec / denom_vec
            net_c_in = norm_vec
            
            # advance context
            c_in = net_c_in
            c = net_c
            B = B_rec  # beta at retrieval
            
            dot_product = (c.T @ c_in)[0, 0]
            rho = np.sqrt(1 + (B**2) * ((dot_product**2) - 1)) - B * dot_product
            updated_c = rho * c + B * c_in
            net_c = updated_c
            
            # determine current learning rate
            lrate_fc = net_lrate_fc_rec
            lrate_cf = net_lrate_cf_rec
            
            # w_fc
            delta = (net_c @ net_f.T) * lrate_fc
            net_w_fc = net_w_fc + delta
            
            # w_cf
            delta = (net_f @ net_c.T) * lrate_cf
            net_w_cf = net_w_cf + delta
            
            # record data
            recall_count = recall_count + 1
            recalls[recall_count - 1, 0] = winner_position + 1  # Convert back to 1-indexed
            times[recall_count - 1, 0] = time_passed
            
            # update retrieved vector
            retrieved[winners[0]] = True
    
    recall_sims[:, sims] = recalls.flatten()
    times_sims[:, sims] = times.flatten()

print(f"Model run complete. Simulated {n_sims} trials.")

# serial position curve

In [None]:
# calculate the total proportion of recall given the serial position
position = pd.DataFrame({'position': np.arange(1, len(pres_indices) + 1)})

numSums = np.zeros(len(pres_indices))

for numSum in range(len(numSums)):
    numSums[numSum] = np.sum(recall_sims == (numSum + 1))

recall = numSums / n_sims

prop_recall = pd.DataFrame({
    'position': position['position'],
    'recall': recall
})

print(prop_recall)

In [None]:
# plot recall success as a function of serial position
plt.figure(figsize=(8, 5))
plt.plot(prop_recall['position'], prop_recall['recall'], 'o-', linewidth=2, markersize=8)
plt.xlabel('serial position', fontsize=15)
plt.ylabel('probability of recall', fontsize=15)
plt.ylim(0, 1)
plt.xticks(fontsize=13)
plt.yticks(fontsize=13)
plt.grid(False)
plt.tight_layout()
plt.show()

# first recall probability

In [None]:
# determine the proportion of "first recall" items as a function of serial position
first_recall = recall_sims[0, :]
first_recall = first_recall[first_recall > 0]  # Filter out zeros

# Count frequencies
unique, counts = np.unique(first_recall, return_counts=True)
first_recall_table = pd.DataFrame({
    'position': unique.astype(int),
    'freq': counts,
    'prop': counts / n_sims
})

print(first_recall_table)

In [None]:
# plot first recall proportion as a function of serial position
plt.figure(figsize=(8, 5))
plt.plot(first_recall_table['position'], first_recall_table['prop'], 'o-', linewidth=2, markersize=8)
plt.xlabel('serial position', fontsize=15)
plt.ylabel('probability of first recall', fontsize=15)
plt.ylim(0, 0.8)
plt.xticks(fontsize=13)
plt.yticks(fontsize=13)
plt.grid(False)
plt.tight_layout()
plt.show()

# conditional response probabilities

In [None]:
poss_outcomes = np.arange(-9, 10)
poss_outcomes = np.delete(poss_outcomes, 9)  # Remove the element at index 9 (value 0)

# create matrix of actual transitions
trans_sims = np.zeros((len(pres_indices), n_sims))

for subj in range(n_sims):
    currentSub = recall_sims[:, subj]
    
    for trial in range(9):
        if currentSub[trial + 1] > 0:
            trans_sims[trial, subj] = currentSub[trial + 1] - currentSub[trial]
        else:
            trans_sims[trial, subj] = 0

# create matrix of all possible transitions
possTransFrame = []

for subj in range(n_sims):
    currentSub = recall_sims[:, subj]
    possTrans_sims = np.zeros((10, 9))
    
    for trial in range(9):
        if currentSub[trial + 1] > 0:
            currentTrial = currentSub[trial]
            itemTally = currentSub[0:trial + 1]
            possPositions = pres_indices[~np.isin(pres_indices, itemTally)]
            possTransitions = possPositions - currentTrial
            
            for poss in range(len(possTransitions)):
                possTrans_sims[trial, poss] = possTransitions[poss]
    
    possTransFrame.append(possTrans_sims)

possTransFrame = np.vstack(possTransFrame)

In [None]:
# create CRP
actual_transitions = trans_sims[trans_sims != 0]
possible_transitions = possTransFrame[possTransFrame != 0]

# Count frequencies
tab_a_t_unique, tab_a_t_counts = np.unique(actual_transitions, return_counts=True)
tab_p_t_unique, tab_p_t_counts = np.unique(possible_transitions, return_counts=True)

tab_a_t = pd.DataFrame({
    'actual_transitions': tab_a_t_unique,
    'Freq': tab_a_t_counts
})

tab_p_t = pd.DataFrame({
    'possible_transitions': tab_p_t_unique,
    'Freq': tab_p_t_counts
})

# Filter
tab_a_t = tab_a_t[(tab_a_t['actual_transitions'] < 6) & (tab_a_t['actual_transitions'] > -6)]
tab_p_t = tab_p_t[(tab_p_t['possible_transitions'] < 6) & (tab_p_t['possible_transitions'] > -6)]

# Calculate CRP
crp = tab_a_t['Freq'].values / tab_p_t['Freq'].values

crps = pd.DataFrame({
    'transitions': tab_a_t['actual_transitions'].values,
    'crp': crp
})

In [None]:
# plot CRP
plt.figure(figsize=(8, 5))
plt.plot(crps['transitions'], crps['crp'], 'o-', linewidth=2, markersize=8)

# Add white segments to mask the gap at lag 0 (between -1 and 1)
mask_indices = (crps['transitions'] == -1) | (crps['transitions'] == 1)
if mask_indices.sum() == 2:
    idx_neg1 = crps[crps['transitions'] == -1].index[0]
    idx_pos1 = crps[crps['transitions'] == 1].index[0]
    plt.plot([-1, 1], [crps.loc[idx_neg1, 'crp'], crps.loc[idx_pos1, 'crp']], 
             'w-', linewidth=3, zorder=10)

plt.xlabel('lag', fontsize=15)
plt.ylabel('conditional response probability', fontsize=15)
plt.xticks(fontsize=13)
plt.yticks(fontsize=13)
plt.grid(False)
plt.tight_layout()
plt.show()

# weight matrices

In [None]:
# re-organize weight matrices for plotting (so as to view them by serial position)
net_w_fc_inorder = net_w_fc[:, pres_indices - 1]  # Convert to 0-indexed
net_w_cf_inorder = net_w_cf[pres_indices - 1, :]  # Convert to 0-indexed

In [None]:
# set up weight matrices plots
fig, axes = plt.subplots(1, 2, figsize=(14, 5))

# Feature to context weight matrix
im1 = axes[0].imshow(net_w_fc_inorder, cmap='YlOrRd', aspect='auto')
axes[0].set_title('feature to context weight matrix', fontsize=16)
axes[0].set_xlabel('serial position', fontsize=14)
axes[0].set_ylabel('', fontsize=14)
axes[0].set_xticks(np.arange(len(pres_indices)))
axes[0].set_xticklabels(np.arange(1, len(pres_indices) + 1))
plt.colorbar(im1, ax=axes[0])

# Context to feature weight matrix
im2 = axes[1].imshow(net_w_cf_inorder.T, cmap='Reds', aspect='auto')
axes[1].set_title('context to feature weight matrix', fontsize=16)
axes[1].set_xlabel('serial position', fontsize=14)
axes[1].set_ylabel('', fontsize=14)
axes[1].set_xticks(np.arange(len(pres_indices)))
axes[1].set_xticklabels(np.arange(1, len(pres_indices) + 1))
plt.colorbar(im2, ax=axes[1])

plt.tight_layout()
plt.show()

# correlation matrices

In [None]:
# set up correlation plots
corr_fc = np.corrcoef(net_w_fc_inorder.T)
corr_cf = np.corrcoef(net_w_cf_inorder)

fig, axes = plt.subplots(1, 2, figsize=(14, 5))

# Feature to context correlation matrix
im1 = axes[0].imshow(corr_fc, cmap='YlOrRd', aspect='auto', vmin=-1, vmax=1)
axes[0].set_title('feature to context weight matrix', fontsize=16)
axes[0].set_xlabel('serial position', fontsize=14)
axes[0].set_ylabel('', fontsize=14)
axes[0].set_xticks(np.arange(len(pres_indices)))
axes[0].set_xticklabels(np.arange(1, len(pres_indices) + 1))
axes[0].set_yticks(np.arange(len(pres_indices)))
axes[0].set_yticklabels(np.arange(1, len(pres_indices) + 1))
plt.colorbar(im1, ax=axes[0])

# Context to feature correlation matrix
im2 = axes[1].imshow(corr_cf, cmap='YlOrRd', aspect='auto', vmin=-1, vmax=1)
axes[1].set_title('context to feature weight matrix', fontsize=16)
axes[1].set_xlabel('serial position', fontsize=14)
axes[1].set_ylabel('', fontsize=14)
axes[1].set_xticks(np.arange(len(pres_indices)))
axes[1].set_xticklabels(np.arange(1, len(pres_indices) + 1))
axes[1].set_yticks(np.arange(len(pres_indices)))
axes[1].set_yticklabels(np.arange(1, len(pres_indices) + 1))
plt.colorbar(im2, ax=axes[1])

plt.tight_layout()
plt.show()