In [None]:
from src.models import DataLoadingSettings
from src.dataset import find_session_info, make_session_dataset
from src.visualization import (
    plot_ethogram,
    patch_index_colormap,
    plot_aligned_to_grouped_by,
)
from contraqctor.contract.utils import print_data_stream_tree
from itertools import groupby
import logging


logging.getLogger("src.visualization").setLevel(logging.DEBUG)

choice_linestyle = {True: "-", False: "--"}

In [None]:
settings = DataLoadingSettings()
session_info = list(find_session_info(settings))

groupby_subject = groupby(session_info, key=lambda x: x.subject)
for subject, sessions in groupby_subject:
    print(f"Subject: {subject}")
    for session in sessions:
        print(f"  - {session.session_id} ({session.date.strftime('%Y-%m-%d')})")

session_datasets = [
    make_session_dataset(info, processing_settings=settings.processing_settings)
    for info in session_info
]
dataset = session_datasets[-1]

In [None]:
print(print_data_stream_tree(dataset.dataset.at("Behavior").at("SoftwareEvents")))

In [None]:
ax_velocity, ax_events = plot_ethogram(
    dataset,
    t_start=dataset.trials["choice_time"][10],
    t_end=dataset.trials["choice_time"][20],
    figsize=(12, 3),
)
ax_velocity.figure.show()

In [None]:
unique_patches = dataset.trials["patch_index"].unique()

pairwise_style = {
    (patch_idx, is_choice): {
        "color": patch_index_colormap[patch_idx],
        "linestyle": choice_linestyle[is_choice],
        "alpha": 0.05,
    }
    for patch_idx in unique_patches
    for is_choice in [True, False]
}

ax, summary = plot_aligned_to_grouped_by(
    timestamp_df=dataset.trials,
    timeseries=dataset.processed_streams.position_velocity["velocity"],
    by=["patch_index", "is_choice"],
    timestamp_column="odor_onset_time",
    plot_kwargs=pairwise_style,
)
ax.set_ylabel("Velocity (cm/s)")

In [None]:
## Make session choice plot

trials = dataset.trials
trials