In [None]:
import datetime
import logging
import typing as t

import pandas as pd
from tqdm.notebook import tqdm

import ssvr.enrich_trials
import ssvr.qc
import ssvr.utils

%load_ext autoreload
%autoreload 2
import matplotlib.pyplot as plt

import ssvr.visualization as viz
from ssvr.dataset import SessionDataset, create_session_info
from ssvr.models import DataLoadingSettings, SessionToLoad

logging.getLogger("ssvr").setLevel(logging.ERROR)

logging.getLogger("aind_behavior_services.base").setLevel(logging.ERROR)

choice_linestyle = {True: "-", False: "--"}
subject_colors = {"808619": "C2", "808728": "C3", "789917": "C4"}

In [None]:
settings = DataLoadingSettings(root_path=[r"\\allen\aind\scratch\bruno.cruz"],
                               sessions_to_load=[SessionToLoad(session_id="TestMouse_2026-01-24T011822Z")],
                               root_derived_path="./")
DERIVED_PATH = settings.root_derived_path
session_datasets: list[SessionDataset] = []
for entry in tqdm(settings.sessions_to_load, desc="Loading sessions", total=len(settings.sessions_to_load)):
    candidate_paths = [p / entry.session_id for p in settings.root_path if p.exists()]
    if not candidate_paths:
        raise FileNotFoundError(f"Session {entry.session_id} not found in any root path.")
    if len(candidate_paths) > 1:
        logging.warning(f"Multiple paths found for session {entry.session_id}, using the first one.")
    info = create_session_info(candidate_paths[0])
    try:
        _session = SessionDataset(session_info=info, processing_settings=settings.processing_settings)
        if not _session.session_metrics.session_duration < datetime.timedelta(minutes=0):
            if entry.crop_max_trials is not None:
                _session.trials = _session.trials[: int(entry.crop_max_trials)]
            session_datasets.append(_session)
    except Exception as e:
        print(f"Failed to load session {info.session_id}: {e}")

plt.close("all")
print(f"Loaded {len(session_datasets)} sessions.")

In [None]:
all_trials = []
for session in session_datasets:
    df = session.trials.copy()
    df = df.reset_index().rename(columns={"index": "trial_number"})
    df["subject"] = session.session_info.subject
    df["session_id"] = session.session_info.session_id
    all_trials.append(df)

all_trials_df = t.cast(pd.DataFrame, pd.concat(all_trials, ignore_index=True))
# all_trials_df.to_csv(DERIVED_PATH / "all_sessions_enriched_trials.csv")

# Print summary statistics
print(f"Total trials across all sessions: {len(all_trials_df)}")
print(f"Number of sessions: {all_trials_df['session_id'].nunique()}")
print(f"Number of subjects: {all_trials_df['subject'].nunique()}")
print()

# Print per-subject statistics
for subject in sorted(all_trials_df["subject"].unique()):
    subject_df = all_trials_df[all_trials_df["subject"] == subject]
    n_sessions = subject_df["session_id"].nunique()
    n_trials = len(subject_df)
    print(f"Subject {subject}: {n_sessions} sessions, {n_trials} trials")

In [None]:
print(all_trials_df.info())

In [None]:
from ssvr.visualization import a_lot_of_style, plot_ethogram

plot_ethogram(session_datasets[0], show_plot=True)

In [None]:
session.trials
for site_type in session.trials["patch_index"].unique():
    site_trials = session.trials[session.trials["patch_index"] == site_type]
    print(f"Site type {site_type}: {len(site_trials)} trials")
    p_stop = site_trials["is_choice"].mean()
    print(f"  Probability of stopping: {p_stop:.2f}")
    p_reward = site_trials["is_rewarded"].mean()
    print(f"  Probability of reward: {p_reward:.2f}")