In [None]:
import datetime
import logging
import typing as t
from pathlib import Path

import pandas as pd
import semver
from tqdm.notebook import tqdm

import ssvr.enrich_trials
import ssvr.qc
import ssvr.utils

%load_ext autoreload
%autoreload 2
from ssvr.dataset import SessionDataset, find_session_info
from ssvr.models import DataLoadingSettings, SessionInfo

logging.getLogger("ssvr").setLevel(logging.ERROR)

logging.getLogger("aind_behavior_services.base").setLevel(logging.ERROR)

choice_linestyle = {True: "-", False: "--"}
subject_colors = {"808619": "C2", "808728": "C3", "789917": "C4"}

In [None]:
settings = DataLoadingSettings()
session_info = list(find_session_info(settings))
session_info = [session for session in session_info if session.version >= semver.Version.parse("0.6.0")]
session_datasets: list[SessionDataset] = []

session_filter = pd.read_csv(Path(settings.root_path[0]) / "../sessions.csv")
session_filter["use"] = session_filter["use"].apply(lambda x: x == "1" if pd.notna(x) else False)
session_filter["crop_trial"] = pd.to_numeric(session_filter["crop_trial"], errors="coerce").astype("Int64")


def _use_session(session: SessionInfo) -> bool:
    session_name = session.session_id
    if session_name not in session_filter["session"].values:
        return False
    return session_filter[session_filter["session"] == session_name]["use"].iloc[0]


session_datasets = []
for info in tqdm(session_info, desc="Loading sessions"):
    if _use_session(info):
        try:
            _session = SessionDataset(session_info=info, processing_settings=settings.processing_settings)
            row = session_filter[session_filter["session"] == info.session_id].iloc[0]
            if not _session.session_metrics.session_duration < datetime.timedelta(minutes=15):
                if pd.notna(row["crop_trial"]):
                    _session.trials = _session.trials[: int(row["crop_trial"])]
                session_datasets.append(_session)
        except Exception as e:
            print(f"Failed to load session {info.session_id}: {e}")

for session in tqdm(session_datasets, desc="Enriching sessions"):
    enriched_trials = ssvr.enrich_trials.enrich_with_block_info(session)
    enriched_trials = ssvr.enrich_trials.enrich_with_relative_to_block(session)
    enriched_trials = ssvr.enrich_trials.enrich_with_previous_trial(session, n_previous=5)
    enriched_trials = ssvr.enrich_trials.enrich_with_block_probability(session)

if 0:
    ssvr.qc.run_qc(session_datasets=session_datasets, path=Path("./derived") / "qc_reports")

In [None]:
all_trials = []
for session in session_datasets:
    df = session.trials.copy()
    df = df.reset_index().rename(columns={"index": "trial_number"})
    df["subject"] = session.session_info.subject
    df["session_id"] = session.session_info.session_id
    all_trials.append(df)

all_trials_df = t.cast(pd.DataFrame, pd.concat(all_trials, ignore_index=True))
# all_trials_df.to_csv(Path("./derived") / "all_sessions_enriched_trials.csv")
print(all_trials_df.info())

In [None]:
switch_trials_df

In [None]:
from ssvr.analysis.block_switching_behavior import (
    calculate_choice_matrix,
    plot_block_switch_choice_patterns,
)

trial_window = (-10, 30)
choice_behavior_matrix, switch_trials_df = calculate_choice_matrix(
    all_trials_df, trial_window=trial_window, block_switch_filter="different"
)
fig, ax = plot_block_switch_choice_patterns(choice_behavior_matrix, trial_window)
fig.suptitle("All subjects")
fig.savefig(Path("./derived") / "block_switch_choice_patterns_all_subjects.png")
for subject, df in switch_trials_df.groupby("subject"):
    fig, ax = plot_block_switch_choice_patterns(
        choice_behavior_matrix[df["index_ord"].values, :, :],
        trial_window,
    )
    fig.suptitle(f"Subject {subject}")
    fig.savefig(Path("./derived") / f"block_switch_choice_patterns_subject_{subject}.png")

## Check if there are biases depending on the odor patch switched to
for (subject, high_patch_index), df in switch_trials_df.groupby(["subject", "after_high_index"]):
    fig, ax = plot_block_switch_choice_patterns(
        choice_behavior_matrix[df["index_ord"].values, :, :],
        trial_window,
    )
    fig.suptitle(f"Subject {subject}, After High Patch Index {high_patch_index}")
    fig.savefig(
        Path("./derived") / f"block_switch_choice_patterns_subject_{subject}_after_high_patch_{high_patch_index}.png"
    )

In [None]:
from pathlib import Path

import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
from matplotlib.lines import Line2D

from ssvr.analysis.block_switching_behavior import calculate_consecutive_choice_runs, plot_trials_to_criterion_histogram

consecutive_runs_df = calculate_consecutive_choice_runs(
    all_trials_df, switch_trials_df, n_consecutive=5, max_trials_ahead=40
)

summary_stats = []

for subject, subject_df in consecutive_runs_df.groupby("subject"):
    subject_df = subject_df.sort_values("trial_index")

    unique_trials = subject_df["trial_index"].unique()
    mid_point = len(unique_trials) // 2

    first_half_df = subject_df[subject_df["trial_index"].isin(unique_trials[:mid_point])]
    second_half_df = subject_df[subject_df["trial_index"].isin(unique_trials[mid_point:])]

    fig, axes = plt.subplots(1, 2, figsize=(12, 5), sharey=True)
    plot_trials_to_criterion_histogram(first_half_df, ax=axes[0], title="First Half")
    plot_trials_to_criterion_histogram(second_half_df, ax=axes[1], title="Second Half")

    fig.suptitle(f"Animal {subject} - Split Half Analysis")
    fig.tight_layout()
    fig.savefig(Path("./derived") / f"trials_to_criterion_split_animal_{subject}.png")
    plt.show()

    # Calculate Medians for Summary
    def get_medians(df):
        low = df[df["is_low_reward_patch"] == True]["trials_to_n_consecutive_false"].dropna()
        high = df[df["is_low_reward_patch"] == False]["trials_to_n_consecutive_true"].dropna()
        return (np.median(low) if len(low) > 0 else np.nan, np.median(high) if len(high) > 0 else np.nan)

    l1, h1 = get_medians(first_half_df)
    l2, h2 = get_medians(second_half_df)

    summary_stats.append({"subject": subject, "low_1st": l1, "high_1st": h1, "low_2nd": l2, "high_2nd": h2})

summary_df = pd.DataFrame(summary_stats)
fig_sum, ax_sum = plt.subplots(figsize=(8, 6))

for _, row in summary_df.iterrows():
    # Low Reward (Blue)
    ax_sum.plot([0, 1], [row["low_1st"], row["low_2nd"]], "o-", color="blue", alpha=0.5)
    # High Reward (Red)
    ax_sum.plot([0, 1], [row["high_1st"], row["high_2nd"]], "o-", color="red", alpha=0.5)

ax_sum.set_xticks([0, 1])
ax_sum.set_xticklabels(["1st Half", "2nd Half"])
ax_sum.set_ylabel("Median Trials to Criterion")
ax_sum.set_title("Change in Performance Across Session Halves")
ax_sum.grid(True, alpha=0.3)

# Custom legend
custom_lines = [Line2D([0], [0], color="blue", lw=2, marker="o"), Line2D([0], [0], color="red", lw=2, marker="o")]
ax_sum.legend(custom_lines, ["Low Reward Patch", "High Reward Patch"])

fig_sum.savefig(Path("./derived") / "trials_to_criterion_summary_split.png")
plt.show()