In [None]:
from src.models import DataLoadingSettings
from src.dataset import find_session_info, make_session_dataset
from src.visualization import (
    plot_ethogram,
    patch_index_colormap,
    plot_aligned_to_grouped_by,
    plot_session_trials,
)
from itertools import groupby
import logging
import pandas as pd
import dataclasses
import matplotlib.pyplot as plt
import math

logging.getLogger("src.visualization").setLevel(logging.WARNING)

choice_linestyle = {True: "-", False: "--"}
subject_colors = {"808619": "C2", "808728": "C3", "789917": "C4"}

In [None]:
settings = DataLoadingSettings()
session_info = list(find_session_info(settings))

print(f"Found {len(session_info)} sessions:")

groupby_subject = groupby(session_info, key=lambda x: x.subject)
for subject, sessions in groupby_subject:
    print(f"Subject: {subject}")
    for session in sessions:
        print(f"  - {session.session_id} ({session.date.strftime('%Y-%m-%d')})")

session_datasets = []

for info in session_info:
    try:
        session_datasets.append(
            make_session_dataset(info, processing_settings=settings.processing_settings)
        )
    except Exception as e:
        print(f"Failed to load session {info.session_id}: {e}")

session_metrics = pd.DataFrame(
    [
        {
            **dataclasses.asdict(ds.session_metrics),
            **dataclasses.asdict(ds.session_info),
        }
        for ds in session_datasets
    ]
)
display(session_metrics)

In [None]:
## Across animals session plots
ax_to_plot = {
    "reward_site_count": "Reward Site Count (#)",
    "total_distance": "Total Distance (cm)",
    "reward_count": "Reward Count (#)",
}

fig, axs = plt.subplots(
    nrows=math.ceil(math.sqrt(len(ax_to_plot))),
    ncols=math.ceil(len(ax_to_plot) / math.ceil(math.sqrt(len(ax_to_plot)))),
    figsize=(10, 6),
)
axs = axs.flatten()

for i, metric in enumerate(ax_to_plot):
    axs[i].set_xlabel("Date")
    axs[i].set_ylabel(ax_to_plot[metric])
    for subject, group in session_metrics.groupby("subject"):
        group_sorted = group.sort_values("date")
        axs[i].plot(
            group_sorted["date"].astype(str),
            group_sorted[metric],
            marker="o",
            linestyle="-",
            label=subject,
            color=subject_colors[subject],
        )
    axs[i].legend(title="Subject")


p_per_patch = session_metrics["p_stop_per_odor"]

In [None]:
for dataset in session_datasets:
    print(
        f"Session: {dataset.session_info.session_id} ({dataset.session_info.subject})"
    )
    ax_velocity, ax_events = plot_ethogram(
        dataset,
        t_start=dataset.trials["odor_onset_time"][30],
        t_end=dataset.trials["odor_onset_time"][40],
        figsize=(12, 3),
    )

    unique_patches = dataset.trials["patch_index"].unique()

    pairwise_style = {
        (patch_idx, is_choice): {
            "color": patch_index_colormap[patch_idx],
            "linestyle": choice_linestyle[is_choice],
            "alpha": 0.05,
        }
        for patch_idx in unique_patches
        for is_choice in [True, False]
    }

    ax, summary = plot_aligned_to_grouped_by(
        timestamp_df=dataset.trials,
        timeseries=dataset.processed_streams.position_velocity["velocity"],
        by=["patch_index", "is_choice"],
        timestamp_column="odor_onset_time",
        plot_kwargs=pairwise_style,
        event_window=(-2, 5),
    )
    ax.set_ylabel("Velocity (cm/s)")

    plot_session_trials(dataset, alpha=0.5)
    plt.show()