In [None]:
from src.models import DataLoadingSettings
from src.dataset import find_session_info, make_session_dataset
from src.visualization import (
    plot_ethogram,
    patch_index_colormap,
    plot_aligned_to_grouped_by,
    plot_session_trials,
    a_lot_of_style,
)
from itertools import groupby
import logging
import pandas as pd
import dataclasses
import matplotlib.pyplot as plt
import math
import numpy as np

logging.getLogger("src.visualization").setLevel(logging.WARNING)

choice_linestyle = {True: "-", False: "--"}
subject_colors = {"808619": "C2", "808728": "C3", "789917": "C4"}

In [None]:
settings = DataLoadingSettings()
session_info = list(find_session_info(settings))

print(f"Found {len(session_info)} sessions:")

groupby_subject = groupby(session_info, key=lambda x: x.subject)
for subject, sessions in groupby_subject:
    print(f"Subject: {subject}")
    for session in sessions:
        print(f"  - {session.session_id} ({session.date.strftime('%Y-%m-%d')})")

session_datasets = []

for info in session_info:
    try:
        session_datasets.append(
            make_session_dataset(info, processing_settings=settings.processing_settings)
        )
    except Exception as e:
        print(f"Failed to load session {info.session_id}: {e}")

session_metrics = pd.DataFrame(
    [
        {
            **dataclasses.asdict(ds.session_metrics),
            **dataclasses.asdict(ds.session_info),
        }
        for ds in session_datasets
    ]
)
display(session_metrics)

In [None]:
## Across animals session plots
ax_to_plot = {
    "reward_site_count": "Reward Site Count (#)",
    "total_distance": "Total Distance (cm)",
    "reward_count": "Reward Count (#)",
    "total_reward_ml": "Total Reward (mL)",
}

with a_lot_of_style():
    fig, axs = plt.subplots(
        nrows=math.ceil(math.sqrt(len(ax_to_plot))),
        ncols=math.ceil(len(ax_to_plot) / math.ceil(math.sqrt(len(ax_to_plot)))),
        figsize=(15, 8),
        sharex=True,
    )
    axs = axs.flatten()

    for i, metric in enumerate(ax_to_plot):
        axs[i].set_xlabel("Date")
        axs[i].set_ylabel(ax_to_plot[metric])
        for subject, group in session_metrics.groupby("subject"):
            group_sorted = group.sort_values("date")
            axs[i].plot(
                group_sorted["date"],
                group_sorted[metric],
                marker="o",
                linestyle="-",
                label=subject,
                color=subject_colors[subject],
            )
        axs[i].legend(title="Subject")
    fig.autofmt_xdate(rotation=45)
    fig.tight_layout()

In [None]:
with a_lot_of_style():
    for dataset in session_datasets:
        print(
            f"Session: {dataset.session_info.session_id} ({dataset.session_info.subject})"
        )
        ax_velocity, ax_events = plot_ethogram(
            dataset,
            t_start=dataset.trials["odor_onset_time"][30],
            t_end=dataset.trials["odor_onset_time"][40],
            figsize=(12, 3),
        )
        dataset.dataset["Behavior"]["OperationControl"]

        unique_patches = dataset.trials["patch_index"].unique()

        pairwise_style = {
            (patch_idx, is_choice): {
                "color": patch_index_colormap[patch_idx],
                "linestyle": choice_linestyle[is_choice],
                "alpha": 0.02,
            }
            for patch_idx in unique_patches
            for is_choice in [True, False]
        }

        ax, summary = plot_aligned_to_grouped_by(
            timestamp_df=dataset.trials,
            timeseries=dataset.processed_streams.position_velocity["velocity"],
            by=["patch_index", "is_choice"],
            timestamp_column="odor_onset_time",
            plot_kwargs=pairwise_style,
            event_window=(-2, 5),
        )
        ax.set_ylabel("Velocity (cm/s)")

        pairwise_style = {
            (p_reward, is_choice): {
                "color": "red" if p_reward > 0.5 else "blue",
                "linestyle": choice_linestyle[is_choice],
                "alpha": 0.02,
            }
            for p_reward in np.unique(dataset.trials["p_reward"].values)
            for is_choice in [True, False]
        }
        ax, summary = plot_aligned_to_grouped_by(
            timestamp_df=dataset.trials.query("p_reward < 1.0").head(150),
            timeseries=dataset.processed_streams.position_velocity["velocity"],
            by=["p_reward", "is_choice"],
            timestamp_column="odor_onset_time",
            plot_kwargs=pairwise_style,
            event_window=(-2, 5),
        )

        ax = plot_session_trials(dataset, alpha=0.33, figsize=(16, 6))

        time_of_trial = dataset.trials["odor_onset_time"]

        blocks = (
            dataset.dataset["Behavior"]["SoftwareEvents"]["Block"].load().data.copy()
        )
        block_times = blocks.index.values
        trial_indices = time_of_trial.searchsorted(block_times, side="right") - 1
        trial_indices = np.maximum(trial_indices, 0)
        blocks["trial_idx"] = time_of_trial.iloc[trial_indices].index.values
        ax.vlines(
            blocks["trial_idx"].values,
            ymin=ax.get_ylim()[0],
            ymax=ax.get_ylim()[1],
            colors="k",
            linestyles="dashed",
            label="Block Change",
        )

        plt.show()