In [1]:
# params for this run
run_tag = "Model_Comparison"
fit_tag = "full_best_of_3"
fit_dir = "fits/"
target_directory = ""

# data params
data_name = "BroitmanKahana2024"

# models
# model_names = ["NarrowWeirdInstanceCMRDE", "WeirdInstanceCMRDE", "ConnectionistCMR", "OutlistInstanceCMRDE", "NarrowReinstateOutlistInstanceCMRDE", "ReinstateOutlistCMRDE", "OutlistCMRDE", "ContextCMRDE", "TrueInstanceCMRDE", "ReinstateContextCMRDE", "FlexCMR2"]
# model_names = ["WeirdInstanceCMRDE", "OutlistInstanceCMRDE", "FlexCMRDE", "WeirdPositionalCMR", "AdditiveItemPositionalCMR", "MultiplicativeItemPositionalCMR", "PreexpMfcItemPositionalCMR", "TwoAlphaItemPositionalCMR", "WeirdFlexPositionalCMR"]
model_names = [
    # "BaseCMR",
    "WeirdCMR",
    "WeirdPositionScaleCMR",
    # "NoReinstateCMR",
    "WeirdNoReinstateCMR",
    "OutlistCMRDE",
    # "WeirdNoPrexpPositionCMR",
    # "FlexPositionScaleCMR",
    # "NoPrexpPositionCMR",
    # "NoScaleNoReinstateCMR",
    # "NoScalePositionScaleBaseCMR",
    # "InstanceCMR",
]

# "FakeNarrowOutlistInstanceCMRDE", "FakeOutlistInstanceCMRDE",
# "MultiContextCMRDE", "NormalContextCMRDE", "NormalMultiContextCMRDE",

model_titles = []

# params to focus on in outputs
query_parameters = [
    "encoding_drift_rate",
    "start_drift_rate",
    "recall_drift_rate",
    "shared_support",
    # "item_shared_support",
    # "position_shared_support",
    "item_support",
    "learning_rate",
    "primacy_scale",
    "primacy_decay",
    "stop_probability_scale",
    "stop_probability_growth",
    "choice_sensitivity",
    # "mfc_trace_sensitivity",
    # "mcf_trace_sensitivity",
    "mfc_choice_sensitivity",
    # "positional_scale",
    # "positional_mfc_scale",
    # "semantic_scale",
    # "semantic_choice_sensitivity",
]

In [2]:
import os
import json
from jaxcmr.summarize import (
    summarize_parameters,
    generate_t_p_matrices,
    winner_comparison_matrix,
    calculate_aic_weights,
)

In [3]:
if not model_titles:
    model_titles = model_names.copy()

results = []
for model_name, model_title in zip(model_names, model_titles):
    fit_path = os.path.join(fit_dir, f"{data_name}_{model_name}_{fit_tag}.json")

    with open(fit_path) as f:
        results.append(json.load(f))
        if "subject" not in results[-1]["fits"]:
            results[-1]["fits"]["subject"] = results[-1]["subject"]
        results[-1]["name"] = model_title

summary = summarize_parameters(
    results, query_parameters, include_std=True, include_ci=True
)

with open(
    os.path.join(target_directory, "tables", f"{data_name}_{fit_tag}_{run_tag}_parameters.md"),
    "w",
) as f:
    f.write(summary)
print(summary)


| | | WeirdCMR | WeirdPositionScaleCMR | WeirdNoReinstateCMR | OutlistCMRDE |
|---|---|---|---|---|---|
| fitness | mean | 1267.41 +/- 278.90 | 1260.70 +/- 278.83 | 1266.60 +/- 278.64 | 1258.53 +/- 276.51 |
| | std | 774.55 | 774.34 | 773.82 | 767.92 |
| encoding drift rate | mean | 0.74 +/- 0.03 | 0.81 +/- 0.05 | 0.72 +/- 0.03 | 0.72 +/- 0.04 |
| | std | 0.09 | 0.13 | 0.08 | 0.11 |
| start drift rate | mean | 0.36 +/- 0.11 | 0.29 +/- 0.10 | 0.34 +/- 0.11 | 0.35 +/- 0.11 |
| | std | 0.30 | 0.27 | 0.30 | 0.30 |
| recall drift rate | mean | 0.83 +/- 0.06 | 0.75 +/- 0.07 | 0.81 +/- 0.06 | 0.76 +/- 0.06 |
| | std | 0.17 | 0.20 | 0.17 | 0.18 |
| shared support | mean | 21.98 +/- 7.73 | 18.74 +/- 8.33 | 22.66 +/- 6.69 | 19.88 +/- 6.65 |
| | std | 21.47 | 23.15 | 18.58 | 18.45 |
| item support | mean | 30.63 +/- 8.61 | 27.78 +/- 10.57 | 31.68 +/- 8.36 | 24.96 +/- 7.59 |
| | std | 23.92 | 29.35 | 23.21 | 21.08 |
| learning rate | mean | 0.12 +/- 0.02 | 0.40 +/- 0.10 | 0.17 +/- 0.04 | 0.24 +/- 

In [4]:
df_t, df_p = generate_t_p_matrices(results)

with open(
    os.path.join(target_directory, "tables", f"{data_name}_{fit_tag}_{run_tag}_p_matrix.md"), "w"
) as f:
    f.write(df_p.to_markdown())

print(df_p.to_markdown())
df_p

|                       | WeirdCMR               | WeirdPositionScaleCMR   | WeirdNoReinstateCMR    | OutlistCMRDE       |
|:----------------------|:-----------------------|:------------------------|:-----------------------|:-------------------|
| WeirdCMR              |                        | 0.9689316856681806      | 0.975504071337862      | 0.9998132316685929 |
| WeirdPositionScaleCMR | 0.03106831433181948    |                         | 0.052073208152561924   | 0.7572985586225554 |
| WeirdNoReinstateCMR   | 0.024495928662137987   | 0.9479267918474381      |                        | 0.9995807153411382 |
| OutlistCMRDE          | 0.00018676833140703468 | 0.2427014413774446      | 0.00041928465886174405 |                    |


Unnamed: 0,WeirdCMR,WeirdPositionScaleCMR,WeirdNoReinstateCMR,OutlistCMRDE
WeirdCMR,,0.968932,0.975504,0.999813
WeirdPositionScaleCMR,0.031068,,0.052073,0.757299
WeirdNoReinstateCMR,0.024496,0.947927,,0.999581
OutlistCMRDE,0.000187,0.242701,0.000419,


In [5]:
with open(
    os.path.join(target_directory, "tables", f"{data_name}_{fit_tag}_{run_tag}_t_matrix.md"), "w"
) as f:
    f.write(df_t.to_markdown())

df_t

Unnamed: 0,WeirdCMR,WeirdPositionScaleCMR,WeirdNoReinstateCMR,OutlistCMRDE
WeirdCMR,,1.93292,2.046538,3.97706
WeirdPositionScaleCMR,-1.93292,,-1.672653,0.705827
WeirdNoReinstateCMR,-2.046538,1.672653,,3.68602
OutlistCMRDE,-3.97706,-0.705827,-3.68602,


In [6]:
aic_weights = calculate_aic_weights(results)

with open(
    os.path.join(target_directory, "tables", f"{data_name}_{fit_tag}_{run_tag}_aic_weights.md"),
    "w",
) as f:
    f.write(aic_weights.to_markdown())

print(aic_weights.to_markdown())
aic_weights

|    | Model                 |         AICw |
|---:|:----------------------|-------------:|
|  3 | OutlistCMRDE          | 1            |
|  1 | WeirdPositionScaleCMR | 2.28709e-31  |
|  2 | WeirdNoReinstateCMR   | 5.44146e-116 |
|  0 | WeirdCMR              | 1.40562e-127 |


Unnamed: 0,Model,AICw
3,OutlistCMRDE,1.0
1,WeirdPositionScaleCMR,2.287093e-31
2,WeirdNoReinstateCMR,5.441463e-116
0,WeirdCMR,1.4056179999999999e-127


In [7]:
df_comparison = winner_comparison_matrix(results)

with open(
    os.path.join(target_directory, "tables", f"{data_name}_{fit_tag}_{run_tag}_winner_ratios.md"),
    "w",
) as f:
    f.write(df_comparison.to_markdown().replace(" nan ", "     "))

print(df_comparison.to_markdown().replace(" nan ", "     "))

|                       |   WeirdCMR |   WeirdPositionScaleCMR |   WeirdNoReinstateCMR |   OutlistCMRDE |
|:----------------------|-----------:|------------------------:|----------------------:|---------------:|
| WeirdCMR              |            |                0.272727 |              0.454545 |       0.151515 |
| WeirdPositionScaleCMR |   0.727273 |                         |              0.727273 |       0.575758 |
| WeirdNoReinstateCMR   |   0.545455 |                0.272727 |                       |       0.181818 |
| OutlistCMRDE          |   0.848485 |                0.424242 |              0.818182 |                |
