# Spike sort

This notebook is a modified version of the *3-sort_spikes* in the chronic ephys processing pipeline

This notebook allows you to concatenate multiple recordings to be spike sorted together. *Be careful, this is only recommended for consecutive recordings on the same date.*

Use the environment **spikeproc** to run this notebook

In [1]:
import numpy as np
import os
os.environ["NPY_MATLAB_PATH"] = '/mnt/cube/chronic_ephys/code/npy-matlab'
os.environ["KILOSORT2_PATH"] = '/mnt/cube/chronic_ephys/code/Kilosort2'
os.environ["KILOSORT3_PATH"] = '/mnt/cube/chronic_ephys/code/Kilosort'
import spikeinterface.full as si
import sys
import traceback
sys.path.append('/mnt/cube/lo/envs/ceciestunepipe/')
from ceciestunepipe.file import bcistructure as et
from ceciestunepipe.mods import probe_maps as pm

## Set parameters

In [2]:
# non default spike sorting parameters
sort_params_dict = {'minFR':0.001, 'minfr_goodchannels':0.001}

# waveform extraction parameters
wave_params_dict = {'ms_before':1, 'ms_after':2, 'max_spikes_per_unit':500,
                    'sparse':True, 'num_spikes_for_sparsity':100, 'method':'radius',
                    'radius_um':40, 'n_components':5, 'mode':'by_channel_local'}

# print stuff
verbose = True

# errors break sorting
raise_error = False

# restrict sorting to a specific GPU
restrict_to_gpu = 1 # 0 1 None

# use specific GPU if specified
if restrict_to_gpu is not None:
    os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID"
    os.environ["CUDA_VISIBLE_DEVICES"] = "{}".format(restrict_to_gpu)

# parallel processing params
job_kwargs = dict(n_jobs=28,chunk_duration="1s",progress_bar=False)
si.set_global_job_kwargs(**job_kwargs)

# force processing of previous failed sorts
skip_failed = False

In [3]:
sess_par = {
    'bird':'z_y19o20_21', # bird ID
    'sess':'2021-10-27', # session (will process all epochs within this date)
    'probes':{}, # probes of interest (needed for oe, not needed for sglx)
    'sort':'sort_0', # label for this sort instance
    'sorter':'kilosort3', # sort method
    'sort_params':sort_params_dict, # non-default sort params
    'wave_params':wave_params_dict, # waveform extraction params
    'ephys_software':'sglx' # sglx or oe
}

## Run sorts

In [4]:
%%time

# get epochs
sess_epochs = et.list_ephys_epochs(sess_par)
concat_epochs = '-'.join(sess_epochs)

# set output directory
epoch_struct = et.sgl_struct(sess_par,concat_epochs,ephys_software=sess_par['ephys_software'])
sort_folder = epoch_struct['folders']['derived'] + '/{}/{}/'.format(sess_par['sorter'],sess_par['sort'])

# get spike sort log
log_dir = os.path.join('/mnt/cube/chronic_ephys/log', sess_par['bird'], sess_par['sess'])
try:
    with open(os.path.join(log_dir, concat_epochs+'_spikesort.log'), 'r') as f:
        log_message=f.readline() # read the first line of the log file
    if log_message[:-1] == sess_par['bird']+' '+sess_par['sess']+' sort complete without error':
        print(sess_par['bird'],sess_par['sess'],'already exists -- skipping sort')
        run_proc = False
    elif log_message[:-1] == sess_par['bird']+' '+sess_par['sess']+' sort failed':
        if skip_failed:
            print(sess_par['bird'],sess_par['sess'],'previously failed -- skipping sort')
            run_proc = False
        else:
            run_proc = True
    else: # uninterpretable log file
        run_proc = True
except: # no existing log file
    run_proc = True

# run sort
if run_proc:
    try: 
        # load recordings
        rec_path = '/'.join(epoch_struct['folders']['sglx'].split('/')[:-1])
        epoch_list = [e for e in os.listdir(rec_path) if os.path.isdir(os.path.join(rec_path,e))]
        recording_list = []
        for this_epoch in epoch_list:
            this_rec = si.read_spikeglx(folder_path=os.path.join(rec_path,this_epoch), stream_name='imec0.ap')
            
            # ibl destriping
            this_rec = si.highpass_filter(recording=this_rec)
            this_rec = si.phase_shift(recording=this_rec)
            bad_good_channel_ids = si.detect_bad_channels(recording=this_rec)
            if len(bad_good_channel_ids[0]) > 0:
                this_rec = si.interpolate_bad_channels(recording=this_rec,bad_channel_ids=bad_good_channel_ids[0])
            this_rec = si.highpass_spatial_filter(recording=this_rec)
            
            recording_list.append(this_rec)
        
        # concatenate recordings
        this_rec_p = si.concatenate_recordings(recording_list)
        
        # set sort params
        sort_params = si.get_default_sorter_params(sess_par['sorter'])
        for this_param in sess_par['sort_params'].keys():
            sort_params[this_param] = sess_par['sort_params'][this_param]
        # run sort
        print('sort..')
        this_sort = si.run_sorter(sorter_name=sess_par['sorter'],recording=this_rec_p,output_folder=sort_folder,
                             remove_existing_folder=True,delete_output_folder=False,delete_container_files=False,
                             verbose=verbose,raise_error=raise_error,**sort_params)
        # bandpass recording before waveform extraction
        print('bandpass..')
        this_rec_pf = si.bandpass_filter(recording=this_rec_p)
        # extract waveforms
        print('waveform..')
        wave_params = sess_par['wave_params']
        wave = si.extract_waveforms(this_rec_pf,this_sort,folder=os.path.join(sort_folder,'waveforms'),
                                    ms_before=wave_params['ms_before'],ms_after=wave_params['ms_after'],
                                    max_spikes_per_unit=wave_params['max_spikes_per_unit'],
                                    sparse=wave_params['sparse'],num_spikes_for_sparsity=wave_params['num_spikes_for_sparsity'],
                                    method=wave_params['method'],radius_um=wave_params['radius_um'],overwrite=True,**job_kwargs)
        # compute metrics
        print('metrics..')
        loc = si.compute_unit_locations(waveform_extractor=wave)
        cor = si.compute_correlograms(waveform_or_sorting_extractor=wave)
        sim = si.compute_template_similarity(waveform_extractor=wave)
        amp = si.compute_spike_amplitudes(waveform_extractor=wave,**job_kwargs)
        pca = si.compute_principal_components(waveform_extractor=wave,n_components=wave_params['n_components'],
                                              mode=wave_params['mode'],**job_kwargs)
        met = si.compute_quality_metrics(waveform_extractor=wave,verbose=verbose,**job_kwargs)

        # mark complete
        print('COMPLETE!!')

        # log complete sort
        if not os.path.exists(log_dir): os.makedirs(log_dir)
        with open(os.path.join(log_dir, concat_epochs+'_spikesort.log'), 'w') as f:
            f.write(sess_par['bird']+' '+sess_par['sess']+' sort complete without error\n')
        sort_summary = [sess_par['bird'],sess_par['sess'],sess_par['ephys_software'],concat_epochs,'COMPLETE']
                    
    except Exception as e:
        # mark exception
        print("An exception occurred:", e)

        # log failed sort[os.path.join(d, o) for o in os.listdir(d) if os.path.isdir(os.path.join(d,o))]
        if not os.path.exists(log_dir): os.makedirs(log_dir)
        with open(os.path.join(log_dir, concat_epochs+'_spikesort.log'), 'w') as f:
            f.write(sess_par['bird']+' '+sess_par['sess']+' sort failed\n')
            f.write(traceback.format_exc())
        sort_summary = [sess_par['bird'],sess_par['sess'],sess_par['ephys_software'],concat_epochs,'FAIL']
else:
    sort_summary = [sess_par['bird'],sess_par['sess'],sess_par['ephys_software'],concat_epochs,'EXISTS']

sort..
RUNNING SHELL SCRIPT: /mnt/cube/chronic_ephys/der/z_y19o20_21/2021-10-27/sglx/1033_undirected_g0-1142_directed_g0/kilosort3/sort_0/sorter_output/run_kilosort3.sh


                            < M A T L A B (R) >

                  Copyright 1984-2023 The MathWorks, Inc.

                  R2023a (9.14.0.2206163) 64-bit (glnxa64)

                             February 22, 2023




 

To get started, type doc.

For product information, visit www.mathworks.com.

 

Time   0s. Computing whitening matrix.. 

Getting channel whitening matrix... 

Channel-whitening matrix computed. 

Time  84s. Loading raw data and applying filters... 

Time 1855s. Finished preprocessing 3051 batches. 

Drift correction ENABLED

vertical pitch size is 20 

horizontal pitch size is 32 

     0    16    32    48



   766



0.35 sec, 1 batches, 4812 spikes 

42.25 sec, 101 batches, 487940 spikes 

87.85 sec, 201 batches, 948560 spikes 

134.16 sec, 301 batches, 1297559 spikes 

175.41 sec, 401 batches, 



Computing amplitude_median
COMPLETE!!
['z_y19o20_21', '2021-10-27', 'sglx', '1033_undirected_g0-1142_directed_g0', 'COMPLETE']
CPU times: user 1h 22min 31s, sys: 37.3 s, total: 1h 23min 8s
Wall time: 4h 39min 7s


In [None]:
print(sort_summary)