In [None]:
import datetime
import logging
import typing as t

import pandas as pd
from tqdm.notebook import tqdm

import ssvr.enrich_trials
import ssvr.qc
import ssvr.utils

%load_ext autoreload
%autoreload 2
import matplotlib.pyplot as plt

import ssvr.visualization as viz
from ssvr.dataset import SessionDataset, create_session_info
from ssvr.models import DataLoadingSettings

logging.getLogger("ssvr").setLevel(logging.ERROR)

logging.getLogger("aind_behavior_services.base").setLevel(logging.ERROR)

choice_linestyle = {True: "-", False: "--"}
subject_colors = {"808619": "C2", "808728": "C3", "789917": "C4"}

In [None]:
settings = DataLoadingSettings()
DERIVED_PATH = settings.root_derived_path

if False:  # Run this cell if you want to sync all dataset locally
    from ssvr.s3_utils import sync_dataset

    sync_dataset(settings)

session_datasets: list[SessionDataset] = []
for entry in tqdm(settings.sessions_to_load, desc="Loading sessions", total=len(settings.sessions_to_load)):
    candidate_paths = [p / entry.session_id for p in settings.root_path if p.exists()]
    if not candidate_paths:
        raise FileNotFoundError(f"Session {entry.session_id} not found in any root path.")
    if len(candidate_paths) > 1:
        logging.warning(f"Multiple paths found for session {entry.session_id}, using the first one.")
    info = create_session_info(candidate_paths[0])
    try:
        _session = SessionDataset(session_info=info, processing_settings=settings.processing_settings)
        if not _session.session_metrics.session_duration < datetime.timedelta(minutes=15):
            if entry.crop_max_trials is not None:
                _session.trials = _session.trials[: int(entry.crop_max_trials)]
            session_datasets.append(_session)
    except Exception as e:
        print(f"Failed to load session {info.session_id}: {e}")

for session in tqdm(session_datasets, desc="Enriching sessions"):
    enriched_trials = ssvr.enrich_trials.enrich_with_session_type(session)  # is_fixed_stop_duration
    enriched_trials = ssvr.enrich_trials.enrich_with_block_info(session)
    enriched_trials = ssvr.enrich_trials.enrich_with_relative_to_block(session)
    enriched_trials = ssvr.enrich_trials.enrich_with_previous_trial(session, n_previous=5)
    enriched_trials = ssvr.enrich_trials.enrich_with_block_probability(session)
    enriched_trials = ssvr.enrich_trials.enrich_with_reward_rate(session, exponential_decay=0.2)

if 0:
    ssvr.qc.run_qc(session_datasets=session_datasets, path=DERIVED_PATH / "qc_reports")

In [None]:
all_trials = []
for session in session_datasets:
    df = session.trials.copy()
    df = df.reset_index().rename(columns={"index": "trial_number"})
    df["subject"] = session.session_info.subject
    df["session_id"] = session.session_info.session_id
    all_trials.append(df)

all_trials_df = t.cast(pd.DataFrame, pd.concat(all_trials, ignore_index=True))
# all_trials_df.to_csv(DERIVED_PATH / "all_sessions_enriched_trials.csv")
print(all_trials_df.info())

# Print summary statistics
print(f"Total trials across all sessions: {len(all_trials_df)}")
print(f"Number of sessions: {all_trials_df['session_id'].nunique()}")
print(f"Number of subjects: {all_trials_df['subject'].nunique()}")
print()

# Print per-subject statistics
for subject in sorted(all_trials_df["subject"].unique()):
    subject_df = all_trials_df[all_trials_df["subject"] == subject]
    n_sessions = subject_df["session_id"].nunique()
    n_trials = len(subject_df)
    print(f"Subject {subject}: {n_sessions} sessions, {n_trials} trials")

In [None]:
from ssvr.analysis.block_switching_behavior import (
    calculate_choice_matrix,
    plot_block_switch_choice_patterns,
)

trial_window = (-10, 30)
choice_behavior_matrix, switch_trials_df = calculate_choice_matrix(
    all_trials_df, trial_window=trial_window, block_switch_filter="different"
)


with viz.a_lot_of_style():
    fig, ax = plot_block_switch_choice_patterns(choice_behavior_matrix, trial_window)

    # Add subject markers to both heatmaps
    for heatmap_ax in ax[:2]:  # First two axes are the heatmaps
        for s in switch_trials_df["subject"].unique():
            subject_rows = switch_trials_df[switch_trials_df["subject"] == s]["index_ord"].values
            color = subject_colors[s]
            for row in subject_rows:
                heatmap_ax.plot(-1.5, row, "o", color=color, markersize=4, clip_on=False)

    fig.suptitle("All subjects")
    fig.savefig(DERIVED_PATH / "block_switch_choice_patterns_all_subjects.svg")


with viz.a_lot_of_style():
    for subject, df in switch_trials_df.groupby("subject"):
        fig, ax = plot_block_switch_choice_patterns(
            choice_behavior_matrix[df["index_ord"].values, :, :],
            trial_window,
        )
        ax[2].legend_.set_visible(False)
        fig.suptitle(f"Subject {subject}")
        fig.savefig(DERIVED_PATH / f"block_switch_choice_patterns_subject_{subject}.svg")

## Check if there are biases depending on the odor patch switched to
for (subject, high_patch_index), df in switch_trials_df.groupby(["subject", "after_high_index"]):
    fig, ax = plot_block_switch_choice_patterns(
        choice_behavior_matrix[df["index_ord"].values, :, :],
        trial_window,
    )
    fig.suptitle(f"Subject {subject}, After High Patch Index {high_patch_index}")
    fig.savefig(
        DERIVED_PATH / f"block_switch_choice_patterns_subject_{subject}_after_high_patch_{high_patch_index}.svg"
    )
plt.show()

In [None]:
conditions = {
    "low first rewarded": lambda row: not row["is_high_reward_patch"] and (row["is_rewarded"] == True),
    "low first not rewarded": lambda row: not row["is_high_reward_patch"] and (row["is_rewarded"] != True),
    "low first not stop": lambda row: not row["is_high_reward_patch"] and (row["is_choice"] == False),
}

for condition_name, condition_fn in conditions.items():
    condition_df = switch_trials_df[switch_trials_df.apply(condition_fn, axis=1)]
    fig, ax = plot_block_switch_choice_patterns(
        choice_behavior_matrix[condition_df["index_ord"].values, :, :],
        trial_window,
    )
    fig.suptitle(f"All subjects - Condition: {condition_name}")
    fig.savefig(
        DERIVED_PATH / f"block_switch_choice_patterns_all_subjects_condition_{condition_name.replace(' ', '_')}.svg"
    )
    for subject in condition_df["subject"].unique():
        subject_condition_df = condition_df[condition_df["subject"] == subject]
        fig, ax = plot_block_switch_choice_patterns(
            choice_behavior_matrix[subject_condition_df["index_ord"].values, :, :],
            trial_window,
        )
        fig.suptitle(f"All subjects - Condition: {condition_name} - Subject: {subject}")
        fig.savefig(
            DERIVED_PATH
            / f"block_switch_choice_patterns_all_subjects_condition_{condition_name.replace(' ', '_')}_{subject}.svg"
        )

plt.show()

In [None]:
all_trials_df

In [None]:
## Grab P(Stay) on the "low patch" after not getting reward on the high patch
high_patch = all_trials_df[all_trials_df["is_high_reward_patch"] & (all_trials_df["is_rewarded"].notna())]
import numpy as np

with viz.a_lot_of_style():
    plt.figure(figsize=(3, 4))
    animals = []
    for subject, subject_df in all_trials_df.groupby("subject"):
        out = {True: 0, False: 0}
        for l, sub_df in subject_df.groupby("is_rewarded"):
            next_trial = all_trials_df.loc[sub_df.index + 1, :]
            next_trial_low_outside_switch = next_trial[
                ~next_trial["is_high_reward_patch"] & (next_trial["trials_from_last_block_by_trial_type"] > 0)
            ]

            # Bootstrap
            n_boot = 1000
            res = np.random.choice(
                next_trial_low_outside_switch["is_choice"],
                size=(n_boot, len(next_trial_low_outside_switch)),
                replace=True,
            )
            boot_means = np.nanmean(res, axis=1)

            mean_val = np.mean(next_trial_low_outside_switch["is_choice"])
            ci_low, ci_high = np.percentile(boot_means, [2.5, 97.5])
            out[l] = mean_val

            plt.scatter([l], [mean_val], color=subject_colors[subject])
            plt.errorbar(
                [l],
                [mean_val],
                yerr=[[mean_val - ci_low], [ci_high - mean_val]],
                color=subject_colors[subject],
                fmt="o",
            )

        plt.plot(
            [0, 1], [out[False], out[True]], marker=None, label=f"Subject {subject}", color=subject_colors[subject]
        )
        animals.append(out)

    plt.plot(
        [0, 1],
        [np.mean(np.array([out[False] for out in animals])), np.mean(np.array([out[True] for out in animals]))],
        marker="o",
        color="k",
        linewidth=3,
    )
    plt.xticks([0, 1], ["Not Rewarded", "Rewarded"])
    plt.ylabel("P(Stay)")
    plt.title("P(Stay) on Low Patch after High Patch Trial Outcome")
    plt.legend(loc="upper left", bbox_to_anchor=(1.02, 1))
    plt.savefig(DERIVED_PATH / "p_stay_on_low_patch_after_high_patch_outcome.svg")
    plt.show()

In [None]:
import numpy as np
import pandas as pd
from matplotlib.lines import Line2D

from ssvr.analysis.block_switching_behavior import calculate_consecutive_choice_runs, plot_trials_to_criterion_histogram

n_consecutive = 3
consecutive_runs_df = calculate_consecutive_choice_runs(
    all_trials_df, switch_trials_df, n_consecutive=n_consecutive, max_trials_ahead=40
)

with viz.a_lot_of_style():
    fig, axes = plt.subplots(1, 1)

    plot_trials_to_criterion_histogram(consecutive_runs_df, ax=axes)
    axes.vlines(n_consecutive, 0, axes.get_ylim()[1], colors="green", linestyles="dashed")
    fig.savefig(DERIVED_PATH / "trials_to_criterion_all_subjects.svg")
    plt.show()

    for subject, subject_df in consecutive_runs_df.groupby("subject"):
        fig, axes = plt.subplots(1, 1)

        plot_trials_to_criterion_histogram(subject_df, ax=axes)
        axes.vlines(n_consecutive, 0, axes.get_ylim()[1], colors="green", linestyles="dashed")
        fig.suptitle(f"Animal {subject}")
        fig.savefig(DERIVED_PATH / f"trials_to_criterion_animal_{subject}.svg")
        plt.show()


summary_stats = []

for subject, subject_df in consecutive_runs_df.groupby("subject"):
    subject_df = subject_df.sort_values("trial_index")

    unique_trials = subject_df["trial_index"].unique()
    mid_point = len(unique_trials) // 2

    first_half_df = subject_df[subject_df["trial_index"].isin(unique_trials[:mid_point])]
    second_half_df = subject_df[subject_df["trial_index"].isin(unique_trials[mid_point:])]

    fig, axes = plt.subplots(1, 2, figsize=(8, 5), sharey=True)
    plot_trials_to_criterion_histogram(first_half_df, ax=axes[0], title="First Half")
    plot_trials_to_criterion_histogram(second_half_df, ax=axes[1], title="Second Half")

    fig.suptitle(f"Animal {subject} - Split Half Analysis")
    fig.tight_layout()
    fig.savefig(DERIVED_PATH / f"trials_to_criterion_split_animal_{subject}.svg")
    plt.show()

    # Calculate Medians for Summary
    def get_medians(df):
        low = df[df["is_low_reward_patch"]]["trials_to_n_consecutive_false"].dropna()
        high = df[~df["is_low_reward_patch"]]["trials_to_n_consecutive_true"].dropna()
        return (np.median(low) if len(low) > 0 else np.nan, np.median(high) if len(high) > 0 else np.nan)

    l1, h1 = get_medians(first_half_df)
    l2, h2 = get_medians(second_half_df)

    summary_stats.append({"subject": subject, "low_1st": l1, "high_1st": h1, "low_2nd": l2, "high_2nd": h2})

summary_df = pd.DataFrame(summary_stats)
fig_sum, ax_sum = plt.subplots(figsize=(8, 6))

for _, row in summary_df.iterrows():
    # Low Reward (Blue)
    ax_sum.plot([0, 1], [row["low_1st"], row["low_2nd"]], "o-", color="blue", alpha=0.5)
    # High Reward (Red)
    ax_sum.plot([0, 1], [row["high_1st"], row["high_2nd"]], "o-", color="red", alpha=0.5)

ax_sum.set_xticks([0, 1])
ax_sum.set_xticklabels(["1st Half", "2nd Half"])
ax_sum.set_ylabel("Median Trials to Criterion")
ax_sum.set_title("Change in Performance Across Session Halves")
ax_sum.grid(True, alpha=0.3)

# Custom legend
custom_lines = [Line2D([0], [0], color="blue", lw=2, marker="o"), Line2D([0], [0], color="red", lw=2, marker="o")]
ax_sum.legend(custom_lines, ["Low Reward Patch", "High Reward Patch"])

fig_sum.savefig(DERIVED_PATH / "trials_to_criterion_summary_split.svg")
plt.show()

In [None]:
from ssvr.analysis.logistic_regression import (
    create_regression_design_matrix,
    fit_logistic_regression,
    perform_bootstrap_regression,
    perform_cross_validation,
    plot_regression_coefficients,
    plot_regression_coefficients_with_ci,
)

with viz.a_lot_of_style():
    # 1. Prepare Data
    n_back = 10
    fit_intercept = True
    regression_df, feature_cols = create_regression_design_matrix(all_trials_df, n_back=n_back)

    print(f"Data shape after cleaning: {regression_df.shape}")

    # 2. Cross Validation (Pooled)
    print("\n--- Cross Validation Results (Pooled) ---")
    scores_all = perform_cross_validation(regression_df, feature_cols, cv=5, fit_intercept=fit_intercept)
    print(f"All Subjects CV Accuracy: {scores_all.mean():.3f} (+/- {scores_all.std() * 2:.3f})")

    # 3. Bootstrapping for Confidence Intervals (Pooled)
    print("\n--- Bootstrapping Confidence Intervals (Pooled) ---")
    n_bootstraps = 100
    print(f"Running {n_bootstraps} bootstraps...")
    coefs_boot, intercepts_boot, scores_boot = perform_bootstrap_regression(
        regression_df, feature_cols, n_bootstraps=n_bootstraps, fit_intercept=fit_intercept
    )
    print(f"Bootstrap OOB Accuracy: {scores_boot.mean():.3f} (+/- {scores_boot.std() * 2:.3f})")

    fig_ci, ax_ci = plot_regression_coefficients_with_ci(coefs_boot, intercepts_boot, n_back)
    fig_ci.suptitle("Logistic Regression with 95% CI (All Subjects)")
    fig_ci.savefig(DERIVED_PATH / "logistic_regression_coefficients_all_ci.svg")
    plt.show()

    # 4. Per Subject
    for subject, subject_df in regression_df.groupby("subject"):
        print(f"\nProcessing Subject {subject}...")

        scores_subj = perform_cross_validation(subject_df, feature_cols, cv=5, fit_intercept=fit_intercept)
        print(f"  CV Accuracy: {scores_subj.mean():.3f} (+/- {scores_subj.std() * 2:.3f})")

        print(f"  Running bootstraps for {subject}...")
        coefs_subj_boot, intercepts_subj_boot, scores_subj_boot = perform_bootstrap_regression(
            subject_df, feature_cols, n_bootstraps=n_bootstraps, fit_intercept=fit_intercept
        )
        print(f"  Bootstrap OOB Accuracy: {scores_subj_boot.mean():.3f} (+/- {scores_subj_boot.std() * 2:.3f})")

        fig_subj_ci, ax_subj_ci = plot_regression_coefficients_with_ci(
            coefs_subj_boot,
            intercepts_subj_boot,
            n_back,
        )
        fig_subj_ci.suptitle(f"Logistic Regression with 95% CI: Subject {subject}")
        fig_subj_ci.savefig(DERIVED_PATH / f"logistic_regression_coefficients_{subject}_ci.svg")
        plt.show()

    # 5. Negative Control (Shuffled Labels)
    print("\n--- Negative Control (Shuffled Labels) ---")
    shuffled_df = regression_df.copy()
    shuffled_df["is_choice"] = np.random.permutation(shuffled_df["is_choice"].values)

    # Fit & CV on Shuffled Data (Pooled)
    scores_shuffled = perform_cross_validation(shuffled_df, feature_cols, cv=5, fit_intercept=fit_intercept)
    print(f"Shuffled Labels CV Accuracy: {scores_shuffled.mean():.3f} (+/- {scores_shuffled.std() * 2:.3f})")

    # For negative control, we can just show the standard plot as a quick check, or bootstrap it too.
    # Standard plot is faster and sufficient to show it's noise.
    model_shuffled = fit_logistic_regression(shuffled_df, feature_cols, fit_intercept=fit_intercept)
    fig, ax = plot_regression_coefficients(model_shuffled, n_back)
    fig.suptitle("Logistic Regression: Shuffled Control (All Subjects)")
    fig.savefig(DERIVED_PATH / "logistic_regression_coefficients_shuffled.svg")
    plt.show()

I think plotting the distributions here is fair because:
- the distribution of underlying delays is the same across patches
- We are asking: 
  * Given trials where the animal did not make a stop, where do the "leave times" cluster?
  * Given trials where the animal did make a choice, where do the times cluster?
  * Are they willing to wait less in both cases?

In [None]:
all_trials_df["is_high_reward_patch"] = all_trials_df["patch_index"] == all_trials_df["high_patch_index"]
mask_for_variable_stop = (~all_trials_df["is_fixed_stop_duration"]) & (
    all_trials_df["trials_from_last_block_by_trial_type"] > 5
)
# mask_for_variable_stop = (all_trials_df["is_fixed_stop_duration"]) & (all_trials_df["trials_from_last_block_by_trial_type"] > 5)

filtered_df = all_trials_df[mask_for_variable_stop]


def calculate_hazard_with_censoring(non_completed_data, all_data, bins):
    """
    Calculate hazard function accounting for censored observations.

    Parameters:
    - non_completed_data: Times for non-completed trials (events)
    - all_data: Times for all trials (completed + non-completed)
    - bins: Bin edges for the histogram

    Returns:
    - hazard: Array of hazard rates for each bin
    """
    # Count non-completed trials in each bin (numerator)
    non_completed_counts, _ = np.histogram(non_completed_data, bins=bins)

    # For each bin, count all trials (completed + non-completed) with times >= bin edge (denominator)
    hazard = np.zeros(len(bins) - 1)
    for i in range(len(bins) - 1):
        at_risk = np.sum(all_data >= bins[i])
        if at_risk > 0:
            hazard[i] = non_completed_counts[i] / at_risk

    return hazard


choices = [False, True]
bin_width = 0.25
max_time = 7.1
bins = np.arange(0, max_time + bin_width, bin_width)
bin_centers = (bins[:-1] + bins[1:]) / 2

for subject, subject_df in filtered_df.groupby("subject"):
    fig, axes = plt.subplots(1, 3, figsize=(18, 5), sharex=True)

    is_choice_times_data = {False: {}, True: {}}

    for i, is_choice in enumerate(choices):
        ax = axes[i]
        subset = subject_df[subject_df["is_choice"] == is_choice]

        for is_high_patch in [False, True]:
            sub_subset = subset[subset["is_high_reward_patch"] == is_high_patch]
            if is_choice:
                stop_duration = (sub_subset["choice_time"] - sub_subset["stop_time"]).dropna()
            else:
                stop_duration = sub_subset["longest_stop_duration"].dropna()
            stop_duration = stop_duration[stop_duration <= max_time]

            is_choice_times_data[is_choice][is_high_patch] = stop_duration

            if len(stop_duration) == 0:
                continue

            label = "High Reward" if is_high_patch else "Low Reward"
            color = "red" if is_high_patch else "blue"

        # For this patch type (is_high), get both completed and non-completed trials
    for i, is_choice in enumerate(choices):
        ax = axes[i]
        choices = is_choice_times_data[is_choice]
        for is_high_patch in [False, True]:
            color = "red" if is_high_patch else "blue"
            label = "High Reward" if is_high_patch else "Low Reward"
            patch_choices = choices[is_high_patch]
            # Histogram: show non-completed as fraction of total trials
            weights = np.ones_like(patch_choices) / len(patch_choices)
            ax.hist(patch_choices, bins=bins, alpha=0.3, weights=weights, label=label, color=color)

            # Median
            median_val = patch_choices.median()
            ax.axvline(median_val, color=color, linestyle="--", linewidth=2, label=f"Median: {median_val:.2f}")

            ax.set_title(f"Is Choice: {is_choice}")
            ax.set_xlabel("Time (s)")
            ax.set_ylabel("Fraction of Total Trials")
            ax.grid(True, alpha=0.3)
            ax.legend(loc="upper right")

    # Third plot: Hazard rate with bootstrap confidence intervals
    ax_hazard = axes[2]
    n_bootstraps = 1000

    for is_high_patch in [False, True]:
        non_completed = is_choice_times_data[False].get(is_high_patch)
        completed = is_choice_times_data[True].get(is_high_patch)

        if non_completed is None or len(non_completed) == 0:
            continue

        label = "High Reward" if is_high_patch else "Low Reward"
        color = "red" if is_high_patch else "blue"

        # Combine all trials for hazard calculation
        if completed is not None and len(completed) > 0:
            all_trials = pd.concat([non_completed, completed])
        else:
            all_trials = non_completed

        # Bootstrap for confidence intervals
        boot_hazards = []
        nc_vals = non_completed.values
        all_vals = all_trials.values

        for _ in range(n_bootstraps):
            # Resample with replacement
            nc_sample = np.random.choice(nc_vals, size=len(nc_vals), replace=True)
            all_sample = np.random.choice(all_vals, size=len(all_vals), replace=True)

            # Calculate hazard for this bootstrap sample
            boot_haz = calculate_hazard_with_censoring(nc_sample, all_sample, bins)
            boot_hazards.append(boot_haz)

        boot_hazards = np.array(boot_hazards)
        mean_hazard = np.mean(boot_hazards, axis=0)
        ci_lower = np.percentile(boot_hazards, 2.5, axis=0)
        ci_upper = np.percentile(boot_hazards, 97.5, axis=0)

        # Plot mean hazard with confidence interval
        ax_hazard.plot(bin_centers, mean_hazard, color=color, linestyle="-", linewidth=2, label=label, marker="o")
        ax_hazard.fill_between(bin_centers, ci_lower, ci_upper, color=color, alpha=0.2)

    ax_hazard.set_ylabel("Hazard Rate")
    ax_hazard.set_xlabel("Time (s)")
    ax_hazard.set_title("Hazard rate of non-completed trials")
    ax_hazard.legend(loc="upper right")
    ax_hazard.grid(True, alpha=0.3)
    ax_hazard.set_ylim(bottom=0, top=1)

    fig.suptitle(f"Subject {subject}")
    plt.tight_layout()
    plt.show()

In [None]:
for dataset in session_datasets[0:1]:
    print(f"Session: {dataset.session_info.session_id} ({dataset.session_info.subject})")
    ax_velocity, ax_events = viz.plot_ethogram(
        dataset,
        t_start=dataset.trials["odor_onset_time"][30],
        t_end=dataset.trials["odor_onset_time"][40],
        figsize=(12, 3),
    )
    dataset.dataset["Behavior"]["OperationControl"]

    unique_patches = dataset.trials["patch_index"].unique()

    pairwise_style = {
        (patch_idx, is_choice): {
            "color": viz.patch_index_colormap[patch_idx],
            "linestyle": choice_linestyle[is_choice],
            "alpha": 0.02,
        }
        for patch_idx in unique_patches
        for is_choice in [True, False]
    }

    fig, ax = viz.plot_aligned_to_grouped_by(
        timestamp_df=dataset.trials,
        timeseries=dataset.processed_streams.sniff_ipi_frequency["frequency"],
        by=["patch_index", "is_choice"],
        timestamp_column="odor_onset_time",
        plot_kwargs=pairwise_style,
        event_window=(-2, 5),
    )
    ax.set_ylabel("Sniff Frequency (Hz)")

    fig, ax = viz.plot_aligned_to_grouped_by(
        timestamp_df=dataset.trials,
        timeseries=dataset.processed_streams.position_velocity["velocity"],
        by=["patch_index", "is_choice"],
        timestamp_column="odor_onset_time",
        plot_kwargs=pairwise_style,
        event_window=(-2, 5),
    )
    ax.set_ylabel("Velocity (cm/s)")

    fig, ax = viz.plot_aligned_to_grouped_by(
        timestamp_df=dataset.trials,
        timeseries=dataset.processed_streams.lickometer.frequency,
        by=["patch_index", "is_choice"],
        timestamp_column="odor_onset_time",
        plot_kwargs=pairwise_style,
        event_window=(-2, 5),
    )
    ax.set_ylabel("Lick rate (Hz)")

    pairwise_style = {
        (p_reward, is_choice): {
            "color": "red" if p_reward > 0.5 else "blue",
            "linestyle": choice_linestyle[is_choice],
            "alpha": 0.02,
        }
        for p_reward in np.unique(dataset.trials["p_reward"].values)
        for is_choice in [True, False]
    }

    fig, ax = viz.plot_aligned_to_grouped_by(
        timestamp_df=dataset.trials.query("p_reward < 1.0"),
        timeseries=dataset.processed_streams.sniff_ipi_frequency["frequency"],
        by=["p_reward", "is_choice"],
        timestamp_column="odor_onset_time",
        plot_kwargs=pairwise_style,
        event_window=(-2, 5),
    )

    fig, ax = viz.plot_aligned_to_grouped_by(
        timestamp_df=dataset.trials.query("p_reward < 1.0"),
        timeseries=dataset.processed_streams.position_velocity["velocity"],
        by=["p_reward", "is_choice"],
        timestamp_column="odor_onset_time",
        plot_kwargs=pairwise_style,
        event_window=(-2, 5),
    )

    fig, ax = viz.plot_aligned_to_grouped_by(
        timestamp_df=dataset.trials.query("p_reward < 1.0"),
        timeseries=dataset.processed_streams.lickometer.frequency,
        by=["p_reward", "is_choice"],
        timestamp_column="odor_onset_time",
        plot_kwargs=pairwise_style,
        event_window=(-2, 5),
    )

    ax = viz.plot_session_trials(dataset, alpha=0.33, figsize=(16, 6))

    time_of_trial = dataset.trials["odor_onset_time"]

    blocks = dataset.dataset["Behavior"]["SoftwareEvents"]["Block"].load().data.copy()
    block_times = blocks.index.values
    trial_indices = time_of_trial.searchsorted(block_times, side="right") - 1
    trial_indices = np.maximum(trial_indices, 0)
    blocks["trial_idx"] = time_of_trial.iloc[trial_indices].index.values
    ax.vlines(
        blocks["trial_idx"].values,
        ymin=ax.get_ylim()[0],
        ymax=ax.get_ylim()[1],
        colors="k",
        linestyles="dashed",
        label="Block Change",
    )

    plt.show()

In [None]:
pairwise_style = {
    (patch_quality, is_choice): {
        "color": "red" if patch_quality else "blue",
        "linestyle": choice_linestyle[is_choice],
        "alpha": 0.02,
    }
    for patch_quality in [True, False]
    for is_choice in [True, False]
}


def make_legend(pairwise_style):
    custom_lines = []
    labels = []
    added = set()
    for (patch_quality, is_choice), style in pairwise_style.items():
        label = f"{'High Reward' if patch_quality else 'Low Reward'} - {'Choice' if is_choice else 'No Choice'}"
        if label not in added:
            line = plt.Line2D([0], [0], color=style["color"], linestyle=style["linestyle"])
            custom_lines.append(line)
            labels.append(label)
            added.add(label)
    return custom_lines, labels


streams = [
    (lambda d: d.processed_streams.sniff_ipi_frequency["frequency"], "Sniff Frequency Aligned to Odor Onset", (1, 10)),
    (lambda d: d.processed_streams.position_velocity["velocity"], "Velocity Aligned to Odor Onset", (-5, 60)),
    (lambda d: d.processed_streams.lickometer.frequency, "Lick Rate Aligned to Odor Onset", (0, 10)),
]
WINDOW = (-2, 5)

with viz.a_lot_of_style():
    for stream_fn, stream_label, y_limits in streams:
        fig, axs = plt.subplots(1, len(all_trials_df["subject"].unique()), figsize=(20, 5), sharex=True, sharey=True)

        for i, subject in enumerate(all_trials_df["subject"].unique()):
            _, ax = viz.plot_aligned_to_grouped_by(
                timestamp_df=[
                    d.trials.query("trials_from_last_block_by_trial_type > 5")
                    for d in session_datasets
                    if d.session_info.subject == subject
                ],
                timeseries=[stream_fn(d) for d in session_datasets if d.session_info.subject == subject],
                by=["is_high_reward_patch", "is_choice"],
                timestamp_column="odor_onset_time",
                event_window=WINDOW,
                time_bin_width=0.025,
                plot_kwargs=pairwise_style,
                agg_plot_kwarg_modifier={"linewidth": 3, "alpha": 1},
                agg_spread_kwarg_modifier={"alpha": 0.05, "linewidth": 0},
                ax=axs[i],
            )
            ax.set_xlim(WINDOW[0], WINDOW[1])
            ax.set_ylim(y_limits[0], y_limits[1])
            ax.vlines(0, ymin=ax.get_ylim()[0], ymax=ax.get_ylim()[1], colors="k", linestyles="dashed")
            ax.set_xlabel(f"Time from Odor Onset (s) \n Subject {subject}")
        fig.suptitle(stream_label)
        fig.legend(*make_legend(pairwise_style), loc="upper right", bbox_to_anchor=(1, 1))
    plt.show()