In [7]:
from pathlib import Path
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
from scipy.ndimage import gaussian_filter1d
from kilosort.io import load_ops
from kilosort.data_tools import (
    mean_waveform, cluster_templates, get_good_cluster, get_cluster_spikes, mean_waveform_with_bounds,
    get_spike_waveforms, get_best_channels
    )

from spikeinterface.extractors import read_phy

def make_cluster_summary_table(results_dir):
    """
    Load summary information from a Kilosort4 results directory and save it as a CSV.

    Parameters:
        results_dir (str or Path): Path to the Kilosort4 output directory.

    Returns:
        pd.DataFrame: Summary dataframe with cluster metadata.
    """
    results_dir = Path(results_dir)
    ops = load_ops(results_dir / 'ops.npy')

    fs = ops['fs']
    chan_map = np.load(results_dir / 'channel_map.npy')
    templates = np.load(results_dir / 'templates.npy')
    chan_best = (templates**2).sum(axis=1).argmax(axis=-1)
    chan_best = chan_map[chan_best]

    template_amplitudes = np.sqrt((templates**2).sum(axis=(-2, -1)))
    st = np.load(results_dir / 'spike_times.npy')
    clu = np.load(results_dir / 'spike_clusters.npy')
    pos = np.load(results_dir / 'spike_positions.npy')

    cluster_ids, spike_counts = np.unique(clu, return_counts=True)
    firing_rates = spike_counts * fs / st.max()

    depth = np.array([
        pos[clu == i, 1].mean() if np.any(clu == i) else np.nan
        for i in cluster_ids
    ])

    df = pd.DataFrame.from_dict({
        'cluster': cluster_ids,
        'chan': chan_best,
        'depth': depth,
        'fr': firing_rates,
        'amp': template_amplitudes,
        'n_spikes': spike_counts
    }).set_index('cluster')

    # Save to CSV
    csv_path = results_dir / 'cluster_summary.csv'
    df.to_csv(csv_path)

    print(f"Saved cluster summary to: {csv_path}")


def plot_cluster_waveforms(results_dir,save_path=False):
    results_dir = Path(results_dir)

    ops = load_ops(results_dir / 'ops.npy')
    t = (np.arange(ops['nt']) / ops['fs']) * 1000

    all_spike_times = np.load(results_dir / 'spike_times.npy')

    # Pick a random good cluster, get its mean waveform and mean template
    cluster_id = get_good_cluster(results_dir, n=1)
    mean_wv, lower_bnd, upper_bnd, spike_subset = mean_waveform_with_bounds(cluster_id, results_dir, n_spikes=1000, bfile=None, best=True)
    mean_temp = cluster_templates(cluster_id, results_dir, mean=True, best=True, spike_subset=spike_subset)

    spike_times, _ = get_cluster_spikes(cluster_id, results_dir, n_spikes=1000)
    t2 = (spike_times / ops['fs'])
    t3 = t2 / 60
    chan = get_best_channels(results_dir)[cluster_id]
    waves = get_spike_waveforms(spike_times, results_dir, chan=chan)

    scores = np.zeros(waves.shape[1])
    times = np.zeros(waves.shape[1])
    for i in range(waves.shape[1]):
        times[i] = t3[i]
        scores[i] = wave_diff(waves[:, i],mean_wv)

    #############################################
    fig, ax = plt.subplots(1, 1)
    
    ax.fill_between(t, lower_bnd, upper_bnd, color='gray', alpha=0.25, label='±1 STD')
    ax.plot(t, mean_wv, c='black', linestyle='dashed', linewidth=2, label='waveform')
    ax.plot(t, mean_temp, linewidth=2, label='template')
    ax.set_title(f'Mean single-channel template and spike waveform for cluster {cluster_id}')
    ax.set_xlabel('Time (ms)')
    ax.set_ylabel('Voltage (µV)')
    ax.legend()
    plt.tight_layout()
    plt.show()
    
    #############################################

    fig2, ax2 = plt.subplots(1,1,figsize=(6,6))
    pos = ax2.imshow(waves.T, aspect='auto', extent=[t[0], t[-1], t3[0], t3[-1]]);

    cbar = fig2.colorbar(pos, ax=ax2)
    cbar.set_label('Voltage (µV)')  # Change to the appropriate unit if needed
    ax2.set_xlabel('Time (ms)');
    ax2.set_ylabel('Spike time (min)');

    #############################################

    fig3, ax3 = plt.subplots(1, 1, figsize=(8, 4))
    ax3.plot(times, gaussian_filter1d(scores, sigma=3), linestyle='-', color='black')  # Smoothed line
    ax3.axhline(y=0, linestyle='--', color='gray', linewidth=1)    # Dashed line at y=0
    ax3.set_xlabel('Spike Time (min)')
    ax3.set_ylabel('MSE (µV²)')
    ax3.set_title('Mean Squared Error Between Mean and Single-Spike Waveforms Over Time')
    plt.tight_layout()
    plt.show()

    #############################################

def wave_diff(a, b):
    a = np.asarray(a)
    b = np.asarray(b)
    mse = np.mean((a - b) ** 2)
    return mse  # Lower means more similar


In [10]:
results_dir = '/ix1/pmayo/lab_NHPdata/kendra_scrappy_0142a_g0/kendra_scrappy_0142a_g0_imec0/kilosort4_unleashed'

plot_cluster_waveforms(results_dir)

In [21]:
import spikeinterface.full as si

spikeglx_folder = Path(results_dir).parent

stream_names, stream_ids = si.get_neo_streams('spikeglx', spikeglx_folder)

stream_names = ['imec0.ap','imec1.ap','imec0.lf','imec1.lf']
ap_stream_name = next(name for name in stream_names if 'ap' in name)
recording = si.read_spikeglx(spikeglx_folder.parent, stream_name=ap_stream_name, load_sync_channel=False)

sorting = read_phy(results_dir)

analyzer = si.create_sorting_analyzer(sorting=sorting, recording=recording, format="memory")
print(analyzer)

estimate_sparsity (no parallelization):   0%|          | 0/7046 [00:00<?, ?it/s]

SortingAnalyzer: 202 channels - 303 units - 1 segments - memory - sparse - has recording
Loaded 0 extensions




In [40]:
from spikeinterface.postprocessing import compute_principal_components
from spikeinterface.qualitymetrics import (
    compute_snrs,
    compute_firing_rates,
    compute_isi_violations,
    calculate_pc_metrics,
    compute_quality_metrics,
)

analyzer.compute("random_spikes", method="uniform", max_spikes_per_unit=600, seed=2205)
analyzer.compute("waveforms", ms_before=1.3, ms_after=2.6, n_jobs=40)
analyzer.compute("templates", operators=["average", "median", "std"])
analyzer.compute("noise_levels")
analyzer.compute(input="isi_histograms",window_ms=50.0,bin_ms=1.0,method="auto")
analyzer.compute(input="correlograms",window_ms=50.0,bin_ms=1.0,method="auto")

qm_ext = analyzer.compute(input="quality_metrics", metric_names=['isi','ccg'], skip_pc_metrics=True)
#metrics = qm_ext.get_data()
#assert 'snr' in metrics.columns

print(analyzer)

compute_waveforms (workers: 40 processes):   0%|          | 0/7046 [00:00<?, ?it/s]

KeyboardInterrupt: 

In [None]:
import spikeinterface.widgets as sw

#w_isi = sw.plot_isi_distribution(sorting, unit_ids=[1], window_ms=150.0, bin_ms=5.0, figsize=(20, 8))

qm = sw.plot_quality_metrics(analyzer, unit_ids=[1], include_metrics=['isi','ccg'], figsize=(20, 8))