In [1]:
# params for this run
run_tag = "Model_Comparison"
fit_tag = "full_best_of_3"
fit_dir = "fits/"
target_directory = ""

# data params
data_name = "LohnasKahana2014"

# models
# model_names = ["NarrowWeirdInstanceCMRDE", "WeirdInstanceCMRDE", "ConnectionistCMR", "OutlistInstanceCMRDE", "NarrowReinstateOutlistInstanceCMRDE", "ReinstateOutlistCMRDE", "OutlistCMRDE", "ContextCMRDE", "TrueInstanceCMRDE", "ReinstateContextCMRDE", "FlexCMR2"]
# model_names = ["WeirdInstanceCMRDE", "OutlistInstanceCMRDE", "FlexCMRDE", "WeirdPositionalCMR", "AdditiveItemPositionalCMR", "MultiplicativeItemPositionalCMR", "PreexpMfcItemPositionalCMR", "TwoAlphaItemPositionalCMR", "WeirdFlexPositionalCMR"]
model_names = [
    "BaseCMR",
    "WeirdCMR",
    # "PositionScaleCMR",
    "NoReinstateCMR",
    "WeirdNoReinstateCMR",
    # "NoScaleNoReinstateCMR",
    # "NoScalePositionScaleBaseCMR",
    # "InstanceCMR",
]

# "FakeNarrowOutlistInstanceCMRDE", "FakeOutlistInstanceCMRDE",
# "MultiContextCMRDE", "NormalContextCMRDE", "NormalMultiContextCMRDE",

model_titles = []

# params to focus on in outputs
query_parameters = [
    "encoding_drift_rate",
    "start_drift_rate",
    "recall_drift_rate",
    "shared_support",
    "item_shared_support",
    "position_shared_support",
    "item_support",
    "learning_rate",
    "primacy_scale",
    "primacy_decay",
    "stop_probability_scale",
    "stop_probability_growth",
    "choice_sensitivity",
    "mfc_trace_sensitivity",
    "mcf_trace_sensitivity",
    "mfc_choice_sensitivity",
    "positional_scale",
    "positional_mfc_scale",
    # "semantic_scale",
    # "semantic_choice_sensitivity",
]

In [2]:
import os
import json
from jaxcmr.summarize import (
    summarize_parameters,
    generate_t_p_matrices,
    winner_comparison_matrix,
    calculate_aic_weights,
)

In [3]:
if not model_titles:
    model_titles = model_names.copy()

results = []
for model_name, model_title in zip(model_names, model_titles):
    fit_path = os.path.join(fit_dir, f"{data_name}_{model_name}_{fit_tag}.json")

    with open(fit_path) as f:
        results.append(json.load(f))
        if "subject" not in results[-1]["fits"]:
            results[-1]["fits"]["subject"] = results[-1]["subject"]
        results[-1]["name"] = model_title

summary = summarize_parameters(
    results, query_parameters, include_std=True, include_ci=True
)

with open(
    os.path.join(target_directory, "tables", f"{data_name}_{fit_tag}_{run_tag}_parameters.md"),
    "w",
) as f:
    f.write(summary)
print(summary)


| | | BaseCMR | WeirdCMR | NoReinstateCMR | WeirdNoReinstateCMR |
|---|---|---|---|---|---|
| fitness | mean | 1665.55 +/- 146.25 | 1667.75 +/- 146.53 | 1668.19 +/- 147.12 | 1667.86 +/- 146.86 |
| | std | 419.61 | 420.43 | 422.12 | 421.38 |
| encoding drift rate | mean | 0.76 +/- 0.03 | 0.77 +/- 0.04 | 0.76 +/- 0.04 | 0.77 +/- 0.04 |
| | std | 0.09 | 0.12 | 0.11 | 0.10 |
| start drift rate | mean | 0.56 +/- 0.12 | 0.51 +/- 0.12 | 0.47 +/- 0.12 | 0.37 +/- 0.12 |
| | std | 0.34 | 0.35 | 0.34 | 0.34 |
| recall drift rate | mean | 0.95 +/- 0.01 | 0.93 +/- 0.02 | 0.95 +/- 0.02 | 0.94 +/- 0.01 |
| | std | 0.04 | 0.05 | 0.04 | 0.04 |
| shared support | mean | 2.75 +/- 2.06 | 2.93 +/- 2.08 | 3.18 +/- 2.62 | 2.24 +/- 2.91 |
| | std | 5.92 | 5.97 | 7.51 | 8.34 |
| item shared support | mean | | | | |
| | std | | | | |
| position shared support | mean | | | | |
| | std | | | | |
| item support | mean | 6.38 +/- 4.31 | 9.16 +/- 6.16 | 6.67 +/- 4.65 | 6.28 +/- 4.47 |
| | std | 12.38 | 17.68 | 13.34

In [4]:
df_t, df_p = generate_t_p_matrices(results)

with open(
    os.path.join(target_directory, "tables", f"{data_name}_{fit_tag}_{run_tag}_p_matrix.md"), "w"
) as f:
    f.write(df_p.to_markdown())

print(df_p.to_markdown())
df_p

|                     | BaseCMR            | WeirdCMR             | NoReinstateCMR       | WeirdNoReinstateCMR   |
|:--------------------|:-------------------|:---------------------|:---------------------|:----------------------|
| BaseCMR             |                    | 0.011853712058516947 | 0.007247849685136102 | 0.012210625428854123  |
| WeirdCMR            | 0.9881462879414831 |                      | 0.33974651012291335  | 0.4607493193733744    |
| NoReinstateCMR      | 0.9927521503148639 | 0.6602534898770867   |                      | 0.6340989573038773    |
| WeirdNoReinstateCMR | 0.9877893745711459 | 0.5392506806266256   | 0.3659010426961227   |                       |


Unnamed: 0,BaseCMR,WeirdCMR,NoReinstateCMR,WeirdNoReinstateCMR
BaseCMR,,0.011854,0.007248,0.012211
WeirdCMR,0.988146,,0.339747,0.460749
NoReinstateCMR,0.992752,0.660253,,0.634099
WeirdNoReinstateCMR,0.987789,0.539251,0.365901,


In [5]:
with open(
    os.path.join(target_directory, "tables", f"{data_name}_{fit_tag}_{run_tag}_t_matrix.md"), "w"
) as f:
    f.write(df_t.to_markdown())

df_t

Unnamed: 0,BaseCMR,WeirdCMR,NoReinstateCMR,WeirdNoReinstateCMR
BaseCMR,,-2.368098,-2.576494,-2.355237
WeirdCMR,2.368098,,-0.416733,-0.09928
NoReinstateCMR,2.576494,0.416733,,0.34556
WeirdNoReinstateCMR,2.355237,0.09928,-0.34556,


In [6]:
aic_weights = calculate_aic_weights(results)

with open(
    os.path.join(target_directory, "tables", f"{data_name}_{fit_tag}_{run_tag}_aic_weights.md"),
    "w",
) as f:
    f.write(aic_weights.to_markdown())

print(aic_weights.to_markdown())
aic_weights

|    | Model               |        AICw |
|---:|:--------------------|------------:|
|  0 | BaseCMR             | 1           |
|  1 | WeirdCMR            | 3.27824e-34 |
|  3 | WeirdNoReinstateCMR | 6.86928e-36 |
|  2 | NoReinstateCMR      | 7.49202e-41 |


Unnamed: 0,Model,AICw
0,BaseCMR,1.0
1,WeirdCMR,3.278242e-34
3,WeirdNoReinstateCMR,6.869284e-36
2,NoReinstateCMR,7.492015999999999e-41


In [7]:
df_comparison = winner_comparison_matrix(results)

with open(
    os.path.join(target_directory, "tables", f"{data_name}_{fit_tag}_{run_tag}_winner_ratios.md"),
    "w",
) as f:
    f.write(df_comparison.to_markdown().replace(" nan ", "     "))

print(df_comparison.to_markdown().replace(" nan ", "     "))

|                     |    BaseCMR |   WeirdCMR |   NoReinstateCMR |   WeirdNoReinstateCMR |
|:--------------------|-----------:|-----------:|-----------------:|----------------------:|
| BaseCMR             |            |   0.714286 |         0.742857 |              0.657143 |
| WeirdCMR            |   0.285714 |            |         0.571429 |              0.571429 |
| NoReinstateCMR      |   0.257143 |   0.428571 |                  |              0.457143 |
| WeirdNoReinstateCMR |   0.342857 |   0.428571 |         0.542857 |                       |
