In [None]:
# params for this run
run_tag = "Model_Comparison"
fit_tag = "full_best_of_3"
fit_dir = "fits/"
target_directory = "projects/KahanaJacobs2000"

# data params
data_name = "KahanaJacobs2000"

# models
# model_names = ["NarrowWeirdInstanceCMRDE", "WeirdInstanceCMRDE", "ConnectionistCMR", "OutlistInstanceCMRDE", "NarrowReinstateOutlistInstanceCMRDE", "ReinstateOutlistCMRDE", "OutlistCMRDE", "ContextCMRDE", "TrueInstanceCMRDE", "ReinstateContextCMRDE", "FlexCMR2"]
# model_names = ["WeirdInstanceCMRDE", "OutlistInstanceCMRDE", "FlexCMRDE", "WeirdPositionalCMR", "AdditiveItemPositionalCMR", "MultiplicativeItemPositionalCMR", "PreexpMfcItemPositionalCMR", "TwoAlphaItemPositionalCMR", "WeirdFlexPositionalCMR"]
model_names = [
    # "BaseCMR",
    "WeirdCMR",
    "WeirdPositionScaleCMR",
    # "NoReinstateCMR",
    "WeirdNoReinstateCMR",
    "OutlistCMRDE",
    # "WeirdNoPrexpPositionCMR",
    # "FlexPositionScaleCMR",
    # "NoPrexpPositionCMR",
    # "NoScaleNoReinstateCMR",
    # "NoScalePositionScaleBaseCMR",
    # "InstanceCMR",
]

# "FakeNarrowOutlistInstanceCMRDE", "FakeOutlistInstanceCMRDE",
# "MultiContextCMRDE", "NormalContextCMRDE", "NormalMultiContextCMRDE",

model_titles = []

# params to focus on in outputs
query_parameters = [
    "encoding_drift_rate",
    "start_drift_rate",
    "recall_drift_rate",
    "shared_support",
    # "item_shared_support",
    # "position_shared_support",
    "item_support",
    "learning_rate",
    "primacy_scale",
    "primacy_decay",
    "stop_probability_scale",
    "stop_probability_growth",
    "choice_sensitivity",
    # "mfc_trace_sensitivity",
    # "mcf_trace_sensitivity",
    "mfc_choice_sensitivity",
    # "positional_scale",
    # "positional_mfc_scale",
    # "semantic_scale",
    # "semantic_choice_sensitivity",
]

In [2]:
import os
import json
from jaxcmr.summarize import (
    summarize_parameters,
    generate_t_p_matrices,
    winner_comparison_matrix,
    calculate_aic_weights,
)

In [3]:
if not model_titles:
    model_titles = model_names.copy()

results = []
for model_name, model_title in zip(model_names, model_titles):
    fit_path = os.path.join(fit_dir, f"{data_name}_{model_name}_{fit_tag}.json")

    with open(fit_path) as f:
        results.append(json.load(f))
        if "subject" not in results[-1]["fits"]:
            results[-1]["fits"]["subject"] = results[-1]["subject"]
        results[-1]["name"] = model_title

summary = summarize_parameters(
    results, query_parameters, include_std=True, include_ci=True
)

with open(
    os.path.join(target_directory, "tables", f"{data_name}_{fit_tag}_{run_tag}_parameters.md"),
    "w",
) as f:
    f.write(summary)
print(summary)


| | | WeirdCMR | WeirdPositionScaleCMR | WeirdNoReinstateCMR | OutlistCMRDE |
|---|---|---|---|---|---|
| fitness | mean | 3850.28 +/- 530.27 | 3760.09 +/- 516.57 | 3850.20 +/- 530.33 | 3843.70 +/- 530.39 |
| | std | 1070.84 | 1043.18 | 1070.96 | 1071.09 |
| encoding drift rate | mean | 0.88 +/- 0.02 | 0.90 +/- 0.03 | 0.88 +/- 0.02 | 0.88 +/- 0.02 |
| | std | 0.04 | 0.05 | 0.04 | 0.04 |
| start drift rate | mean | 0.54 +/- 0.11 | 0.55 +/- 0.08 | 0.55 +/- 0.09 | 0.54 +/- 0.10 |
| | std | 0.21 | 0.17 | 0.18 | 0.21 |
| recall drift rate | mean | 0.81 +/- 0.03 | 0.81 +/- 0.04 | 0.81 +/- 0.03 | 0.80 +/- 0.04 |
| | std | 0.07 | 0.08 | 0.07 | 0.07 |
| shared support | mean | 42.35 +/- 19.56 | 41.47 +/- 19.36 | 35.24 +/- 18.91 | 31.79 +/- 17.25 |
| | std | 39.50 | 39.09 | 38.18 | 34.83 |
| item support | mean | 38.26 +/- 17.73 | 36.34 +/- 16.87 | 31.70 +/- 16.84 | 29.26 +/- 15.98 |
| | std | 35.81 | 34.06 | 34.02 | 32.27 |
| learning rate | mean | 0.03 +/- 0.02 | 0.13 +/- 0.04 | 0.06 +/- 0.03 

In [4]:
df_t, df_p = generate_t_p_matrices(results)

with open(
    os.path.join(target_directory, "tables", f"{data_name}_{fit_tag}_{run_tag}_p_matrix.md"), "w"
) as f:
    f.write(df_p.to_markdown())

print(df_p.to_markdown())
df_p

|                       | WeirdCMR              | WeirdPositionScaleCMR   | WeirdNoReinstateCMR    | OutlistCMRDE           |
|:----------------------|:----------------------|:------------------------|:-----------------------|:-----------------------|
| WeirdCMR              |                       | 0.9999999953126066      | 0.5194504154082416     | 0.9999931411531482     |
| WeirdPositionScaleCMR | 4.687393328416847e-09 |                         | 6.2027728317819655e-09 | 1.6182696418657597e-08 |
| WeirdNoReinstateCMR   | 0.48054958459175845   | 0.9999999937972271      |                        | 0.9999783185321464     |
| OutlistCMRDE          | 6.858846851830407e-06 | 0.9999999838173036      | 2.1681467853704367e-05 |                        |


Unnamed: 0,WeirdCMR,WeirdPositionScaleCMR,WeirdNoReinstateCMR,OutlistCMRDE
WeirdCMR,,1.0,0.51945,0.999993
WeirdPositionScaleCMR,0.0,,0.0,0.0
WeirdNoReinstateCMR,0.48055,1.0,,0.999978
OutlistCMRDE,7e-06,1.0,2.2e-05,


In [5]:
with open(
    os.path.join(target_directory, "tables", f"{data_name}_{fit_tag}_{run_tag}_t_matrix.md"), "w"
) as f:
    f.write(df_t.to_markdown())

df_t

Unnamed: 0,WeirdCMR,WeirdPositionScaleCMR,WeirdNoReinstateCMR,OutlistCMRDE
WeirdCMR,,9.968519,0.049458,5.90432
WeirdPositionScaleCMR,-9.968519,,-9.788075,-9.187512
WeirdNoReinstateCMR,-0.049458,9.788075,,5.354752
OutlistCMRDE,-5.90432,9.187512,-5.354752,


In [6]:
aic_weights = calculate_aic_weights(results)

with open(
    os.path.join(target_directory, "tables", f"{data_name}_{fit_tag}_{run_tag}_aic_weights.md"),
    "w",
) as f:
    f.write(aic_weights.to_markdown())

print(aic_weights.to_markdown())
aic_weights

|    | Model                 |   AICw |
|---:|:----------------------|-------:|
|  1 | WeirdPositionScaleCMR |      1 |
|  0 | WeirdCMR              |      0 |
|  2 | WeirdNoReinstateCMR   |      0 |
|  3 | OutlistCMRDE          |      0 |


Unnamed: 0,Model,AICw
1,WeirdPositionScaleCMR,1.0
0,WeirdCMR,0.0
2,WeirdNoReinstateCMR,0.0
3,OutlistCMRDE,0.0


In [7]:
df_comparison = winner_comparison_matrix(results)

with open(
    os.path.join(target_directory, "tables", f"{data_name}_{fit_tag}_{run_tag}_winner_ratios.md"),
    "w",
) as f:
    f.write(df_comparison.to_markdown().replace(" nan ", "     "))

print(df_comparison.to_markdown().replace(" nan ", "     "))

|                       |   WeirdCMR |   WeirdPositionScaleCMR |   WeirdNoReinstateCMR |   OutlistCMRDE |
|:----------------------|-----------:|------------------------:|----------------------:|---------------:|
| WeirdCMR              |            |                       0 |              0.315789 |       0.105263 |
| WeirdPositionScaleCMR |   1        |                         |              1        |       1        |
| WeirdNoReinstateCMR   |   0.684211 |                       0 |                       |       0.105263 |
| OutlistCMRDE          |   0.894737 |                       0 |              0.894737 |                |
