In [1]:
# params for this run
run_tag = "Model_Comparison"
fit_tag = "full_best_of_3"
fit_dir = "fits/"
target_directory = ""

# data params
data_name = "LohnasKahana2014"

# models
# model_names = ["NarrowWeirdInstanceCMRDE", "WeirdInstanceCMRDE", "ConnectionistCMR", "OutlistInstanceCMRDE", "NarrowReinstateOutlistInstanceCMRDE", "ReinstateOutlistCMRDE", "OutlistCMRDE", "ContextCMRDE", "TrueInstanceCMRDE", "ReinstateContextCMRDE", "FlexCMR2"]
# model_names = ["WeirdInstanceCMRDE", "OutlistInstanceCMRDE", "FlexCMRDE", "WeirdPositionalCMR", "AdditiveItemPositionalCMR", "MultiplicativeItemPositionalCMR", "PreexpMfcItemPositionalCMR", "TwoAlphaItemPositionalCMR", "WeirdFlexPositionalCMR"]
model_names = [
    # "BaseCMR",
    "WeirdCMR",
    "WeirdPositionScaleCMR",
    # "NoReinstateCMR",
    "WeirdNoReinstateCMR",
    "OutlistCMRDE",
    "WeirdNoPrexpPositionCMR",
    "FlexPositionScaleCMR",
    "NoPrexpPositionCMR",
    "WeirdPrimacyPositionScaleCMR"
    # "NoScaleNoReinstateCMR",
    # "NoScalePositionScaleBaseCMR",
    # "InstanceCMR",
]

# "FakeNarrowOutlistInstanceCMRDE", "FakeOutlistInstanceCMRDE",
# "MultiContextCMRDE", "NormalContextCMRDE", "NormalMultiContextCMRDE",

model_titles = []

# params to focus on in outputs
query_parameters = [
    "encoding_drift_rate",
    "start_drift_rate",
    "recall_drift_rate",
    "shared_support",
    # "item_shared_support",
    # "position_shared_support",
    "item_support",
    "learning_rate",
    "primacy_scale",
    "primacy_decay",
    "stop_probability_scale",
    "stop_probability_growth",
    "choice_sensitivity",
    # "mfc_trace_sensitivity",
    # "mcf_trace_sensitivity",
    "mfc_choice_sensitivity",
    # "positional_scale",
    # "positional_mfc_scale",
    # "semantic_scale",
    # "semantic_choice_sensitivity",
]

In [2]:
import os
import json
from jaxcmr.summarize import (
    summarize_parameters,
    generate_t_p_matrices,
    winner_comparison_matrix,
    calculate_aic_weights,
)

In [3]:
if not model_titles:
    model_titles = model_names.copy()

results = []
for model_name, model_title in zip(model_names, model_titles):
    fit_path = os.path.join(fit_dir, f"{data_name}_{model_name}_{fit_tag}.json")

    with open(fit_path) as f:
        results.append(json.load(f))
        if "subject" not in results[-1]["fits"]:
            results[-1]["fits"]["subject"] = results[-1]["subject"]
        results[-1]["name"] = model_title

summary = summarize_parameters(
    results, query_parameters, include_std=True, include_ci=True
)

with open(
    os.path.join(target_directory, "tables", f"{data_name}_{fit_tag}_{run_tag}_parameters.md"),
    "w",
) as f:
    f.write(summary)
print(summary)


| | | WeirdCMR | WeirdPositionScaleCMR | WeirdNoReinstateCMR | OutlistCMRDE | WeirdNoPrexpPositionCMR | FlexPositionScaleCMR | NoPrexpPositionCMR | WeirdPrimacyPositionScaleCMR |
|---|---|---|---|---|---|---|---|---|---|
| fitness | mean | 1667.75 +/- 146.53 | 1659.83 +/- 146.81 | 1667.86 +/- 146.86 | 1660.34 +/- 146.28 | 1659.83 +/- 146.78 | 1660.92 +/- 147.25 | 1663.70 +/- 147.32 | 1660.20 +/- 147.09 |
| | std | 420.43 | 421.22 | 421.38 | 419.71 | 421.13 | 422.50 | 422.70 | 422.03 |
| encoding drift rate | mean | 0.77 +/- 0.04 | 0.71 +/- 0.04 | 0.77 +/- 0.04 | 0.75 +/- 0.04 | 0.72 +/- 0.04 | 0.70 +/- 0.04 | 0.67 +/- 0.05 | 0.72 +/- 0.04 |
| | std | 0.12 | 0.12 | 0.10 | 0.12 | 0.12 | 0.11 | 0.13 | 0.13 |
| start drift rate | mean | 0.51 +/- 0.12 | 0.45 +/- 0.10 | 0.37 +/- 0.12 | 0.41 +/- 0.10 | 0.37 +/- 0.10 | 0.44 +/- 0.10 | 0.49 +/- 0.10 | 0.46 +/- 0.10 |
| | std | 0.35 | 0.29 | 0.34 | 0.29 | 0.30 | 0.29 | 0.30 | 0.28 |
| recall drift rate | mean | 0.93 +/- 0.02 | 0.91 +/- 0.02 | 0.

In [4]:
df_t, df_p = generate_t_p_matrices(results)

with open(
    os.path.join(target_directory, "tables", f"{data_name}_{fit_tag}_{run_tag}_p_matrix.md"), "w"
) as f:
    f.write(df_p.to_markdown())

print(df_p.to_markdown())
df_p

|                              | WeirdCMR               | WeirdPositionScaleCMR   | WeirdNoReinstateCMR    | OutlistCMRDE        | WeirdNoPrexpPositionCMR   | FlexPositionScaleCMR   | NoPrexpPositionCMR    | WeirdPrimacyPositionScaleCMR   |
|:-----------------------------|:-----------------------|:------------------------|:-----------------------|:--------------------|:--------------------------|:-----------------------|:----------------------|:-------------------------------|
| WeirdCMR                     |                        | 0.9999961851599968      | 0.4607493193733744     | 0.9999993900826833  | 0.9999988540109558        | 0.9997562412149947     | 0.9980435882732807    | 0.9999969592895686             |
| WeirdPositionScaleCMR        | 3.814840003284842e-06  |                         | 1.5893901076511372e-07 | 0.3607711231613683  | 0.49710342997778967       | 0.10391516401755524    | 0.0025709868468370236 | 0.2863453573401475             |
| WeirdNoReinstateCMR          | 0.5

Unnamed: 0,WeirdCMR,WeirdPositionScaleCMR,WeirdNoReinstateCMR,OutlistCMRDE,WeirdNoPrexpPositionCMR,FlexPositionScaleCMR,NoPrexpPositionCMR,WeirdPrimacyPositionScaleCMR
WeirdCMR,,0.999996,0.460749,0.999999,0.999999,0.999756,0.998044,0.999997
WeirdPositionScaleCMR,4e-06,,0.0,0.360771,0.497103,0.103915,0.002571,0.286345
WeirdNoReinstateCMR,0.539251,1.0,,0.999998,1.0,0.999982,0.999207,1.0
OutlistCMRDE,1e-06,0.639229,2e-06,,0.645834,0.367588,0.010432,0.541711
WeirdNoPrexpPositionCMR,1e-06,0.502897,0.0,0.354166,,0.215913,8e-06,0.36474
FlexPositionScaleCMR,0.000244,0.896085,1.8e-05,0.632412,0.784087,,0.038199,0.778152
NoPrexpPositionCMR,0.001956,0.997429,0.000793,0.989568,0.999992,0.961801,,0.9961
WeirdPrimacyPositionScaleCMR,3e-06,0.713655,0.0,0.458289,0.63526,0.221848,0.0039,


In [5]:
with open(
    os.path.join(target_directory, "tables", f"{data_name}_{fit_tag}_{run_tag}_t_matrix.md"), "w"
) as f:
    f.write(df_t.to_markdown())

df_t

Unnamed: 0,WeirdCMR,WeirdPositionScaleCMR,WeirdNoReinstateCMR,OutlistCMRDE,WeirdNoPrexpPositionCMR,FlexPositionScaleCMR,NoPrexpPositionCMR,WeirdPrimacyPositionScaleCMR
WeirdCMR,,5.271482,-0.09928,5.883552,5.672886,3.856557,3.096077,5.347219
WeirdPositionScaleCMR,-5.271482,,-6.334471,-0.359368,-0.007314,-1.283979,-2.991125,-0.569604
WeirdNoReinstateCMR,0.09928,6.334471,,5.525735,7.061236,4.750718,3.433346,6.453527
OutlistCMRDE,-5.883552,0.359368,-5.525735,,0.377251,-0.341036,-2.423089,0.105527
WeirdNoPrexpPositionCMR,-5.672886,0.007314,-7.061236,-0.377251,,-0.795528,-5.031358,-0.348679
FlexPositionScaleCMR,-3.856557,1.283979,-4.750718,0.341036,0.795528,,-1.827594,0.775002
NoPrexpPositionCMR,-3.096077,2.991125,-3.433346,2.423089,5.031358,1.827594,,2.827788
WeirdPrimacyPositionScaleCMR,-5.347219,0.569604,-6.453527,-0.105527,0.348679,-0.775002,-2.827788,


In [6]:
aic_weights = calculate_aic_weights(results)

with open(
    os.path.join(target_directory, "tables", f"{data_name}_{fit_tag}_{run_tag}_aic_weights.md"),
    "w",
) as f:
    f.write(aic_weights.to_markdown())

print(aic_weights.to_markdown())
aic_weights

|    | Model                        |         AICw |
|---:|:-----------------------------|-------------:|
|  1 | WeirdPositionScaleCMR        | 0.571049     |
|  4 | WeirdNoPrexpPositionCMR      | 0.428951     |
|  7 | WeirdPrimacyPositionScaleCMR | 3.55099e-07  |
|  3 | OutlistCMRDE                 | 3.24986e-09  |
|  5 | FlexPositionScaleCMR         | 5.27696e-18  |
|  6 | NoPrexpPositionCMR           | 8.5139e-60   |
|  0 | WeirdCMR                     | 1.90224e-121 |
|  2 | WeirdNoReinstateCMR          | 3.98598e-123 |


Unnamed: 0,Model,AICw
1,WeirdPositionScaleCMR,0.5710489
4,WeirdNoPrexpPositionCMR,0.4289507
7,WeirdPrimacyPositionScaleCMR,3.550992e-07
3,OutlistCMRDE,3.249855e-09
5,FlexPositionScaleCMR,5.276964e-18
6,NoPrexpPositionCMR,8.513900000000001e-60
0,WeirdCMR,1.902238e-121
2,WeirdNoReinstateCMR,3.985982e-123


In [7]:
df_comparison = winner_comparison_matrix(results)

with open(
    os.path.join(target_directory, "tables", f"{data_name}_{fit_tag}_{run_tag}_winner_ratios.md"),
    "w",
) as f:
    f.write(df_comparison.to_markdown().replace(" nan ", "     "))

print(df_comparison.to_markdown().replace(" nan ", "     "))

|                              |   WeirdCMR |   WeirdPositionScaleCMR |   WeirdNoReinstateCMR |   OutlistCMRDE |   WeirdNoPrexpPositionCMR |   FlexPositionScaleCMR |   NoPrexpPositionCMR |   WeirdPrimacyPositionScaleCMR |
|:-----------------------------|-----------:|------------------------:|----------------------:|---------------:|--------------------------:|-----------------------:|---------------------:|-------------------------------:|
| WeirdCMR                     |            |               0.142857  |              0.571429 |       0.114286 |                  0.142857 |               0.257143 |             0.257143 |                      0.142857  |
| WeirdPositionScaleCMR        |   0.857143 |                         |              0.914286 |       0.542857 |                  0.542857 |               0.6      |             0.657143 |                      0.514286  |
| WeirdNoReinstateCMR          |   0.428571 |               0.0857143 |                       |       0.2      |