In [None]:
import datetime
import logging
import typing as t
from pathlib import Path

import pandas as pd
import semver
from tqdm.notebook import tqdm

import ssvr.enrich_trials
import ssvr.qc
import ssvr.utils

%load_ext autoreload
%autoreload 2
from ssvr.dataset import SessionDataset, find_session_info
from ssvr.models import DataLoadingSettings, SessionInfo

logging.getLogger("ssvr").setLevel(logging.ERROR)

logging.getLogger("aind_behavior_services.base").setLevel(logging.ERROR)

choice_linestyle = {True: "-", False: "--"}
subject_colors = {"808619": "C2", "808728": "C3", "789917": "C4"}

In [None]:
settings = DataLoadingSettings()
session_info = list(find_session_info(settings))
session_info = [session for session in session_info if session.version >= semver.Version.parse("0.6.0")]
session_datasets: list[SessionDataset] = []

session_filter = pd.read_csv(Path(settings.root_path[0]) / "../sessions.csv")
session_filter["use"] = session_filter["use"].apply(lambda x: x == "1" if pd.notna(x) else False)
session_filter["crop_trial"] = pd.to_numeric(session_filter["crop_trial"], errors="coerce").astype("Int64")


def _use_session(session: SessionInfo) -> bool:
    session_name = session.session_id
    if session_name not in session_filter["session"].values:
        return False
    return session_filter[session_filter["session"] == session_name]["use"].iloc[0]


session_datasets = []
for info in tqdm(session_info, desc="Loading sessions"):
    if _use_session(info):
        try:
            _session = SessionDataset(session_info=info, processing_settings=settings.processing_settings)
            row = session_filter[session_filter["session"] == info.session_id].iloc[0]
            if not _session.session_metrics.session_duration < datetime.timedelta(minutes=15):
                if pd.notna(row["crop_trial"]):
                    _session.trials = _session.trials[: int(row["crop_trial"])]
                session_datasets.append(_session)
        except Exception as e:
            print(f"Failed to load session {info.session_id}: {e}")

for session in tqdm(session_datasets, desc="Enriching sessions"):
    enriched_trials = ssvr.enrich_trials.enrich_with_block_info(session)
    enriched_trials = ssvr.enrich_trials.enrich_with_relative_to_block(session)
    enriched_trials = ssvr.enrich_trials.enrich_with_previous_trial(session, n_previous=5)

if 0:
    ssvr.qc.run_qc(session_datasets=session_datasets, path=Path("./derived") / "qc_reports")

In [None]:
all_trials = []
for session in session_datasets:
    df = session.trials.copy()
    df = df.reset_index().rename(columns={"index": "trial_number"})
    df["subject"] = session.session_info.subject
    df["session_id"] = session.session_info.session_id
    all_trials.append(df)

all_trials_df = t.cast(pd.DataFrame, pd.concat(all_trials, ignore_index=True))
# all_trials_df.to_csv(Path("./derived") / "all_sessions_enriched_trials.csv")
print(all_trials_df.info())

In [None]:
from ssvr.analysis.block_switching_behavior import (
    calculate_choice_matrix,
    get_switches_df,
    plot_block_switch_choice_patterns,
)

trial_window = (-10, 30)
switch_df = get_switches_df(all_trials_df, block_switch_filter="different")
choice_behavior_matrix, switch_trials = calculate_choice_matrix(all_trials_df, switch_df, trial_window=trial_window)
fig, ax = plot_block_switch_choice_patterns(choice_behavior_matrix, trial_window)
fig.suptitle("All subjects")
fig.savefig(Path("./derived") / "block_switch_choice_patterns_all_subjects.png")
for subject, df in switch_trials.groupby("subject"):
    fig, ax = plot_block_switch_choice_patterns(
        choice_behavior_matrix[df["index_ord"].values, :, :],
        trial_window,
    )
    fig.suptitle(f"Subject {subject}")
    fig.savefig(Path("./derived") / f"block_switch_choice_patterns_subject_{subject}.png")

In [None]:
from ssvr.analysis.block_switching_behavior import calculate_consecutive_choice_runs, _find_consecutive_run

consecutive_runs_df = calculate_consecutive_choice_runs(all_trials_df, switch_df, n_consecutive=3, max_trials_ahead=50)
display(consecutive_runs_df.head())
