In [None]:
import numpy as np
import matplotlib.pyplot as plt
import pandas as pd
import seaborn as sns
import os
from glob import glob
from pathlib import Path

from tkinter import Tk
from tkinter import filedialog
import pathlib

import datetime

from tqdm import tqdm
import tempfile

os.environ['KMP_DUPLICATE_LIB_OK']='True'
global_job_kwargs = dict(n_jobs=4, chunk_duration="1s")

In [None]:
'''
spike_mask: boolean array indicating whether the spike is a good unit or not
spike_seconds: time of the spike in seconds
spike_clusters: cluster ID of the spike
spike_positions: position of the spike in the probe
strobe_seconds: time of the strobe in seconds
estimated_strobe_seconds: estimated time of the strobe in seconds (artificially created by uniform interval between the first and the last strobe time stamps)
tagged_good_units: list of good units
tagged_mua_units: list of multi-unit activity (MUA) units
dlc: DeepLabCut data
events: event data
vame_labels: VAME labels

dlc tends to be bit longer than the strobe file and event file, so we need to trim it by the first index in the event file.
'''

# Open the folder selection dialog
root = Tk()         
root.withdraw()    
root.attributes("-topmost", True)

selected_dir = pathlib.Path(filedialog.askdirectory(title="Select folder"))
base_dir = str(selected_dir)
root.destroy()
print(base_dir)

######## load neural data ##########
spike_mask = np.load(os.path.join(selected_dir, "kilosort4\\spike_mask.npy"))
spike_seconds = np.load(os.path.join(selected_dir, "kilosort4\\spike_seconds_adj.npy"))[spike_mask]
spike_clusters = np.load(os.path.join(selected_dir, "kilosort4\\spike_clusters.npy"))[spike_mask]
spike_positions = np.load(os.path.join(selected_dir, "kilosort4\\spike_positions.npy"))[spike_mask]
templates = np.load(os.path.join(selected_dir, "kilosort4\\templates.npy"))


strobe_seconds = np.load(os.path.join(selected_dir, "kilosort4\\strobe_seconds.npy"))

tagged_good_units = np.load(os.path.join(selected_dir, "kilosort4\\tagged_good_units.npy"))
tagged_mua_units = np.load(os.path.join(selected_dir, "kilosort4\\tagged_mua_units.npy"))

unit_label = pd.read_csv(os.path.join(selected_dir, "kilosort4qMetrics\\templates._bc_unit_labels.tsv"), sep="\t")

######## load behavioral data ##########
session_name = selected_dir.parts[-1]
date_str = session_name.split('_')[1]

# Convert MMDDYYYY to YYYY-MM-DD
date_obj = datetime.datetime.strptime(date_str, "%m%d%Y")
date_str = date_obj.strftime("%Y-%m-%d")

# Search for matching file
event_dir = pathlib.Path(r"D:\Neuropixels\Event\9153")
event_file = event_dir.glob(f"*{date_str}*.csv")
events = pd.read_csv(event_file.__next__())
events = events.drop_duplicates(subset='Index', keep='last') # remove duplicates

dlc_dir = pathlib.Path(r"D:\Neuropixels\DLC\9153")
dlc_file = dlc_dir.glob(f"*{date_str}*.h5")
dlc = pd.read_hdf(dlc_file.__next__())
dlc.columns = dlc.columns.droplevel(0)
dlc = dlc.loc[dlc.index >= events['Index'].iloc[0]] # trim DLC data to the first index in the event file

# VAME_dir = pathlib.Path(r"D:\Neuropixels\VAME\9153")
# VAME_label_file = VAME_dir.glob(f"*{date_str}\VAME\hmm*\*hmm_label*.npy")
# vame_label_path = next(VAME_label_file, None)
# vame_labels = np.load(vame_label_path)[events['Index'].iloc[0]:]

# VAME_motif_usage_file = VAME_dir.glob(f"*{date_str}\VAME\hmm*\motif_usage*.npy")
# vame_motif_usage_path = next(VAME_motif_usage_file, None)
# vame_motif_usage = np.load(vame_motif_usage_path)

# VAME_community_label_file = VAME_dir.glob(f"*{date_str}\VAME\hmm*\community\*label*.npy")
# vame_community_label_path = next(VAME_community_label_file, None)
# vame_community_labels = np.load(vame_community_label_path)[events['Index'].iloc[0]:]

# VAME_community_file = VAME_dir.glob("community_cohort\hmm*\cohort_community_bag.npy")
# vame_community_path = next(VAME_community_file, None)
# vame_community = np.load(vame_community_path, allow_pickle=True)




# Check the number of frames in each file
start = events['Index'].iloc[0] # When the first event starts
end = events['Index'].iloc[-1]
print("First index: ", start)
print("Last index: ", end)
stim = events['Stim'].to_numpy()
print("Event file length: ", len(stim))
print("Strobe file length: ", len(strobe_seconds))

snout_x = dlc[('Snout', 'x')].to_numpy()
snout_y = dlc[('Snout', 'y')].to_numpy()
print("DLC file length: ", len(snout_x))

print("Estimated # of frames: ", (strobe_seconds[-1] - strobe_seconds[0]) * 89.97)

# Create estimated strobe timings array by uniform interval between the first and the last strobe time stamps
estimated_strobe_seconds = np.linspace(strobe_seconds[0], strobe_seconds[-1], len(events))
print("Strobe file length(estimate): ", len(estimated_strobe_seconds))

In [None]:
### Plot strobe intervals ###
strobe_interval = np.diff(strobe_seconds)
max(strobe_interval), min(strobe_interval), np.mean(strobe_interval), np.std(strobe_interval)
mean_si = np.mean(strobe_interval)
std_si = np.std(strobe_interval)
outlier_mask = np.abs(strobe_interval - mean_si) > 5 * std_si
outliers = strobe_interval[outlier_mask]

print("# of outlier strobe intervals:", len(outliers))

plt.figure(figsize=(10,10))
plt.subplot(3,1,1)
plt.hist(strobe_interval, bins=300)
plt.title("Strobe interval histogram", fontsize=20)
plt.ylabel("Counts", fontsize=16)
plt.xticks(fontsize=14)
plt.yticks(fontsize=14)
plt.subplot(3,1,2)
plt.plot(strobe_interval)
plt.title("Strobe interval time series", fontsize=20)
plt.ylabel("Strobe interval (s)", fontsize=16)
plt.xticks(fontsize=14)
plt.yticks(fontsize=14)
plt.subplot(3,1,3)
plt.plot(np.sort(outliers))
plt.title("Strobe interval outliers", fontsize=20)
plt.ylabel("Interval (s)", fontsize=16)
plt.xticks(fontsize=14)
plt.yticks(fontsize=14)
plt.tight_layout()
plt.show()

In [None]:
### Interpolate DLC data ###
from scipy.interpolate import UnivariateSpline

# Specify body parts
bodyparts = dlc.columns.get_level_values(0).unique().tolist()
coords = dlc.columns.get_level_values(1).unique().tolist()[0:2]

# Threshold for likelihood
likelihood_threshold = 0.7

dlc_interpolated = dlc.copy()

# Process each body part
for bp in bodyparts:
    for axis in coords:
        series = dlc[bp][axis]
        likelihood = dlc[bp]['likelihood']
        
        # Mask low-confidence data
        series_masked = series.copy()
        series_masked[likelihood < likelihood_threshold] = np.nan
        
        # Get valid indices and values
        valid = ~series_masked.isna()
        x_valid = np.arange(len(series_masked))[valid]
        y_valid = series_masked[valid]
        
        spline = UnivariateSpline(x_valid, y_valid, k=1, s=0)
        interpolated = spline(np.arange(len(series_masked)))
        series_filled = series_masked.copy()
        series_filled[~valid] = interpolated[~valid]
        dlc_interpolated[bp][axis] = series_filled

In [None]:
### xaxis: 150--1250
### yaxis: 50--1050

plt.figure(figsize=(10, 4))
plt.subplot(1,2,1)
plt.plot(dlc[('Snout', 'x')], dlc[('Snout', 'y')], lw=1)
plt.xlabel('Snout x')
plt.ylabel('Snout y')
plt.title('Snout Trajectory')
plt.gca().invert_yaxis()  # Optional: invert y-axis for image-like coordinates

plt.subplot(1,2,2)
plt.plot(dlc_interpolated[('Snout', 'x')], dlc_interpolated[('Snout', 'y')], lw=1)
plt.xlabel('Snout x')
plt.ylabel('Snout y')
plt.title('Snout Trajectory (interpolated)')
plt.gca().invert_yaxis()  # Optional: invert y-axis for image-like coordinates
plt.show()

In [None]:
from scipy.stats import zscore, ttest_1samp
from scipy.ndimage import label as nd_label

# --- Align spikes to events and plot PETHs for significantly tuned neurons ---


# Helper: Find event onsets (False to True transitions)
def find_event_onsets(bool_array):
    return np.where((~bool_array[:-1]) & (bool_array[1:]))[0] + 1

# 1. Corner visits (any corner)
corner_cols = ['Corner1', 'Corner2', 'Corner3', 'Corner4']
corner_onsets = []
for col in corner_cols:
    onsets = find_event_onsets(events[col].values)
    corner_onsets.append(onsets)
corner_onsets = np.sort(np.concatenate(corner_onsets))

# 2. Acceleration events (using snout, can be changed to other bodyparts)
# Get x and y coordinates for all six body parts
bodyparts = ['Snout', 'RF', 'LF', 'RH', 'LH', 'Tail']
x_coords = np.stack([dlc_interpolated[(bp, 'x')].values for bp in bodyparts], axis=1)
y_coords = np.stack([dlc_interpolated[(bp, 'y')].values for bp in bodyparts], axis=1)

# Average across body parts
mean_x = np.mean(x_coords, axis=1)
mean_y = np.mean(y_coords, axis=1)

# Calculate velocity and acceleration
dt = np.median(np.diff(estimated_strobe_seconds))
vx = np.gradient(mean_x, dt)
vy = np.gradient(mean_y, dt)
speed = np.sqrt(vx**2 + vy**2)
acceleration = np.gradient(speed, dt)

# Threshold for acceleration events (top 1%)
accel_thresh = np.percentile(np.abs(acceleration), 99)
accel_events = np.where(np.abs(acceleration) > accel_thresh)[0]

# 3. VAME movement events (each motif transition)
vame_transitions = np.where(np.diff(vame_labels) != 0)[0] + 1

# --- Align spikes to events and compute PETHs ---
def compute_peth(spike_times, event_times, window=(-1, 1), bin_size=0.02):
    bins = np.arange(window[0], window[1] + bin_size, bin_size)
    all_counts = []
    for et in event_times:
        rel_spikes = spike_times - et
        counts, _ = np.histogram(rel_spikes, bins)
        all_counts.append(counts)
    return bins[:-1], np.array(all_counts)

# Example: test all good units for tuning to corner visits
window = (-1, 1)
bin_size = 0.02
significant_units = []
peths = []
peths_all = []  # Store all trial PETHs for each unit
for unit in tagged_mua_units:
    unit_spike_times = spike_seconds[spike_clusters == unit]
    bins, counts = compute_peth(unit_spike_times, estimated_strobe_seconds[corner_onsets], window, bin_size)
    baseline = counts[:, :int(abs(window[0])/bin_size)]  # pre-event
    response = counts[:, int(abs(window[0])/bin_size):]  # post-event
    # Test: mean response > baseline
    stat, pval = ttest_1samp(response.mean(axis=1) - baseline.mean(axis=1), 0)
    if pval < 0.01 and response.mean() > baseline.mean():
        significant_units.append(unit)
        peths.append(counts.mean(axis=0))
        peths_all.append(counts)  # Store all trials

# Plot PETHs for significant units
if significant_units:
    plt.figure(figsize=(8, 4))
    for i, (unit, peth, all_counts) in enumerate(zip(significant_units, peths, peths_all)):
        mean_peth = zscore(peth)
        std_peth = np.std(zscore(all_counts), axis=0)
        plt.plot(bins, mean_peth + i*3, label=f'Unit {unit}')
        plt.fill_between(bins, mean_peth + i*3 - std_peth, mean_peth + i*3 + std_peth, alpha=0.2)
        plt.axhline(i*3, color='k', lw=0.5)
    plt.axvline(0, color='k', ls='--')
    plt.xlabel('Time from event (s)')
    plt.ylabel('Z-scored firing rate (offset for clarity)')
    plt.title('PETHs for Significantly Tuned Units (Corner Visits)')
    plt.legend()
    plt.tight_layout()
    plt.show()
else:
    print("No significantly tuned units found for corner visits.")


window = (-1, 1)
bin_size = 0.02
significant_units = []
peths = []
for unit in tagged_mua_units:
    unit_spike_times = spike_seconds[spike_clusters == unit]
    bins, counts = compute_peth(unit_spike_times, estimated_strobe_seconds[accel_events], window, bin_size)
    baseline = counts[:, :int(abs(window[0])/bin_size)]  # pre-event
    response = counts[:, int(abs(window[0])/bin_size):]  # post-event
    # Test: mean response > baseline
    stat, pval = ttest_1samp(response.mean(axis=1) - baseline.mean(axis=1), 0)
    if pval < 0.01 and response.mean() > baseline.mean():
        significant_units.append(unit)
        peths.append(counts.mean(axis=0))

# Plot PETHs for significant units
if significant_units:
    plt.figure(figsize=(8, 4))
    for i, (unit, peth) in enumerate(zip(significant_units, peths)):
        plt.plot(bins, zscore(peth) + i*3, label=f'Unit {unit}')
        std = np.std(peth)
        plt.fill_between(bins, zscore(peth) + i*3 - std, zscore(peth) + i*3 + std, alpha=0.2)
        plt.axhline(i*3, color='k', lw=0.5)
    plt.axvline(0, color='k', ls='--')
    plt.xlabel('Time from event (s)')
    plt.ylabel('Z-scored firing rate (offset for clarity)')
    plt.title('PETHs for Significantly Tuned Units (acceleration events)')
    plt.legend()
    plt.tight_layout()
    plt.show()
else:
    print("No significantly tuned units found for acceleration.")

In [None]:
import numpy as np
from sklearn.decomposition import PCA
import matplotlib.pyplot as plt

# Use all unique neuron IDs from spike_clusters
unit_ids = np.unique(spike_clusters)
bin_size = 0.05  # seconds
t_start = spike_seconds.min()
t_end = spike_seconds.max()
n_bins = int(np.ceil((t_end - t_start) / bin_size))

# Build spike count matrix: shape (n_units, n_bins)
spike_matrix = np.zeros((len(unit_ids), n_bins))
for i, unit in enumerate(unit_ids):
    times = spike_seconds[spike_clusters == unit]
    bin_idx = ((times - t_start) / bin_size).astype(int)
    bin_idx = bin_idx[(bin_idx >= 0) & (bin_idx < n_bins)]
    np.add.at(spike_matrix[i], bin_idx, 1)

# Z-score across time for each neuron
spike_matrix_z = (spike_matrix - spike_matrix.mean(axis=1, keepdims=True)) / spike_matrix.std(axis=1, keepdims=True)

# PCA
n_components = 300  # You can adjust this number as needed
pca = PCA(n_components=n_components)
assembly_patterns = pca.fit_transform(spike_matrix_z.T)  # shape: (n_bins, n_components)

# Plot the first few assembly activations
plt.figure(figsize=(10, 4))
for i in range(30):
    plt.plot(assembly_patterns[:, i], label=f'Assembly {i+1}')
plt.xlabel('Time bin')
plt.ylabel('Assembly activation')
plt.legend()
plt.title('Cell Assembly Activations (PCA, all neurons)')
plt.show()

# Scree plot
plt.figure()
plt.plot(np.cumsum(pca.explained_variance_ratio_)*100, marker='o')
plt.xlabel('Number of components')
plt.ylabel('Cumulative explained variance (%)')
plt.title('Scree plot (all neurons)')
plt.grid()
plt.show()

# Determine the number of units involved in a cell assembly (data-driven way)
# Example for PCA, but works similarly for ICA (just use ica.mixing_)

for assembly_idx in range(50):
    weights = pca.components_[assembly_idx]  # shape: (n_units,)

    # Data-driven threshold: units with absolute weights above mean + 2*std
    abs_weights = np.abs(weights)
    threshold = abs_weights.mean() + 2 * abs_weights.std()
    involved_units_idx = np.where(abs_weights > threshold)[0]
    involved_units = unit_ids[involved_units_idx]

    print(f"Number of units involved in assembly {assembly_idx+1}: {len(involved_units)}")
    print("Unit IDs:", involved_units)

    # Optional: visualize the sorted weights and threshold
    plt.figure(figsize=(8, 4))
    plt.plot(np.sort(abs_weights)[::-1], marker='o', label='Absolute weights')
    plt.axhline(threshold, color='r', linestyle='--', label='Threshold')
    plt.xlabel('Unit rank')
    plt.ylabel('Absolute component weight')
    plt.title(f'Assembly {assembly_idx+1}: Absolute Weights and Threshold')
    plt.legend()
    plt.show()

In [None]:
from sklearn.decomposition import FastICA

n_components = 30  # You can adjust this number as needed
ica = FastICA(n_components=n_components, random_state=0, max_iter=1000)
assembly_patterns_ica = ica.fit_transform(spike_matrix_z.T)  # shape: (n_bins, n_components)

# Plot the first few assembly activations
plt.figure(figsize=(10, 4))
for i in range(n_components):
    plt.plot(assembly_patterns_ica[:, i], label=f'Assembly {i+1}')
plt.xlabel('Time bin')
plt.ylabel('Assembly activation (ICA)')
plt.legend()
plt.title('Cell Assembly Activations (ICA, all neurons)')
plt.show()

# Scree plot is not meaningful for ICA, but you can inspect the mixing matrix
for assembly_idx in range(n_components):
    weights = ica.mixing_[:, assembly_idx]  # shape: (n_units,)

    # Data-driven threshold: units with absolute weights above mean + 2*std
    abs_weights = np.abs(weights)
    threshold = abs_weights.mean() + 4 * abs_weights.std()
    involved_units_idx = np.where(abs_weights > threshold)[0]
    involved_units = unit_ids[involved_units_idx]

    print(f"Number of units involved in assembly {assembly_idx+1}: {len(involved_units)}")
    print("Unit IDs:", involved_units)

    # Optional: visualize the sorted weights and threshold
    plt.figure(figsize=(8, 4))
    plt.plot(np.sort(abs_weights)[::-1], marker='o', label='Absolute weights')
    plt.axhline(threshold, color='r', linestyle='--', label='Threshold')
    plt.xlabel('Unit rank')
    plt.ylabel('Absolute component weight')
    plt.title(f'Assembly {assembly_idx+1}: Absolute Weights and Threshold (ICA)')
    plt.legend()
    plt.show()


In [None]:
from sklearn.decomposition import NMF

# NMF requires non-negative input, so shift spike_matrix_z to be all positive
spike_matrix_nmf = spike_matrix_z - spike_matrix_z.min()

n_top_components = 50  # You can adjust this number
nmf = NMF(n_components=n_top_components, init='nndsvda', random_state=0, max_iter=1000)
W = nmf.fit_transform(spike_matrix_nmf.T)  # shape: (n_bins, n_components)
H = nmf.components_  # shape: (n_components, n_units)

# Plot the first few assembly activations (W)
plt.figure(figsize=(10, 4))
for i in range(min(10, n_top_components)):
    plt.plot(W[:, i], label=f'Assembly {i+1}')
plt.xlabel('Time bin')
plt.ylabel('Assembly activation (NMF)')
plt.legend()
plt.title('Cell Assembly Activations (NMF, all neurons)')
plt.tight_layout()
plt.show()

# Plot the number of involved units per assembly (nonzero weights in H)]
involved_units_counts_nmf = np.sum(H > (0.5 * H.max(axis=1, keepdims=True)), axis=1)
plt.figure(figsize=(8, 4))
plt.plot(np.arange(1, n_top_components), involved_units_counts_nmf[:-1], marker='o')
plt.xlabel('Assembly (NMF component) index')
plt.ylabel('Number of involved units')
plt.title('Number of units involved in each NMF assembly')
plt.grid(True)
plt.tight_layout()
plt.show()

In [None]:
plt.figure(figsize=(16, 30))
for i in range(n_top_components-1):
    plt.subplot(10, 5, i+1)
    sorted_weights = np.sort(H[i])[::-1]
    plt.plot(sorted_weights, marker='o')
    plt.axhline(H.max(axis=1, keepdims=True)[i]/2, color='r', linestyle='--', label='Threshold')
    plt.title(f'NMF Component {i+1}')
    plt.xlabel('Unit rank')
    plt.ylabel('Weight')
    plt.tight_layout()
plt.suptitle('Sorted NMF Weights for First 10 Components', y=1.02, fontsize=16)
plt.show()

In [None]:
plt.figure(figsize=(10, 10))
plt.scatter(dlc_interpolated[('Snout', 'x')], dlc_interpolated[('Snout', 'y')], s=0.15, color='black', alpha=0.2, label='DLC Snout Trajectory')
plt.title('Snout Trajectory')

In [None]:
from collections import Counter
tagged_units = np.concatenate([tagged_good_units, tagged_mua_units]).squeeze()
# Analyze what happened around the time that tagged units fired

# Choose a window around each spike (e.g., +/- 1 second)
peri_window = 1.0  # seconds
frame_rate = 90  # Hz
n_bins = int(2 * peri_window * frame_rate) + 1
peri_time_axis = np.linspace(-peri_window, peri_window, n_bins)

# For each tagged unit, extract peri-event speed and VAME motif
for unit in tagged_units:
    peri_speed = []
    unit_spike_times = spike_seconds[spike_clusters == unit]
    for st in unit_spike_times:
        # Find the closest strobe timing and its index to the spike time
        idx = np.argmin(np.abs(estimated_strobe_seconds - st))
        closest_strobe_time = estimated_strobe_seconds[idx]

        # Take 90 frames before and after that index (total 181 frames)
        peri_idx = np.arange(idx - 90, idx + 91)
        # Bounds check: skip if peri_idx goes out of bounds
        if peri_idx[0]-1 < 0 or peri_idx[-1]-1 >= len(speed):
            continue
        peri_speed_arr = speed[peri_idx-1]
        peri_speed.append(peri_speed_arr)
    peri_speed = np.array(peri_speed, dtype=object)

    # plot peri-spike speed (mean across all spikes)
    plt.figure(figsize=(20, 4))
    peri_speed_stack = np.vstack(peri_speed)
    plt.subplot(1,5,1)
    plt.plot(peri_time_axis, np.nanmean(peri_speed_stack, axis=0), color='gray', alpha=0.8)
    plt.xlabel('Time from spike (s)')
    plt.ylabel('Speed')
    plt.ylim(100,800)
    plt.title(f'Unit #{unit}: Mean speed around tagged unit spikes')

    # Bin the arena and calculate firing rate heatmap
    fired_coords = []
    for spike_time in unit_spike_times:
        idx = np.argmin(np.abs(estimated_strobe_seconds - spike_time))
        timegap = np.min(np.abs(estimated_strobe_seconds - spike_time))
        if timegap > 0.02:
            continue
        snout_x_pos = snout_x[idx]
        snout_y_pos = snout_y[idx]
        fired_coords.append((snout_x_pos, snout_y_pos))

    fired_coords = np.array(fired_coords)
    # Define arena bins
    x_bins = np.linspace(200, 1230, 15)
    y_bins = np.linspace(40, 1070, 15)

    # 2D histogram of spike positions (firing events)
    spike_hist, xedges, yedges = np.histogram2d(
        fired_coords[:, 0], fired_coords[:, 1], bins=[x_bins, y_bins]
    )

    # Occupancy: how many frames the animal spent in each bin
    occupancy_hist, _, _ = np.histogram2d(
        snout_x, snout_y, bins=[x_bins, y_bins]
    )
    # Convert occupancy to seconds (assuming 90 Hz)
    occupancy_sec = occupancy_hist / frame_rate
    # Avoid division by zero
    with np.errstate(divide='ignore', invalid='ignore'):
        firing_rate_map = np.where(occupancy_sec > 0.5, spike_hist / occupancy_sec, np.nan)
    
    # Plot scatter
    plt.subplot(1,5,2)
    plt.scatter(fired_coords[:, 0], fired_coords[:, 1], s=1, color='black', alpha=0.5)
    plt.xlim(200, 1230)
    plt.ylim(40, 1070)
    plt.ylabel('Y coordinate')
    plt.title('Snout position when the unit fired')

    # Plot heatmap
    plt.subplot(1,5,3)
    im = plt.imshow(
        np.rot90(firing_rate_map), 
        extent=[x_bins[0], x_bins[-1], y_bins[0], y_bins[-1]],
        aspect='auto', cmap='coolwarm'
    )
    plt.xlabel('X coordinate')
    plt.title('Firing Rate Heatmap')
    plt.colorbar(im, label='Hz')
    plt.title('Firing Rate Heatmap')
    plt.gca().set_aspect('equal', adjustable='box')

    #plot the position on the probe and waveform of the unit
    plt.subplot(1,5,4)
    shank_width = 24       # in microns
    shank_length = -5000   # in microns
    shank_spacing = 250    # center-to-center in microns
    n_shanks = 4

    # Draw 4 shanks
    for i in range(n_shanks):
        # Center positions:  (so everything is centered at x=0)
        center_x = (i-3) * shank_spacing
        left = center_x - shank_width / 2
        bottom = 0
        rect = plt.Rectangle((left, bottom), shank_width, shank_length,
                            linewidth=2, edgecolor='k', facecolor='none', zorder=1)
        plt.gca().add_patch(rect)

    # Configure axes
    plt.xlabel('Lateral (Î¼m) 0 = ML:-1500Î¼m')
    plt.ylabel('Depth (Î¼m) 0 = brain surface')
    plt.title('Neuropixels 2.0 4-shank Probe Layout')
    # plt.gca().set_aspect('equal')


    unit_position = spike_positions[spike_clusters == unit]
    plt.scatter(np.median(unit_position[:,0])-800, np.median(unit_position[:,1])-5000, s=10, c='red')
    plt.gca().invert_xaxis()

    # Plot the waveform of the unit
    plt.subplot(1,5,5)
    wf = templates[unit, :, :]  # [n_timepoints, n_channels]
    min_vals = wf.min(axis=0)
    max_vals = wf.max(axis=0)
    abs_peaks = np.where(np.abs(min_vals) > np.abs(max_vals), min_vals, max_vals)
    main_ch_idx = np.argmax(np.abs(abs_peaks))
    main_peak_val = abs_peaks[main_ch_idx]
    waveform = wf[:, main_ch_idx]
    plt.plot(waveform, color='k')
    plt.gca().invert_yaxis()  # Invert y-axis for waveform
    plt.title(f'Waveform')
    plt.xlabel('Time (samples)')
    plt.ylabel('Template used for waveform extraction')

    plt.tight_layout()
    plt.savefig(selected_dir + f"EDA\\unit_{unit}_EDA.png", dpi=300)
    plt.show()

In [None]:
from collections import Counter
# Get all non-tagged units
all_units = np.unique(spike_clusters)
tagged_units = np.concatenate([tagged_good_units, tagged_mua_units]).squeeze()
non_tagged_units = np.setdiff1d(all_units, tagged_units)

# Choose a window around each spike (e.g., +/- 1 second)
peri_window = 1.0  # seconds
frame_rate = 90  # Hz
n_bins = int(2 * peri_window * frame_rate) + 1
peri_time_axis = np.linspace(-peri_window, peri_window, n_bins)

# For each tagged unit, extract peri-event speed and VAME motif
for unit in tagged_units:
    peri_speed = []
    unit_spike_times = spike_seconds[spike_clusters == unit]
    for st in unit_spike_times:
        # Find the closest strobe timing and its index to the spike time
        idx = np.argmin(np.abs(estimated_strobe_seconds - st))
        closest_strobe_time = estimated_strobe_seconds[idx]

        # Take 90 frames before and after that index (total 181 frames)
        peri_idx = np.arange(idx - 90, idx + 91)
        # Bounds check: skip if peri_idx goes out of bounds
        if peri_idx[0]-1 < 0 or peri_idx[-1]-1 >= len(speed):
            continue
        peri_speed_arr = speed[peri_idx-1]
        peri_speed.append(peri_speed_arr)
    peri_speed = np.array(peri_speed, dtype=object)

    # plot peri-spike speed (mean across all spikes)
    plt.figure(figsize=(13, 4))
    peri_speed_stack = np.vstack(peri_speed)
    plt.subplot(1,3,1)
    plt.plot(peri_time_axis, np.nanmean(peri_speed_stack, axis=0), color='gray', alpha=0.8)
    plt.xlabel('Time from spike (s)')
    plt.ylabel('Speed')
    plt.title(f'Unit #{unit}: Mean speed around tagged unit spikes')

    # Bin the arena and calculate firing rate heatmap
    fired_coords = []
    for spike_time in unit_spike_times:
        idx = np.argmin(np.abs(estimated_strobe_seconds - spike_time))
        timegap = np.min(np.abs(estimated_strobe_seconds - spike_time))
        if timegap > 0.02:
            continue
        snout_x_pos = snout_x[idx]
        snout_y_pos = snout_y[idx]
        fired_coords.append((snout_x_pos, snout_y_pos))

    fired_coords = np.array(fired_coords)
    # Define arena bins
    x_bins = np.linspace(200, 1230, 20)
    y_bins = np.linspace(40, 1070, 20)

    # 2D histogram of spike positions (firing events)
    spike_hist, xedges, yedges = np.histogram2d(
        fired_coords[:, 0], fired_coords[:, 1], bins=[x_bins, y_bins]
    )

    # Occupancy: how many frames the animal spent in each bin
    occupancy_hist, _, _ = np.histogram2d(
        snout_x, snout_y, bins=[x_bins, y_bins]
    )
    # Convert occupancy to seconds (assuming 90 Hz)
    occupancy_sec = occupancy_hist / frame_rate
    # Avoid division by zero
    with np.errstate(divide='ignore', invalid='ignore'):
        firing_rate_map = np.where(occupancy_sec > 0.5, spike_hist / occupancy_sec, np.nan)
    
    # Plot scatter
    plt.subplot(1,3,2)
    plt.scatter(fired_coords[:, 0], fired_coords[:, 1], s=1, color='black', alpha=0.5)

    # Plot heatmap
    plt.subplot(1,3,3)
    im = plt.imshow(
        np.rot90(firing_rate_map), 
        extent=[x_bins[0], x_bins[-1], y_bins[0], y_bins[-1]],
        aspect='auto', cmap='coolwarm'
    )
    plt.xlabel('Snout x')
    plt.ylabel('Snout y')
    plt.title('Firing Rate Heatmap')
    plt.colorbar(im, label='Hz')
    plt.tight_layout()
    plt.show()

## CEBRA Analysis

In [None]:
### Data conversion
# Convert spike timestamps to binned firing counts using estimated_strobe_seconds as time bins
# Assumes: spike_seconds, spike_clusters, tagged_good_units, tagged_mua_units, estimated_strobe_seconds are defined

import cebra
from cebra import CEBRA
# Combine all units you want to include (e.g., good + MUA)
unit_ids = unit_label[(unit_label['unitType'] == 1) | (unit_label['unitType'] == 2)]
n_units = len(unit_ids)
n_timebins = len(estimated_strobe_seconds)

# Prepare output: (n_timebins, n_units)
binned_counts = np.zeros((n_timebins, n_units), dtype=int)
plt.figure()
# For each unit, bin its spikes according to estimated_strobe_seconds
# unit_ids is a DataFrame, so iterate over its index to get the unit IDs
for i, unit_idx in enumerate(unit_ids.index):
    unit = unit_idx  # The index is the unit ID
    # Get spike times for this unit
    unit_spike_times = spike_seconds[spike_clusters == unit]
    # Bin the spikes: np.searchsorted finds the index in estimated_strobe_seconds where each spike belongs
    bin_idx = np.searchsorted(estimated_strobe_seconds, unit_spike_times, side='right') - 1
    # Remove out-of-bounds (before first or after last bin)
    bin_idx = bin_idx[(bin_idx >= 0) & (bin_idx < n_timebins-1)]
    # Count spikes in each bin
    counts = np.bincount(bin_idx, minlength=n_timebins)
    plt.plot(counts)
    binned_counts[:, i] = counts

# Now binned_counts is (n_timebins, n_units), ready for CEBRA
print("Binned firing count matrix shape:", binned_counts.shape)

neural_data = binned_counts  # shape: (n_bins, n_neurons)

In [None]:
### Split data and labels (labels we use later!)
from sklearn.model_selection import train_test_split


# Smooth speed with a moving average of 5 bins
window_size = 5
speed_smooth = np.convolve(speed, np.ones(window_size)/window_size, mode='same')
# Cap the speed at 5000
speed_smooth = np.clip(speed_smooth, None, 2500)

split_idx = int(0.8 * len(neural_data)) #suggest: 5%-20% depending on your dataset size
train_data = neural_data[:split_idx]
valid_data = neural_data[split_idx:]


'''
0: x-coordinate of the snout
1: y-coordinate of the snout
2: x-coordinate of the snout + y-coordinate of the snout
3: smoothed speed(smoothing window = 5, capped at 5000 pxl/sec)
'''
# Stack all four arrays first, then slice for train/valid
snout_x_all = dlc[('Snout', 'x')].to_numpy()
snout_y_all = dlc[('Snout', 'y')].to_numpy()
snout_xy_all = snout_x_all + snout_y_all
speed_smooth_all = speed_smooth

continuous_label = np.stack([snout_x_all, snout_y_all, snout_xy_all, speed_smooth_all], axis=1)
train_continuous_label = continuous_label[:split_idx]
valid_continuous_label = continuous_label[split_idx:]

# VAME label as discrete label
train_discrete_label = vame_labels[:split_idx]
valid_discrete_label = vame_labels[split_idx:]

# 1. Define a CEBRA model
cebra_model = CEBRA(
    model_architecture="offset10-model", #consider: "offset10-model-mse" if Euclidean
    batch_size=2048,
    learning_rate=3e-4,
    temperature=0.1,
    max_iterations=5000, #we will sweep later; start with default
    conditional='time', #for supervised, put 'time_delta', or 'delta'
    output_dimension=32,
    num_hidden_units=256,
    distance='cosine', #consider 'euclidean'; if you set this, output_dimension min=2
    device="cuda",
    verbose=True,
    time_offsets=5
)


In [None]:
#%mkdir saved_models

params_grid = dict(
    model_architecture="offset10-model",
    device="cuda",  # or "cpu"
    batch_size=2048,
    learning_rate=3e-4,
    max_iterations=5000,
    temperature=[0.1, 1.0],
    output_dimension=[3, 6, 12, 32],
    time_offsets=[5,10,50],
    verbose=True,
    num_hidden_units=[32,64,256],
    hybrid=False
)

datasets = {"dataset1": train_data}
# run the grid search
grid_search = cebra.grid_search.GridSearch()
grid_search.fit_models(datasets, params=params_grid, models_dir="saved_models")

In [None]:
# Get the results
df_results = grid_search.get_df_results(models_dir="saved_models")

# Get the best model for a given dataset
best_model, best_model_name = grid_search.get_best_model(dataset_name="dataset1", models_dir="saved_models")
print("The best model is:", best_model_name)

In [None]:
#load the top model âœ¨
model_path = Path("saved_models") / f"{best_model_name}.pt"
top_model = cebra.CEBRA.load(model_path)

#transform:
top_train_embedding = top_model.transform(train_data)
top_valid_embedding = top_model.transform(valid_data)

# plot the loss curve
ax = cebra.plot_loss(top_model)


# plot embeddings
fig = cebra.integrations.plotly.plot_embedding_interactive(top_train_embedding,
                                                           embedding_labels=train_continuous_label[:,3],
                                                           title = "top model - train",
                                                           markersize=3,
                                                           cmap = "rainbow")
fig.show()

fig = cebra.integrations.plotly.plot_embedding_interactive(top_valid_embedding,
                                                           embedding_labels=valid_continuous_label[:,3],
                                                           title = "top model - validation",
                                                           markersize=3,
                                                           cmap = "rainbow")
fig.show()

In [None]:
### Now we are going to run our train/val. 5-10 times to be sure they are consistent!

X = 5  # Number of training runs
model_paths = []  # Store file paths

for i in range(X):
    print(f"Training ðŸ¦“CEBRA model {i+1}/{X}")

    # Train and save model
    cebra_train_model = cebra_model.fit(train_data)
    tmp_file = Path(tempfile.gettempdir(), f'cebra_{i}.pt')
    cebra_train_model.save(tmp_file)
    model_paths.append(tmp_file)

### Reload models and transform data
train_embeddings = []
valid_embeddings = []

for tmp_file in model_paths:
    cebra_train_model = cebra.CEBRA.load(tmp_file)
    train_embeddings.append(cebra_train_model.transform(train_data))
    valid_embeddings.append(cebra_train_model.transform(valid_data))

In [None]:
scores, pairs, ids_runs = cebra.sklearn.metrics.consistency_score(
    embeddings=train_embeddings,
    between="runs"
)

cebra.plot_consistency(scores, pairs, ids_runs)

In [None]:
scores, pairs, ids_runs = cebra.sklearn.metrics.consistency_score(
    embeddings=valid_embeddings,
    between="runs"
)

cebra.plot_consistency(scores, pairs, ids_runs)

In [None]:
### Check the behavioral encoding
# Define the model
# consider changing based on search/results above
cebra_behavior_model = CEBRA(model_architecture='offset10-model',
                        batch_size=2048,
                        learning_rate=3e-4,
                        temperature=0.1,
                        output_dimension=32,
                        num_hidden_units=256,
                        max_iterations=5000,
                        distance='cosine',
                        conditional='time_delta', #using labels
                        device='cuda',
                        verbose=True,
                        time_offsets=5)
# fit
cebra_behavior_full_model = cebra_behavior_model.fit(neural_data, continuous_label)
# transform
cebra_behavior_full = cebra_behavior_full_model.transform(neural_data)
# GoF
gof_full = cebra.sklearn.metrics.goodness_of_fit_score(cebra_behavior_full_model,neural_data, continuous_label)
print(" GoF in bits - full:", gof_full)

In [None]:
for i in range(4):
    # plot embedding
    fig = cebra.integrations.plotly.plot_embedding_interactive(cebra_behavior_full, embedding_labels=continuous_label[:,i], title = "CEBRA-Behavior (full)", markersize=3, cmap = "rainbow")
    fig.show()

In [None]:
# Label Shuffle control model:
cebra_shuffled_model = CEBRA(model_architecture='offset10-model',
                        batch_size=2048,
                        learning_rate=3e-4,
                        temperature=0.1,
                        output_dimension=32,
                        max_iterations=5000,
                        num_hidden_units=256,
                        distance='cosine',
                        conditional='time_delta',
                        device='cuda_if_available',
                        verbose=True,
                        time_offsets=5)

In [None]:
# Shuffle the behavior variable and use it for training
shuffled_pos = np.random.permutation(continuous_label)

#fit, transform
cebra_shuffled_model.fit(neural_data, shuffled_pos)
cebra_pos_shuffled = cebra_shuffled_model.transform(neural_data)
# GoF
gof_full = cebra.sklearn.metrics.goodness_of_fit_score(cebra_shuffled_model, neural_data, continuous_label)
print(" GoF in bits - full:", gof_full)
# plot embedding
fig = cebra.integrations.plotly.plot_embedding_interactive(cebra_pos_shuffled, embedding_labels=continuous_label[:,0], title = "CEBRA-Behavior (labels shuffled)", markersize=3, cmap = "rainbow")
fig.show()
# plot the loss curve
ax = cebra.plot_loss(cebra_shuffled_model)

In [None]:
### Shuffle the neural data and use it for training
shuffle_idx = np.random.permutation(len(neural_data))
shuffled_neural = neural_data[shuffle_idx]
#fit, transform
cebra_shuffled_model.fit(shuffled_neural, continuous_label)
cebra_neural_shuffled = cebra_shuffled_model.transform(shuffled_neural)
# GoF
gof_full = cebra.sklearn.metrics.goodness_of_fit_score(cebra_shuffled_model, shuffled_neural, continuous_label)
print(" GoF in bits - full:", gof_full)
# plot embedding
fig = cebra.integrations.plotly.plot_embedding_interactive(cebra_neural_shuffled, embedding_labels=continuous_label[:,0], title = "CEBRA-Behavior (neural shuffled)", markersize=3, cmap = "rainbow")
fig.show()
# plot the loss curve
ax = cebra.plot_loss(cebra_shuffled_model)

In [None]:
cebra_hybrid_model = CEBRA(model_architecture='offset10-model',
                        batch_size=2048,
                        learning_rate=3e-4,
                        temperature=1.12,
                        output_dimension=8,
                        max_iterations=5000,
                        num_hidden_units=64,
                        distance='cosine',
                        conditional='time_delta',
                        device='cuda',
                        verbose=True,
                        time_offsets=5,
                        hybrid = True)

cebra_hybrid_model.fit(neural_data, continuous_label)
cebra_hybrid = cebra_hybrid_model.transform(neural_data)

In [None]:
for i in range(4):
    # plot embedding
    fig = cebra.integrations.plotly.plot_embedding_interactive(cebra_hybrid, embedding_labels=continuous_label[:,i], title = "CEBRA-Behavior (full)", markersize=3, cmap = "rainbow")
    fig.show()

### PP-seq

!pip install PPseq