In [None]:
import os
import json
import numpy as np
import jax.numpy as jnp

# Factory A
from jaxcmr.cmr import MixedCMRFactory as MixedCMRFactoryA

# Factory B
from jaxcmr.instance_cmr import MixedCMRFactory as MixedCMRFactoryB

from jaxcmr.likelihood import MemorySearchLikelihoodFnGenerator
from jaxcmr.helpers import load_data, find_project_root
from jax import numpy as jnp, lax, vmap
from jaxcmr.math import lb

## 1. Load the Dataset

In [None]:
data_path = "data/HealeyKahana2014.h5"
data_path = os.path.join(find_project_root(), data_path)
data = load_data(data_path)

print("Dataset keys:", list(data.keys()))
print("Number of trials:", data["recalls"].shape[0])

Dataset keys: ['listLength', 'listtype', 'pres_itemids', 'pres_itemnos', 'rec_itemids', 'recalls', 'session', 'subject']
Number of trials: 14112


If no semantic or associative connections are used, we can just supply a zero matrix of the appropriate size.

In [None]:
max_itemno = np.max(data["pres_itemnos"])
connections = jnp.zeros((max_itemno, max_itemno))

## Load First Subject Parameters

In [None]:
fit_results_path = os.path.join("fits", "HealeyKahana2014_InstanceCMR_best_of_1.json")
fit_results_path = os.path.join(find_project_root(), fit_results_path)
with open(fit_results_path, "r") as f:
    fit_results = json.load(f)

print("Fit results keys:", list(fit_results.keys()))
print("Parameter names in fit:", fit_results["fits"].keys())

Fit results keys: ['fixed', 'free', 'fitness', 'fits', 'hyperparameters', 'fit_time', 'data_query', 'model', 'name']
Parameter names in fit: dict_keys(['encoding_drift_rate', 'start_drift_rate', 'recall_drift_rate', 'shared_support', 'item_support', 'learning_rate', 'primacy_scale', 'primacy_decay', 'stop_probability_scale', 'stop_probability_growth', 'mcf_trace_sensitivity', 'choice_sensitivity', 'subject'])


In [None]:
subject_id = data["subject"][0].item()
print(f"Using subject {subject_id}")

# Create a boolean mask for trials belonging to this subject
trial_mask = data["subject"].flatten() == subject_id

print(f"Number of trials for subject {subject_id}: {trial_mask.sum()}")

trial = data["recalls"][trial_mask][0]
trial

Using subject 63
Number of trials for subject 63: 112


Array([15, 16,  9, 13, 14,  1,  2,  0,  0,  0,  0,  0,  0,  0,  0,  0],      dtype=int32)

In [None]:
fit_dict = fit_results["fits"]
param_names = [k for k in fit_dict.keys() if k not in ("subject",)]

# Build a simple dictionary of param_name -> single float value for the chosen subject
params_for_subject = {}
for p in param_names:
    arr = np.array(fit_dict[p], dtype=float)
    params_for_subject[p] = float(arr[0])

print("Subject parameter dictionary:")
params_for_subject

Subject parameter dictionary:


{'encoding_drift_rate': 0.19690460839075208,
 'start_drift_rate': 0.13471814445395458,
 'recall_drift_rate': 0.9184788152966787,
 'shared_support': 75.02002448558682,
 'item_support': 98.32965873913969,
 'learning_rate': 0.6151140503551014,
 'primacy_scale': 53.6643963763364,
 'primacy_decay': 0.8904633246586897,
 'stop_probability_scale': 0.004061401990797298,
 'stop_probability_growth': 0.35607019330468326,
 'mcf_trace_sensitivity': 12.16929668644169,
 'choice_sensitivity': 1.0}

## Initialize Models

In [None]:
present = data["pres_itemnos"][trial_mask][0]

models = [
    MixedCMRFactoryA(data, connections).create_model(0, params_for_subject),
    MixedCMRFactoryB(data, connections).create_model(0, params_for_subject)
]

for i in range(len(models)):
    models[i] = lax.fori_loop(
                0, present.size, lambda i, m: m.experience(present[i]), models[i]
            ).start_retrieving()

Correct outcome probability is...

```
[0.0040614  0.24123573 0.12437733 0.06647404 0.03827588 0.02519201
 0.0199643  0.01901533 0.02062784 0.02403779 0.0289853  0.0354953
 0.04377616 0.05417768 0.0671835  0.08342308 0.10369731]
 ```

In [None]:
print("Outcome Probabilities")
for i in range(len(models)):
    print(models[i].outcome_probabilities())


Outcome Probabilities
[4.0614018e-03 9.9590391e-01 3.2995831e-05 1.2037624e-07 1.1872118e-07
 1.1872098e-07 1.1872098e-07 1.1872098e-07 1.1872098e-07 1.1872098e-07
 1.1872098e-07 1.1872098e-07 1.1872098e-07 1.1872098e-07 1.1872098e-07
 1.1872098e-07 1.1872098e-07]
[4.0614018e-03 9.9590218e-01 3.3114491e-05 2.3909678e-07 2.3744172e-07
 2.3744153e-07 2.3744153e-07 2.3744153e-07 2.3744153e-07 2.3744153e-07
 2.3744153e-07 2.3744153e-07 2.3744153e-07 2.3744153e-07 2.3744153e-07
 2.3744153e-07 2.3744153e-07]


In [None]:
models[0].mcf.state[2]

Array([ 0.      ,  0.      ,  0.      ,  1.      ,  0.      ,  0.      ,
        0.      ,  0.      ,  0.      ,  0.      ,  0.      ,  0.      ,
        0.      ,  0.      ,  0.      ,  0.      ,  0.      , 75.02003 ,
       75.02003 , 98.329666, 75.02003 , 75.02003 , 75.02003 , 75.02003 ,
       75.02003 , 75.02003 , 75.02003 , 75.02003 , 75.02003 , 75.02003 ,
       75.02003 , 75.02003 , 75.02003 ], dtype=float32)

In [None]:

print("Trace Activations")
for i in range(len(models)):
    probe = models[i].mcf._probe.at[: models[i].context.state.size].set(models[i].context.state)
    trace_activations = models[i].mcf.trace_activations(probe) + lb
    print(trace_activations)

print("Item Activations")
for i in range(len(models)):

    activations = models[i].activations()
    try:
        probe = models[i].mcf._probe.at[: models[i].context.state.size].set(models[i].context.state)
        trace_activations = models[i].mcf.trace_activations(probe) + lb
        item_activations = vmap(
            lambda item_index:  models[i].item_activation(item_index, trace_activations)
        )( models[i].unique_items)
        for item_index in models[i].unique_items:
            item_activations = jnp.sum(trace_activations * (models[i].trace_items == item_index))
            print(f"Item {item_index} activation: {item_activations}")
            print(f"Total Relevant traces: {jnp.sum(models[i].trace_items == item_index)}")
    except:
        item_activations = activations

    print(item_activations)
    item_activation_sum = jnp.sum(item_activations)
    print(item_activation_sum)
    # p_stop = 0
    # print(jnp.hstack(
    #         (
    #             p_stop,
    #             (
    #                 (1 - p_stop)
    #                 * item_activations
    #                 / lax.select(item_activation_sum == 0, 1.0, item_activation_sum)
    #             ),
    #         )))
    # print()


Trace Activations
[1.1920929e-07 1.1920929e-07 1.1920929e-07 1.1920929e-07 1.1920929e-07
 1.1920929e-07 1.1920929e-07 1.1920929e-07 1.1920929e-07 1.1920929e-07
 1.1920929e-07 1.1920929e-07 1.1920929e-07 1.1920929e-07 1.1920929e-07
 1.1920929e-07 1.0000001e+00 3.3131546e-05 1.2087136e-07 1.1920949e-07
 1.1920929e-07 1.1920929e-07 1.1920929e-07 1.1920929e-07 1.1920929e-07
 1.1920929e-07 1.1920929e-07 1.1920929e-07 1.1920929e-07 1.1920929e-07
 1.1920929e-07 1.1920929e-07]
[1.1920929e-07 1.1920929e-07 1.1920929e-07 1.1920929e-07 1.1920929e-07
 1.1920929e-07 1.1920929e-07 1.1920929e-07 1.1920929e-07 1.1920929e-07
 1.1920929e-07 1.1920929e-07 1.1920929e-07 1.1920929e-07 1.1920929e-07
 1.1920929e-07 1.0000001e+00 3.3131546e-05 1.2087136e-07 1.1920949e-07
 1.1920929e-07 1.1920929e-07 1.1920929e-07 1.1920929e-07 1.1920929e-07
 1.1920929e-07 1.1920929e-07 1.1920929e-07 1.1920929e-07 1.1920929e-07
 1.1920929e-07 1.1920929e-07]
Item Activations
[1.0000001e+00 3.3131546e-05 1.2087136e-07 1.1920949e