# Spike-finding comparison


In [1]:
from matplotlib import pyplot as plt
from matplotlib.patches import Polygon
import numpy as np

import os
import openephys_utils

from scipy import signal

%matplotlib qt

from tqdm.notebook import tqdm

import kilosort

Pop out the threshold crossings from each channel for a particular directory

In [2]:
directory = 'Z:\\BrainPatch\\20241002\\lateral\\Crimson__2024-10-02_12-21-01__20mA_2ms_400um'
filename = os.path.join(directory, 'sig_eraasr.npy')

if not os.path.exists(os.path.join(directory, 'sig_eraasr.npy')):
    # read everything
    sig, timestamps, stims, stim_ts = openephys_utils.open_sig_stims(directory)

    # clean, pull out threshold crossings, get firing rates
    sig_eraasr = openephys_utils.ERAASR(sig)
    spike_dict = openephys_utils.threshold_crossings(sig_eraasr, multi_rejection=None, low_cutoff=-20)

    filename = os.path.join(directory, 'sig_eraasr.npy')
    np.save(filename, sig_eraasr.astype(np.float32))

else:
    sig_eraasr = np.load(filename)
    spike_dict = openephys_utils.threshold_crossings(sig_eraasr, multi_rejection=None, low_cutoff=-20)

kilosort

In [3]:
settings = {'filename':filename,
            'probe_name':'Z:\\BrainPatch\\20241002\\64-4shank-poly-brainpatch-chanMap.mat',
            'n_chan_bin':64,
            'nearest_chans':1}

kilosort.run_kilosort(settings, file_object= sig_eraasr.astype(np.float32), data_dtype='float32', results_dir=os.path.join(directory, 'kilosort4_unfiltered'))

# load spike times, template shape, and channel number
spike_times = np.load(os.path.join(directory, 'kilosort4_unfiltered\\spike_times.npy'))
channel_map = np.load(os.path.join(directory, 'kilosort4_unfiltered\\channel_map.npy'))
templates = np.load(os.path.join(directory, 'kilosort4_unfiltered\\templates.npy'))
spike_templates = np.load(os.path.join(directory, 'kilosort4_unfiltered\\spike_templates.npy'))

# "best channel" for each template -- based on the greatest variance
channel_best = (templates**2).sum(axis=1).argmax(axis=-1)

kilosort.run_kilosort: Kilosort version 4.0.17
kilosort.run_kilosort: Sorting Z:\BrainPatch\20241002\lateral\Crimson__2024-10-02_12-21-01__20mA_2ms_400um\sig_eraasr.npy
kilosort.run_kilosort: ----------------------------------------
kilosort.run_kilosort: Using GPU for PyTorch computations. Specify `device` to change this.
kilosort.run_kilosort:  
kilosort.run_kilosort: Computing preprocessing variables.
kilosort.run_kilosort: ----------------------------------------
kilosort.run_kilosort: N samples: 2164480
kilosort.run_kilosort: N seconds: 72.14933333333333
kilosort.run_kilosort: N batches: 37
kilosort.run_kilosort: Preprocessing filters computed in  3.06s; total  3.07s
kilosort.run_kilosort:  
kilosort.run_kilosort: Computing drift correction.
kilosort.run_kilosort: ----------------------------------------
kilosort.spikedetect: Re-computing universal templates from data.
100%|██████████| 37/37 [00:16<00:00,  2.25it/s]
kilosort.run_kilosort: drift computed in  17.49s; total  20.60s
k

In [4]:
num_matches = np.empty_like(channel_best)

aligned_unfiltered = []
for i_channel, channel in enumerate(channel_best):
    # find the times of a particular template
    same_template = spike_templates == i_channel
    same_ts = spike_times[same_template]

    # find the closest timestamp 
    aligned_unfiltered.extend([np.abs(spike_dict[channel]['sample_no'] - ts).min() for ts in same_ts])


kilosort on filtered data

In [5]:
# sos = signal.butter(8, [300, 6000], 'bandpass', output='sos', fs=30000)
# sig_filter = signal.sosfiltfilt(sos, sig_eraasr, axis=0)

# filename = os.path.join(directory, 'sig_filter.npy')
# np.save(filename, sig_filter)

# settings = {'filename':filename,
#             'probe_name':'Z:\\BrainPatch\\20241002\\64-4shank-poly-brainpatch-chanMap.mat',
#             'n_chan_bin':64,
#             'nearest_chans':1}

# kilosort.run_kilosort(settings, file_object= sig_filter.astype(np.float32), data_dtype='float32', results_dir=os.path.join(directory, 'kilosort4_filter'))

# load spike times, template shape, and channel number
spike_times = np.load(os.path.join(directory, 'kilosort4_filter\\spike_times.npy'))
channel_map = np.load(os.path.join(directory, 'kilosort4_filter\\channel_map.npy'))
templates = np.load(os.path.join(directory, 'kilosort4_filter\\templates.npy'))
spike_templates = np.load(os.path.join(directory, 'kilosort4_filter\\spike_templates.npy'))

# "best channel" for each template -- based on the greatest variance
channel_best = (templates**2).sum(axis=1).argmax(axis=-1)

In [6]:
num_matches = np.empty_like(channel_best)

aligned_filtered = []
for i_channel, channel in enumerate(channel_best):
    # find the times of a particular template
    same_template = spike_templates == i_channel
    same_ts = spike_times[same_template]

    # find the closest timestamp 
    aligned_filtered.extend([np.abs(spike_dict[channel]['sample_no'] - ts).min() for ts in same_ts])


Plot the filtered and unfiltered

In [7]:
fig,ax = plt.subplots()

ax.hist(aligned_unfiltered, bins=np.arange(50), label='Unfiltered', histtype='step')
ax.hist(aligned_filtered, bins=np.arange(50), label='Filtered', histtype='step')

ax.legend()
ax.set_xlabel('Closest spike on channel (samples)')
ax.set_ylabel('Number of spikes')

for spine in ax.spines:
    # ax.spines[spine].set(['visible',False])
    ax.spines[spine].set_visible(False)

In [10]:
len(aligned_unfiltered)

31954