In [1]:
#!%load_ext autoreload
#!%autoreload 2

In [2]:
# | parametrization

import json
from tqdm import trange
import jax.numpy as jnp
import numpy as np
from cmrt.fitting import make_subject_trial_masks
from cmrt.likelihood import MemorySearchLikelihoodFnGenerator as LikelihoodFnGenerator
from cmrt.cmr import CMRFactory as model_factory
from cmrt.helpers import load_data, generate_trial_mask
from jax import lax

fit_path = "fits/HerrKaha24_BaseCMR_test.json"

data_name = "HerrKaha24.h5"
data_query = "data['subject'] > -1"
data_path = "data/HerrKaha24.h5"

use_progress_bar = True

In [3]:
# | load data

with open(fit_path) as f:
    fit_result = json.load(f)
    if "subject" not in fit_result["fits"]:
        fit_result["fits"]["subject"] = fit_result["subject"]


data = load_data(data_path)
trial_mask = generate_trial_mask(data, data_query)
max_size = np.max(data["pres_itemnos"])
connections = jnp.zeros((max_size, max_size))

subjects = data["subject"].flatten()

likelihood_generator = LikelihoodFnGenerator(model_factory, data, connections)

subject_trial_masks, unique_subjects = make_subject_trial_masks(trial_mask, subjects)

In [36]:
subject_range = (
            trange(len(unique_subjects)) if use_progress_bar else range(len(unique_subjects))
        )

for s in subject_range:

    if np.sum(subject_trial_masks[s]) == 0:
        continue

    trial_indices = jnp.where(subject_trial_masks[s])[0]
    parameters = {key: fit_result['fits'][key][s] for key in fit_result['fits'] if key != "subject"}        

    for trial_index in trial_indices:

        model = likelihood_generator.init_model_for_retrieval(trial_index, parameters)
        trial = likelihood_generator.trials[trial_index]
        trial = jnp.concat((trial[jnp.nonzero(trial)], jnp.array([0])))
        reaction_times = data['irt'][trial_index][jnp.nonzero(trial)]

        recall_positions = jnp.arange(1, len(trial) + 1)[jnp.nonzero(trial)]

        # whether the current recall attempt continues or terminates retrieval
        coding = trial == 0

        # pre-scaled memory activations for each retrieval attempt
        model, prescaled_activations = lax.scan(
            lambda m, c: (m.retrieve(c), jnp.dot(m.context.state, m.mcf.state) * m.recallable), model, trial
        )

        # converted to an average activation over recallable items at each retrieval attempt
        average_activation = jnp.sum(prescaled_activations, axis=1) / jnp.sum(
            prescaled_activations != 0, axis=1)
        
        assert(False)

  0%|          | 0/456 [00:00<?, ?it/s]


AssertionError: 

In [38]:
recall_positions, average_activation, coding, reaction_times

(Array([ 1,  2,  3,  4,  5,  6,  7,  8,  9, 10], dtype=int32),
 Array([197.8275 , 188.14159, 185.1032 , 185.62675, 186.85577, 187.96103,
        188.77097, 189.12077, 187.81125, 185.07805, 180.34671],      dtype=float32),
 Array([False, False, False, False, False, False, False, False, False,
        False,  True], dtype=bool),
 Array([ 6549,  6199,  3782,  3500,  2539,  2840,  3093, 11515,  3649,
         5263], dtype=int32))

In [None]:
jnp.max(lax.scan(
        lambda m, c: (m.retrieve(c), m.activations()), model, trial
    )[1], axis=1)

Array([1.0000001 , 1.0000001 , 1.0000001 , 1.0000001 , 1.0000001 ,
       1.0000001 , 1.0000001 , 1.0000001 , 1.0000001 , 1.0000001 ,
       0.00687336, 0.00687336, 0.00687336, 0.00687336, 0.00687336,
       0.00687336, 0.00687336, 0.00687336, 0.00687336, 0.00687336,
       0.00687336, 0.00687336, 0.00687336, 0.00687336, 0.00687336,
       0.00687336, 0.00687336, 0.00687336], dtype=float32)

In [None]:
likelihood_generator.present_and_predict_trials(trial_indices, parameters).shape

(24, 28)

In [None]:
{key: jnp.array(val) for key, val in fit_result["fits"].items()}

{'encoding_drift_rate': Array([0.44938543, 0.47920167, 0.84217674, 0.5360002 , 0.26933333,
        0.59014446, 0.54241186, 0.45352584, 0.6070481 , 0.6147673 ,
        0.71054256, 0.5774032 , 0.4422168 , 0.6476088 , 0.68213236,
        0.7848329 , 0.60253024, 0.6824302 , 0.8834518 , 0.62867755,
        0.7179525 , 0.51034236, 0.80288416, 0.4265573 , 0.6202975 ,
        0.51358074, 0.4120547 , 0.65794057, 0.6196855 , 0.3717013 ,
        0.36296645, 0.7439885 , 0.50260615, 0.71665114, 0.8376869 ,
        0.79982376, 0.5113187 , 0.2737898 , 0.40850675, 0.3551169 ,
        0.5393236 , 0.6131607 , 0.25954995, 0.21783216, 0.6446211 ,
        0.59626055, 0.5939617 , 0.47423947, 0.30032685, 0.29033118,
        0.40712202, 0.6242779 , 0.5156292 , 0.33245167, 0.45459023,
        0.09023021, 0.5715727 , 0.26975113, 0.46545142, 0.29335165,
        0.30920154, 0.4608867 , 0.54023993, 0.6755189 , 0.67787045,
        0.31482914, 0.34914058, 0.7859265 , 0.5226945 , 0.2876349 ,
        0.6752981 , 0.778