In [None]:
"""Set notebook settings."""

%load_ext autoreload
%autoreload 2
# %flow mode reactive

In [None]:
"""Import packages."""

# Standard library
from collections import defaultdict
from datetime import datetime
from pathlib import Path
import os

# Numerics / Data
from scipy.ndimage import gaussian_filter1d
from scipy.linalg import orthogonal_procrustes
import numpy as np
import pandas as pd

# Plotting
import matplotlib.pyplot as plt
import plotly.colors as pc
import plotly.express as px

# Machine learning
from sklearn.decomposition import PCA
from sklearn.model_selection import train_test_split
from sklearn.linear_model import Ridge
from sklearn.metrics import r2_score
from sklearn.pipeline import make_pipeline
from sklearn.preprocessing import StandardScaler

# CEBRA
from cebra import CEBRA
import cebra.datasets

# Jupyter
from IPython.display import display

# Other third-party
from temporaldata import Data
import h5py

In [None]:
# Set max rows and cols for df display
pd.set_option("display.max_rows", 300)
pd.set_option("display.max_columns", 25)

## Load data

In [None]:
def clean_session_data(session):
    """Clean session data by filtering trials and spikes based on quality criteria."""
    
    # Mark Churchland gave me a matlab script which I have converted to python

    num_trials = len(session.trials.start)
    # I think is_valid is already defined by brainsets as (session.trials.discard_trial == 0) & (session.trials.task_success == 1) so it's a bit redundant but including it all just to be sure
    # In theory I should also filter based on whether the maze was possible or not (a field called "unhittable") but I cannot find this in the data, perhaps this has already been done in this release of the data
    good_trials = (session.trials.trial_type > 0) & (session.trials.is_valid == 1) & (session.trials.discard_trial == 0) & (session.trials.novel_maze == 0) & (session.trials.trial_version < 3) 
    session.trials = session.trials.select_by_mask(good_trials)
    new_num_trials = len(session.trials.start)
    if num_trials - new_num_trials > 0:
        print("Filtered out extraneous trials, went from", num_trials, "trials to", new_num_trials)

    num_trials = len(session.trials.start)
    success = (session.trials.task_success == 1)
    session.trials = session.trials.select_by_mask(success)
    new_num_trials = len(session.trials.start)
    if num_trials - new_num_trials > 0:
        print("Filtered out unsuccessful trials, went from", num_trials, "trials to", new_num_trials)

    num_trials = len(session.trials.start)
    post_move = 0.8 # to be kept, there must be at least this many ms after the movement onset
    long_enough = (session.trials.end - session.trials.move_begins_time >= post_move) # should essentially always be true for successes
    session.trials = session.trials.select_by_mask(long_enough)
    new_num_trials = len(session.trials.start)
    if num_trials - new_num_trials > 0:
        print("Filtered out trials that were too short, went from", num_trials, "trials to", new_num_trials)

    num_trials = len(session.trials.start)
    consistent = (session.trials.correct_reach == 1)
    session.trials = session.trials.select_by_mask(consistent)
    new_num_trials = len(session.trials.start)
    if num_trials - new_num_trials > 0:
        print("Filtered out trials with inconsistent reaches (not similar enough to the \"prototypical\" trial), went from", num_trials, "trials to", new_num_trials)

    primary_conditions = np.unique(session.trials.maze_condition)
    num_conditions = len(primary_conditions)
    print("Number of primary conditions:", num_conditions)
    # Check to make sure they are monotonic, starting from 1 and counting up
    if min(primary_conditions) != 1 or len(np.unique(np.diff(primary_conditions))) != 1:
        raise ValueError("Primary conditions are not monotonic or do not start from 1")

    # In theory I should filter units based on a ranking from 1-4 but I cannot find the ranking in the data, perhaps this has already been done in this release of the data

    # Only keep spikes that are within the cleaned trials
    session.spikes = session.spikes.select_by_interval(session.trials)
    session.hand = session.hand.select_by_interval(session.trials)
    session.eye = session.eye.select_by_interval(session.trials)

    # Convert session recording date to timestamp
    session.session.recording_date = datetime.strptime(session.session.recording_date, '%Y-%m-%d %H:%M:%S')
    session.session.recording_date = session.session.recording_date.timestamp()
    
    return session

In [None]:
def analyze_maze_conditions(session):
    """
    Analyze what each maze_condition corresponds to in terms of 
    maze parameters (barriers, targets, hit position).
    """
    
    # Get unique maze conditions
    unique_conditions = np.unique(session.trials.maze_condition)
    
    # Create a summary for each condition
    condition_summary = []
    
    for condition in unique_conditions:
        # Get trials for this condition
        condition_mask = session.trials.maze_condition == condition
        
        # Get the unique values for this condition
        barriers = np.unique(session.trials.maze_num_barriers[condition_mask])
        targets = np.unique(session.trials.maze_num_targets[condition_mask])
        hit_position = np.unique(session.trials.hit_target_position[condition_mask], axis=0)
        if len(hit_position) > 1:
            raise ValueError(f"Condition {condition} has multiple hit positions: {hit_position}")
        else:
            hit_position = hit_position[0]
        
        # Count trials for this condition
        num_trials = np.sum(condition_mask)
        
        # Store for summary table
        condition_summary.append({
            'Maze Condition': condition,
            'Trials': num_trials,
            'Barriers': barriers,
            'Targets': targets,
            'Hit Position': hit_position, 
            'Hit Position Angles': str(np.degrees(np.arctan2(hit_position[1], hit_position[0])))
        })
    summary_df = pd.DataFrame(condition_summary)
    
    # Convert hit positions to tuples temporarily for proper duplicate detection
    summary_df_temp = summary_df.copy()
    summary_df_temp['Hit Position Tuple'] = summary_df_temp['Hit Position'].apply(tuple)
    plot_df = summary_df_temp.drop_duplicates(subset=['Hit Position Tuple'], keep='first')
    plot_df = plot_df.drop('Hit Position Tuple', axis=1)  # Remove the temporary column
    
    # Create a proper DataFrame for plotting
    plot_data = pd.DataFrame({
        'Hit Position X': plot_df['Hit Position'].apply(lambda x: x[0]),
        'Hit Position Y': plot_df['Hit Position'].apply(lambda x: x[1]),
        'Maze Condition': plot_df['Maze Condition'].astype(str)
    })
    
    # Generate unique colors for each maze condition
    n_conditions = len(plot_data['Maze Condition'].unique())
    colors = pc.sample_colorscale('viridis', [i/(max(n_conditions-1, 1)) for i in range(n_conditions)])
    
    # Plot hit position by maze condition
    fig = px.scatter(
        plot_data,
        x='Hit Position X',
        y='Hit Position Y',
        color='Maze Condition',
        labels={'Hit Position X': 'Hit Position X', 'Hit Position Y': 'Hit Position Y', 'color': 'Maze Condition'},
        title='Hit Position by Maze Condition',
        color_discrete_sequence=colors,
        hover_data=['Maze Condition']
    )
    
    fig.update_layout(
        xaxis=dict(scaleanchor="y", scaleratio=1, range=[-150, 150]),
        yaxis=dict(constrain="domain", range=[-100, 100]),
        width=600,
        height=600
    )
    fig.show()
    
    return summary_df

In [None]:
# Path to your data directory
# data_path = r"C:\Users\pouge\Documents\mini_data\brainsets\processed\churchland_shenoy_neural_2012"
data_path = "/ceph/aeon/aeon/mini/brainsets_data/processed/churchland_shenoy_neural_2012"
data_path = Path(data_path)

# List all h5 files in the directory
h5_files = [f for f in os.listdir(data_path) if f.endswith('.h5')]
print(f"Available h5 files: {h5_files}")

# User parameters
subject_name = "nitschke"  # Change to "nitschke" or "jenkins"
num_files_to_load = 3     # Change to desired number of files, max 6 (only 3 work) for nitschke, 4 for jenkins

# Filter files by subject
subject_files = [f for f in h5_files if subject_name.lower() in f.lower()]
subject_files.sort()
print(f"\nFiles for subject {subject_name}: {subject_files}")

if len(subject_files) == 0:
    print(f"No files found for subject {subject_name}")
elif len(subject_files) < num_files_to_load:
    print(f"Only {len(subject_files)} files available for {subject_name}, loading all of them")
    num_files_to_load = len(subject_files)

# Load and clean the specified number of files
sessions = []
for i in range(min(num_files_to_load, len(subject_files))):
    file_path = os.path.join(data_path, subject_files[i])
    print(f"\nLoading file {i+1}/{num_files_to_load}: {subject_files[i]}")
    
    # Read neural data from HDF5
    with h5py.File(file_path, "r") as f:
        session = Data.from_hdf5(f)

        session.spikes.materialize()
        session.trials.materialize()
        session.hand.materialize()
        session.eye.materialize()
        session.session.materialize()

        print("Session ID: ", session.session.id)
        print("Session subject id: ", session.subject.id)
        print("Session subject sex: ", session.subject.sex)
        print("Session subject species: ", session.subject.species)
        print("Session recording date: ", session.session.recording_date)
        print("Original number of trials:", len(session.trials.start))
        
        # Clean the session data
        try:
            session = clean_session_data(session)
            print("Final number of trials after cleaning:", len(session.trials.start))

            print("Summary of primary conditions:")
            primary_conditions_summary = analyze_maze_conditions(session)
            display(primary_conditions_summary)
            
            sessions.append(session)
        except Exception as e:
            print(f"Error processing session {session.session.id}: {e}")
            continue

print(f"\nSuccessfully loaded and cleaned {len(sessions)} sessions for subject {subject_name}")

In [None]:
def fix_maze_conditions_consistency(sessions):
    """
    Fix maze condition numbering to be consistent across all sessions.
    
    Args:
        sessions: List of session objects
    
    Returns:
        List of cleaned sessions with consistent maze condition numbering
    """
    
    def get_maze_signature(session, condition):
        """Get maze parameters for a specific condition to create a signature"""
        condition_mask = session.trials.maze_condition == condition
        
        # Get unique values for this condition
        barriers = tuple(np.unique(session.trials.maze_num_barriers[condition_mask]))
        targets = tuple(np.unique(session.trials.maze_num_targets[condition_mask]))
        hit_position = np.unique(session.trials.hit_target_position[condition_mask], axis=0)
        if len(barriers) > 1 or len(targets) > 1 or len(hit_position) > 1:
            raise ValueError(f"Condition {condition} has the following >1 unique values for one of the following: "
                             f"barriers={barriers}, targets={targets}, hit_position={hit_position}. "
                             "This should not be possible.")
        else:
            hit_position = tuple(tuple(hit_position[0]))
        
        return (barriers, targets, hit_position)
    
    def get_group_signature(session, group_conditions):
        """Get combined signature for a group of 3 conditions"""
        group_sigs = []
        for condition in sorted(group_conditions):
            sig = get_maze_signature(session, condition)
            group_sigs.append(sig)
        return tuple(group_sigs)
    
    # Remove sessions that don't have multiples of 3 maze conditions
    valid_sessions = []
    for i, session in enumerate(sessions):
        unique_conditions = np.unique(session.trials.maze_condition)
        num_conditions = len(unique_conditions)
        
        if num_conditions % 3 != 0:
            print(f"WARNING: Removing session {session.session.id} - has {num_conditions} maze conditions (not multiple of 3)")
            continue
        
        valid_sessions.append(session)
    
    if len(valid_sessions) == 0:
        raise ValueError("No valid sessions remaining after filtering")
    
    print(f"Kept {len(valid_sessions)} out of {len(sessions)} sessions after filtering")
    
    # Pick reference session (most unique maze conditions)
    condition_counts = []
    for session in valid_sessions:
        unique_conditions = len(np.unique(session.trials.maze_condition))
        condition_counts.append(unique_conditions)
    
    max_conditions = max(condition_counts)
    ref_idx = condition_counts.index(max_conditions)
    reference_session = valid_sessions[ref_idx]
    
    print(f"Using session {reference_session.session.id} as reference (has {max_conditions} maze conditions)")
    
    # Group conditions and check for duplicates in reference session
    ref_conditions = sorted(np.unique(reference_session.trials.maze_condition))
    ref_groups = []
    
    for i in range(0, len(ref_conditions), 3):
        group = ref_conditions[i:i+3]
        if len(group) != 3:
            raise ValueError(f"Reference session has incomplete group: {group}")
        ref_groups.append(group)
    
    # Create signatures for reference groups
    ref_group_signatures = {}
    ref_signatures_to_group = {}
    
    for group_idx, group in enumerate(ref_groups):
        signature = get_group_signature(reference_session, group)
        
        if signature in ref_signatures_to_group:
            existing_group = ref_signatures_to_group[signature]
            print(f"WARNING: Duplicate group found in reference session!")
            print(f"  Group {existing_group} and Group {group} have identical maze parameters")
        
        ref_group_signatures[group_idx] = signature
        ref_signatures_to_group[signature] = group
    
    print(f"Reference session has {len(ref_groups)} groups of maze conditions")

    print("Table of reference session maze conditions:")
    primary_conditions_summary = analyze_maze_conditions(reference_session)
    display(primary_conditions_summary) 
    
    # Process all sessions to match groups and renumber
    processed_sessions = []
    
    for session in valid_sessions:
        print(f"Processing session {session.session.id}...")
        
        # Get conditions and group them
        conditions = sorted(np.unique(session.trials.maze_condition))
        session_groups = []
        
        for i in range(0, len(conditions), 3):
            group = conditions[i:i+3]
            session_groups.append(group)
        
        # Match each group to reference
        condition_mapping = {}  # old_condition -> new_condition
        
        for group in session_groups:
            group_sig = get_group_signature(session, group)
            
            # Find matching reference group
            matched_ref_group_idx = None
            for ref_idx, ref_sig in ref_group_signatures.items():
                if group_sig == ref_sig:
                    matched_ref_group_idx = ref_idx
                    break
            
            if matched_ref_group_idx is None:
                print(f"ERROR: Could not match group {group} in session {session.session.id}")
                print(f"Group signature: {group_sig}")
                print("Available reference signatures:")
                for ref_idx, ref_sig in ref_group_signatures.items():
                    ref_group = ref_groups[ref_idx]
                    print(f"  Reference group {ref_group}: {ref_sig}")
                raise ValueError(f"Unmatchable maze conditions {group} in session {session.session.id}")
            
            # Map old conditions to new conditions
            new_base_condition = matched_ref_group_idx * 3 + 1  # 1, 4, 7, 10, ...
            for i, old_condition in enumerate(sorted(group)):
                new_condition = new_base_condition + i
                condition_mapping[old_condition] = new_condition
        
        
        # Apply the mapping to the session
        new_maze_conditions = np.array([condition_mapping[old] for old in session.trials.maze_condition])
        session.trials.maze_condition = new_maze_conditions
        
        processed_sessions.append(session)
    
    print(f"\nSuccessfully processed {len(processed_sessions)} sessions with consistent maze condition numbering")
    return processed_sessions

# Add this after your existing code, after sessions are loaded and cleaned:
print("Fixing maze condition consistency across sessions...")

sessions = fix_maze_conditions_consistency(sessions)

# Show updated condition summaries for verification
print("\nCondition remapping summary:")
for i, session in enumerate(sessions):
    unique_conditions = sorted(np.unique(session.trials.maze_condition))
    print(f"Session {session.session.id}: {unique_conditions}")

In [None]:
# Uncomment to train a model on a specific session post fixing maze conditions consistency
# print(len(sessions))
# sessions =  [sessions[i] for i in [2]]
# print(len(sessions))

In [None]:
# Parameters
bin_size = 0.05

# Check unit consistency across sessions
unit_ids = np.unique(sessions[0].spikes.unit_index)
for session in sessions:
    unique_units = np.unique(session.spikes.unit_index)
    if not np.array_equal(unique_units, unit_ids):
        raise ValueError("Sessions do not have the same unit IDs. Cannot combine spike data.")

# Determine global bin alignment start point
global_start = min(session.session.recording_date + session.trials.start.min() for session in sessions)
global_start = np.floor(global_start / bin_size) * bin_size  # ensure clean bin alignment

# Convert to consistent timestamps
n_decimals = int(-np.log10(bin_size)) + 1 if bin_size < 1 else 0

# Accumulator for all binned trials
binned_dfs = []

# Loop over sessions
for session in sessions:
    # Shift spike timestamps to absolute time
    abs_timestamps = session.spikes.timestamps + session.session.recording_date
    unit_ids_this_session = session.spikes.unit_index
    df_spikes = pd.DataFrame({
        'timestamp': abs_timestamps,
        'unit': unit_ids_this_session
    }).sort_values('timestamp')

    # Convert trial times to absolute time
    trial_starts = session.trials.start + session.session.recording_date
    trial_ends = session.trials.end + session.session.recording_date
    df_trials = pd.DataFrame({
        'trial_start': trial_starts,
        'trial_end': trial_ends
    }).sort_values('trial_start')

    # Assign each spike to the most recent trial_start <= timestamp
    df_merged = pd.merge_asof(
        df_spikes,
        df_trials[['trial_start', 'trial_end']],
        left_on='timestamp',
        right_on='trial_start',
        direction='backward'
    )

    # Drop spikes that fall outside their trial interval
    df_merged = df_merged[df_merged['timestamp'] < df_merged['trial_end']]

    # Compute bin index relative to global bin start
    df_merged['bin'] = ((df_merged['timestamp'] - global_start) / bin_size).astype(int)

    # Group by (bin, unit) and count spikes
    df_counts = (
        df_merged
        .groupby(['bin', 'unit'], observed=True)
        .size()
        .reset_index(name='count')
    )

    # Pivot to wide format: units as columns
    spk_cts = df_counts.pivot_table(
        index='bin',
        columns='unit',
        values='count',
        fill_value=0
    )

    # Convert to consistent timestamps
    session_timestamps = np.round(global_start + spk_cts.index * bin_size, n_decimals)
    spk_cts.index = pd.Index(session_timestamps, name='timestamp')
    spk_cts.columns.name = None

    # Collect results
    binned_dfs.append(spk_cts)

# Concatenate all binned trials across sessions
spk_cts_df = pd.concat(binned_dfs)
spk_cts_df = spk_cts_df[~spk_cts_df.index.duplicated()]
spk_cts_df.sort_index(inplace=True)

# Result: each row = time bin, each col = unit, values = spike count
display(spk_cts_df)

In [None]:
# Compute mean firing rate (Hz) per unit
duration_sec = len(spk_cts_df) * bin_size
mean_firing_rates = spk_cts_df.sum(axis=0) / duration_sec  # spikes/sec

# Plot histogram
plt.figure(figsize=(8, 4))
plt.hist(mean_firing_rates, bins=30, edgecolor='black')
plt.xlabel('Mean Firing Rate (Hz)')
plt.ylabel('Number of Units')
plt.title('Distribution of Mean Firing Rates')
plt.grid(True)
plt.tight_layout()
plt.show()

# Print summary stats
print("Mean firing rate across units: {:.2f} Hz".format(mean_firing_rates.mean()))
print("Median: {:.2f} Hz".format(np.median(mean_firing_rates)))
print("Range: {:.2f}–{:.2f} Hz".format(mean_firing_rates.min(), mean_firing_rates.max()))

In [None]:
# Flatten the spike count matrix to a 1D array
flattened_spike_counts = spk_cts_df.values.flatten()

# Plot histogram of spike counts
plt.figure(figsize=(8, 5))
plt.hist(flattened_spike_counts, bins=50, edgecolor='black')
plt.title("Distribution of Spike Counts per Unit per Time Bin")
plt.xlabel("Spike Count")
plt.ylabel("Number of Unit-Bin Combinations")
plt.yscale("log")  # Optional: log scale to better visualize skewed distributions
plt.grid(True)
plt.tight_layout()
plt.show()

In [None]:
"""Check sparsity of binned spike counts."""

frac_nonzero_bins = (spk_cts_df != 0).values.sum() / spk_cts_df.size
frac_nonzero_examples = (spk_cts_df.sum(axis=1) > 0).mean()
print(f"{frac_nonzero_bins=:.4f}")
print(f"{frac_nonzero_examples=:.4f}")

In [None]:
"""Arrange the metadata."""

# Collect data from all sessions into arrays
hand_data = defaultdict(list)
eye_data = defaultdict(list)
trial_data = defaultdict(list)

for i, session in enumerate(sessions):
    recording_date = session.session.recording_date
    
    # Hand data - extract timestamps and 2D arrays
    timestamps = session.hand.timestamps + recording_date
    acc = session.hand.acc_2d
    pos = session.hand.pos_2d
    vel = session.hand.vel_2d
    
    hand_data['timestamp'].append(timestamps)
    hand_data['acc_x'].append(acc[:, 0])
    hand_data['acc_y'].append(acc[:, 1])
    hand_data['pos_x'].append(pos[:, 0])
    hand_data['pos_y'].append(pos[:, 1])
    hand_data['vel_x'].append(vel[:, 0])
    hand_data['vel_y'].append(vel[:, 1])
    hand_data['session'].append(np.full(len(timestamps), i))
    
    # Eye data - extract timestamps and position arrays
    eye_timestamps = session.eye.timestamps + recording_date
    eye_pos = session.eye.pos
    
    eye_data['timestamp'].append(eye_timestamps)
    eye_data['pos_x'].append(eye_pos[:, 0])
    eye_data['pos_y'].append(eye_pos[:, 1])
    
    # Trial data - add recording_date to all time columns
    trial_data['start'].append(session.trials.start + recording_date)
    trial_data['end'].append(session.trials.end + recording_date)
    trial_data['target_on_time'].append(session.trials.target_on_time + recording_date)
    trial_data['go_cue_time'].append(session.trials.go_cue_time + recording_date)
    trial_data['move_begins_time'].append(session.trials.move_begins_time + recording_date)
    trial_data['move_ends_time'].append(session.trials.move_ends_time + recording_date)
    trial_data['maze_condition'].append(session.trials.maze_condition)
    trial_data['barriers'].append(session.trials.maze_num_barriers)
    trial_data['targets'].append(session.trials.maze_num_targets)
    trial_data['hit_position_x'].append([pos[0] for pos in session.trials.hit_target_position])
    trial_data['hit_position_y'].append([pos[1] for pos in session.trials.hit_target_position])
    trial_data['hit_position_angle'].append([np.degrees(np.arctan2(pos[1], pos[0])) for pos in session.trials.hit_target_position])

# Concatenate all arrays into final datasets
combined_hand_data = {}
for key, arrays in hand_data.items():
    combined_hand_data[key] = np.concatenate(arrays)

combined_eye_data = {}
for key, arrays in eye_data.items():
    combined_eye_data[key] = np.concatenate(arrays)

combined_trial_data = {}
for key, arrays in trial_data.items():
    combined_trial_data[key] = np.concatenate(arrays)

# Create final DataFrames
combined_hand_df = pd.DataFrame(combined_hand_data).set_index('timestamp')
combined_eye_df = pd.DataFrame(combined_eye_data).set_index('timestamp')
combined_trials_df = pd.DataFrame(combined_trial_data)

# Create unified timestamp index from all data sources
all_event_ts = np.concatenate([
    combined_trials_df['target_on_time'].values,
    combined_trials_df['go_cue_time'].values,
    combined_trials_df['move_begins_time'].values,
    combined_trials_df['move_ends_time'].values,
])

# Get unique timestamps across all data
all_ts = np.unique(np.concatenate([
    combined_hand_df.index.values,
    combined_eye_df.index.values,
    all_event_ts
]))

# Create master dataframe with unified timestamp index
metadata = pd.DataFrame(index=all_ts)
metadata.index.name = 'timestamp'

# Merge hand and eye data
metadata = metadata.join(combined_hand_df, how='left')
metadata = metadata.join(combined_eye_df, how='left', rsuffix='_eye')

# Add event column - mark timestamps that correspond to trial events
event_map = {
    'target_on_time': 'target_on',
    'go_cue_time': 'go_cue',
    'move_begins_time': 'move_begins',
    'move_ends_time': 'move_ends',
}
event_col = pd.Series(index=metadata.index, dtype="object")
for col, label in event_map.items():
    event_times = combined_trials_df[col].values
    mask = np.isin(metadata.index.values, event_times)
    event_col.iloc[mask] = label
metadata['event'] = event_col

# Add trial_idx column - assign each timestamp to its trial
# Use binary search to efficiently find which trial each timestamp belongs to
trial_idx_series = pd.Series(index=metadata.index, dtype='float64')

# Sort trials by start time for binary search
trial_sort_idx = np.argsort(combined_trials_df['start'].values)
starts = combined_trials_df['start'].values[trial_sort_idx]
ends = combined_trials_df['end'].values[trial_sort_idx]

# Find potential trial for each timestamp
timestamps = metadata.index.values
start_positions = np.searchsorted(starts, timestamps, side='right') - 1

# Check which timestamps are within valid trial intervals
valid_mask = (start_positions >= 0) & (start_positions < len(starts))
valid_positions = start_positions[valid_mask]
valid_timestamps = timestamps[valid_mask]

# Verify timestamps are before trial end times
end_mask = valid_timestamps <= ends[valid_positions]
final_valid_mask = np.zeros(len(timestamps), dtype=bool)
final_valid_mask[valid_mask] = end_mask

# Assign trial indices to timestamps
trial_indices = np.full(len(timestamps), np.nan)
trial_indices[final_valid_mask] = trial_sort_idx[valid_positions[end_mask]]
trial_idx_series.iloc[:] = trial_indices

metadata['trial_idx'] = trial_idx_series

# Map trial properties using the trial indices
metadata['maze_condition'] = metadata['trial_idx'].astype('Int64').map(
    combined_trials_df['maze_condition']
)
metadata['barriers'] = metadata['trial_idx'].astype('Int64').map(
    combined_trials_df['barriers']
)
metadata['targets'] = metadata['trial_idx'].astype('Int64').map(
    combined_trials_df['targets']
)
metadata['hit_position_x'] = metadata['trial_idx'].astype('Int64').map(
    combined_trials_df['hit_position_x']
)
metadata['hit_position_y'] = metadata['trial_idx'].astype('Int64').map(
    combined_trials_df['hit_position_y']
)
metadata['hit_position_angle'] = metadata['trial_idx'].astype('Int64').map(
    combined_trials_df['hit_position_angle']
)

# Add movement_angle column based on position difference
pos_delta_x = metadata['pos_x'].diff()
pos_delta_y = metadata['pos_y'].diff()
metadata['movement_angle'] = np.degrees(np.arctan2(pos_delta_y, pos_delta_x))

# Calculate speed and acceleration magnitudes from vector components
metadata['vel_magnitude'] = np.sqrt(
    metadata['vel_x'].values**2 + metadata['vel_y'].values**2
)
metadata['accel_magnitude'] = np.sqrt(
    metadata['acc_x'].values**2 + metadata['acc_y'].values**2
)

# Show result
print("Metadata:")
display(metadata)

In [None]:
"""Bin the metadata to match spike counts."""

# Create metadata_binned with consistent timestamps
ts = spk_cts_df.index.values
metadata_binned = pd.DataFrame(index=pd.Index(ts, name='timestamp'))

# Assign each metadata row to a bin index
bin_ids = np.digitize(metadata.index.values, ts) - 1
bin_ids = np.clip(bin_ids, 0, len(ts) - 1)

# Handle event aggregation 
event_values = metadata['event'].values
event_mask = pd.notna(event_values)

if event_mask.any():
    event_bin_ids = bin_ids[event_mask]
    valid_events = event_values[event_mask].astype(str)
    
    # Create event assignments using vectorized operations
    event_agg = np.full(len(ts), None, dtype=object)
    
    # Calculate target bins for all events at once
    event_indices = np.arange(len(valid_events))
    target_bins = event_bin_ids + event_indices
    
    # Only assign events that fall within valid bin range
    valid_targets = target_bins < len(ts)
    event_agg[target_bins[valid_targets]] = valid_events[valid_targets]
else:
    event_agg = np.full(len(ts), None, dtype=object)

# Efficient nearest neighbor reindexing using searchsorted
metadata_timestamps = metadata.index.values
ts_positions = np.searchsorted(metadata_timestamps, ts, side='left')

# Handle edge cases and find true nearest neighbors
ts_positions = np.clip(ts_positions, 0, len(metadata_timestamps) - 1)

# For positions not at the start, check if the previous position is closer
mask = ts_positions > 0
left_positions = ts_positions.copy()
left_positions[mask] = ts_positions[mask] - 1

# Calculate distances to determine nearest
left_distances = np.abs(ts - metadata_timestamps[left_positions])
right_distances = np.abs(ts - metadata_timestamps[ts_positions])

# Choose the nearest position
final_positions = np.where(left_distances < right_distances, left_positions, ts_positions)

# Copy all columns from nearest metadata
for col in metadata.columns:
    if col != 'event':
        metadata_binned[col] = metadata[col].iloc[final_positions].values

# Insert distributed events
metadata_binned['event'] = event_agg

# Result: metadata for all trial bins, one row per bin
print("Metadata binned:")
display(metadata_binned)

## Train CEBRA

In [None]:
# """Prepare data for CEBRA model training."""

# spikes = spk_cts_df.values.astype(np.float32)
# sigma = 0.05 / bin_size 
# spikes = gaussian_filter1d(spikes, sigma=sigma, axis=0)
# spikes = spikes / np.max(spikes, axis=0, keepdims=True) # normalise 
# feature = metadata_binned["vel_x"].to_numpy(dtype=np.float32)
# velocity = np.column_stack([
#     metadata_binned["vel_x"].to_numpy(dtype=np.float32),
#     metadata_binned["vel_y"].to_numpy(dtype=np.float32)
# ])

In [None]:
# """Create a train/validation split."""

# spikes_train, spikes_val, feature_train, feature_val, velocity_train, velocity_val = train_test_split(
#     spikes, feature, velocity, test_size=0.2, random_state=42, shuffle=False
# )

In [None]:
"""Prepare data for CEBRA model training."""

spikes = spk_cts_df.values.astype(np.float32)
feature  = metadata_binned["vel_x"].to_numpy(dtype=np.float32)
velocity = np.column_stack([
    metadata_binned["vel_x"].to_numpy(dtype=np.float32),
    metadata_binned["vel_y"].to_numpy(dtype=np.float32),
])

In [None]:
"""Create a train/validation split."""

n = spikes.shape[0]
idx = np.arange(n)
train_idx, val_idx = train_test_split(idx, test_size=0.2, shuffle=False)

In [None]:
"""Split and normalise."""

# Split spikes and smooth within each split (no cross-boundary leakage)
sigma = 0.05 / bin_size
spikes_train = gaussian_filter1d(spikes[train_idx], sigma=sigma, axis=0)
spikes_val = gaussian_filter1d(spikes[val_idx],   sigma=sigma, axis=0)

# Normalise spikes (fit on train only, apply to both)
train_max = np.max(spikes_train, axis=0, keepdims=True)
spikes_train = spikes_train / train_max
spikes_val = spikes_val   / train_max

# Split 
feature_train = feature[train_idx]
feature_val = feature[val_idx]
velocity_train = velocity[train_idx]
velocity_val = velocity[val_idx]

In [None]:
"""Grid search for hyperparameter tuning."""

params_grid = dict(
    output_dimension = [48], 
    time_offsets = [1,2], # in the paper for 20ms bins they use 1-2 so I think 1 here for 50ms is good? or [0,1]?
    model_architecture='offset10-model',
    temperature_mode='constant',
    temperature=np.linspace(0.0001, 0.004, 10).tolist(),
    max_iterations=[5000],
    batch_size = [512], 
    device='cuda_if_available',
    num_hidden_units = [[128, 256, 512]], 
    verbose = True)

datasets = {"dataset1": spikes_train}

# Run the grid search
session_dates = []
for session in sessions:
    session_date = datetime.fromtimestamp(session.session.recording_date)
    session_date = session_date.strftime("%Y%m%d")
    session_dates.append(session_date)
session_dates_str = "_".join(session_dates)
save_dir = data_path / f"{subject_name}_{session_dates_str}" / "CEBRA_models"
grid_search = cebra.grid_search.GridSearch()
grid_search.fit_models(datasets, params=params_grid, models_dir=save_dir)

In [None]:
"""Load top model."""

df_results = grid_search.get_df_results(models_dir=save_dir)
best_model, best_model_name = grid_search.get_best_model(dataset_name="dataset1", models_dir=save_dir)
# best_model_name = "dataset1"
print("The best model is:", best_model_name)

model_path = save_dir / f"{best_model_name}.pt"
top_model = cebra.CEBRA.load(model_path, weights_only=False)

# Transform
top_train_embedding = top_model.transform(spikes_train)
top_val_embedding = top_model.transform(spikes_val)

# InfoNCE loss
loss_train = cebra.sklearn.metrics.infonce_loss(top_model, spikes_train, num_batches=200)
loss_val = cebra.sklearn.metrics.infonce_loss(top_model, spikes_val, num_batches=200)
print("InfoNCE loss (train):", loss_train)
print("InfoNCE loss (validation):", loss_val)

In [None]:
"""Create consistent random samples for plotting (reuse for CEBRA & PCA)."""

# Train sample
n_plot_train = min(10_000, top_train_embedding.shape[0])
idx_train = np.random.choice(top_train_embedding.shape[0], size=n_plot_train, replace=False)

# Val sample
n_plot_val = min(10_000, top_val_embedding.shape[0])
idx_val = np.random.choice(top_val_embedding.shape[0], size=n_plot_val, replace=False)

# Subset embeddings and features once and reuse
top_train_embedding_sample = top_train_embedding[idx_train, :]
feature_train_sample = feature_train[idx_train]
top_val_embedding_sample = top_val_embedding[idx_val, :]
feature_val_sample = feature_val[idx_val]

In [None]:
"""Plot CEBRA embeddings (train / validation)."""

fig = cebra.integrations.plotly.plot_embedding_interactive(
    top_train_embedding_sample,
    embedding_labels=feature_train_sample,
    title="CEBRA-Time (train)",
    markersize=3,
    cmap="rainbow"
)
fig.show()

fig = cebra.integrations.plotly.plot_embedding_interactive(
    top_val_embedding_sample,
    embedding_labels=feature_val_sample,
    title="CEBRA-Time (validation)",
    markersize=3,
    cmap="rainbow"
)
fig.show()

# Plot loss curve
ax = cebra.plot_loss(top_model)

In [None]:
"""PCA with same samples."""

# Fit PCA on train spikes; transform both splits
pca = PCA(n_components=3)
pcs_train = pca.fit_transform(spikes_train)
pcs_val = pca.transform(spikes_val)

print("Explained variance ratio (PC1..PC3):", np.round(pca.explained_variance_ratio_, 4))
print("Total explained variance (3 PCs):", np.round(pca.explained_variance_ratio_.sum(), 4))

# Use the exact same indices as above
pcs_train_sample = pcs_train[idx_train, :]
pcs_val_sample = pcs_val[idx_val, :]

# Plot (train)
df_plot_train = pd.DataFrame({
    "PC1": pcs_train_sample[:, 0],
    "PC2": pcs_train_sample[:, 1],
    "PC3": pcs_train_sample[:, 2],
    "feature": feature_train_sample,
})
fig = px.scatter_3d(
    df_plot_train, x="PC1", y="PC2", z="PC3",
    color="feature", title="PCA (train, top 3 PCs)",
    opacity=0.9
)
fig.update_traces(marker=dict(size=3))
fig.update_layout(margin=dict(l=0, r=0, t=40, b=0))
fig.show()

# Plot (validation)
df_plot_val = pd.DataFrame({
    "PC1": pcs_val_sample[:, 0],
    "PC2": pcs_val_sample[:, 1],
    "PC3": pcs_val_sample[:, 2],
    "feature": feature_val_sample,
})
fig = px.scatter_3d(
    df_plot_val, x="PC1", y="PC2", z="PC3",
    color="feature", title="PCA (validation, top 3 PCs)",
    opacity=0.9
)
fig.update_traces(marker=dict(size=3))
fig.update_layout(margin=dict(l=0, r=0, t=40, b=0))
fig.show()

## Align embeddings from multiple models

In [None]:
save_dir = Path(save_dir)

# Load all .pt files in folder
pt_files = sorted(save_dir.glob("*.pt"))
if not pt_files:
    raise FileNotFoundError(f"No .pt files found in {save_dir}")

print(f"Found {len(pt_files)} models.")

# Compute validation losses and store together with paths
model_losses = []
for p in pt_files:
    model = cebra.CEBRA.load(p, weights_only=False)
    loss_val = cebra.sklearn.metrics.infonce_loss(model, spikes_val, num_batches=200)
    model_losses.append((p, loss_val))

# Sort by loss
model_losses.sort(key=lambda x: x[1])
print("Model losses (sorted):")
for p, l in model_losses:
    print(f"{p.name}: {l:.4f}")

# Apply loss threshold
loss_threshold = -200  # change this as needed
model_losses = [(p, l) for p, l in model_losses if l <= loss_threshold]
print(f"Using {len(model_losses)} models after applying loss threshold {loss_threshold}.")

# Get embeddings 
emb_train_list = []
emb_val_list = []
for p, _ in model_losses:
    model = cebra.CEBRA.load(p, weights_only=False)
    emb_tr = model.transform(spikes_train).astype(np.float64)
    emb_va = model.transform(spikes_val).astype(np.float64)
    emb_train_list.append(emb_tr)
    emb_val_list.append(emb_va)

# Procrustes alignment: fit on train then apply to train and val
def center(X, mu):
    return X - mu

def fit_alignment_to_ref(ref_tr, X_tr, allow_scaling=False):
    # fit using TRAIN data only
    mu_ref = ref_tr.mean(axis=0, keepdims=True)
    mu_X   = X_tr.mean(axis=0, keepdims=True)
    A, B = center(X_tr, mu_X), center(ref_tr, mu_ref)
    R, s = orthogonal_procrustes(A, B)  # A @ R * s ≈ B
    if not allow_scaling:
        s = 1.0
    return R, s, mu_X, mu_ref  # return means so we can apply consistently

def apply_alignment(X, R, s, mu_src, mu_ref):
    return center(X, mu_src) @ R * s  # no need to add mu_ref; we re-center after averaging

# reference = first model's TRAIN embedding
ref_tr = emb_train_list[0]

aligned_train = []
aligned_val = []
for Etr, Eva in zip(emb_train_list, emb_val_list):
    R, s, mu_src, mu_ref = fit_alignment_to_ref(ref_tr, Etr, allow_scaling=False)
    Etr_al = apply_alignment(Etr, R, s, mu_src, mu_ref)
    Eva_al = apply_alignment(Eva, R, s, mu_src, mu_ref)  # note: center with train mean
    aligned_train.append(Etr_al)
    aligned_val.append(Eva_al)

# Average across models, per split
avg_train_embedding = np.mean(np.stack(aligned_train, axis=0), axis=0)
avg_val_embedding   = np.mean(np.stack(aligned_val,   axis=0), axis=0)

# Final recentring (optional but common)
avg_train_embedding -= avg_train_embedding.mean(axis=0, keepdims=True)
avg_val_embedding   -= avg_val_embedding.mean(axis=0, keepdims=True)

print("Train averaged embedding shape:", avg_train_embedding.shape)
print("Val averaged embedding shape:",   avg_val_embedding.shape)

## Decode

In [None]:
# Train on avg_train_embedding, decode on avg_val_embedding with a small lag sweep
Etr, Eval = avg_train_embedding, avg_val_embedding
ytr, yva  = velocity_train, velocity_val

def apply_lag(E_tr, E_va, y_tr, y_va, lag_bins: int):
    if lag_bins > 0:
        Et, yt = E_tr[:-lag_bins], y_tr[lag_bins:]
        Ev, yv = E_va[:-lag_bins], y_va[lag_bins:]
    elif lag_bins < 0:
        k = -lag_bins
        Et, yt = E_tr[k:], y_tr[:-k]
        Ev, yv = E_va[k:], y_va[:-k]
    else:
        Et, yt, Ev, yv = E_tr, y_tr, E_va, y_va
    return Et, Ev, yt, yv

best = {"lag": 0, "r2_mean": -np.inf, "r2_per_dim": None}

for lag in range(0, 6):
    Et, Ev, yt, yv = apply_lag(Etr, Eval, ytr, yva, lag)

    # Train on train, evaluate on val
    decoder = make_pipeline(StandardScaler(with_mean=True, with_std=True),
                         Ridge(alpha=1.0))
    decoder.fit(Et, yt)
    y_pred = decoder.predict(Ev)

    r2_per_dim = r2_score(yv, y_pred, multioutput='raw_values')
    r2_mean = float(np.mean(r2_per_dim))

    if r2_mean > best["r2_mean"]:
        best.update(lag=lag, r2_mean=r2_mean, r2_per_dim=r2_per_dim, decoder=decoder)

print(f"Best lag (bins): {best['lag']}")
print(f"R² per dimension: {best['r2_per_dim']}")
print(f"Mean R²: {best['r2_mean']}")