In [1]:
%load_ext autoreload
%autoreload 2

import json
import os
import warnings

import jax.numpy as jnp
import matplotlib.pyplot as plt
from IPython.display import Image, display
import numpy as np
from jax import random
from matplotlib import rcParams  # type: ignore

from jaxcmr.fitting import ScipyDE as fitting_method
from jaxcmr.helpers import (
    generate_trial_mask,
    import_from_string,
    load_data,
    save_dict_to_hdf5,
)
from jaxcmr.likelihood import MemorySearchLikelihoodFnGenerator as loss_fn_generator
from jaxcmr import repetition
from jaxcmr.simulation import simulate_h5_from_h5
from jaxcmr.summarize import summarize_parameters

warnings.filterwarnings("ignore")

## Setup

In [2]:
# repeat params
allow_repeated_recalls = False
filter_repeated_recalls = False
data_tag = "LohnasKahana2014"
data_path = "data/LohnasKahana2014.h5"

# data params
trial_query = "data['list_type'] > 0"
run_tag = "full_best_of_3"

# fitting params
redo_fits = True
model_factory_path = "jaxcmr.models_repfr.weird_cmr.BaseCMRFactory"
model_name = "WeirdCMR"
relative_tolerance = 0.001
popsize = 15
num_steps = 1000
cross_rate = 0.9
diff_w = 0.85
best_of = 3
target_dir = "projects/thesis"

# sim params
redo_sims = True
seed = 0
experiment_count = 50

# figure params
redo_figures = True

parameters = {
    "fixed": {
    },
    "free": {
        "encoding_drift_rate": [2.220446049250313e-16, 0.9999999999999998],
        "start_drift_rate": [2.220446049250313e-16, 0.9999999999999998],
        "recall_drift_rate": [2.220446049250313e-16, 0.9999999999999998],
        "shared_support": [2.220446049250313e-16, 99.9999999999999998],
        "item_support": [2.220446049250313e-16, 99.9999999999999998],
        "learning_rate": [2.220446049250313e-16, 0.9999999999999998],
        "primacy_scale": [2.220446049250313e-16, 99.9999999999999998],
        "primacy_decay": [2.220446049250313e-16, 99.9999999999999998],
        "stop_probability_scale": [2.220446049250313e-16, 0.9999999999999998],
        "stop_probability_growth": [2.220446049250313e-16, 9.9999999999999998],
        "choice_sensitivity": [2.220446049250313e-16, 99.9999999999999998],
    },
}

single_analysis_paths = [
    "jaxcmr.analyses.repcrp.plot_rep_crp",
    "jaxcmr.analyses.backrepcrp.plot_back_rep_crp",
]

comparison_analysis_paths = [
    "jaxcmr.analyses.spc.plot_spc",
    "jaxcmr.analyses.crp.plot_crp",
    "jaxcmr.analyses.pnr.plot_pnr",
    "jaxcmr.analyses.repneighborcrp.plot_repneighborcrp_i2j",
    "jaxcmr.analyses.repneighborcrp.plot_repneighborcrp_j2i",
    "jaxcmr.analyses.repneighborcrp.plot_repneighborcrp_both",
    "jaxcmr.analyses.rpl.plot_rpl",
    "jaxcmr.analyses.rpl.plot_full_rpl",
]

In [3]:
# Parameters
allow_repeated_recalls = True
filter_repeated_recalls = False
data_tag = "RepeatedRecallsKahanaJacobs2000"
data_path = "data/RepeatedRecallsKahanaJacobs2000.h5"
base_data_tag = "KahanaJacobs2000"
trial_query = "jnp.logical_and(data['recall_attempt'] == 1, data['recall_total'] > 0)"
single_analysis_paths = []
comparison_analysis_paths = []
model_name = "WeirdPositionalCMR"
model_factory_path = "jaxcmr.models_repfr.weird_positional_cmr.BaseCMRFactory"
redo_fits = False
redo_sims = False
redo_figures = False
parameters = {"fixed": {"mfc_choice_sensitivity": 1.0}, "free": {"encoding_drift_rate": [2.220446049250313e-16, 0.9999999999999998], "start_drift_rate": [2.220446049250313e-16, 0.9999999999999998], "recall_drift_rate": [2.220446049250313e-16, 0.9999999999999998], "shared_support": [2.220446049250313e-16, 100.0], "item_support": [2.220446049250313e-16, 100.0], "learning_rate": [2.220446049250313e-16, 0.9999999999999998], "primacy_scale": [2.220446049250313e-16, 100.0], "primacy_decay": [2.220446049250313e-16, 100.0], "stop_probability_scale": [2.220446049250313e-16, 0.9999999999999998], "stop_probability_growth": [2.220446049250313e-16, 10.0], "choice_sensitivity": [2.220446049250313e-16, 100.0]}}


In [4]:
# add subdirectories for each product type: json, figures, h5
product_dirs = {}
for product in ["fits", "figures", "simulations"]:
    product_dir = os.path.join(target_dir, product)
    product_dirs[product] = product_dir
    if not os.path.exists(product_dir):
        os.makedirs(product_dir)

data = load_data(data_path)
trial_mask = generate_trial_mask(data, trial_query)

max_size = np.max(data["pres_itemnos"])
connections = jnp.zeros((max_size, max_size))

single_analyses = [import_from_string(path) for path in single_analysis_paths]
comparison_analyses = [import_from_string(path) for path in comparison_analysis_paths]
model_factory = import_from_string(model_factory_path)

# derive list of query parameters from keys of `parameters`
parameters['fixed']['allow_repeated_recalls'] = allow_repeated_recalls
query_parameters = list(parameters["free"].keys())

# make sure repeatedrecalls is in either both data_tag or data_path, or is in neither
if "repeatedrecalls" in data_tag.lower() or "repeatedrecalls" in data_path.lower():
    if "repeatedrecalls" not in data_tag.lower() and "repeatedrecalls" not in data_path.lower():
        raise ValueError(
            "If 'repeatedrecalls' is in data_tag or data_path, it must be in both."
        )

## Fitting

In [5]:
fit_path = os.path.join(product_dirs["fits"], f"{data_tag}_{model_name}_{run_tag}.json")
print(fit_path)

if os.path.exists(fit_path) and not redo_fits:
    with open(fit_path) as f:
        results = json.load(f)
        if "subject" not in results["fits"]:
            results["fits"]["subject"] = results["subject"]

else:
    base_params = parameters["fixed"]
    bounds = parameters["free"]
    fitter = fitting_method(
        data,
        connections,
        base_params,
        model_factory,
        loss_fn_generator,
        hyperparams={
            "num_steps": num_steps,
            "pop_size": popsize,
            "relative_tolerance": relative_tolerance,
            "cross_over_rate": cross_rate,
            "diff_w": diff_w,
            "progress_bar": True,
            "display_iterations": False,
            "bounds": bounds,
            "best_of": best_of,
        },
    )

    results = fitter.fit(trial_mask)
    results = dict(results)

    with open(fit_path, "w") as f:
        json.dump(results, f, indent=4)

results["data_query"] = trial_query
results["model"] = model_name
results["name"] = f"{data_tag}_{model_name}_{run_tag}"

with open(fit_path, "w") as f:
    json.dump(results, f, indent=4)

print(
    summarize_parameters([results], query_parameters, include_std=True, include_ci=True)
)

projects/thesis/fits/RepeatedRecallsKahanaJacobs2000_WeirdPositionalCMR_full_best_of_3.json


  0%|          | 0/19 [00:00<?, ?it/s]

Subject=200, Fitness=3681.031494140625:   0%|          | 0/19 [04:49<?, ?it/s]

Subject=200, Fitness=3681.031494140625:   5%|▌         | 1/19 [04:49<1:26:57, 289.87s/it]

Subject=201, Fitness=3589.544677734375:   5%|▌         | 1/19 [10:39<1:26:57, 289.87s/it]

Subject=201, Fitness=3589.544677734375:  11%|█         | 2/19 [10:39<1:32:01, 324.80s/it]

Subject=205, Fitness=3370.51416015625:  11%|█         | 2/19 [15:39<1:32:01, 324.80s/it] 

Subject=205, Fitness=3370.51416015625:  16%|█▌        | 3/19 [15:39<1:23:37, 313.61s/it]

Subject=206, Fitness=4166.34912109375:  16%|█▌        | 3/19 [20:39<1:23:37, 313.61s/it]

Subject=206, Fitness=4166.34912109375:  21%|██        | 4/19 [20:39<1:17:06, 308.45s/it]

Subject=210, Fitness=3277.752685546875:  21%|██        | 4/19 [26:27<1:17:06, 308.45s/it]

Subject=210, Fitness=3277.752685546875:  26%|██▋       | 5/19 [26:27<1:15:15, 322.54s/it]

Subject=215, Fitness=5601.51708984375:  26%|██▋       | 5/19 [32:47<1:15:15, 322.54s/it] 

Subject=215, Fitness=5601.51708984375:  32%|███▏      | 6/19 [32:47<1:14:05, 341.95s/it]

Subject=220, Fitness=3539.63720703125:  32%|███▏      | 6/19 [37:26<1:14:05, 341.95s/it]

Subject=220, Fitness=3539.63720703125:  37%|███▋      | 7/19 [37:26<1:04:19, 321.59s/it]

Subject=230, Fitness=4454.98876953125:  37%|███▋      | 7/19 [42:36<1:04:19, 321.59s/it]

Subject=230, Fitness=4454.98876953125:  42%|████▏     | 8/19 [42:36<58:15, 317.81s/it]  

Subject=240, Fitness=5342.37158203125:  42%|████▏     | 8/19 [48:30<58:15, 317.81s/it]

Subject=240, Fitness=5342.37158203125:  47%|████▋     | 9/19 [48:30<54:50, 329.04s/it]

Subject=256, Fitness=2964.753173828125:  47%|████▋     | 9/19 [53:15<54:50, 329.04s/it]

Subject=256, Fitness=2964.753173828125:  53%|█████▎    | 10/19 [53:15<47:20, 315.66s/it]

Subject=299, Fitness=1787.01708984375:  53%|█████▎    | 10/19 [56:52<47:20, 315.66s/it] 

Subject=299, Fitness=1787.01708984375:  58%|█████▊    | 11/19 [56:52<38:02, 285.35s/it]

Subject=300, Fitness=4190.390625:  58%|█████▊    | 11/19 [1:01:29<38:02, 285.35s/it]   

Subject=300, Fitness=4190.390625:  63%|██████▎   | 12/19 [1:01:29<33:00, 282.90s/it]

Subject=301, Fitness=5163.38232421875:  63%|██████▎   | 12/19 [1:09:03<33:00, 282.90s/it]

Subject=301, Fitness=5163.38232421875:  68%|██████▊   | 13/19 [1:09:03<33:27, 334.63s/it]

Subject=303, Fitness=3334.4501953125:  68%|██████▊   | 13/19 [1:15:14<33:27, 334.63s/it] 

Subject=303, Fitness=3334.4501953125:  74%|███████▎  | 14/19 [1:15:14<28:48, 345.69s/it]

Subject=305, Fitness=4706.203125:  74%|███████▎  | 14/19 [1:19:58<28:48, 345.69s/it]    

Subject=305, Fitness=4706.203125:  79%|███████▉  | 15/19 [1:19:58<21:47, 326.97s/it]

Subject=306, Fitness=4406.20703125:  79%|███████▉  | 15/19 [1:24:10<21:47, 326.97s/it]

Subject=306, Fitness=4406.20703125:  84%|████████▍ | 16/19 [1:24:10<15:13, 304.34s/it]

Subject=307, Fitness=2882.212646484375:  84%|████████▍ | 16/19 [1:28:43<15:13, 304.34s/it]

Subject=307, Fitness=2882.212646484375:  89%|████████▉ | 17/19 [1:28:43<09:50, 295.15s/it]

Subject=308, Fitness=3359.059814453125:  89%|████████▉ | 17/19 [1:33:16<09:50, 295.15s/it]

Subject=308, Fitness=3359.059814453125:  95%|█████████▍| 18/19 [1:33:16<04:48, 288.26s/it]

Subject=666, Fitness=1617.85888671875:  95%|█████████▍| 18/19 [1:35:30<04:48, 288.26s/it] 

Subject=666, Fitness=1617.85888671875: 100%|██████████| 19/19 [1:35:30<00:00, 242.07s/it]

Subject=666, Fitness=1617.85888671875: 100%|██████████| 19/19 [1:35:30<00:00, 301.61s/it]

| | | RepeatedRecallsKahanaJacobs2000 WeirdPositionalCMR full best of 3 |
|---|---|---|
| fitness | mean | 3759.75 +/- 516.93 |
| | std | 1043.91 |
| encoding drift rate | mean | 0.90 +/- 0.03 |
| | std | 0.05 |
| start drift rate | mean | 0.55 +/- 0.08 |
| | std | 0.17 |
| recall drift rate | mean | 0.81 +/- 0.04 |
| | std | 0.08 |
| shared support | mean | 45.45 +/- 19.28 |
| | std | 38.93 |
| item support | mean | 39.56 +/- 16.54 |
| | std | 33.40 |
| learning rate | mean | 0.13 +/- 0.04 |
| | std | 0.08 |
| primacy scale | mean | 7.69 +/- 3.97 |
| | std | 8.02 |
| primacy decay | mean | 16.90 +/- 10.81 |
| | std | 21.83 |
| stop probability scale | mean | 0.01 +/- 0.01 |
| | std | 0.02 |
| stop probability growth | mean | 0.46 +/- 0.09 |
| | std | 0.18 |
| choice sensitivity | mean | 75.46 +/- 9.69 |
| | std | 19.57 |






## Simulation

In [6]:
sim_path = os.path.join(
    product_dirs["simulations"], f"{data_tag}_{model_name}_{run_tag}.h5"
)
print(sim_path)

with open(fit_path) as f:
    results = json.load(f)
    if "subject" not in results["fits"]:
        results["fits"]["subject"] = results["subject"]


rng = random.PRNGKey(seed)
rng, rng_iter = random.split(rng)
trial_mask = generate_trial_mask(data, trial_query)
params = {key: jnp.array(val) for key, val in results["fits"].items()}  # type: ignore

if os.path.exists(sim_path) and not redo_sims:
    sim = load_data(sim_path)
    print(f"Loaded from {sim_path}")

else:
    sim = simulate_h5_from_h5(
        model_factory=model_factory,
        dataset=data,
        connections=connections,
        parameters=params,
        trial_mask=trial_mask,
        experiment_count=experiment_count,
        rng=rng_iter,
    )

    save_dict_to_hdf5(sim, sim_path)
    print(f"Saved to {sim_path}")

if filter_repeated_recalls:
    sim['recalls'] = repetition.filter_repeated_recalls(sim['recalls'])

params

projects/thesis/simulations/RepeatedRecallsKahanaJacobs2000_WeirdPositionalCMR_full_best_of_3.h5


Saved to projects/thesis/simulations/RepeatedRecallsKahanaJacobs2000_WeirdPositionalCMR_full_best_of_3.h5


{'encoding_drift_rate': Array([0.97380733, 0.89902353, 0.8902182 , 0.9005012 , 0.88293475,
        0.8372477 , 0.80210084, 0.8204367 , 0.9540917 , 0.8738369 ,
        0.92977947, 0.85710466, 0.95551026, 0.9013666 , 0.95778465,
        0.97172534, 0.8272172 , 0.87720025, 0.9085238 ], dtype=float32),
 'start_drift_rate': Array([0.35395676, 0.59607995, 0.5256642 , 0.5856657 , 0.70755273,
        0.70429486, 0.6538253 , 0.7312479 , 0.7083114 , 0.6502304 ,
        0.38177982, 0.7673987 , 0.6122898 , 0.38500515, 0.13249072,
        0.32248506, 0.58901453, 0.42258617, 0.6169741 ], dtype=float32),
 'recall_drift_rate': Array([0.91724575, 0.79573745, 0.79249686, 0.8582841 , 0.73545647,
        0.7471737 , 0.69940937, 0.75860375, 0.9822245 , 0.84242874,
        0.79738605, 0.7925049 , 0.8557381 , 0.74275434, 0.8834948 ,
        0.92511046, 0.79366165, 0.6545081 , 0.8126339 ], dtype=float32),
 'shared_support': Array([11.087187 , 75.372505 ,  9.134972 , 75.0989   , 87.556335 ,
        97.32288  ,

## Figures

In [7]:
#|code-summary: single-dataset views

for analysis in single_analyses:
    figure_str = f"{data_tag}_{model_name}_{run_tag}_{analysis.__name__[5:]}.png"
    figure_path = os.path.join(product_dirs["figures"], figure_str)
    print(f"![]({figure_path})")

    # if redo_figures, check if figure already exists and don't redo
    # we want to display the figure here if it already exists
    if os.path.exists(figure_path) and not redo_figures:
        display(Image(filename=figure_path))
    else:
        color_cycle = [each["color"] for each in rcParams["axes.prop_cycle"]]

        # Create a mask for data using np.isin for the selected list types
        trial_mask = generate_trial_mask(sim, trial_query)

        axis = analysis(
            datasets=[sim],
            trial_masks=[np.array(trial_mask)],
            color_cycle=color_cycle,
            labels=["First", "Second"],
            contrast_name="Repetition Index",
            axis=None,
            distances=None,
        )

        plt.savefig(figure_path, bbox_inches="tight", dpi=600)
        plt.show()

In [8]:
#| code-summary: mixed vs control views

for analysis in comparison_analyses:

    figure_str = f"{data_tag}_{model_name}_{run_tag}_{analysis.__name__[5:]}.png"
    figure_path = os.path.join(product_dirs["figures"], figure_str)
    print(f"![]({figure_path})")

    # if redo_figures, check if figure already exists and don't redo
    # we want to display the figure here if it already exists
    if os.path.exists(figure_path) and not redo_figures:
        display(Image(filename=figure_path))
        continue

    color_cycle = [each["color"] for each in rcParams["axes.prop_cycle"]]

    # Create a mask for data using np.isin for the selected list types
    trial_mask = generate_trial_mask(data, trial_query)
    sim_trial_mask = generate_trial_mask(sim, trial_query)

    axis = analysis(
        datasets=[sim, data],
        trial_masks=[np.array(sim_trial_mask), np.array(trial_mask)],
        color_cycle=color_cycle,
        labels=["Model", "Data"],
        contrast_name="Source",
        axis=None,
        distances=None,
    )

    plt.savefig(figure_path, bbox_inches="tight", dpi=600)
    plt.show()