In [None]:
from src.models import DataLoadingSettings
from src.dataset import find_session_info, make_session_dataset
from src.visualization import plot_ethogram, plot_aligned_to, patch_index_colormap
from contraqctor.contract.utils import print_data_stream_tree
from itertools import groupby
import logging

from matplotlib import pyplot as plt
import numpy as np
import pandas as pd

logging.getLogger("src.processing").setLevel(logging.WARNING)
logging.getLogger("src.visualization").setLevel(logging.DEBUG)

choice_linestyle = {True: "-", False: "--"}

In [None]:
settings = DataLoadingSettings()
session_info = list(find_session_info(settings))

groupby_subject = groupby(session_info, key=lambda x: x.subject)
for subject, sessions in groupby_subject:
    print(f"Subject: {subject}")
    for session in sessions:
        print(f"  - {session.session_id} ({session.date.strftime('%Y-%m-%d')})")

session_datasets = [
    make_session_dataset(info, processing_settings=settings.processing_settings)
    for info in session_info
]
dataset = session_datasets[0]

In [None]:
print(print_data_stream_tree(dataset.dataset.at("Behavior").at("SoftwareEvents")))

In [None]:
ax_velocity, ax_events = plot_ethogram(
    dataset,
    t_start=dataset.sites["t_start"].iloc[0],
    t_end=dataset.sites["t_start"].iloc[20],
    figsize=(12, 3),
)
ax_velocity.figure.show()

In [None]:
EVENT_WINDOW = (-0.5, 1.5)  # seconds

fig, ax = plt.subplots(figsize=(6, 4))
for (patch_type, is_choice), df in dataset.trials.groupby(["patch_index", "is_choice"]):
    color = patch_index_colormap[patch_type]
    linestyle = choice_linestyle[is_choice]
    ax, data = plot_aligned_to(
        df.index,
        dataset.processed_streams.position_velocity["velocity"],
        event_window=EVENT_WINDOW,
        plot_kwargs={
            "linestyle": linestyle,
            "color": color,
            "alpha": 0.1,
            "linewidth": 0.5,
        },
        ax=ax,
    )
    for i, d in enumerate(data):
        d.index = d.index - df.index[i]
    bin_width = 0.1
    new_index = pd.Index(
        np.arange(
            min(s.index.min() for s in data),
            max(s.index.max() for s in data),
            bin_width,
        ),
        name="Time",
    )
    binned = pd.concat([s.reindex(new_index, method="nearest") for s in data], axis=1)
    binned_mean = binned.mean(axis=1)
    lower = binned.apply(lambda x: np.nanpercentile(x, 2.5), axis=1)
    upper = binned.apply(lambda x: np.nanpercentile(x, 97.5), axis=1)

    ax.plot(
        binned_mean.index,
        binned_mean.values,
        color=color,
        linestyle=linestyle,
        linewidth=2,
        label=f"Patch {patch_type} {'(Choice)' if is_choice else '(No Choice)'}",
    )
    ax.fill_between(
        binned_mean.index,
        lower,
        upper,
        color=color,
        alpha=0.1,
        linewidth=0,
    )

ax.axvline(0, color="k", linestyle="--", linewidth=1)
ax.set_xlabel("Time from patch entry (s)")
ax.set_ylabel("Velocity (cm/s)")
ax.set_xlim(EVENT_WINDOW)
ax.legend(loc="best")