## Spike Processing Pipeline ver 0.2
This notebook specifes a preprocessing pipeline for neural data colleted on blackrock aquisition devices in the Hatsopoulos lab. The file assumes that spikeinterface(ver 0.13 - 0.10), python-neo, matplotlib, numpy, PHY, and pandas are installed.

The code needs a datafile (ns6) and a probe file (you can find tony's its at /Paul/prbfiles/TY_array.prb). This script will automatically precondition the data (filtering and referencing), attach probe infomation (via .prb file) and pull out LFP data. The file will then sort the neural data with a number of different sorters, and comparitive sort to find the units common to all of the sorters (see https://elifesciences.org/articles/61834 ). Finally, it will convert the automatic sorting of one sorter (It is recommened to use IronClust as your primary sorter) and mark the common units as 'good' in the phy output.

The user would then mark each cluster as good (SU), MUA (multiunit), or noise. For analysis, you could read the phy output directly (numpy files) or you could inmport back to spikeinterface and save.

Useful web sites:
spikeinterface docs (for old api): https://spikeinterface.readthedocs.io/en/0.13.0/
phy docs: https://phy.readthedocs.io/en/latest/
Useful info about .prb files): https://tridesclous.readthedocs.io/en/latest/important_details.html
                               https://spyking-circus.readthedocs.io/en/latest/code/probe.html

## preprocessing --SESSION NAME--

In [None]:
# import the necessary toolboxes. Note unknown installation problem with matplotlib. Run this cell 2x. Will complain
# but ok.

import matplotlib.pyplot as plt
import spikeinterface
import spikeinterface.extractors as se 
import spikeinterface.toolkit as st
import spikeinterface.sorters as ss
import spikeinterface.comparison as sc
import spikeinterface.widgets as sw
import numpy as np
import pandas as pd

# specify inline plotting
%matplotlib notebook

In [None]:
# User defined information about the data

pth = '/media/paul/storage/Data/Tony/' # path to data directory
sess = 'TY20210211_inHammock_night/' # directory where Ns6 file lives
file = 'TY20210211_inHammock_night-002.ns6' # name of NS6 file
prbfile = '/media/paul/storage/Data/TY_array.prb' # name of probe (.prb) file

In [None]:
# specify path to recording
recording_folder = pth+sess+file
print('recording_folder: ', recording_folder)
# load recording
recording = se.BlackrockRecordingExtractor(recording_folder, nsx_to_load=6)

In [None]:
# load probe information
probefile = prbfile
print('probefile: ', probefile)
recording_prb = recording.load_probe_file(probefile)
# check that info correct properties are present (should be gain, group, location, name, and offset)
recording_prb.get_shared_channel_property_names()

In [None]:
# visualize probe geometry: want to see that it looks correct
w_elec = sw.plot_electrode_geometry(recording_prb)

## Preprocessing
First step is common median filtering (the median across all channels is removed sample wise from the signal at each probe). Then the signal at each channel is low pass filtered and saved off as the lfp. The data is then bandpass filtered for spike processing. The filter values can be set, the defaults are 350Hz for LFP and 350-7500Hz for spikes. Some plots to review the data and save progress. The spike data is saved as a cache. All of the filtering is only done when the data is read. So to save time in future steps, we save the filtered/processed data to file.

In [None]:
# automagically attempt top find bad channels. need to then specify them in next cell. Don't forget to add 1 
# to the channel ID found in this step. Python zero based. amiright?
st.preprocessing.remove_bad_channels(recording_prb, verbose=True)

In [None]:
# remove bad channels: First time you run this skip.
recording_rmc = st.preprocessing.remove_bad_channels(recording_prb, bad_channel_ids=[])
# verify that all of the properties were trnasfered to the new recording object
print('properties: ', recording_rmc.get_shared_channel_property_names())
# verify bad channels have been removed
print('ids: ', recording_rmc.get_channel_ids())

In [None]:
# condition the signal for the lfp
# lowpass filter for lfp
recording_lp = st.preprocessing.bandpass_filter(recording_prb, freq_min=1, freq_max=350, filter_type='butter')
# downsample
recording_lfp = st.preprocessing.resample(recording_lp, resample_rate=1000)

In [None]:
# bandpass filter for spikes
recording_f = st.preprocessing.bandpass_filter(recording_rmc, freq_min=350, freq_max=7500, filter_type='butter')

In [None]:
# common median reference. First time switch input to recording_prb. Make it recording_rmc if you remove channels 
recording_cmr = st.preprocessing.common_reference(recording_f, reference='median')

In [None]:
# view the signal on channels. channel_id is the probe and trange is the time sample to view in seconds
sw.plot_timeseries(recording_cmr, channel_ids=[2, 5, 7], trange=[0, 6])

In [None]:
# view the power spectrum of the data. Check that the filtering looks reasonable. You can also look at the
# the raw data: recording, or the lfp: recording_lfp, or the spike data: recording_f
w_sp = sw.plot_spectrum(recording_prb, channels=[5])

In [None]:
# save preprocessed data for spikes and cache recording
recording_cache = se.CacheRecordingExtractor(recording_cmr, save_path=pth+sess+'processed/filtered_data.dat')
recording_cache.dump_to_dict()
recording_cache.dump_to_pickle(pth+sess+'processed/recording.pkl')
# save preprocessed data for lfp
se.CacheRecordingExtractor(recording_lfp, save_path=pth+sess+'processed/lfp_data.dat')

In [None]:
# load up from save: If you need to reload the data above, you can just run this cell
recording_cache = se.load_extractor_from_pickle(pth+sess+'processed/recording.pkl')
# check the channel properties are correct
recording_cache.get_shared_channel_property_names()


## Spike Sorting
First step is to set the parameters for the sorters. Then call each sorter that you want to run. Each run will make its own folder with the sorting results and save the results. You can also add aditional sorters. For example if you want to add YASS you can follow the pattern set here. See the documentation to install.

In [None]:
# set the paths for the sorters you want to run
ss.Kilosort2_5Sorter.set_kilosort2_5_path('/media/paul/storage/MLToolBoxes/Kilosort-2.5/')
ss.IronClustSorter.set_ironclust_path('/media/paul/storage/MLToolBoxes/ironclust/')
ss.WaveClusSorter.set_waveclus_path('/media/paul/storage/MLToolBoxes/wave_clus/')

In [None]:
# check which sorters are installed
ss.installed_sorters()

## Spyking Circus
Spyking circus nicely complains when an electrode is too corrupted to sort, so run this first and check to see if there are any problem electrodes. You can find which channels didn't get sorted by searching the results, each probe has its own folder, so search: ls results_sc/?/recording/*result.hdf5.  If so, you can go back and exclude those probes from the analysis by inserting the problematic probe numbers into the previous cell for exluding probes (don't forget that the probes will be in python 0-based, and you need to specify in the cell above the probe number in 1-based system). When bad electrodes are found Spyking Circus will tell you in the log which probe it is, then it will crash just before it completes. You can find the bad probes becuase they will be missing a *recording.result.hdf5* file in that probes recording directory (for ex. for the first probe the file would be located at processed/results_sc/0/recording/). Once you get spyking circus to run through to completion, everything else should work no problem.

In [None]:
# Start with spyking circus. list the parameters for the sorter
ss.get_params_description('spykingcircus')

In [None]:
# see what the default parameters are. These are the only parameters spikeinterface will let you modify.
ss.get_default_params('spykingcircus')

In [None]:
# set your own parameter values
params = {'detect_sign': -1,
 'adjacency_radius': 100,
 'detect_threshold': 6,
 'template_width_ms': 3,
 'filter': False,
 'merge_spikes': True,
 'auto_merge': 0.75,
 'num_workers': 15,
 'whitening_max_elts': 1000,
 'clustering_max_elts': 10000}

In [None]:
# run the sorter
sorting_SC = ss.run_spykingcircus(recording_cache, 
                                  output_folder=pth+sess+'processed/results_sc', 
                                  grouping_property='group',
                                  n_jobs=5,
                                  verbose=True, 
                                  **params)
print(f'SpykingCircus found {len(sorting_SC.get_unit_ids())} units')

In [None]:
# attempt to save sorting results in case of crash
sorting_SC.dump_to_dict()
sorting_SC.dump_to_pickle(pth+sess+'processed/sorting_sc.pkl')
# to reload
#sorting_SC = se.load_extractor_from_pickle('sorting_sc.pkl')

## Ironclust

In [None]:
# ironclust tends to do ok for our arrays same procedure view available paramters
ss.get_params_description('ironclust')

In [None]:
# parameter default values
ss.get_default_params('ironclust')

In [None]:
# create our own paramter dict
params = {'detect_sign': -1,
 'adjacency_radius': 50,
 'adjacency_radius_out': 100,
 'detect_threshold': 3.5,
 'prm_template_name': '',
 'freq_min': 300,
 'freq_max': 8000,
 'merge_thresh': 0.7,
 'pc_per_chan': 9,
 'whiten': False,
 'filter_type': 'none',
 'filter_detect_type': 'none',
 'common_ref_type': 'trimmean',
 'batch_sec_drift': 300,
 'step_sec_drift': 20,
 'knn': 30,
 'n_jobs_bin': 1,
 'chunk_mb': 500,
 'min_count': 30,
 'fGpu': True,
 'fft_thresh': 8,
 'fft_thresh_low': 0,
 'nSites_whiten': 16,
 'feature_type': 'gpca',
 'delta_cut': 1,
 'post_merge_mode': 1,
 'sort_mode': 1,
 'fParfor': True,
 'filter': False,
 'clip_pre': 0.25,
 'clip_post': 0.75,
 'merge_thresh_cc': 1,
 'nRepeat_merge': 3,
 'merge_overlap_thresh': 0.95}

In [None]:
# run sorter
sorting_IC = ss.run_ironclust(recording_cache, 
                              output_folder=pth+sess+'processed/results_ic', 
                              grouping_property='group', 
                              n_jobs=5, 
                              verbose=True,
                              **params)
print(f'Ironclust found {len(sorting_IC.get_unit_ids())} units')

In [None]:
# attempt to save sorting results in case of crash
sorting_IC.dump_to_dict()
sorting_IC.dump_to_pickle(pth+sess+'processed/sorting_ic.pkl')

## Waveclus

In [None]:
# waveclus takes a long time. see parameters
ss.get_params_description('waveclus')

In [None]:
# default values
ss.get_default_params('waveclus')

In [None]:
# modify parameters
params = {'detect_threshold': 4,
 'detect_sign': -1,
 'feature_type': 'wav',
 'scales': 4,
 'min_clus': 40,
 'maxtemp': 0.251,
 'template_sdnum': 3,
 'enable_detect_filter': True,
 'enable_sort_filter': True,
 'detect_filter_fmin': 300,
 'detect_filter_fmax': 3000,
 'detect_filter_order': 4,
 'sort_filter_fmin': 300,
 'sort_filter_fmax': 3000,
 'sort_filter_order': 2,
 'mintemp': 0,
 'w_pre': 20,
 'w_post': 44,
 'alignment_window': 10,
 'stdmax': 50,
 'max_spk': 40000,
 'ref_ms': 1.5,
 'interpolation': True}

In [None]:
# run sorter
sorting_WC = ss.run_waveclus(recording_cache, 
                             output_folder=pth+sess+'processed/results_wc_04', 
                             grouping_property='group', 
                             n_jobs=5, 
                             verbose=True,
                             **params)
print(f'Waveclus found {len(sorting_WC.get_unit_ids())} units')

In [None]:
# attempt to save sorting results in case of crash
sorting_WC.dump_to_dict()
sorting_WC.dump_to_pickle(pth+sess+'processed/sorting_wc_03.pkl')

## Klusta Kwik
Only use if Tridesclous or kilosort doesn't work

In [None]:
# get the parameters
ss.get_params_description('klusta')

In [None]:
# default values
ss.get_default_params('klusta')

In [None]:
# modify parameters
params = {'adjacency_radius': None,
 'threshold_strong_std_factor': 5,
 'threshold_weak_std_factor': 2,
 'detect_sign': -1,
 'extract_s_before': 16,
 'extract_s_after': 32,
 'n_features_per_channel': 3,
 'pca_n_waveforms_max': 10000,
 'num_starting_clusters': 50,
 'chunk_mb': 500,
 'n_jobs_bin': 1}

In [None]:
# run sorter
sorting_KL = ss.run_klusta(recording_cache, 
                           output_folder=pth+sess+'processed/results_kl',
                           grouping_property='group',
                           n_jobs=5, 
                           verbose=True, 
                           **params)
print(f'klusta found {len(sorting_KL.get_unit_ids())} units')

In [None]:
# attempt to save sorting results in case of crash
sorting_KL.dump_to_dict()
sorting_KL.dump_to_pickle(pth+sess+'processed/sorting_kl.pkl')

## Try Kilosort
You need a GPU to run kilosort. Additionally you need to compile the cuda mex files. See installation instructions for kilosort. Use ver 2.5 or 2.0. Version 3 has issues.

In [None]:
# try kilosort. it usually doesn't work with the marmoset data, but sometimes it does. Only use v2.5 or v2.0. 
# see params
ss.get_params_description('kilosort2_5')

In [None]:
# see default vales
ss.get_default_params('kilosort2_5')

In [None]:
# set desired parameter values
params = {'detect_threshold': 6,
 'projection_threshold': [5, 2],
 'preclust_threshold': 5,
 'car': True,
 'minFR': 0.1,
 'minfr_goodchannels': 0.1,
 'nblocks': 5,
 'sig': 20,
 'freq_min': 150,
 'sigmaMask': 60,
 'nPCs': 3,
 'ntbuff': 64,
 'nfilt_factor': 4,
 'NT': None,
 'keep_good_only': False,
 'chunk_mb': 500,
 'n_jobs_bin': 1}

In [None]:
# run sorter (dont worry if it crashes. Just go to the next cell). note that this sorter does not use the group
# parameter. Kilosort assumes electrode conatacts are <20um, so that a spatial signal is aquired. So it wont work
# with one channel. Might be able to finagle something with 3 channels, but needs development. 
sorting_KS = ss.run_kilosort2_5(recording_cache, 
                                output_folder=pth+sess+'processed/results_ks', 
                                #grouping_property='group', 
                                n_jobs=5, 
                                verbose=True,
                                **params)
print(f'Kilosort2.5 found {len(sorting_KS.get_unit_ids())} units')

In [None]:
# attempt to save sorting results in case of crash
sorting_KS.dump_to_dict()
sorting_KS.dump_to_pickle(pth+sess+'processed/sorting_ks.pkl')

## Try Tridesclous

In [None]:
# list the properties of sorter
ss.get_params_description('tridesclous')

In [None]:
# see what the default parameters are. These are the only parameters spikeinterface will let you modify.
ss.get_default_params('tridesclous')

In [None]:
# the paramters are dict. Make a new dict with the parameters you want. One thing you generally need to change is
# to remove the options for filtering, as the data is already preprocessed. Though sometimes this isn't an option 
# (well at least that I have figured out how to change)
params = {
    'freq_min': 400.0,
    'freq_max': 5000.0,
    'detect_sign': -1,
    'detect_threshold': 4,
    'peak_span_ms': 0.7,
    'wf_left_ms': -2.0,
    'wf_right_ms': 3.0,
    'feature_method': 'auto',
    'cluster_method': 'auto',
    'clean_catalogue_gui': False,
    'chunk_mb': 500,
    'n_jobs_bin': 1}

In [None]:
# run the sorter. Always try and run with the group option (this sorts each probe individually, which is what we
# want for blackrock arrays since each probe is ~400um away)
sorting_TDC = ss.run_tridesclous(recording_cache, 
                                 output_folder=pth+sess+'processed/results_tdc', 
                                 grouping_property='group', 
                                 n_jobs=5, 
                                 verbose=True,
                                 **params)
print(f'Tridesclous found {len(sorting_TDC.get_unit_ids())} units')

In [None]:
# attempt to save sorting results in case of crash
sorting_TDC.dump_to_dict()
sorting_TDC.dump_to_pickle(pth+sess+'processed/sorting_tdc.pkl')

## Curating the spike sorting
Curation means to check the output of the spike sorter and decide if it was accurate. This can genrally be done in three ways. Criteria: set thresholds for metrics and reject clusters that don't meet the threshold. For example: ISI violations, SNR, Distance from noise clusters, etc. Comparitive: This tries to find the same unit bewteen each sorter in a pairwise fashion. it then identifies units that are common to all of the sorters. According to the paper (see ref at top), four sorters provided the optimal information and performed well with little manual intervention. Manual: Examining each cluster to decide if it's a single unit or not. 

Here we take a combined approach of comapritive and criteria curation combined with manual validation. This allows the user to have the ultimate say and provides the oppurtunity to save multiunit clusters. There is no guarentee that clusters identified by all of the sorters is a single unit. There could be errors (the cluster is really garbage), splits (two of the clusters actually belong in the same cluster, or one cluster is aactually two cells), or some combintation (an identifed cluster is actually a seperable noise and unit cluster).

In [None]:
# compare sorter outputs. Its important here to have your primary sorter as the first sorter. The primary sorter 
# is the sort that you will actually process and use. I currently recommend using Ironclust with the default
# parameters as your primary sorters. The following code reflects this decision. If you want to use a different
# sorter as primary, you need to adjust the code below

mcmp = sc.compare_multiple_sorters([sorting_IC, sorting_TDC, sorting_KS, sorting_SC], ['IC', 'TDC', 'KS', 'SC'], 
                                   spiketrain_mode='union', n_jobs=1, 
                                   verbose=True)

In [None]:
# visualize comaprisons
w = sw.plot_multicomp_agreement(mcmp)
w = sw.plot_multicomp_agreement_by_sorter(mcmp)

In [None]:
# set agreement sorter. min agreement count is the number of sorters that had to agree to count the unit. Default is
# is to use 4 (as per the paper)
agreement_sorting = mcmp.get_agreement_sorting(minimum_agreement_count=4)

In [None]:
# get the ids of the common units
ids = agreement_sorting.get_unit_ids()
# show user common unit ids from the primary sorter
print('Common Unit IDs: ', ids)

In [None]:
# crucial: cache main sorter and specify location of tmp directory. The tmp directory needs to exist in your system.
# This will eat a lot of space while its processing. I would recomment that you have 2TB+ free in the tmp directory
sorting_IC_cache = se.CacheSortingExtractor(sorting_IC, pth+sess+'processed/ic_sort_results_cache.dat')
sorting_IC_cache.dump_to_pickle(pth+sess+'processed/ic_sorting_cache.pkl')
sorting_IC_cache.set_tmp_folder(pth+sess+'processed/tmp/')

In [None]:
# Export the data to phy for manual curation. This should work, but if it crashes, you will need to do the 
# processing separately. Use the code cells at the end of this docuiment to do that, than rerun this cell.
st.postprocessing.export_to_phy(recording_cache, sorting_IC_cache, 
                                output_folder=pth+sess+'processed/phy_IC', 
                                ms_before=0.5, 
                                ms_after=1, 
                                compute_pc_features=True, 
                                compute_amplitudes=True, 
                                max_spikes_per_unit=None, 
                                compute_property_from_recording=False, 
                                n_jobs=1, 
                                recompute_info=False, 
                                save_property_or_features=False, 
                                verbose=True)

In [None]:
# quality metrics
import seaborn as sns
# get quality metrics
quality_metrics = st.validation.compute_quality_metrics(sorting_IC, recording_cache, 
                                                        metric_names=['firing_rate', 'isi_violation', 'snr'], 
                                                        as_dataframe=True)
# plot the data
plt.figure()
# you can change these however you want to see the values
sns.scatterplot(data=quality_metrics, x="snr", y='isi_violation')

In [None]:
# Decide thresholds for quality metrics and ID sites that pass criteria
snr_thresh = 5
isi_viol_thresh = 0.5
# first get ISI violations and see ids that pass
sorting_auto = st.curation.threshold_isi_violations(sorting_IC, isi_viol_thresh, 'greater', duration)
print('#: ', len(sorting_auto.get_unit_ids()))
print('IDs: ', sorting_auto.get_unit_ids())

In [None]:
# now threshold on snr, and additionally remove clusters that do not pass
sorting_auto = st.curation.threshold_snrs(sorting_auto, recording_cache, snr_thresh, 'less')
print('#: ', len(sorting_auto.get_unit_ids()))
print('IDs: ', sorting_auto.get_unit_ids())

In [None]:
# Auto label based on criteria and comparision analysis. We do that by labelling all clusters that passed our
# criteria as MUA. Then we go back and label all clusters that were found in all sorters as 'Good'(SU).
cfile = pth+sess+'processed/phy_IC/cluster_group.tsv'
cg = pd.read_csv(cfile, delimiter='\t')
cg.iloc[sorting_auto.get_unit_ids(), 1] = 'mua'
cg.iloc[ids, 1] = 'good'
cg.to_csv(cfile, index=False, sep='\t') # check to see if the correct units were marked

### Done!
You should now use Phy to manually curate the results. I recommend having Phy installed in a seprate environment. When you open Phy, all of the clusters should be labeled "mua" (gray), "good" (green), or unlabeled (white). They can also be labeled "noise" (dark gray). You can now go through and adjust the automated curation. You want to check that good units are indeed single units to you, and that none of the mua are actually single units. You also want to check for errors (which should be rare) like: A cluster needs to be split into two. Two clusters need to be mereged. The cluster is really some weird noise. A cluster acually has labeled the same events as another cluster. Remeber to save your work in phy through the menu often.

### Saving the output/working with the data

After completing the manual curation we can access the spike info directly from phy using python (phy output is in .npy) or matlab (using the numpy toolbox). Or you can reimport the data form phy and save as nwb or access directy (the spikeinterface format is really neo under the hood). The NWB part is still a work in progress and may be problematic. The method attempted here is to save the processed data. then append the sorted data to the NWB file created for the processed data.

In [None]:
# import phy curation
sorting_TDC_phy = se.PhySortingExtractor('/media/paul/storage/Data/Theseus/phy_TDC/')

In [None]:
# export to NWB. Allegedly this exports the data to NWB format. It doesn't like it if there is already a file that 
# has the name of your output file. (so delete it if it already exists)
outputfile = pth+sess+'SPLTest01.nwb'
se.NwbRecordingExtractor.write_recording(recording_cache, outputfile)

In [None]:
# Append the sorting data to the NWB file by using setting the overwrite argument to False
se.NwbSortingExtractor.write_sorting(sorting_TDC_phy, outputfile, overwrite=False)

# Extra cells that might be helpful

If export to phy crashes, try running these next three cells first, then retry export to phy.

In [None]:
# get waveforms for chosen sorter
st.postprocessing.get_unit_waveforms(recording_cache, sorting_IC_cache, 
                                     ms_before=0.5, 
                                     ms_after=1, 
                                     compute_property_from_recording=True, 
                                     n_jobs=10, 
                                     max_spikes_per_unit=None, 
                                     memmap=True, 
                                     save_property_or_features=True, 
                                     recompute_info=True, 
                                     verbose=True)

In [None]:
# get amplitudes for chosen sorter
st.postprocessing.get_unit_amplitudes(recording_cache, sorting_IC_cache, 
                                      ms_before=0.5, 
                                      ms_after=1, 
                                      max_spikes_per_unit=None, 
                                      memmap=True, 
                                      save_property_or_features=True, 
                                      n_jobs=10, 
                                      verbose=True)

In [None]:
# unit templates
st.postprocessing.get_unit_templates(recording_cache, sorting_IC_cache, 
                                     ms_before=0.5, 
                                     ms_after=1, 
                                     max_spikes_per_unit=None, 
                                     memmap=True, 
                                     save_property_or_feature=True, 
                                     n_jobs=10, 
                                     verbose=True)

In case the isi violations throws an error, try running these cells first

In [None]:
duration = recording_cache.get_num_frames()
isi_violations = st.validation.compute_isi_violations(sorting_IC, duration_in_frames=duration)
print('ISI violations:', isi_violations)

snrs = st.validation.compute_snrs(sorting_IC, recording_cache)
print('SNRs:', snrs)