## MSAE training and feature analysis

Main goal: look at different levels of features in a 3-level MSAE

Steps:

1. Train MSAE on spikes dataset

2. Create (topk) sae feature df

3. Create stim (meta)data df

4. Interp / autointerp
    
    a. Create infra for finding stim at times of particular feature(s)
    
    b. Create infra for finding topk SAE features that fire at times of particular stim

In [None]:
"""Set notebook settings."""

%load_ext autoreload
%autoreload 2
# %flow mode reactive

In [None]:
"""Import packages."""

# Standard library imports
import math
import os
from collections import defaultdict
from dataclasses import dataclass
from datetime import datetime
from functools import partial
from pathlib import Path
from typing import Any, Callable, Dict, List, Optional, Tuple

# IPython and Jupyter-related imports
import ipywidgets as widgets
from IPython.display import clear_output, display

# Third-party libraries
import h5py
import numpy as np
import pandas as pd
import seaborn as sns
import temporaldata as td
import torch as t
from einops import (
    asnumpy,
    einsum,
    pack,
    parse_shape,
    rearrange,
    reduce,
    repeat,
    unpack,
)
from einops.layers.torch import Rearrange, Reduce
from jaxtyping import Float, Int
from matplotlib import pyplot as plt
from plotly import express as px
from plotly import graph_objects as go
from plotly.subplots import make_subplots
from rich import print as rprint
from scipy import stats
from scipy.ndimage import uniform_filter1d
from sklearn.linear_model import LinearRegression
from sklearn.metrics import classification_report, confusion_matrix, r2_score
from temporaldata import Data
from torch import Tensor, bfloat16, nn
from torch.nn import functional as F
from torcheval.metrics.functional import r2_score as tm_r2_score
from tqdm.notebook import tqdm

# Local project modules
from mini import train as mt
from mini.util import vec_r2

In [None]:
# Set max rows and cols for df display
pd.set_option("display.max_rows", 300)
pd.set_option("display.max_columns", 25)

In [None]:
def clean_session_data(session):
    """Clean session data by filtering trials and spikes based on quality criteria."""
    
    # Mark Churchland gave me a matlab script which I have converted to python

    num_trials = len(session.trials.start)
    # I think is_valid is already defined by brainsets as (session.trials.discard_trial == 0) & (session.trials.task_success == 1) so it's a bit redundant but including it all just to be sure
    # In theory I should also filter based on whether the maze was possible or not (a field called "unhittable") but I cannot find this in the data, perhaps this has already been done in this release of the data
    good_trials = (session.trials.trial_type > 0) & (session.trials.is_valid == 1) & (session.trials.discard_trial == 0) & (session.trials.novel_maze == 0) & (session.trials.trial_version < 3) 
    session.trials = session.trials.select_by_mask(good_trials)
    new_num_trials = len(session.trials.start)
    if num_trials - new_num_trials > 0:
        print("Filtered out extraneous trials, went from", num_trials, "trials to", new_num_trials)

    num_trials = len(session.trials.start)
    success = (session.trials.task_success == 1)
    session.trials = session.trials.select_by_mask(success)
    new_num_trials = len(session.trials.start)
    if num_trials - new_num_trials > 0:
        print("Filtered out unsuccessful trials, went from", num_trials, "trials to", new_num_trials)

    num_trials = len(session.trials.start)
    post_move = 0.8 # to be kept, there must be at least this many ms after the movement onset
    long_enough = (session.trials.end - session.trials.move_begins_time >= post_move) # should essentially always be true for successes
    session.trials = session.trials.select_by_mask(long_enough)
    new_num_trials = len(session.trials.start)
    if num_trials - new_num_trials > 0:
        print("Filtered out trials that were too short, went from", num_trials, "trials to", new_num_trials)

    num_trials = len(session.trials.start)
    consistent = (session.trials.correct_reach == 1)
    session.trials = session.trials.select_by_mask(consistent)
    new_num_trials = len(session.trials.start)
    if num_trials - new_num_trials > 0:
        print("Filtered out trials with inconsistent reaches (not similar enough to the \"prototypical\" trial), went from", num_trials, "trials to", new_num_trials)

    primary_conditions = np.unique(session.trials.maze_condition)
    num_conditions = len(primary_conditions)
    print("Number of primary conditions:", num_conditions)
    # Check to make sure they are monotonic, starting from 1 and counting up
    if min(primary_conditions) != 1 or len(np.unique(np.diff(primary_conditions))) != 1:
        raise ValueError("Primary conditions are not monotonic or do not start from 1")

    # In theory I should filter units based on a ranking from 1-4 but I cannot find the ranking in the data, perhaps this has already been done in this release of the data

    # Only keep spikes that are within the cleaned trials
    session.spikes = session.spikes.select_by_interval(session.trials)
    session.hand = session.hand.select_by_interval(session.trials)
    session.eye = session.eye.select_by_interval(session.trials)

    # Convert session recording date to timestamp
    session.session.recording_date = datetime.strptime(session.session.recording_date, '%Y-%m-%d %H:%M:%S')
    session.session.recording_date = session.session.recording_date.timestamp()
    
    return session

In [None]:
def analyze_maze_conditions(session):
    """
    Analyze what each maze_condition corresponds to in terms of 
    maze parameters (barriers, targets, hit position).
    """
    
    # Get unique maze conditions
    unique_conditions = np.unique(session.trials.maze_condition)
    
    # Create a summary for each condition
    condition_summary = []
    
    for condition in unique_conditions:
        # Get trials for this condition
        condition_mask = session.trials.maze_condition == condition
        
        # Get the unique values for this condition
        barriers = np.unique(session.trials.maze_num_barriers[condition_mask])
        targets = np.unique(session.trials.maze_num_targets[condition_mask])
        hit_position = np.unique(session.trials.hit_target_position[condition_mask], axis=0)
        if len(hit_position) > 1:
            raise ValueError(f"Condition {condition} has multiple hit positions: {hit_position}")
        else:
            hit_position = hit_position[0]
        
        # Count trials for this condition
        num_trials = np.sum(condition_mask)
        
        # Store for summary table
        condition_summary.append({
            'Maze Condition': condition,
            'Trials': num_trials,
            'Barriers': barriers,
            'Targets': targets,
            'Hit Position': hit_position, 
            'Hit Position Angles': str(np.degrees(np.arctan2(hit_position[1], hit_position[0])))
        })
    summary_df = pd.DataFrame(condition_summary)
    
    # Convert hit positions to tuples temporarily for proper duplicate detection
    summary_df_temp = summary_df.copy()
    summary_df_temp['Hit Position Tuple'] = summary_df_temp['Hit Position'].apply(tuple)
    plot_df = summary_df_temp.drop_duplicates(subset=['Hit Position Tuple'], keep='first')
    plot_df = plot_df.drop('Hit Position Tuple', axis=1)  # Remove the temporary column
    
    # Create a proper DataFrame for plotting
    plot_data = pd.DataFrame({
        'Hit Position X': plot_df['Hit Position'].apply(lambda x: x[0]),
        'Hit Position Y': plot_df['Hit Position'].apply(lambda x: x[1]),
        'Maze Condition': plot_df['Maze Condition'].astype(str)
    })
    
    # Generate unique colors for each maze condition
    import plotly.colors as pc
    n_conditions = len(plot_data['Maze Condition'].unique())
    colors = pc.sample_colorscale('viridis', [i/(max(n_conditions-1, 1)) for i in range(n_conditions)])
    
    # Plot hit position by maze condition
    fig = px.scatter(
        plot_data,
        x='Hit Position X',
        y='Hit Position Y',
        color='Maze Condition',
        labels={'Hit Position X': 'Hit Position X', 'Hit Position Y': 'Hit Position Y', 'color': 'Maze Condition'},
        title='Hit Position by Maze Condition',
        color_discrete_sequence=colors,
        hover_data=['Maze Condition']
    )
    
    fig.update_layout(
        xaxis=dict(scaleanchor="y", scaleratio=1, range=[-150, 150]),
        yaxis=dict(constrain="domain", range=[-100, 100]),
        width=600,
        height=600
    )
    fig.show()
    
    return summary_df

In [None]:
# Path to your data directory
data_path = r"C:\Users\pouge\Documents\mini_data\brainsets\processed\churchland_shenoy_neural_2012"
# data_path = "/ceph/aeon/aeon/SANe/brainsets_data/processed/churchland_shenoy_neural_2012"
data_path = Path(data_path)

# List all h5 files in the directory
h5_files = [f for f in os.listdir(data_path) if f.endswith('.h5')]
print(f"Available h5 files: {h5_files}")

# User parameters
subject_name = "nitschke"  # Change to "nitschke" or "jenkins"
num_files_to_load = 3     # Change to desired number of files, max 6 (only 3 work) for nitschke, 4 for jenkins

# Filter files by subject
subject_files = [f for f in h5_files if subject_name.lower() in f.lower()]
print(f"\nFiles for subject {subject_name}: {subject_files}")

if len(subject_files) == 0:
    print(f"No files found for subject {subject_name}")
elif len(subject_files) < num_files_to_load:
    print(f"Only {len(subject_files)} files available for {subject_name}, loading all of them")
    num_files_to_load = len(subject_files)

# Load and clean the specified number of files
sessions = []
for i in range(min(num_files_to_load, len(subject_files))):
    file_path = os.path.join(data_path, subject_files[i])
    print(f"\nLoading file {i+1}/{num_files_to_load}: {subject_files[i]}")
    
    # Read neural data from HDF5
    with h5py.File(file_path, "r") as f:
        session = Data.from_hdf5(f)

        session.spikes.materialize()
        session.trials.materialize()
        session.hand.materialize()
        session.eye.materialize()
        session.session.materialize()

        print("Session ID: ", session.session.id)
        print("Session subject id: ", session.subject.id)
        print("Session subject sex: ", session.subject.sex)
        print("Session subject species: ", session.subject.species)
        print("Session recording date: ", session.session.recording_date)
        print("Original number of trials:", len(session.trials.start))
        
        # Clean the session data
        try:
            session = clean_session_data(session)
            print("Final number of trials after cleaning:", len(session.trials.start))

            print("Summary of primary conditions:")
            primary_conditions_summary = analyze_maze_conditions(session)
            display(primary_conditions_summary)
            
            sessions.append(session)
        except Exception as e:
            print(f"Error processing session {session.session.id}: {e}")
            continue

print(f"\nSuccessfully loaded and cleaned {len(sessions)} sessions for subject {subject_name}")

In [None]:
def fix_maze_conditions_consistency(sessions):
    """
    Fix maze condition numbering to be consistent across all sessions.
    
    Args:
        sessions: List of session objects
    
    Returns:
        List of cleaned sessions with consistent maze condition numbering
    """
    
    def get_maze_signature(session, condition):
        """Get maze parameters for a specific condition to create a signature"""
        condition_mask = session.trials.maze_condition == condition
        
        # Get unique values for this condition
        barriers = tuple(np.unique(session.trials.maze_num_barriers[condition_mask]))
        targets = tuple(np.unique(session.trials.maze_num_targets[condition_mask]))
        hit_position = np.unique(session.trials.hit_target_position[condition_mask], axis=0)
        if len(barriers) > 1 or len(targets) > 1 or len(hit_position) > 1:
            raise ValueError(f"Condition {condition} has the following >1 unique values for one of the following: "
                             f"barriers={barriers}, targets={targets}, hit_position={hit_position}. "
                             "This should not be possible.")
        else:
            hit_position = tuple(tuple(hit_position[0]))
        
        return (barriers, targets, hit_position)
    
    def get_group_signature(session, group_conditions):
        """Get combined signature for a group of 3 conditions"""
        group_sigs = []
        for condition in sorted(group_conditions):
            sig = get_maze_signature(session, condition)
            group_sigs.append(sig)
        return tuple(group_sigs)
    
    # Remove sessions that don't have multiples of 3 maze conditions
    valid_sessions = []
    for i, session in enumerate(sessions):
        unique_conditions = np.unique(session.trials.maze_condition)
        num_conditions = len(unique_conditions)
        
        if num_conditions % 3 != 0:
            print(f"WARNING: Removing session {session.session.id} - has {num_conditions} maze conditions (not multiple of 3)")
            continue
        
        valid_sessions.append(session)
    
    if len(valid_sessions) == 0:
        raise ValueError("No valid sessions remaining after filtering")
    
    print(f"Kept {len(valid_sessions)} out of {len(sessions)} sessions after filtering")
    
    # Pick reference session (most unique maze conditions)
    condition_counts = []
    for session in valid_sessions:
        unique_conditions = len(np.unique(session.trials.maze_condition))
        condition_counts.append(unique_conditions)
    
    max_conditions = max(condition_counts)
    ref_idx = condition_counts.index(max_conditions)
    reference_session = valid_sessions[ref_idx]
    
    print(f"Using session {reference_session.session.id} as reference (has {max_conditions} maze conditions)")
    
    # Group conditions and check for duplicates in reference session
    ref_conditions = sorted(np.unique(reference_session.trials.maze_condition))
    ref_groups = []
    
    for i in range(0, len(ref_conditions), 3):
        group = ref_conditions[i:i+3]
        if len(group) != 3:
            raise ValueError(f"Reference session has incomplete group: {group}")
        ref_groups.append(group)
    
    # Create signatures for reference groups
    ref_group_signatures = {}
    ref_signatures_to_group = {}
    
    for group_idx, group in enumerate(ref_groups):
        signature = get_group_signature(reference_session, group)
        
        if signature in ref_signatures_to_group:
            existing_group = ref_signatures_to_group[signature]
            print(f"WARNING: Duplicate group found in reference session!")
            print(f"  Group {existing_group} and Group {group} have identical maze parameters")
        
        ref_group_signatures[group_idx] = signature
        ref_signatures_to_group[signature] = group
    
    print(f"Reference session has {len(ref_groups)} groups of maze conditions")

    print("Table of reference session maze conditions:")
    primary_conditions_summary = analyze_maze_conditions(reference_session)
    display(primary_conditions_summary) 
    
    # Process all sessions to match groups and renumber
    processed_sessions = []
    
    for session in valid_sessions:
        print(f"Processing session {session.session.id}...")
        
        # Get conditions and group them
        conditions = sorted(np.unique(session.trials.maze_condition))
        session_groups = []
        
        for i in range(0, len(conditions), 3):
            group = conditions[i:i+3]
            session_groups.append(group)
        
        # Match each group to reference
        condition_mapping = {}  # old_condition -> new_condition
        
        for group in session_groups:
            group_sig = get_group_signature(session, group)
            
            # Find matching reference group
            matched_ref_group_idx = None
            for ref_idx, ref_sig in ref_group_signatures.items():
                if group_sig == ref_sig:
                    matched_ref_group_idx = ref_idx
                    break
            
            if matched_ref_group_idx is None:
                print(f"ERROR: Could not match group {group} in session {session.session.id}")
                print(f"Group signature: {group_sig}")
                print("Available reference signatures:")
                for ref_idx, ref_sig in ref_group_signatures.items():
                    ref_group = ref_groups[ref_idx]
                    print(f"  Reference group {ref_group}: {ref_sig}")
                raise ValueError(f"Unmatchable maze conditions {group} in session {session.session.id}")
            
            # Map old conditions to new conditions
            new_base_condition = matched_ref_group_idx * 3 + 1  # 1, 4, 7, 10, ...
            for i, old_condition in enumerate(sorted(group)):
                new_condition = new_base_condition + i
                condition_mapping[old_condition] = new_condition
        
        
        # Apply the mapping to the session
        new_maze_conditions = np.array([condition_mapping[old] for old in session.trials.maze_condition])
        session.trials.maze_condition = new_maze_conditions
        
        processed_sessions.append(session)
    
    print(f"\nSuccessfully processed {len(processed_sessions)} sessions with consistent maze condition numbering")
    return processed_sessions

# Add this after your existing code, after sessions are loaded and cleaned:
print("Fixing maze condition consistency across sessions...")

sessions = fix_maze_conditions_consistency(sessions)

# Show updated condition summaries for verification
print("\nCondition remapping summary:")
for i, session in enumerate(sessions):
    unique_conditions = sorted(np.unique(session.trials.maze_condition))
    print(f"Session {session.session.id}: {unique_conditions}")

In [None]:
# Uncomment to train a model on a specific session post fixing maze conditions consistency
# print(len(sessions))
# sessions =  [sessions[i] for i in [2]]
# print(len(sessions))

In [None]:
# Parameters
bin_size = 0.05

# Check unit consistency across sessions
unit_ids = np.unique(sessions[0].spikes.unit_index)
for session in sessions:
    unique_units = np.unique(session.spikes.unit_index)
    if not np.array_equal(unique_units, unit_ids):
        raise ValueError("Sessions do not have the same unit IDs. Cannot combine spike data.")

# Determine global bin alignment start point
global_start = min(session.session.recording_date + session.trials.start.min() for session in sessions)
global_start = np.floor(global_start / bin_size) * bin_size  # ensure clean bin alignment

# Convert to consistent timestamps
n_decimals = int(-np.log10(bin_size)) + 1 if bin_size < 1 else 0

# Accumulator for all binned trials
binned_dfs = []

# Loop over sessions
for session in sessions:
    # Shift spike timestamps to absolute time
    abs_timestamps = session.spikes.timestamps + session.session.recording_date
    unit_ids_this_session = session.spikes.unit_index
    df_spikes = pd.DataFrame({
        'timestamp': abs_timestamps,
        'unit': unit_ids_this_session
    }).sort_values('timestamp')

    # Convert trial times to absolute time
    trial_starts = session.trials.start + session.session.recording_date
    trial_ends = session.trials.end + session.session.recording_date
    df_trials = pd.DataFrame({
        'trial_start': trial_starts,
        'trial_end': trial_ends
    }).sort_values('trial_start')

    # Assign each spike to the most recent trial_start <= timestamp
    df_merged = pd.merge_asof(
        df_spikes,
        df_trials[['trial_start', 'trial_end']],
        left_on='timestamp',
        right_on='trial_start',
        direction='backward'
    )

    # Drop spikes that fall outside their trial interval
    df_merged = df_merged[df_merged['timestamp'] < df_merged['trial_end']]

    # Compute bin index relative to global bin start
    df_merged['bin'] = ((df_merged['timestamp'] - global_start) / bin_size).astype(int)

    # Group by (bin, unit) and count spikes
    df_counts = (
        df_merged
        .groupby(['bin', 'unit'], observed=True)
        .size()
        .reset_index(name='count')
    )

    # Pivot to wide format: units as columns
    spk_cts = df_counts.pivot_table(
        index='bin',
        columns='unit',
        values='count',
        fill_value=0
    )

    # Convert to consistent timestamps
    session_timestamps = np.round(global_start + spk_cts.index * bin_size, n_decimals)
    spk_cts.index = pd.Index(session_timestamps, name='timestamp')
    spk_cts.columns.name = None

    # Collect results
    binned_dfs.append(spk_cts)

# Concatenate all binned trials across sessions
spk_cts_df = pd.concat(binned_dfs)
spk_cts_df = spk_cts_df[~spk_cts_df.index.duplicated()]
spk_cts_df.sort_index(inplace=True)

# Result: each row = time bin, each col = unit, values = spike count
display(spk_cts_df)

In [None]:
# Compute mean firing rate (Hz) per unit
duration_sec = len(spk_cts_df) * bin_size
mean_firing_rates = spk_cts_df.sum(axis=0) / duration_sec  # spikes/sec

# Plot histogram
plt.figure(figsize=(8, 4))
plt.hist(mean_firing_rates, bins=30, edgecolor='black')
plt.xlabel('Mean Firing Rate (Hz)')
plt.ylabel('Number of Units')
plt.title('Distribution of Mean Firing Rates')
plt.grid(True)
plt.tight_layout()
plt.show()

# Print summary stats
print("Mean firing rate across units: {:.2f} Hz".format(mean_firing_rates.mean()))
print("Median: {:.2f} Hz".format(np.median(mean_firing_rates)))
print("Range: {:.2f}–{:.2f} Hz".format(mean_firing_rates.min(), mean_firing_rates.max()))

In [None]:
# Flatten the spike count matrix to a 1D array
flattened_spike_counts = spk_cts_df.values.flatten()

# Plot histogram of spike counts
plt.figure(figsize=(8, 5))
plt.hist(flattened_spike_counts, bins=50, edgecolor='black')
plt.title("Distribution of Spike Counts per Unit per Time Bin")
plt.xlabel("Spike Count")
plt.ylabel("Number of Unit-Bin Combinations")
plt.yscale("log")  # Optional: log scale to better visualize skewed distributions
plt.grid(True)
plt.tight_layout()
plt.show()

In [None]:
"""Check sparsity of binned spike counts."""

frac_nonzero_bins = (spk_cts_df != 0).values.sum() / spk_cts_df.size
frac_nonzero_examples = (spk_cts_df.sum(axis=1) > 0).mean()
print(f"{frac_nonzero_bins=:.4f}")
print(f"{frac_nonzero_examples=:.4f}")

In [None]:
"""Load spikes and set sae config."""

# gpu for training
device = t.device("cuda" if t.cuda.is_available() else "cpu")
print(f"{device=}")

spk_cts = t.from_numpy(spk_cts_df.to_numpy()).bfloat16().to(device)
spk_cts /= spk_cts.max()  # max normalize spike counts

dsae_topk_map = {256: 8, 512: 16, 1024: 24}
dsae_topk_map = dict(sorted(dsae_topk_map.items()))  # ensure sorted from smallest to largest
dsae_loss_x_map = {256: 1, 512: 1.25, 1024: 1.5}
dsae_loss_x_map = dict(sorted(dsae_loss_x_map.items()))
# dsae_topk_map = {1024: 12, 2048: 24, 4096: 48}
dsae = max(dsae_topk_map.keys())
n_inst = 2

display(spk_cts)

## Train MSAE

In [None]:

sae_cfg = mt.SaeConfig(
    n_input_ae=spk_cts.shape[1],
    dsae_topk_map=dsae_topk_map,
    dsae_loss_x_map=dsae_loss_x_map,
    seq_len=1,
    n_instances=n_inst,
)
sae = mt.Sae(sae_cfg).to(device)
loss_fn = mt.msle
tau = 1.0
lr = 5e-3

n_epochs = 20
batch_sz = 1024
n_steps = spk_cts.shape[0] // batch_sz * n_epochs
log_freq = n_steps // n_epochs // 2
dead_neuron_window = n_steps // n_epochs // 3

data_log = mt.optimize(  # train model
    spk_cts=spk_cts,
    sae=sae,
    loss_fn=loss_fn,
    optimizer=t.optim.Adam(sae.parameters(), lr=lr),
    use_lr_sched=True,
    dead_neuron_window=dead_neuron_window,
    n_steps=n_steps,
    log_freq=log_freq,
    batch_sz=batch_sz,
    log_wandb=False,
    plot_l0=False,
    tau=tau
)

## Validate SAEs

In [None]:
"""Check for nans in weights."""

sae.W_dec.isnan().sum(), sae.W_enc.isnan().sum()

In [None]:
"""Visualize weights."""

fig, ax = plt.subplots(figsize=(8, 6))
for inst in range(n_inst):
    W_dec_flat = asnumpy(sae.W_dec[inst].float()).ravel()
    sns.histplot(W_dec_flat, bins=1000, stat="probability", alpha=0.7, label=f"SAE {inst}")
    
ax.set_title("SAE decoder weights")
ax.set_xlabel("Weight value")
ax.set_ylabel("Frequency")
ax.legend()

In [None]:
"""Visualize metrics over all examples and units."""

topk_acts_4d, recon_spk_cts, r2_per_unit, _, cossim_per_unit, _ = mt.eval_model(spk_cts, sae, batch_sz=batch_sz)

In [None]:
"""Calculate variance explained of summed spike counts."""

# Var explained for summed spike counts.
n_recon_examples = recon_spk_cts.shape[0]
recon_summed_spk_cts = reduce(recon_spk_cts, "example inst unit -> example inst", "sum")
actual_summed_spk_cts = reduce(spk_cts, "example unit -> example", "sum")
actual_summed_spk_cts = actual_summed_spk_cts[0:n_recon_examples]  # trim to match
for inst in range(n_inst):
    r2 = r2_score(
        asnumpy(actual_summed_spk_cts.float()), asnumpy(recon_summed_spk_cts[:, inst].float())
    )
    print(f"SAE instance {inst} R² (summed spike count over all units per example) = {r2:.3f}")

In [None]:
# If cosine similarity is high but r2 is low, it suggests that the model is capturing the structure of the data but not the magnitude.
# Calculate scale ratio of norms to check this

# Expand spk_cts to shape [n_examples, 1, n_units]
spk_cts_exp = spk_cts[:recon_spk_cts.shape[0]].unsqueeze(1)

# Compute norms
true_norms = t.norm(spk_cts_exp, dim=-1)               # shape: [n_examples, 1]
recon_norms = t.norm(recon_spk_cts, dim=-1)            # shape: [n_examples, n_instances]

# Compute scale ratio per example and instance
scale = true_norms / recon_norms                       # shape: [n_examples, n_instances]

print(scale.mean(dim=0))  # if it’s consistently >1 or <1, your model is biased in magnitude

In [None]:
spk_cts_trimmed = spk_cts[:recon_spk_cts.shape[0]]
bias = (recon_spk_cts - spk_cts_trimmed.unsqueeze(1)).mean(dim=0)
print(bias.mean(dim=0))  # mean bias per unit, averaged across examples

In [None]:
true_var = spk_cts_trimmed.var(dim=0).mean()
pred_var = recon_spk_cts.var(dim=0).mean()
print(f"True variance: {true_var.item():.4f}, Pred variance: {pred_var.item():.4f}")

### Remove bad units and retrain.

In [None]:

# Set threshold for removing units
r2_thresh = 0.1
inst = 0
r2_per_unit = r2_per_unit[:, inst]
keep_mask = r2_per_unit > r2_thresh
print(f"frac units above {r2_thresh=}: {keep_mask.sum() / keep_mask.shape[0]:.2f}")
print(f"Number to keep: {keep_mask.sum()} / {keep_mask.shape[0]}")

# Remove units and retrain
spk_cts = spk_cts[:, keep_mask]

sae_cfg = mt.SaeConfig(
    n_input_ae=spk_cts.shape[1],
    dsae_topk_map=dsae_topk_map,
    dsae_loss_x_map=dsae_loss_x_map,
    seq_len=1,
    n_instances=n_inst,
)
sae = mt.Sae(sae_cfg).to(device)
loss_fn = mt.msle
tau = 1.0
lr = 5e-3

n_epochs = 20
batch_sz = 1024
n_steps = spk_cts.shape[0] // batch_sz * n_epochs
log_freq = n_steps // n_epochs // 2
dead_neuron_window = n_steps // n_epochs // 3

data_log = mt.optimize(  # train model
    spk_cts=spk_cts,
    sae=sae,
    loss_fn=loss_fn,
    optimizer=t.optim.Adam(sae.parameters(), lr=lr),
    use_lr_sched=True,
    dead_neuron_window=dead_neuron_window,
    n_steps=n_steps,
    log_freq=log_freq,
    batch_sz=batch_sz,
    log_wandb=False,
    plot_l0=False,
    tau=tau
)

In [None]:
"""Re-Visualize metrics over all examples and units."""

topk_acts_4d, recon_spk_cts, r2_per_unit, _, cossim_per_unit, _ = mt.eval_model(spk_cts, sae, batch_sz=batch_sz)

n_recon_examples = recon_spk_cts.shape[0]
recon_summed_spk_cts = reduce(recon_spk_cts, "example inst unit -> example inst", "sum")
actual_summed_spk_cts = reduce(spk_cts, "example unit -> example", "sum")
actual_summed_spk_cts = actual_summed_spk_cts[0:n_recon_examples]  # trim to match
for inst in range(n_inst):
    r2 = r2_score(
        asnumpy(actual_summed_spk_cts.float()), asnumpy(recon_summed_spk_cts[:, inst].float())
    )
    print(f"SAE instance {inst} R² (summed spike count over all units per example) = {r2:.3f}")


## Get environment / behavior (meta)data

In [None]:
# Collect data from all sessions into arrays
hand_data = defaultdict(list)
eye_data = defaultdict(list)
trial_data = defaultdict(list)

for i, session in enumerate(sessions):
    recording_date = session.session.recording_date
    
    # Hand data - extract timestamps and 2D arrays
    timestamps = session.hand.timestamps + recording_date
    acc = session.hand.acc_2d
    pos = session.hand.pos_2d
    vel = session.hand.vel_2d
    
    hand_data['timestamp'].append(timestamps)
    hand_data['acc_x'].append(acc[:, 0])
    hand_data['acc_y'].append(acc[:, 1])
    hand_data['pos_x'].append(pos[:, 0])
    hand_data['pos_y'].append(pos[:, 1])
    hand_data['vel_x'].append(vel[:, 0])
    hand_data['vel_y'].append(vel[:, 1])
    hand_data['session'].append(np.full(len(timestamps), i))
    
    # Eye data - extract timestamps and position arrays
    eye_timestamps = session.eye.timestamps + recording_date
    eye_pos = session.eye.pos
    
    eye_data['timestamp'].append(eye_timestamps)
    eye_data['pos_x'].append(eye_pos[:, 0])
    eye_data['pos_y'].append(eye_pos[:, 1])
    
    # Trial data - add recording_date to all time columns
    trial_data['start'].append(session.trials.start + recording_date)
    trial_data['end'].append(session.trials.end + recording_date)
    trial_data['target_on_time'].append(session.trials.target_on_time + recording_date)
    trial_data['go_cue_time'].append(session.trials.go_cue_time + recording_date)
    trial_data['move_begins_time'].append(session.trials.move_begins_time + recording_date)
    trial_data['move_ends_time'].append(session.trials.move_ends_time + recording_date)
    trial_data['maze_condition'].append(session.trials.maze_condition)
    trial_data['barriers'].append(session.trials.maze_num_barriers)
    trial_data['targets'].append(session.trials.maze_num_targets)
    trial_data['hit_position_x'].append([pos[0] for pos in session.trials.hit_target_position])
    trial_data['hit_position_y'].append([pos[1] for pos in session.trials.hit_target_position])
    trial_data['hit_position_angle'].append([np.degrees(np.arctan2(pos[1], pos[0])) for pos in session.trials.hit_target_position])

# Concatenate all arrays into final datasets
combined_hand_data = {}
for key, arrays in hand_data.items():
    combined_hand_data[key] = np.concatenate(arrays)

combined_eye_data = {}
for key, arrays in eye_data.items():
    combined_eye_data[key] = np.concatenate(arrays)

combined_trial_data = {}
for key, arrays in trial_data.items():
    combined_trial_data[key] = np.concatenate(arrays)

# Create final DataFrames
combined_hand_df = pd.DataFrame(combined_hand_data).set_index('timestamp')
combined_eye_df = pd.DataFrame(combined_eye_data).set_index('timestamp')
combined_trials_df = pd.DataFrame(combined_trial_data)

# Create unified timestamp index from all data sources
all_event_ts = np.concatenate([
    combined_trials_df['target_on_time'].values,
    combined_trials_df['go_cue_time'].values,
    combined_trials_df['move_begins_time'].values,
    combined_trials_df['move_ends_time'].values,
])

# Get unique timestamps across all data
all_ts = np.unique(np.concatenate([
    combined_hand_df.index.values,
    combined_eye_df.index.values,
    all_event_ts
]))

# Create master dataframe with unified timestamp index
metadata = pd.DataFrame(index=all_ts)
metadata.index.name = 'timestamp'

# Merge hand and eye data
metadata = metadata.join(combined_hand_df, how='left')
metadata = metadata.join(combined_eye_df, how='left', rsuffix='_eye')

# Add event column - mark timestamps that correspond to trial events
event_map = {
    'target_on_time': 'target_on',
    'go_cue_time': 'go_cue',
    'move_begins_time': 'move_begins',
    'move_ends_time': 'move_ends',
}
event_col = pd.Series(index=metadata.index, dtype="object")
for col, label in event_map.items():
    event_times = combined_trials_df[col].values
    mask = np.isin(metadata.index.values, event_times)
    event_col.iloc[mask] = label
metadata['event'] = event_col

# Add trial_idx column - assign each timestamp to its trial
# Use binary search to efficiently find which trial each timestamp belongs to
trial_idx_series = pd.Series(index=metadata.index, dtype='float64')

# Sort trials by start time for binary search
trial_sort_idx = np.argsort(combined_trials_df['start'].values)
starts = combined_trials_df['start'].values[trial_sort_idx]
ends = combined_trials_df['end'].values[trial_sort_idx]

# Find potential trial for each timestamp
timestamps = metadata.index.values
start_positions = np.searchsorted(starts, timestamps, side='right') - 1

# Check which timestamps are within valid trial intervals
valid_mask = (start_positions >= 0) & (start_positions < len(starts))
valid_positions = start_positions[valid_mask]
valid_timestamps = timestamps[valid_mask]

# Verify timestamps are before trial end times
end_mask = valid_timestamps <= ends[valid_positions]
final_valid_mask = np.zeros(len(timestamps), dtype=bool)
final_valid_mask[valid_mask] = end_mask

# Assign trial indices to timestamps
trial_indices = np.full(len(timestamps), np.nan)
trial_indices[final_valid_mask] = trial_sort_idx[valid_positions[end_mask]]
trial_idx_series.iloc[:] = trial_indices

metadata['trial_idx'] = trial_idx_series

# Map trial properties using the trial indices
metadata['maze_condition'] = metadata['trial_idx'].astype('Int64').map(
    combined_trials_df['maze_condition']
)
metadata['barriers'] = metadata['trial_idx'].astype('Int64').map(
    combined_trials_df['barriers']
)
metadata['targets'] = metadata['trial_idx'].astype('Int64').map(
    combined_trials_df['targets']
)
metadata['hit_position_x'] = metadata['trial_idx'].astype('Int64').map(
    combined_trials_df['hit_position_x']
)
metadata['hit_position_y'] = metadata['trial_idx'].astype('Int64').map(
    combined_trials_df['hit_position_y']
)
metadata['hit_position_angle'] = metadata['trial_idx'].astype('Int64').map(
    combined_trials_df['hit_position_angle']
)

# Add movement_angle column based on position difference
pos_delta_x = metadata['pos_x'].diff()
pos_delta_y = metadata['pos_y'].diff()
metadata['movement_angle'] = np.degrees(np.arctan2(pos_delta_y, pos_delta_x))

# Calculate speed and acceleration magnitudes from vector components
metadata['vel_magnitude'] = np.sqrt(
    metadata['vel_x'].values**2 + metadata['vel_y'].values**2
)
metadata['accel_magnitude'] = np.sqrt(
    metadata['acc_x'].values**2 + metadata['acc_y'].values**2
)

# Show result
print("Metadata:")
display(metadata)

In [None]:
# Create metadata_binned with consistent timestamps
ts = spk_cts_df.index.values
metadata_binned = pd.DataFrame(index=pd.Index(ts, name='timestamp'))

# Assign each metadata row to a bin index
bin_ids = np.digitize(metadata.index.values, ts) - 1
bin_ids = np.clip(bin_ids, 0, len(ts) - 1)

# Handle event aggregation 
event_values = metadata['event'].values
event_mask = pd.notna(event_values)

if event_mask.any():
    event_bin_ids = bin_ids[event_mask]
    valid_events = event_values[event_mask].astype(str)
    
    # Create event assignments using vectorized operations
    event_agg = np.full(len(ts), None, dtype=object)
    
    # Calculate target bins for all events at once
    event_indices = np.arange(len(valid_events))
    target_bins = event_bin_ids + event_indices
    
    # Only assign events that fall within valid bin range
    valid_targets = target_bins < len(ts)
    event_agg[target_bins[valid_targets]] = valid_events[valid_targets]
else:
    event_agg = np.full(len(ts), None, dtype=object)

# Efficient nearest neighbor reindexing using searchsorted
metadata_timestamps = metadata.index.values
ts_positions = np.searchsorted(metadata_timestamps, ts, side='left')

# Handle edge cases and find true nearest neighbors
ts_positions = np.clip(ts_positions, 0, len(metadata_timestamps) - 1)

# For positions not at the start, check if the previous position is closer
mask = ts_positions > 0
left_positions = ts_positions.copy()
left_positions[mask] = ts_positions[mask] - 1

# Calculate distances to determine nearest
left_distances = np.abs(ts - metadata_timestamps[left_positions])
right_distances = np.abs(ts - metadata_timestamps[ts_positions])

# Choose the nearest position
final_positions = np.where(left_distances < right_distances, left_positions, ts_positions)

# Copy all columns from nearest metadata
for col in metadata.columns:
    if col != 'event':
        metadata_binned[col] = metadata[col].iloc[final_positions].values

# Insert distributed events
metadata_binned['event'] = event_agg

# Result: metadata for all trial bins, one row per bin
print("Metadata binned:")
display(metadata_binned)

In [None]:
def plot_hand_trajectories_by_maze_condition(metadata_binned, max_conditions=None, figsize_per_plot=(3, 3)):
    """
    Plot hand trajectories grouped by maze condition in an optimized grid layout.
    
    Parameters:
    metadata_binned (pd.DataFrame): DataFrame with hand position data and maze conditions
    max_conditions (int): Maximum number of conditions to plot (None for all)
    figsize_per_plot (tuple): Size of each individual subplot
    """
    import matplotlib.pyplot as plt
    import numpy as np
    
    # Filter out rows without valid position data or maze condition
    valid_data = metadata_binned.dropna(subset=['pos_x', 'pos_y', 'maze_condition', 'trial_idx'])
    
    if len(valid_data) == 0:
        print("No valid data found for plotting trajectories")
        return
    
    # Get unique maze conditions
    maze_conditions = sorted(valid_data['maze_condition'].unique())
    
    # Limit conditions if specified
    if max_conditions is not None:
        maze_conditions = maze_conditions[:max_conditions]
        print(f"Showing first {len(maze_conditions)} of {len(valid_data['maze_condition'].unique())} conditions")
    
    n_conditions = len(maze_conditions)
    
    # Calculate grid dimensions with max 6 columns
    cols = min(6, int(np.ceil(np.sqrt(n_conditions))))
    rows = int(np.ceil(n_conditions / cols))
    
    # Create figure with appropriate size
    fig_width = cols * figsize_per_plot[0]
    fig_height = rows * figsize_per_plot[1]
    fig, axes = plt.subplots(rows, cols, figsize=(fig_width, fig_height))
    
    # Flatten axes array for easy indexing
    if n_conditions == 1:
        axes = [axes]
    else:
        axes = axes.flatten() if hasattr(axes, 'flatten') else [axes]
    
    # Pre-filter data by conditions for efficiency
    condition_data = {condition: valid_data[valid_data['maze_condition'] == condition] 
                     for condition in maze_conditions}
    
    # Color palette
    colors = plt.cm.Set3(np.linspace(0, 1, 12))  # Use Set3 for better distinction
    
    # Process each condition
    for i, condition in enumerate(maze_conditions):
        ax = axes[i]
        data = condition_data[condition]
        
        # Group by trial for efficient processing
        trial_groups = data.groupby('trial_idx')
        trial_count = 0
        
        for trial_idx, trial_data in trial_groups:
            if len(trial_data) > 1:  # Only plot if we have multiple points
                # Sort by timestamp for proper trajectory
                trial_data = trial_data.sort_index()
                
                color = colors[trial_count % len(colors)]
                
                # Plot trajectory as single line (much faster than individual segments)
                ax.plot(trial_data['pos_x'].values, trial_data['pos_y'].values, 
                       alpha=0.6, linewidth=1, color=color)
                
                # Mark start and end points with clear distinction
                ax.scatter(trial_data['pos_x'].iloc[0], trial_data['pos_y'].iloc[0], 
                          color='green', marker='o', s=40, alpha=0.9, 
                          edgecolor='darkgreen', linewidth=1.5, zorder=5)
                ax.scatter(trial_data['pos_x'].iloc[-1], trial_data['pos_y'].iloc[-1], 
                          color='red', marker='X', s=50, alpha=0.9, 
                          edgecolor='darkred', linewidth=1.5, zorder=5)
                
                trial_count += 1
        
        # Formatting
        ax.set_title(f'Condition {condition}\n({trial_count} trials)', fontsize=10, pad=10)
        ax.set_xlabel('X Position', fontsize=8)
        ax.set_ylabel('Y Position', fontsize=8)
        ax.set_xlim(-150, 150)
        ax.set_ylim(-150, 150)
        ax.tick_params(axis='both', which='major', labelsize=7)
        ax.grid(True, alpha=0.3, linewidth=0.5)
        ax.set_aspect('equal')
    
    # Hide unused subplots
    for i in range(n_conditions, len(axes)):
        axes[i].set_visible(False)
    
    # Add overall title and legend
    fig.suptitle('Hand Trajectories by Maze Condition', fontsize=14, fontweight='bold', y=0.95)
    
    # Add a single legend to the figure
    legend_elements = [
        plt.Line2D([0], [0], marker='o', color='green', linestyle='None',
                  markersize=8, label='Start', markerfacecolor='green', 
                  markeredgecolor='darkgreen', markeredgewidth=1.5),
        plt.Line2D([0], [0], marker='X', color='red', linestyle='None',
                  markersize=10, label='End', markerfacecolor='red',
                  markeredgecolor='darkred', markeredgewidth=1.5)
    ]
    fig.legend(handles=legend_elements, loc='upper right', bbox_to_anchor=(0.99, 0.93), fontsize=10)
    
    plt.tight_layout()
    plt.subplots_adjust(top=0.9)  # Make room for suptitle
    plt.show()

# Function call - show first 20 conditions for manageable viewing
plot_hand_trajectories_by_maze_condition(metadata_binned, max_conditions=108)

## Interpret features

In [None]:
"""Set vars for saving / loading feature act data!"""

load_acts = True  # if True, will load saved feature activations
save_acts = False
sae_datafile = "sae_0.csv"  # set name of data file to load (or save to)

session_dates = []
for session in sessions:
    session_date = datetime.fromtimestamp(session.session.recording_date)
    session_date = session_date.strftime("%Y%m%d")
    session_dates.append(session_date)
session_dates_str = "_".join(session_dates)
session_dates_str

acts_df_save_path = data_path / f"{subject_name}_{session_dates_str}" / "sae_features" / sae_datafile
acts_df_save_path.parent.mkdir(parents=True, exist_ok=True)
keep_mask_save_path = data_path / f"{subject_name}_{session_dates_str}" / "sae_features" / "keep_mask.npy"

In [None]:
"""Load feature activations data."""

if load_acts:
    acts_df = pd.read_csv(acts_df_save_path)
    n_recon_examples = int(acts_df.iloc[-1]["example_idx"]) + 1

In [None]:
"""Create dfs of feature activations."""

if not load_acts:

    # Convert tensor to numpy array for easier handling
    acts_array = asnumpy(topk_acts_4d)

    # Create DataFrame with the data
    acts_df = pd.DataFrame({
        "example_idx": acts_array[:, 0],
        "instance_idx": acts_array[:, 1],
        "feature_idx": acts_array[:, 2],
        "activation_value": acts_array[:, 3]
    })

    # Convert appropriate cols to ints
    acts_df["example_idx"] = acts_df["example_idx"].astype(int)
    acts_df["feature_idx"] = acts_df["feature_idx"].astype(int)
    acts_df["instance_idx"] = acts_df["instance_idx"].astype(int)

    n_recon_examples = int(acts_df.iloc[-1]["example_idx"]) + 1

In [None]:
"""Create df with info per feature."""

# Get a features df from acts df
features_df = acts_df.groupby(["instance_idx", "feature_idx"]).agg(
    activation_mean=("activation_value", "mean"),
    activation_std=("activation_value", "std"),
    activation_count=("activation_value", "count")
).reset_index()
features_df["act_mean_over_std"] = features_df["activation_mean"] / features_df["activation_std"]
features_df["activation_frac"] = features_df["activation_count"] / n_recon_examples
features_df = features_df.drop(columns=["activation_count"])
features_df = features_df.dropna().reset_index(drop=True)

if not load_acts:
    # Keep only features from features_df in acts_df
    acts_df = acts_df[acts_df["feature_idx"].isin(features_df["feature_idx"])].reset_index(drop=True)

if save_acts:
    acts_df.to_csv(acts_df_save_path, index=False)
    np.save(keep_mask_save_path, keep_mask)

In [None]:
display(acts_df)
display(features_df)

In [None]:
"""Compare features from ("broad" or "general") and ("specific" or "nested") groups"""

last_feat_idx_general = list(dsae_topk_map.keys())[0]
first_feat_idx_specific = list(dsae_topk_map.keys())[1]

# general
print(features_df[features_df["feature_idx"] < last_feat_idx_general]["activation_frac"].describe())
print()
print(features_df[np.logical_and(
    features_df["feature_idx"] > last_feat_idx_general,
    features_df["feature_idx"] < first_feat_idx_specific
)]["activation_frac"].describe())
print()
# specific
print(features_df[features_df["feature_idx"] > first_feat_idx_specific]["activation_frac"].describe())

### Hunt for features

In [None]:
"""Automatically map features to metadata"""

def analyze_discrete_variable(
    acts_df: pd.DataFrame,
    metadata_binned: pd.DataFrame,
    variable: str,
    min_activation_frac: float
) -> List[Dict]:
    """
    Analyzes a discrete variable, calculating activation metrics for features meeting a minimum activation fraction.
    This version is corrected to handle features that are 100% selective for an event.
    """
    results = []
    unique_values = metadata_binned[variable].dropna().unique()

    for value in unique_values:
        try:
            event_idxs = np.where(metadata_binned[variable] == value)[0]
            if len(event_idxs) == 0:
                continue

            event_acts_df = acts_df[acts_df["example_idx"].isin(event_idxs)]
            if len(event_acts_df) == 0:
                continue

            event_features_df = event_acts_df.groupby(["instance_idx", "feature_idx"]).agg(
                activation_count=("activation_value", "count")
            ).reset_index()
            n_event_examples = len(event_idxs)
            event_features_df["activation_frac_event"] = event_features_df["activation_count"] / n_event_examples

            promising_features = event_features_df[event_features_df["activation_frac_event"] >= min_activation_frac]
            if promising_features.empty:
                continue

            non_event_mask = ~acts_df["example_idx"].isin(event_idxs)
            non_event_acts_df = acts_df[non_event_mask].merge(
                promising_features[["instance_idx", "feature_idx"]],
                on=["instance_idx", "feature_idx"], how="inner"
            )

            if not non_event_acts_df.empty:
                non_event_features_df = non_event_acts_df.groupby(["instance_idx", "feature_idx"]).agg(
                    activation_count=("activation_value", "count")
                ).reset_index()
                n_non_event_examples = len(metadata_binned) - n_event_examples
                non_event_features_df["activation_frac_non_event"] = non_event_features_df["activation_count"] / n_non_event_examples
                ratio_df = promising_features.merge(
                    non_event_features_df, on=["instance_idx", "feature_idx"], how="left"
                )
                ratio_df["activation_frac_non_event"] = ratio_df["activation_frac_non_event"].fillna(0.0)
            else:
                ratio_df = promising_features.copy()
                ratio_df["activation_frac_non_event"] = 0.0
            
            ratio_df["activation_ratio"] = ratio_df["activation_frac_event"] / (ratio_df["activation_frac_non_event"] + 1e-9)
            ratio_df["rate_proportion"] = ratio_df["activation_frac_event"] / (ratio_df["activation_frac_event"] + ratio_df["activation_frac_non_event"])

            for _, row in ratio_df.iterrows():
                results.append({
                    'variable': variable, 'variable_type': 'discrete', 'value': value,
                    'instance_idx': row['instance_idx'], 'feature_idx': row['feature_idx'],
                    'activation_ratio': row['activation_ratio'],
                    'activation_frac_during': row['activation_frac_event'],
                    'activation_frac_outside': row['activation_frac_non_event'],
                    'rate_proportion': row['rate_proportion']
                })
        except Exception as e:
            print(f"Could not analyze {variable}={value}: {e}")
            continue
    return results

def analyze_continuous_variable(
    acts_df: pd.DataFrame,
    metadata_binned: pd.DataFrame,
    variable: str,
    n_bins: int,
    min_activation_frac: float
) -> List[Dict]:
    """
    Analyzes a continuous variable by binning it and then using the discrete analysis method.
    """
    print(f"  Binning '{variable}' into {n_bins} bins...")
    binned_col_name = f"{variable}_binned"

    data_to_bin = metadata_binned[variable].dropna()
    if data_to_bin.empty:
        return []

    if variable == 'movement_angle':
        bins = np.linspace(-180, 180, n_bins + 1)
        labels = [f"({bins[i]:.0f}, {bins[i+1]:.0f}]" for i in range(n_bins)]
        metadata_binned[binned_col_name] = pd.cut(data_to_bin, bins=bins, labels=labels, include_lowest=True)
    else:
        metadata_binned[binned_col_name] = pd.qcut(data_to_bin, q=n_bins, labels=None, duplicates='drop')

    results = analyze_discrete_variable(acts_df, metadata_binned, binned_col_name, min_activation_frac)

    for res in results:
        res['variable'] = variable
        res['variable_type'] = 'continuous'

    return results

def map_features_to_metadata(
    acts_df: pd.DataFrame,
    metadata_binned: pd.DataFrame,
    discrete_vars: List[str] = None,
    continuous_vars: List[str] = None,
    exclude_columns: List[str] = None,
    min_activation_frac: float = 0.1,
    n_bins_continuous: int = 10,
    top_n_features: int = 3
) -> pd.DataFrame:  
    """
    Automatically maps SAE features to metadata by finding top N features for each condition.
    Returns a single DataFrame with both discrete and continuous results.
    """
    if discrete_vars is None: discrete_vars = []
    if continuous_vars is None: continuous_vars = []
    if exclude_columns is None: exclude_columns = ['trial_idx', 'session']
    
    all_results = []
    print("🚀 Starting automated feature-to-metadata mapping...")
    
    for variable in metadata_binned.columns:
        if variable in exclude_columns:
            continue
        print(f"\nAnalyzing variable: {variable}")
        
        if variable in discrete_vars:
            print(f"  Treating as: discrete")
            results = analyze_discrete_variable(acts_df, metadata_binned, variable, min_activation_frac)
            all_results.extend(results)
        elif variable in continuous_vars:
            print(f"  Treating as: continuous")
            results = analyze_continuous_variable(acts_df, metadata_binned, variable, n_bins=n_bins_continuous, min_activation_frac=min_activation_frac)
            all_results.extend(results)
        else:
            print(f"  Skipping (not in discrete_vars or continuous_vars list)")
            continue
        print(f"  Found {len(results)} potential associations.")
    
    if not all_results:
        print("\nNo associations found meeting the minimum activation fraction!")
        return pd.DataFrame()  # Return empty DataFrame instead of tuple
    
    results_df = pd.DataFrame(all_results)
    results_df['value'] = results_df['value'].astype(str)
    
    print(f"\nRanking features and selecting top {top_n_features} for each condition...")
    ranked_df = (
        results_df.sort_values('activation_ratio', ascending=False)
        .groupby(['variable', 'value', 'instance_idx'])
        .head(top_n_features)
    )
    
    # Sort the combined results
    sort_order = ['variable_type', 'variable', 'value', 'instance_idx', 'activation_ratio']
    ascending_order = [True, True, True, True, False]
    
    final_df = ranked_df.sort_values(by=sort_order, ascending=ascending_order).reset_index(drop=True)
    
    discrete_count = len(final_df[final_df['variable_type'] == 'discrete'])
    continuous_count = len(final_df[final_df['variable_type'] == 'continuous'])
    
    print(f"\n✅ Found {discrete_count} top discrete associations.")
    print(f"✅ Found {continuous_count} top continuous associations.")
    print(f"✅ Total: {len(final_df)} associations returned in single DataFrame.")
    
    return final_df

discrete_vars = ['event', 'maze_condition', 'barriers', 'targets', 'hit_position_x', 'hit_position_y', 'hit_position_angle']
continuous_vars = ['vel_magnitude', 'accel_magnitude', 'movement_angle']

results = map_features_to_metadata(
    acts_df, metadata_binned,
    discrete_vars=discrete_vars,
    continuous_vars=continuous_vars,
    min_activation_frac=0.5,
    n_bins_continuous=12,
    top_n_features=3
)
display(results)

# # Optional filtering (ratio > 2.0, proportion > 0.5)
# results = results[
#     (results['activation_ratio'] > 2.0) & 
#     (results['rate_proportion'] > 0.5)
# ].reset_index(drop=True)
# display(results)

In [None]:
"""Calculate z-scores for spike counts across neurons."""
# Calculate mean and standard deviation for each neuron (column)
neuron_means = spk_cts_df.mean(axis=0)
neuron_stds = spk_cts_df.std(axis=0)

# Calculate z-scores
# Handle cases where standard deviation is zero to avoid division by zero
spk_z_scores_df = spk_cts_df.sub(neuron_means, axis=1).div(neuron_stds, axis=1)
spk_z_scores_df = spk_z_scores_df.replace([np.inf, -np.inf], np.nan) # Replace inf with NaN for clarity

# Set z-score to 0 where standard deviation was 0 (and thus z-score would be NaN)
spk_z_scores_df = spk_z_scores_df.fillna(0.0)

display(spk_z_scores_df)


In [None]:
"""Visualisation functions for event-feature associations"""
def create_canonical_timeline(combined_trials_df, maze_conditions=None, hit_target_positions=None):
    """
    Create a canonical timeline based on average event durations from the data
    
    Parameters:
    - combined_trials_df: trial metadata with absolute timestamps
    - maze_conditions: list of maze conditions to include (None for all)
    - hit_target_positions: list of target positions to include (None for all)
    
    Returns:
    - canonical_events: dict with canonical event times
    - filtered_trials: dataframe with filtered trials
    """
    
    # Filter trials based on conditions
    filtered_trials = combined_trials_df.copy()
    
    if maze_conditions is not None:
        filtered_trials = filtered_trials[filtered_trials['maze_condition'].isin(maze_conditions)]
        print(f"Filtered to maze conditions: {maze_conditions}")
    
    if hit_target_positions is not None:
        # Handle the tuple format of hit_target_position
        if len(hit_target_positions) > 0 and not isinstance(hit_target_positions[0], tuple):
            # Convert to tuples if needed
            hit_target_positions = [tuple(pos) if isinstance(pos, (list, np.ndarray)) else pos 
                                  for pos in hit_target_positions]
        filtered_trials = filtered_trials[filtered_trials['hit_target_position'].isin(hit_target_positions)]
        print(f"Filtered to target positions: {hit_target_positions}")
    
    print(f"Using {len(filtered_trials)} trials (from {len(combined_trials_df)} total) to create canonical timeline")
    
    if len(filtered_trials) == 0:
        print("No trials match the filtering criteria!")
        return None, None
    
    # Calculate average durations between consecutive events
    events_sequence = ['start', 'target_on_time', 'go_cue_time', 'move_begins_time', 'move_ends_time', 'end']
    
    # Remove trials with missing events
    for event in events_sequence:
        filtered_trials = filtered_trials[filtered_trials[event].notna()]
    
    if len(filtered_trials) == 0:
        print("No trials have all required events!")
        return None, None
    
    print(f"Computing durations from {len(filtered_trials)} complete trials")
    
    # Calculate durations between consecutive events
    durations = {}
    for i in range(len(events_sequence) - 1):
        event1, event2 = events_sequence[i], events_sequence[i + 1]
        trial_durations = filtered_trials[event2] - filtered_trials[event1]
        avg_duration = trial_durations.mean()
        durations[f"{event1}_to_{event2}"] = avg_duration
        print(f"  {event1} to {event2}: {avg_duration:.3f}s (±{trial_durations.std():.3f})")
    
    # Build canonical timeline starting from 0
    canonical_events = {'start': 0.0}
    current_time = 0.0
    
    for i in range(len(events_sequence) - 1):
        event1, event2 = events_sequence[i], events_sequence[i + 1]
        duration_key = f"{event1}_to_{event2}"
        current_time += durations[duration_key]
        canonical_events[event2] = current_time
    
    print(f"\nCanonical timeline: {canonical_events}")
    print(f"Total canonical duration: {current_time:.3f}s")
    
    return canonical_events, filtered_trials

def warp_trials_to_canonical_timeline(combined_trials_df, acts_df, spk_z_scores_df, metadata_binned,
                                    instance_idx=0, feature_idx=None, 
                                    maze_conditions=None, hit_target_positions=None):
    """
    Warp filtered trials to a data-driven canonical timeline
    
    Parameters:
    - combined_trials_df: trial metadata with absolute timestamps
    - acts_df: feature activation dataframe  
    - spk_z_scores_df: z-scored spike count dataframe
    - metadata_binned: binned metadata
    - instance_idx: SAE instance to analyze
    - feature_idx: feature to analyze (if None, will use most active feature)
    - maze_conditions: list of maze conditions to include (None for all)
    - hit_target_positions: list of target positions to include (None for all)
    
    Returns:
    - warped_data: dict with warped activations and canonical timeline
    - feature_idx: feature index analyzed
    - top_unit: top unit index
    """
    
    # Create canonical timeline from filtered data
    canonical_events, filtered_trials = create_canonical_timeline(
        combined_trials_df, maze_conditions, hit_target_positions
    )
    
    if canonical_events is None:
        return None, None, None
    
    # If no feature specified, find most active feature for this instance
    if feature_idx is None:
        instance_acts = acts_df[acts_df['instance_idx'] == instance_idx]
        if len(instance_acts) == 0:
            print(f"No activations found for instance {instance_idx}")
            return None, None, None
        feature_counts = instance_acts['feature_idx'].value_counts()
        feature_idx = feature_counts.index[0]
        print(f"Using most active feature: {feature_idx} ({feature_counts.iloc[0]} activations)")
    
    # Get feature activations for this instance/feature
    feature_acts = acts_df[
        (acts_df['instance_idx'] == instance_idx) & 
        (acts_df['feature_idx'] == feature_idx)
    ].copy()
    
    # Find top unit for this feature based on z-scores when feature is active
    if len(feature_acts) > 0:
        feature_active_indices = feature_acts['example_idx'].values
        # Get z-scores when feature is active
        active_zscores = spk_z_scores_df.iloc[feature_active_indices]
        # Calculate mean z-score for each unit when feature is active
        unit_mean_zscores = active_zscores.mean(axis=0)
        # Find unit with highest mean z-score
        top_unit = unit_mean_zscores.idxmax()
        top_zscore = unit_mean_zscores[top_unit]
        print(f"Top co-active unit: {top_unit} (mean z-score when feature active: {top_zscore:.3f})")
    else:
        top_unit = spk_z_scores_df.mean().idxmax()
        print(f"No feature activations found, using unit with highest mean z-score: {top_unit}")
    
    print(f"Analyzing Feature {feature_idx}, Top Unit: {top_unit}")
    
    # Create canonical time axis ending exactly at trial end (no buffer)
    canonical_duration = max(canonical_events.values())
    canonical_time_axis = np.linspace(0, canonical_duration, int(canonical_duration / 0.05))
    
    print(f"Canonical duration: {canonical_duration:.1f}s, {len(canonical_time_axis)} bins")
    
    # Storage for warped data
    warped_feature_acts = []
    warped_unit_acts = []
    trial_info_list = []
    
    print("Warping trials...")
    valid_trials = 0
    
    for trial_idx, trial in filtered_trials.iterrows():
        # Get trial event times
        required_events = ['start', 'target_on_time', 'go_cue_time', 'move_begins_time', 'move_ends_time', 'end']
        trial_event_times = {}
        
        skip_trial = False
        for event in required_events:
            if pd.isna(trial[event]):
                skip_trial = True
                break
            trial_event_times[event] = trial[event]
        
        if skip_trial:
            continue
        
        # Get trial data from binned metadata (ending exactly at trial end)
        trial_start = trial['start']
        trial_end = trial['end']
        
        # Find bins for this trial
        trial_mask = (metadata_binned.index >= trial_start) & (metadata_binned.index <= trial_end)
        
        if not trial_mask.any():
            continue
        
        trial_bin_indices = np.where(trial_mask)[0]
        trial_timestamps = metadata_binned.index[trial_bin_indices].values
        
        # Create warping function from original time to canonical time
        original_event_times = np.array([trial_event_times[event] for event in required_events])
        canonical_event_times = np.array([canonical_events[event] for event in required_events])
        
        # Warp timestamps using piecewise linear interpolation
        warped_timestamps = np.interp(trial_timestamps, original_event_times, canonical_event_times)
        
        # Extract feature activations for this trial
        trial_feature_acts = np.zeros(len(canonical_time_axis))
        trial_feature_data = feature_acts[feature_acts['example_idx'].isin(trial_bin_indices)]
        
        for _, act in trial_feature_data.iterrows():
            bin_idx = int(act['example_idx'])
            bin_time = metadata_binned.index[bin_idx]
            warped_time = np.interp(bin_time, original_event_times, canonical_event_times)
            
            # Find closest canonical time point
            time_idx = np.argmin(np.abs(canonical_time_axis - warped_time))
            if 0 <= time_idx < len(trial_feature_acts):
                trial_feature_acts[time_idx] = act['activation_value']
        
        # Extract unit z-scores for this trial (using z-scored data)
        trial_unit_acts = np.zeros(len(canonical_time_axis))
        
        for bin_idx in trial_bin_indices:
            bin_idx = int(bin_idx)
            if bin_idx < len(spk_z_scores_df):
                bin_time = metadata_binned.index[bin_idx]
                warped_time = np.interp(bin_time, original_event_times, canonical_event_times)
                
                time_idx = np.argmin(np.abs(canonical_time_axis - warped_time))
                if 0 <= time_idx < len(trial_unit_acts):
                    trial_unit_acts[time_idx] = spk_z_scores_df.iloc[bin_idx][top_unit]
        
        # Store warped data
        warped_feature_acts.append(trial_feature_acts)
        warped_unit_acts.append(trial_unit_acts)
        
        # Store trial info
        trial_info = trial.copy()
        trial_info['trial_idx'] = trial_idx
        trial_info_list.append(trial_info)
        
        valid_trials += 1
    
    print(f"Successfully warped {valid_trials} trials")
    
    if valid_trials == 0:
        return None, feature_idx, top_unit
    
    # Convert to arrays
    warped_feature_acts = np.array(warped_feature_acts)
    warped_unit_acts = np.array(warped_unit_acts)
    trial_info_df = pd.DataFrame(trial_info_list).reset_index(drop=True)
    
    warped_data = {
        'feature_activations': warped_feature_acts,
        'unit_activations': warped_unit_acts,
        'trial_info': trial_info_df,
        'canonical_time_axis': canonical_time_axis,
        'canonical_events': canonical_events
    }
    
    return warped_data, feature_idx, top_unit

def plot_warped_trials(warped_data, instance_idx, feature_idx, top_unit,
                      highlight_trials=None, max_individual_trials=10,
                      smooth_window=None, show_event_regions=True):
    """
    Plot trial-warped feature and unit activations on canonical timeline
    
    Parameters:
    - warped_data: output from warp_trials_to_canonical_timeline
    - instance_idx: SAE instance index being analyzed
    - feature_idx: feature index being plotted
    - top_unit: top unit index  
    - highlight_trials: list of trial indices to highlight, or None for random selection
    - max_individual_trials: maximum number of individual trials to show
    - smooth_window: optional smoothing window size in bins
    - show_event_regions: whether to show colored regions for each epoch
    """
    
    if warped_data is None:
        print("No warped data available")
        return
    
    feature_acts = warped_data['feature_activations']
    unit_acts = warped_data['unit_activations']
    trial_info = warped_data['trial_info']
    time_axis = warped_data['canonical_time_axis']
    canonical_events = warped_data['canonical_events']
    
    print(f"Plotting {len(feature_acts)} warped trials")
    
    # Apply smoothing if requested
    if smooth_window is not None and smooth_window > 1:
        feature_acts = uniform_filter1d(feature_acts, size=smooth_window, axis=1)
        unit_acts = uniform_filter1d(unit_acts, size=smooth_window, axis=1)
    
    # Calculate statistics
    feature_mean = np.mean(feature_acts, axis=0)
    feature_sem = np.std(feature_acts, axis=0) / np.sqrt(len(feature_acts))
    
    unit_mean = np.mean(unit_acts, axis=0)
    unit_sem = np.std(unit_acts, axis=0) / np.sqrt(len(unit_acts))
    
    # Create single plot with two y-axes
    fig = go.Figure()
    
    # Add background regions for different epochs if requested
    if show_event_regions:
        event_times = list(canonical_events.values())
        event_names = list(canonical_events.keys())
        colors = ['rgba(255,200,200,0.2)', 'rgba(200,255,200,0.2)', 'rgba(200,200,255,0.2)', 
                 'rgba(255,255,200,0.2)', 'rgba(255,200,255,0.2)']
        
        for i in range(len(event_times)-1):
            fig.add_vrect(x0=event_times[i], x1=event_times[i+1], fillcolor=colors[i % len(colors)], layer="below", line_width=0)
    
    # Plot feature activations (primary y-axis)
    # SEM band
    fig.add_trace(go.Scatter(x=time_axis, y=feature_mean + feature_sem, mode='lines', line=dict(width=0), showlegend=False, hoverinfo='skip', name='upper_bound'))
    fig.add_trace(go.Scatter(x=time_axis, y=feature_mean - feature_sem, mode='lines', line=dict(width=0), fill='tonexty', fillcolor='rgba(0,100,80,0.3)', name='Feature ±SEM', showlegend=True))
    
    # Mean feature line
    fig.add_trace(go.Scatter(x=time_axis, y=feature_mean, mode='lines', line=dict(color='darkgreen', width=4), name=f'Mean Feature {feature_idx}', yaxis='y'))
    
    # Individual feature trials
    if highlight_trials is not None:
        trial_indices = [idx for idx in highlight_trials if idx < len(feature_acts)][:max_individual_trials]
        trial_label = "Selected"
    else:
        if len(feature_acts) <= max_individual_trials:
            trial_indices = list(range(len(feature_acts)))
        else:
            trial_indices = np.random.choice(len(feature_acts), max_individual_trials, replace=False)
        trial_label = "Random"
    
    for i, trial_idx in enumerate(trial_indices):
        trial_info_str = ""
        if 'maze_condition' in trial_info.columns:
            maze_cond = trial_info.iloc[trial_idx]['maze_condition']
            trial_info_str += f"Maze: {maze_cond}"
        if 'hit_target_position' in trial_info.columns:
            target_pos = trial_info.iloc[trial_idx]['hit_target_position']
            trial_info_str += f", Target: {target_pos}"
        
        fig.add_trace(go.Scatter(x=time_axis, y=feature_acts[trial_idx], mode='lines', line=dict(color='rgba(0,150,100,0.5)', width=1), name=f'{trial_label} Feature Trials' if i == 0 else None, showlegend=i == 0, legendgroup='individual_feature_trials', hovertemplate=f'Trial {trial_idx}<br>{trial_info_str}<br>Time: %{{x}}<br>Feature: %{{y}}<extra></extra>', yaxis='y'))
    
    # Plot unit z-scores (secondary y-axis)
    # SEM band for units
    fig.add_trace(go.Scatter(x=time_axis, y=unit_mean + unit_sem, mode='lines', line=dict(width=0), showlegend=False, hoverinfo='skip', yaxis='y2'))
    fig.add_trace(go.Scatter(x=time_axis, y=unit_mean - unit_sem, mode='lines', line=dict(width=0), fill='tonexty', fillcolor='rgba(0,80,150,0.3)', name='Unit Z-score ±SEM', showlegend=True, yaxis='y2'))
    
    # Mean unit line
    fig.add_trace(go.Scatter(x=time_axis, y=unit_mean, mode='lines', line=dict(color='darkblue', width=4), name=f'Mean Unit {top_unit} Z-score', yaxis='y2'))
    
    # Individual unit trials
    for i, trial_idx in enumerate(trial_indices):
        trial_info_str = ""
        if 'maze_condition' in trial_info.columns:
            maze_cond = trial_info.iloc[trial_idx]['maze_condition']
            trial_info_str += f"Maze: {maze_cond}"
        if 'hit_target_position' in trial_info.columns:
            target_pos = trial_info.iloc[trial_idx]['hit_target_position']
            trial_info_str += f", Target: {target_pos}"
            
        fig.add_trace(go.Scatter(x=time_axis, y=unit_acts[trial_idx], mode='lines', line=dict(color='rgba(100,100,255,0.5)', width=1), name=f'{trial_label} Unit Trials' if i == 0 else None, showlegend=i == 0, legendgroup='individual_unit_trials', hovertemplate=f'Trial {trial_idx}<br>{trial_info_str}<br>Time: %{{x}}<br>Unit Z-score: %{{y}}<extra></extra>', yaxis='y2'))
    
    # Add event lines
    for event_name, event_time in canonical_events.items():
        fig.add_vline(x=event_time, line_dash="dash", line_color="red", line_width=2, annotation_text=event_name.replace('_', ' ').title(), annotation_position="top")
    
    # Update layout with dual y-axes
    fig.update_layout(
        title=f"Instance {instance_idx} Feature {feature_idx} & Top Unit {top_unit}",
        xaxis_title="Canonical Time (s)",
        yaxis=dict(title="Feature Activation", side="left", color="darkgreen"),
        yaxis2=dict(title="Unit Z-score", side="right", overlaying="y", color="darkblue"),
        height=600,
        width=1400,
        legend=dict(x=0.02, y=0.98, bgcolor="rgba(255,255,255,0.8)")
    )
    
    fig.show()
    
    return fig

def explore_trial_conditions(combined_trials_df):
    """Explore available maze conditions and target positions for filtering"""
    print("\nAvailable Trial Conditions")
    
    print("\nMaze Conditions:")
    maze_counts = combined_trials_df['maze_condition'].value_counts().sort_index()
    for condition, count in maze_counts.items():
        print(f"  {condition}: {count} trials")
    
    print("\nHit Target Positions:")
    target_counts = combined_trials_df['hit_target_position'].value_counts()
    for position, count in target_counts.items():
        print(f"  {position}: {count} trials")
    
    print(f"\nTotal trials: {len(combined_trials_df)}")
    
    return maze_counts, target_counts

# Example usage
print("Trial Warping Analysis")
# First, explore what conditions are available
print("Exploring available trial conditions...")
maze_counts, target_counts = explore_trial_conditions(combined_trials_df)

# Parameters to customize
instance_to_analyze = 0
feature_to_analyze = 182
# FILTERING OPTIONS - Set these to filter trials
maze_conditions_to_include = None
target_positions_to_include = None
# OTHER PARAMETERS
num_example_trials = 8
smooth_data = 3
show_epochs = True

print(f"\nRunning analysis with filters:")
print(f"  Maze conditions: {maze_conditions_to_include}")
print(f"  Target positions: {target_positions_to_include}")

# Run warping analysis (now including spk_z_scores_df parameter)
warped_data, feature_idx, top_unit = warp_trials_to_canonical_timeline(
    combined_trials_df, acts_df, spk_z_scores_df, metadata_binned,
    instance_idx=instance_to_analyze, 
    feature_idx=feature_to_analyze,
    maze_conditions=maze_conditions_to_include,
    hit_target_positions=target_positions_to_include
)

if warped_data is not None:
    print(f"\nResults")
    print(f"Analyzed Feature: {feature_idx}")
    print(f"Top Co-active Unit: {top_unit}")
    
    # Create the warped trial plot
    plot_warped_trials(
        warped_data, instance_to_analyze, feature_idx, top_unit,
        highlight_trials=None,
        max_individual_trials=num_example_trials,
        smooth_window=smooth_data,
        show_event_regions=show_epochs
    )
else:
    print("Failed to warp trials - check your parameters and data.")

In [None]:
"""Interactive UI"""

# Helper to map a base variable and its type to the metadata_binned column name
def _bvar_name(var_name, var_type):
    return f"{var_name}_binned" if var_type == 'continuous' else var_name

# Mode selector: preset vs manual
mode_radio = widgets.RadioButtons(
    options=[
        ('Preset (from results table)', 'preset'),
        ('Manual selection',       'manual')
    ],
    value='preset',
    description=''
)

# Build the preset dropdown with full metrics
preset_entries = []
for _, r in results.iterrows():
    bvar = _bvar_name(r.variable, r.variable_type)
    if bvar not in metadata_binned.columns:
        continue
    label = (
        f"Inst:{int(r.instance_idx)} | "
        f"Feat:{int(r.feature_idx)} | "
        f"Var:{bvar} | "
        f"Val:{r['value']} | "
        f"FracDuring:{r.activation_frac_during:.3f} | "
        f"FracOutside:{r.activation_frac_outside:.3f} | "
        f"ActRatio:{r.activation_ratio:.3f} | "
        f"RateProp:{r.rate_proportion:.3f}"
    )
    preset_entries.append((label, (int(r.instance_idx), int(r.feature_idx), bvar)))

preset_dropdown = widgets.Dropdown(
    options=preset_entries,
    description='Select Result:',
    layout=widgets.Layout(width='80%')
)
preset_box = widgets.VBox([preset_dropdown])

# Manual instance & feature selection
instance_dropdown = widgets.Dropdown(
    options=sorted(acts_df['instance_idx'].unique()),
    description='Instance:'
)

feature_dropdown = widgets.Dropdown(
    options=[],
    description='Feature:'
)

def _on_instance_change(change):
    inst = change['new']
    feats = sorted(
        acts_df.loc[acts_df['instance_idx'] == inst, 'feature_idx'].unique()
    )
    feature_dropdown.options = feats

instance_dropdown.observe(_on_instance_change, names='value')
_on_instance_change({'new': instance_dropdown.value})

manual_box = widgets.VBox([instance_dropdown, feature_dropdown])
manual_box.layout.display = 'none'

# Maze condition and target position filters
maze_options = sorted(combined_trials_df['maze_condition'].dropna().unique())
maze_dropdown = widgets.Dropdown(
    options=[None] + maze_options,
    description='Maze cond:'
)

target_positions = combined_trials_df['hit_target_position'].dropna().unique()
target_strs = [str(pos) for pos in target_positions]
target_dropdown = widgets.Dropdown(
    options=[None] + target_strs,
    description='Target pos:'
)

# Toggle preset vs manual
def _on_mode_change(change):
    if change['new'] == 'preset':
        preset_box.layout.display = ''
        manual_box.layout.display = 'none'
    else:
        preset_box.layout.display = 'none'
        manual_box.layout.display = ''

mode_radio.observe(_on_mode_change, names='value')
_on_mode_change({'new': mode_radio.value})

# Generate button and output area
generate_btn = widgets.Button(description='Generate Plot', button_style='info')
out = widgets.Output()

def _on_generate(_):
    with out:
        clear_output()
        if mode_radio.value == 'preset':
            inst, feat, _ = preset_dropdown.value
        else:
            inst = instance_dropdown.value
            feat = feature_dropdown.value

        maze = [maze_dropdown.value] if maze_dropdown.value is not None else None
        tgt = target_dropdown.value
        hit_positions = [eval(tgt)] if tgt is not None else None

        warped_data, used_feat, top_unit = warp_trials_to_canonical_timeline(
            combined_trials_df,
            acts_df,
            spk_z_scores_df,
            metadata_binned,
            instance_idx=inst,
            feature_idx=feat,
            maze_conditions=maze,
            hit_target_positions=hit_positions
        )

        if warped_data is not None:
            plot_warped_trials(warped_data, inst, used_feat, top_unit)
        else:
            print("No data to display. Check your selections.")

generate_btn.on_click(_on_generate)

# Assemble and display the UI
ui = widgets.VBox([
    widgets.HTML("<h2>Warp Trials Visualization</h2>"),
    mode_radio,
    preset_box,
    manual_box,
    maze_dropdown,
    target_dropdown,
    generate_btn,
    out
])
display(ui)

In [None]:
"""Visualisation functions for all feature associations"""

def plot_feature_tuning(
    acts_df: pd.DataFrame,
    spk_z_scores_df: pd.DataFrame,
    metadata_binned: pd.DataFrame,
    variable: str,
    instance_idx: int,
    feature_idx: int
):
    """Visualizes SAE feature tuning to metadata variables."""
    # Get feature activations
    feature_acts = acts_df[(acts_df['instance_idx'] == instance_idx) & (acts_df['feature_idx'] == feature_idx)]
    
    # Find top/bottom co-active units
    if len(feature_acts) > 0:
        feature_active_indices = feature_acts['example_idx'].values
        neuron_mean_zscores = spk_z_scores_df.iloc[feature_active_indices].mean(axis=0)
        top_unit = neuron_mean_zscores.idxmax()
        bottom_unit = neuron_mean_zscores.idxmin()
        top_zscore = neuron_mean_zscores[top_unit]
        bottom_zscore = neuron_mean_zscores[bottom_unit]
        print(f"Top co-active unit: {top_unit} (z-score: {top_zscore:.3f})")
        print(f"Bottom co-active unit: {bottom_unit} (z-score: {bottom_zscore:.3f})")
    else:
        print("No feature activations found for this instance/feature.")
        return
    
    # Create complete dataset with zeros for inactive features
    all_examples = pd.DataFrame({'example_idx': range(len(metadata_binned))})
    all_examples = all_examples.merge(feature_acts[['example_idx', 'activation_value']], on='example_idx', how='left').fillna(0)
    
    if all_examples.empty:
        print(f"⚠️ No examples found for Instance {instance_idx}, Feature {feature_idx}. Cannot generate plot.")
        return
    
    # Get metadata and z-scores for all examples
    metadata_slice = metadata_binned[[variable]].iloc[all_examples['example_idx']]
    top_unit_slice = spk_z_scores_df[[top_unit]].iloc[all_examples['example_idx']]
    bottom_unit_slice = spk_z_scores_df[[bottom_unit]].iloc[all_examples['example_idx']]
    
    # Create plotting dataframe
    data_df = metadata_slice.reset_index(drop=True)
    data_df['activation_value'] = all_examples['activation_value'].reset_index(drop=True)
    data_df['top_unit_zscore'] = top_unit_slice[top_unit].reset_index(drop=True)
    data_df['bottom_unit_zscore'] = bottom_unit_slice[bottom_unit].reset_index(drop=True)
    data_df = data_df.dropna(subset=[variable])
    
    if data_df.empty:
        print(f"⚠️ No matching metadata found for feature bins. Cannot generate plot.")
        return
    
    # Check if data is interval type
    try:
        is_interval_data = pd.api.types.is_interval_dtype(data_df[variable].cat.categories)
    except AttributeError:
        is_interval_data = False
    
    # Calculate summary statistics
    stats_df = data_df.groupby(variable).agg({
        'activation_value': ['mean', 'sem'],
        'top_unit_zscore': ['mean', 'sem'],
        'bottom_unit_zscore': ['mean', 'sem']
    }).reset_index()
    
    # Flatten column names
    stats_df.columns = [variable, 'feature_mean', 'feature_sem', 'top_unit_mean', 'top_unit_sem', 'bottom_unit_mean', 'bottom_unit_sem']
    
    # Calculate rate proportions
    if not feature_acts.empty:
        condition_masks = {condition: metadata_binned[variable] == condition for condition in stats_df[variable]}
        active_example_set = set(feature_acts['example_idx'])
        
        rate_props = []
        for _, row in stats_df.iterrows():
            condition = row[variable]
            condition_mask = condition_masks[condition]
            
            condition_example_idxs = np.where(condition_mask)[0]
            condition_activations = len(active_example_set.intersection(condition_example_idxs))
            activation_frac_during = condition_activations / len(condition_example_idxs) if len(condition_example_idxs) > 0 else 0
            
            non_condition_example_idxs = np.where(~condition_mask)[0]
            non_condition_activations = len(active_example_set.intersection(non_condition_example_idxs))
            activation_frac_outside = non_condition_activations / len(non_condition_example_idxs) if len(non_condition_example_idxs) > 0 else 0
            
            rate_proportion = activation_frac_during / (activation_frac_during + activation_frac_outside) if (activation_frac_during + activation_frac_outside) > 0 else 0
            rate_props.append(rate_proportion)
        
        stats_df['rate_proportion'] = rate_props
    else:
        stats_df['rate_proportion'] = 0
    
    # Calculate z-score stats for bar plot
    if len(feature_acts) > 0:
        zscore_stats = spk_z_scores_df.iloc[feature_active_indices].agg(['mean', 'sem']).T
        zscore_stats.columns = ['mean_zscore', 'sem_zscore']
        zscore_stats = zscore_stats.reset_index()
        zscore_stats.columns = ['neuron', 'mean_zscore', 'sem_zscore']
    else:
        zscore_stats = pd.DataFrame({'neuron': spk_z_scores_df.columns, 'mean_zscore': 0, 'sem_zscore': 0})
    
    # Create plots based on variable type
    if 'angle' in variable:  # Polar plot
        stats_df['theta'] = stats_df[variable].apply(lambda x: x.mid if isinstance(x, pd.Interval) else x)
        stats_df = stats_df.sort_values('theta')
        plot_df = pd.concat([stats_df, stats_df.head(1)], ignore_index=True)
        
        fig = make_subplots(
            rows=2, cols=4,
            specs=[[{"type": "polar"}, {"type": "polar"}, {"type": "polar"}, {"type": "polar"}],
                   [{"type": "xy", "colspan": 4}, None, None, None]],
            horizontal_spacing=0.1,
            vertical_spacing=0.15,
            subplot_titles=["Feature Activation", "Top Unit Z-score", "Bottom Unit Z-score", "Rate Proportion", "Mean Z-scores when Feature Active"]
        )
        
        # Feature activation polar plot
        fig.add_trace(go.Scatterpolar(r=plot_df['feature_mean'] + plot_df['feature_sem'], theta=plot_df['theta'], mode='lines', line=dict(width=0), showlegend=False), row=1, col=1)
        fig.add_trace(go.Scatterpolar(r=plot_df['feature_mean'] - plot_df['feature_sem'], theta=plot_df['theta'], mode='lines', line=dict(width=0), fill='tonext', fillcolor='rgba(220,20,60,0.2)', name='Feature ±SEM'), row=1, col=1)
        fig.add_trace(go.Scatterpolar(r=plot_df['feature_mean'], theta=plot_df['theta'], mode='lines+markers', line=dict(color='crimson', width=3), name='Feature Activation'), row=1, col=1)
        
        # Top unit polar plot
        fig.add_trace(go.Scatterpolar(r=plot_df['top_unit_mean'] + plot_df['top_unit_sem'], theta=plot_df['theta'], mode='lines', line=dict(width=0), showlegend=False), row=1, col=2)
        fig.add_trace(go.Scatterpolar(r=plot_df['top_unit_mean'] - plot_df['top_unit_sem'], theta=plot_df['theta'], mode='lines', line=dict(width=0), fill='tonext', fillcolor='rgba(0,0,139,0.2)', name='Top Unit ±SEM'), row=1, col=2)
        fig.add_trace(go.Scatterpolar(r=plot_df['top_unit_mean'], theta=plot_df['theta'], mode='lines+markers', line=dict(color='darkblue', width=3), name='Top Unit Z-score'), row=1, col=2)
        
        # Bottom unit polar plot
        fig.add_trace(go.Scatterpolar(r=plot_df['bottom_unit_mean'] + plot_df['bottom_unit_sem'], theta=plot_df['theta'], mode='lines', line=dict(width=0), showlegend=False), row=1, col=3)
        fig.add_trace(go.Scatterpolar(r=plot_df['bottom_unit_mean'] - plot_df['bottom_unit_sem'], theta=plot_df['theta'], mode='lines', line=dict(width=0), fill='tonext', fillcolor='rgba(255,165,0,0.2)', name='Bottom Unit ±SEM'), row=1, col=3)
        fig.add_trace(go.Scatterpolar(r=plot_df['bottom_unit_mean'], theta=plot_df['theta'], mode='lines+markers', line=dict(color='orange', width=3), name='Bottom Unit Z-score'), row=1, col=3)
        
        # Rate proportion polar plot
        fig.add_trace(go.Scatterpolar(r=plot_df['rate_proportion'], theta=plot_df['theta'], mode='lines+markers', line=dict(color='green', width=3), name='Rate Proportion'), row=1, col=4)
        
        # Z-score bar plot
        fig.add_trace(go.Bar(x=zscore_stats['neuron'], y=zscore_stats['mean_zscore'], error_y=dict(type='data', array=zscore_stats['sem_zscore']), marker_color='purple', marker_line_width=0, opacity=0.7, name='Mean Z-score'), row=2, col=1)
        
        # Create tick labels
        if 'movement_angle' not in variable:
            tick_labels = [f"{int(round(theta))}°" for theta in stats_df['theta']]
        else:
            tick_labels = [f"{int(interval.mid)}°" if hasattr(interval, 'mid') else str(interval) for interval in stats_df[variable]]    
        
        fig.update_layout(title=f"Instance {instance_idx} Feature {feature_idx} & Top Unit {top_unit} & Bottom Unit {bottom_unit}", showlegend=True, height=800, width=1600, margin=dict(t=80, b=60, l=50, r=50))
        
        # Update polar plots
        rotation = 195 if 'movement_angle' in variable else 0
        tickvals = stats_df['theta'].tolist() if 'movement_angle' in variable else [(angle % 360) for angle in stats_df['theta'].tolist()]
        for i in range(1, 5):
            polar_key = f'polar{i if i > 1 else ""}'
            fig.update_layout(**{polar_key: dict(angularaxis=dict(direction="counterclockwise", rotation=rotation, tickvals=tickvals, ticktext=tick_labels, tickfont=dict(size=10)), radialaxis=dict(range=[0, None]))})

        fig.update_xaxes(title_text="Neuron", row=2, col=1)
        fig.update_yaxes(title_text="Mean Z-score", row=2, col=1)
            
    else:  # Linear plot
        fig = make_subplots(rows=2, cols=4, specs=[[{}, {}, {}, {}], [{"colspan": 4}, None, None, None]], subplot_titles=["Feature Activation", "Top Unit Z-score", "Bottom Unit Z-score", "Rate Proportion", "Mean Z-scores when Feature Active"], horizontal_spacing=0.1, vertical_spacing=0.3)
        
        # Setup x-axis labels
        if is_interval_data:
            x_axis_labels = stats_df[variable].apply(lambda x: str(x))
            stats_df = stats_df.sort_values(by=variable)
        else:
            x_axis_labels = stats_df[variable].astype(str)
        
        # Plot based on data type
        if is_interval_data:
            # Line plots for continuous data
            fig.add_trace(go.Scatter(x=x_axis_labels, y=stats_df['feature_mean'] + stats_df['feature_sem'], mode='lines', line_color='rgba(0,0,0,0)', showlegend=False), row=1, col=1)
            fig.add_trace(go.Scatter(x=x_axis_labels, y=stats_df['feature_mean'] - stats_df['feature_sem'], mode='lines', line_color='rgba(0,0,0,0)', fill='tonexty', fillcolor='rgba(220,20,60,0.2)', name='Feature ±SEM'), row=1, col=1)
            fig.add_trace(go.Scatter(x=x_axis_labels, y=stats_df['feature_mean'], mode='lines+markers', line_color='crimson', name='Feature Activation'), row=1, col=1)
            
            fig.add_trace(go.Scatter(x=x_axis_labels, y=stats_df['top_unit_mean'] + stats_df['top_unit_sem'], mode='lines', line_color='rgba(0,0,0,0)', showlegend=False), row=1, col=2)
            fig.add_trace(go.Scatter(x=x_axis_labels, y=stats_df['top_unit_mean'] - stats_df['top_unit_sem'], mode='lines', line_color='rgba(0,0,0,0)', fill='tonexty', fillcolor='rgba(0,0,139,0.2)', name='Top Unit ±SEM'), row=1, col=2)
            fig.add_trace(go.Scatter(x=x_axis_labels, y=stats_df['top_unit_mean'], mode='lines+markers', line_color='darkblue', name='Top Unit Z-score'), row=1, col=2)
            
            fig.add_trace(go.Scatter(x=x_axis_labels, y=stats_df['bottom_unit_mean'] + stats_df['bottom_unit_sem'], mode='lines', line_color='rgba(0,0,0,0)', showlegend=False), row=1, col=3)
            fig.add_trace(go.Scatter(x=x_axis_labels, y=stats_df['bottom_unit_mean'] - stats_df['bottom_unit_sem'], mode='lines', line_color='rgba(0,0,0,0)', fill='tonexty', fillcolor='rgba(255,165,0,0.2)', name='Bottom Unit ±SEM'), row=1, col=3)
            fig.add_trace(go.Scatter(x=x_axis_labels, y=stats_df['bottom_unit_mean'], mode='lines+markers', line_color='orange', name='Bottom Unit Z-score'), row=1, col=3)
            
            fig.add_trace(go.Scatter(x=x_axis_labels, y=stats_df['rate_proportion'], mode='lines+markers', line=dict(color='green', width=3), name='Rate Proportion'), row=1, col=4)
        else:
            # Bar plots for categorical data
            fig.add_trace(go.Bar(x=x_axis_labels, y=stats_df['feature_mean'], error_y=dict(type='data', array=stats_df['feature_sem']), marker_color='crimson', marker_line_width=0, opacity=0.7, name='Feature Activation'), row=1, col=1)
            fig.add_trace(go.Bar(x=x_axis_labels, y=stats_df['top_unit_mean'], error_y=dict(type='data', array=stats_df['top_unit_sem']), marker_color='darkblue', marker_line_width=0, opacity=0.7, name='Top Unit Z-score'), row=1, col=2)
            fig.add_trace(go.Bar(x=x_axis_labels, y=stats_df['bottom_unit_mean'], error_y=dict(type='data', array=stats_df['bottom_unit_sem']), marker_color='orange', marker_line_width=0, opacity=0.7, name='Bottom Unit Z-score'), row=1, col=3)
            fig.add_trace(go.Bar(x=x_axis_labels, y=stats_df['rate_proportion'], marker_color='green', marker_line_width=0, opacity=0.7, name='Rate Proportion'), row=1, col=4)
        
        # Z-score bar plot
        fig.add_trace(go.Bar(x=zscore_stats['neuron'], y=zscore_stats['mean_zscore'], error_y=dict(type='data', array=zscore_stats['sem_zscore']), marker_color='purple', marker_line_width=0, opacity=0.7, name='Mean Z-score'), row=2, col=1)
        
        fig.update_layout(title=f"Instance {instance_idx} Feature {feature_idx} & Top Unit {top_unit} & Bottom Unit {bottom_unit}", height=800, width=1600, showlegend=True, margin=dict(t=80, b=60, l=50, r=50))
        
        # Update axes
        fig.update_xaxes(title_text=variable, tickangle=45, row=1, col=1)
        fig.update_xaxes(title_text=variable, tickangle=45, row=1, col=2)
        fig.update_xaxes(title_text=variable, tickangle=45, row=1, col=3)
        fig.update_xaxes(title_text=variable, tickangle=45, row=1, col=4)
        fig.update_xaxes(title_text="Neuron", row=2, col=1)
        
        fig.update_yaxes(title_text="Feature Activation", color="crimson", range=[0, None], row=1, col=1)
        fig.update_yaxes(title_text="Top Unit Z-score", color="darkblue", row=1, col=2)
        fig.update_yaxes(title_text="Bottom Unit Z-score", color="orange", row=1, col=3)
        fig.update_yaxes(title_text="Rate Proportion", range=[0, None], color="green", row=1, col=4)
        fig.update_yaxes(title_text="Mean Z-score", row=2, col=1)
        
        fig.update_yaxes(rangemode='tozero', row=1, col=1)
        fig.update_yaxes(rangemode='tozero', row=1, col=4)
    
    fig.show()

# Find go_cue feature and plot
event_feature = results[results['value'] == 'go_cue'].sort_values('activation_ratio', ascending=False).iloc[0]
display(event_feature)
plot_feature_tuning(acts_df=acts_df, spk_z_scores_df=spk_z_scores_df, metadata_binned=metadata_binned, variable='event', instance_idx=int(event_feature['instance_idx']), feature_idx=int(event_feature['feature_idx']))

# Find movement_angle feature and plot  
move_angle_feature = results[results['variable'] == 'movement_angle'].sort_values('activation_ratio', ascending=False).iloc[0]
display(move_angle_feature)
plot_feature_tuning(acts_df=acts_df, spk_z_scores_df=spk_z_scores_df, metadata_binned=metadata_binned, variable='movement_angle_binned', instance_idx=int(move_angle_feature['instance_idx']), feature_idx=int(move_angle_feature['feature_idx']))

# Find vel_magnitude feature and plot
velocity_feature = results[results['variable'] == 'vel_magnitude'].sort_values('activation_ratio', ascending=False).iloc[0]
display(velocity_feature)
plot_feature_tuning(acts_df=acts_df, spk_z_scores_df=spk_z_scores_df, metadata_binned=metadata_binned, variable='vel_magnitude_binned', instance_idx=int(velocity_feature['instance_idx']), feature_idx=int(velocity_feature['feature_idx']))

In [None]:
"""Interactive UI"""

# Helper to map a base variable and its type to the metadata_binned column name
def _bvar_name(var_name, var_type):
    return f"{var_name}_binned" if var_type == 'continuous' else var_name

# Selector for preset vs manual mode
mode_radio = widgets.RadioButtons(
    options=[
        ('Preset (from results table)', 'preset'),
        ('Manual selection', 'manual')
    ],
    value='preset',
    description=''
)

# Build the preset dropdown and store (instance, feature, variable) as the value
preset_entries = []
for _, r in results.iterrows():
    bvar = _bvar_name(r.variable, r.variable_type)
    if bvar not in metadata_binned.columns:
        continue  # Skip variables you haven’t binned
    label = (
        f"Inst:{int(r.instance_idx)} | "
        f"Feat:{int(r.feature_idx)} | "
        f"Var:{bvar} | "
        f"Val:{r['value']} | "
        f"FracDuring:{r.activation_frac_during:.3f} | "
        f"FracOutside:{r.activation_frac_outside:.3f} | "
        f"ActRatio:{r.activation_ratio:.3f} | "
        f"RateProp:{r.rate_proportion:.3f}"
    )
    preset_entries.append((label, (int(r.instance_idx), int(r.feature_idx), bvar)))

preset_dropdown = widgets.Dropdown(
    options=preset_entries,
    description='Select Result:',
    layout=widgets.Layout(width='80%')
)
preset_box = widgets.VBox([preset_dropdown])

# Build the manual selection box
instance_dropdown = widgets.Dropdown(
    options=sorted(acts_df['instance_idx'].unique()),
    description='Instance:'
)

# Only include properly binned variables
unique_vars = results[['variable','variable_type']].drop_duplicates()
manual_var_options = [
    (_bvar_name(v, t), _bvar_name(v, t))
    for v, t in unique_vars.values
    if _bvar_name(v, t) in metadata_binned.columns
]
variable_dropdown = widgets.Dropdown(
    options=manual_var_options,
    description='Variable:'
)

# Precompute bvar column on results for filtering feature indices
results['bvar'] = results.apply(
    lambda r: _bvar_name(r.variable, r.variable_type), axis=1
)

# Replace fixed feature input with a dropdown that updates based on variable
feature_dropdown = widgets.Dropdown(
    description='Feature:',
    options=[]
)

# Callback to repopulate feature options when variable changes
def _on_var_change(change):
    sel_var = change['new']
    feats = sorted(
        results.loc[results['bvar'] == sel_var, 'feature_idx'].unique()
    )
    feature_dropdown.options = feats

variable_dropdown.observe(_on_var_change, names='value')
_on_var_change({'new': variable_dropdown.value})

manual_box = widgets.VBox([
    instance_dropdown,
    variable_dropdown,
    feature_dropdown
])
manual_box.layout.display = 'none'  # Start hidden

# Buttons for generating or clearing the plot, and an output area
generate_btn = widgets.Button(description='Generate Plot', button_style='info')
clear_btn    = widgets.Button(description='Clear',         button_style='warning')
button_box   = widgets.HBox([generate_btn, clear_btn])
out = widgets.Output()

# Toggle between preset and manual views
def _on_mode_change(change):
    if change['new'] == 'preset':
        preset_box.layout.display = ''
        manual_box.layout.display = 'none'
    else:
        preset_box.layout.display = 'none'
        manual_box.layout.display = ''

mode_radio.observe(_on_mode_change, names='value')

# Callback to generate the tuning plot
def _on_generate(_):
    with out:
        clear_output()
        if mode_radio.value == 'preset':
            inst, feat, var = preset_dropdown.value
        else:
            inst = instance_dropdown.value
            feat = feature_dropdown.value
            var  = variable_dropdown.value

        plot_feature_tuning(
            acts_df=acts_df,
            spk_z_scores_df=spk_z_scores_df,
            metadata_binned=metadata_binned,
            variable=var,
            instance_idx=inst,
            feature_idx=feat
        )

# Callback to clear the output
def _on_clear(_):
    out.clear_output()

generate_btn.on_click(_on_generate)
clear_btn.on_click(_on_clear)

# Assemble and display the UI
ui = widgets.VBox([
    widgets.HTML("<h2>SAE Feature Visualization</h2>"),
    mode_radio,
    preset_box,
    manual_box,
    button_box,
    out
])
display(ui)

## Scratchpad below

In [None]:
instance = 0
feature = 474
feature_activity = acts_df[
    (acts_df["instance_idx"] == instance) &
    (acts_df["feature_idx"] == feature)
]
display(feature_activity.sort_values("example_idx"))
feat_act_ts = metadata.iloc[feature_activity["example_idx"]].index
print(feat_act_ts)

In [None]:
max_col_name = spk_cts_df.iloc[feature_activity["example_idx"]].mean().idxmax()
unit_activity = spk_cts_df[max_col_name].values.cumsum()
unit_activity = unit_activity / unit_activity[-1]
print(unit_activity)

In [None]:
metadata_binned.dtypes, metadata_binned.index

In [None]:
"""Visualize feature activity vs. movement blocks from metadata."""

# Parameters
instance = 0
feature = 474

# Get feature activity
acts = acts_df[(acts_df["instance_idx"] == instance) & (acts_df["feature_idx"] == feature)]
feat_act_ts = metadata_binned.iloc[acts["example_idx"]].index

# Build normalized cumulative feature activity
feat_timeline = np.zeros(len(metadata_binned))
feat_timeline[acts["example_idx"].values] = 1
feat_cumsum = np.cumsum(feat_timeline)
feat_norm = feat_cumsum / feat_cumsum[-1] if feat_cumsum[-1] > 0 else feat_cumsum

# Get top unit activity
top_unit = spk_cts_df.iloc[acts["example_idx"]].mean().idxmax()
unit_cumsum = spk_cts_df[top_unit].values.cumsum()
unit_norm = unit_cumsum / unit_cumsum[-1]

# Calculate speed and normalize
speed = np.sqrt(metadata_binned["vel_x"]**2 + metadata_binned["vel_y"]**2)
speed_norm = speed / speed.max()

# Create subplots: feature/unit/speed on top, movement block on bottom
fig = make_subplots(
    rows=2, cols=1,
    shared_xaxes=True,
    vertical_spacing=0.05,
    row_heights=[0.8, 0.2]
)

# Top traces: Feature activity, unit activity, speed
fig.add_trace(
    go.Scatter(
        x=metadata_binned.index,
        y=feat_norm,
        mode="lines",
        name="Feature Activity",
        line=dict(color="black")
    ),
    row=1, col=1
)
fig.add_trace(
    go.Scatter(
        x=metadata_binned.index[: len(unit_norm)],
        y=unit_norm,
        mode="lines",
        name="Top Unit Activity",
        line=dict(color="blue"),
        opacity=0.5
    ),
    row=1, col=1
)
subsample = max(1, len(metadata_binned) // 5000)
fig.add_trace(
    go.Scatter(
        x=metadata_binned.index[::subsample],
        y=speed_norm[::subsample],
        mode="lines",
        name="Running Speed",
        line=dict(color="red"),
        opacity=0.3
    ),
    row=1, col=1
)

# Bottom trace: blocks where is_moving == True
# Assumes metadata_binned has an 'is_moving' boolean column
moving_int = metadata_binned["is_moving"].astype(int)
fig.add_trace(
    go.Scatter(
        x=metadata_binned.index,
        y=moving_int,
        mode="none",
        fill='tozeroy',
        name='Is Moving',
        fillcolor='rgba(0,200,0,0.3)'
    ),
    row=2, col=1
)

# Update axes
fig.update_yaxes(
    title_text="Normalized Cumulative Count",
    tickfont=dict(size=17),
    titlefont=dict(size=18),
    row=1, col=1
)
fig.update_yaxes(
    title_text="Moving",
    showticklabels=False,
    range=[0,1],
    row=2, col=1
)
fig.update_xaxes(
    title_text="Time (timestamp)",
    tickfont=dict(size=17),
    titlefont=dict(size=19),
    tickangle=-45,
    row=2, col=1
)

# Layout tweaks
total_act = feat_cumsum[-1] if len(feat_cumsum) > 0 else 0
fig.update_layout(
    title_text=f"Instance {instance}, Feature {feature} - Total Activations: {total_act}",
    margin=dict(l=40, r=20, t=40, b=20),
    legend=dict(
        y=0.99,
        x=0.01,
        font=dict(size=13),
        itemsizing='constant',
        bgcolor="rgba(255,255,255,0.3)"
    ),
    title=dict(font=dict(size=23)),
    height=600,
    width=1000
)

fig.show()


In [None]:
# """Visualize feature activity vs. events from metadata."""

# import numpy as np
# import pandas as pd
# import plotly.graph_objects as go
# from plotly.subplots import make_subplots
# import plotly.express as px

# # Parameters
# instance = 0
# feature = 151

# # Get feature activity
# feature_activity = acts_df[
#     (acts_df["instance_idx"] == instance) &
#     (acts_df["feature_idx"] == feature)
# ]
# feat_act_ts = metadata_binned.iloc[feature_activity["example_idx"]].index

# # Create feature activity timeline (cumulative normalized count)
# feat_activity_timeline = np.zeros(len(metadata_binned))
# feat_indices = feature_activity["example_idx"].values
# feat_activity_timeline[feat_indices] = 1
# feat_activity_cumsum = np.cumsum(feat_activity_timeline)
# feat_activity_norm = feat_activity_cumsum / feat_activity_cumsum[-1] if feat_activity_cumsum[-1] > 0 else feat_activity_cumsum

# # Get top unit activity
# max_col_name = spk_cts_df.iloc[feature_activity["example_idx"]].mean().idxmax()
# unit_activity = spk_cts_df[max_col_name].values.cumsum()
# unit_activity = unit_activity / unit_activity[-1]

# # Calculate speed from vel_x, vel_y
# speed = np.sqrt(metadata_binned["vel_x"]**2 + metadata_binned["vel_y"]**2)
# speed_normalized = speed / speed.max()

# # Get events data - filter out empty string events
# events_df = metadata_binned[
#     (metadata_binned["event"].notna()) & 
#     (metadata_binned["event"] != '') & 
#     (metadata_binned["event"].str.strip() != '')
# ].copy()
# event_types = events_df["event"].unique()

# # Create subplots
# fig = make_subplots(
#     rows=2, 
#     cols=1,
#     shared_xaxes=True,
#     vertical_spacing=0.05,
#     row_heights=[0.8, 0.2]
# )

# # Add feature activity (top plot)
# fig.add_trace(
#     go.Scatter(
#         x=metadata_binned.index,
#         y=feat_activity_norm,
#         mode="lines",
#         name="Feature Activity",
#         line=dict(color="black")
#     ),
#     row=1, col=1
# )

# # Add top unit activity (top plot)
# fig.add_trace(
#     go.Scatter(
#         x=metadata_binned.index[:len(unit_activity)],
#         y=unit_activity,
#         mode="lines",
#         name="Top Unit Activity",
#         line=dict(color="blue"),
#         opacity=0.5,
#     ),
#     row=1, col=1
# )

# # Add speed (top plot) - subsample for performance
# subsample_rate = max(1, len(metadata_binned) // 5000)  # Adjust for performance
# fig.add_trace(
#     go.Scatter(
#         x=metadata_binned.index[::subsample_rate],
#         y=speed_normalized.fillna(0)[::subsample_rate],
#         mode="lines",
#         name="Running Speed",
#         line=dict(color="red"),
#         opacity=0.3
#     ),
#     row=1, col=1
# )

# # Add events (bottom plot) - each event type in its own row
# # First, get all individual event types (including from combined events)
# all_individual_events = set()
# for event in event_types:
#     if ',' in event:
#         # Split combined events
#         individual_events = [e.strip() for e in event.split(',')]
#         all_individual_events.update(individual_events)
#     else:
#         all_individual_events.add(event)

# all_individual_events = sorted(list(all_individual_events))

# # Create colors for individual events
# colors = px.colors.qualitative.Plotly
# event_colors = {event: colors[i % len(colors)] for i, event in enumerate(all_individual_events)}

# # Calculate row positions for each event type
# num_event_types = len(all_individual_events)
# row_height = 0.8 / num_event_types  # Use 80% of vertical space, leaving margins
# row_spacing = 0.1 / (num_event_types + 1)  # Spacing between rows

# event_row_positions = {}
# for i, event in enumerate(all_individual_events):
#     # Calculate bottom and top of each row
#     y_bottom = row_spacing + i * (row_height + row_spacing)
#     y_top = y_bottom + row_height
#     event_row_positions[event] = (y_bottom, y_top)

# print(f"Individual event types: {all_individual_events}")
# print(f"Row positions: {event_row_positions}")

# # Keep track of which events have been added to legend
# events_in_legend = set()

# # Sort event types to control drawing order
# event_types_sorted = sorted(event_types, key=lambda x: (x == 'move_ends', x))

# # Process each original event type
# for event_type in event_types_sorted:
#     event_times = events_df[events_df["event"] == event_type].index
    
#     if ',' in event_type:
#         # Handle combined events - plot each component in its own row
#         individual_events = [e.strip() for e in event_type.split(',')]
        
#         for individual_event in individual_events:
#             y_bottom, y_top = event_row_positions[individual_event]
            
#             # Create vertical lines in this event's dedicated row
#             x_coords = []
#             y_coords = []
            
#             for event_time in event_times:
#                 x_coords.extend([event_time, event_time, None])
#                 y_coords.extend([y_bottom, y_top, None])
            
#             fig.add_trace(
#                 go.Scatter(
#                     x=x_coords,
#                     y=y_coords,
#                     mode="lines",
#                     name=individual_event,
#                     line=dict(color=event_colors[individual_event], width=1),
#                     opacity=0.9,
#                     hovertemplate=f"<b>Event:</b> {individual_event}<br><b>Time:</b> %{{x}}<br><b>Combined from:</b> {event_type}<extra></extra>",
#                     showlegend=individual_event not in events_in_legend,
#                     legendgroup=individual_event
#                 ),
#                 row=2, col=1
#             )
#             events_in_legend.add(individual_event)
#     else:
#         # Handle single events in their dedicated row
#         y_bottom, y_top = event_row_positions[event_type]
        
#         x_coords = []
#         y_coords = []
        
#         for event_time in event_times:
#             x_coords.extend([event_time, event_time, None])
#             y_coords.extend([y_bottom, y_top, None])
        
#         fig.add_trace(
#             go.Scatter(
#                 x=x_coords,
#                 y=y_coords,
#                 mode="lines",
#                 name=event_type,
#                 line=dict(color=event_colors[event_type], width=1),
#                 opacity=0.9,
#                 hovertemplate=f"<b>Event:</b> {event_type}<br><b>Time:</b> %{{x}}<extra></extra>",
#                 showlegend=event_type not in events_in_legend
#             ),
#             row=2, col=1
#         )
#         events_in_legend.add(event_type)

# # Update layout
# fig.update_yaxes(
#     title_text="Normalized Cumulative Count",
#     tickfont=dict(size=17),
#     titlefont=dict(size=18),
#     row=1, col=1
# )
# fig.update_yaxes(
#     title_text="Events",
#     showticklabels=False,
#     range=[0, 1],  # Fixed range for event display
#     row=2, col=1
# )
# fig.update_xaxes(
#     title_text="Time (timestamp)",
#     tickfont=dict(size=17),
#     titlefont=dict(size=19),
#     tickangle=-45,
#     row=2, col=1
# )

# # Set title
# total_activations = feat_activity_cumsum[-1] if len(feat_activity_cumsum) > 0 else 0
# fig.update_layout(
#     title_text=f"Instance {instance}, Feature {feature} - Feature Activity (Total: {total_activations} activations)"
# )

# fig.update_layout(
#     # Reduce margin space around the entire figure
#     margin=dict(l=40, r=20, t=40, b=20),
    
#     # Make the legend more compact
#     legend=dict(
#         y=0.999,
#         x=0.001,
#         font=dict(size=13),
#         itemsizing='constant',
#         bgcolor="rgba(255, 255, 255, 0.3)"
#     ),
    
#     # Title styling
#     title=dict(
#         font=dict(size=23),
#     ),
    
#     height=600,
#     width=1000
# )

# fig.show()

Jai's code below

In [None]:
"""Visualize unit spiking variability over SAE feature activity."""

fig, ax = plt.subplots(figsize=(20, 8))

ax = sns.boxplot(
    x="variable", 
    y="value", 
    data=pd.melt(spk_cts_df.iloc[f_ex_idxs]), 
    showfliers=False,
    width=1,
    whis=0.75, 
    ax=ax
)
ax.set_xlabel("units", fontsize=26)
ax.set_ylabel("normalized spike counts", fontsize=26)# ax.set_yticks([])
ax.set_xticklabels([])
# Set yticklabel font size
ax.tick_params(axis="y", labelsize=24)
ax.set_title(f"feature ({inst_i} : {feat_i}) normalized unit spike counts when active", fontsize=30)

In [None]:
# fraction of data that is flashes
len(metadata[metadata["stimulus_name"] == "flashes"]) / len(metadata)

In [None]:
11000 / 15000

In [None]:
"""Visualize SAE-natural feature confusion matrix."""

fig, ax = plt.subplots(figsize=(12, 8))
ax.bar(
    x=[
        "gratings", 
        "gratings", 
        "gratings\n(spatial frequency = 0.04)", 
        "gratings\n(spatial frequency = 0.04)"
    ],
    height=[0.56, 0.37, 0.6, 0.17],
    color=["blue", "black", "blue", "black"],
    alpha=0.7,
    width=0.5
)

# fig, ax = plt.subplots(figsize=(8, 8))
# ax.bar(
#     x=[
#         "flashes", 
#         "flashes", 
#     ],
#     height=[0.89, 0.03],
#     color=["blue", "black"],
#     alpha=0.7,
#     width=0.3
# )

ax.set_title(f"({inst_i} : {feat_i}) feature active on", fontsize=26)

# update fontsizes of ticks labels and titles
ax.tick_params(axis="x", labelsize=22)
ax.tick_params(axis="y", labelsize=19)
ax.set_ylabel("fraction of presentations", fontsize=22)
# ax.set_xticklabels(["drifting_gratings", "", "drifting_gratings\n(temporal frequency < 4 Hz)", ""])
ax.set_yticks(np.arange(0, 1.1, 0.1))

# fig.tight_layout()


In [None]:
"""Visualize boxplots of feature activations across levels as a proxy for represented hierarchy."""

l0_act_frac = features_df[features_df["feature_idx"] < last_feat_idx_general]["activation_frac"] + 0.01

l1_act_frac = features_df[np.logical_and(
    features_df["feature_idx"] > last_feat_idx_general,
    features_df["feature_idx"] < first_feat_idx_specific
)]["activation_frac"] + 0.01

l2_act_frac = features_df[features_df["feature_idx"] > first_feat_idx_specific]["activation_frac"] - 0.01
l2_act_frac = l2_act_frac.clip(lower=0)

# First, create a dataframe to hold all three sets of data with appropriate labels
boxplot_data = pd.DataFrame({
    "Layer": ["0 (General)"] * len(l0_act_frac) + 
             ["1 (Intermediate)"] * len(l1_act_frac) + 
             ["2 (Specific)"] * len(l2_act_frac),
    "Activation Fraction": pd.concat([l0_act_frac, l1_act_frac, l2_act_frac])
})

# Create the plot
plt.figure(figsize=(10, 6))
ax = sns.boxplot(
    x="Layer", 
    y="Activation Fraction", 
    data=boxplot_data,
    showfliers=False,  # This hides the outliers
    width=0.6,         # Controls the width of the boxes
    palette="viridis",  # Optional: choose a color palette
    # show mean
    showmeans=True,
    meanprops={"marker": "v", "markerfacecolor": "white", "markeredgecolor": "black", "markersize": 10},
)

# Customize the plot
plt.title("Feature activation fraction by matryoshka level", fontsize=14)
plt.xlabel("Level", fontsize=12)
plt.ylabel("Activation Fraction", fontsize=12)
plt.grid(axis="y", linestyle="--", alpha=0.7)
plt.tight_layout()

# Show the plot
plt.show()

In [None]:
"""Visualize boxplots of feature activations across levels as a proxy for represented hierarchy."""

acts_df_feat_on_idx = np.logical_and(
    acts_df["feature_idx"].isin(features_df["feature_idx"].unique()), 
    acts_df["instance_idx"] == 0
)

acts_df_comp = acts_df[acts_df_feat_on_idx].copy().drop(
    columns=["activation_value", "instance_idx"], axis=1
)

feature_idxs = np.sort(acts_df_comp["feature_idx"].unique())
corr_mat_idxs = np.arange(0, len(feature_idxs))
corr_mat_feat_idx_map = {feature_idxs[i]: corr_mat_idxs[i] for i in range(len(feature_idxs))}
acts_df_comp["corr_mat_idx"] = acts_df_comp["feature_idx"].map(corr_mat_feat_idx_map)

feat_on_mask = np.zeros((n_recon_examples, len(feature_idxs)))
feat_on_mask[acts_df_comp["example_idx"].values, acts_df_comp["corr_mat_idx"].values] = True

tot_feat_act = feat_on_mask.sum(axis=0)

feat_corr = t.tensor(feat_on_mask.T, device=device) @ t.tensor(feat_on_mask, device=device)
feat_corr_norm = (feat_corr / t.tensor(tot_feat_act, device=device).unsqueeze(0))

l0_co_idx = np.where(feature_idxs > 256)[0][0]
l1_co_idx = np.where(feature_idxs > 512)[0][0]

In [None]:
# Create a heatmap

fig, ax = plt.subplots(figsize=(12, 8))
sns.heatmap(
    asnumpy(feat_corr_norm), 
    cmap="viridis", 
    vmin=0, 
    vmax=1, 
    square=True,
    xticklabels=range(feat_on_mask.shape[1]),
    yticklabels=range(feat_on_mask.shape[1]),
    ax=ax
)
ax.set_title("SAE feature conditional correlation matrix", fontsize=20)
ax.set_xlabel("SAE feature i", fontsize=16)
ax.set_ylabel("SAE feature j ( P(j | i) )", fontsize=16)
ax.set_xticks([])
ax.set_yticks([])

# Draw red horizontal and vertical lines at l0_co_idx and l1_co_idx
for idx in [l0_co_idx, l1_co_idx]:
    ax.axhline(y=idx, color="red", linestyle="--", linewidth=2)
    ax.axvline(x=idx, color="red", linestyle="--", linewidth=2)

plt.tight_layout()

In [None]:
metadata

In [None]:
"""Visualize MSAE latents as bases of neural manifolds."""

# For SAE features: look at manifolds, feature activity over time (over stimuli xaxis bar), 
# extra: feature correlations ?

features_df.sort_values("activation_std", ascending=True)
acts_df_man = acts_df[acts_df["feature_idx"].isin([246, 41, 22])]
acts_df_man = acts_df_man[acts_df_man["instance_idx"] != 1]

In [None]:
acts_df_man

In [None]:
# Drop rows where instance_idx == 1

man_vals = np.zeros((n_recon_examples, 3))

for i, feat_i in enumerate([246, 41, 22]):
    idxs = acts_df_man[acts_df_man["feature_idx"] == feat_i]["example_idx"].values
    man_vals[idxs, i] = acts_df_man[acts_df_man["feature_idx"] == feat_i]["activation_value"].values


In [None]:
# Create a DataFrame from the numpy array
man_df = pd.DataFrame(man_vals)
man_df.index = man_df.index * bin_s
man_df

In [None]:
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
from matplotlib.animation import FuncAnimation
from mpl_toolkits.mplot3d import Axes3D
from IPython.display import HTML
import matplotlib.animation as animation

In [None]:
# Enhanced version with better visuals and controls
def create_enhanced_3d_animation(df, save_path=None, fps=30, dpi=100):
    fig = plt.figure(figsize=(12, 10))
    ax = fig.add_subplot(111, projection='3d')
    
    # Extract data
    x_data = df.iloc[:, 0].values
    y_data = df.iloc[:, 1].values
    z_data = df.iloc[:, 2].values
    
    # Calculate ranges for better visualization
    x_range = max(x_data) - min(x_data)
    y_range = max(y_data) - min(y_data)
    z_range = max(z_data) - min(z_data)
    
    # Set equal aspect ratio
    max_range = max(x_range, y_range, z_range) / 2
    mid_x = (max(x_data) + min(x_data)) / 2
    mid_y = (max(y_data) + min(y_data)) / 2
    mid_z = (max(z_data) + min(z_data)) / 2
    
    ax.set_xlim(mid_x - max_range, mid_x + max_range)
    ax.set_ylim(mid_y - max_range, mid_y + max_range)
    ax.set_zlim(mid_z - max_range, mid_z + max_range)
    
    # Set labels and title
    ax.set_xlabel(df.columns[0], fontsize=12)
    ax.set_ylabel(df.columns[1], fontsize=12)
    ax.set_zlabel(df.columns[2], fontsize=12)
    ax.set_title('Top 3 SAE latents manifold over time', fontsize=14)
    
    # Add grid
    ax.grid(True)
    
    # Initialize line and point objects
    line, = ax.plot([], [], [], lw=2, color='blue')
    point, = ax.plot([], [], [], 'ro', markersize=10)
    
    # Add a shadow on the floor
    ax.plot(x_data, y_data, min(z_data), 'k--', alpha=0.2)
    
    # Add vector from origin to current point
    vector, = ax.plot([0, 0], [0, 0], [0, 0], 'g-', lw=1.5, alpha=0.7)
    
    # Time display
    time_template = 'Time = %.3f s'
    time_text = ax.text2D(0.05, 0.95, '', transform=ax.transAxes, fontsize=12)
    
    # Progress bar
    progress_bar_ax = fig.add_axes([0.2, 0.05, 0.6, 0.03])
    progress_bar = plt.Rectangle((0, 0), 0, 1, fc='blue', alpha=0.5)
    progress_bar_ax.add_patch(progress_bar)
    progress_bar_ax.set_xlim(0, 1)
    progress_bar_ax.set_ylim(0, 1)
    progress_bar_ax.axis('off')
    
    # Get time values
    if isinstance(df.index[0], (int, float)):
        times = df.index.values
    else:
        times = np.arange(len(df))
    
    def init():
        line.set_data([], [])
        line.set_3d_properties([])
        point.set_data([], [])
        point.set_3d_properties([])
        vector.set_data([], [])
        vector.set_3d_properties([])
        time_text.set_text('')
        progress_bar.set_width(0)
        return line, point, vector, time_text, progress_bar
    
    def update(frame):
        i = min(frame, len(df) - 1)
        progress = i / (len(df) - 1)
        
        # Update line data
        line.set_data(x_data[:i+1], y_data[:i+1])
        line.set_3d_properties(z_data[:i+1])
        
        # Update current point
        point.set_data([x_data[i]], [y_data[i]])
        point.set_3d_properties([z_data[i]])
        
        # Update vector from origin
        vector.set_data([0, x_data[i]], [0, y_data[i]])
        vector.set_3d_properties([0, z_data[i]])
        
        # Update time text
        current_time = times[i] if i < len(times) else times[-1]
        time_text.set_text(time_template % current_time)
        
        # Update progress bar
        progress_bar.set_width(progress)
        
        # Rotate view slightly for 3D effect
        ax.view_init(elev=30, azim=i / 5 % 360)
        
        return line, point, vector, time_text, progress_bar
    
    # Sample a reasonable number of frames
    num_frames = min(len(df), 300)
    frame_indices = np.linspace(0, len(df)-1, num_frames, dtype=int)
    
    ani = FuncAnimation(fig, update, frames=frame_indices,
                        init_func=init, blit=False, interval=1000/fps)
    
    if save_path:
        ani.save(save_path, writer='pillow', fps=fps, dpi=dpi)
        print(f"Enhanced animation saved to {save_path}")
    
    plt.close()
    return ani

# Use this enhanced version
save_path = out_dir / f"{session_id}" / "sae_features" / "sae_0" / "enhanced_3d_animation.gif"
enhanced_ani = create_enhanced_3d_animation(man_df.iloc[::500], save_path=save_path, fps=15, dpi=200)
# HTML(enhanced_ani.to_jshtml())

In [None]:
fig, ax = plt.subplots()
ax.axis("off")
ax.imshow(cache.get_natural_scene_template(0), cmap="gray")

In [None]:
metadata[metadata["stimulus_name"] == "natural_scenes"]["stimulus_condition_id"].unique()

In [None]:
# session_data.presentationwise_spike_times()

In [None]:
# session_data.conditionwise_spike_statistics()