In [1]:
%load_ext autoreload
%autoreload 2

In [None]:
import json

import jax.numpy as jnp
from jax import random

from jaxcmr.helpers import (
    generate_trial_mask,
    import_from_string,
    load_data,
    save_dict_to_hdf5,
)
from jaxcmr.simulation import simulate_h5_from_h5

seed = 0
experiment_count = 50

run_tag = "full_best_of_3"
data_path = "data/{}.h5"
target_dir = "projects/repfr"

data_tags = [
    "LohnasKahana2014",
    # "KahanaJacobs2000",
]

allow_repeated_recalls = [
    False,
    True,
]

trial_queries = [
    "data['list_type'] != 0",
    "jnp.logical_and(data['recall_attempt'] == 1, data['recall_total'] > 0)",
]

model_names = [
    "WeirdCMR",
    "WeirdNoReinstateCMR",
    "WeirdPositionScaleCMR",
    "OutlistCMRDE",
    "WeirdAmaxPositionScaleCMR",
    # "WeirdDriftPositionScaleCMR",
]

model_paths = [
    "jaxcmr.models.weird_cmr.BaseCMRFactory",
    "jaxcmr.models.weird_no_reinstate_cmr.BaseCMRFactory",
    "jaxcmr.models.weird_position_scale_cmr.BaseCMRFactory",
    "jaxcmr.models.outlistcmrde.BaseCMRFactory",
    "jaxcmr.models.weird_position_scale_cmr.BaseCMRFactory",
    # "jaxcmr.weird_drift_position_scale_cmr.BaseCMRFactory",
]

model_factories = [import_from_string(path) for path in model_paths]

In [None]:
data_paths = {}
for data_tag, trial_query, allow_repeats in zip(data_tags, trial_queries, allow_repeated_recalls):
    data_paths[data_tag] = [data_path.format(data_tag)]
    data = load_data(data_path.format(data_tag))
    max_size = jnp.max(data["pres_itemnos"])
    connections = jnp.zeros((max_size, max_size))

    for model_name, model_factory in zip(model_names, model_factories):
        tag = f"{data_tag}_{model_name}_{run_tag}"
        print(tag)
        fit_path = os.path.join(target_dir, "fits", f"{tag}.json")

        with open(fit_path) as f:
            results = json.load(f)
            if "subject" not in results["fits"]:
                results["fits"]["subject"] = results["subject"]

        rng = random.PRNGKey(seed)
        rng, rng_iter = random.split(rng)
        trial_mask = generate_trial_mask(data, trial_query)
        params = {key: jnp.array(val) for key, val in results["fits"].items()}  # type: ignore
        params["allow_repeated_recalls"] = jnp.array([allow_repeats] * len(results['fitness']))
        sim = simulate_h5_from_h5(
            model_factory=model_factory,
            dataset=data,
            connections=connections,
            parameters=params,
            trial_mask=trial_mask,
            experiment_count=experiment_count,
            rng=rng_iter,
        )

        save_dict_to_hdf5(sim, f"fits/{tag}.h5")
        data_paths[data_tag].append(f"fits/{tag}.h5")
        print(f"Saved {tag} to fits/{tag}.h5")

data_paths

LohnasKahana2014_WeirdCMR_full_best_of_3
Saved LohnasKahana2014_WeirdCMR_full_best_of_3 to fits/LohnasKahana2014_WeirdCMR_full_best_of_3.h5
LohnasKahana2014_WeirdNoReinstateCMR_full_best_of_3
Saved LohnasKahana2014_WeirdNoReinstateCMR_full_best_of_3 to fits/LohnasKahana2014_WeirdNoReinstateCMR_full_best_of_3.h5
LohnasKahana2014_WeirdPositionScaleCMR_full_best_of_3
Saved LohnasKahana2014_WeirdPositionScaleCMR_full_best_of_3 to fits/LohnasKahana2014_WeirdPositionScaleCMR_full_best_of_3.h5
LohnasKahana2014_OutlistCMRDE_full_best_of_3
Saved LohnasKahana2014_OutlistCMRDE_full_best_of_3 to fits/LohnasKahana2014_OutlistCMRDE_full_best_of_3.h5
LohnasKahana2014_WeirdAmaxPositionScaleCMR_full_best_of_3
Saved LohnasKahana2014_WeirdAmaxPositionScaleCMR_full_best_of_3 to fits/LohnasKahana2014_WeirdAmaxPositionScaleCMR_full_best_of_3.h5


{'LohnasKahana2014': ['data/LohnasKahana2014.h5',
  'fits/LohnasKahana2014_WeirdCMR_full_best_of_3.h5',
  'fits/LohnasKahana2014_WeirdNoReinstateCMR_full_best_of_3.h5',
  'fits/LohnasKahana2014_WeirdPositionScaleCMR_full_best_of_3.h5',
  'fits/LohnasKahana2014_OutlistCMRDE_full_best_of_3.h5',
  'fits/LohnasKahana2014_WeirdAmaxPositionScaleCMR_full_best_of_3.h5']}

In [None]:
data_paths = {}
for data_tag, trial_query, allow_repeats in zip(data_tags, trial_queries, allow_repeated_recalls):
    data_paths[data_tag] = [data_path.format(data_tag)]
    for model_name, model_factory in zip(model_names, model_factories):
        tag = f"{data_tag}_{model_name}_{run_tag}"
        data_paths[data_tag].append(f"fits/{tag}.h5")

data_paths

{'LohnasKahana2014': ['data/LohnasKahana2014.h5',
  'fits/LohnasKahana2014_WeirdCMR_full_best_of_3.h5',
  'fits/LohnasKahana2014_WeirdNoReinstateCMR_full_best_of_3.h5',
  'fits/LohnasKahana2014_WeirdPositionScaleCMR_full_best_of_3.h5',
  'fits/LohnasKahana2014_OutlistCMRDE_full_best_of_3.h5',
  'fits/LohnasKahana2014_WeirdDriftPositionScaleCMR_full_best_of_3.h5'],
 'KahanaJacobs2000': ['data/KahanaJacobs2000.h5',
  'fits/KahanaJacobs2000_WeirdCMR_full_best_of_3.h5',
  'fits/KahanaJacobs2000_WeirdNoReinstateCMR_full_best_of_3.h5',
  'fits/KahanaJacobs2000_WeirdPositionScaleCMR_full_best_of_3.h5',
  'fits/KahanaJacobs2000_OutlistCMRDE_full_best_of_3.h5',
  'fits/KahanaJacobs2000_WeirdDriftPositionScaleCMR_full_best_of_3.h5']}