In [None]:
import datetime
import logging
import typing as t
from pathlib import Path

import pandas as pd
import semver
from tqdm.notebook import tqdm

import ssvr.enrich_trials
import ssvr.qc
import ssvr.utils

%load_ext autoreload
%autoreload 2
from ssvr.dataset import SessionDataset, find_session_info
from ssvr.models import DataLoadingSettings, SessionInfo

logging.getLogger("ssvr").setLevel(logging.ERROR)

logging.getLogger("aind_behavior_services.base").setLevel(logging.ERROR)

choice_linestyle = {True: "-", False: "--"}
subject_colors = {"808619": "C2", "808728": "C3", "789917": "C4"}

In [None]:
settings = DataLoadingSettings()
session_info = list(find_session_info(settings))
session_info = [session for session in session_info if session.version >= semver.Version.parse("0.6.0")]
session_datasets: list[SessionDataset] = []

session_filter = pd.read_csv(Path(settings.root_path[0]) / "../sessions.csv")
session_filter["use"] = session_filter["use"].apply(lambda x: x == "1" if pd.notna(x) else False)
session_filter["crop_trial"] = pd.to_numeric(session_filter["crop_trial"], errors="coerce").astype("Int64")


def _use_session(session: SessionInfo) -> bool:
    session_name = session.session_id
    if session_name not in session_filter["session"].values:
        return False
    return session_filter[session_filter["session"] == session_name]["use"].iloc[0]


session_datasets = []
for info in tqdm(session_info, desc="Loading sessions"):
    if _use_session(info):
        try:
            _session = SessionDataset(session_info=info, processing_settings=settings.processing_settings)
            row = session_filter[session_filter["session"] == info.session_id].iloc[0]
            if not _session.session_metrics.session_duration < datetime.timedelta(minutes=15):
                if pd.notna(row["crop_trial"]):
                    _session.trials = _session.trials[: int(row["crop_trial"])]
                session_datasets.append(_session)
        except Exception as e:
            print(f"Failed to load session {info.session_id}: {e}")

for session in tqdm(session_datasets, desc="Enriching sessions"):
    enriched_trials = ssvr.enrich_trials.enrich_with_session_type(session) # is_fixed_stop_duration
    enriched_trials = ssvr.enrich_trials.enrich_with_block_info(session)
    enriched_trials = ssvr.enrich_trials.enrich_with_relative_to_block(session)
    enriched_trials = ssvr.enrich_trials.enrich_with_previous_trial(session, n_previous=5)
    enriched_trials = ssvr.enrich_trials.enrich_with_block_probability(session)
    enriched_trials = ssvr.enrich_trials.enrich_with_reward_rate(session, exponential_decay=0.7071)

if 0:
    ssvr.qc.run_qc(session_datasets=session_datasets, path=Path("./derived") / "qc_reports")

In [None]:
all_trials = []
for session in session_datasets:
    df = session.trials.copy()
    df = df.reset_index().rename(columns={"index": "trial_number"})
    df["subject"] = session.session_info.subject
    df["session_id"] = session.session_info.session_id
    all_trials.append(df)

all_trials_df = t.cast(pd.DataFrame, pd.concat(all_trials, ignore_index=True))
# all_trials_df.to_csv(Path("./derived") / "all_sessions_enriched_trials.csv")
print(all_trials_df.info())

In [None]:
from ssvr.analysis.block_switching_behavior import (
    calculate_choice_matrix,
    plot_block_switch_choice_patterns,
)

trial_window = (-10, 30)
choice_behavior_matrix, switch_trials_df = calculate_choice_matrix(
    all_trials_df, trial_window=trial_window, block_switch_filter="different"
)
fig, ax = plot_block_switch_choice_patterns(choice_behavior_matrix, trial_window)
fig.suptitle("All subjects")
fig.savefig(Path("./derived") / "block_switch_choice_patterns_all_subjects.png")
for subject, df in switch_trials_df.groupby("subject"):
    fig, ax = plot_block_switch_choice_patterns(
        choice_behavior_matrix[df["index_ord"].values, :, :],
        trial_window,
    )
    fig.suptitle(f"Subject {subject}")
    fig.savefig(Path("./derived") / f"block_switch_choice_patterns_subject_{subject}.png")

## Check if there are biases depending on the odor patch switched to
for (subject, high_patch_index), df in switch_trials_df.groupby(["subject", "after_high_index"]):
    fig, ax = plot_block_switch_choice_patterns(
        choice_behavior_matrix[df["index_ord"].values, :, :],
        trial_window,
    )
    fig.suptitle(f"Subject {subject}, After High Patch Index {high_patch_index}")
    fig.savefig(
        Path("./derived") / f"block_switch_choice_patterns_subject_{subject}_after_high_patch_{high_patch_index}.png"
    )

In [None]:
from pathlib import Path

import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
from matplotlib.lines import Line2D

from ssvr.analysis.block_switching_behavior import calculate_consecutive_choice_runs, plot_trials_to_criterion_histogram

consecutive_runs_df = calculate_consecutive_choice_runs(
    all_trials_df, switch_trials_df, n_consecutive=5, max_trials_ahead=40
)

summary_stats = []

for subject, subject_df in consecutive_runs_df.groupby("subject"):
    subject_df = subject_df.sort_values("trial_index")

    unique_trials = subject_df["trial_index"].unique()
    mid_point = len(unique_trials) // 2

    first_half_df = subject_df[subject_df["trial_index"].isin(unique_trials[:mid_point])]
    second_half_df = subject_df[subject_df["trial_index"].isin(unique_trials[mid_point:])]

    fig, axes = plt.subplots(1, 2, figsize=(12, 5), sharey=True)
    plot_trials_to_criterion_histogram(first_half_df, ax=axes[0], title="First Half")
    plot_trials_to_criterion_histogram(second_half_df, ax=axes[1], title="Second Half")

    fig.suptitle(f"Animal {subject} - Split Half Analysis")
    fig.tight_layout()
    fig.savefig(Path("./derived") / f"trials_to_criterion_split_animal_{subject}.png")
    plt.show()

    # Calculate Medians for Summary
    def get_medians(df):
        low = df[df["is_low_reward_patch"]]["trials_to_n_consecutive_false"].dropna()
        high = df[~df["is_low_reward_patch"]]["trials_to_n_consecutive_true"].dropna()
        return (np.median(low) if len(low) > 0 else np.nan, np.median(high) if len(high) > 0 else np.nan)

    l1, h1 = get_medians(first_half_df)
    l2, h2 = get_medians(second_half_df)

    summary_stats.append({"subject": subject, "low_1st": l1, "high_1st": h1, "low_2nd": l2, "high_2nd": h2})

summary_df = pd.DataFrame(summary_stats)
fig_sum, ax_sum = plt.subplots(figsize=(8, 6))

for _, row in summary_df.iterrows():
    # Low Reward (Blue)
    ax_sum.plot([0, 1], [row["low_1st"], row["low_2nd"]], "o-", color="blue", alpha=0.5)
    # High Reward (Red)
    ax_sum.plot([0, 1], [row["high_1st"], row["high_2nd"]], "o-", color="red", alpha=0.5)

ax_sum.set_xticks([0, 1])
ax_sum.set_xticklabels(["1st Half", "2nd Half"])
ax_sum.set_ylabel("Median Trials to Criterion")
ax_sum.set_title("Change in Performance Across Session Halves")
ax_sum.grid(True, alpha=0.3)

# Custom legend
custom_lines = [Line2D([0], [0], color="blue", lw=2, marker="o"), Line2D([0], [0], color="red", lw=2, marker="o")]
ax_sum.legend(custom_lines, ["Low Reward Patch", "High Reward Patch"])

fig_sum.savefig(Path("./derived") / "trials_to_criterion_summary_split.png")
plt.show()

## Logistic regression

### Regressors
* Previous Stimulus sameness (0, 1), IsRewarded (-0.5, 0, 0.5) # stop no reward, no stop no reward, stop reward
* Two-way interaction between Previous Stimulus sameness and IsRewarded
* Repeat for N choices back


In [None]:
from ssvr.analysis.logistic_regression import (
    create_regression_design_matrix,
    fit_logistic_regression,
    perform_bootstrap_regression,
    perform_cross_validation,
    plot_regression_coefficients,
    plot_regression_coefficients_with_ci,
)

# 1. Prepare Data
n_back = 10
fit_intercept = True
regression_df, feature_cols = create_regression_design_matrix(all_trials_df, n_back=n_back)

print(f"Data shape after cleaning: {regression_df.shape}")

# 2. Cross Validation (Pooled)
print("\n--- Cross Validation Results (Pooled) ---")
scores_all = perform_cross_validation(regression_df, feature_cols, cv=5, fit_intercept=fit_intercept)
print(f"All Subjects CV Accuracy: {scores_all.mean():.3f} (+/- {scores_all.std() * 2:.3f})")

# 3. Bootstrapping for Confidence Intervals (Pooled)
print("\n--- Bootstrapping Confidence Intervals (Pooled) ---")
n_bootstraps = 100
print(f"Running {n_bootstraps} bootstraps...")
coefs_boot, intercepts_boot, scores_boot = perform_bootstrap_regression(
    regression_df, feature_cols, n_bootstraps=n_bootstraps, fit_intercept=fit_intercept
)
print(f"Bootstrap OOB Accuracy: {scores_boot.mean():.3f} (+/- {scores_boot.std() * 2:.3f})")

fig_ci, ax_ci = plot_regression_coefficients_with_ci(coefs_boot, intercepts_boot, n_back)
fig_ci.suptitle("Logistic Regression with 95% CI (All Subjects)")
fig_ci.savefig(Path("./derived") / "logistic_regression_coefficients_all_ci.png")
plt.show()


# 4. Per Subject
for subject, subject_df in regression_df.groupby("subject"):
    print(f"\nProcessing Subject {subject}...")

    scores_subj = perform_cross_validation(subject_df, feature_cols, cv=5, fit_intercept=fit_intercept)
    print(f"  CV Accuracy: {scores_subj.mean():.3f} (+/- {scores_subj.std() * 2:.3f})")

    print(f"  Running bootstraps for {subject}...")
    coefs_subj_boot, intercepts_subj_boot, scores_subj_boot = perform_bootstrap_regression(
        subject_df, feature_cols, n_bootstraps=n_bootstraps, fit_intercept=fit_intercept
    )
    print(f"  Bootstrap OOB Accuracy: {scores_subj_boot.mean():.3f} (+/- {scores_subj_boot.std() * 2:.3f})")

    fig_subj_ci, ax_subj_ci = plot_regression_coefficients_with_ci(
        coefs_subj_boot,
        intercepts_subj_boot,
        n_back,
    )
    fig_subj_ci.suptitle(f"Logistic Regression with 95% CI: Subject {subject}")
    fig_subj_ci.savefig(Path("./derived") / f"logistic_regression_coefficients_{subject}_ci.png")
    plt.show()

# 5. Negative Control (Shuffled Labels)
print("\n--- Negative Control (Shuffled Labels) ---")
shuffled_df = regression_df.copy()
shuffled_df["is_choice"] = np.random.permutation(shuffled_df["is_choice"].values)

# Fit & CV on Shuffled Data (Pooled)
scores_shuffled = perform_cross_validation(shuffled_df, feature_cols, cv=5, fit_intercept=fit_intercept)
print(f"Shuffled Labels CV Accuracy: {scores_shuffled.mean():.3f} (+/- {scores_shuffled.std() * 2:.3f})")

# For negative control, we can just show the standard plot as a quick check, or bootstrap it too.
# Standard plot is faster and sufficient to show it's noise.
model_shuffled = fit_logistic_regression(shuffled_df, feature_cols, fit_intercept=fit_intercept)
fig, ax = plot_regression_coefficients(model_shuffled, n_back)
fig.suptitle("Logistic Regression: Shuffled Control (All Subjects)")
fig.savefig(Path("./derived") / "logistic_regression_coefficients_shuffled.png")
plt.show()

In [None]:
# Calculate correlation matrix
corr_matrix = regression_df[feature_cols].corr()

# Plot heatmap using matplotlib
fig, ax = plt.subplots(figsize=(12, 10))
cax = ax.imshow(corr_matrix, cmap="coolwarm", vmin=-1, vmax=1)
fig.colorbar(cax)

# Set ticks and labels
ax.set_xticks(np.arange(len(feature_cols)))
ax.set_yticks(np.arange(len(feature_cols)))
ax.set_xticklabels(feature_cols, rotation=90)
ax.set_yticklabels(feature_cols)

plt.title("Pairwise Correlation of Regressors")
plt.tight_layout()
plt.savefig(Path("./derived") / "regressor_correlation_matrix.png")
plt.show()

In [None]:
all_trials_df["is_high_reward_patch"] = all_trials_df["patch_index"] == all_trials_df["high_patch_index"]
mask_for_variable_stop = (all_trials_df["is_fixed_stop_duration"] == False) & (all_trials_df["trials_from_last_block"] > 5)
filtered_df = all_trials_df[mask_for_variable_stop]


choices = [False, True]

for subject, subject_df in filtered_df.groupby("subject"):
    fig, axes = plt.subplots(1, 2, figsize=(14, 6), sharey=True)

    for ax, is_choice in zip(axes, choices):
        subset = subject_df[subject_df["is_choice"] == is_choice]
        
        for is_high in [False, True]:
            sub_subset = subset[subset["is_high_reward_patch"] == is_high]
            stop_duration = sub_subset["longest_stop_duration"].dropna()
            
            if len(stop_duration) == 0:
                continue
                
            label = "High Reward" if is_high else "Low Reward"
            color = "red" if is_high else "blue"
            
            ax.hist(stop_duration, bins=30, alpha=0.5, density=True, label=label, color=color)
            
            median_val = stop_duration.median()
            ax.axvline(median_val, color=color, linestyle="--", linewidth=2, label=f"Median: {median_val:.2f}")

        ax.set_title(f"Is Choice: {is_choice}")
        ax.set_xlabel("Longest Stop Duration (s)")
        ax.grid(True, alpha=0.3)
        ax.legend()

    axes[0].set_ylabel("Density")
    plt.tight_layout()
    plt.show()