In [None]:
import datetime
import logging
from pathlib import Path

import pandas as pd
import semver
from tqdm.notebook import tqdm

import ssvr.qc
import ssvr.utils

%load_ext autoreload
%autoreload 2
from ssvr.dataset import SessionDataset, find_session_info
from ssvr.models import DataLoadingSettings, SessionInfo

logging.getLogger("ssvr").setLevel(logging.ERROR)

logging.getLogger("aind_behavior_services.base").setLevel(logging.ERROR)

choice_linestyle = {True: "-", False: "--"}
subject_colors = {"808619": "C2", "808728": "C3", "789917": "C4"}

In [None]:
settings = DataLoadingSettings()
session_info = list(find_session_info(settings))
session_info = [session for session in session_info if session.version >= semver.Version.parse("0.6.0")]
session_datasets: list[SessionDataset] = []

session_filter = pd.read_csv(Path(settings.root_path[0]) / "../sessions.csv")
session_filter["use"] = session_filter["use"].apply(lambda x: x == "1" if pd.notna(x) else False)
session_filter["crop_trial"] = pd.to_numeric(session_filter["crop_trial"], errors="coerce").astype("Int64")


def _use_session(session: SessionInfo) -> bool:
    session_name = session.session_id
    if session_name not in session_filter["session"].values:
        return False
    return session_filter[session_filter["session"] == session_name]["use"].iloc[0]


session_datasets = []
for info in tqdm(session_info, desc="Loading sessions"):
    if _use_session(info):
        try:
            _session = SessionDataset(session_info=info, processing_settings=settings.processing_settings)
            row = session_filter[session_filter["session"] == info.session_id].iloc[0]
            if not _session.session_metrics.session_duration < datetime.timedelta(minutes=15):
                if pd.notna(row["crop_trial"]):
                    _session.trials = _session.trials[: int(row["crop_trial"])]
                session_datasets.append(_session)
        except Exception as e:
            print(f"Failed to load session {info.session_id}: {e}")

In [None]:
if 0:
    ssvr.qc.run_qc(session_datasets=session_datasets, path=Path("./derived") / "qc_reports")

# Table enrichment

* Waiting time
* Waiting time predicted
* Block label
* Trials since block start
* Trials to next block start
* Trials of this type since block start
* Trials of this type to next block start
* Previous trial choice
* Previous trial reward
* Previous trial stimuli


In [None]:
session = session_datasets[0]

enriched_trials = session.trials.copy(deep=True)

## Get trials since block start
block = session.dataset["Behavior"]["SoftwareEvents"]["Block"].data
# block["parsed"] = block["data"].apply(lambda x: task_logic.Block.model_validate)
# enriched_trials
block["parsed"]
print(block)