In [1]:
from collections import Counter
import numpy as np
import matplotlib.pyplot as plt
import pandas as pd
import seaborn as sns
import os
from glob import glob
from pathlib import Path
import time

from tkinter import Tk
from tkinter import filedialog
import pathlib

import datetime

from tqdm import tqdm
import tempfile
import sys
sys.path.append(str(pathlib.Path.cwd() / "pipeline" / "spikeinterface_waveform_extraction"))
from joblib import Parallel, delayed
from scipy.ndimage import gaussian_filter
import numbers
import warnings
warnings.filterwarnings('default')

from pprint import pprint

import spikeinterface_v2 as si
from DemoReadSGLXData.readSGLX import readMeta, SampRate, makeMemMapRaw, ExtractDigital
from findT import find_tEnd_from_KS, find_tTurn_from_KS
from collections import defaultdict
from my_spike_tools import *

from spikeinterface_v2 import download_dataset
from spikeinterface_v2 import create_sorting_analyzer, load_sorting_analyzer
import spikeinterface_v2.extractors as se

import spikeinterface_v2.full as si
import spikeinterface_v2.extractors as se
import spikeinterface_v2.preprocessing as spre
import spikeinterface_v2.sorters as ss
import spikeinterface_v2.postprocessing as spost
import spikeinterface_v2.qualitymetrics as sqm
import spikeinterface_v2.comparison as sc
import spikeinterface_v2.exporters as sexp
import spikeinterface_v2.curation as scur
import spikeinterface_v2.widgets as sw


core_dir = r"D:/Neuropixels"
date_strs = [
    '01292025', '01302025', '01312025', '02012025',
    '02022025', '02032025', '02042025', '02062025', '02072025',
    '02092025', '02102025', '02112025', '02122025', '02132025',
    '02142025', '02152025', '02162025'
]
session_names = [f"9153_{d}_tagging_g0" for d in date_strs]

for session in session_names:
    selected_dir = os.path.join(core_dir, session, f"{session}_imec0")
    selected_dir = pathlib.Path(selected_dir)
    base_dir = str(selected_dir)
    print(base_dir)
    
    ######## load neural data ##########
    spike_mask = np.load(os.path.join(selected_dir, "kilosort4\\spike_mask.npy"))
    spike_seconds = np.load(os.path.join(selected_dir, "kilosort4\\spike_seconds_adj.npy"))[spike_mask]
    spike_clusters = np.load(os.path.join(selected_dir, "kilosort4\\spike_clusters.npy"))[spike_mask]
    spike_positions = np.load(os.path.join(selected_dir, "kilosort4\\spike_positions.npy"))[spike_mask]
    templates = np.load(os.path.join(selected_dir, "kilosort4\\templates.npy"))


    strobe_seconds = np.load(os.path.join(selected_dir, "kilosort4\\strobe_seconds.npy"))

    tagged_good_units = np.load(os.path.join(selected_dir, "kilosort4\\tagged_good_units.npy"))
    tagged_mua_units = np.load(os.path.join(selected_dir, "kilosort4\\tagged_mua_units.npy"))

    unit_label = pd.read_csv(os.path.join(selected_dir, "kilosort4qMetrics\\templates._bc_unit_labels.tsv"), sep="\t")

    ######## load behavioral data ##########
    session_name = selected_dir.parts[-1]
    date_str = session_name.split('_')[1]

    # Convert MMDDYYYY to YYYY-MM-DD
    date_obj = datetime.datetime.strptime(date_str, "%m%d%Y")
    date_str = date_obj.strftime("%Y-%m-%d")

    # Search for matching file
    event_dir = pathlib.Path(r"D:\Neuropixels\Event\9153")
    event_file = event_dir.glob(f"*{date_str}*.csv")
    events = pd.read_csv(event_file.__next__())
    events = events.drop_duplicates(subset='Index', keep='last') # remove duplicates

    dlc_dir = pathlib.Path(r"D:\Neuropixels\DLC\9153")
    dlc_file = dlc_dir.glob(f"*{date_str}*.h5")
    dlc = pd.read_hdf(dlc_file.__next__())
    dlc.columns = dlc.columns.droplevel(0)
    dlc = dlc.loc[dlc.index >= events['Index'].iloc[0]] # trim DLC data to the first index in the event file

    # Check the number of frames in each file
    start = events['Index'].iloc[0] # When the first event starts
    end = events['Index'].iloc[-1]
    print("First index: ", start)
    print("Last index: ", end)
    stim = events['Stim'].to_numpy()
    print("Event file length: ", len(stim))
    print("Strobe file length: ", len(strobe_seconds))

    snout_x = dlc[('Snout', 'x')].to_numpy()
    snout_y = dlc[('Snout', 'y')].to_numpy()
    print("DLC file length: ", len(snout_x))
    print("Estimated # of frames: ", (strobe_seconds[-1] - strobe_seconds[0]) * 89.97)

    # Create estimated strobe timings array by uniform interval between the first and the last strobe time stamps
    estimated_strobe_seconds = np.linspace(strobe_seconds[0], strobe_seconds[-1], len(events))
    print("Strobe file length(estimate): ", len(estimated_strobe_seconds))

    ### Interpolate DLC data ###
    from scipy.interpolate import UnivariateSpline

    # Specify body parts
    bodyparts = dlc.columns.get_level_values(0).unique().tolist()
    coords = dlc.columns.get_level_values(1).unique().tolist()[0:2]

    # Threshold for likelihood
    likelihood_threshold = 0.7

    dlc_interpolated = dlc.copy()

    # Process each body part
    for bp in bodyparts:
        for axis in coords:
            series = dlc[bp][axis]
            likelihood = dlc[bp]['likelihood']
            
            # Mask low-confidence data
            series_masked = series.copy()
            series_masked[likelihood < likelihood_threshold] = np.nan
            
            # Get valid indices and values
            valid = ~series_masked.isna()
            x_valid = np.arange(len(series_masked))[valid]
            y_valid = series_masked[valid]
            
            spline = UnivariateSpline(x_valid, y_valid, k=1, s=0)
            interpolated = spline(np.arange(len(series_masked)))
            series_filled = series_masked.copy()
            series_filled[~valid] = interpolated[~valid]
            dlc_interpolated[bp][axis] = series_filled

    # Kinematic analysis
    bodyparts = ['Snout', 'Tail']
    x_coords = np.stack([dlc_interpolated[(bp, 'x')].values for bp in bodyparts], axis=1)
    y_coords = np.stack([dlc_interpolated[(bp, 'y')].values for bp in bodyparts], axis=1)
    mean_x = np.mean(x_coords, axis=1)
    mean_y = np.mean(y_coords, axis=1)
    dt = np.median(np.diff(estimated_strobe_seconds))
    vx = np.gradient(mean_x, dt)
    vy = np.gradient(mean_y, dt)
    speed = np.sqrt(vx**2 + vy**2)
    acceleration = np.gradient(speed, dt)

    # Classify the units
    all_units = np.where(np.isin(unit_label['unitType'], [1, 2]))[0]
    tagged_units = np.concatenate([tagged_good_units, tagged_mua_units]).squeeze()
    good_units = np.where(unit_label['unitType'] == 1)[0]
    mua_units = np.where(unit_label['unitType'] == 2)[0]


    # Choose a window around each spike (e.g., +/- 1 second)
    peri_window = 1.0  # seconds
    frame_rate = 90  # Hz
    n_bins = int(2 * peri_window * frame_rate) + 1
    peri_time_axis = np.linspace(-peri_window, peri_window, n_bins)

    tagged_units_final = []
    for unit in tqdm(all_units):
        plt.figure(figsize=(20, 7))
        plot_times = {}
        # --------------Plot the waveform of the unit-----------------
        t0 = time.time()
        plt.subplot(2,5,1)
        beh_folder = os.path.join(selected_dir, 'analyzer_beh')
        behavioral_analyzer = load_sorting_analyzer(folder=beh_folder, format='binary_folder')
        beh_median_wf_mc  = np.load(os.path.join(selected_dir, "kilosort4\\waveform_beh_median.npy"), allow_pickle=True).item()[unit]
        if beh_median_wf_mc is not None:
            if unit in tagged_units:
                tag_folder = os.path.join(selected_dir, 'analyzer_tag')
                tag_analyzer = load_sorting_analyzer(folder=tag_folder, format='binary_folder')
                tag_median_wf_mc = np.load(os.path.join(selected_dir, "kilosort4\\waveform_tag_median.npy"), allow_pickle=True).item()[unit]
                if tag_median_wf_mc is not None:
                    # Compute convolution similarity
                    conv_similarity_median = compute_waveform_similarity_convolution(
                        beh_median_wf_mc,
                        tag_median_wf_mc,
                        behavioral_analyzer,
                        tag_analyzer,      
                        unit,
                        use_main_channel=True,
                    )
                    plt.plot(beh_median_wf_mc, color='k')
                    plt.plot(tag_median_wf_mc, color='blue', alpha=0.5)
                    plt.gca().invert_yaxis()  # Invert y-axis for waveform
                    plt.title(f'Waveform')
                    plt.xlabel('Time (samples)')
                    plt.ylabel('Template used for waveform extraction')
                    plt.legend(['Behavioral', 'Tagged'], loc='upper right')
                    plt.text(0.5, 0.80, f'Conv similarity: {conv_similarity_median:.2f}', transform=plt.gca().transAxes, fontsize=10, verticalalignment='top')
                    plt.tight_layout()
                    if conv_similarity_median > 0.8:
                        tagged_units_final.append(unit)
            else:
                plt.plot(beh_median_wf_mc, color='k')
                plt.gca().invert_yaxis()  # Invert y-axis for waveform
                plt.title(f'Waveform')
                plt.xlabel('Time (samples)')
                plt.ylabel('Waveform(median)')
                plt.tight_layout()
        else:
            plt.text(0.5, 0.5, 'No waveform data available', horizontalalignment='center', verticalalignment='center', fontsize=12)
        plot_times['waveform'] = time.time() - t0

        #-------------plot the position on the probe and waveform of the unit------------------
        t0 = time.time()
        plt.subplot(2,5,2)
        shank_width = 24       # in microns
        shank_length = -5000   # in microns
        shank_spacing = 250    # center-to-center in microns
        n_shanks = 4
        for i in range(n_shanks):
            center_x = (i-3) * shank_spacing
            left = center_x - shank_width / 2
            bottom = 0
            rect = plt.Rectangle((left, bottom), shank_width, shank_length,
                                linewidth=2, edgecolor='k', facecolor='none', zorder=1)
            plt.gca().add_patch(rect)
        plt.xlabel('Lateral (μm) 0 = ML:-1500μm')
        plt.ylabel('Depth (μm) 0 = brain surface')
        plt.title('Neuropixels 2.0 4-shank Probe Layout')
        unit_position = spike_positions[spike_clusters == unit]
        plt.scatter(np.median(unit_position[:,0])-800, np.median(unit_position[:,1])-5000, s=10, c='red')
        plt.gca().invert_xaxis()
        plot_times['probe_layout'] = time.time() - t0

        # ---------------plot peri-spike speed (mean across all spikes)-----------------
        t0 = time.time()
        unit_spike_times = spike_seconds[spike_clusters == unit]
        # Vectorized: find closest strobe indices for all spikes
        spike_indices = np.searchsorted(estimated_strobe_seconds, unit_spike_times, side='left')
        spike_indices = np.clip(spike_indices, 0, len(estimated_strobe_seconds)-1)
        peri_idx = np.arange(-90, 91)  # 181 frames, for 1s window at 90Hz
        peri_speed = []
        valid_spikes = []
        for idx in spike_indices:
            idxs = idx + peri_idx
            if idxs[0] < 0 or idxs[-1] >= len(speed):
                continue
            peri_speed.append(speed[idxs])
            valid_spikes.append(idx)
        if len(peri_speed) == 0:
            print(f"Skipping unit {unit}: no peri-event speed data available.")
            continue
        peri_speed_stack = np.vstack(peri_speed)
        plt.subplot(2,5,3)
        plt.plot(peri_time_axis, np.nanmean(peri_speed_stack, axis=0), color='gray', alpha=0.8)
        plt.xlabel('Time from spike (s)')
        plt.ylabel('Speed')
        plt.ylim(100,1100)
        plt.title(f'Unit #{unit}: Mean speed around tagged unit spikes')
        plot_times['peri_speed'] = time.time() - t0

        plt.subplot(2,5,4)
        plt.subplot(2,5,5)
        # ----------------scatter plot of the spike on the arena-----------------
        t0 = time.time()
        # Vectorized: get snout positions for all valid spikes
        fired_coords = np.column_stack((snout_x[spike_indices], snout_y[spike_indices]))
        plt.subplot(2,5,6)
        plt.scatter(fired_coords[:, 0], fired_coords[:, 1], s=1, color='black', alpha=0.5)
        plt.xlim(200, 1230)
        plt.ylim(40, 1070)
        plt.ylabel('Y coordinate')
        plt.title('Snout position when the unit fired')
        plot_times['scatter_arena'] = time.time() - t0

        # ----------- Plot firing rate heatmap------------------
        t0 = time.time()
        x_bins = np.linspace(200, 1230, 15)
        y_bins = np.linspace(40, 1070, 15)
        spike_hist, xedges, yedges = np.histogram2d(fired_coords[:, 0], fired_coords[:, 1], bins=[x_bins, y_bins])
        occupancy_hist, _, _ = np.histogram2d(snout_x, snout_y, bins=[x_bins, y_bins])
        occupancy_sec = occupancy_hist / frame_rate
        with np.errstate(divide='ignore', invalid='ignore'):
            firing_rate_map = np.where(spike_hist >= 5, spike_hist / occupancy_sec, np.nan)
        plt.subplot(2,5,7)
        im = plt.imshow(np.rot90(firing_rate_map), extent=[x_bins[0], x_bins[-1], y_bins[0], y_bins[-1]], aspect='auto', cmap='coolwarm')
        plt.xlabel('X coordinate')
        plt.title('Firing Rate Heatmap')
        plt.colorbar(im, label='Hz')
        plt.title('Firing Rate Heatmap')
        plt.gca().set_aspect('equal', adjustable='box')
        plot_times['firing_rate_heatmap'] = time.time() - t0

        # ----------- Plot mean speed vectors at spikes------------------        
        t0 = time.time()
        x_bins = np.linspace(200, 1230, 10)
        y_bins = np.linspace(40, 1070, 10)
        x_centers = (x_bins[:-1] + x_bins[1:]) / 2
        y_centers = (y_bins[:-1] + y_bins[1:]) / 2
        # Vectorized binning
        sx = snout_x[spike_indices]
        sy = snout_y[spike_indices]
        svx = vx[spike_indices]
        svy = vy[spike_indices]
        xi = np.digitize(sx, x_bins) - 1
        yi = np.digitize(sy, y_bins) - 1
        valid = (xi >= 0) & (xi < len(x_centers)) & (yi >= 0) & (yi < len(y_centers))
        vx_spike_grid = np.zeros((len(x_centers), len(y_centers)))
        vy_spike_grid = np.zeros((len(x_centers), len(y_centers)))
        spike_count_grid = np.zeros((len(x_centers), len(y_centers)))
        np.add.at(vx_spike_grid, (xi[valid], yi[valid]), svx[valid])
        np.add.at(vy_spike_grid, (xi[valid], yi[valid]), svy[valid])
        np.add.at(spike_count_grid, (xi[valid], yi[valid]), 1)
        with np.errstate(invalid='ignore', divide='ignore'):
            vx_spike_mean = np.where(spike_count_grid >= 5, vx_spike_grid / spike_count_grid, np.nan)
            vy_spike_mean = np.where(spike_count_grid >= 5, vy_spike_grid / spike_count_grid, np.nan)
        X, Y = np.meshgrid(x_centers, y_centers, indexing='ij')
        mask = ~np.isnan(vx_spike_mean) & ~np.isnan(vy_spike_mean)
        plt.subplot(2,5,8)
        plt.quiver(X[mask], Y[mask], vx_spike_mean[mask], vy_spike_mean[mask], angles='xy', scale_units='xy', scale=2, color='red', width=0.008)
        plt.xlabel('X coordinate')
        plt.ylabel('Y coordinate')
        plt.title('Mean Speed Vectors at Spikes (min 10 spikes/bin)')
        plt.xlim(200, 1230)
        plt.ylim(40, 1070)
        plt.gca().set_aspect('equal', adjustable='box')
        plot_times['mean_speed_vectors'] = time.time() - t0

        plt.subplot(2,5,9)
        plt.subplot(2,5,10)

        # Ensure the EDA directory exists before saving
        if unit in tagged_units:
            if unit in good_units:
                prefix = "tagged_good"
            elif unit in mua_units:
                prefix = "tagged_mua"
        else:
            if unit in good_units:
                prefix = "good"
            elif unit in mua_units:
                prefix = "mua"
        eda_dir = selected_dir / "EDA"
        eda_dir.mkdir(parents=True, exist_ok=True)

        plt.tight_layout()
        plt.savefig(str(eda_dir / f"{prefix}_unit{unit}_EDA.png"), dpi=300)
        # plt.show()
        plt.close()

        # At the end, print timing for this unit
        # print(f"Plot times for unit {unit}: {plot_times}")

  right=ast.Str(s=sentinel),
  return Constant(*args, **kwargs)


KeyboardInterrupt: 