In [None]:
import datetime
import logging
from pathlib import Path

import pandas as pd
import semver
from tqdm.notebook import tqdm

import ssvr.enrich_trials
import ssvr.qc
import ssvr.utils

%load_ext autoreload
%autoreload 2
from ssvr.dataset import SessionDataset, find_session_info
from ssvr.models import DataLoadingSettings, SessionInfo

logging.getLogger("ssvr").setLevel(logging.ERROR)

logging.getLogger("aind_behavior_services.base").setLevel(logging.ERROR)

choice_linestyle = {True: "-", False: "--"}
subject_colors = {"808619": "C2", "808728": "C3", "789917": "C4"}

In [None]:
settings = DataLoadingSettings()
session_info = list(find_session_info(settings))
session_info = [session for session in session_info if session.version >= semver.Version.parse("0.6.0")]
session_datasets: list[SessionDataset] = []

session_filter = pd.read_csv(Path(settings.root_path[0]) / "../sessions.csv")
session_filter["use"] = session_filter["use"].apply(lambda x: x == "1" if pd.notna(x) else False)
session_filter["crop_trial"] = pd.to_numeric(session_filter["crop_trial"], errors="coerce").astype("Int64")


def _use_session(session: SessionInfo) -> bool:
    session_name = session.session_id
    if session_name not in session_filter["session"].values:
        return False
    return session_filter[session_filter["session"] == session_name]["use"].iloc[0]


session_datasets = []
for info in tqdm(session_info, desc="Loading sessions"):
    if _use_session(info):
        try:
            _session = SessionDataset(session_info=info, processing_settings=settings.processing_settings)
            row = session_filter[session_filter["session"] == info.session_id].iloc[0]
            if not _session.session_metrics.session_duration < datetime.timedelta(minutes=15):
                if pd.notna(row["crop_trial"]):
                    _session.trials = _session.trials[: int(row["crop_trial"])]
                session_datasets.append(_session)
        except Exception as e:
            print(f"Failed to load session {info.session_id}: {e}")

In [None]:
if 0:
    ssvr.qc.run_qc(session_datasets=session_datasets, path=Path("./derived") / "qc_reports")

# Table enrichment

* Waiting time
* Waiting time predicted
* Block label ❌
* Trials since block start ❌
* Trials to next block start ❌
* Trials of this type since block start
* Trials of this type to next block start
* Previous trial choice
* Previous trial reward
* Previous trial stimuli


In [16]:
import ssvr.enrich_trials

for session in tqdm(session_datasets, desc="Enriching sessions"):
    enriched_trials = ssvr.enrich_trials.enrich_with_block_info(session)
    enriched_trials = ssvr.enrich_trials.enrich_with_relative_to_block(session)

Enriching sessions:   0%|          | 0/64 [00:00<?, ?it/s]

In [None]:
session = session_datasets[0]
session.trials






Unnamed: 0,odor_onset_time,choice_time,reward_time,reaction_duration,patch_index,is_rewarded,is_choice,p_reward,block_index,block,block_patch_probabilities,trials_from_last_block,trials_to_next_block
0,1.625212e+06,1.625214e+06,1.625215e+06,2.885024,1,True,True,0.9,0,environment_statistics=EnvironmentStatistics(p...,"[0.1, 0.9]",0,51
1,1.625221e+06,1.625223e+06,,2.779008,0,False,True,0.1,0,environment_statistics=EnvironmentStatistics(p...,"[0.1, 0.9]",1,50
2,1.625230e+06,1.625233e+06,,2.779008,0,False,True,0.1,0,environment_statistics=EnvironmentStatistics(p...,"[0.1, 0.9]",2,49
3,1.625238e+06,1.625240e+06,1.625241e+06,2.764000,1,True,True,0.9,0,environment_statistics=EnvironmentStatistics(p...,"[0.1, 0.9]",3,48
4,1.625248e+06,1.625251e+06,,2.816000,0,False,True,0.1,0,environment_statistics=EnvironmentStatistics(p...,"[0.1, 0.9]",4,47
...,...,...,...,...,...,...,...,...,...,...,...,...,...
305,1.627745e+06,,,,0,False,False,1.0,4,environment_statistics=EnvironmentStatistics(p...,"[0.1, 0.9]",80,4
306,1.627748e+06,,,,0,False,False,1.0,4,environment_statistics=EnvironmentStatistics(p...,"[0.1, 0.9]",81,3
307,1.627751e+06,1.627754e+06,1.627754e+06,2.875008,1,True,True,0.9,4,environment_statistics=EnvironmentStatistics(p...,"[0.1, 0.9]",82,2
308,1.627761e+06,1.627764e+06,1.627764e+06,2.882016,1,True,True,0.9,4,environment_statistics=EnvironmentStatistics(p...,"[0.1, 0.9]",83,1
