In [1]:
import itertools
import papermill as pm
from tqdm import tqdm

In [2]:
data_parameters = [
    # {"base_data_tag": "BroitmanKahana2024", "trial_query": "data['subject'] != -1"},
    {
        "base_data_tag": "KahanaJacobs2000",
        "trial_query": "jnp.logical_and(data['recall_attempt'] == 1, data['recall_total'] > 0)",
        "mixed_trial_query": "jnp.logical_and(jnp.logical_and(data['recall_attempt'] == 1, data['recall_total'] > 0), data['repetitions'] == 1)",
        "control_trial_query": "jnp.logical_and(jnp.logical_and(data['recall_attempt'] == 1, data['recall_total'] > 0), data['repetitions'] == 0)",
        "single_analysis_paths": ["jaxcmr.analyses.serialrepcrp.plot_rep_crp"],
        "comparison_analysis_paths": [
            "jaxcmr.analyses.spc.plot_spc",
            "jaxcmr.analyses.crp.plot_crp",
            "jaxcmr.analyses.pnr.plot_pnr",
            "jaxcmr.analyses.serialrepcrp.plot_first_rep_crp",
            "jaxcmr.analyses.serialrepcrp.plot_second_rep_crp",
        ],
    },
    {
        "base_data_tag": "GordonRanschburg2021",
        "trial_query": "data['condition'] == 2",
        "mixed_trial_query": "jnp.logical_and(data['condition'] == 2, data['lag'] != 0)",
        "control_trial_query": "jnp.logical_and(data['condition'] == 2, data['lag'] == 0)",
        "single_analysis_paths": ["jaxcmr.analyses.serialrepcrp.plot_rep_crp"],
        "comparison_analysis_paths": [
            "jaxcmr.analyses.spc.plot_spc",
            "jaxcmr.analyses.crp.plot_crp",
            "jaxcmr.analyses.pnr.plot_pnr",
            "jaxcmr.analyses.serialrepcrp.plot_first_rep_crp",
            "jaxcmr.analyses.serialrepcrp.plot_second_rep_crp",
        ],
    },
]

In [3]:
handle_repeats = [True]

In [4]:
model_parameters = [
    {
        "model_name": "WeirdCMR",
        "model_factory_path": "jaxcmr.models_repfr.weird_cmr.BaseCMRFactory",
        "redo_fits": False,
        "redo_sims": False,
        "redo_figures": True,
        "parameters": {
            "fixed": {},
            "free": {
                "encoding_drift_rate": [2.220446049250313e-16, 0.9999999999999998],
                "start_drift_rate": [2.220446049250313e-16, 0.9999999999999998],
                "recall_drift_rate": [2.220446049250313e-16, 0.9999999999999998],
                "shared_support": [2.220446049250313e-16, 99.9999999999999998],
                "item_support": [2.220446049250313e-16, 99.9999999999999998],
                "learning_rate": [2.220446049250313e-16, 0.9999999999999998],
                "primacy_scale": [2.220446049250313e-16, 99.9999999999999998],
                "primacy_decay": [2.220446049250313e-16, 99.9999999999999998],
                "stop_probability_scale": [2.220446049250313e-16, 0.9999999999999998],
                "stop_probability_growth": [2.220446049250313e-16, 9.9999999999999998],
                "choice_sensitivity": [2.220446049250313e-16, 99.9999999999999998],
                "mfc_choice_sensitivity": [2.220446049250313e-16, 99.9999999999999998],
            },
        },
    },
    # {
    #     "model_name": "WeirdReinfPositionalCMR",
    #     "model_factory_path": "jaxcmr.models_repfr.weird_reinf_positional_cmr.BaseCMRFactory",
    #     "redo_fits": False,
    #     "redo_sims": False,
    #     "redo_figures": True,
    #     "parameters": {
    #         "fixed": {"mfc_choice_sensitivity": 1.0},
    #         "free": {
    #             "encoding_drift_rate": [2.220446049250313e-16, 0.9999999999999998],
    #             "start_drift_rate": [2.220446049250313e-16, 0.9999999999999998],
    #             "recall_drift_rate": [2.220446049250313e-16, 0.9999999999999998],
    #             "shared_support": [2.220446049250313e-16, 99.9999999999999998],
    #             "item_support": [2.220446049250313e-16, 99.9999999999999998],
    #             "learning_rate": [2.220446049250313e-16, 0.9999999999999998],
    #             "primacy_scale": [2.220446049250313e-16, 99.9999999999999998],
    #             "primacy_decay": [2.220446049250313e-16, 99.9999999999999998],
    #             "stop_probability_scale": [2.220446049250313e-16, 0.9999999999999998],
    #             "stop_probability_growth": [2.220446049250313e-16, 9.9999999999999998],
    #             "choice_sensitivity": [2.220446049250313e-16, 99.9999999999999998],
    #             "first_presentation_reinforcement": [
    #                 2.220446049250313e-16,
    #                 99.9999999999999998,
    #             ],
    #         },
    #     },
    # },
    {
        "model_name": "WeirdStudyReinfPositionalCMR",
        "model_factory_path": "jaxcmr.models_repfr.weird_study_reinf_positional_cmr.BaseCMRFactory",
        "redo_fits": False,
        "redo_sims": False,
        "redo_figures": True,
        "parameters": {
            "fixed": {"mfc_choice_sensitivity": 1.0},
            "free": {
                "encoding_drift_rate": [2.220446049250313e-16, 0.9999999999999998],
                "start_drift_rate": [2.220446049250313e-16, 0.9999999999999998],
                "recall_drift_rate": [2.220446049250313e-16, 0.9999999999999998],
                "shared_support": [2.220446049250313e-16, 99.9999999999999998],
                "item_support": [2.220446049250313e-16, 99.9999999999999998],
                "learning_rate": [2.220446049250313e-16, 0.9999999999999998],
                "primacy_scale": [2.220446049250313e-16, 99.9999999999999998],
                "primacy_decay": [2.220446049250313e-16, 99.9999999999999998],
                "stop_probability_scale": [2.220446049250313e-16, 0.9999999999999998],
                "stop_probability_growth": [2.220446049250313e-16, 9.9999999999999998],
                "choice_sensitivity": [2.220446049250313e-16, 99.9999999999999998],
                "first_presentation_reinforcement": [
                    2.220446049250313e-16,
                    99.9999999999999998,
                ],
            },
        },
    },
    {
        "model_name": "FullWeirdPositionalCMR",
        "model_factory_path": "jaxcmr.models_repfr.weird_positional_cmr.BaseCMRFactory",
        "redo_fits": False,
        "redo_sims": False,
        "redo_figures": True,
        "parameters": {
            "fixed": {},
            "free": {
                "encoding_drift_rate": [2.220446049250313e-16, 0.9999999999999998],
                "start_drift_rate": [2.220446049250313e-16, 0.9999999999999998],
                "recall_drift_rate": [2.220446049250313e-16, 0.9999999999999998],
                "shared_support": [2.220446049250313e-16, 99.9999999999999998],
                "item_support": [2.220446049250313e-16, 99.9999999999999998],
                "learning_rate": [2.220446049250313e-16, 0.9999999999999998],
                "primacy_scale": [2.220446049250313e-16, 99.9999999999999998],
                "primacy_decay": [2.220446049250313e-16, 99.9999999999999998],
                "stop_probability_scale": [2.220446049250313e-16, 0.9999999999999998],
                "stop_probability_growth": [2.220446049250313e-16, 9.9999999999999998],
                "choice_sensitivity": [2.220446049250313e-16, 99.9999999999999998],
                "mfc_choice_sensitivity": [2.220446049250313e-16, 99.9999999999999998],
            },
        },
    },
    {
        "model_name": "WeirdPositionalCMR",
        "model_factory_path": "jaxcmr.models_repfr.weird_positional_cmr.BaseCMRFactory",
        "redo_fits": False,
        "redo_sims": False,
        "redo_figures": True,
        "parameters": {
            "fixed": {"mfc_choice_sensitivity": 1.0},
            "free": {
                "encoding_drift_rate": [2.220446049250313e-16, 0.9999999999999998],
                "start_drift_rate": [2.220446049250313e-16, 0.9999999999999998],
                "recall_drift_rate": [2.220446049250313e-16, 0.9999999999999998],
                "shared_support": [2.220446049250313e-16, 99.9999999999999998],
                "item_support": [2.220446049250313e-16, 99.9999999999999998],
                "learning_rate": [2.220446049250313e-16, 0.9999999999999998],
                "primacy_scale": [2.220446049250313e-16, 99.9999999999999998],
                "primacy_decay": [2.220446049250313e-16, 99.9999999999999998],
                "stop_probability_scale": [2.220446049250313e-16, 0.9999999999999998],
                "stop_probability_growth": [2.220446049250313e-16, 9.9999999999999998],
                "choice_sensitivity": [2.220446049250313e-16, 99.9999999999999998],
            },
        },
    },
    {
        "model_name": "WeirdNoReinstateCMR",
        "model_factory_path": "jaxcmr.models_repfr.weird_no_reinstate_cmr.BaseCMRFactory",
        "redo_fits": False,
        "redo_sims": False,
        "redo_figures": True,
        "parameters": {
            "fixed": {},
            "free": {
                "encoding_drift_rate": [2.220446049250313e-16, 0.9999999999999998],
                "start_drift_rate": [2.220446049250313e-16, 0.9999999999999998],
                "recall_drift_rate": [2.220446049250313e-16, 0.9999999999999998],
                "shared_support": [2.220446049250313e-16, 99.9999999999999998],
                "item_support": [2.220446049250313e-16, 99.9999999999999998],
                "learning_rate": [2.220446049250313e-16, 0.9999999999999998],
                "primacy_scale": [2.220446049250313e-16, 99.9999999999999998],
                "primacy_decay": [2.220446049250313e-16, 99.9999999999999998],
                "stop_probability_scale": [2.220446049250313e-16, 0.9999999999999998],
                "stop_probability_growth": [2.220446049250313e-16, 9.9999999999999998],
                "choice_sensitivity": [2.220446049250313e-16, 99.9999999999999998],
                # "mfc_choice_sensitivity": [2.220446049250313e-16, 99.9999999999999998],
            },
        },
    },
    # {
    #     "model_name": "WeirdCMRDistinctContexts",
    #     "model_factory_path": "jaxcmr.models_repfr.weird_cmr_distinct_contexts.BaseCMRFactory",
    #     "redo_fits": False,
    #     "redo_sims": False,
    #     "redo_figures": True,
    #     "parameters": {
    #         "fixed": {},
    #         "free": {
    #             "encoding_drift_rate": [2.220446049250313e-16, 0.9999999999999998],
    #             "start_drift_rate": [2.220446049250313e-16, 0.9999999999999998],
    #             "recall_drift_rate": [2.220446049250313e-16, 0.9999999999999998],
    #             "shared_support": [2.220446049250313e-16, 99.9999999999999998],
    #             "item_support": [2.220446049250313e-16, 99.9999999999999998],
    #             "learning_rate": [2.220446049250313e-16, 0.9999999999999998],
    #             "primacy_scale": [2.220446049250313e-16, 99.9999999999999998],
    #             "primacy_decay": [2.220446049250313e-16, 99.9999999999999998],
    #             "stop_probability_scale": [2.220446049250313e-16, 0.9999999999999998],
    #             "stop_probability_growth": [2.220446049250313e-16, 9.9999999999999998],
    #             "choice_sensitivity": [2.220446049250313e-16, 99.9999999999999998],
    #             # "mfc_choice_sensitivity": [2.220446049250313e-16, 99.9999999999999998],
    #         },
    #     },
    # },
    # {
    #     "model_name": "TrueWeirdPositionalCMR",
    #     "model_factory_path": "jaxcmr.models_repfr.true_weird_positional_cmr.BaseCMRFactory",
    #     "redo_fits": False,
    #     "redo_sims": False,
    #     "redo_figures": True,
    #     "parameters": {
    #         "fixed": {},
    #         "free": {
    #             "encoding_drift_rate": [2.220446049250313e-16, 0.9999999999999998],
    #             "start_drift_rate": [2.220446049250313e-16, 0.9999999999999998],
    #             "recall_drift_rate": [2.220446049250313e-16, 0.9999999999999998],
    #             "shared_support": [2.220446049250313e-16, 99.9999999999999998],
    #             "item_support": [2.220446049250313e-16, 99.9999999999999998],
    #             "learning_rate": [2.220446049250313e-16, 0.9999999999999998],
    #             "primacy_scale": [2.220446049250313e-16, 99.9999999999999998],
    #             "primacy_decay": [2.220446049250313e-16, 99.9999999999999998],
    #             "stop_probability_scale": [2.220446049250313e-16, 0.9999999999999998],
    #             "stop_probability_growth": [2.220446049250313e-16, 9.9999999999999998],
    #             "choice_sensitivity": [2.220446049250313e-16, 99.9999999999999998],
    #             # "mfc_choice_sensitivity": [2.220446049250313e-16, 99.9999999999999998],
    #         },
    #     },
    # },
]

In [None]:
for data_params, model_params, allow_repeated_recalls in tqdm(
    itertools.product(data_parameters, model_parameters, handle_repeats)
):
    # configure handling of repeated recalls
    base_data_tag = data_params["base_data_tag"]
    if allow_repeated_recalls:
        filter_repeated_recalls = False
        data_tag = f"RepeatedRecalls{base_data_tag}"
        data_path = f"data/RepeatedRecalls{base_data_tag}.h5"
    else:
        filter_repeated_recalls = True
        data_tag = base_data_tag
        data_path = f"data/{base_data_tag}.h5"

    output_path = (
        f"projects/thesis/{data_tag}_{model_params['model_name']}_Fitting.ipynb"
    )
    print(output_path)
    print(data_params)
    print(model_params)

    pm.execute_notebook(
        "projects/thesis/Fitting_Serial.ipynb",
        output_path,
        autosave_cell_every=180,
        log_output=True,
        parameters={
            "allow_repeated_recalls": allow_repeated_recalls,
            "filter_repeated_recalls": filter_repeated_recalls,
            "data_tag": data_tag,
            "data_path": data_path,
            **data_params,
            **model_params,
        },
    )

0it [00:00, ?it/s]Unable to parse line 8 'trial_query =  "jnp.logical_and(data['recall_attempt'] == 1, data['recall_total'] > 0)"'.
Unable to parse line 9 'mixed_trial_query = "jnp.logical_and(jnp.logical_and(data['recall_attempt'] == 1, data['recall_total'] > 0), data['repetitions'] == 1)"'.
Unable to parse line 10 'control_trial_query = "jnp.logical_and(jnp.logical_and(data['recall_attempt'] == 1, data['recall_total'] > 0), data['repetitions'] == 0)"'.
Passed unknown parameter: base_data_tag
Passed unknown parameter: trial_query
Passed unknown parameter: mixed_trial_query
Passed unknown parameter: control_trial_query


projects/thesis/RepeatedRecallsKahanaJacobs2000_WeirdCMR_Fitting.ipynb
{'base_data_tag': 'KahanaJacobs2000', 'trial_query': "jnp.logical_and(data['recall_attempt'] == 1, data['recall_total'] > 0)", 'mixed_trial_query': "jnp.logical_and(jnp.logical_and(data['recall_attempt'] == 1, data['recall_total'] > 0), data['repetitions'] == 1)", 'control_trial_query': "jnp.logical_and(jnp.logical_and(data['recall_attempt'] == 1, data['recall_total'] > 0), data['repetitions'] == 0)", 'single_analysis_paths': ['jaxcmr.analyses.serialrepcrp.plot_rep_crp'], 'comparison_analysis_paths': ['jaxcmr.analyses.spc.plot_spc', 'jaxcmr.analyses.crp.plot_crp', 'jaxcmr.analyses.pnr.plot_pnr', 'jaxcmr.analyses.serialrepcrp.plot_first_rep_crp', 'jaxcmr.analyses.serialrepcrp.plot_second_rep_crp']}
{'model_name': 'WeirdCMR', 'model_factory_path': 'jaxcmr.models_repfr.weird_cmr.BaseCMRFactory', 'redo_fits': False, 'redo_sims': False, 'redo_figures': True, 'parameters': {'fixed': {}, 'free': {'encoding_drift_rate': [2.

Executing:   0%|          | 0/13 [00:00<?, ?cell/s]