In [None]:
# params for this run
run_tag = "Model_Comparison"
fit_tag = "full_best_of_3"
fit_dir = "fits/"
target_directory = "projects/Logan2021"

# data params
data_name = "GordonRanschburg2021"

# models
# model_names = ["NarrowWeirdInstanceCMRDE", "WeirdInstanceCMRDE", "ConnectionistCMR", "OutlistInstanceCMRDE", "NarrowReinstateOutlistInstanceCMRDE", "ReinstateOutlistCMRDE", "OutlistCMRDE", "ContextCMRDE", "TrueInstanceCMRDE", "ReinstateContextCMRDE", "FlexCMR2"]
# model_names = ["WeirdInstanceCMRDE", "OutlistInstanceCMRDE", "FlexCMRDE", "WeirdPositionalCMR", "AdditiveItemPositionalCMR", "MultiplicativeItemPositionalCMR", "PreexpMfcItemPositionalCMR", "TwoAlphaItemPositionalCMR", "WeirdFlexPositionalCMR"]
model_names = [
    # "BaseCMR",
    "WeirdCMR",
    "WeirdPositionScaleCMR",
    # "NoReinstateCMR",
    "WeirdNoReinstateCMR",
    "OutlistCMRDE",
    # "WeirdNoPrexpPositionCMR",
    # "FlexPositionScaleCMR",
    # "NoPrexpPositionCMR",
    # "NoScaleNoReinstateCMR",
    # "NoScalePositionScaleBaseCMR",
    # "InstanceCMR",
]

# "FakeNarrowOutlistInstanceCMRDE", "FakeOutlistInstanceCMRDE",
# "MultiContextCMRDE", "NormalContextCMRDE", "NormalMultiContextCMRDE",

model_titles = []

# params to focus on in outputs
query_parameters = [
    "encoding_drift_rate",
    "start_drift_rate",
    "recall_drift_rate",
    "shared_support",
    # "item_shared_support",
    # "position_shared_support",
    "item_support",
    "learning_rate",
    "primacy_scale",
    "primacy_decay",
    "stop_probability_scale",
    "stop_probability_growth",
    "choice_sensitivity",
    # "mfc_trace_sensitivity",
    # "mcf_trace_sensitivity",
    "mfc_choice_sensitivity",
    # "positional_scale",
    # "positional_mfc_scale",
    # "semantic_scale",
    # "semantic_choice_sensitivity",
]

In [2]:
import os
import json
from jaxcmr.summarize import (
    summarize_parameters,
    generate_t_p_matrices,
    winner_comparison_matrix,
    calculate_aic_weights,
)

In [3]:
if not model_titles:
    model_titles = model_names.copy()

results = []
for model_name, model_title in zip(model_names, model_titles):
    fit_path = os.path.join(fit_dir, f"{data_name}_{model_name}_{fit_tag}.json")

    with open(fit_path) as f:
        results.append(json.load(f))
        if "subject" not in results[-1]["fits"]:
            results[-1]["fits"]["subject"] = results[-1]["subject"]
        results[-1]["name"] = model_title

summary = summarize_parameters(
    results, query_parameters, include_std=True, include_ci=True
)

with open(
    os.path.join(target_directory, "tables", f"{data_name}_{fit_tag}_{run_tag}_parameters.md"),
    "w",
) as f:
    f.write(summary)
print(summary)


| | | WeirdCMR | WeirdPositionScaleCMR | WeirdNoReinstateCMR | OutlistCMRDE |
|---|---|---|---|---|---|
| fitness | mean | 770.97 +/- 59.74 | 726.69 +/- 60.46 | 771.07 +/- 59.61 | 771.46 +/- 61.54 |
| | std | 138.50 | 140.17 | 138.19 | 142.68 |
| encoding drift rate | mean | 0.84 +/- 0.01 | 0.54 +/- 0.06 | 0.84 +/- 0.01 | 0.82 +/- 0.05 |
| | std | 0.02 | 0.14 | 0.03 | 0.11 |
| start drift rate | mean | 0.77 +/- 0.03 | 0.40 +/- 0.12 | 0.76 +/- 0.03 | 0.70 +/- 0.09 |
| | std | 0.07 | 0.28 | 0.08 | 0.21 |
| recall drift rate | mean | 0.75 +/- 0.02 | 0.67 +/- 0.06 | 0.75 +/- 0.02 | 0.67 +/- 0.07 |
| | std | 0.04 | 0.15 | 0.05 | 0.16 |
| shared support | mean | 48.29 +/- 14.85 | 28.13 +/- 13.56 | 58.42 +/- 14.06 | 49.42 +/- 15.93 |
| | std | 34.42 | 31.43 | 32.59 | 36.94 |
| item support | mean | 28.00 +/- 11.69 | 18.39 +/- 9.69 | 32.60 +/- 11.16 | 31.46 +/- 12.46 |
| | std | 27.10 | 22.47 | 25.87 | 28.88 |
| learning rate | mean | 0.01 +/- 0.01 | 0.19 +/- 0.10 | 0.01 +/- 0.01 | 0.09 +/- 0.

In [4]:
df_t, df_p = generate_t_p_matrices(results)

with open(
    os.path.join(target_directory, "tables", f"{data_name}_{fit_tag}_{run_tag}_p_matrix.md"), "w"
) as f:
    f.write(df_p.to_markdown())

print(df_p.to_markdown())
df_p

|                       | WeirdCMR               | WeirdPositionScaleCMR   | WeirdNoReinstateCMR   | OutlistCMRDE          |
|:----------------------|:-----------------------|:------------------------|:----------------------|:----------------------|
| WeirdCMR              |                        | 0.9999999999998908      | 0.37185584923926396   | 0.4095753524411694    |
| WeirdPositionScaleCMR | 1.0928408201197199e-13 |                         | 6.299530749345191e-14 | 5.624222367663339e-12 |
| WeirdNoReinstateCMR   | 0.628144150760736      | 0.9999999999999369      |                       | 0.42945442558291547   |
| OutlistCMRDE          | 0.5904246475588306     | 0.9999999999943757      | 0.5705455744170845    |                       |


Unnamed: 0,WeirdCMR,WeirdPositionScaleCMR,WeirdNoReinstateCMR,OutlistCMRDE
WeirdCMR,,1.0,0.371856,0.409575
WeirdPositionScaleCMR,0.0,,0.0,0.0
WeirdNoReinstateCMR,0.628144,1.0,,0.429454
OutlistCMRDE,0.590425,1.0,0.570546,


In [5]:
with open(
    os.path.join(target_directory, "tables", f"{data_name}_{fit_tag}_{run_tag}_t_matrix.md"), "w"
) as f:
    f.write(df_t.to_markdown())

df_t

Unnamed: 0,WeirdCMR,WeirdPositionScaleCMR,WeirdNoReinstateCMR,OutlistCMRDE
WeirdCMR,,15.032704,-0.330905,-0.231269
WeirdPositionScaleCMR,-15.032704,,-15.432126,-12.411247
WeirdNoReinstateCMR,0.330905,15.432126,,-0.179768
OutlistCMRDE,0.231269,12.411247,0.179768,


In [6]:
aic_weights = calculate_aic_weights(results)

with open(
    os.path.join(target_directory, "tables", f"{data_name}_{fit_tag}_{run_tag}_aic_weights.md"),
    "w",
) as f:
    f.write(aic_weights.to_markdown())

print(aic_weights.to_markdown())
aic_weights

|    | Model                 |   AICw |
|---:|:----------------------|-------:|
|  1 | WeirdPositionScaleCMR |      1 |
|  0 | WeirdCMR              |      0 |
|  2 | WeirdNoReinstateCMR   |      0 |
|  3 | OutlistCMRDE          |      0 |


Unnamed: 0,Model,AICw
1,WeirdPositionScaleCMR,1.0
0,WeirdCMR,0.0
2,WeirdNoReinstateCMR,0.0
3,OutlistCMRDE,0.0


In [7]:
df_comparison = winner_comparison_matrix(results)

with open(
    os.path.join(target_directory, "tables", f"{data_name}_{fit_tag}_{run_tag}_winner_ratios.md"),
    "w",
) as f:
    f.write(df_comparison.to_markdown().replace(" nan ", "     "))

print(df_comparison.to_markdown().replace(" nan ", "     "))

|                       |   WeirdCMR |   WeirdPositionScaleCMR |   WeirdNoReinstateCMR |   OutlistCMRDE |
|:----------------------|-----------:|------------------------:|----------------------:|---------------:|
| WeirdCMR              |            |                       0 |              0.583333 |       0.416667 |
| WeirdPositionScaleCMR |   1        |                         |              1        |       1        |
| WeirdNoReinstateCMR   |   0.416667 |                       0 |                       |       0.416667 |
| OutlistCMRDE          |   0.583333 |                       0 |              0.583333 |                |
