In [1]:
# params for this run
run_tag = "Model_Comparison"
fit_tag = "full_best_of_3"
fit_dir = "fits/"
target_directory = ""

# data params
data_name = "LohnasKahana2014"

# models
# model_names = ["NarrowWeirdInstanceCMRDE", "WeirdInstanceCMRDE", "ConnectionistCMR", "OutlistInstanceCMRDE", "NarrowReinstateOutlistInstanceCMRDE", "ReinstateOutlistCMRDE", "OutlistCMRDE", "ContextCMRDE", "TrueInstanceCMRDE", "ReinstateContextCMRDE", "FlexCMR2"]
# model_names = ["WeirdInstanceCMRDE", "OutlistInstanceCMRDE", "FlexCMRDE", "WeirdPositionalCMR", "AdditiveItemPositionalCMR", "MultiplicativeItemPositionalCMR", "PreexpMfcItemPositionalCMR", "TwoAlphaItemPositionalCMR", "WeirdFlexPositionalCMR"]
model_names = [
    # "BaseCMR",
    "WeirdCMR",
    "WeirdPositionScaleCMR",
    # "NoReinstateCMR",
    "WeirdNoReinstateCMR",
    "OutlistCMRDE",
    "WeirdNoPrexpPositionCMR",
    # "NoScaleNoReinstateCMR",
    # "NoScalePositionScaleBaseCMR",
    # "InstanceCMR",
]

# "FakeNarrowOutlistInstanceCMRDE", "FakeOutlistInstanceCMRDE",
# "MultiContextCMRDE", "NormalContextCMRDE", "NormalMultiContextCMRDE",

model_titles = []

# params to focus on in outputs
query_parameters = [
    "encoding_drift_rate",
    "start_drift_rate",
    "recall_drift_rate",
    "shared_support",
    "item_shared_support",
    "position_shared_support",
    "item_support",
    "learning_rate",
    "primacy_scale",
    "primacy_decay",
    "stop_probability_scale",
    "stop_probability_growth",
    "choice_sensitivity",
    "mfc_trace_sensitivity",
    "mcf_trace_sensitivity",
    "mfc_choice_sensitivity",
    "positional_scale",
    "positional_mfc_scale",
    # "semantic_scale",
    # "semantic_choice_sensitivity",
]

In [2]:
import os
import json
from jaxcmr.summarize import (
    summarize_parameters,
    generate_t_p_matrices,
    winner_comparison_matrix,
    calculate_aic_weights,
)

In [3]:
if not model_titles:
    model_titles = model_names.copy()

results = []
for model_name, model_title in zip(model_names, model_titles):
    fit_path = os.path.join(fit_dir, f"{data_name}_{model_name}_{fit_tag}.json")

    with open(fit_path) as f:
        results.append(json.load(f))
        if "subject" not in results[-1]["fits"]:
            results[-1]["fits"]["subject"] = results[-1]["subject"]
        results[-1]["name"] = model_title

summary = summarize_parameters(
    results, query_parameters, include_std=True, include_ci=True
)

with open(
    os.path.join(target_directory, "tables", f"{data_name}_{fit_tag}_{run_tag}_parameters.md"),
    "w",
) as f:
    f.write(summary)
print(summary)


| | | WeirdCMR | WeirdPositionScaleCMR | WeirdNoReinstateCMR | OutlistCMRDE | WeirdNoPrexpPositionCMR |
|---|---|---|---|---|---|---|
| fitness | mean | 1667.75 +/- 146.53 | 1659.83 +/- 146.81 | 1667.86 +/- 146.86 | 1660.34 +/- 146.28 | 1659.83 +/- 146.78 |
| | std | 420.43 | 421.22 | 421.38 | 419.71 | 421.13 |
| encoding drift rate | mean | 0.77 +/- 0.04 | 0.71 +/- 0.04 | 0.77 +/- 0.04 | 0.75 +/- 0.04 | 0.72 +/- 0.04 |
| | std | 0.12 | 0.12 | 0.10 | 0.12 | 0.12 |
| start drift rate | mean | 0.51 +/- 0.12 | 0.45 +/- 0.10 | 0.37 +/- 0.12 | 0.41 +/- 0.10 | 0.37 +/- 0.10 |
| | std | 0.35 | 0.29 | 0.34 | 0.29 | 0.30 |
| recall drift rate | mean | 0.93 +/- 0.02 | 0.91 +/- 0.02 | 0.94 +/- 0.01 | 0.93 +/- 0.01 | 0.94 +/- 0.02 |
| | std | 0.05 | 0.06 | 0.04 | 0.04 | 0.04 |
| shared support | mean | 2.93 +/- 2.08 | 4.66 +/- 2.32 | 2.24 +/- 2.91 | 5.34 +/- 2.78 | 2.31 +/- 1.44 |
| | std | 5.97 | 6.66 | 8.34 | 7.98 | 4.14 |
| item shared support | mean | | | | | |
| | std | | | | | |
| position s

In [4]:
df_t, df_p = generate_t_p_matrices(results)

with open(
    os.path.join(target_directory, "tables", f"{data_name}_{fit_tag}_{run_tag}_p_matrix.md"), "w"
) as f:
    f.write(df_p.to_markdown())

print(df_p.to_markdown())
df_p

|                         | WeirdCMR               | WeirdPositionScaleCMR   | WeirdNoReinstateCMR    | OutlistCMRDE        | WeirdNoPrexpPositionCMR   |
|:------------------------|:-----------------------|:------------------------|:-----------------------|:--------------------|:--------------------------|
| WeirdCMR                |                        | 0.9999961851599968      | 0.4607493193733744     | 0.9999993900826833  | 0.9999988540109558        |
| WeirdPositionScaleCMR   | 3.814840003284842e-06  |                         | 1.5893901076511372e-07 | 0.3607711231613683  | 0.49710342997778967       |
| WeirdNoReinstateCMR     | 0.5392506806266256     | 0.9999998410609893      |                        | 0.9999982189592076  | 0.9999999813102048        |
| OutlistCMRDE            | 6.09917316586724e-07   | 0.6392288768386317      | 1.781040792387375e-06  |                     | 0.6458341521361792        |
| WeirdNoPrexpPositionCMR | 1.1459890441792031e-06 | 0.5028965700222103     

Unnamed: 0,WeirdCMR,WeirdPositionScaleCMR,WeirdNoReinstateCMR,OutlistCMRDE,WeirdNoPrexpPositionCMR
WeirdCMR,,0.999996,0.460749,0.999999,0.999999
WeirdPositionScaleCMR,4e-06,,0.0,0.360771,0.497103
WeirdNoReinstateCMR,0.539251,1.0,,0.999998,1.0
OutlistCMRDE,1e-06,0.639229,2e-06,,0.645834
WeirdNoPrexpPositionCMR,1e-06,0.502897,0.0,0.354166,


In [5]:
with open(
    os.path.join(target_directory, "tables", f"{data_name}_{fit_tag}_{run_tag}_t_matrix.md"), "w"
) as f:
    f.write(df_t.to_markdown())

df_t

Unnamed: 0,WeirdCMR,WeirdPositionScaleCMR,WeirdNoReinstateCMR,OutlistCMRDE,WeirdNoPrexpPositionCMR
WeirdCMR,,5.271482,-0.09928,5.883552,5.672886
WeirdPositionScaleCMR,-5.271482,,-6.334471,-0.359368,-0.007314
WeirdNoReinstateCMR,0.09928,6.334471,,5.525735,7.061236
OutlistCMRDE,-5.883552,0.359368,-5.525735,,0.377251
WeirdNoPrexpPositionCMR,-5.672886,0.007314,-7.061236,-0.377251,


In [6]:
aic_weights = calculate_aic_weights(results)

with open(
    os.path.join(target_directory, "tables", f"{data_name}_{fit_tag}_{run_tag}_aic_weights.md"),
    "w",
) as f:
    f.write(aic_weights.to_markdown())

print(aic_weights.to_markdown())
aic_weights

|    | Model                   |         AICw |
|---:|:------------------------|-------------:|
|  1 | WeirdPositionScaleCMR   | 0.571049     |
|  4 | WeirdNoPrexpPositionCMR | 0.428951     |
|  3 | OutlistCMRDE            | 3.24986e-09  |
|  0 | WeirdCMR                | 1.90224e-121 |
|  2 | WeirdNoReinstateCMR     | 3.98598e-123 |


Unnamed: 0,Model,AICw
1,WeirdPositionScaleCMR,0.5710491
4,WeirdNoPrexpPositionCMR,0.4289509
3,OutlistCMRDE,3.249856e-09
0,WeirdCMR,1.902238e-121
2,WeirdNoReinstateCMR,3.985983e-123


In [7]:
df_comparison = winner_comparison_matrix(results)

with open(
    os.path.join(target_directory, "tables", f"{data_name}_{fit_tag}_{run_tag}_winner_ratios.md"),
    "w",
) as f:
    f.write(df_comparison.to_markdown().replace(" nan ", "     "))

print(df_comparison.to_markdown().replace(" nan ", "     "))

|                         |   WeirdCMR |   WeirdPositionScaleCMR |   WeirdNoReinstateCMR |   OutlistCMRDE |   WeirdNoPrexpPositionCMR |
|:------------------------|-----------:|------------------------:|----------------------:|---------------:|--------------------------:|
| WeirdCMR                |            |               0.142857  |              0.571429 |       0.114286 |                  0.142857 |
| WeirdPositionScaleCMR   |   0.857143 |                         |              0.914286 |       0.542857 |                  0.542857 |
| WeirdNoReinstateCMR     |   0.428571 |               0.0857143 |                       |       0.2      |                  0.114286 |
| OutlistCMRDE            |   0.885714 |               0.457143  |              0.8      |                |                  0.4      |
| WeirdNoPrexpPositionCMR |   0.857143 |               0.457143  |              0.885714 |       0.6      |                           |
