In [1]:
import spikeinterface as si
from spikeinterface.extractors import read_openephys
import spikeinterface.preprocessing as spre
import spikeinterface.sorters as ss
import spikeinterface.postprocessing as spost
import spikeinterface.qualitymetrics as sqm
import spikeinterface.comparison as sc
import spikeinterface.exporters as sexp
import spikeinterface.widgets as sw

from kilosort import io
from kilosort import run_kilosort
import torch

import json
import matplotlib.pyplot as plt
import numpy as np
from pathlib import Path
import glob
import os
from os import listdir
from os.path import isfile, join
import subprocess
import sys
from matplotlib import gridspec

import warnings
warnings.simplefilter("ignore")

In [6]:
clusters = np.load(Path('D:/HexinData') / 'Dylan_2024-03-25_13-58-44_HPC' / 'kilosort4'  / 'spike_clusters.npy')

In [8]:
np.unique(clusters)

array([ 0,  1,  2,  3,  4,  5,  6,  7,  8,  9, 10, 11, 12, 13, 14, 15, 16,
       17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32, 33,
       34, 35, 36, 37, 38])

In [2]:
def run_ks_files (DATA_DIRECTORY):
    
    recording = read_openephys(DATA_DIRECTORY)
    probe_path = 'D:/Kilosort-main/poly3_64.prb'
    probe = io.load_probe(probe_path)
    dtype = np.int16

    filename = DATA_DIRECTORY / 'data.bin'

    print(filename)

    if filename.is_file():
        print('FILE EXISTED')
    else:
        filename, N, c, s, fs, probe_path = io.spikeinterface_to_binary(
            recording, DATA_DIRECTORY, data_name='data.bin', dtype=dtype,
            chunksize=60000, export_probe=True, probe_name=probe_path
            )        

    ks_dir = DATA_DIRECTORY / 'kilosort4'
    if ks_dir.is_dir():
        print('SORTED')
    else:
        settings = {'n_chan_bin': 64, 'dmin': 25, 'dminx': 22,'nearest_templates':64}
        
        ops, st, clu, tF, Wall, similar_templates, is_ref, est_contam_rate = run_kilosort(
            settings=settings, probe=probe, filename=filename, data_dtype=dtype
            )
    
    print('~~~~~~~~~~~~~~~~~~~')

In [3]:
def extract_timing_ephys(base_dir,ttl_dir):    
    timefile = ttl_dir / 'timestamps.npy'
    statefile = ttl_dir / 'states.npy'
    outfile = base_dir / 'timestamps.data'
    
    a = np.load(timefile)
    b = np.load(statefile)
    
    c = a[b==1]
    c.shape = (c.size,1)
    
    d = (a[b==-1]-a[b==1])*1000
    d.shape = (d.size,1)
    
    np.savetxt(outfile,np.hstack([c,d]),fmt='%.6f,%.1f')

In [4]:
def time_alignment(DATA_DIRECTORY):

    print(DATA_DIRECTORY)
    
    base_folder = DATA_DIRECTORY / 'Record Node 113/experiment1/recording1'
    header_file = base_folder / 'structure.oebin'
    session_file = list(base_folder.glob("*.sqlite"))[0]
    timestamp_file = base_folder / 'timestamps.data'
    alignment_file = base_folder / 'alignmentinfo_README.txt'
    
    with open(header_file) as file_:
        header_data = json.load(file_)
        
    data_path = base_folder / 'continuous' / header_data['continuous'][0]['folder_name'] / 'continuous.dat'
    ttl_path = base_folder / 'events' / header_data['events'][0]['folder_name']

    # Extract timing signal if not done already
    if os.path.exists(timestamp_file):
        print('TIME EXTRACTED')
    else:
        extract_timing_ephys(base_folder,ttl_path)
        # For neuropixels recordings, run the following line instead:
        # subprocess.run(['python','D:/SpikeInterface/extract_timing.py',data_path,str(sample_rate),str(nchannels),str(nchannels-1),"2"])
    
    # Align timing if not done already
    if os.path.exists(alignment_file):
        print('TIME ALIGNED')
    else:
        subprocess.run(['D:/SpikeInterface/align_timestamps.exe',session_file,timestamp_file,alignment_file])

    print('~~~~~~~~~~~~~~~~~~~')

In [5]:
mypath = Path('D:\HexinData')
for ii,f in enumerate(listdir(mypath)):
    run_ks_files(mypath / f)

D:\HexinData\Dylan_2024-03-18_13-29-27_HPC\data.bin
FILE EXISTED
SORTED
~~~~~~~~~~~~~~~~~~~
D:\HexinData\Dylan_2024-03-19_14-42-22_HPC\data.bin
FILE EXISTED
SORTED
~~~~~~~~~~~~~~~~~~~
D:\HexinData\Dylan_2024-03-20_13-55-20_HPC\data.bin
FILE EXISTED
SORTED
~~~~~~~~~~~~~~~~~~~
D:\HexinData\Dylan_2024-03-21_14-55-03_HPC\data.bin
FILE EXISTED
SORTED
~~~~~~~~~~~~~~~~~~~
D:\HexinData\Dylan_2024-03-25_13-58-44_HPC\data.bin
Loading recording with SpikeInterface...
number of samples: 187229152
number of channels: 64
numbef of segments: 1
sampling rate: 30000.0
dtype: int16
Converting 3121 data chunks with a chunksize of 60000 samples...
27 of 3121 chunks converted...
49 of 3121 chunks converted...
62 of 3121 chunks converted...
80 of 3121 chunks converted...
94 of 3121 chunks converted...
112 of 3121 chunks converted...
121 of 3121 chunks converted...
143 of 3121 chunks converted...
154 of 3121 chunks converted...
172 of 3121 chunks converted...
185 of 3121 chunks converted...
200 of 3121 chunk

100%|██████████████████████████████████████████████████████████████████████████████| 3121/3121 [21:44<00:00,  2.39it/s]


drift computed in  1311.04s; total  1314.73s

Extracting spikes using templates
Re-computing universal templates from data.


100%|██████████████████████████████████████████████████████████████████████████████| 3121/3121 [21:32<00:00,  2.41it/s]


921110 spikes extracted in  1295.83s; total  2610.57s

First clustering


100%|██████████████████████████████████████████████████████████████████████████████████| 11/11 [01:02<00:00,  5.66s/it]


38 clusters found, in  62.36s; total  2672.93s

Extracting spikes using cluster waveforms


100%|██████████████████████████████████████████████████████████████████████████████| 3121/3121 [02:35<00:00, 20.04it/s]


1088662 spikes extracted in  155.86s; total  2828.79s

Final clustering


100%|██████████████████████████████████████████████████████████████████████████████████| 11/11 [01:16<00:00,  6.91s/it]


39 clusters found, in  76.02s; total  2904.82s

Merging clusters
39 units found, in  0.17s; total  2904.98s

Saving to phy and computing refractory periods
6 units found with good refractory periods

Total runtime: 2907.00s = 00:48:27 h:m:s
~~~~~~~~~~~~~~~~~~~
D:\HexinData\Dylan_2024-03-26_14-09-51_HPC\data.bin
Loading recording with SpikeInterface...
number of samples: 190608204
number of channels: 64
numbef of segments: 1
sampling rate: 30000.0
dtype: int16
Converting 3177 data chunks with a chunksize of 60000 samples...
24 of 3177 chunks converted...
46 of 3177 chunks converted...
59 of 3177 chunks converted...
68 of 3177 chunks converted...
79 of 3177 chunks converted...
94 of 3177 chunks converted...
107 of 3177 chunks converted...
118 of 3177 chunks converted...
127 of 3177 chunks converted...
146 of 3177 chunks converted...
157 of 3177 chunks converted...
166 of 3177 chunks converted...
179 of 3177 chunks converted...
191 of 3177 chunks converted...
202 of 3177 chunks converted.

100%|██████████████████████████████████████████████████████████████████████████████| 3177/3177 [21:46<00:00,  2.43it/s]


drift computed in  1313.69s; total  1316.18s

Extracting spikes using templates
Re-computing universal templates from data.


100%|██████████████████████████████████████████████████████████████████████████████| 3177/3177 [21:43<00:00,  2.44it/s]


1092825 spikes extracted in  1307.39s; total  2623.56s

First clustering


100%|██████████████████████████████████████████████████████████████████████████████████| 11/11 [01:07<00:00,  6.15s/it]


70 clusters found, in  67.77s; total  2691.33s

Extracting spikes using cluster waveforms


100%|██████████████████████████████████████████████████████████████████████████████| 3177/3177 [02:42<00:00, 19.57it/s]


1157245 spikes extracted in  162.54s; total  2853.87s

Final clustering


100%|██████████████████████████████████████████████████████████████████████████████████| 11/11 [01:11<00:00,  6.50s/it]


34 clusters found, in  71.46s; total  2925.32s

Merging clusters
32 units found, in  0.20s; total  2925.53s

Saving to phy and computing refractory periods
12 units found with good refractory periods

Total runtime: 2928.00s = 00:48:48 h:m:s
~~~~~~~~~~~~~~~~~~~
D:\HexinData\Dylan_2024-03-27_12-21-08_HPC\data.bin
Loading recording with SpikeInterface...
number of samples: 267245821
number of channels: 64
numbef of segments: 1
sampling rate: 30000.0
dtype: int16
Converting 4455 data chunks with a chunksize of 60000 samples...
40 of 4455 chunks converted...
56 of 4455 chunks converted...
66 of 4455 chunks converted...
77 of 4455 chunks converted...
90 of 4455 chunks converted...
104 of 4455 chunks converted...
115 of 4455 chunks converted...
129 of 4455 chunks converted...
141 of 4455 chunks converted...
155 of 4455 chunks converted...
166 of 4455 chunks converted...
178 of 4455 chunks converted...
191 of 4455 chunks converted...
205 of 4455 chunks converted...
218 of 4455 chunks converte

100%|██████████████████████████████████████████████████████████████████████████████| 4455/4455 [34:11<00:00,  2.17it/s]


drift computed in  2062.19s; total  2068.44s

Extracting spikes using templates
Re-computing universal templates from data.


100%|██████████████████████████████████████████████████████████████████████████████| 4455/4455 [30:33<00:00,  2.43it/s]


1151345 spikes extracted in  1838.55s; total  3906.99s

First clustering


100%|██████████████████████████████████████████████████████████████████████████████████| 11/11 [01:18<00:00,  7.16s/it]


94 clusters found, in  78.85s; total  3985.84s

Extracting spikes using cluster waveforms


100%|██████████████████████████████████████████████████████████████████████████████| 4455/4455 [04:41<00:00, 15.80it/s]


1645748 spikes extracted in  282.08s; total  4267.92s

Final clustering


100%|██████████████████████████████████████████████████████████████████████████████████| 11/11 [01:51<00:00, 10.12s/it]


61 clusters found, in  111.32s; total  4379.24s

Merging clusters
53 units found, in  0.45s; total  4379.69s

Saving to phy and computing refractory periods
28 units found with good refractory periods

Total runtime: 4383.30s = 01:13:3 h:m:s
~~~~~~~~~~~~~~~~~~~
D:\HexinData\Dylan_2024-03-28_14-32-11_HPC\data.bin
Loading recording with SpikeInterface...
number of samples: 257368131
number of channels: 64
numbef of segments: 1
sampling rate: 30000.0
dtype: int16
Converting 4290 data chunks with a chunksize of 60000 samples...
30 of 4290 chunks converted...
50 of 4290 chunks converted...
65 of 4290 chunks converted...
77 of 4290 chunks converted...
91 of 4290 chunks converted...
103 of 4290 chunks converted...
114 of 4290 chunks converted...
130 of 4290 chunks converted...
143 of 4290 chunks converted...
155 of 4290 chunks converted...
168 of 4290 chunks converted...
179 of 4290 chunks converted...
192 of 4290 chunks converted...
202 of 4290 chunks converted...
217 of 4290 chunks converte

100%|██████████████████████████████████████████████████████████████████████████████| 4290/4290 [31:51<00:00,  2.24it/s]


drift computed in  1922.48s; total  1925.94s

Extracting spikes using templates
Re-computing universal templates from data.


100%|██████████████████████████████████████████████████████████████████████████████| 4290/4290 [29:23<00:00,  2.43it/s]


1429872 spikes extracted in  1768.49s; total  3694.44s

First clustering


100%|██████████████████████████████████████████████████████████████████████████████████| 11/11 [01:30<00:00,  8.21s/it]


98 clusters found, in  90.44s; total  3784.88s

Extracting spikes using cluster waveforms


100%|██████████████████████████████████████████████████████████████████████████████| 4290/4290 [04:32<00:00, 15.76it/s]


1579192 spikes extracted in  272.29s; total  4057.16s

Final clustering


100%|██████████████████████████████████████████████████████████████████████████████████| 11/11 [01:36<00:00,  8.82s/it]


59 clusters found, in  97.03s; total  4154.19s

Merging clusters
57 units found, in  0.43s; total  4154.62s

Saving to phy and computing refractory periods
27 units found with good refractory periods

Total runtime: 4158.07s = 01:09:18 h:m:s
~~~~~~~~~~~~~~~~~~~


In [None]:
%matplotlib inline
import matplotlib.pyplot as plt
from matplotlib import gridspec, rcParams
rcParams['axes.spines.top'] = False
rcParams['axes.spines.right'] = False
gray = .5 * np.ones(3)

fig = plt.figure(figsize=(10,10), dpi=100)
grid = gridspec.GridSpec(3, 3, figure=fig, hspace=0.5, wspace=0.5)

ax = fig.add_subplot(grid[0,0])
ax.plot(np.arange(0, ops['Nbatches'])*2, dshift);
ax.set_xlabel('time (sec.)')
ax.set_ylabel('drift (um)')

ax = fig.add_subplot(grid[0,1:])
t0 = 0
t1 = np.nonzero(st > ops['fs']*5)[0][0]
ax.scatter(st[t0:t1]/30000., chan_best[clu[t0:t1]], s=0.5, color='k', alpha=0.25)
ax.set_xlim([0, 5])
ax.set_ylim([chan_map.max(), 0])
ax.set_xlabel('time (sec.)')
ax.set_ylabel('channel')
ax.set_title('spikes from units')

ax = fig.add_subplot(grid[1,0])
nb=ax.hist(firing_rates, 20, color=gray)
ax.set_xlabel('firing rate (Hz)')
ax.set_ylabel('# of units')

ax = fig.add_subplot(grid[1,1])
nb=ax.hist(camps, 20, color=gray)
ax.set_xlabel('amplitude')
ax.set_ylabel('# of units')

ax = fig.add_subplot(grid[1,2])
nb=ax.hist(np.minimum(100, contam_pct), np.arange(0,105,5), color=gray)
ax.plot([10, 10], [0, nb[0].max()], 'k--')
ax.set_xlabel('% contamination')
ax.set_ylabel('# of units')
ax.set_title('< 10% = good units')

for k in range(2):
    ax = fig.add_subplot(grid[2,k])
    is_ref = contam_pct<10.
    ax.scatter(firing_rates[~is_ref], camps[~is_ref], s=3, color='r', label='mua', alpha=0.25)
    ax.scatter(firing_rates[is_ref], camps[is_ref], s=3, color='b', label='good', alpha=0.25)
    ax.set_ylabel('amplitude (a.u.)')
    ax.set_xlabel('firing rate (Hz)')
    ax.legend()
    if k==1:
        ax.set_xscale('log')
        ax.set_yscale('log')
        ax.set_title('loglog')

In [None]:
probe = ops['probe']
# x and y position of probe sites
xc, yc = probe['xc'], probe['yc']
nc = 16 # number of channels to show
good_units = np.nonzero(contam_pct <= 0.1)[0]
mua_units = np.nonzero(contam_pct > 0.1)[0]


gstr = ['good', 'mua']
for j in range(2):
    print(f'~~~~~~~~~~~~~~ {gstr[j]} units ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~')
    print('title = number of spikes from each unit')
    units = good_units if j==0 else mua_units
    fig = plt.figure(figsize=(12,3), dpi=150)
    grid = gridspec.GridSpec(2,20, figure=fig, hspace=0.25, wspace=0.5)

    for k in range(40):
        wi = units[np.random.randint(len(units))]
        wv = templates[wi].copy()
        cb = chan_best[wi]
        nsp = (clu==wi).sum()

        ax = fig.add_subplot(grid[k//20, k%20])
        n_chan = wv.shape[-1]
        ic0 = max(0, cb-nc//2)
        ic1 = min(n_chan, cb+nc//2)
        wv = wv[:, ic0:ic1]
        x0, y0 = xc[ic0:ic1], yc[ic0:ic1]

        amp = 4
        for ii, (xi,yi) in enumerate(zip(x0,y0)):
            t = np.arange(-wv.shape[0]//2,wv.shape[0]//2,1,'float32')
            t /= wv.shape[0] / 20
            ax.plot(xi + t, yi + wv[:,ii]*amp, lw=0.5, color='k')

        ax.set_title(f'{nsp}', fontsize='small')
        ax.axis('off')
    plt.show()

In [25]:
# To Run Phy:
# Activate conda environment: 'conda activate phy2'
# Run phy: 'phy template-gui params.py'