In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
import os
import json
from jaxcmr.summarize import (
    calculate_aic_weights,
    generate_t_p_matrices,
    summarize_parameters,
    winner_comparison_matrix,
)

In [3]:
fit_tag = "full_best_of_3"
fit_dir = "projects/thesis/fits/"
target_directory = "projects/thesis/"

data_names = [
    "LohnasKahana2014",
    "BroitmanKahana2024",
    # "GordonRanschburg2021",
    # "KahanaJacobs2000",
]

model_names = [
    "WeirdCMR",
    "WeirdStudyReinfPositionalCMR",
    "WeirdReinfPositionalCMR",
    "FullWeirdPositionalCMR",
    "WeirdPositionalCMR",
    # "WeirdNoReinstateCMR",
    # "WeirdCMRDistinctContexts",
    # "WeirdPositionScaleCMR",
    # "WeirdAmaxPositionScaleCMR",
    # "OutlistCMRDE",
]

model_titles = [
    "Standard CMR",
    "Reinforced Positional CMR",
    "Reinforced Positional CMR (No Study)",
    "Full Positional CMR",
    "Positional CMR",
    # "No-Reinstate CMR",
    # "Distinct Contexts CMR",
    # "Amax Positional CMR",
    # "CMR-DE",
]

query_parameters = [
    "encoding_drift_rate",
    "start_drift_rate",
    "recall_drift_rate",
    "shared_support",
    "item_support",
    "learning_rate",
    "primacy_scale",
    "primacy_decay",
    "stop_probability_scale",
    "stop_probability_growth",
    "choice_sensitivity",
    # "mfc_trace_sensitivity",
]

In [4]:
run_tag = "Model_Comparison"

if not model_titles:
    model_titles = model_names.copy()

for data_name in data_names:
    print(f"### {data_name}\n")
    results = []

    for model_name, model_title in zip(model_names, model_titles):
        fit_path = os.path.join(fit_dir, f"{data_name}_{model_name}_{fit_tag}.json")

        with open(fit_path) as f:
            results.append(json.load(f))
            if "subject" not in results[-1]["fits"]:
                results[-1]["fits"]["subject"] = results[-1]["subject"]
            if "allow_repeated_recalls" not in results[-1]["fixed"]:
                results[-1]["fixed"]["allow_repeated_recalls"] = False
                results[-1]["fits"]["allow_repeated_recalls"] = [False] * len(
                    results[-1]["fits"]["subject"]
                )
            results[-1]["name"] = model_title
            if "mfc_trace_sensitivity" in results[-1]["free"]:
                results[-1]["free"]["repetition_orthogonality"] = results[-1]["free"][
                    "mfc_trace_sensitivity"
                ]
                results[-1]["fits"]["repetition_orthogonality"] = results[-1]["fits"][
                    "mfc_trace_sensitivity"
                ]
                results[-1]["free"].pop("mfc_trace_sensitivity")
                results[-1]["fits"].pop("mfc_trace_sensitivity")

    summary = summarize_parameters(
        results, query_parameters, include_std=True, include_ci=True
    )

    table_path = os.path.join(
        target_directory, "tables", f"{data_name}_{fit_tag}_{run_tag}_parameters.md"
    )
    with open(table_path, "w") as f:
        f.write(summary)
    print(summary)

    df_t, df_p = generate_t_p_matrices(results)

    print(df_p.to_markdown())
    print()

    aic_weights = calculate_aic_weights(results)

    with open(
        os.path.join(
            target_directory,
            "tables",
            f"{data_name}_{fit_tag}_{run_tag}_aic_weights.md",
        ),
        "w",
    ) as f:
        f.write(aic_weights.to_markdown())

    print(aic_weights.to_markdown())
    print()

    df_comparison = winner_comparison_matrix(results)

    with open(
        os.path.join(
            target_directory,
            "tables",
            f"{data_name}_{fit_tag}_{run_tag}_winner_ratios.md",
        ),
        "w",
    ) as f:
        f.write(df_comparison.to_markdown().replace(" nan ", "     "))

    print(df_comparison.to_markdown().replace(" nan ", "     "))
    print()


### LohnasKahana2014

| | | Standard CMR | Reinforced Positional CMR | Reinforced Positional CMR (No Study) | Full Positional CMR | Positional CMR |
|---|---|---|---|---|---|---|
| fitness | mean | 1668.19 +/- 146.93 | 1658.99 +/- 146.97 | 1659.38 +/- 146.68 | 1660.00 +/- 146.91 | 1661.01 +/- 147.02 |
| | std | 421.58 | 421.70 | 420.85 | 421.51 | 421.82 |
| encoding drift rate | mean | 0.76 +/- 0.04 | 0.75 +/- 0.03 | 0.73 +/- 0.03 | 0.72 +/- 0.03 | 0.70 +/- 0.05 |
| | std | 0.12 | 0.10 | 0.09 | 0.10 | 0.15 |
| start drift rate | mean | 0.50 +/- 0.12 | 0.47 +/- 0.10 | 0.43 +/- 0.10 | 0.43 +/- 0.11 | 0.45 +/- 0.11 |
| | std | 0.35 | 0.30 | 0.30 | 0.30 | 0.30 |
| recall drift rate | mean | 0.94 +/- 0.01 | 0.93 +/- 0.01 | 0.91 +/- 0.02 | 0.93 +/- 0.01 | 0.91 +/- 0.04 |
| | std | 0.03 | 0.04 | 0.06 | 0.04 | 0.12 |
| shared support | mean | 2.24 +/- 1.85 | 9.96 +/- 5.77 | 8.26 +/- 3.92 | 6.48 +/- 3.62 | 8.42 +/- 6.06 |
| | std | 5.30 | 16.54 | 11.26 | 10.38 | 17.39 |
| item support | mean | 