# Spike sort

Notebook within the chronic ephys processing pipeline
- 1-preprocess_acoustics
- 2-curate_acoustics
- **3-sort_spikes**
- 4-curate_spikes

Use the environment **spikeproc** to run this notebook

In [None]:
import numpy as np
import os
import pickle
os.environ["NPY_MATLAB_PATH"] = '/mnt/cube/chronic_ephys/code/npy-matlab'
os.environ["KILOSORT2_PATH"] = '/mnt/cube/chronic_ephys/code/Kilosort2'
os.environ["KILOSORT3_PATH"] = '/mnt/cube/chronic_ephys/code/Kilosort'
import spikeinterface.full as si
import sys
import traceback
import torch
sys.path.append('/mnt/cube/lo/envs/ceciestunepipe/')
from ceciestunepipe.file import bcistructure as et
from ceciestunepipe.mods import probe_maps as pm

## Set parameters

In [2]:
si.get_default_sorter_params('kilosort4')

{'batch_size': 60000,
 'nblocks': 1,
 'Th_universal': 9,
 'Th_learned': 8,
 'do_CAR': True,
 'invert_sign': False,
 'nt': 61,
 'artifact_threshold': None,
 'nskip': 25,
 'whitening_range': 32,
 'binning_depth': 5,
 'sig_interp': 20,
 'nt0min': None,
 'dmin': None,
 'dminx': None,
 'min_template_size': 10,
 'template_sizes': 5,
 'nearest_chans': 10,
 'nearest_templates': 100,
 'templates_from_data': True,
 'n_templates': 6,
 'n_pcs': 6,
 'Th_single_ch': 6,
 'acg_threshold': 0.2,
 'ccg_threshold': 0.25,
 'cluster_downsampling': 20,
 'cluster_pcs': 64,
 'duplicate_spike_bins': 15,
 'do_correction': True,
 'keep_good_only': False,
 'save_extra_kwargs': False,
 'skip_kilosort_preprocessing': False,
 'scaleproc': None}

#### Set `dmin` and `dminx`
**Setting these appropriately will greatly reduce sort time**
- The default value for dmin is the median distance between contacts -- if contacts are irregularly spaced, like in a modular Neuropixels 2.0 setup, will need to specify a value
- The default for dminx is 32um (designed for Neuropixels probes)

Support documentation [here](https://kilosort.readthedocs.io/en/latest/parameters.html#dmin-and-dminx)

In [3]:
# non default spike sorting parameters
sort_params_dict_ks3 = {'minFR':0.001, 'minfr_goodchannels':0.001} # kilosort 3
sort_params_dict_ks4_npx = {'nblocks':5, 'Th_universal':8, 'Th_learned':7, 'dmin':15, 'dminx':32} # kilosort 4, neuropixels (set dmin and dminx to true pitch)
sort_params_dict_ks4_nnx64 = {'nblocks':0, 'nearest_templates':64,
                              'Th_universal':8, 'Th_learned':7} # kilosort 4, neuronexus 64 chan

# waveform extraction parameters
wave_params_dict = {'ms_before':1, 'ms_after':2, 'max_spikes_per_unit':500,
                    'sparse':True, 'num_spikes_for_sparsity':100, 'method':'radius',
                    'radius_um':40, 'n_components':5, 'mode':'by_channel_local'}

# print stuff
verbose = True

# errors break sorting
raise_error = False

# restrict sorting to a specific GPU
restrict_to_gpu = 1 # 0 1 None

# use specific GPU if specified
if restrict_to_gpu is not None:
    os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID"
    os.environ["CUDA_VISIBLE_DEVICES"] = "{}".format(restrict_to_gpu)

# parallel processing params
job_kwargs = dict(n_jobs=28,chunk_duration="1s",progress_bar=False)
si.set_global_job_kwargs(**job_kwargs)

# force processing of previous failed sorts
skip_failed = False

In [4]:
bird_rec_dict = {
    'z_p5y10_23':[
        {'sess_par_list':['2024-05-16'], # sessions (will process all epochs within)
         'probe':{'probe_type':'neuropixels-2.0'}, # probe specs
         'sort':'sort_0', # label for this sort instance
         'sorter':'kilosort4', # sort method
         'sort_params':sort_params_dict_ks4_npx, # non-default sort params
         'wave_params':wave_params_dict, # waveform extraction params
         'ephys_software':'sglx' # sglx or oe
        },
    ],
}

## Run sorts

In [5]:
%%time

# store sort summaries
sort_summary_all = []

# loop through all birds / recordings
for this_bird in bird_rec_dict.keys():
    # get session configurations
    sess_all = bird_rec_dict[this_bird]
    
    # loop through session configurations
    for this_sess_config in sess_all:
        
        # loop through sessions
        for this_sess in this_sess_config['sess_par_list']:
            log_dir = os.path.join('/mnt/cube/chronic_ephys/log', this_bird, this_sess)
            
            # build session parameter dictionary
            sess_par = {'bird':this_bird,
                        'sess':this_sess,
                        'ephys_software':this_sess_config['ephys_software'],
                        'sorter':this_sess_config['sorter'],
                        'sort':this_sess_config['sort']}
            # get epochs
            sess_epochs = et.list_ephys_epochs(sess_par)
            
            for this_epoch in sess_epochs:
                
                # set output directory
                epoch_struct = et.sgl_struct(sess_par,this_epoch,ephys_software=sess_par['ephys_software'])
                sess_par['epoch'] = this_epoch
                sort_path = epoch_struct['folders']['derived'] + '/{}/{}/'.format(sess_par['sorter'],sess_par['sort'])
                sorting_analyzer_path = sort_path + 'sorting_analyzer/'
                
                # get spike sort log
                try:
                    with open(os.path.join(log_dir, this_epoch+'_spikesort_'+this_sess_config['sort']+'.log'), 'r') as f:
                        log_message=f.readline() # read the first line of the log file
                    if log_message[:-1] == sess_par['bird']+' '+sess_par['sess']+' '+this_epoch+' sort complete without error':
                        print(sess_par['bird'],sess_par['sess'],this_epoch,'already exists -- skipping sort')
                        run_proc = False
                    elif log_message[:-1] == sess_par['bird']+' '+sess_par['sess']+' '+this_epoch+' sort failed':
                        if skip_failed:
                            print(sess_par['bird'],sess_par['sess'],this_epoch,'previously failed -- skipping sort')
                            run_proc = False
                        else:
                            run_proc = True
                    else: # uninterpretable log file
                        run_proc = True
                except: # no existing log file
                    run_proc = True
                
                # run sort
                if run_proc:
                    try:
                        print('___________',this_bird,this_sess,this_epoch,'___________')
                        # prepare recording for sorting
                        print('prep..')
                        if sess_par['ephys_software'] == 'sglx':
                            # load recording
                            rec_path = epoch_struct['folders']['sglx']
                            this_rec = si.read_spikeglx(folder_path=rec_path,stream_name='imec0.ap')
                            # save probe map prior to re-ordering for sorting
                            probe_df = this_rec.get_probe().to_dataframe()
                            probe_df.to_pickle(os.path.join(epoch_struct['folders']['derived'],'probe_map_df.pickle'))
                            # ibl destriping
                            this_rec = si.highpass_filter(recording=this_rec)
                            this_rec = si.phase_shift(recording=this_rec)
                            bad_good_channel_ids = si.detect_bad_channels(recording=this_rec)
                            if len(bad_good_channel_ids[0]) > 0:
                                this_rec = si.interpolate_bad_channels(recording=this_rec,bad_channel_ids=bad_good_channel_ids[0])
                            if this_sess_config['probe']['probe_type'] == 'neuropixels-2.0':
                                # highpass by shank
                                split_rec = this_rec.split_by(property='group',outputs='list') # split recording by shank
                                split_rec = [si.highpass_spatial_filter(recording=r,n_channel_pad=min(r.get_num_channels(),60)) for r in split_rec]
                                this_rec_p = si.aggregate_channels(split_rec) # recombine shanks
                                # stack shanks
                                p,_ = pm.stack_shanks(probe_df) # make new Probe object with shanks stacked
                                this_rec_p = this_rec.set_probe(p,group_mode='by_probe') # assign new Probe object to probe
                            else:
                                this_rec_p = si.highpass_spatial_filter(recording=this_rec)
                        elif sess_par['ephys_software'] =='oe':
                            # load recording
                            rec_path = [f.path for f in os.scandir(epoch_struct['folders']['oe']) if f.is_dir()][0]
                            this_rec = si.read_openephys(folder_path=rec_path)
                            # add probe
                            this_probe = pm.make_probes(this_sess_config['probe']['probe_type'],this_sess_config['probe']['probe_model']) # neuronexus, Buzsaki64
                            this_rec_p = this_rec.set_probe(this_probe,group_mode='by_shank')
                        # set sort params
                        this_rec_p = si.concatenate_recordings([this_rec_p])
                        sort_params = si.get_default_sorter_params(this_sess_config['sorter'])
                        for this_param in this_sess_config['sort_params'].keys():
                            sort_params[this_param] = this_sess_config['sort_params'][this_param]
                        # run sort
                        print('sort..')
                        torch.cuda.empty_cache()
                        this_sort = si.run_sorter(sorter_name=this_sess_config['sorter'],recording=this_rec_p,output_folder=sort_path,
                                             remove_existing_folder=True,delete_output_folder=False,delete_container_files=False,
                                             verbose=verbose,raise_error=raise_error,**sort_params)
                        torch.cuda.empty_cache()
                        # bandpass recording before running analyzer
                        this_rec_pf = si.bandpass_filter(recording=this_rec_p)
                        # run sorting analyzer
                        print('sorting analyzer..')
                        analyzer = si.create_sorting_analyzer(sorting=this_sort,recording=this_rec_pf,format="binary_folder",
                                                              sparse=True,return_scaled=True,folder=sorting_analyzer_path)
                        ext_compute_all = analyzer.get_computable_extensions()
                        for this_ext in ext_compute_all:
                            print(this_ext + '..')
                            analyzer.compute(this_ext)
                        
                        # mark complete
                        print('COMPLETE!!')

                        # log complete sort
                        if not os.path.exists(log_dir): os.makedirs(log_dir)
                        with open(os.path.join(log_dir, this_epoch+'_spikesort_'+this_sess_config['sort']+'.log'), 'w') as f:
                            f.write(sess_par['bird']+' '+sess_par['sess']+' '+this_epoch+' sort complete without error\n\n')
                            f.write('Sort method: '+this_sess_config['sorter']+'\n\n')
                            f.write('Sort params: '+str(sort_params)+'\n\n')
                            f.write('Computable extensions: '+str(ext_compute_all)+'\n\n')
                        sort_summary = [this_bird,this_sess,sess_par['ephys_software'],this_epoch,'COMPLETE']
                    
                    except Exception as e:
                        # mark exception
                        print("An exception occurred:", e)
                        
                        # log failed sort
                        if not os.path.exists(log_dir): os.makedirs(log_dir)
                        with open(os.path.join(log_dir, this_epoch+'_spikesort_'+this_sess_config['sort']+'.log'), 'w') as f:
                            f.write(sess_par['bird']+' '+sess_par['sess']+' '+this_epoch+' sort failed\n')
                            f.write(traceback.format_exc())
                        sort_summary = [this_bird,this_sess,sess_par['ephys_software'],this_epoch,'FAIL']
                else:
                    sort_summary = [this_bird,this_sess,sess_par['ephys_software'],this_epoch,'EXISTS']
                
                # report and store sort summary
                print(sort_summary)
                sort_summary_all.append(sort_summary)

___________ z_p5y10_23 2024-05-16 1246_g0 ___________
prep..
sort..
Loading recording with SpikeInterface...
number of samples: 368121306
number of channels: 384
number of segments: 1
sampling rate: 30000.0
dtype: int16
Preprocessing filters computed in  2077.34s; total  2077.34s

computing drift
Re-computing universal templates from data.


100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 6136/6136 [15:52:26<00:00,  9.31s/it]


drift computed in  59153.07s; total  61230.41s

Extracting spikes using templates
Re-computing universal templates from data.


100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 6136/6136 [16:23:41<00:00,  9.62s/it]


19999820 spikes extracted in  61300.19s; total  122530.61s

First clustering


100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 107/107 [19:49<00:00, 11.11s/it]


1217 clusters found, in  1197.43s; total  123728.03s

Extracting spikes using cluster waveforms


100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 6136/6136 [14:52:18<00:00,  8.73s/it]


37481401 spikes extracted in  53546.84s; total  177274.87s

Final clustering


100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 107/107 [26:00<00:00, 14.58s/it]


829 clusters found, in  1560.73s; total  178835.74s

Merging clusters
742 units found, in  140.57s; total  178976.31s

Saving to phy and computing refractory periods
338 units found with good refractory periods

Total runtime: 179044.84s = 49:2984:4 h:m:s
kilosort4 run time 179047.45s
bandpass..
waveform..
metrics..
An exception occurred: [Errno 13] Permission denied: '/tmp/spikeinterface_cache/tmpnq7ml3iz'
['z_p5y10_23', '2024-05-16', 'sglx', '1246_g0', 'FAIL']
___________ z_p5y10_23 2024-05-16 1611_g0 ___________
prep..
sort..
Loading recording with SpikeInterface...
number of samples: 158273746
number of channels: 384
number of segments: 1
sampling rate: 30000.0
dtype: int16
Preprocessing filters computed in  878.17s; total  878.28s

computing drift
Re-computing universal templates from data.


100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 2638/2638 [7:23:47<00:00, 10.09s/it]


drift computed in  27475.47s; total  28353.88s

Extracting spikes using templates
Re-computing universal templates from data.


 48%|████████████████████████████████████████████████████████████████████████████▉                                                                                  | 1276/2638 [3:13:08<3:26:09,  9.08s/it]


KeyboardInterrupt: 