In [None]:
import datetime
import logging
import typing as t
from pathlib import Path

import matplotlib.pyplot as plt
import pandas as pd
import semver
from tqdm.notebook import tqdm

import ssvr.enrich_trials
import ssvr.qc
import ssvr.utils

%load_ext autoreload
%autoreload 2
from ssvr.dataset import SessionDataset, find_session_info
from ssvr.models import DataLoadingSettings, SessionInfo

logging.getLogger("ssvr").setLevel(logging.ERROR)

logging.getLogger("aind_behavior_services.base").setLevel(logging.ERROR)

choice_linestyle = {True: "-", False: "--"}
subject_colors = {"808619": "C2", "808728": "C3", "789917": "C4"}

In [None]:
settings = DataLoadingSettings()
session_info = list(find_session_info(settings))
session_info = [session for session in session_info if session.version >= semver.Version.parse("0.6.0")]
session_datasets: list[SessionDataset] = []

session_filter = pd.read_csv(Path(settings.root_path[0]) / "../sessions.csv")
session_filter["use"] = session_filter["use"].apply(lambda x: x == "1" if pd.notna(x) else False)
session_filter["crop_trial"] = pd.to_numeric(session_filter["crop_trial"], errors="coerce").astype("Int64")


def _use_session(session: SessionInfo) -> bool:
    session_name = session.session_id
    if session_name not in session_filter["session"].values:
        return False
    return session_filter[session_filter["session"] == session_name]["use"].iloc[0]


session_datasets = []
for info in tqdm(session_info, desc="Loading sessions"):
    if _use_session(info):
        try:
            _session = SessionDataset(session_info=info, processing_settings=settings.processing_settings)
            row = session_filter[session_filter["session"] == info.session_id].iloc[0]
            if not _session.session_metrics.session_duration < datetime.timedelta(minutes=15):
                if pd.notna(row["crop_trial"]):
                    _session.trials = _session.trials[: int(row["crop_trial"])]
                session_datasets.append(_session)
        except Exception as e:
            print(f"Failed to load session {info.session_id}: {e}")

for session in tqdm(session_datasets, desc="Enriching sessions"):
    enriched_trials = ssvr.enrich_trials.enrich_with_block_info(session)
    enriched_trials = ssvr.enrich_trials.enrich_with_relative_to_block(session)
    enriched_trials = ssvr.enrich_trials.enrich_with_previous_trial(session, n_previous=5)

if 0:
    ssvr.qc.run_qc(session_datasets=session_datasets, path=Path("./derived") / "qc_reports")

In [None]:
all_trials = []
for session in session_datasets:
    df = session.trials.copy()
    df = df.reset_index().rename(columns={"index": "trial_number"})
    df["subject"] = session.session_info.subject
    df["session_id"] = session.session_info.session_id
    all_trials.append(df)

all_trials_df = t.cast(pd.DataFrame, pd.concat(all_trials, ignore_index=True))
# all_trials_df.to_csv(Path("./derived") / "all_sessions_enriched_trials.csv")
print(all_trials_df.info())

In [None]:
# X axis is trials to block switch
# Y axis is probability of choice
# we only want switches where the probabilities actually switched
import numpy as np

block_switch_filter: t.Literal["same", "different", "both"] = "different"

block_switches = (
    all_trials_df["block_index"].diff().fillna(0) > 0
)  # this also gets rid of cross session switches since those diffs will be <=0
switch_indices = all_trials_df.index[block_switches]
all_trials_df.sort_index(inplace=True)  # just to ensure the shift works correctly
block_probabilities_before = all_trials_df["block_patch_probabilities"].shift()[block_switches]
block_probabilities_after = all_trials_df["block_patch_probabilities"][block_switches]
prob_switch_df = pd.DataFrame(
    {
        "before": block_probabilities_before.values,
        "after": block_probabilities_after.values,
        "before_high_index": [np.argmax(probs) for probs in block_probabilities_before.values],
        "after_high_index": [np.argmax(probs) for probs in block_probabilities_after.values],
        "after_low_index": [np.argmin(probs) for probs in block_probabilities_after.values],
        "before_low_index": [np.argmin(probs) for probs in block_probabilities_before.values],
    },
    index=switch_indices,
)

if block_switch_filter == "same":
    prob_switch_df = prob_switch_df[prob_switch_df["before"].apply(tuple) == prob_switch_df["after"].apply(tuple)]
elif block_switch_filter == "different":
    prob_switch_df = prob_switch_df[prob_switch_df["before"].apply(tuple) != prob_switch_df["after"].apply(tuple)]
elif block_switch_filter == "both":
    pass
else:
    raise ValueError(f"Invalid block_switch_filter: {block_switch_filter}")

In [None]:
trial_window = (-10, 30)
is_patch_low_after_zip = (0, 1)  # high, low

# Initialize with NaN to handle missing data
switch_choice_data = np.full(
    shape=(len(prob_switch_df), trial_window[1] - trial_window[0], 2), fill_value=np.nan, dtype=float
)

# Also store subject info for each switch
switch_subjects = []

for i_switch, (trial_switch, row) in tqdm(enumerate(prob_switch_df.iterrows()), desc="Processing block switches"):
    switch_trial = all_trials_df.loc[trial_switch]
    session_id = switch_trial["session_id"]
    switch_subjects.append(switch_trial["subject"])
    session_trials = all_trials_df[all_trials_df["session_id"] == session_id]
    trial_window_mask_after = (session_trials["trials_from_last_block_by_trial_type"] < trial_window[1]) & (
        session_trials["block_index"] == switch_trial["block_index"]
    )
    trial_window_mask_before = (session_trials["trials_to_next_block_by_trial_type"] < -trial_window[0]) & (
        session_trials["block_index"] == switch_trial["block_index"] - 1
    )
    trial_window_mask = trial_window_mask_after | trial_window_mask_before
    for is_patch_low_after in is_patch_low_after_zip:
        if is_patch_low_after == 0:
            patch_idx = prob_switch_df.loc[trial_switch]["after_high_index"]
        else:
            patch_idx = prob_switch_df.loc[trial_switch]["after_low_index"]
        patch_id_mask = session_trials["patch_index"] == patch_idx
        trials_to_take = session_trials[trial_window_mask & patch_id_mask]
        if len(trials_to_take) == 0:
            continue
        min_idx = trials_to_take.iloc[0]["trials_to_next_block_by_trial_type"]
        max_idx = trials_to_take.iloc[-1]["trials_from_last_block_by_trial_type"]
        slice_from_array_start = -(trial_window[0] + min_idx) - 1
        slice_from_array_end = slice_from_array_start + len(trials_to_take)
        switch_choice_data[i_switch, slice_from_array_start:slice_from_array_end, is_patch_low_after] = trials_to_take[
            "is_choice"
        ].values

# Convert to array for easier indexing
switch_subjects = np.array(switch_subjects)
unique_subjects = np.unique(switch_subjects)

# Create x-axis labels for trial positions relative to block switch
x_positions = np.arange(trial_window[0], trial_window[1])

patch_names = ["High Reward Patch", "Low Reward Patch"]
patch_colors = ["red", "blue"]

# Plot 1: All animals combined
fig, axes = plt.subplots(3, 1, figsize=(12, 15))
fig.suptitle("Choice Patterns Around Block Switches - All Animals", fontsize=16)

# Top two plots: Heatmaps for each patch type (all animals)
for patch_id in range(2):
    im = axes[patch_id].imshow(
        switch_choice_data[:, :, patch_id], aspect="auto", cmap="RdYlBu_r", interpolation="none", vmin=0, vmax=1
    )
    axes[patch_id].set_title(f"{patch_names[patch_id]} - Individual Switches (All Animals)")
    axes[patch_id].set_ylabel("Block Switch Number")

    # Set x-axis ticks and labels
    tick_positions = np.arange(0, len(x_positions), 10)
    tick_labels = x_positions[tick_positions]
    axes[patch_id].set_xticks(tick_positions)
    axes[patch_id].set_xticklabels(tick_labels)

    # Add vertical line at switch point (x=0)
    switch_position = -trial_window[0]
    axes[patch_id].axvline(x=switch_position, color="white", linestyle="--", alpha=0.8)

    # Add colorbar
    plt.colorbar(im, ax=axes[patch_id], label="Choice Probability")

# Bottom plot: Overlapped average choice probabilities (all animals)
for patch_id in range(2):
    mean_choice = np.nanmean(switch_choice_data[:, :, patch_id], axis=0)
    sem_choice = np.nanstd(switch_choice_data[:, :, patch_id], axis=0) / np.sqrt(
        np.sum(~np.isnan(switch_choice_data[:, :, patch_id]), axis=0)
    )

    axes[2].plot(
        x_positions,
        mean_choice,
        "o-",
        color=patch_colors[patch_id],
        alpha=0.8,
        label=patch_names[patch_id],
        linewidth=2,
    )
    axes[2].fill_between(
        x_positions, mean_choice - sem_choice, mean_choice + sem_choice, alpha=0.2, color=patch_colors[patch_id]
    )

axes[2].set_title("Average Choice Probability - Both Patches (All Animals)")
axes[2].set_xlabel("Trials Relative to Block Switch")
axes[2].set_ylabel("Choice Probability")
axes[2].set_ylim(0, 1)
axes[2].grid(True, alpha=0.3)
axes[2].axvline(x=0, color="black", linestyle="--", alpha=0.8, label="Block Switch")
axes[2].legend()

plt.tight_layout()
plt.savefig("block_switch_choice_patterns_all_animals.png")
plt.show()

# Plot 2: Per-animal analysis
for subject in unique_subjects:
    subject_mask = switch_subjects == subject
    subject_data = switch_choice_data[subject_mask, :, :]

    if subject_data.shape[0] == 0:  # Skip if no data for this subject
        continue

    fig, axes = plt.subplots(3, 1, figsize=(12, 15))
    fig.suptitle(f"Choice Patterns Around Block Switches - Animal {subject}", fontsize=16)

    # Top two plots: Heatmaps for each patch type (this animal)
    for patch_id in range(2):
        im = axes[patch_id].imshow(
            subject_data[:, :, patch_id], aspect="auto", cmap="RdYlBu_r", interpolation="none", vmin=0, vmax=1
        )
        axes[patch_id].set_title(f"{patch_names[patch_id]} - Individual Switches (Animal {subject})")
        axes[patch_id].set_ylabel("Block Switch Number")

        # Set x-axis ticks and labels
        tick_positions = np.arange(0, len(x_positions), 10)
        tick_labels = x_positions[tick_positions]
        axes[patch_id].set_xticks(tick_positions)
        axes[patch_id].set_xticklabels(tick_labels)

        # Add vertical line at switch point (x=0)
        switch_position = -trial_window[0]
        axes[patch_id].axvline(x=switch_position, color="white", linestyle="--", alpha=0.8)

        # Add colorbar
        plt.colorbar(im, ax=axes[patch_id], label="Choice Probability")

    # Bottom plot: Overlapped average choice probabilities (this animal)
    for patch_id in range(2):
        mean_choice = np.nanmean(subject_data[:, :, patch_id], axis=0)
        sem_choice = np.nanstd(subject_data[:, :, patch_id], axis=0) / np.sqrt(
            np.sum(~np.isnan(subject_data[:, :, patch_id]), axis=0)
        )

        axes[2].plot(
            x_positions,
            mean_choice,
            "o-",
            color=patch_colors[patch_id],
            alpha=0.8,
            label=patch_names[patch_id],
            linewidth=2,
        )
        axes[2].fill_between(
            x_positions, mean_choice - sem_choice, mean_choice + sem_choice, alpha=0.2, color=patch_colors[patch_id]
        )

    axes[2].set_title(f"Average Choice Probability - Both Patches (Animal {subject})")
    axes[2].set_xlabel("Trials Relative to Block Switch")
    axes[2].set_ylabel("Choice Probability")
    axes[2].set_ylim(0, 1)
    axes[2].grid(True, alpha=0.3)
    axes[2].axvline(x=0, color="black", linestyle="--", alpha=0.8, label="Block Switch")
    axes[2].legend()

    plt.tight_layout()
    plt.savefig(f"block_switch_choice_patterns_animal_{subject}.png")
    plt.show()