In [1]:
# params for this run
run_tag = "Model_Comparison"
fit_tag = "full_best_of_3"
fit_dir = "fits/"
target_directory = ""

# data params
data_name = "LohnasKahana2014"

# models
# model_names = ["NarrowWeirdInstanceCMRDE", "WeirdInstanceCMRDE", "ConnectionistCMR", "OutlistInstanceCMRDE", "NarrowReinstateOutlistInstanceCMRDE", "ReinstateOutlistCMRDE", "OutlistCMRDE", "ContextCMRDE", "TrueInstanceCMRDE", "ReinstateContextCMRDE", "FlexCMR2"]
# model_names = ["WeirdInstanceCMRDE", "OutlistInstanceCMRDE", "FlexCMRDE", "WeirdPositionalCMR", "AdditiveItemPositionalCMR", "MultiplicativeItemPositionalCMR", "PreexpMfcItemPositionalCMR", "TwoAlphaItemPositionalCMR", "WeirdFlexPositionalCMR"]
model_names = [
    "BaseCMR",
    # "PositionScaleCMR",
    "NoReinstateCMR",
    # "NoScaleNoReinstateCMR",
    # "NoScalePositionScaleBaseCMR",
    # "InstanceCMR",
]

# "FakeNarrowOutlistInstanceCMRDE", "FakeOutlistInstanceCMRDE",
# "MultiContextCMRDE", "NormalContextCMRDE", "NormalMultiContextCMRDE",

model_titles = []

# params to focus on in outputs
query_parameters = [
    "encoding_drift_rate",
    "start_drift_rate",
    "recall_drift_rate",
    "shared_support",
    "item_shared_support",
    "position_shared_support",
    "item_support",
    "learning_rate",
    "primacy_scale",
    "primacy_decay",
    "stop_probability_scale",
    "stop_probability_growth",
    "choice_sensitivity",
    "mfc_trace_sensitivity",
    "mcf_trace_sensitivity",
    "mfc_choice_sensitivity",
    "positional_scale",
    "positional_mfc_scale",
    # "semantic_scale",
    # "semantic_choice_sensitivity",
]

In [2]:
import os
import json
from jaxcmr.summarize import (
    summarize_parameters,
    generate_t_p_matrices,
    winner_comparison_matrix,
    calculate_aic_weights,
)

In [3]:
if not model_titles:
    model_titles = model_names.copy()

results = []
for model_name, model_title in zip(model_names, model_titles):
    fit_path = os.path.join(fit_dir, f"{data_name}_{model_name}_{fit_tag}.json")

    with open(fit_path) as f:
        results.append(json.load(f))
        if "subject" not in results[-1]["fits"]:
            results[-1]["fits"]["subject"] = results[-1]["subject"]
        results[-1]["name"] = model_title

summary = summarize_parameters(
    results, query_parameters, include_std=True, include_ci=True
)

with open(
    os.path.join(target_directory, "tables", f"{data_name}_{fit_tag}_{run_tag}_parameters.md"),
    "w",
) as f:
    f.write(summary)
print(summary)


| | | BaseCMR | NoReinstateCMR |
|---|---|---|---|
| fitness | mean | 1665.55 +/- 146.25 | 1668.19 +/- 147.12 |
| | std | 419.61 | 422.12 |
| encoding drift rate | mean | 0.76 +/- 0.03 | 0.76 +/- 0.04 |
| | std | 0.09 | 0.11 |
| start drift rate | mean | 0.56 +/- 0.12 | 0.47 +/- 0.12 |
| | std | 0.34 | 0.34 |
| recall drift rate | mean | 0.95 +/- 0.01 | 0.95 +/- 0.02 |
| | std | 0.04 | 0.04 |
| shared support | mean | 2.75 +/- 2.06 | 3.18 +/- 2.62 |
| | std | 5.92 | 7.51 |
| item shared support | mean | | |
| | std | | |
| position shared support | mean | | |
| | std | | |
| item support | mean | 6.38 +/- 4.31 | 6.67 +/- 4.65 |
| | std | 12.38 | 13.34 |
| learning rate | mean | 0.42 +/- 0.06 | 0.45 +/- 0.08 |
| | std | 0.18 | 0.24 |
| primacy scale | mean | 15.04 +/- 9.01 | 17.96 +/- 8.80 |
| | std | 25.86 | 25.24 |
| primacy decay | mean | 12.45 +/- 8.30 | 10.36 +/- 7.76 |
| | std | 23.83 | 22.28 |
| stop probability scale | mean | 0.02 +/- 0.01 | 0.02 +/- 0.01 |
| | std | 0.02 | 0.02

In [4]:
df_t, df_p = generate_t_p_matrices(results)

with open(
    os.path.join(target_directory, "tables", f"{data_name}_{fit_tag}_{run_tag}_p_matrix.md"), "w"
) as f:
    f.write(df_p.to_markdown())

print(df_p.to_markdown())
df_p

|                | BaseCMR            | NoReinstateCMR       |
|:---------------|:-------------------|:---------------------|
| BaseCMR        |                    | 0.007247849685136102 |
| NoReinstateCMR | 0.9927521503148639 |                      |


Unnamed: 0,BaseCMR,NoReinstateCMR
BaseCMR,,0.007248
NoReinstateCMR,0.992752,


In [5]:
with open(
    os.path.join(target_directory, "tables", f"{data_name}_{fit_tag}_{run_tag}_t_matrix.md"), "w"
) as f:
    f.write(df_t.to_markdown())

df_t

Unnamed: 0,BaseCMR,NoReinstateCMR
BaseCMR,,-2.576494
NoReinstateCMR,2.576494,


In [6]:
aic_weights = calculate_aic_weights(results)

with open(
    os.path.join(target_directory, "tables", f"{data_name}_{fit_tag}_{run_tag}_aic_weights.md"),
    "w",
) as f:
    f.write(aic_weights.to_markdown())

print(aic_weights.to_markdown())
aic_weights

|    | Model          |        AICw |
|---:|:---------------|------------:|
|  0 | BaseCMR        | 1           |
|  1 | NoReinstateCMR | 7.49202e-41 |


Unnamed: 0,Model,AICw
0,BaseCMR,1.0
1,NoReinstateCMR,7.492015999999999e-41


In [7]:
df_comparison = winner_comparison_matrix(results)

with open(
    os.path.join(target_directory, "tables", f"{data_name}_{fit_tag}_{run_tag}_winner_ratios.md"),
    "w",
) as f:
    f.write(df_comparison.to_markdown().replace(" nan ", "     "))

print(df_comparison.to_markdown().replace(" nan ", "     "))

|                |    BaseCMR |   NoReinstateCMR |
|:---------------|-----------:|-----------------:|
| BaseCMR        |            |         0.742857 |
| NoReinstateCMR |   0.257143 |                  |
