# Context Maintenance and Retrieval (CMR) Model

This notebook implements the CMR model (Polyn, Norman & Kahana, 2009), a computational model of human episodic memory that explains how people recall lists of items. The model combines:

- **Episodic associations**: Learned connections between items and the temporal context in which they were presented
- **Semantic associations**: Pre-existing associations between items based on their meaning
- **Context drift**: The gradual change in mental context over time during encoding and retrieval

The model simulates classic memory phenomena including primacy effects (better recall of early items), recency effects (better recall of recent items), and temporal contiguity (tendency to recall items presented close together in time).

In [None]:
# Import Python packages
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
from scipy import stats
from scipy.stats import zscore
import warnings
warnings.filterwarnings('ignore')

# Set random seed for reproducibility
np.random.seed(42)

# Parameters to Change

These parameters control the relative strength of semantic versus episodic retrieval routes:

- **sem**: Weight of semantic associations (meaning-based connections)
- **episodic**: Weight of episodic associations (temporal context-based connections)

The weights are normalized so they sum to 1, determining the relative contribution of each route during memory retrieval. Equal weights (0.5/0.5) mean both routes contribute equally to recall.

In [None]:
# strength of drift
# scale_drift = 0.4  # scaling parameter (not used if inputting manual drift)

# strength of semantic + episodic route
sem = 0.5
episodic = 0.5

# normalize
sem_weight = sem / (episodic + sem)
episodic_weight = episodic / (episodic + sem)

# Reward Sequence

Defines the presentation order and values of items in the simulated memory list:

- **pres_indices**: Random permutation determining the order items are presented
- **sequence**: Reward values associated with each position (used to calculate prediction errors)

The "primacy" sequence used here shows gradually decreasing values (54 → 45), which can affect context drift when prediction error-based drift is enabled. This simulates scenarios where reward or value changes over the course of a sequence.

In [None]:
# random sample (1-indexed to match R behavior, then convert to 0-indexed)
pres_indices = np.random.permutation(10) + 1

# high RPE within sequence (first outcome is "reward expectation" at 0)
# sequence = np.array([0, 54, 57, 56, 53, 55, 7, 5, 4, 6, 3])  # initial expectation is 0

# primacy (first outcome is "reward expectation" at 50)
sequence = np.array([50, 54, 53, 52, 51, 50, 49, 48, 47, 46, 45])

# Dynamic Drift

**Context drift** refers to how quickly the mental context changes during encoding. This section (currently commented out) implements dynamic drift based on prediction errors:

**Logic**:
1. Calculate absolute prediction errors (difference between consecutive rewards)
2. Z-score normalize the prediction errors
3. Scale by `scale_drift` parameter
4. Cap maximum drift at 1.0

**Interpretation**: Larger prediction errors cause greater context drift, meaning unexpected events create stronger temporal boundaries in memory. This can enhance memory for items around surprising events.

This approach is currently disabled in favor of manual drift specification below.

In [None]:
# dynamic beta
# absPE = np.zeros((len(pres_indices), 1))

# take absolute difference of previous reward with current reward
# (first "prediction error" is reward - initial expectation)
# for seq in range(len(sequence) - 1):
#     absPE[seq] = abs(sequence[seq] - sequence[seq + 1])

# take absolute value of z-scored absPE and multiply by scaling parameter
# B_encD = np.abs(zscore(absPE)) * scale_drift

# B_encD = np.where(B_encD > 1, 1, B_encD)  # if it's over 1, make it 1

# Manual Drift

Manually specified context integration rates (β values) for each item during encoding:

- **B_encD[0] = 1.0**: First item causes maximum context drift (complete context update)
- **B_encD[1:] = 0.65**: Subsequent items cause moderate drift

**Interpretation**: High drift for the first item creates a strong primacy effect - the first item becomes strongly associated with the initial context state. Lower drift for subsequent items means they blend more with the evolving context, creating temporal associations between nearby items.

In [None]:
# manual drift (comment out if using above or stable drift)
B_encD = np.array([1, 0.65, 0.65, 0.65, 0.65, 0.65, 0.65, 0.65, 0.65, 0.65])

# Semantic Matrix

Defines pre-existing semantic associations between items:

- **Identity matrix structure**: Each item has activation of 1.0 to itself, 0 to all others
- This represents **orthogonal** items with no semantic similarity

**Interpretation**: In this simple version, items have no inherent semantic relationships. The semantic route will primarily help retrieve items based on their direct activation rather than meaning-based associations. For more realistic simulations, this could be replaced with a structured similarity matrix based on word associations, categories, etc.

In [None]:
# create "semantic" matrix
# each item has an activation of 1 (and to no other units)
sem_mat = np.eye(len(pres_indices))

# Original Parameters (Polyn, Norman & Kahana, 2009)

Core model parameters calibrated to fit human free recall data:

## Network Architecture
- **gamma_fc (0.581)**: Strength of pre-existing feature-context associations
- **eye_fc, eye_cf**: Initialize connection matrices

## Encoding Parameters
- **B_enc (0.745)**: Standard context integration rate during study
- **lrate_fc_enc**: Learning rate for feature→context connections
- **lrate_cf_enc**: Learning rate for context→feature connections

## Retrieval Parameters
- **B_rec (0.36)**: Context integration during recall (slower than encoding)
- **lrate_fc_rec, lrate_cf_rec (0)**: No new learning during recall

## Decision Competition
- **thresh (1)**: Activation threshold for successful recall
- **L (0.375)**: Lateral inhibition between competing items
- **K (0.091)**: Decay rate of accumulator activation
- **eta (0.3699)**: Noise in the decision process
- **tau (413 ms)**: Time constant for accumulation

These parameters create a competitive retrieval process where items race to threshold, with noise and inhibition producing realistic recall dynamics.

In [None]:
# SET PARAMETERS

# for creating the network
gamma_fc = 0.581  # relative strength of pre-existing associations on connections feature-context
eye_fc = 1 - gamma_fc  # if items are rep as orthonormal vectors, identity matrix ("eye")
eye_cf = 0

# during encoding
B_enc = 0.745  # vector of context integration rate at encoding, dynamic alternative above
lrate_fc_enc = gamma_fc  # feature-to-context during encoding
lrate_cf_enc = 1  # context-to-feature during encoding

# during recall
B_rec = 0.36  # vector of context integration at recall
lrate_fc_rec = 0  # feature-to-context during recall
lrate_cf_rec = 0  # context-to-feature during recall
thresh = 1  # threshold for an accumulating element to win the decision competition (fixed at 1)
rec_time = 90000  # max recall process (interpreted as 90 seconds)
dt = 100  # time constant on decision process
L = 0.375  # lateral inhibition between units in decision competition
K = 0.091  # decay rate for the accumulating elements in decision competition
eta = 0.3699  # standard deviation of gaussian noise term in decision competition
tau = 413  # time constant in decision competition

n_sims = 1000  # number of simulations
recall_sims = np.zeros((len(pres_indices), n_sims))
times_sims = np.zeros((len(pres_indices), n_sims))

# Model Run

Main simulation loop implementing the CMR model's encoding and retrieval phases.

## Overall Structure
For each simulation:
1. **Initialize network**: Create feature and context layers, connection weights
2. **Encoding phase**: Present items sequentially, update context and associations
3. **Retrieval phase**: Use context cues to retrieve items via competitive accumulation

## Encoding Phase Logic
For each presented item:
1. Activate the item's feature representation
2. Calculate context input from feature layer (via learned associations)
3. Update context using drift equation: `c_new = ρ·c_old + β·c_in`
   - **ρ (rho)**: Ensures context vector stays normalized
   - **β (beta)**: Controls integration rate (how much context changes)
4. Learn associations between active features and current context

**Key insight**: Items encoded in similar contexts will be more likely to cue each other during recall.

## Retrieval Phase Logic
Iterative process until time runs out:
1. Use current context to activate features (via episodic + semantic routes)
2. Run accumulator race:
   - Each item accumulates activation based on its cue strength
   - Lateral inhibition suppresses competitors
   - Noise creates variability
   - Decay prevents runaway activation
3. When an item crosses threshold:
   - Record the recall
   - Mark item as retrieved (prevent repetitions)
   - Use retrieved item to update context
   - Continue with new context state

**Key insight**: Each recall changes context, which determines what gets recalled next - creating temporal contiguity effects.

In [None]:
for sims in range(n_sims):
    
    # NETWORK
    
    # initialize the features and context layers
    net_f = np.zeros((len(pres_indices), 1))
    net_c = np.zeros((len(pres_indices), 1))
    
    # learning rate matrices
    net_lrate_fc = np.zeros((len(pres_indices), len(pres_indices)))
    net_lrate_cf = np.zeros((len(pres_indices), len(pres_indices)))
    
    # the lrate matrices
    net_lrate_fc_enc = np.full((len(pres_indices), len(pres_indices)), lrate_fc_enc)
    net_lrate_cf_enc = np.full((len(pres_indices), len(pres_indices)), lrate_cf_enc)
    net_lrate_fc_rec = np.full((len(pres_indices), len(pres_indices)), lrate_fc_rec)
    net_lrate_cf_rec = np.full((len(pres_indices), len(pres_indices)), lrate_cf_rec)
    
    net_w_fc = np.eye(len(net_c)) * eye_fc  # m_fc eye() creates identity matrices
    net_w_cf = np.eye(len(net_f)) * eye_cf  # m_cf zero
    net_weights = np.zeros((len(pres_indices), len(pres_indices)))
    
    # ENCODING
    
    net_idx = np.arange(len(pres_indices))
    
    for item in range(len(pres_indices)):
        
        # present item
        feature_idx = pres_indices[item] - 1  # Convert to 0-indexed
        
        # activates the indexed feature (each item activates one element)
        net_f = np.zeros((len(pres_indices), 1))
        net_f[feature_idx] = 1
        
        # update context representations
        net_c_in = net_w_fc @ net_f
        
        # normalize vector
        vec = net_c_in
        denom_vec = np.sqrt(vec.T @ vec)[0, 0]
        norm_vec = vec / denom_vec
        net_c_in = norm_vec
        
        # advance context
        c_in = net_c_in
        c = net_c
        
        # set dynamic or stable drift
        B = B_encD[item]  # beta at encoding for dynamic
        # B = B_enc  # beta at encoding if stable
        
        dot_product = (c.T @ c_in)[0, 0]
        rho = np.sqrt(1 + (B**2) * ((dot_product**2) - 1)) - B * dot_product
        updated_c = rho * c + B * c_in
        net_c = updated_c
        
        # determine current learning rate
        lrate_fc = net_lrate_fc_enc
        lrate_cf = net_lrate_cf_enc
        
        # update weights
        
        # w_fc
        delta = (net_c @ net_f.T) * lrate_fc
        net_w_fc = net_w_fc + delta
        
        # w_cf
        delta = (net_f @ net_c.T) * lrate_cf
        net_w_cf = net_w_cf + delta
    
    # RECALL
    
    # set up
    recalls = np.zeros((len(pres_indices), 1))
    times = np.zeros((len(pres_indices), 1))
    
    rec_time_local = 90000
    time_passed = 0
    recall_count = 0
    
    retrieved = np.zeros((len(pres_indices), 1), dtype=bool)
    thresholds = np.ones((len(pres_indices), 1))
    
    # semantic + episodic routes
    net_weights = episodic_weight * net_w_cf + sem_weight * sem_mat
    
    # GO!!!
    
    while time_passed < rec_time_local:
        
        # input to the feature layer, from last context cue
        f_in = net_weights @ net_c
        
        # set max number of cycles
        max_cycles = int((rec_time_local - time_passed) / dt)
        
        # for noise error standard deviation
        dt_tau = dt / tau
        sq_dt_tau = np.sqrt(dt_tau)
        
        # noise matrix
        noise = np.random.normal(0, eta * sq_dt_tau, (len(pres_indices), max_cycles))
        eyeI = ~np.eye(len(pres_indices), dtype=bool)
        lmat = eyeI.astype(float) * L
        
        ncycles = noise.shape[1]
        inds = np.arange(len(pres_indices))
        
        crossed = 0
        
        x = np.zeros((len(pres_indices), 1))
        
        K_array = np.ones((len(pres_indices), 1)) * K
        
        i = 0
        
        # ACCUMULATORS CYCLING
        while i < ncycles and crossed == 0:
            
            # the lateral inhibition felt by each unit
            lx = lmat @ x
            
            # the activity leaking from each unit
            kx = K_array * x
            
            # change in each accumulator
            x = x + ((f_in - kx - lx) * dt_tau + noise[:, i:i+1])
            x[x < 0] = 0
            
            # reset retrieved values, allow them to compete but prevent from accumulating
            reset_these = retrieved & (x >= thresholds)
            x[reset_these] = 0.95 * thresholds[reset_these]
            
            # retrieved items cannot be repeated
            retrievable = ~retrieved
            
            # determine whether any items have crossed thresholds
            crossed = 0
            if np.any(x[retrievable] >= thresholds[retrievable]):
                crossed = 1
                temp_win = x[retrievable] >= thresholds[retrievable]
                temp_ind = inds[retrievable.flatten()]
                winners = temp_ind[temp_win.flatten()]
                
                # if there is a tie, random tiebreak
                if len(winners) > 1:
                    winners = np.array([np.random.choice(winners)])
                
                winner_position = np.where(pres_indices - 1 == winners[0])[0][0]
            
            i = i + 1
        
        # calculate the amount of elapsed time
        time = i * dt
        
        time_passed = time_passed + time
        
        # reactivate item if there has been a retrieval
        if crossed == 1:
            
            # activate the retrieved feature
            net_f = np.zeros((len(pres_indices), 1))
            net_f[winners[0]] = 1
            
            # update context representations
            net_c_in = net_w_fc @ net_f
            
            # normalize vector
            vec = net_c_in
            denom_vec = np.sqrt(vec.T @ vec)[0, 0]
            norm_vec = vec / denom_vec
            net_c_in = norm_vec
            
            # advance context
            c_in = net_c_in
            c = net_c
            B = B_rec  # beta at retrieval
            
            dot_product = (c.T @ c_in)[0, 0]
            rho = np.sqrt(1 + (B**2) * ((dot_product**2) - 1)) - B * dot_product
            updated_c = rho * c + B * c_in
            net_c = updated_c
            
            # determine current learning rate
            lrate_fc = net_lrate_fc_rec
            lrate_cf = net_lrate_cf_rec
            
            # w_fc
            delta = (net_c @ net_f.T) * lrate_fc
            net_w_fc = net_w_fc + delta
            
            # w_cf
            delta = (net_f @ net_c.T) * lrate_cf
            net_w_cf = net_w_cf + delta
            
            # record data
            recall_count = recall_count + 1
            recalls[recall_count - 1, 0] = winner_position + 1  # Convert back to 1-indexed
            times[recall_count - 1, 0] = time_passed
            
            # update retrieved vector
            retrieved[winners[0]] = True
    
    recall_sims[:, sims] = recalls.flatten()
    times_sims[:, sims] = times.flatten()

print(f"Model run complete. Simulated {n_sims} trials.")

# Serial Position Curve

Analyzes recall probability as a function of an item's position in the study list.

## Calculation Logic
For each serial position (1-10):
- Count how many times items from that position were recalled across all simulations
- Divide by total number of simulations to get recall probability

## Expected Pattern
Classic serial position curve shows:
- **Primacy effect**: Higher recall for early items (positions 1-3)
  - Caused by high initial context drift creating distinct temporal context
- **Recency effect**: Higher recall for recent items (positions 8-10)
  - Caused by end-of-list context still being active during recall
- **Middle items**: Lower recall for mid-list positions
  - These items compete with many similar contexts

## Interpretation
The curve shape reveals how temporal context affects memory strength. Items encoded in unique contexts (beginning/end) are easier to retrieve than items in crowded temporal neighborhoods (middle).

In [None]:
# calculate the total proportion of recall given the serial position
position = pd.DataFrame({'position': np.arange(1, len(pres_indices) + 1)})

numSums = np.zeros(len(pres_indices))

for numSum in range(len(numSums)):
    numSums[numSum] = np.sum(recall_sims == (numSum + 1))

recall = numSums / n_sims

prop_recall = pd.DataFrame({
    'position': position['position'],
    'recall': recall
})

print(prop_recall)

## Serial Position Curve Visualization

This plot shows the probability of recalling an item as a function of its study position. The shape reveals fundamental memory processes:

- **Rising curve**: Indicates strong recency effect (later items better recalled)
- **U-shape**: Would indicate both primacy and recency
- **Flat curve**: Would suggest position-independent recall

With high first-item drift (B=1.0) and moderate subsequent drift (B=0.65), we expect strong recency and possible primacy enhancement.

In [None]:
# plot recall success as a function of serial position
plt.figure(figsize=(8, 5))
plt.plot(prop_recall['position'], prop_recall['recall'], 'o-', linewidth=2, markersize=8)
plt.xlabel('serial position', fontsize=15)
plt.ylabel('probability of recall', fontsize=15)
plt.ylim(0, 1)
plt.xticks(fontsize=13)
plt.yticks(fontsize=13)
plt.grid(False)
plt.tight_layout()
plt.show()

# First Recall Probability

Analyzes which items are most likely to be recalled first when retrieval begins.

## Calculation Logic
- Extract the first recalled item from each simulation
- Count frequency of each position being recalled first
- Convert to proportions

## Theoretical Predictions
First recall is strongly influenced by:
1. **Recency**: Recent items have highest context match with end-of-list state
2. **Primacy**: With high initial drift, first item may have unique retrieval advantage
3. **Semantic strength**: Items with strong semantic cues may be recalled first

## Interpretation
Typically shows extreme recency bias - the last item is most likely to be recalled first because recall begins from the end-of-list context state. The probability of first recall (PFR) curve is often steeper than the overall serial position curve.

In [None]:
# determine the proportion of "first recall" items as a function of serial position
first_recall = recall_sims[0, :]
first_recall = first_recall[first_recall > 0]  # Filter out zeros

# Count frequencies
unique, counts = np.unique(first_recall, return_counts=True)
first_recall_table = pd.DataFrame({
    'position': unique.astype(int),
    'freq': counts,
    'prop': counts / n_sims
})

print(first_recall_table)

## First Recall Probability Visualization

Shows the likelihood of each position being the first item recalled. This analysis reveals:

- **Peak position**: Identifies which items have strongest initial retrieval strength
- **Slope steepness**: Indicates how strongly position affects immediate accessibility
- **Distribution spread**: Shows whether recall initiates from diverse or concentrated positions

In free recall, the last few items typically dominate first recalls due to context matching.

In [None]:
# plot first recall proportion as a function of serial position
plt.figure(figsize=(8, 5))
plt.plot(first_recall_table['position'], first_recall_table['prop'], 'o-', linewidth=2, markersize=8)
plt.xlabel('serial position', fontsize=15)
plt.ylabel('probability of first recall', fontsize=15)
plt.ylim(0, 0.8)
plt.xticks(fontsize=13)
plt.yticks(fontsize=13)
plt.grid(False)
plt.tight_layout()
plt.show()

# Conditional Response Probabilities (CRP)

Analyzes the **temporal contiguity effect**: the tendency to recall items that were studied near each other in time.

## Calculation Logic

### Step 1: Calculate Actual Transitions
For each recall sequence, compute the "lag" between consecutive recalls:
- Lag = (position of item N+1) - (position of item N)
- Example: If positions 5→7 are recalled consecutively, lag = +2
- Example: If positions 8→6 are recalled consecutively, lag = -2

### Step 2: Calculate Possible Transitions
At each recall, determine which items could still be recalled:
- Exclude already-recalled items
- Calculate all possible lags from current position

### Step 3: Compute CRP
For each lag value:
- CRP(lag) = P(transition to lag | lag is available)
- CRP(lag) = (# actual transitions of lag) / (# times lag was available)

## Expected Pattern
- **Peak at lag ±1**: Strongest tendency to recall adjacent items
- **Asymmetry**: Often stronger forward (positive lag) than backward
- **Decline with distance**: Probability decreases for distant items
- **Gap at lag 0**: Can't recall the same item twice

## Interpretation
CRP reveals temporal organization in memory. High values near lag ±1 show that retrieving an item reinstates its encoding context, which cues nearby items. This is the signature of context-based retrieval.

In [None]:
poss_outcomes = np.arange(-9, 10)
poss_outcomes = np.delete(poss_outcomes, 9)  # Remove the element at index 9 (value 0)

# create matrix of actual transitions
trans_sims = np.zeros((len(pres_indices), n_sims))

for subj in range(n_sims):
    currentSub = recall_sims[:, subj]
    
    for trial in range(9):
        if currentSub[trial + 1] > 0:
            trans_sims[trial, subj] = currentSub[trial + 1] - currentSub[trial]
        else:
            trans_sims[trial, subj] = 0

# create matrix of all possible transitions
possTransFrame = []

for subj in range(n_sims):
    currentSub = recall_sims[:, subj]
    possTrans_sims = np.zeros((10, 9))
    
    for trial in range(9):
        if currentSub[trial + 1] > 0:
            currentTrial = currentSub[trial]
            itemTally = currentSub[0:trial + 1]
            possPositions = pres_indices[~np.isin(pres_indices, itemTally)]
            possTransitions = possPositions - currentTrial
            
            for poss in range(len(possTransitions)):
                possTrans_sims[trial, poss] = possTransitions[poss]
    
    possTransFrame.append(possTrans_sims)

possTransFrame = np.vstack(possTransFrame)

## CRP Data Processing

This cell computes the conditional response probability for each lag:

1. **Filter transitions**: Focus on lags within ±5 positions (most informative range)
2. **Calculate probabilities**: Divide actual transitions by possible transitions
3. **Result interpretation**:
   - High CRP values: Strong tendency for that lag
   - Low CRP values: Weak tendency or random transitions
   - Relative heights: Compare forward vs. backward recall patterns

In [None]:
# create CRP
actual_transitions = trans_sims[trans_sims != 0]
possible_transitions = possTransFrame[possTransFrame != 0]

# Count frequencies
tab_a_t_unique, tab_a_t_counts = np.unique(actual_transitions, return_counts=True)
tab_p_t_unique, tab_p_t_counts = np.unique(possible_transitions, return_counts=True)

tab_a_t = pd.DataFrame({
    'actual_transitions': tab_a_t_unique,
    'Freq': tab_a_t_counts
})

tab_p_t = pd.DataFrame({
    'possible_transitions': tab_p_t_unique,
    'Freq': tab_p_t_counts
})

# Filter
tab_a_t = tab_a_t[(tab_a_t['actual_transitions'] < 6) & (tab_a_t['actual_transitions'] > -6)]
tab_p_t = tab_p_t[(tab_p_t['possible_transitions'] < 6) & (tab_p_t['possible_transitions'] > -6)]

# Calculate CRP
crp = tab_a_t['Freq'].values / tab_p_t['Freq'].values

crps = pd.DataFrame({
    'transitions': tab_a_t['actual_transitions'].values,
    'crp': crp
})

## CRP Visualization

The lag-CRP curve is a key signature of temporal context effects in memory:

**Key Features to Observe**:
- **Central peak**: Should be highest near lag ±1 (adjacent items)
- **Contiguity gradient**: Probability decreases with increasing lag distance
- **Forward asymmetry**: Often stronger for positive lags (forward in time)
- **White gap at lag 0**: Represents the discontinuity (can't recall same item)

**What this reveals**: The peaked shape demonstrates that context serves as an effective retrieval cue - items studied in similar temporal contexts are recalled together. The asymmetry (if present) suggests forward-going associations are stronger than backward associations during encoding.

In [None]:
# plot CRP
plt.figure(figsize=(8, 5))
plt.plot(crps['transitions'], crps['crp'], 'o-', linewidth=2, markersize=8)

# Add white segments to mask the gap at lag 0 (between -1 and 1)
mask_indices = (crps['transitions'] == -1) | (crps['transitions'] == 1)
if mask_indices.sum() == 2:
    idx_neg1 = crps[crps['transitions'] == -1].index[0]
    idx_pos1 = crps[crps['transitions'] == 1].index[0]
    plt.plot([-1, 1], [crps.loc[idx_neg1, 'crp'], crps.loc[idx_pos1, 'crp']], 
             'w-', linewidth=3, zorder=10)

plt.xlabel('lag', fontsize=15)
plt.ylabel('conditional response probability', fontsize=15)
plt.xticks(fontsize=13)
plt.yticks(fontsize=13)
plt.grid(False)
plt.tight_layout()
plt.show()

# Weight Matrices

Visualizes the learned associations between items and temporal context after encoding.

## Two Association Matrices

### Feature-to-Context (M^FC)
- **Rows**: Context elements
- **Columns**: Item positions (serial order)
- **Values**: How strongly each item activates each context element
- **Interpretation**: Shows what context was active when each item was encoded

### Context-to-Feature (M^CF)
- **Rows**: Item positions
- **Columns**: Context elements  
- **Values**: How strongly each context element retrieves each item
- **Interpretation**: Shows which items are cued by each context state

## Expected Patterns

**Diagonal structure**: Items are most strongly associated with their encoding context

**Gradient/smearing**: Due to context drift, nearby items share similar contexts
- More smearing = more context overlap = stronger temporal associations
- Less smearing = more distinct contexts = weaker temporal associations

**First item distinctiveness**: High initial drift (B=1.0) should create unique first-item pattern

## Functional Significance
These matrices implement the episodic memory route. During retrieval, the current context state activates items via M^CF, and retrieved items update context via M^FC, creating a feedback loop that drives sequential recall.

In [None]:
# re-organize weight matrices for plotting (so as to view them by serial position)
net_w_fc_inorder = net_w_fc[:, pres_indices - 1]  # Convert to 0-indexed
net_w_cf_inorder = net_w_cf[pres_indices - 1, :]  # Convert to 0-indexed

## Weight Matrix Visualization

These heatmaps reveal the structure of learned episodic associations:

**Left plot (Feature→Context)**:
- Warm colors (orange/red): Strong activation of context by item
- Pattern reveals temporal context evolution during encoding
- Examine: Is there a diagonal? How much do patterns overlap?

**Right plot (Context→Feature)**:
- Warm colors (red): Strong retrieval cue strength
- Pattern reveals which items are accessible from each context
- Examine: Can you see temporal gradients? First/last item distinctiveness?

Together, these matrices encode the episodic memory trace that drives context-based retrieval.

In [None]:
# set up weight matrices plots
fig, axes = plt.subplots(1, 2, figsize=(14, 5))

# Feature to context weight matrix
im1 = axes[0].imshow(net_w_fc_inorder, cmap='YlOrRd', aspect='auto')
axes[0].set_title('feature to context weight matrix', fontsize=16)
axes[0].set_xlabel('serial position', fontsize=14)
axes[0].set_ylabel('', fontsize=14)
axes[0].set_xticks(np.arange(len(pres_indices)))
axes[0].set_xticklabels(np.arange(1, len(pres_indices) + 1))
plt.colorbar(im1, ax=axes[0])

# Context to feature weight matrix
im2 = axes[1].imshow(net_w_cf_inorder.T, cmap='Reds', aspect='auto')
axes[1].set_title('context to feature weight matrix', fontsize=16)
axes[1].set_xlabel('serial position', fontsize=14)
axes[1].set_ylabel('', fontsize=14)
axes[1].set_xticks(np.arange(len(pres_indices)))
axes[1].set_xticklabels(np.arange(1, len(pres_indices) + 1))
plt.colorbar(im2, ax=axes[1])

plt.tight_layout()
plt.show()

# Correlation Matrices

Analyzes the similarity structure of learned representations by computing correlations between columns (items) of the weight matrices.

## What Correlations Reveal

### Feature-to-Context Correlations
- Correlates context patterns activated by different items
- **High correlation**: Two items activated similar contexts (encoded nearby in time)
- **Low correlation**: Two items activated distinct contexts (encoded far apart)
- **Pattern**: Should show temporal gradient - higher correlation for items closer in study order

### Context-to-Feature Correlations  
- Correlates the retrieval strength patterns across context states
- **High correlation**: Two items are similarly cued by the same contexts
- **Low correlation**: Two items are cued by different contexts
- **Pattern**: Reflects temporal organization of memory accessibility

## Expected Structure

**Banded diagonal**: Strongest correlations near the diagonal (adjacent items)

**Gradient decay**: Correlation decreases with increasing temporal distance

**Primacy/recency islands**: Possible distinct correlation structures at list boundaries

## Functional Interpretation
These correlation patterns quantify the temporal organization of episodic memory. The gradient structure directly predicts the temporal contiguity effect in recall - items with high context correlation are likely to be recalled together.

In [None]:
# set up correlation plots
corr_fc = np.corrcoef(net_w_fc_inorder.T)
corr_cf = np.corrcoef(net_w_cf_inorder)

fig, axes = plt.subplots(1, 2, figsize=(14, 5))

# Feature to context correlation matrix
im1 = axes[0].imshow(corr_fc, cmap='YlOrRd', aspect='auto', vmin=-1, vmax=1)
axes[0].set_title('feature to context weight matrix', fontsize=16)
axes[0].set_xlabel('serial position', fontsize=14)
axes[0].set_ylabel('', fontsize=14)
axes[0].set_xticks(np.arange(len(pres_indices)))
axes[0].set_xticklabels(np.arange(1, len(pres_indices) + 1))
axes[0].set_yticks(np.arange(len(pres_indices)))
axes[0].set_yticklabels(np.arange(1, len(pres_indices) + 1))
plt.colorbar(im1, ax=axes[0])

# Context to feature correlation matrix
im2 = axes[1].imshow(corr_cf, cmap='YlOrRd', aspect='auto', vmin=-1, vmax=1)
axes[1].set_title('context to feature weight matrix', fontsize=16)
axes[1].set_xlabel('serial position', fontsize=14)
axes[1].set_ylabel('', fontsize=14)
axes[1].set_xticks(np.arange(len(pres_indices)))
axes[1].set_xticklabels(np.arange(1, len(pres_indices) + 1))
axes[1].set_yticks(np.arange(len(pres_indices)))
axes[1].set_yticklabels(np.arange(1, len(pres_indices) + 1))
plt.colorbar(im2, ax=axes[1])

plt.tight_layout()
plt.show()

## Interpretation Guide for Correlation Matrices

**Reading the plots**:
- **Diagonal (perfect correlation)**: Each item correlates perfectly with itself
- **Near-diagonal bands**: Show temporal neighborhood structure
- **Off-diagonal patterns**: Reveal long-range associations
- **Color intensity**: Strength of representational similarity

**What to look for**:
1. **Width of diagonal band**: Indicates temporal resolution of context
   - Narrow band = sharp temporal discrimination
   - Wide band = gradual context change

2. **Asymmetries**: Might reveal forward/backward association differences

3. **Clustering**: Groups of items with similar correlation profiles

These patterns bridge the gap between neural representations and behavioral recall patterns, showing how temporal context similarity translates to recall transitions.