In [None]:
%load_ext autoreload
%autoreload 2

from src.models import DataLoadingSettings
from src.dataset import find_session_info, make_session_dataset
from src.processing import process_sites
from src.visualization import plot_ethogram
from aind_behavior_vr_foraging import task_logic as vrf_task

from itertools import groupby
import logging

logger = logging.getLogger(__name__)

In [None]:
settings = DataLoadingSettings()
session_info = list(find_session_info(settings))

groupby_subject = groupby(session_info, key=lambda x: x.subject)
for subject, sessions in groupby_subject:
    print(f"Subject: {subject}")
    for session in sessions:
        print(f"  - {session.session_id} ({session.date.strftime('%Y-%m-%d')})")

session_datasets = [
    make_session_dataset(info, processing_settings=settings.processing_settings)
    for info in session_info
]
dataset = session_datasets[0]

In [None]:
_colormap = [
    "#1b9e77",
    "#d95f02",
    "#7570b3",
    "#e7298a",
    "#66a61e",
    "#e6ab02",
    "#a6761d",
]


def get_color_from_site(site_label: str, patch_idx: int) -> str:
    if site_label == vrf_task.VirtualSiteLabels.REWARDSITE:
        base_color = _colormap[patch_idx % len(_colormap)]
    elif site_label == vrf_task.VirtualSiteLabels.INTERPATCH:
        base_color = "#A9A9A9"
    elif site_label == vrf_task.VirtualSiteLabels.INTERSITE:
        base_color = "#4C4C4C"
    else:
        raise ValueError(f"Unknown site label: {site_label}")
    return base_color

In [None]:
sites = process_sites(dataset.dataset)

In [None]:
ax = plot_ethogram(
    sites,
    dataset,
    t_start=sites["t_start"].iloc[0],
    t_end=sites["t_start"].iloc[20],
    figsize=(12, 3),
)
ax.figure.show()

## Metrics to grab:
* Total distance traveled
* Number of sites
* Number of rewards
* P(Stop | Odor)

## Plots
* Histogram of inter-odor distances and time
* Odor aligned velocity
