In [11]:
import sys
from pathlib import Path
import pandas as pd
import numpy as np
import matplotlib as mpl
import matplotlib.pyplot as plt
from matplotlib.patches import RegularPolygon
from datetime import datetime
import networkx as nx
import statsmodels.api as sm

sys.path.insert(0, str(Path.cwd().parent))

from src.behavior_import.import_data import *
from src.behavior_import.extract_trials import *
from src.behavior_analysis.get_good_reversal_info import *
from src.behavior_analysis.get_choice_probs_around_good_reversals import *
from src.behavior_analysis.split_early_late_good_reversals import *
from src.behavior_analysis.get_first_leave_after_good_reversals import *
from src.behavior_analysis.get_rank_counts_by_good_reversal import *
from src.behavior_visualization.plot_num_reversals import *
from src.behavior_visualization.plot_first_leave_after_good_reversals import *

In [12]:
root = "../data/cohort-02/rawdata/"
subjects_data = import_data(root)

[INFO] Processed 6 subjects(s), 108 session(s).


In [13]:
subjects_trials = extract_trials(subjects_data)

[INFO] Merging multiple files for subject MY_05_N, session ses-2_date-20260111
[INFO] Merging multiple files for subject MY_05_L, session ses-8_date-20260114
[INFO] Merging multiple files for subject MY_05_L, session ses-9_date-20260115


In [17]:
from scipy.stats import friedmanchisquare, wilcoxon
from statsmodels.stats.multitest import multipletests

best = []
second = []
third = []

good_reversal_info = get_good_reversal_info(subjects_trials)
rank_counts_by_good_reversal = get_rank_counts_by_good_reversal(good_reversal_info)

for subj, rows in rank_counts_by_good_reversal.items():
    if not rows:
        continue
    best.append(np.nanmean([r["best_prop"] for r in rows]))
    second.append(np.nanmean([r["second_prop"] for r in rows]))
    third.append(np.nanmean([r["third_prop"] for r in rows]))

best = np.array(best, float)
second = np.array(second, float)
third = np.array(third, float)

stat, p = friedmanchisquare(best, second, third)
print("Friedman:", stat, p)

pairs = [("best-second", best, second),
         ("best-third", best, third),
         ("second-third", second, third)]

pvals = []
stats = []
names = []
for name, a, b in pairs:
    s, pv = wilcoxon(a, b, zero_method="wilcox", alternative="two-sided")
    names.append(name); stats.append(s); pvals.append(pv)

reject, p_corr, _, _ = multipletests(pvals, method="holm")
for name, s, pc in zip(names, stats, p_corr):
    print(f"{name}: W={s:.3f}, p_holm={pc:.4g}")

best = []
second = []
third = []

for subj, rows in rank_counts_by_good_reversal.items():
    if not rows:
        continue
    best.append(np.nanmean([r["best_prop"] for r in rows]))
    second.append(np.nanmean([r["second_prop"] for r in rows]))
    third.append(np.nanmean([r["third_prop"] for r in rows]))

best = np.array(best)
second = np.array(second)
third = np.array(third)

chance = 1/3

w_best, p_best = wilcoxon(best - chance, alternative="greater")
w_second, p_second = wilcoxon(second - chance, alternative="greater")
w_third, p_third = wilcoxon(third - chance, alternative="greater")

print("Best > chance:", w_best, p_best)
print("Second > chance:", w_second, p_second)
print("Third > chance:", w_third, p_third)

pvals = [p_best, p_second, p_third]
reject, p_corr, _, _ = multipletests(pvals, method="holm")

for name, p in zip(["Best", "Second", "Third"], p_corr):
    print(f"{name} > chance (Holm): p = {p:.4g}")

[SKIP] MY_05_L reversal@401 (block 8): reward magnitudes before reversal were [4, 0, 0] across towers ['A3', 'A1', 'C3'] (expected a permutation of [4, 1, 0])
Friedman: 9.333333333333329 0.00940356255149523
best-second: W=0.000, p_holm=0.09375
best-third: W=4.000, p_holm=0.2188
second-third: W=0.000, p_holm=0.09375
Best > chance: 21.0 0.015625
Second > chance: 0.0 1.0
Third > chance: 18.0 0.078125
Best > chance (Holm): p = 0.04688
Second > chance (Holm): p = 1
Third > chance (Holm): p = 0.1562


In [30]:
merged, _ = get_vars_across_all_sessions(subjects_trials)



In [32]:
merged['MY_04_L'].keys()

dict_keys(['trial', 'good_reversals', 'bad_reversals', 'blocks', 'trials_in_block', 'reward_magnitudes_by_tower', 'choices_by_tower', 'choices_by_rank'])

In [4]:
def make_arm_mapping(tower_keys):
    """Deterministic mapping: sorted tower keys -> Arm1/Arm2/Arm3."""
    keys_sorted = sorted(list(tower_keys))
    return {k: f"Arm{i+1}" for i, k in enumerate(keys_sorted)}

def remap_tower_dict_to_arms(d, mapping):
    """Remap {tower: list} -> {ArmX: list} using mapping[tower]."""
    out = {}
    for tower, arr in d.items():
        if tower in mapping:
            out[mapping[tower]] = arr
    return out

def get_chosen_reward_per_trial(choices_by_tower, reward_magnitudes_by_tower):
    """
    Returns:
        rewards: list length n_trials, where rewards[t] is the reward magnitude
                 of the chosen tower at trial t.

    Args:
        strict: if True, raises if a trial is not exactly one-hot.
                if False, returns fill_value for ambiguous/missing trials.
        fill_value: value to use when strict=False and choice is ambiguous/missing.
    """
    towers = list(reward_magnitudes_by_tower.keys())
    if not towers:
        return []

    # Use the minimum available length across towers to avoid IndexError
    n_trials = min(len(reward_magnitudes_by_tower[t]) for t in towers)

    # Ensure we have choice arrays for the same towers
    missing = [t for t in towers if t not in choices_by_tower]
    if missing:
        raise KeyError(f"choices_by_tower missing towers: {missing}")

    # Also ensure choice arrays are long enough
    n_trials = min(n_trials, *(len(choices_by_tower[t]) for t in towers))

    rewards = []
    for i in range(n_trials):
        chosen = [t for t in towers if choices_by_tower[t][i] == 1]

        if len(chosen) != 1:
            raise ValueError(f"Trial {i}: expected exactly one chosen tower, got {chosen}")

        t = chosen[0]
        rewards.append(reward_magnitudes_by_tower[t][i])

    return rewards

In [5]:
import numpy as np

def format_glm_input(data, input_dimension=29, drop_redundant_cols=True):
    merged_data_across_subjects, _ = get_vars_across_all_sessions(data)
    regressors_across_subjects = {}
    true_choices_across_subjects = {}

    # columns that are redundant given one-hot choice structure
    drop_cols = [3, 6, 9, 12, 19, 22, 25, 28]  # Arm3 for choices + Arm3 for interactions
    keep_cols = [i for i in range(input_dimension) if i not in drop_cols]

    for subject in merged_data_across_subjects.keys():
        current_subject_data = merged_data_across_subjects[subject]

        choices_by_tower = current_subject_data["choices_by_tower"]
        chosen_rewards = get_chosen_reward_per_trial(
            choices_by_tower,
            current_subject_data["reward_magnitudes_by_tower"]
        )

        tower_to_arm = make_arm_mapping(choices_by_tower.keys())
        choices_by_arm = remap_tower_dict_to_arms(choices_by_tower, tower_to_arm)

        num_trials = len(current_subject_data["trial"])
        regressors = np.zeros((num_trials, input_dimension))
        true_choices = np.zeros((num_trials, len(choices_by_tower.keys())))

        for t in range(1, num_trials):
            # explicit ordering = Arm1, Arm2, Arm3
            true_choices[t, :] = [
                choices_by_arm["Arm1"][t],
                choices_by_arm["Arm2"][t],
                choices_by_arm["Arm3"][t],
            ]

            regressors[t, 0] = 1  # Bias term

            # lag 1
            regressors[t, 1] = choices_by_arm["Arm1"][t - 1]
            regressors[t, 2] = choices_by_arm["Arm2"][t - 1]
            regressors[t, 3] = choices_by_arm["Arm3"][t - 1]
            regressors[t, 13] = chosen_rewards[t - 1]
            regressors[t, 17] = choices_by_arm["Arm1"][t - 1] * chosen_rewards[t - 1]
            regressors[t, 18] = choices_by_arm["Arm2"][t - 1] * chosen_rewards[t - 1]
            regressors[t, 19] = choices_by_arm["Arm3"][t - 1] * chosen_rewards[t - 1]

            # lag 2
            if t >= 2:
                regressors[t, 4] = choices_by_arm["Arm1"][t - 2]
                regressors[t, 5] = choices_by_arm["Arm2"][t - 2]
                regressors[t, 6] = choices_by_arm["Arm3"][t - 2]
                regressors[t, 14] = chosen_rewards[t - 2]
                regressors[t, 20] = choices_by_arm["Arm1"][t - 2] * chosen_rewards[t - 2]
                regressors[t, 21] = choices_by_arm["Arm2"][t - 2] * chosen_rewards[t - 2]
                regressors[t, 22] = choices_by_arm["Arm3"][t - 2] * chosen_rewards[t - 2]

            # lag 3
            if t >= 3:
                regressors[t, 7] = choices_by_arm["Arm1"][t - 3]
                regressors[t, 8] = choices_by_arm["Arm2"][t - 3]
                regressors[t, 9] = choices_by_arm["Arm3"][t - 3]
                regressors[t, 15] = chosen_rewards[t - 3]
                regressors[t, 23] = choices_by_arm["Arm1"][t - 3] * chosen_rewards[t - 3]
                regressors[t, 24] = choices_by_arm["Arm2"][t - 3] * chosen_rewards[t - 3]
                regressors[t, 25] = choices_by_arm["Arm3"][t - 3] * chosen_rewards[t - 3]

            # lag 4
            if t >= 4:
                regressors[t, 10] = choices_by_arm["Arm1"][t - 4]
                regressors[t, 11] = choices_by_arm["Arm2"][t - 4]
                regressors[t, 12] = choices_by_arm["Arm3"][t - 4]
                regressors[t, 16] = chosen_rewards[t - 4]
                regressors[t, 26] = choices_by_arm["Arm1"][t - 4] * chosen_rewards[t - 4]
                regressors[t, 27] = choices_by_arm["Arm2"][t - 4] * chosen_rewards[t - 4]
                regressors[t, 28] = choices_by_arm["Arm3"][t - 4] * chosen_rewards[t - 4]

        # drop first trial (no history)
        regressors = regressors[1:, :]
        true_choices = true_choices[1:, :]

        # drop redundant columns to make X full rank
        if drop_redundant_cols:
            regressors = regressors[:, keep_cols]

        regressors_across_subjects[subject] = regressors
        true_choices_across_subjects[subject] = true_choices

    return regressors_across_subjects, true_choices_across_subjects, keep_cols

In [None]:
def format_glm_input(data, input_dimension=29):
    merged_data_across_subjects, _ = get_vars_across_all_sessions(data)
    regressors_across_subjects = {}
    true_choices_across_subjects = {}
    for subject in merged_data_across_subjects.keys():
        current_subject_data = merged_data_across_subjects[subject]

        choices_by_tower = current_subject_data["choices_by_tower"]
        chosen_rewards = get_chosen_reward_per_trial(choices_by_tower, current_subject_data["reward_magnitudes_by_tower"])

        tower_to_arm = make_arm_mapping(choices_by_tower.keys())
        choices_by_arm = remap_tower_dict_to_arms(choices_by_tower, tower_to_arm)

        num_trials = len(current_subject_data['trial'])
        regressors = np.zeros((num_trials, input_dimension))
        true_choices = np.zeros((num_trials, len(choices_by_tower.keys())))

        # Fill input features
        # 0: bias
        # 1: prev choice arm 1
        # 2: prev choice arm 2
        # 3: prev choice arm 3
        # 4: n - 1 choice arm 1
        # 5: n - 1 choice arm 2
        # 6: n - 1 choice arm 3
        # 7: n - 2 choice arm 1
        # 8: n - 2 choice arm 2
        # 9: n - 2 choice arm 3
        # 10: n - 3 choice arm 1
        # 11: n - 3 choice arm 2
        # 12: n - 3 choice arm 3
        # 13: prev reward
        # 14: n - 1 reward
        # 15: n - 2 reward
        # 16: n - 3 reward
        # 17: prev choice arm 1 x reward
        # 18: prev choice arm 2 x reward
        # 19: prev choice arm 3 x reward
        # 20: n - 1 choice arm 1 x reward
        # 21: n - 1 choice arm 2 x reward
        # 22: n - 1 choice arm 3 x reward
        # 23: n - 2 choice arm 1 x reward
        # 24: n - 2 choice arm 2 x reward
        # 25: n - 2 choice arm 3 x reward
        # 26: n - 3 choice arm 1 x reward
        # 27: n - 3 choice arm 2 x reward
        # 28: n - 3 choice arm 3 x reward
        for t in range(1, num_trials):
            true_choices[t, :] = [choices_by_arm[arm][t] for arm in sorted(choices_by_arm.keys())]

            regressors[t, 0] = 1  # Bias term
            regressors[t, 1] = choices_by_arm['Arm1'][t-1]
            regressors[t, 2] = choices_by_arm['Arm2'][t-1]
            regressors[t, 3] = choices_by_arm['Arm3'][t-1]
            regressors[t, 13] = chosen_rewards[t-1]
            regressors[t, 17] = choices_by_arm['Arm1'][t-1] * chosen_rewards[t-1]
            regressors[t, 18] = choices_by_arm['Arm2'][t-1] * chosen_rewards[t-1]
            regressors[t, 19] = choices_by_arm['Arm3'][t-1] * chosen_rewards[t-1]

            if t >= 2:
                regressors[t, 4] = choices_by_arm['Arm1'][t-2]
                regressors[t, 5] = choices_by_arm['Arm2'][t-2]
                regressors[t, 6] = choices_by_arm['Arm3'][t-2]
                regressors[t, 14] = chosen_rewards[t-2]
                regressors[t, 20] = choices_by_arm['Arm1'][t-2] * chosen_rewards[t-2]
                regressors[t, 21] = choices_by_arm['Arm2'][t-2] * chosen_rewards[t-2]
                regressors[t, 22] = choices_by_arm['Arm3'][t-2] * chosen_rewards[t-2]
            if t >= 3:
                regressors[t, 7] = choices_by_arm['Arm1'][t-3]
                regressors[t, 8] = choices_by_arm['Arm2'][t-3]
                regressors[t, 9] = choices_by_arm['Arm3'][t-3]
                regressors[t, 15] = chosen_rewards[t-3]
                regressors[t, 23] = choices_by_arm['Arm1'][t-3] * chosen_rewards[t-3]
                regressors[t, 24] = choices_by_arm['Arm2'][t-3] * chosen_rewards[t-3]
                regressors[t, 25] = choices_by_arm['Arm3'][t-3] * chosen_rewards[t-3]
            if t >= 4:
                regressors[t, 10] = choices_by_arm['Arm1'][t-4]
                regressors[t, 11] = choices_by_arm['Arm2'][t-4]
                regressors[t, 12] = choices_by_arm['Arm3'][t-4]
                regressors[t, 16] = chosen_rewards[t-4]
                regressors[t, 26] = choices_by_arm['Arm1'][t-4] * chosen_rewards[t-4]
                regressors[t, 27] = choices_by_arm['Arm2'][t-4] * chosen_rewards[t-4]
                regressors[t, 28] = choices_by_arm['Arm3'][t-4] * chosen_rewards[t-4]
        regressors = regressors[1:, :]
        true_choices = true_choices[1:, :]
        regressors_across_subjects[subject] = regressors
        true_choices_across_subjects[subject] = true_choices
    return regressors_across_subjects, true_choices_across_subjects

In [8]:
regressors_across_subjects, true_choices_across_subjects, keep_cols = format_glm_input(subjects_data)



In [9]:
y = true_choices_across_subjects['MY_04_L']
X = regressors_across_subjects['MY_04_L']

In [10]:
X_by, Y_by, keep_cols = format_glm_input(subjects_trials)

X = X_by["MY_04_L"]
Y = Y_by["MY_04_L"]
y = np.argmax(Y, axis=1)

print("rank vs p:", np.linalg.matrix_rank(X), X.shape[1])  # should match

mn = sm.MNLogit(y, X)
res = mn.fit(method="bfgs", maxiter=500, disp=False)
print(res.summary())

rank vs p: 21 21
                          MNLogit Regression Results                          
Dep. Variable:                      y   No. Observations:                  694
Model:                        MNLogit   Df Residuals:                      652
Method:                           MLE   Df Model:                           40
Date:                Mon, 19 Jan 2026   Pseudo R-squ.:                 0.08920
Time:                        17:25:54   Log-Likelihood:                -691.33
converged:                       True   LL-Null:                       -759.03
Covariance Type:            nonrobust   LLR p-value:                 2.706e-12
       y=1       coef    std err          z      P>|z|      [0.025      0.975]
------------------------------------------------------------------------------
const          0.5762      0.366      1.575      0.115      -0.141       1.293
x1            -0.3465      0.455     -0.762      0.446      -1.238       0.545
x2             0.2412      0.398   

In [21]:
import numpy as np

X = regressors_across_subjects["MY_04_L"]
print("rank:", np.linalg.matrix_rank(X), "p:", X.shape[1])


rank: 24 p: 29
