In [1]:
import sys
from pathlib import Path
import pandas as pd
import numpy as np
import matplotlib as mpl
import matplotlib.pyplot as plt
from matplotlib.patches import RegularPolygon
from datetime import datetime
import networkx as nx
import statsmodels.api as sm

sys.path.insert(0, str(Path.cwd().parent))

from src.behavior_import.import_data import *
from src.behavior_import.extract_trials import *
from src.behavior_import.filter_trials_by_subject import *
from src.behavior_analysis.get_good_reversal_info import *
from src.behavior_analysis.get_choice_probs_around_good_reversals import *
from src.behavior_analysis.split_early_late_good_reversals import *
from src.behavior_analysis.get_first_leave_after_good_reversals import *
from src.behavior_analysis.get_rank_counts_by_good_reversal import *
from src.behavior_analysis.get_bad_reversal_info import *
from src.behavior_analysis.get_diagnostic_p_value import *
from src.behavior_visualization.plot_num_reversals import *
from src.behavior_visualization.plot_first_leave_after_good_reversals import *
from src.behavior_visualization.plot_choice_probs_around_good_reversals import *

In [2]:
root = "../data/cohort-02/rawdata/"
subjects_data = import_data(root)
subjects_trials = extract_trials(subjects_data)

[INFO] Processed 6 subjects(s), 159 session(s).
[INFO] Merging multiple files for subject MY_05_L, session ses-8_date-20260114
[INFO] Merging multiple files for subject MY_05_L, session ses-9_date-20260115
[INFO] Merging multiple files for subject MY_05_N, session ses-2_date-20260111


In [3]:
print(subjects_trials["MY_04_L"]["ses-1_date-20260111"]["num_long_pokes"])

[1, 2, 2, 5, 1, 5, 6]


In [4]:
all_subjects = list(subjects_trials.keys())
avg_num_pokes_across_subjects = {}
for subject in all_subjects:
    for session in subjects_trials[subject].keys():
        try:
            num_long_pokes = subjects_trials[subject][session]["num_long_pokes"]
            avg_num_pokes_across_subjects.setdefault(subject, {})
            avg_num_pokes_across_subjects[subject][session] = np.mean(num_long_pokes)
        except:
            print(f"Missing num_long_pokes for {subject} {session}")
for session in avg_num_pokes_across_subjects[all_subjects[0]].keys():
    session_avg = np.mean([avg_num_pokes_across_subjects[subject][session] for subject in all_subjects if session in avg_num_pokes_across_subjects[subject]])
    print(f"{session}: {session_avg}")

Missing num_long_pokes for MY_05_N ses-1_date-20260111
ses-1_date-20260111: 2.5074175824175824
ses-2_date-20260111: 2.6159722222222226
ses-3_date-20260112: 2.180991462241462
ses-4_date-20260112: 1.239031339031339
ses-5_date-20260113: 0.9444310478793237
ses-6_date-20260113: 0.4512290391600737
ses-7_date-20260114: 0.295637326205347
ses-8_date-20260114: 0.2910994227833318
ses-9_date-20260115: 0.26900473564272065
ses-10_date-20260115: 0.2961140289449113
ses-11_date-20260116: 0.3294940685929058
ses-12_date-20260116: 0.30860113765288144
ses-13_date-20260117: 0.2385524903301586
ses-14_date-20260117: 0.09346263758028463
ses-15_date-20260118: 0.12142921676851304
ses-16_date-20260118: 0.15415722471989543
ses-17_date-20260119: 0.09249757961521032
ses-18_date-20260119: 0.049700956937799044
ses-19_date-20260120: 0.07277092065589845
ses-20_date-20260120: 0.06130221751806451
ses-21_date-20260121: 0.0856504976305757
ses-22_date-20260121: 0.017154431216931217
ses-23_date-20260122: 0.040164420917248235


In [5]:
def plot_num_pokes_across_mice(avg_num_pokes_across_subjects):

    per_mouse_blocklens = {}
    for mouse, rev_list in avg_num_pokes_across_subjects.items():
        if not rev_list:
            continue

        rev_sorted = sorted(rev_list, key=lambda r: r["reversal_idx"])
        rev_idxs = [r["reversal_idx"] for r in rev_sorted]

        if len(rev_idxs) < 2:
            continue

        lens = [rev_idxs[i + 1] - rev_idxs[i] for i in range(len(rev_idxs) - 1)]
        per_mouse_blocklens[mouse] = {b + 1: float(L) for b, L in enumerate(lens)}

    if not per_mouse_blocklens:
        raise ValueError("No block lengths computed. Need >=2 reversal boundaries per mouse.")

    # ---- 2) aggregate across mice per block number ----
    block_to_vals = defaultdict(list)
    for mouse, bl in per_mouse_blocklens.items():
        for b, L in bl.items():
            if np.isfinite(L):
                block_to_vals[b].append(L)

    blocks = sorted(block_to_vals.keys())
    meds = np.array([np.median(block_to_vals[b]) for b in blocks], dtype=float)

    # SE across mice (ddof=1 if >=2 mice have that block)
    ses = []
    for b in blocks:
        vals = np.asarray(block_to_vals[b], dtype=float)
        if len(vals) >= 2:
            ses.append(np.std(vals, ddof=1) / np.sqrt(len(vals)))
        else:
            ses.append(0.0)
    ses = np.asarray(ses, dtype=float)

    # ---- 3) plot ----
    fig, ax = plt.subplots(figsize=(10, 4.5))

    ax.bar(x,means,yerr=errs,capsize=6,edgecolor="black",linewidth=1.5,alpha=0.55,color=["#999999", "#999999"])

    # Per-mouse points (and optional connecting lines)
    for m in mice:
        c = mouse_to_color[m]
        xs, ys = [], []
        for b in blocks:
            if b in per_mouse_blocklens[m]:
                xs.append(b)
                ys.append(per_mouse_blocklens[m][b])

        if not xs:
            continue

        xs = np.asarray(xs, dtype=float)
        ys = np.asarray(ys, dtype=float)

        if jitter > 0:
            xs = xs + np.random.uniform(-jitter, jitter, size=len(xs))

        if show_lines and len(xs) >= 2:
            ax.plot(xs, ys, color=c, linewidth=2.5, alpha=0.9, zorder=4)

        ax.scatter(xs, ys, s=70, color=c, edgecolor="white", linewidth=0.8, zorder=5)

    # ---- 4) styling like your reversal plot ----
    ax.set_title(title, fontsize=16, pad=12)
    ax.set_xlabel("Block number", fontsize=16)
    ax.set_ylabel("Block length (trials)", fontsize=16)

    ax.set_xticks(blocks)
    ax.tick_params(axis="both", labelsize=13)

    ax.spines["top"].set_visible(False)
    ax.spines["right"].set_visible(False)

    # Legend: one handle per mouse (prevents the “two colors” issue)
    handles = [
        Line2D([0], [0], marker="o", linestyle="-",
               color=mouse_to_color[m], markerfacecolor=mouse_to_color[m],
               markeredgecolor="white", markeredgewidth=0.8,
               linewidth=2.5, markersize=9, label=m)
        for m in mice
    ]
    ax.legend(handles=handles, loc="upper right", frameon=True, fontsize=11)

    plt.tight_layout()
    plt.show()

    return per_mouse_blocklens, block_to_vals

In [6]:
print(avg_num_pokes_across_subjects["MY_04_L"].keys())

dict_keys(['ses-1_date-20260111', 'ses-2_date-20260111', 'ses-3_date-20260112', 'ses-4_date-20260112', 'ses-5_date-20260113', 'ses-6_date-20260113', 'ses-7_date-20260114', 'ses-8_date-20260114', 'ses-9_date-20260115', 'ses-10_date-20260115', 'ses-11_date-20260116', 'ses-12_date-20260116', 'ses-13_date-20260117', 'ses-14_date-20260117', 'ses-15_date-20260118', 'ses-16_date-20260118', 'ses-17_date-20260119', 'ses-18_date-20260119', 'ses-19_date-20260120', 'ses-20_date-20260120', 'ses-21_date-20260121', 'ses-22_date-20260121', 'ses-23_date-20260122', 'ses-24_date-20260122'])
