In [None]:
"""Set notebook settings."""

%load_ext autoreload
%autoreload 2
# %flow mode reactive

In [None]:
"""Import packages."""

# Standard library imports
import math
import os
from collections import defaultdict
from dataclasses import dataclass
from datetime import datetime
from functools import partial
from pathlib import Path
from typing import Any, Callable, Dict, List, Optional, Tuple

# IPython and Jupyter-related imports
import ipywidgets as widgets
from IPython.display import clear_output, display

# Third-party libraries
import h5py
import numpy as np
import pandas as pd
import seaborn as sns
import temporaldata as td
import torch as t
from einops import (
    asnumpy,
    einsum,
    pack,
    parse_shape,
    rearrange,
    reduce,
    repeat,
    unpack,
)
from einops.layers.torch import Rearrange, Reduce
from jaxtyping import Float, Int
from matplotlib import pyplot as plt
from plotly import express as px
from plotly import graph_objects as go
from plotly.subplots import make_subplots
from rich import print as rprint
from scipy import stats
from scipy.ndimage import uniform_filter1d, gaussian_filter1d
from sklearn.linear_model import LinearRegression, Ridge
from sklearn.metrics import classification_report, confusion_matrix, r2_score
from sklearn.pipeline import make_pipeline
from sklearn.preprocessing import StandardScaler
from temporaldata import Data
from torch import Tensor, bfloat16, nn
from torch.nn import functional as F
from torcheval.metrics.functional import r2_score as tm_r2_score
from tqdm.notebook import tqdm

# Local project modules
from mini import train as mt
from mini.util import vec_r2

In [None]:
# Set max rows and cols for df display
pd.set_option("display.max_rows", 300)
pd.set_option("display.max_columns", 25)

In [None]:
def clean_session_data(session):
    """Clean session data by filtering trials and spikes based on quality criteria."""
    
    # Mark Churchland gave me a matlab script which I have converted to python

    num_trials = len(session.trials.start)
    # I think is_valid is already defined by brainsets as (session.trials.discard_trial == 0) & (session.trials.task_success == 1) so it's a bit redundant but including it all just to be sure
    # In theory I should also filter based on whether the maze was possible or not (a field called "unhittable") but I cannot find this in the data, perhaps this has already been done in this release of the data
    good_trials = (session.trials.trial_type > 0) & (session.trials.is_valid == 1) & (session.trials.discard_trial == 0) & (session.trials.novel_maze == 0) & (session.trials.trial_version < 3) 
    session.trials = session.trials.select_by_mask(good_trials)
    new_num_trials = len(session.trials.start)
    if num_trials - new_num_trials > 0:
        print("Filtered out extraneous trials, went from", num_trials, "trials to", new_num_trials)

    num_trials = len(session.trials.start)
    success = (session.trials.task_success == 1)
    session.trials = session.trials.select_by_mask(success)
    new_num_trials = len(session.trials.start)
    if num_trials - new_num_trials > 0:
        print("Filtered out unsuccessful trials, went from", num_trials, "trials to", new_num_trials)

    num_trials = len(session.trials.start)
    post_move = 0.8 # to be kept, there must be at least this many ms after the movement onset
    long_enough = (session.trials.end - session.trials.move_begins_time >= post_move) # should essentially always be true for successes
    session.trials = session.trials.select_by_mask(long_enough)
    new_num_trials = len(session.trials.start)
    if num_trials - new_num_trials > 0:
        print("Filtered out trials that were too short, went from", num_trials, "trials to", new_num_trials)

    num_trials = len(session.trials.start)
    consistent = (session.trials.correct_reach == 1)
    session.trials = session.trials.select_by_mask(consistent)
    new_num_trials = len(session.trials.start)
    if num_trials - new_num_trials > 0:
        print("Filtered out trials with inconsistent reaches (not similar enough to the \"prototypical\" trial), went from", num_trials, "trials to", new_num_trials)

    primary_conditions = np.unique(session.trials.maze_condition)
    num_conditions = len(primary_conditions)
    print("Number of primary conditions:", num_conditions)
    # Check to make sure they are monotonic, starting from 1 and counting up
    if min(primary_conditions) != 1 or len(np.unique(np.diff(primary_conditions))) != 1:
        raise ValueError("Primary conditions are not monotonic or do not start from 1")

    # In theory I should filter units based on a ranking from 1-4 but I cannot find the ranking in the data, perhaps this has already been done in this release of the data

    # Only keep spikes that are within the cleaned trials
    session.spikes = session.spikes.select_by_interval(session.trials)
    session.hand = session.hand.select_by_interval(session.trials)
    session.eye = session.eye.select_by_interval(session.trials)

    # Convert session recording date to timestamp
    session.session.recording_date = datetime.strptime(session.session.recording_date, '%Y-%m-%d %H:%M:%S')
    session.session.recording_date = session.session.recording_date.timestamp()
    
    return session

In [None]:
def analyze_maze_conditions(session):
    """
    Analyze what each maze_condition corresponds to in terms of 
    maze parameters (barriers, targets, hit position).
    """
    
    # Get unique maze conditions
    unique_conditions = np.unique(session.trials.maze_condition)
    
    # Create a summary for each condition
    condition_summary = []
    
    for condition in unique_conditions:
        # Get trials for this condition
        condition_mask = session.trials.maze_condition == condition
        
        # Get the unique values for this condition
        barriers = np.unique(session.trials.maze_num_barriers[condition_mask])
        targets = np.unique(session.trials.maze_num_targets[condition_mask])
        hit_position = np.unique(session.trials.hit_target_position[condition_mask], axis=0)
        if len(hit_position) > 1:
            raise ValueError(f"Condition {condition} has multiple hit positions: {hit_position}")
        else:
            hit_position = hit_position[0]
        
        # Count trials for this condition
        num_trials = np.sum(condition_mask)
        
        # Store for summary table
        condition_summary.append({
            'Maze Condition': condition,
            'Trials': num_trials,
            'Barriers': barriers,
            'Targets': targets,
            'Hit Position': hit_position, 
            'Hit Position Angles': str(np.degrees(np.arctan2(hit_position[1], hit_position[0])))
        })
    summary_df = pd.DataFrame(condition_summary)
    
    # Convert hit positions to tuples temporarily for proper duplicate detection
    summary_df_temp = summary_df.copy()
    summary_df_temp['Hit Position Tuple'] = summary_df_temp['Hit Position'].apply(tuple)
    plot_df = summary_df_temp.drop_duplicates(subset=['Hit Position Tuple'], keep='first')
    plot_df = plot_df.drop('Hit Position Tuple', axis=1)  # Remove the temporary column
    
    # Create a proper DataFrame for plotting
    plot_data = pd.DataFrame({
        'Hit Position X': plot_df['Hit Position'].apply(lambda x: x[0]),
        'Hit Position Y': plot_df['Hit Position'].apply(lambda x: x[1]),
        'Maze Condition': plot_df['Maze Condition'].astype(str)
    })
    
    # Generate unique colors for each maze condition
    import plotly.colors as pc
    n_conditions = len(plot_data['Maze Condition'].unique())
    colors = pc.sample_colorscale('viridis', [i/(max(n_conditions-1, 1)) for i in range(n_conditions)])
    
    # Plot hit position by maze condition
    fig = px.scatter(
        plot_data,
        x='Hit Position X',
        y='Hit Position Y',
        color='Maze Condition',
        labels={'Hit Position X': 'Hit Position X', 'Hit Position Y': 'Hit Position Y', 'color': 'Maze Condition'},
        title='Hit Position by Maze Condition',
        color_discrete_sequence=colors,
        hover_data=['Maze Condition']
    )
    
    fig.update_layout(
        xaxis=dict(scaleanchor="y", scaleratio=1, range=[-150, 150]),
        yaxis=dict(constrain="domain", range=[-150, 150]),
        width=600,
        height=600
    )
    fig.show()
    
    return summary_df

In [None]:
# Path to your data directory
data_path = r"C:\Users\pouge\Documents\mini_data\brainsets\processed\churchland_shenoy_neural_2012"
# data_path = "/ceph/aeon/aeon/SANe/brainsets_data/processed/churchland_shenoy_neural_2012"
data_path = Path(data_path)

# List all h5 files in the directory
h5_files = [f for f in os.listdir(data_path) if f.endswith('.h5')]
print(f"Available h5 files: {h5_files}")

# User parameters
subject_name = "jenkins"  # Change to "nitschke" or "jenkins"
num_files_to_load = 4     # Change to desired number of files, max 6 (only 3 work) for nitschke, 4 for jenkins

# Filter files by subject
subject_files = [f for f in h5_files if subject_name.lower() in f.lower()]
print(f"\nFiles for subject {subject_name}: {subject_files}")

if len(subject_files) == 0:
    print(f"No files found for subject {subject_name}")
elif len(subject_files) < num_files_to_load:
    print(f"Only {len(subject_files)} files available for {subject_name}, loading all of them")
    num_files_to_load = len(subject_files)

# Load and clean the specified number of files
sessions = []
for i in range(min(num_files_to_load, len(subject_files))):
    file_path = os.path.join(data_path, subject_files[i])
    print(f"\nLoading file {i+1}/{num_files_to_load}: {subject_files[i]}")
    
    # Read neural data from HDF5
    with h5py.File(file_path, "r") as f:
        session = Data.from_hdf5(f)

        session.spikes.materialize()
        session.trials.materialize()
        session.hand.materialize()
        session.eye.materialize()
        session.session.materialize()
        session.units.materialize()

        print("Session ID: ", session.session.id)
        print("Session subject id: ", session.subject.id)
        print("Session subject sex: ", session.subject.sex)
        print("Session subject species: ", session.subject.species)
        print("Session recording date: ", session.session.recording_date)
        print("Original number of trials:", len(session.trials.start))
        
        # Clean the session data
        try:
            session = clean_session_data(session)
            print("Final number of trials after cleaning:", len(session.trials.start))

            print("Summary of primary conditions:")
            primary_conditions_summary = analyze_maze_conditions(session)
            display(primary_conditions_summary)
            
            sessions.append(session)
        except Exception as e:
            print(f"Error processing session {session.session.id}: {e}")
            continue

print(f"\nSuccessfully loaded and cleaned {len(sessions)} sessions for subject {subject_name}")

In [None]:
def fix_maze_conditions_consistency(sessions):
    """
    Fix maze condition numbering to be consistent across all sessions.
    
    Args:
        sessions: List of session objects
    
    Returns:
        List of cleaned sessions with consistent maze condition numbering
    """
    
    def get_maze_signature(session, condition):
        """Get maze parameters for a specific condition to create a signature"""
        condition_mask = session.trials.maze_condition == condition
        
        # Get unique values for this condition
        barriers = tuple(np.unique(session.trials.maze_num_barriers[condition_mask]))
        targets = tuple(np.unique(session.trials.maze_num_targets[condition_mask]))
        hit_position = np.unique(session.trials.hit_target_position[condition_mask], axis=0)
        if len(barriers) > 1 or len(targets) > 1 or len(hit_position) > 1:
            raise ValueError(f"Condition {condition} has the following >1 unique values for one of the following: "
                             f"barriers={barriers}, targets={targets}, hit_position={hit_position}. "
                             "This should not be possible.")
        else:
            hit_position = tuple(tuple(hit_position[0]))
        
        return (barriers, targets, hit_position)
    
    def get_group_signature(session, group_conditions):
        """Get combined signature for a group of 3 conditions"""
        group_sigs = []
        for condition in sorted(group_conditions):
            sig = get_maze_signature(session, condition)
            group_sigs.append(sig)
        return tuple(group_sigs)
    
    # Remove sessions that don't have multiples of 3 maze conditions
    valid_sessions = []
    for i, session in enumerate(sessions):
        unique_conditions = np.unique(session.trials.maze_condition)
        num_conditions = len(unique_conditions)
        
        if num_conditions % 3 != 0:
            print(f"WARNING: Removing session {session.session.id} - has {num_conditions} maze conditions (not multiple of 3)")
            continue
        
        valid_sessions.append(session)
    
    if len(valid_sessions) == 0:
        raise ValueError("No valid sessions remaining after filtering")
    
    print(f"Kept {len(valid_sessions)} out of {len(sessions)} sessions after filtering")
    
    # Pick reference session (most unique maze conditions)
    condition_counts = []
    for session in valid_sessions:
        unique_conditions = len(np.unique(session.trials.maze_condition))
        condition_counts.append(unique_conditions)
    
    max_conditions = max(condition_counts)
    ref_idx = condition_counts.index(max_conditions)
    reference_session = valid_sessions[ref_idx]
    
    print(f"Using session {reference_session.session.id} as reference (has {max_conditions} maze conditions)")
    
    # Group conditions and check for duplicates in reference session
    ref_conditions = sorted(np.unique(reference_session.trials.maze_condition))
    ref_groups = []
    
    for i in range(0, len(ref_conditions), 3):
        group = ref_conditions[i:i+3]
        if len(group) != 3:
            raise ValueError(f"Reference session has incomplete group: {group}")
        ref_groups.append(group)
    
    # Create signatures for reference groups
    ref_group_signatures = {}
    ref_signatures_to_group = {}
    
    for group_idx, group in enumerate(ref_groups):
        signature = get_group_signature(reference_session, group)
        
        if signature in ref_signatures_to_group:
            existing_group = ref_signatures_to_group[signature]
            print(f"WARNING: Duplicate group found in reference session!")
            print(f"  Group {existing_group} and Group {group} have identical maze parameters")
        
        ref_group_signatures[group_idx] = signature
        ref_signatures_to_group[signature] = group
    
    print(f"Reference session has {len(ref_groups)} groups of maze conditions")

    print("Table of reference session maze conditions:")
    primary_conditions_summary = analyze_maze_conditions(reference_session)
    display(primary_conditions_summary) 
    
    # Process all sessions to match groups and renumber
    processed_sessions = []
    
    for session in valid_sessions:
        print(f"Processing session {session.session.id}...")
        
        # Get conditions and group them
        conditions = sorted(np.unique(session.trials.maze_condition))
        session_groups = []
        
        for i in range(0, len(conditions), 3):
            group = conditions[i:i+3]
            session_groups.append(group)
        
        # Match each group to reference
        condition_mapping = {}  # old_condition -> new_condition
        
        for group in session_groups:
            group_sig = get_group_signature(session, group)
            
            # Find matching reference group
            matched_ref_group_idx = None
            for ref_idx, ref_sig in ref_group_signatures.items():
                if group_sig == ref_sig:
                    matched_ref_group_idx = ref_idx
                    break
            
            if matched_ref_group_idx is None:
                print(f"ERROR: Could not match group {group} in session {session.session.id}")
                print(f"Group signature: {group_sig}")
                print("Available reference signatures:")
                for ref_idx, ref_sig in ref_group_signatures.items():
                    ref_group = ref_groups[ref_idx]
                    print(f"  Reference group {ref_group}: {ref_sig}")
                raise ValueError(f"Unmatchable maze conditions {group} in session {session.session.id}")
            
            # Map old conditions to new conditions
            new_base_condition = matched_ref_group_idx * 3 + 1  # 1, 4, 7, 10, ...
            for i, old_condition in enumerate(sorted(group)):
                new_condition = new_base_condition + i
                condition_mapping[old_condition] = new_condition
        
        
        # Apply the mapping to the session
        new_maze_conditions = np.array([condition_mapping[old] for old in session.trials.maze_condition])
        session.trials.maze_condition = new_maze_conditions
        
        processed_sessions.append(session)
    
    print(f"\nSuccessfully processed {len(processed_sessions)} sessions with consistent maze condition numbering")
    return processed_sessions

# Add this after your existing code, after sessions are loaded and cleaned:
print("Fixing maze condition consistency across sessions...")

sessions = fix_maze_conditions_consistency(sessions)

# Show updated condition summaries for verification
print("\nCondition remapping summary:")
for i, session in enumerate(sessions):
    unique_conditions = sorted(np.unique(session.trials.maze_condition))
    print(f"Session {session.session.id}: {unique_conditions}")

In [None]:
# Uncomment to train a model on a specific session post fixing maze conditions consistency
# print(len(sessions))
# sessions =  [sessions[i] for i in [0, 1]]
# print(len(sessions))

In [None]:
# Parameters
bin_size = 0.05

# Check unit consistency across sessions
unit_ids = np.unique(sessions[0].spikes.unit_index)
for session in sessions:
    unique_units = np.unique(session.spikes.unit_index)
    if not np.array_equal(unique_units, unit_ids):
        raise ValueError("Sessions do not have the same unit IDs. Cannot combine spike data.")

# Determine global bin alignment start point
global_start = min(session.session.recording_date + session.trials.start.min() for session in sessions)
global_start = np.floor(global_start / bin_size) * bin_size  # ensure clean bin alignment

# Convert to consistent timestamps
n_decimals = int(-np.log10(bin_size)) + 1 if bin_size < 1 else 0

# Accumulator for all binned trials
binned_dfs = []

# Loop over sessions
for session in sessions:
    # Shift spike timestamps to absolute time
    abs_timestamps = session.spikes.timestamps + session.session.recording_date
    unit_ids_this_session = session.spikes.unit_index
    df_spikes = pd.DataFrame({
        'timestamp': abs_timestamps,
        'unit': unit_ids_this_session
    }).sort_values('timestamp')

    # Convert trial times to absolute time
    trial_starts = session.trials.start + session.session.recording_date
    trial_ends = session.trials.end + session.session.recording_date
    df_trials = pd.DataFrame({
        'trial_start': trial_starts,
        'trial_end': trial_ends
    }).sort_values('trial_start')

    # Assign each spike to the most recent trial_start <= timestamp
    df_merged = pd.merge_asof(
        df_spikes,
        df_trials[['trial_start', 'trial_end']],
        left_on='timestamp',
        right_on='trial_start',
        direction='backward'
    )

    # Drop spikes that fall outside their trial interval
    df_merged = df_merged[df_merged['timestamp'] < df_merged['trial_end']]

    # Compute bin index relative to global bin start
    df_merged['bin'] = ((df_merged['timestamp'] - global_start) / bin_size).astype(int)

    # Group by (bin, unit) and count spikes
    df_counts = (
        df_merged
        .groupby(['bin', 'unit'], observed=True)
        .size()
        .reset_index(name='count')
    )

    # Pivot to wide format: units as columns
    spk_cts = df_counts.pivot_table(
        index='bin',
        columns='unit',
        values='count',
        fill_value=0
    )

    # Convert to consistent timestamps
    session_timestamps = np.round(global_start + spk_cts.index * bin_size, n_decimals)
    spk_cts.index = pd.Index(session_timestamps, name='timestamp')
    spk_cts.columns.name = None

    # Collect results
    binned_dfs.append(spk_cts)

# Concatenate all binned trials across sessions
spk_cts_df = pd.concat(binned_dfs)
spk_cts_df = spk_cts_df[~spk_cts_df.index.duplicated()]
spk_cts_df.sort_index(inplace=True)

# Result: each row = time bin, each col = unit, values = spike count
display(spk_cts_df)

In [None]:
# Compute mean firing rate (Hz) per unit
duration_sec = len(spk_cts_df) * bin_size
mean_firing_rates = spk_cts_df.sum(axis=0) / duration_sec  # spikes/sec

# Plot histogram
plt.figure(figsize=(8, 4))
plt.hist(mean_firing_rates, bins=30, edgecolor='black')
plt.xlabel('Mean Firing Rate (Hz)')
plt.ylabel('Number of Units')
plt.title('Distribution of Mean Firing Rates')
plt.grid(True)
plt.tight_layout()
plt.show()

# Print summary stats
print("Mean firing rate across units: {:.2f} Hz".format(mean_firing_rates.mean()))
print("Median: {:.2f} Hz".format(np.median(mean_firing_rates)))
print("Range: {:.2f}–{:.2f} Hz".format(mean_firing_rates.min(), mean_firing_rates.max()))

In [None]:
# Flatten the spike count matrix to a 1D array
flattened_spike_counts = spk_cts_df.values.flatten()

# Plot histogram of spike counts
plt.figure(figsize=(8, 5))
plt.hist(flattened_spike_counts, bins=50, edgecolor='black')
plt.title("Distribution of Spike Counts per Unit per Time Bin")
plt.xlabel("Spike Count")
plt.ylabel("Number of Unit-Bin Combinations")
plt.yscale("log")  # Optional: log scale to better visualize skewed distributions
plt.grid(True)
plt.tight_layout()
plt.show()

In [None]:
"""Check sparsity of binned spike counts."""

frac_nonzero_bins = (spk_cts_df != 0).values.sum() / spk_cts_df.size
frac_nonzero_examples = (spk_cts_df.sum(axis=1) > 0).mean()
print(f"{frac_nonzero_bins=:.4f}")
print(f"{frac_nonzero_examples=:.4f}")

## Get environment / behavior (meta)data

In [None]:
# Collect data from all sessions into arrays
hand_data = defaultdict(list)
eye_data = defaultdict(list)
trial_data = defaultdict(list)

for i, session in enumerate(sessions):
    recording_date = session.session.recording_date
    
    # Hand data - extract timestamps and 2D arrays
    timestamps = session.hand.timestamps + recording_date
    acc = session.hand.acc_2d
    pos = session.hand.pos_2d
    vel = session.hand.vel_2d
    
    hand_data['timestamp'].append(timestamps)
    hand_data['acc_x'].append(acc[:, 0])
    hand_data['acc_y'].append(acc[:, 1])
    hand_data['pos_x'].append(pos[:, 0])
    hand_data['pos_y'].append(pos[:, 1])
    hand_data['vel_x'].append(vel[:, 0])
    hand_data['vel_y'].append(vel[:, 1])
    hand_data['session'].append(np.full(len(timestamps), i))
    
    # Eye data - extract timestamps and position arrays
    eye_timestamps = session.eye.timestamps + recording_date
    eye_pos = session.eye.pos
    
    eye_data['timestamp'].append(eye_timestamps)
    eye_data['pos_x'].append(eye_pos[:, 0])
    eye_data['pos_y'].append(eye_pos[:, 1])
    
    # Trial data - add recording_date to all time columns
    trial_data['start'].append(session.trials.start + recording_date)
    trial_data['end'].append(session.trials.end + recording_date)
    trial_data['target_on_time'].append(session.trials.target_on_time + recording_date)
    trial_data['go_cue_time'].append(session.trials.go_cue_time + recording_date)
    trial_data['move_begins_time'].append(session.trials.move_begins_time + recording_date)
    trial_data['move_ends_time'].append(session.trials.move_ends_time + recording_date)
    trial_data['maze_condition'].append(session.trials.maze_condition)
    trial_data['barriers'].append(session.trials.maze_num_barriers)
    trial_data['targets'].append(session.trials.maze_num_targets)
    trial_data['hit_position_x'].append([pos[0] for pos in session.trials.hit_target_position])
    trial_data['hit_position_y'].append([pos[1] for pos in session.trials.hit_target_position])
    trial_data['hit_position_angle'].append([np.degrees(np.arctan2(pos[1], pos[0])) for pos in session.trials.hit_target_position])

# Concatenate all arrays into final datasets
combined_hand_data = {}
for key, arrays in hand_data.items():
    combined_hand_data[key] = np.concatenate(arrays)

combined_eye_data = {}
for key, arrays in eye_data.items():
    combined_eye_data[key] = np.concatenate(arrays)

combined_trial_data = {}
for key, arrays in trial_data.items():
    combined_trial_data[key] = np.concatenate(arrays)

# Create final DataFrames
combined_hand_df = pd.DataFrame(combined_hand_data).set_index('timestamp')
combined_eye_df = pd.DataFrame(combined_eye_data).set_index('timestamp')
combined_trials_df = pd.DataFrame(combined_trial_data)

# Add hit target position column
combined_trials_df['hit_target_position'] = list(zip(
    combined_trials_df['hit_position_x'],
    combined_trials_df['hit_position_y']
))

# Create unified timestamp index from all data sources
all_event_ts = np.concatenate([
    combined_trials_df['target_on_time'].values,
    combined_trials_df['go_cue_time'].values,
    combined_trials_df['move_begins_time'].values,
    combined_trials_df['move_ends_time'].values,
])

# Get unique timestamps across all data
all_ts = np.unique(np.concatenate([
    combined_hand_df.index.values,
    combined_eye_df.index.values,
    all_event_ts
]))

# Create master dataframe with unified timestamp index
metadata = pd.DataFrame(index=all_ts)
metadata.index.name = 'timestamp'

# Merge hand and eye data
metadata = metadata.join(combined_hand_df, how='left')
metadata = metadata.join(combined_eye_df, how='left', rsuffix='_eye')

# Add event column - mark timestamps that correspond to trial events
event_map = {
    'target_on_time': 'target_on',
    'go_cue_time': 'go_cue',
    'move_begins_time': 'move_begins',
    'move_ends_time': 'move_ends',
}
event_col = pd.Series(index=metadata.index, dtype="object")
for col, label in event_map.items():
    event_times = combined_trials_df[col].values
    mask = np.isin(metadata.index.values, event_times)
    event_col.iloc[mask] = label
metadata['event'] = event_col

# Add trial_idx column - assign each timestamp to its trial
# Use binary search to efficiently find which trial each timestamp belongs to
trial_idx_series = pd.Series(index=metadata.index, dtype='float64')

# Sort trials by start time for binary search
trial_sort_idx = np.argsort(combined_trials_df['start'].values)
starts = combined_trials_df['start'].values[trial_sort_idx]
ends = combined_trials_df['end'].values[trial_sort_idx]

# Find potential trial for each timestamp
timestamps = metadata.index.values
start_positions = np.searchsorted(starts, timestamps, side='right') - 1

# Check which timestamps are within valid trial intervals
valid_mask = (start_positions >= 0) & (start_positions < len(starts))
valid_positions = start_positions[valid_mask]
valid_timestamps = timestamps[valid_mask]

# Verify timestamps are before trial end times
end_mask = valid_timestamps <= ends[valid_positions]
final_valid_mask = np.zeros(len(timestamps), dtype=bool)
final_valid_mask[valid_mask] = end_mask

# Assign trial indices to timestamps
trial_indices = np.full(len(timestamps), np.nan)
trial_indices[final_valid_mask] = trial_sort_idx[valid_positions[end_mask]]
trial_idx_series.iloc[:] = trial_indices

metadata['trial_idx'] = trial_idx_series

# Map trial properties using the trial indices
metadata['maze_condition'] = metadata['trial_idx'].astype('Int64').map(
    combined_trials_df['maze_condition']
)
metadata['barriers'] = metadata['trial_idx'].astype('Int64').map(
    combined_trials_df['barriers']
)
metadata['targets'] = metadata['trial_idx'].astype('Int64').map(
    combined_trials_df['targets']
)
metadata['hit_position_x'] = metadata['trial_idx'].astype('Int64').map(
    combined_trials_df['hit_position_x']
)
metadata['hit_position_y'] = metadata['trial_idx'].astype('Int64').map(
    combined_trials_df['hit_position_y']
)
metadata['hit_position_angle'] = metadata['trial_idx'].astype('Int64').map(
    combined_trials_df['hit_position_angle']
)

# Add movement_angle column based on position difference
pos_delta_x = metadata['pos_x'].diff()
pos_delta_y = metadata['pos_y'].diff()
metadata['movement_angle'] = np.degrees(np.arctan2(pos_delta_y, pos_delta_x))

# Calculate speed and acceleration magnitudes from vector components
metadata['vel_magnitude'] = np.sqrt(
    metadata['vel_x'].values**2 + metadata['vel_y'].values**2
)
metadata['accel_magnitude'] = np.sqrt(
    metadata['acc_x'].values**2 + metadata['acc_y'].values**2
)

# Show result
print("Metadata:")
display(metadata)

In [None]:
"""Bin the metadata to match spike counts."""

# Create metadata_binned with consistent timestamps
ts = spk_cts_df.index.values
metadata_binned = pd.DataFrame(index=pd.Index(ts, name='timestamp'))

# Assign each metadata row to a bin index (half-open [t, t+bin))
bin_ids = np.searchsorted(ts, metadata.index.values, side='right') - 1
bin_ids = np.clip(bin_ids, 0, len(ts) - 1)

# Handle event aggregation 
event_values = metadata['event'].values
event_mask = pd.notna(event_values)

if event_mask.any():
    event_bin_ids = bin_ids[event_mask]
    valid_events = event_values[event_mask].astype(str)
    
    # Create event assignments using vectorized operations
    event_agg = np.full(len(ts), None, dtype=object)
    
    # Calculate target bins for all events at once
    event_indices = np.arange(len(valid_events))
    target_bins = event_bin_ids + event_indices
    
    # Only assign events that fall within valid bin range
    valid_targets = target_bins < len(ts)
    event_agg[target_bins[valid_targets]] = valid_events[valid_targets]
else:
    event_agg = np.full(len(ts), None, dtype=object)

# Efficient nearest neighbor reindexing using searchsorted
metadata_timestamps = metadata.index.values
ts_positions = np.searchsorted(metadata_timestamps, ts, side='left')

# Handle edge cases and find true nearest neighbors
ts_positions = np.clip(ts_positions, 0, len(metadata_timestamps) - 1)

# For positions not at the start, check if the previous position is closer
mask = ts_positions > 0
left_positions = ts_positions.copy()
left_positions[mask] = ts_positions[mask] - 1

# Calculate distances to determine nearest
left_distances = np.abs(ts - metadata_timestamps[left_positions])
right_distances = np.abs(ts - metadata_timestamps[ts_positions])

# Choose the nearest position
final_positions = np.where(left_distances < right_distances, left_positions, ts_positions)

# Copy all columns from nearest metadata except event and trial_idx
for col in metadata.columns:
    if col not in ('event', 'trial_idx'):
        metadata_binned[col] = metadata[col].iloc[final_positions].values

# Insert distributed events
metadata_binned['event'] = event_agg

# Override trial_idx from GT intervals [start, end) to avoid NN bleed across gaps
starts = combined_trials_df['start'].to_numpy()
ends = combined_trials_df['end'].to_numpy()
order = np.argsort(starts)
s, e = starts[order], ends[order]

left  = ts
right = ts + bin_size  # each bin is [left, right)
# Candidate trial: last trial that starts before this bin ends
cand = np.searchsorted(s, right, side='right') - 1
# Bin belongs if it overlaps trial interval
valid = (cand >= 0) & (left < e[cand]) & (right > s[cand])

trial_idx_binned = np.full(len(ts), np.nan)
trial_idx_binned[valid] = order[cand[valid]]

metadata_binned['trial_idx'] = pd.Series(
    trial_idx_binned, index=metadata_binned.index, dtype='float64'
)

first = metadata_binned.groupby('trial_idx', sort=False).head(1).index
last  = metadata_binned.groupby('trial_idx', sort=False).tail(1).index
metadata_binned.loc[first, 'event'] = 'start'
metadata_binned.loc[last.difference(first), 'event'] = 'end'

# Fill "A -> B" on empty rows between events, per trial (keeps existing 'start'/'end'/'event' rows)
metadata_binned['event'] = metadata_binned['event'].astype(object)
prev_ev = metadata_binned.groupby('trial_idx', sort=False)['event'].ffill()
next_ev = metadata_binned.groupby('trial_idx', sort=False)['event'].bfill()
mask = metadata_binned['event'].isna() & prev_ev.notna() & next_ev.notna()
metadata_binned.loc[mask, 'event'] = prev_ev[mask] + ' -> ' + next_ev[mask]
allowed = {'start', 'start -> target_on', 'target_on', 'target_on -> go_cue', 'go_cue', 'go_cue -> move_begins',
           'move_begins', 'move_begins -> move_ends', 'move_ends', 'move_ends -> end', 'end'}
metadata_binned['event'] = metadata_binned['event'].where(metadata_binned['event'].isin(allowed))

# Result: metadata for all trial bins, one row per bin
print("Metadata binned:")
display(metadata_binned)

## Train/val split and normalise with trial-aware shuffling

In [None]:
spikes = spk_cts_df.values.astype(np.float32)
velocity = np.column_stack([
    metadata_binned["vel_x"].to_numpy(dtype=np.float32),
    metadata_binned["vel_y"].to_numpy(dtype=np.float32),
])

In [None]:
# Split trials into train/val sets (split by session)
train_sessions = [0, 1, 2, 3]
train_trials = metadata_binned[metadata_binned['session'].isin(train_sessions)]['trial_idx'].unique()
if len(train_sessions) == len(sessions):
    print("WARNING: All sessions are included in the training set. Validation set will be identical to training set.")
    val_trials = train_trials  # if all sessions are train, val is same as train
else:
    val_trials = metadata_binned[~metadata_binned['session'].isin(train_sessions)]['trial_idx'].unique()

# # OR Split trials into train/val sets (80/20 split)
# # Get unique trial indices and shuffle them
# unique_trials = metadata_binned['trial_idx'].unique()
# np.random.shuffle(unique_trials)
# # Split 
# n_train_trials = int(len(unique_trials) * 0.8)
# train_trials = unique_trials[:n_train_trials]
# val_trials = unique_trials[n_train_trials:]

# Create boolean masks for train/val based on trial membership
train_mask = metadata_binned['trial_idx'].isin(train_trials)
val_mask = metadata_binned['trial_idx'].isin(val_trials)

# # Split 
# spikes_train = spikes[train_mask]
# spikes_val = spikes[val_mask]
# OR split and smooth spikes (avoiding cross-boundary leakage)
sigma = 0.05 / bin_size
spikes_train = gaussian_filter1d(spikes[train_mask], sigma=sigma, axis=0)
spikes_val = gaussian_filter1d(spikes[val_mask], sigma=sigma, axis=0)

# Normalize spikes (fit normalization on training data only)
train_max = spikes_train.max()
spikes_train = spikes_train / train_max
spikes_val = spikes_val / train_max

# Split velocity
velocity_train = velocity[train_mask]
velocity_val = velocity[val_mask]

# Extract trial IDs for reference
trial_ids_train = metadata_binned['trial_idx'][train_mask].values
trial_ids_val = metadata_binned['trial_idx'][val_mask].values

# Summary
print(f"Train set: {len(train_trials)} trials ({train_mask.sum()} time bins)")
print(f"Val set: {len(val_trials)} trials ({val_mask.sum()} time bins)")
print(f"Spike data shapes: train {spikes_train.shape}, val {spikes_val.shape}")

## Set SAE config

In [None]:
# gpu for training
device = t.device("cuda" if t.cuda.is_available() else "cpu")
print(f"{device=}")

spk_train_t = t.from_numpy(spikes_train).to(device).to(dtype=t.bfloat16)

dsae_topk_map = {256: 8, 512: 16, 1024: 24}
dsae_topk_map = dict(sorted(dsae_topk_map.items()))  # ensure sorted from smallest to largest
dsae_loss_x_map = {256: 1, 512: 1.25, 1024: 1.5}
dsae_loss_x_map = dict(sorted(dsae_loss_x_map.items()))
# dsae_topk_map = {1024: 12, 2048: 24, 4096: 48}
dsae = max(dsae_topk_map.keys())
n_inst = 2

display(spk_train_t)

## Train MSAE

In [None]:
sae_cfg = mt.SaeConfig(
    n_input_ae=spk_train_t.shape[1],   # input dimension = #units
    dsae_topk_map=dsae_topk_map,
    dsae_loss_x_map=dsae_loss_x_map,
    seq_len=1,
    n_instances=n_inst,
)
sae = mt.Sae(sae_cfg).to(device)
loss_fn = mt.msle
tau = 1.0
lr = 5e-3

n_epochs = 20
batch_sz = 1024
n_steps = (spk_train_t.shape[0] // batch_sz) * n_epochs
log_freq = max(1, n_steps // n_epochs // 2)
dead_neuron_window = max(1, n_steps // n_epochs // 3)

data_log = mt.optimize(  # train model
    spk_cts=spk_train_t,
    sae=sae,
    loss_fn=loss_fn,
    optimizer=t.optim.Adam(sae.parameters(), lr=lr),
    use_lr_sched=True,
    dead_neuron_window=dead_neuron_window,
    n_steps=n_steps,
    log_freq=log_freq,
    batch_sz=batch_sz,
    log_wandb=False,
    plot_l0=False,
    tau=tau,
)

## Validate SAEs

In [None]:
"""Check for nans in weights."""

sae.W_dec.isnan().sum(), sae.W_enc.isnan().sum()

In [None]:
"""Visualize weights."""

fig, ax = plt.subplots(figsize=(8, 6))
for inst in range(n_inst):
    W_dec_flat = asnumpy(sae.W_dec[inst].float()).ravel()
    sns.histplot(W_dec_flat, bins=1000, stat="probability", alpha=0.7, label=f"SAE {inst}")
    
ax.set_title("SAE decoder weights")
ax.set_xlabel("Weight value")
ax.set_ylabel("Frequency")
ax.legend()

In [None]:
"""Visualize metrics over all examples and units."""

topk_acts_4d, recon_spk_cts, r2_per_unit, _, cossim_per_unit, _ = mt.eval_model(
    spk_train_t, sae, batch_sz=batch_sz
)

In [None]:
"""Calculate variance explained of summed spike counts."""

n_recon_examples = recon_spk_cts.shape[0]
recon_summed_spk_cts = reduce(recon_spk_cts, "example inst unit -> example inst", "sum")

actual_summed_spk_cts = reduce(spk_train_t, "example unit -> example", "sum")
actual_summed_spk_cts = actual_summed_spk_cts[:n_recon_examples]  # trim to match

for inst in range(n_inst):
    r2 = r2_score(
        asnumpy(actual_summed_spk_cts.float()),
        asnumpy(recon_summed_spk_cts[:, inst].float()),
    )
    print(f"SAE instance {inst} R² (summed spike count over all units per example) = {r2:.3f}")


In [None]:
# If cosine similarity is high but r2 is low, it suggests that the model is capturing the structure of the data but not the magnitude.
# Calculate scale ratio of norms to check this

# Expand to [n_examples, 1, n_units]
spk_train_exp = spk_train_t[:n_recon_examples].unsqueeze(1)

true_norms  = t.norm(spk_train_exp, dim=-1)   # [n_examples, 1]
recon_norms = t.norm(recon_spk_cts, dim=-1)   # [n_examples, n_instances]
scale = true_norms / recon_norms              # [n_examples, n_instances]

print(scale.mean(dim=0))  # if it’s consistently >1 or <1, your model is biased in magnitude

In [None]:
spk_train_trim = spk_train_t[:n_recon_examples]
bias = (recon_spk_cts - spk_train_trim.unsqueeze(1)).mean(dim=0)
print(bias.mean(dim=0))  # mean bias per unit, averaged across examples

In [None]:
true_var = spk_train_trim.var(dim=0).mean()
pred_var = recon_spk_cts.var(dim=0).mean()
print(f"True variance: {true_var.item():.4f}, Pred variance: {pred_var.item():.4f}")

### Remove bad units and retrain.

In [None]:
# Set threshold for removing units
r2_thresh = 0.1
inst = 0
r2_inst = r2_per_unit[:, inst]              # [n_units]
keep_mask = r2_inst > r2_thresh
print(f"frac units above {r2_thresh=}: {keep_mask.sum() / keep_mask.shape[0]:.2f}")
print(f"Number to keep: {keep_mask.sum()} / {keep_mask.shape[0]}")

# Apply mask to train data (can be applied to val later)
spk_train_pruned = spk_train_t[:, keep_mask]

# Retrain SAE on pruned train data
sae_cfg = mt.SaeConfig(
    n_input_ae=spk_train_pruned.shape[1],
    dsae_topk_map=dsae_topk_map,
    dsae_loss_x_map=dsae_loss_x_map,
    seq_len=1,
    n_instances=n_inst,
)
sae = mt.Sae(sae_cfg).to(device)
loss_fn = mt.msle
tau = 1.0
lr = 5e-3

n_epochs = 20
batch_sz = 1024
n_steps = (spk_train_pruned.shape[0] // batch_sz) * n_epochs
log_freq = max(1, n_steps // n_epochs // 2)
dead_neuron_window = max(1, n_steps // n_epochs // 3)

data_log = mt.optimize(
    spk_cts=spk_train_pruned,
    sae=sae,
    loss_fn=loss_fn,
    optimizer=t.optim.Adam(sae.parameters(), lr=lr),
    use_lr_sched=True,
    dead_neuron_window=dead_neuron_window,
    n_steps=n_steps,
    log_freq=log_freq,
    batch_sz=batch_sz,
    log_wandb=False,
    plot_l0=False,
    tau=tau,
)

In [None]:
"""Re-visualize metrics over all examples and units."""

Xtr_t = spk_train_pruned

topk_acts_4d_tr, recon_spk_cts_tr, r2_per_unit_tr, _, cossim_per_unit_tr, _ = mt.eval_model(
    Xtr_t, sae, batch_sz=batch_sz
)

n_recon_examples_tr = recon_spk_cts_tr.shape[0]
recon_summed_tr = reduce(recon_spk_cts_tr, "example inst unit -> example inst", "sum")

actual_summed_tr = reduce(Xtr_t, "example unit -> example", "sum")
actual_summed_tr = actual_summed_tr[:n_recon_examples_tr]

for inst in range(n_inst):
    r2 = r2_score(
        asnumpy(actual_summed_tr.float()),
        asnumpy(recon_summed_tr[:, inst].float()),
    )
    print(f"[TRAIN] SAE instance {inst} R² (summed spike count per example) = {r2:.3f}")

In [None]:
"""Do the same on validation data to check generalisation"""

Xva_np = spikes_val[:, keep_mask]  # apply same unit mask
Xva_t = t.from_numpy(Xva_np).to(Xtr_t.device).to(Xtr_t.dtype)

topk_acts_4d_va, recon_spk_cts_va, r2_per_unit_va, _, cossim_per_unit_va, _ = mt.eval_model(
    Xva_t, sae, batch_sz=batch_sz
)

n_recon_examples_va = recon_spk_cts_va.shape[0]
recon_summed_va = reduce(recon_spk_cts_va, "example inst unit -> example inst", "sum")

actual_summed_va = reduce(Xva_t, "example unit -> example", "sum")
actual_summed_va = actual_summed_va[:n_recon_examples_va]

for inst in range(n_inst):
    r2 = r2_score(
        asnumpy(actual_summed_va.float()),
        asnumpy(recon_summed_va[:, inst].float()),
    )
    print(f"[VAL] SAE instance {inst} R² (summed spike count per example) = {r2:.3f}")


## Save/load activations

In [None]:
"""Load saved activations (Etr/Eval) if available; otherwise (only if save_activations=True) densify from top-k and save."""

load_activations = False
save_activations = True
dense_activations_file = "sae_dense_activations_and_targets.npz"
sparse_activations_file_tr = "sae_sparse_activations_train.csv"
sparse_activations_file_va = "sae_sparse_activations_val.csv"

# Build save path (same style as before)
session_dates = []
for session in sessions:
    session_date = datetime.fromtimestamp(session.session.recording_date).strftime("%Y%m%d")
    session_dates.append(session_date)
session_dates_str = "_".join(session_dates)

activations_save_path = data_path / f"{subject_name}_{session_dates_str}" / "sae_activations_all_sessions_w_smoothing"
activations_save_path.mkdir(parents=True, exist_ok=True)

if load_activations and activations_save_path.exists():
    # Load pre-saved activations/targets (cast to float32)
    data = np.load(activations_save_path / dense_activations_file, allow_pickle=True)
    Etr, Eval = data["Etr"].astype(np.float32), data["Eval"].astype(np.float32)
    ytr, yva  = data["ytr"].astype(np.float32), data["yva"].astype(np.float32)

    acts_df_tr = pd.read_csv(
        activations_save_path / sparse_activations_file_tr,
        dtype={"example_idx": int, "instance_idx": int, "feature_idx": int, "activation_value": np.float32}
    )
    acts_df_va = pd.read_csv(
        activations_save_path / sparse_activations_file_va,
        dtype={"example_idx": int, "instance_idx": int, "feature_idx": int, "activation_value": np.float32}
    )
    print(f"Loaded activations from {activations_save_path}")
else:
    # Build Etr/Eval from top-k activations (sparse → dense)
    # Train
    arr_tr = asnumpy(topk_acts_4d_tr)  # [example_idx, instance_idx, feature_idx, activation_value]
    # Sparse activations (tight dtypes on indices, fp32 values)
    acts_df_tr = pd.DataFrame({
        "example_idx":      arr_tr[:, 0].astype(int),
        "instance_idx":     arr_tr[:, 1].astype(int),
        "feature_idx":      arr_tr[:, 2].astype(int),
        "activation_value": arr_tr[:, 3].astype(np.float32),
    })
    # Dense activations
    N_tr = int(arr_tr[:, 0].max()) + 1 if arr_tr.size else 0
    Etr = np.zeros((N_tr, n_inst * dsae), dtype=np.float32)
    cols_tr = (arr_tr[:, 1].astype(np.int64) * int(dsae)) + arr_tr[:, 2].astype(np.int64)
    Etr[arr_tr[:, 0].astype(np.int64), cols_tr] = arr_tr[:, 3].astype(np.float32)

    # Val
    arr_va = asnumpy(topk_acts_4d_va)
    acts_df_va = pd.DataFrame({
        "example_idx":      arr_va[:, 0].astype(int),
        "instance_idx":     arr_va[:, 1].astype(int),
        "feature_idx":      arr_va[:, 2].astype(int),
        "activation_value": arr_va[:, 3].astype(np.float32),
    })
    N_va = int(arr_va[:, 0].max()) + 1 if arr_va.size else 0
    Eval = np.zeros((N_va, n_inst * dsae), dtype=np.float32)
    cols_va = (arr_va[:, 1].astype(np.int64) * int(dsae)) + arr_va[:, 2].astype(np.int64)
    Eval[arr_va[:, 0].astype(np.int64), cols_va] = arr_va[:, 3].astype(np.float32)

    # Targets aligned to available rows (float32)
    ytr = velocity_train[:N_tr].astype(np.float32)
    yva = velocity_val[:N_va].astype(np.float32)

    # Prune zero/near-zero-variance (already in fp32)
    std_tr = Etr.std(axis=0, ddof=0)
    keep = std_tr > 1e-6
    # # Optional: also require a few nonzeros to avoid “almost dead” columns
    # nnz_tr = (Etr != 0).sum(axis=0)
    # keep &= nnz_tr >= 5

    dropped = int((~keep).sum())
    if dropped:
        # Prune sparse activations
        keep_cols = np.nonzero(keep)[0]
        keep_inst = (keep_cols // dsae).astype(int)
        keep_feat = (keep_cols %  dsae).astype(int)
        _keep_pairs = pd.DataFrame({"instance_idx": keep_inst, "feature_idx": keep_feat})
        acts_df_tr = acts_df_tr.merge(_keep_pairs, on=["instance_idx", "feature_idx"], how="inner")
        acts_df_va = acts_df_va.merge(_keep_pairs, on=["instance_idx", "feature_idx"], how="inner")

        # Prune dense activations
        Etr = Etr[:, keep]
        Eval = Eval[:, keep]
        print(f"Pruned {dropped} / {keep.size} features. New D = {Etr.shape[1]}")

    if save_activations:
        np.savez(activations_save_path / dense_activations_file, Etr=Etr, Eval=Eval, ytr=ytr, yva=yva)
        acts_df_tr.to_csv(activations_save_path / sparse_activations_file_tr, index=False)
        acts_df_va.to_csv(activations_save_path / sparse_activations_file_va, index=False)
        print(f"Saved activations to {activations_save_path}")

print(f"Dense activations: \nEtr shape: {Etr.shape}, Eval shape: {Eval.shape}")
print(f"Sparse activations: \nTrain shape: {acts_df_tr.shape}, Val shape: {acts_df_va.shape}")


: 

## Decode

In [None]:
"""Check feature variance, sparsity, and conditioning"""

X = Etr  # or Eval
std = X.std(axis=0, ddof=0)
nnz = (X != 0).sum(axis=0)

print("zero columns:", (std == 0).sum(), " / ", X.shape[1])
print("near-constant (std<1e-6):", (std < 1e-6).sum())
print("median nnz per column:", int(np.median(nnz)))

# Rough condition number on a smaller slice (upcast to fp32 for linalg)
r = min(5000, X.shape[0])
svals = np.linalg.svd(X[:r], compute_uv=False)
cond = svals.max() / max(svals.min(), 1e-12)
print("approx cond:", cond)

In [None]:
"""Train on Etr, decode on Eval with a small lag sweep."""

def apply_lag(E_tr, E_va, y_tr, y_va, lag_bins: int):
    if lag_bins > 0:
        Et, yt = E_tr[:-lag_bins], y_tr[lag_bins:]
        Ev, yv = E_va[:-lag_bins], y_va[lag_bins:]
    elif lag_bins < 0:
        k = -lag_bins
        Et, yt = E_tr[k:], y_tr[:-k]
        Ev, yv = E_va[k:], y_va[:-k]
    else:
        Et, yt, Ev, yv = E_tr, y_tr, E_va, y_va
    return Et, Ev, yt, yv

best = {"lag": 0, "r2_mean": -np.inf, "r2_per_dim": None, "decoder": None}

for lag in range(0, 6):
    Et, Ev, yt, yv = apply_lag(Etr, Eval, ytr, yva, lag)

    # Train on train, evaluate on val
    decoder = make_pipeline(
        StandardScaler(with_mean=True, with_std=True),
        Ridge(alpha=30.0, solver="lsqr"),
    )
    decoder.fit(Et, yt)
    y_pred = decoder.predict(Ev)

    r2_per_dim = r2_score(yv, y_pred, multioutput='raw_values')
    r2_mean = float(np.mean(r2_per_dim))

    if r2_mean > best["r2_mean"]:
        best.update(lag=lag, r2_mean=r2_mean, r2_per_dim=r2_per_dim, decoder=decoder)

print(f"Best lag (bins): {best['lag']}")
print(f"R² per dimension: {best['r2_per_dim']}")
print(f"Mean R²: {best['r2_mean']}")

## Hunt for features

In [None]:
# Pick whether to find features in training or validation set
search_set = "train" # or "train" or "val"

if search_set == "train":
    acts_df = acts_df_tr
    metadata_binned_subset = metadata_binned[train_mask].copy()
else:
    acts_df = acts_df_va
    metadata_binned_subset = metadata_binned[val_mask].copy()

In [None]:
"""Automatically map features to metadata"""

def analyze_discrete_variable(
    acts_df: pd.DataFrame,
    metadata_binned: pd.DataFrame,
    variable: str,
    min_activation_frac: float
) -> List[Dict]:
    """
    Analyzes a discrete variable, calculating activation metrics for features meeting a minimum activation fraction.
    This version is corrected to handle features that are 100% selective for an event.
    """
    results = []
    unique_values = metadata_binned[variable].dropna().unique()

    for value in unique_values:
        try:
            event_idxs = np.where(metadata_binned[variable] == value)[0]
            if len(event_idxs) == 0:
                continue

            event_acts_df = acts_df[acts_df["example_idx"].isin(event_idxs)]
            if len(event_acts_df) == 0:
                continue

            event_features_df = event_acts_df.groupby(["instance_idx", "feature_idx"]).agg(
                activation_count=("activation_value", "count")
            ).reset_index()
            n_event_examples = len(event_idxs)
            event_features_df["activation_frac_event"] = event_features_df["activation_count"] / n_event_examples

            promising_features = event_features_df[event_features_df["activation_frac_event"] >= min_activation_frac]
            if promising_features.empty:
                continue

            non_event_mask = ~acts_df["example_idx"].isin(event_idxs)
            non_event_acts_df = acts_df[non_event_mask].merge(
                promising_features[["instance_idx", "feature_idx"]],
                on=["instance_idx", "feature_idx"], how="inner"
            )

            if not non_event_acts_df.empty:
                non_event_features_df = non_event_acts_df.groupby(["instance_idx", "feature_idx"]).agg(
                    activation_count=("activation_value", "count")
                ).reset_index()
                n_non_event_examples = len(metadata_binned) - n_event_examples
                non_event_features_df["activation_frac_non_event"] = non_event_features_df["activation_count"] / n_non_event_examples
                ratio_df = promising_features.merge(
                    non_event_features_df, on=["instance_idx", "feature_idx"], how="left"
                )
                ratio_df["activation_frac_non_event"] = ratio_df["activation_frac_non_event"].fillna(0.0)
            else:
                ratio_df = promising_features.copy()
                ratio_df["activation_frac_non_event"] = 0.0
            
            ratio_df["activation_ratio"] = ratio_df["activation_frac_event"] / (ratio_df["activation_frac_non_event"] + 1e-9)
            ratio_df["rate_proportion"] = ratio_df["activation_frac_event"] / (ratio_df["activation_frac_event"] + ratio_df["activation_frac_non_event"])

            for _, row in ratio_df.iterrows():
                results.append({
                    'variable': variable, 'variable_type': 'discrete', 'value': value,
                    'instance_idx': row['instance_idx'], 'feature_idx': row['feature_idx'],
                    'activation_ratio': row['activation_ratio'],
                    'activation_frac_during': row['activation_frac_event'],
                    'activation_frac_outside': row['activation_frac_non_event'],
                    'rate_proportion': row['rate_proportion']
                })
        except Exception as e:
            print(f"Could not analyze {variable}={value}: {e}")
            continue
    return results

def analyze_continuous_variable(
    acts_df: pd.DataFrame,
    metadata_binned: pd.DataFrame,
    variable: str,
    n_bins: int,
    min_activation_frac: float
) -> List[Dict]:
    """
    Analyzes a continuous variable by binning it and then using the discrete analysis method.
    """
    print(f"  Binning '{variable}' into {n_bins} bins...")
    binned_col_name = f"{variable}_binned"

    data_to_bin = metadata_binned[variable].dropna()
    if data_to_bin.empty:
        return []

    if variable == 'movement_angle':
        bins = np.linspace(-180, 180, n_bins + 1)
        labels = [f"({bins[i]:.0f}, {bins[i+1]:.0f}]" for i in range(n_bins)]
        metadata_binned[binned_col_name] = pd.cut(data_to_bin, bins=bins, labels=labels, include_lowest=True)
    else:
        metadata_binned[binned_col_name] = pd.qcut(data_to_bin, q=n_bins, labels=None, duplicates='drop')

    results = analyze_discrete_variable(acts_df, metadata_binned, binned_col_name, min_activation_frac)

    for res in results:
        res['variable'] = variable
        res['variable_type'] = 'continuous'

    return results

def map_features_to_metadata(
    acts_df: pd.DataFrame,
    metadata_binned: pd.DataFrame,
    discrete_vars: List[str] = None,
    continuous_vars: List[str] = None,
    exclude_columns: List[str] = None,
    min_activation_frac: float = 0.1,
    n_bins_continuous: int = 10,
    top_n_features: int = 3
) -> pd.DataFrame:  
    """
    Automatically maps SAE features to metadata by finding top N features for each condition.
    Returns a single DataFrame with both discrete and continuous results.
    """
    if discrete_vars is None: discrete_vars = []
    if continuous_vars is None: continuous_vars = []
    if exclude_columns is None: exclude_columns = ['trial_idx', 'session']
    
    all_results = []
    print("🚀 Starting automated feature-to-metadata mapping...")
    
    for variable in metadata_binned.columns:
        if variable in exclude_columns:
            continue
        print(f"\nAnalyzing variable: {variable}")
        
        if variable in discrete_vars:
            print(f"  Treating as: discrete")
            results = analyze_discrete_variable(acts_df, metadata_binned, variable, min_activation_frac)
            all_results.extend(results)
        elif variable in continuous_vars:
            print(f"  Treating as: continuous")
            results = analyze_continuous_variable(acts_df, metadata_binned, variable, n_bins=n_bins_continuous, min_activation_frac=min_activation_frac)
            all_results.extend(results)
        else:
            print(f"  Skipping (not in discrete_vars or continuous_vars list)")
            continue
        print(f"  Found {len(results)} potential associations.")
    
    if not all_results:
        print("\nNo associations found meeting the minimum activation fraction!")
        return pd.DataFrame()  # Return empty DataFrame instead of tuple
    
    results_df = pd.DataFrame(all_results)
    results_df['value'] = results_df['value'].astype(str)
    
    print(f"\nRanking features and selecting top {top_n_features} for each condition...")
    ranked_df = (
        results_df.sort_values('activation_ratio', ascending=False)
        .groupby(['variable', 'value', 'instance_idx'])
        .head(top_n_features)
    )
    
    # Sort the combined results
    sort_order = ['variable_type', 'variable', 'value', 'instance_idx', 'activation_ratio']
    ascending_order = [True, True, True, True, False]
    
    final_df = ranked_df.sort_values(by=sort_order, ascending=ascending_order).reset_index(drop=True)
    
    discrete_count = len(final_df[final_df['variable_type'] == 'discrete'])
    continuous_count = len(final_df[final_df['variable_type'] == 'continuous'])
    
    print(f"\n✅ Found {discrete_count} top discrete associations.")
    print(f"✅ Found {continuous_count} top continuous associations.")
    print(f"✅ Total: {len(final_df)} associations returned in single DataFrame.")
    
    return final_df

# discrete_vars = ['event', 'maze_condition', 'barriers', 'targets', 'hit_position_x', 'hit_position_y', 'hit_position_angle']
# continuous_vars = ['vel_magnitude', 'accel_magnitude', 'movement_angle']
discrete_vars = ['event', 'maze_condition']
continuous_vars = ['vel_magnitude', 'accel_magnitude']


results = map_features_to_metadata(
    acts_df, metadata_binned_subset,
    discrete_vars=discrete_vars,
    continuous_vars=continuous_vars,
    min_activation_frac=0.5,
    n_bins_continuous=12,
    top_n_features=3
)
display(results)

# # Optional filtering (ratio > 2.0, proportion > 0.5)
# results = results[
#     (results['activation_ratio'] > 2.0) & 
#     (results['rate_proportion'] > 0.5)
# ].reset_index(drop=True)
# display(results)

In [None]:
"""Calculate z-scores for spike counts across neurons."""
# Calculate mean and standard deviation for each neuron (column)
neuron_means = spk_cts_df.mean(axis=0)
neuron_stds = spk_cts_df.std(axis=0)

# Calculate z-scores
# Handle cases where standard deviation is zero to avoid division by zero
spk_z_scores_df = spk_cts_df.sub(neuron_means, axis=1).div(neuron_stds, axis=1)
spk_z_scores_df = spk_z_scores_df.replace([np.inf, -np.inf], np.nan) # Replace inf with NaN for clarity

# Set z-score to 0 where standard deviation was 0 (and thus z-score would be NaN)
spk_z_scores_df = spk_z_scores_df.fillna(0.0)

display(spk_z_scores_df)

In [None]:
"""Visualisation functions for event-feature associations"""

# make sure these only get created once
_canonical_cache = globals().setdefault('_canonical_cache', {})
_warp_cache = globals().setdefault('_warp_cache', {})

def create_canonical_timeline(combined_trials_df, maze_conditions=None, hit_target_positions=None):
    """
    Create (or reuse) a canonical timeline based on average event durations.
    Uses a simple cache so repeated calls with the same filters reuse results.
    """
    # --- normalize filters BEFORE keying ---
    # maze_conditions -> sorted tuple or None
    if maze_conditions is not None:
        maze_conditions = tuple(sorted(maze_conditions))
    # hit_target_positions: coerce lists/ndarrays->tuples, then sort
    if hit_target_positions is not None:
        if len(hit_target_positions) > 0 and not isinstance(hit_target_positions[0], tuple):
            hit_target_positions = [
                tuple(pos) if isinstance(pos, (list, np.ndarray)) else pos
                for pos in hit_target_positions
            ]
        hit_target_positions = tuple(sorted(hit_target_positions))

    key = (maze_conditions, hit_target_positions)

    if key in _canonical_cache:
        print(f"Reusing cached canonical timeline for key={key}")
        return _canonical_cache[key]

    # --- filter trials ---
    filtered_trials = combined_trials_df.copy()
    if maze_conditions is not None:
        filtered_trials = filtered_trials[filtered_trials['maze_condition'].isin(maze_conditions)]
        print(f"Filtered to maze conditions: {maze_conditions}")
    if hit_target_positions is not None:
        filtered_trials = filtered_trials[filtered_trials['hit_target_position'].isin(hit_target_positions)]
        print(f"Filtered to target positions: {hit_target_positions}")

    print(f"Using {len(filtered_trials)} trials (from {len(combined_trials_df)} total) to create canonical timeline")
    if len(filtered_trials) == 0:
        print("No trials match the filtering criteria!")
        return None, None

    # --- compute canonical ---
    events_sequence = ['start', 'target_on_time', 'go_cue_time', 'move_begins_time', 'move_ends_time', 'end']
    for event in events_sequence:
        filtered_trials = filtered_trials[filtered_trials[event].notna()]
    if len(filtered_trials) == 0:
        print("No trials have all required events!")
        return None, None

    print(f"Computing durations from {len(filtered_trials)} complete trials")
    durations = {}
    for i in range(len(events_sequence) - 1):
        e1, e2 = events_sequence[i], events_sequence[i + 1]
        td = filtered_trials[e2] - filtered_trials[e1]
        durations[f"{e1}_to_{e2}"] = td.mean()
        print(f"  {e1} to {e2}: {td.mean():.3f}s (±{td.std():.3f})")

    canonical_events = {'start': 0.0}
    t = 0.0
    for i in range(len(events_sequence) - 1):
        k = f"{events_sequence[i]}_to_{events_sequence[i+1]}"
        t += durations[k]
        canonical_events[events_sequence[i+1]] = t

    print(f"\nCanonical timeline: {canonical_events}")
    print(f"Total canonical duration: {t:.3f}s")

    # cache & return
    _canonical_cache[key] = (canonical_events, filtered_trials)
    return canonical_events, filtered_trials

def warp_trials_to_canonical_timeline(combined_trials_df, acts_df, spk_z_scores_df, metadata_binned,
                                      instance_idx=0, feature_idx=None,
                                      maze_conditions=None, hit_target_positions=None):
    """
    Warp filtered trials to a (cached) canonical timeline. Vectorised for speed.
    """
    # 1) Canonical (already cached by create_canonical_timeline)
    canonical_events, filtered_trials = create_canonical_timeline(
        combined_trials_df, maze_conditions, hit_target_positions
    )
    if canonical_events is None:
        return None, None, None

    # 2) Resolve feature
    if feature_idx is None:
        instance_acts = acts_df[acts_df['instance_idx'] == instance_idx]
        if len(instance_acts) == 0:
            print(f"No activations found for instance {instance_idx}")
            return None, None, None
        feature_idx = int(instance_acts['feature_idx'].value_counts().index[0])

    # 3) Warp cache key
    mc_key = tuple(sorted(maze_conditions)) if maze_conditions is not None else None
    tp_key = hit_target_positions
    if tp_key is not None:
        if len(tp_key) > 0 and not isinstance(tp_key[0], tuple):
            tp_key = [tuple(pos) if isinstance(pos, (list, np.ndarray)) else pos for pos in tp_key]
        tp_key = tuple(sorted(tp_key))
    warp_key = (instance_idx, feature_idx, mc_key, tp_key)

    if warp_key in _warp_cache:
        print(f"Reusing cached warped data for key={warp_key}")
        return _warp_cache[warp_key]

    # 4) Slice activations once
    feature_acts = acts_df[
        (acts_df['instance_idx'] == instance_idx) &
        (acts_df['feature_idx'] == feature_idx)
    ][['example_idx', 'activation_value']].copy()

    # 5) Choose top unit (vectorised)
    if len(feature_acts) > 0:
        active_rows = feature_acts['example_idx'].values
        unit_mean_zscores = spk_z_scores_df.iloc[active_rows].mean(axis=0)
        top_unit = unit_mean_zscores.idxmax()
    else:
        top_unit = spk_z_scores_df.mean(axis=0).idxmax()

    # 6) Canonical time axis
    canonical_duration = float(max(canonical_events.values()))
    n_bins = max(int(canonical_duration / 0.05), 1)
    canonical_time_axis = np.linspace(0.0, canonical_duration, n_bins)
    step = canonical_time_axis[1] - canonical_time_axis[0] if n_bins > 1 else 1.0

    required_events = ['start', 'target_on_time', 'go_cue_time', 'move_begins_time', 'move_ends_time', 'end']

    # Precompute a quick view for timestamps (numpy array for speed)
    mb_index = metadata_binned.index.values

    warped_feature_acts = []
    warped_unit_acts = []
    trial_info_list = []
    valid_trials = 0

    print("Warping trials...")

    # 7) (Tiny) helper for scattering with duplicate indices -> keep max
    def scatter_max(dst_len, idx, vals):
        dst = np.zeros(dst_len, dtype=float)
        if idx.size:
            # clamp indices to bounds
            idx = np.clip(idx, 0, dst_len - 1)
            # aggregate duplicates via max
            # argsort, then group
            order = np.argsort(idx)
            idx_sorted = idx[order]
            vals_sorted = vals[order]
            # find boundaries
            bounds = np.flatnonzero(np.diff(idx_sorted)) + 1
            starts = np.concatenate(([0], bounds))
            ends = np.concatenate((bounds, [idx_sorted.size]))
            # reduce each run
            for s, e in zip(starts, ends):
                i = idx_sorted[s]
                dst[i] = max(dst[i], np.max(vals_sorted[s:e]))
        return dst

    # 8) Pre-index feature acts by example_idx for fast per-trial slicing
    #    (keeps it simple: per-trial merge on tiny arrays)
    feature_acts_sorted = feature_acts.sort_values('example_idx', kind='mergesort')
    fa_idx = feature_acts_sorted['example_idx'].values
    fa_val = feature_acts_sorted['activation_value'].values

    # 9) Iterate trials, but vectorise inside each trial
    for trial_idx, trial in filtered_trials.iterrows():
        # ensure required events present
        if any(pd.isna(trial[ev]) for ev in required_events):
            continue

        # original & canonical event times
        orig_ev = np.array([trial[ev] for ev in required_events], dtype=float)
        cano_ev = np.array([canonical_events[ev] for ev in required_events], dtype=float)

        # trial bin indices
        t0, t1 = float(trial['start']), float(trial['end'])
        trial_mask = (mb_index >= t0) & (mb_index <= t1)
        if not trial_mask.any():
            continue
        trial_bins = np.nonzero(trial_mask)[0]

        # warp all timestamps for this trial at once
        warped_times = np.interp(mb_index[trial_bins], orig_ev, cano_ev)
        time_idx = np.rint(warped_times / step).astype(int)
        time_idx = np.clip(time_idx, 0, n_bins - 1)

        # --- feature activations for these bins (vectorised) ---
        # select rows of feature_acts that fall into trial_bins
        # both arrays are sorted -> use searchsorted for fast intersection
        pos_lo = np.searchsorted(fa_idx, trial_bins[0], side='left')
        pos_hi = np.searchsorted(fa_idx, trial_bins[-1], side='right')
        fa_slice_idx = fa_idx[pos_lo:pos_hi]
        fa_slice_val = fa_val[pos_lo:pos_hi]
        # filter to exact membership
        mask_in = np.isin(fa_slice_idx, trial_bins, assume_unique=False)
        fa_slice_idx = fa_slice_idx[mask_in]
        fa_slice_val = fa_slice_val[mask_in]
        if fa_slice_idx.size:
            # map example_idx -> local position in trial_bins
            loc = np.searchsorted(trial_bins, fa_slice_idx)
            loc = np.clip(loc, 0, trial_bins.size - 1)
            # scatter into canonical positions (use max if collisions)
            feat_vec = scatter_max(n_bins, time_idx[loc], fa_slice_val.astype(float))
        else:
            feat_vec = np.zeros(n_bins, dtype=float)

        # --- unit z-scores for these bins (vectorised) ---
        unit_vals = spk_z_scores_df.iloc[trial_bins][top_unit].to_numpy(dtype=float, copy=False)
        unit_vec = scatter_max(n_bins, time_idx, unit_vals)

        warped_feature_acts.append(feat_vec)
        warped_unit_acts.append(unit_vec)

        tr = trial.copy()
        tr['trial_idx'] = trial_idx
        trial_info_list.append(tr)
        valid_trials += 1

    print(f"Successfully warped {valid_trials} trials")
    if valid_trials == 0:
        return None, feature_idx, top_unit

    warped_feature_acts = np.stack(warped_feature_acts, axis=0)
    warped_unit_acts = np.stack(warped_unit_acts, axis=0)
    trial_info_df = pd.DataFrame(trial_info_list).reset_index(drop=True)

    warped_data = {
        'feature_activations': warped_feature_acts,
        'unit_activations': warped_unit_acts,
        'trial_info': trial_info_df,
        'canonical_time_axis': canonical_time_axis,
        'canonical_events': canonical_events
    }

    _warp_cache[warp_key] = (warped_data, feature_idx, top_unit)
    return warped_data, feature_idx, top_unit

def plot_warped_trials(warped_data, instance_idx, feature_idx, top_unit, top_unit_id,
                      highlight_trials=None, max_individual_trials=10,
                      smooth_window=None, show_event_regions=True):
    """
    Plot trial-warped feature and unit activations on canonical timeline
    
    Parameters:
    - warped_data: output from warp_trials_to_canonical_timeline
    - instance_idx: SAE instance index being analyzed
    - feature_idx: feature index being plotted
    - top_unit: top unit index 
    - top_unit_id: location of top unit (PMd or M1)
    - highlight_trials: list of trial indices to highlight, or None for random selection
    - max_individual_trials: maximum number of individual trials to show
    - smooth_window: optional smoothing window size in bins
    - show_event_regions: whether to show colored regions for each epoch
    """
    
    if warped_data is None:
        print("No warped data available")
        return
    
    feature_acts = warped_data['feature_activations']
    unit_acts = warped_data['unit_activations']
    trial_info = warped_data['trial_info']
    time_axis = warped_data['canonical_time_axis']
    canonical_events = warped_data['canonical_events']
    
    print(f"Plotting {len(feature_acts)} warped trials")
    
    # Apply smoothing if requested
    if smooth_window is not None and smooth_window > 1:
        feature_acts = uniform_filter1d(feature_acts, size=smooth_window, axis=1)
        unit_acts = uniform_filter1d(unit_acts, size=smooth_window, axis=1)
    
    # Calculate statistics
    feature_mean = np.mean(feature_acts, axis=0)
    feature_sem = np.std(feature_acts, axis=0) / np.sqrt(len(feature_acts))
    
    unit_mean = np.mean(unit_acts, axis=0)
    unit_sem = np.std(unit_acts, axis=0) / np.sqrt(len(unit_acts))
    
    # Create single plot with two y-axes
    fig = go.Figure()
    
    # Add background regions for different epochs if requested
    if show_event_regions:
        event_times = list(canonical_events.values())
        event_names = list(canonical_events.keys())
        colors = ['rgba(255,200,200,0.2)', 'rgba(200,255,200,0.2)', 'rgba(200,200,255,0.2)', 
                 'rgba(255,255,200,0.2)', 'rgba(255,200,255,0.2)']
        
        for i in range(len(event_times)-1):
            fig.add_vrect(x0=event_times[i], x1=event_times[i+1], fillcolor=colors[i % len(colors)], layer="below", line_width=0)
    
    # Plot feature activations (primary y-axis)
    # SEM band
    fig.add_trace(go.Scatter(x=time_axis, y=feature_mean + feature_sem, mode='lines', line=dict(width=0), showlegend=False, hoverinfo='skip', name='upper_bound'))
    fig.add_trace(go.Scatter(x=time_axis, y=feature_mean - feature_sem, mode='lines', line=dict(width=0), fill='tonexty', fillcolor='rgba(0,100,80,0.3)', name='Feature ±SEM', showlegend=True))
    
    # Mean feature line
    fig.add_trace(go.Scatter(x=time_axis, y=feature_mean, mode='lines', line=dict(color='darkgreen', width=4), name=f'Mean Feature {feature_idx}', yaxis='y'))
    
    # Individual feature trials
    if highlight_trials is not None:
        trial_indices = [idx for idx in highlight_trials if idx < len(feature_acts)][:max_individual_trials]
        trial_label = "Selected"
    else:
        if len(feature_acts) <= max_individual_trials:
            trial_indices = list(range(len(feature_acts)))
        else:
            trial_indices = np.random.choice(len(feature_acts), max_individual_trials, replace=False)
        trial_label = "Random"
    
    for i, trial_idx in enumerate(trial_indices):
        trial_info_str = ""
        if 'maze_condition' in trial_info.columns:
            maze_cond = trial_info.iloc[trial_idx]['maze_condition']
            trial_info_str += f"Maze: {maze_cond}"
        if 'hit_target_position' in trial_info.columns:
            target_pos = trial_info.iloc[trial_idx]['hit_target_position']
            trial_info_str += f", Target: {target_pos}"
        
        fig.add_trace(go.Scatter(x=time_axis, y=feature_acts[trial_idx], mode='lines', line=dict(color='rgba(0,150,100,0.5)', width=1), name=f'{trial_label} Feature Trials' if i == 0 else None, showlegend=i == 0, legendgroup='individual_feature_trials', hovertemplate=f'Trial {trial_idx}<br>{trial_info_str}<br>Time: %{{x}}<br>Feature: %{{y}}<extra></extra>', yaxis='y'))
    
    # Plot unit z-scores (secondary y-axis)
    # SEM band for units
    fig.add_trace(go.Scatter(x=time_axis, y=unit_mean + unit_sem, mode='lines', line=dict(width=0), showlegend=False, hoverinfo='skip', yaxis='y2'))
    fig.add_trace(go.Scatter(x=time_axis, y=unit_mean - unit_sem, mode='lines', line=dict(width=0), fill='tonexty', fillcolor='rgba(0,80,150,0.3)', name='Unit Z-score ±SEM', showlegend=True, yaxis='y2'))
    
    # Mean unit line
    fig.add_trace(go.Scatter(x=time_axis, y=unit_mean, mode='lines', line=dict(color='darkblue', width=4), name=f'Mean Unit {top_unit} Z-score', yaxis='y2'))
    
    # Individual unit trials
    for i, trial_idx in enumerate(trial_indices):
        trial_info_str = ""
        if 'maze_condition' in trial_info.columns:
            maze_cond = trial_info.iloc[trial_idx]['maze_condition']
            trial_info_str += f"Maze: {maze_cond}"
        if 'hit_target_position' in trial_info.columns:
            target_pos = trial_info.iloc[trial_idx]['hit_target_position']
            trial_info_str += f", Target: {target_pos}"
            
        fig.add_trace(go.Scatter(x=time_axis, y=unit_acts[trial_idx], mode='lines', line=dict(color='rgba(100,100,255,0.5)', width=1), name=f'{trial_label} Unit Trials' if i == 0 else None, showlegend=i == 0, legendgroup='individual_unit_trials', hovertemplate=f'Trial {trial_idx}<br>{trial_info_str}<br>Time: %{{x}}<br>Unit Z-score: %{{y}}<extra></extra>', yaxis='y2'))
    
    # Add event lines
    for event_name, event_time in canonical_events.items():
        fig.add_vline(x=event_time, line_dash="dash", line_color="red", line_width=2, annotation_text=event_name.replace('_', ' ').title(), annotation_position="top")
    
    # Update layout with dual y-axes
    fig.update_layout(
        title=f"Instance {instance_idx} Feature {feature_idx} & Top Unit {top_unit} ({top_unit_id})",
        xaxis_title="Canonical Time (s)",
        yaxis=dict(title="Feature Activation", side="left", color="darkgreen"),
        yaxis2=dict(title="Unit Z-score", side="right", overlaying="y", color="darkblue"),
        height=600,
        width=1400,
        legend=dict(x=0.02, y=0.98, bgcolor="rgba(255,255,255,0.8)")
    )
    
    fig.show()
    
    return fig

def explore_trial_conditions(combined_trials_df):
    """Explore available maze conditions and target positions for filtering"""
    print("\nAvailable Trial Conditions")
    
    print("\nMaze Conditions:")
    maze_counts = combined_trials_df['maze_condition'].value_counts().sort_index()
    for condition, count in maze_counts.items():
        print(f"  {condition}: {count} trials")
    
    print("\nHit Target Positions:")
    target_counts = combined_trials_df['hit_target_position'].value_counts()
    for position, count in target_counts.items():
        print(f"  {position}: {count} trials")
    
    print(f"\nTotal trials: {len(combined_trials_df)}")
    
    return maze_counts, target_counts

# Example usage
print("Trial Warping Analysis")
# First, explore what conditions are available
print("Exploring available trial conditions...")
maze_counts, target_counts = explore_trial_conditions(combined_trials_df)

# Parameters to customize
instance_to_analyze = 0
feature_to_analyze = 183
# FILTERING OPTIONS - Set these to filter trials
maze_conditions_to_include = None
target_positions_to_include = None
# OTHER PARAMETERS
num_example_trials = 8
smooth_data = 3
show_epochs = True

print(f"\nRunning analysis with filters:")
print(f"  Maze conditions: {maze_conditions_to_include}")
print(f"  Target positions: {target_positions_to_include}")

# Run warping analysis (now including spk_z_scores_df parameter)
warped_data, feature_idx, top_unit = warp_trials_to_canonical_timeline(
    combined_trials_df, acts_df, spk_z_scores_df, metadata_binned_subset,
    instance_idx=instance_to_analyze, 
    feature_idx=feature_to_analyze,
    maze_conditions=maze_conditions_to_include,
    hit_target_positions=target_positions_to_include
)
top_unit_id = session.units.id[top_unit]
if isinstance(top_unit_id, (bytes, bytearray)):
    top_unit_id = top_unit_id.decode("utf-8")
if "group_1" in top_unit_id:
    top_unit_id = "PMd"
elif "group_2" in top_unit_id:
    top_unit_id = "M1"

if warped_data is not None:
    print(f"\nResults")
    print(f"Analyzed Feature: {feature_idx}")
    print(f"Top Co-active Unit: {top_unit} ({top_unit_id})")
    
    # Create the warped trial plot
    plot_warped_trials(
        warped_data, instance_to_analyze, feature_idx, top_unit, top_unit_id,
        highlight_trials=None,
        max_individual_trials=num_example_trials,
        smooth_window=smooth_data,
        show_event_regions=show_epochs
    )
else:
    print("Failed to warp trials - check your parameters and data.")

In [None]:
"""Interactive UI"""

# Helper to map a base variable and its type to the metadata_binned_subset column name
def _bvar_name(var_name, var_type):
    return f"{var_name}_binned" if var_type == 'continuous' else var_name

# Mode selector: preset vs manual
mode_radio = widgets.RadioButtons(
    options=[
        ('Preset (from results table)', 'preset'),
        ('Manual selection',       'manual')
    ],
    value='preset',
    description=''
)

# Build the preset dropdown with full metrics
preset_entries = []
for _, r in results.iterrows():
    bvar = _bvar_name(r.variable, r.variable_type)
    if bvar not in metadata_binned_subset.columns:
        continue
    label = (
        f"Inst:{int(r.instance_idx)} | "
        f"Feat:{int(r.feature_idx)} | "
        f"Var:{bvar} | "
        f"Val:{r['value']} | "
        f"FracDuring:{r.activation_frac_during:.3f} | "
        f"FracOutside:{r.activation_frac_outside:.3f} | "
        f"ActRatio:{r.activation_ratio:.3f} | "
        f"RateProp:{r.rate_proportion:.3f}"
    )
    preset_entries.append((label, (int(r.instance_idx), int(r.feature_idx), bvar)))

preset_dropdown = widgets.Dropdown(
    options=preset_entries,
    description='Select Result:',
    layout=widgets.Layout(width='80%')
)
preset_box = widgets.VBox([preset_dropdown])

# Manual instance & feature selection
instance_dropdown = widgets.Dropdown(
    options=sorted(acts_df['instance_idx'].unique()),
    description='Instance:'
)

feature_dropdown = widgets.Dropdown(
    options=[],
    description='Feature:'
)

def _on_instance_change(change):
    inst = change['new']
    feats = sorted(
        acts_df.loc[acts_df['instance_idx'] == inst, 'feature_idx'].unique()
    )
    feature_dropdown.options = feats

instance_dropdown.observe(_on_instance_change, names='value')
_on_instance_change({'new': instance_dropdown.value})

manual_box = widgets.VBox([instance_dropdown, feature_dropdown])
manual_box.layout.display = 'none'

# Maze condition and target position filters
maze_options = sorted(combined_trials_df['maze_condition'].dropna().unique())
maze_dropdown = widgets.Dropdown(
    options=[None] + maze_options,
    description='Maze cond:'
)

target_positions = combined_trials_df['hit_target_position'].dropna().unique()
target_strs = [str(pos) for pos in target_positions]
target_dropdown = widgets.Dropdown(
    options=[None] + target_strs,
    description='Target pos:'
)

# Toggle preset vs manual
def _on_mode_change(change):
    if change['new'] == 'preset':
        preset_box.layout.display = ''
        manual_box.layout.display = 'none'
    else:
        preset_box.layout.display = 'none'
        manual_box.layout.display = ''

mode_radio.observe(_on_mode_change, names='value')
_on_mode_change({'new': mode_radio.value})

# Generate button and output area
generate_btn = widgets.Button(description='Generate Plot', button_style='info')
out = widgets.Output()

def _on_generate(_):
    with out:
        clear_output()
        if mode_radio.value == 'preset':
            inst, feat, _ = preset_dropdown.value
        else:
            inst = instance_dropdown.value
            feat = feature_dropdown.value

        maze = [maze_dropdown.value] if maze_dropdown.value is not None else None
        tgt = target_dropdown.value
        hit_positions = [eval(tgt)] if tgt is not None else None

        warped_data, used_feat, top_unit = warp_trials_to_canonical_timeline(
            combined_trials_df,
            acts_df,
            spk_z_scores_df,
            metadata_binned_subset,
            instance_idx=inst,
            feature_idx=feat,
            maze_conditions=maze,
            hit_target_positions=hit_positions
        )
        top_unit_id = session.units.id[top_unit]
        if isinstance(top_unit_id, (bytes, bytearray)):
            top_unit_id = top_unit_id.decode("utf-8")
        if "group_1" in top_unit_id:
            top_unit_id = "PMd"
        elif "group_2" in top_unit_id:
            top_unit_id = "M1"

        if warped_data is not None:
            plot_warped_trials(warped_data, inst, used_feat, top_unit, top_unit_id)
        else:
            print("No data to display. Check your selections.")

generate_btn.on_click(_on_generate)

# Assemble and display the UI
ui = widgets.VBox([
    widgets.HTML("<h2>Warp Trials Visualization</h2>"),
    mode_radio,
    preset_box,
    manual_box,
    maze_dropdown,
    target_dropdown,
    generate_btn,
    out
])
display(ui)

In [None]:
"""Visualisation functions for all feature associations"""

def plot_feature_tuning(
    acts_df: pd.DataFrame,
    spk_z_scores_df: pd.DataFrame,
    metadata_binned: pd.DataFrame,
    variable: str,
    instance_idx: int,
    feature_idx: int
):
    """Visualizes SAE feature tuning to metadata variables."""
    # Get feature activations
    feature_acts = acts_df[(acts_df['instance_idx'] == instance_idx) & (acts_df['feature_idx'] == feature_idx)]
    
    # Find top/bottom co-active units
    if len(feature_acts) > 0:
        feature_active_indices = feature_acts['example_idx'].values
        neuron_mean_zscores = spk_z_scores_df.iloc[feature_active_indices].mean(axis=0)
        top_unit = neuron_mean_zscores.idxmax()
        bottom_unit = neuron_mean_zscores.idxmin()
        top_zscore = neuron_mean_zscores[top_unit]
        bottom_zscore = neuron_mean_zscores[bottom_unit]
        print(f"Top co-active unit: {top_unit} (z-score: {top_zscore:.3f})")
        print(f"Bottom co-active unit: {bottom_unit} (z-score: {bottom_zscore:.3f})")
    else:
        print("No feature activations found for this instance/feature.")
        return
    
    # Create complete dataset with zeros for inactive features
    all_examples = pd.DataFrame({'example_idx': range(len(metadata_binned))})
    all_examples = all_examples.merge(feature_acts[['example_idx', 'activation_value']], on='example_idx', how='left').fillna(0)
    
    if all_examples.empty:
        print(f"⚠️ No examples found for Instance {instance_idx}, Feature {feature_idx}. Cannot generate plot.")
        return
    
    # Get metadata and z-scores for all examples
    metadata_slice = metadata_binned[[variable]].iloc[all_examples['example_idx']]
    top_unit_slice = spk_z_scores_df[[top_unit]].iloc[all_examples['example_idx']]
    bottom_unit_slice = spk_z_scores_df[[bottom_unit]].iloc[all_examples['example_idx']]
    
    # Create plotting dataframe
    data_df = metadata_slice.reset_index(drop=True)
    data_df['activation_value'] = all_examples['activation_value'].reset_index(drop=True)
    data_df['top_unit_zscore'] = top_unit_slice[top_unit].reset_index(drop=True)
    data_df['bottom_unit_zscore'] = bottom_unit_slice[bottom_unit].reset_index(drop=True)
    data_df = data_df.dropna(subset=[variable])
    
    if data_df.empty:
        print(f"⚠️ No matching metadata found for feature bins. Cannot generate plot.")
        return
    
    # Check if data is interval type
    try:
        is_interval_data = pd.api.types.is_interval_dtype(data_df[variable].cat.categories)
    except AttributeError:
        is_interval_data = False
    
    # Calculate summary statistics
    stats_df = data_df.groupby(variable).agg({
        'activation_value': ['mean', 'sem'],
        'top_unit_zscore': ['mean', 'sem'],
        'bottom_unit_zscore': ['mean', 'sem']
    }).reset_index()
    
    # Flatten column names
    stats_df.columns = [variable, 'feature_mean', 'feature_sem', 'top_unit_mean', 'top_unit_sem', 'bottom_unit_mean', 'bottom_unit_sem']
    
    # Calculate rate proportions
    if not feature_acts.empty:
        condition_masks = {condition: metadata_binned[variable] == condition for condition in stats_df[variable]}
        active_example_set = set(feature_acts['example_idx'])
        
        rate_props = []
        for _, row in stats_df.iterrows():
            condition = row[variable]
            condition_mask = condition_masks[condition]
            
            condition_example_idxs = np.where(condition_mask)[0]
            condition_activations = len(active_example_set.intersection(condition_example_idxs))
            activation_frac_during = condition_activations / len(condition_example_idxs) if len(condition_example_idxs) > 0 else 0
            
            non_condition_example_idxs = np.where(~condition_mask)[0]
            non_condition_activations = len(active_example_set.intersection(non_condition_example_idxs))
            activation_frac_outside = non_condition_activations / len(non_condition_example_idxs) if len(non_condition_example_idxs) > 0 else 0
            
            rate_proportion = activation_frac_during / (activation_frac_during + activation_frac_outside) if (activation_frac_during + activation_frac_outside) > 0 else 0
            rate_props.append(rate_proportion)
        
        stats_df['rate_proportion'] = rate_props
    else:
        stats_df['rate_proportion'] = 0
    
    # Calculate z-score stats for bar plot
    if len(feature_acts) > 0:
        zscore_stats = spk_z_scores_df.iloc[feature_active_indices].agg(['mean', 'sem']).T
        zscore_stats.columns = ['mean_zscore', 'sem_zscore']
        zscore_stats = zscore_stats.reset_index()
        zscore_stats.columns = ['neuron', 'mean_zscore', 'sem_zscore']
    else:
        zscore_stats = pd.DataFrame({'neuron': spk_z_scores_df.columns, 'mean_zscore': 0, 'sem_zscore': 0})
    
    # Create plots based on variable type
    if 'angle' in variable:  # Polar plot
        stats_df['theta'] = stats_df[variable].apply(lambda x: x.mid if isinstance(x, pd.Interval) else x)
        stats_df = stats_df.sort_values('theta')
        plot_df = pd.concat([stats_df, stats_df.head(1)], ignore_index=True)
        
        fig = make_subplots(
            rows=2, cols=4,
            specs=[[{"type": "polar"}, {"type": "polar"}, {"type": "polar"}, {"type": "polar"}],
                   [{"type": "xy", "colspan": 4}, None, None, None]],
            horizontal_spacing=0.1,
            vertical_spacing=0.15,
            subplot_titles=["Feature Activation", "Top Unit Z-score", "Bottom Unit Z-score", "Rate Proportion", "Mean Z-scores when Feature Active"]
        )
        
        # Feature activation polar plot
        fig.add_trace(go.Scatterpolar(r=plot_df['feature_mean'] + plot_df['feature_sem'], theta=plot_df['theta'], mode='lines', line=dict(width=0), showlegend=False), row=1, col=1)
        fig.add_trace(go.Scatterpolar(r=plot_df['feature_mean'] - plot_df['feature_sem'], theta=plot_df['theta'], mode='lines', line=dict(width=0), fill='tonext', fillcolor='rgba(220,20,60,0.2)', name='Feature ±SEM'), row=1, col=1)
        fig.add_trace(go.Scatterpolar(r=plot_df['feature_mean'], theta=plot_df['theta'], mode='lines+markers', line=dict(color='crimson', width=3), name='Feature Activation'), row=1, col=1)
        
        # Top unit polar plot
        fig.add_trace(go.Scatterpolar(r=plot_df['top_unit_mean'] + plot_df['top_unit_sem'], theta=plot_df['theta'], mode='lines', line=dict(width=0), showlegend=False), row=1, col=2)
        fig.add_trace(go.Scatterpolar(r=plot_df['top_unit_mean'] - plot_df['top_unit_sem'], theta=plot_df['theta'], mode='lines', line=dict(width=0), fill='tonext', fillcolor='rgba(0,0,139,0.2)', name='Top Unit ±SEM'), row=1, col=2)
        fig.add_trace(go.Scatterpolar(r=plot_df['top_unit_mean'], theta=plot_df['theta'], mode='lines+markers', line=dict(color='darkblue', width=3), name='Top Unit Z-score'), row=1, col=2)
        
        # Bottom unit polar plot
        fig.add_trace(go.Scatterpolar(r=plot_df['bottom_unit_mean'] + plot_df['bottom_unit_sem'], theta=plot_df['theta'], mode='lines', line=dict(width=0), showlegend=False), row=1, col=3)
        fig.add_trace(go.Scatterpolar(r=plot_df['bottom_unit_mean'] - plot_df['bottom_unit_sem'], theta=plot_df['theta'], mode='lines', line=dict(width=0), fill='tonext', fillcolor='rgba(255,165,0,0.2)', name='Bottom Unit ±SEM'), row=1, col=3)
        fig.add_trace(go.Scatterpolar(r=plot_df['bottom_unit_mean'], theta=plot_df['theta'], mode='lines+markers', line=dict(color='orange', width=3), name='Bottom Unit Z-score'), row=1, col=3)
        
        # Rate proportion polar plot
        fig.add_trace(go.Scatterpolar(r=plot_df['rate_proportion'], theta=plot_df['theta'], mode='lines+markers', line=dict(color='green', width=3), name='Rate Proportion'), row=1, col=4)
        
        # Z-score bar plot
        fig.add_trace(go.Bar(x=zscore_stats['neuron'], y=zscore_stats['mean_zscore'], error_y=dict(type='data', array=zscore_stats['sem_zscore']), marker_color='purple', marker_line_width=0, opacity=0.7, name='Mean Z-score'), row=2, col=1)
        
        # Create tick labels
        if 'movement_angle' not in variable:
            tick_labels = [f"{int(round(theta))}°" for theta in stats_df['theta']]
        else:
            tick_labels = [f"{int(interval.mid)}°" if hasattr(interval, 'mid') else str(interval) for interval in stats_df[variable]]    
        
        fig.update_layout(title=f"Instance {instance_idx} Feature {feature_idx} & Top Unit {top_unit} & Bottom Unit {bottom_unit}", showlegend=True, height=800, width=1600, margin=dict(t=80, b=60, l=50, r=50))
        
        # Update polar plots
        rotation = 195 if 'movement_angle' in variable else 0
        tickvals = stats_df['theta'].tolist() if 'movement_angle' in variable else [(angle % 360) for angle in stats_df['theta'].tolist()]
        for i in range(1, 5):
            polar_key = f'polar{i if i > 1 else ""}'
            fig.update_layout(**{polar_key: dict(angularaxis=dict(direction="counterclockwise", rotation=rotation, tickvals=tickvals, ticktext=tick_labels, tickfont=dict(size=10)), radialaxis=dict(range=[0, None]))})

        fig.update_xaxes(title_text="Neuron", row=2, col=1)
        fig.update_yaxes(title_text="Mean Z-score", row=2, col=1)
            
    else:  # Linear plot
        fig = make_subplots(rows=2, cols=4, specs=[[{}, {}, {}, {}], [{"colspan": 4}, None, None, None]], subplot_titles=["Feature Activation", "Top Unit Z-score", "Bottom Unit Z-score", "Rate Proportion", "Mean Z-scores when Feature Active"], horizontal_spacing=0.1, vertical_spacing=0.3)
        
        # Setup x-axis labels
        if is_interval_data:
            x_axis_labels = stats_df[variable].apply(lambda x: str(x))
            stats_df = stats_df.sort_values(by=variable)
        else:
            x_axis_labels = stats_df[variable].astype(str)
        
        # Plot based on data type
        if is_interval_data:
            # Line plots for continuous data
            fig.add_trace(go.Scatter(x=x_axis_labels, y=stats_df['feature_mean'] + stats_df['feature_sem'], mode='lines', line_color='rgba(0,0,0,0)', showlegend=False), row=1, col=1)
            fig.add_trace(go.Scatter(x=x_axis_labels, y=stats_df['feature_mean'] - stats_df['feature_sem'], mode='lines', line_color='rgba(0,0,0,0)', fill='tonexty', fillcolor='rgba(220,20,60,0.2)', name='Feature ±SEM'), row=1, col=1)
            fig.add_trace(go.Scatter(x=x_axis_labels, y=stats_df['feature_mean'], mode='lines+markers', line_color='crimson', name='Feature Activation'), row=1, col=1)
            
            fig.add_trace(go.Scatter(x=x_axis_labels, y=stats_df['top_unit_mean'] + stats_df['top_unit_sem'], mode='lines', line_color='rgba(0,0,0,0)', showlegend=False), row=1, col=2)
            fig.add_trace(go.Scatter(x=x_axis_labels, y=stats_df['top_unit_mean'] - stats_df['top_unit_sem'], mode='lines', line_color='rgba(0,0,0,0)', fill='tonexty', fillcolor='rgba(0,0,139,0.2)', name='Top Unit ±SEM'), row=1, col=2)
            fig.add_trace(go.Scatter(x=x_axis_labels, y=stats_df['top_unit_mean'], mode='lines+markers', line_color='darkblue', name='Top Unit Z-score'), row=1, col=2)
            
            fig.add_trace(go.Scatter(x=x_axis_labels, y=stats_df['bottom_unit_mean'] + stats_df['bottom_unit_sem'], mode='lines', line_color='rgba(0,0,0,0)', showlegend=False), row=1, col=3)
            fig.add_trace(go.Scatter(x=x_axis_labels, y=stats_df['bottom_unit_mean'] - stats_df['bottom_unit_sem'], mode='lines', line_color='rgba(0,0,0,0)', fill='tonexty', fillcolor='rgba(255,165,0,0.2)', name='Bottom Unit ±SEM'), row=1, col=3)
            fig.add_trace(go.Scatter(x=x_axis_labels, y=stats_df['bottom_unit_mean'], mode='lines+markers', line_color='orange', name='Bottom Unit Z-score'), row=1, col=3)
            
            fig.add_trace(go.Scatter(x=x_axis_labels, y=stats_df['rate_proportion'], mode='lines+markers', line=dict(color='green', width=3), name='Rate Proportion'), row=1, col=4)
        else:
            # Bar plots for categorical data
            fig.add_trace(go.Bar(x=x_axis_labels, y=stats_df['feature_mean'], error_y=dict(type='data', array=stats_df['feature_sem']), marker_color='crimson', marker_line_width=0, opacity=0.7, name='Feature Activation'), row=1, col=1)
            fig.add_trace(go.Bar(x=x_axis_labels, y=stats_df['top_unit_mean'], error_y=dict(type='data', array=stats_df['top_unit_sem']), marker_color='darkblue', marker_line_width=0, opacity=0.7, name='Top Unit Z-score'), row=1, col=2)
            fig.add_trace(go.Bar(x=x_axis_labels, y=stats_df['bottom_unit_mean'], error_y=dict(type='data', array=stats_df['bottom_unit_sem']), marker_color='orange', marker_line_width=0, opacity=0.7, name='Bottom Unit Z-score'), row=1, col=3)
            fig.add_trace(go.Bar(x=x_axis_labels, y=stats_df['rate_proportion'], marker_color='green', marker_line_width=0, opacity=0.7, name='Rate Proportion'), row=1, col=4)
        
        # Z-score bar plot
        fig.add_trace(go.Bar(x=zscore_stats['neuron'], y=zscore_stats['mean_zscore'], error_y=dict(type='data', array=zscore_stats['sem_zscore']), marker_color='purple', marker_line_width=0, opacity=0.7, name='Mean Z-score'), row=2, col=1)
        
        fig.update_layout(title=f"Instance {instance_idx} Feature {feature_idx} & Top Unit {top_unit} & Bottom Unit {bottom_unit}", height=800, width=1600, showlegend=True, margin=dict(t=80, b=60, l=50, r=50))
        
        # Update axes
        fig.update_xaxes(title_text=variable, tickangle=45, row=1, col=1)
        fig.update_xaxes(title_text=variable, tickangle=45, row=1, col=2)
        fig.update_xaxes(title_text=variable, tickangle=45, row=1, col=3)
        fig.update_xaxes(title_text=variable, tickangle=45, row=1, col=4)
        fig.update_xaxes(title_text="Neuron", row=2, col=1)
        
        fig.update_yaxes(title_text="Feature Activation", color="crimson", range=[0, None], row=1, col=1)
        fig.update_yaxes(title_text="Top Unit Z-score", color="darkblue", row=1, col=2)
        fig.update_yaxes(title_text="Bottom Unit Z-score", color="orange", row=1, col=3)
        fig.update_yaxes(title_text="Rate Proportion", range=[0, None], color="green", row=1, col=4)
        fig.update_yaxes(title_text="Mean Z-score", row=2, col=1)
        
        fig.update_yaxes(rangemode='tozero', row=1, col=1)
        fig.update_yaxes(rangemode='tozero', row=1, col=4)
    
    fig.show()

# # Find go_cue feature and plot
# event_feature = results[results['value'] == 'go_cue'].sort_values('activation_ratio', ascending=False).iloc[0]
# display(event_feature)
# plot_feature_tuning(acts_df=acts_df, spk_z_scores_df=spk_z_scores_df, metadata_binned=metadata_binned_subset, variable='event', instance_idx=int(event_feature['instance_idx']), feature_idx=int(event_feature['feature_idx']))

# # Find movement_angle feature and plot  
# move_angle_feature = results[results['variable'] == 'movement_angle'].sort_values('activation_ratio', ascending=False).iloc[0]
# display(move_angle_feature)
# plot_feature_tuning(acts_df=acts_df, spk_z_scores_df=spk_z_scores_df, metadata_binned=metadata_binned, variable='movement_angle_binned', instance_idx=int(move_angle_feature['instance_idx']), feature_idx=int(move_angle_feature['feature_idx']))

# # Find vel_magnitude feature and plot
# velocity_feature = results[results['variable'] == 'vel_magnitude'].sort_values('activation_ratio', ascending=False).iloc[0]
# display(velocity_feature)
# plot_feature_tuning(acts_df=acts_df, spk_z_scores_df=spk_z_scores_df, metadata_binned=metadata_binned_subset, variable='vel_magnitude_binned', instance_idx=int(velocity_feature['instance_idx']), feature_idx=int(velocity_feature['feature_idx']))

In [None]:
"""Interactive UI"""

# Helper to map a base variable and its type to the metadata_binned_subset column name
def _bvar_name(var_name, var_type):
    return f"{var_name}_binned" if var_type == 'continuous' else var_name

# Selector for preset vs manual mode
mode_radio = widgets.RadioButtons(
    options=[
        ('Preset (from results table)', 'preset'),
        ('Manual selection', 'manual')
    ],
    value='preset',
    description=''
)

# Build the preset dropdown and store (instance, feature, variable) as the value
preset_entries = []
for _, r in results.iterrows():
    bvar = _bvar_name(r.variable, r.variable_type)
    if bvar not in metadata_binned_subset.columns:
        continue  # Skip variables you haven’t binned
    label = (
        f"Inst:{int(r.instance_idx)} | "
        f"Feat:{int(r.feature_idx)} | "
        f"Var:{bvar} | "
        f"Val:{r['value']} | "
        f"FracDuring:{r.activation_frac_during:.3f} | "
        f"FracOutside:{r.activation_frac_outside:.3f} | "
        f"ActRatio:{r.activation_ratio:.3f} | "
        f"RateProp:{r.rate_proportion:.3f}"
    )
    preset_entries.append((label, (int(r.instance_idx), int(r.feature_idx), bvar)))

preset_dropdown = widgets.Dropdown(
    options=preset_entries,
    description='Select Result:',
    layout=widgets.Layout(width='80%')
)
preset_box = widgets.VBox([preset_dropdown])

# ---- Manual selection (strict coupling) ----

# Precompute bvar column on results for filtering; keep existing columns intact
if 'bvar' not in results.columns:
    results['bvar'] = results.apply(
        lambda r: _bvar_name(r.variable, r.variable_type), axis=1
    )

instance_dropdown = widgets.Dropdown(
    options=sorted(acts_df['instance_idx'].unique()),
    description='Instance:'
)

# Start variable dropdown empty; we'll populate based on instance
variable_dropdown = widgets.Dropdown(
    options=[],
    description='Variable:'
)

# Feature dropdown will depend on (instance, variable)
feature_dropdown = widgets.Dropdown(
    description='Feature:',
    options=[]
)

def _refresh_variable_options(*_):
    """Restrict variables to those present for the selected instance and binned in metadata."""
    inst = instance_dropdown.value
    if inst is None:
        variable_dropdown.options = []
        variable_dropdown.value = None
        return

    mask = (results['instance_idx'] == inst)
    vars_for_inst = sorted(results.loc[mask, 'bvar'].unique())
    vars_for_inst = [v for v in vars_for_inst if v in metadata_binned_subset.columns]

    # Keep label==value pairing as before
    manual_var_options = [(v, v) for v in vars_for_inst]
    prev = variable_dropdown.value
    variable_dropdown.options = manual_var_options
    # Preserve selection if still valid; else pick first or None
    valid_vals = [v for _, v in manual_var_options]
    variable_dropdown.value = prev if prev in valid_vals else (valid_vals[0] if valid_vals else None)

def _refresh_feature_options(*_):
    """Restrict features to those present for (instance, variable)."""
    inst = instance_dropdown.value
    sel_var = variable_dropdown.value
    if (inst is None) or (sel_var is None):
        feature_dropdown.options = []
        feature_dropdown.value = None
        feature_dropdown.disabled = True
        return

    mask = (results['instance_idx'] == inst) & (results['bvar'] == sel_var)
    feats = sorted(results.loc[mask, 'feature_idx'].unique())

    feature_dropdown.options = feats
    feature_dropdown.value = (feats[0] if feats else None)
    feature_dropdown.disabled = (len(feats) == 0)

# Wire up dependencies: instance → variables, and (instance or variable) → features
instance_dropdown.observe(_refresh_variable_options, names='value')
instance_dropdown.observe(_refresh_feature_options, names='value')
variable_dropdown.observe(_refresh_feature_options, names='value')

# Prime once after widgets are created
_refresh_variable_options()
_refresh_feature_options()

manual_box = widgets.VBox([
    instance_dropdown,
    variable_dropdown,
    feature_dropdown
])
manual_box.layout.display = 'none'  # Start hidden

# Buttons for generating or clearing the plot, and an output area
generate_btn = widgets.Button(description='Generate Plot', button_style='info')
clear_btn    = widgets.Button(description='Clear',         button_style='warning')
button_box   = widgets.HBox([generate_btn, clear_btn])
out = widgets.Output()

# Toggle between preset and manual views
def _on_mode_change(change):
    if change['new'] == 'preset':
        preset_box.layout.display = ''
        manual_box.layout.display = 'none'
    else:
        preset_box.layout.display = 'none'
        manual_box.layout.display = ''
        # ensure dropdowns are in a valid state when switching
        _refresh_variable_options()
        _refresh_feature_options()

mode_radio.observe(_on_mode_change, names='value')

# Callback to generate the tuning plot
def _on_generate(_):
    with out:
        clear_output()
        if mode_radio.value == 'preset':
            inst, feat, var = preset_dropdown.value
        else:
            inst = instance_dropdown.value
            var  = variable_dropdown.value
            feat = feature_dropdown.value
            if (inst is None) or (var is None) or (feat is None):
                print("No matching (instance, variable, feature) for the current selection.")
                return

        plot_feature_tuning(
            acts_df=acts_df,
            spk_z_scores_df=spk_z_scores_df,
            metadata_binned=metadata_binned_subset,
            variable=var,
            instance_idx=inst,
            feature_idx=feat
        )

# Callback to clear the output
def _on_clear(_):
    out.clear_output()

generate_btn.on_click(_on_generate)
clear_btn.on_click(_on_clear)

# Assemble and display the UI
ui = widgets.VBox([
    widgets.HTML("<h2>SAE Feature Visualization</h2>"),
    mode_radio,
    preset_box,
    manual_box,
    button_box,
    out
])
display(ui)


In [None]:
"""Interactive UI"""

# Colour-by selector
colour_by = widgets.Dropdown(
    options=[
        ("Maze condition", "maze_condition"),
        ("Hit position", "hit_position"),
        ("Number of barriers", "barriers"),
        ("Number of targets", "targets")
    ],
    value="maze_condition",
    description="Colour by:"
)

# Preset dropdown (same style as earlier UI)
preset_entries = []
for _, r in results.iterrows():
    bvar = _bvar_name(r.variable, r.variable_type)
    if bvar not in metadata_binned_subset.columns:
        continue
    label = (
        f"Inst:{int(r.instance_idx)} | "
        f"Feat:{int(r.feature_idx)} | "
        f"Var:{bvar} | "
        f"Val:{r['value']} | "
        f"FracDuring:{r.activation_frac_during:.3f} | "
        f"FracOutside:{r.activation_frac_outside:.3f} | "
        f"ActRatio:{r.activation_ratio:.3f} | "
        f"RateProp:{r.rate_proportion:.3f}"
    )
    preset_entries.append((label, (int(r.instance_idx), int(r.feature_idx), bvar)))

preset_dropdown = widgets.Dropdown(
    options=preset_entries,
    description='Select Result:',
    layout=widgets.Layout(width='90%')
)
preset_box = widgets.VBox([preset_dropdown])

# Manual selection (instance → variable → feature dependent dropdowns)
if 'bvar' not in results.columns:
    results['bvar'] = results.apply(lambda r: _bvar_name(r.variable, r.variable_type), axis=1)

instance_dropdown = widgets.Dropdown(
    options=sorted(acts_df['instance_idx'].unique()),
    description='Instance:'
)
variable_dropdown = widgets.Dropdown(
    options=[],
    description='Variable:'
)
feature_dropdown = widgets.Dropdown(
    options=[],
    description='Feature:'
)

def _refresh_variable_options(*_):
    inst = instance_dropdown.value
    if inst is None:
        variable_dropdown.options = []
        variable_dropdown.value = None
        return
    mask = (results['instance_idx'] == inst)
    vars_for_inst = sorted(results.loc[mask, 'bvar'].unique())
    vars_for_inst = [v for v in vars_for_inst if v in metadata_binned_subset.columns]
    variable_dropdown.options = vars_for_inst
    variable_dropdown.value = (vars_for_inst[0] if vars_for_inst else None)

def _refresh_feature_options(*_):
    inst = instance_dropdown.value
    sel_var = variable_dropdown.value
    if (inst is None) or (sel_var is None):
        feature_dropdown.options = []
        feature_dropdown.value = None
        return
    mask = (results['instance_idx'] == inst) & (results['bvar'] == sel_var)
    feats = sorted(results.loc[mask, 'feature_idx'].unique())
    feature_dropdown.options = feats
    feature_dropdown.value = (feats[0] if feats else None)

instance_dropdown.observe(_refresh_variable_options, names='value')
instance_dropdown.observe(_refresh_feature_options, names='value')
variable_dropdown.observe(_refresh_feature_options, names='value')

_refresh_variable_options()
_refresh_feature_options()

manual_box = widgets.VBox([instance_dropdown, variable_dropdown, feature_dropdown])
manual_box.layout.display = 'none'

# Buttons and output
generate_btn = widgets.Button(description='Generate Plot', button_style='info')
clear_btn    = widgets.Button(description='Clear',         button_style='warning')
button_box   = widgets.HBox([generate_btn, clear_btn])
out = widgets.Output()

# Mode toggle
def _on_mode_change(change):
    if change['new'] == 'preset':
        preset_box.layout.display = ''
        manual_box.layout.display = 'none'
    else:
        preset_box.layout.display = 'none'
        manual_box.layout.display = ''
        _refresh_variable_options()
        _refresh_feature_options()
mode_radio.observe(_on_mode_change, names='value')

# --- Plot function ---
def plot_trial_avg_scatter(instance_idx, feature_idx, variable, colour_var):
    acts = acts_df[(acts_df.instance_idx==instance_idx) & (acts_df.feature_idx==feature_idx)]
    if acts.empty:
        print("No activations found")
        return

    merged = acts.merge(metadata_binned_subset.reset_index(), left_on="example_idx", right_index=True, how="left")

    trial_avg = merged.groupby("trial_idx").agg(
        avg_activation=("activation_value", "mean"),
        timestamp=("timestamp", "first"),
        maze_condition=("maze_condition", "first"),
        barriers=("barriers", "first"),
        targets=("targets", "first"),
        hit_position_x=("hit_position_x", "first"),
        hit_position_y=("hit_position_y", "first"),
    ).reset_index()

    if colour_var == "hit_position":
        def pos_label(x, y):
            if y > 0 and x < 0: return "Top-Left"
            if y > 0 and x > 0: return "Top-Right"
            if y < 0 and x < 0: return "Bottom-Left"
            if y < 0 and x > 0: return "Bottom-Right"
            return "Other"
        trial_avg["hit_position"] = trial_avg.apply(lambda r: pos_label(r.hit_position_x, r.hit_position_y), axis=1)

    import plotly.express as px
    fig = px.scatter(
        trial_avg,
        x="timestamp",
        y="avg_activation",
        color=colour_var,
        title=f"Inst {instance_idx}, Feat {feature_idx} | Avg activation per trial",
        color_discrete_sequence=px.colors.qualitative.Alphabet * 10
    )
    fig.update_layout(
        legend=dict(orientation="v", x=1.02, y=1),
        margin=dict(r=200)
    )
    fig.show()

# Button callbacks
def _on_generate(_):
    with out:
        out.clear_output()
        if mode_radio.value == 'preset':
            inst, feat, var = preset_dropdown.value
        else:
            inst = instance_dropdown.value
            var  = variable_dropdown.value
            feat = feature_dropdown.value
            if (inst is None) or (var is None) or (feat is None):
                print("No matching (instance, variable, feature)")
                return
        plot_trial_avg_scatter(inst, feat, var, colour_by.value)

def _on_clear(_):
    out.clear_output()

generate_btn.on_click(_on_generate)
clear_btn.on_click(_on_clear)

# Assemble UI
ui = widgets.VBox([
    widgets.HTML("<h2>Trial-average Feature Activations</h2>"),
    mode_radio,
    preset_box,
    manual_box,
    colour_by,
    button_box,
    out
])
display(ui)
