In [None]:
%load_ext autoreload
%autoreload 2

# SpikeInterface pipeline for Mease Lab - Syntalos

In [None]:
from pathlib import Path
import numpy as np
import matplotlib.pyplot as plt
from pprint import pprint
from datetime import datetime, timedelta
from isodate import duration_isoformat

import spikeextractors as se
import spiketoolkit as st
import spikesorters as ss
import spikecomparison as sc
import spikewidgets as sw
from nwb_conversion_tools.json_schema_utils import dict_deep_update
from nwb_conversion_tools.conversion_tools import save_si_object

from mease_lab_to_nwb import SyntalosNWBConverter, SyntalosRecordingExtractor, SyntalosRecordingInterface
from mease_lab_to_nwb.convert_syntalos.syntalosnwbconverter import quick_write

%matplotlib notebook

## 1) Load Intan recordings, compute LFP, and inspect signals

In [None]:
#syntalos_folder = Path('D:/Syntalos/Latest Syntalos Recording _20200730/intan-signals')
syntalos_folder = Path('/Users/abuccino/Documents/Data/catalyst/heidelberg/syntalos/Latest_Syntalos_Recording_20200730/intan-signals/')
spikeinterface_folder = syntalos_folder / "spikeinterface"
probe_file = '../probe_files/tetrode_32.prb'

In [None]:
recording = SyntalosRecordingExtractor(syntalos_folder)

In [None]:
# Load probe file for tetrodes
recording = se.load_probe_file(recording, probe_file)
print(recording.get_channel_groups())

In [None]:
# the probe file loads a dummy geometry to identify the different tetrodes
sw.plot_electrode_geometry(recording)

### Stub recording for fast testing; set to False for running processing pipeline on entire data

In [None]:
stub_test = True
nsec_stub = 10

### Compute LFP

In [None]:
freq_min_lfp = 1
freq_max_lfp = 300
freq_resample_lfp = 1000.

In [None]:
recording_lfp = st.preprocessing.bandpass_filter(recording, freq_min=freq_min_lfp, freq_max=freq_max_lfp)
recording_lfp = st.preprocessing.resample(recording_lfp, freq_resample_lfp)

In [None]:
print(f"Sampling frequency AP: {recording.get_sampling_frequency()}")
print(f"Sampling frequency LF: {recording_lfp.get_sampling_frequency()}")      

### Inspect signals

In [None]:
w_ts_ap = sw.plot_timeseries(recording, color_groups=True)

In [None]:
w_ts_lf = sw.plot_timeseries(recording_lfp, trange=[10, 15])

## 2) Pre-processing

In [None]:
apply_filter = True
apply_cmr = True
freq_min_hp = 300
freq_max_hp = 3000

In [None]:
if apply_filter:
    recording_processed = st.preprocessing.bandpass_filter(recording, freq_min=freq_min_hp, freq_max=freq_max_hp)
else:
    recording_processed = recording

if apply_cmr:
    recording_processed = st.preprocessing.common_reference(recording_processed)

In [None]:
if stub_test:
    recording_processed = se.SubRecordingExtractor(recording_processed, end_frame=int(nsec_stub*recording_processed.get_sampling_frequency()))
    recording_lfp = se.SubRecordingExtractor(recording_lfp, end_frame=int(nsec_stub*recording_lfp.get_sampling_frequency()))

In [None]:
w_ts_ap = sw.plot_timeseries(recording_processed)

## 3) Run spike sorters

In [None]:
ss.installed_sorters()

In [None]:
sorter_list = [
    "klusta", # ironclust requires channel locations
    "tridesclous"
    # "waveclus" # waveclust errors out, "File type '' isn't supportedERROR: MATLAB error Exit Status: 0x00000001"
]

# this can also be done by setting global env variables: IRONCLUST_PATH, WAVECLUS_PATH
ss.IronClustSorter.set_ironclust_path("D:/GitHub/ironclust")
ss.WaveClusSorter.set_waveclus_path("D:/GitHub/wave_clus")

In [None]:
# Inspect sorter-specific parameters and defaults
for sorter in sorter_list:
    print(f"{sorter} params description:")
    pprint(ss.get_params_description(sorter))
    print("Default params:")
    pprint(ss.get_default_params(sorter))

In [None]:
# user-specific parameters
sorter_params = dict(
    klusta=dict(),
    tridesclous=dict()
)

**Important** To spike sort by grup, set the `grouping_property='group`. This way, the different tetrodes will be sorted separately and re-assembled after spike sorting. 

The `run_sorters` function does not support the `grouping_property` argument, so you need to launch spike sorters in a different way if you have groups. This next cell will create the same output dictionary as the CED pipeline.

Note that the `parallel` argument allows one to run different groups (tetrodes) in parallel. This is different than the intrinsic parallelization of the spike sorters (which can be controlled by their params).

In [None]:
sorting_outputs = dict()
working_folder = spikeinterface_folder

In [None]:
for sorter_name in sorter_list:
    print(f"Running {sorter_name}")
    sorting = ss.run_sorter(sorter_name, recording_processed, output_folder=working_folder / sorter_name,
                            grouping_property='group', parallel=True, n_jobs=4, verbose=True, 
                            raise_error=False, **sorter_params[sorter_name])
    sorting_outputs[('rec0', sorter_name)] = sorting
    print(f"{sorter_name} found {len(sorting.get_unit_ids())} units!")

## 4) Post-processing: extract waveforms, templates, quality metrics, extracellular features

### Set postprocessing parameters

In [None]:
# Post-processing params
postprocessing_params = st.postprocessing.get_common_params()
pprint(postprocessing_params)

**Important note for Windows**: on Windows, we currently have some problems with the `memmap` argument. While we fix it, we recommend to set it to `False`.

In [None]:
# (optional) change parameters
postprocessing_params['max_spikes_per_unit'] = 1000  # with None, all waveforms are extracted

# by setting 'grouping_property' to True, everything is computed tetrode-wise (handy for manual curation)
postprocessing_params['grouping_property'] = "group"

### Set quality metric list

In [None]:
# Quality metrics
qc_list = st.validation.get_quality_metrics_list()
print(f"Available quality metrics: {qc_list}")

In [None]:
# (optional) define subset of qc
qc_list = ["snr", "isi_violation", "firing_rate"]

### Set extracellular features

In [None]:
# Extracellular features
ec_list = st.postprocessing.get_template_features_list()
print(f"Available EC features: {ec_list}")

In [None]:
# (optional) define subset of ec
ec_list = ["peak_to_valley", "halfwidth"]

### Postprocess all sorting outputs

In [None]:
for result_name, sorting in sorting_outputs.items():
    rec_name, sorter = result_name
    print(f"Postprocessing recording {rec_name} sorted with {sorter}")
    tmp_folder = spikeinterface_folder / 'tmp' / sorter
    tmp_folder.mkdir(parents=True, exist_ok=True)
    
    # set local tmp folder
    sorting.set_tmp_folder(tmp_folder)
    
    # compute waveforms
    waveforms = st.postprocessing.get_unit_waveforms(recording_processed, sorting, **postprocessing_params)
    
    # compute templates
    templates = st.postprocessing.get_unit_templates(recording_processed, sorting, **postprocessing_params)
    
    # comput EC features
    ec = st.postprocessing.compute_unit_template_features(recording_processed, sorting,
                                                          feature_names=ec_list, as_dataframe=True)
    # compute QCs
    qc = st.validation.compute_quality_metrics(sorting, recording=recording_processed, 
                                               metric_names=qc_list, as_dataframe=True)
    
    # export to phy
    phy_folder = spikeinterface_folder / 'phy' / sorter
    phy_folder.mkdir(parents=True, exist_ok=True)
    st.postprocessing.export_to_phy(recording_processed, sorting, phy_folder)

## 5) Ensemble spike sorting


In [None]:
if len(sorting_outputs) > 1:
    # retrieve sortings and sorter names
    sorting_list = []
    sorter_names_comp = []
    for result_name, sorting in sorting_outputs.items():
        rec_name, sorter = result_name
        sorting_list.append(sorting)
        sorter_names_comp.append(sorter)
        
    # run multisorting comparison
    mcmp = sc.compare_multiple_sorters(sorting_list=sorting_list, name_list=sorter_names_comp)
    
    # plot agreement results
    w_agr = sw.plot_multicomp_agreement(mcmp)
    
    # extract ensamble sorting
    sorting_ensemble = mcmp.get_agreement_sorting(minimum_agreement_count=2)
    
    print(f"Ensemble sorting among {sorter_list} found: {len(sorting_ensemble.get_unit_ids())} units")

### (Optional) save ensemble output for later use

In [None]:
save_si_object(
    "sorting_ensemble", sorting_ensemble, spikeinterface_folder,
    cache_raw=False, include_properties=True, include_features=False
)

# 6) Automatic curation

In [None]:
# define curators and thresholds
isi_violation_threshold = 0.5
snr_threshold = 5
firing_rate_threshold = 0.1

In [None]:
sorting_auto_curated = []
sorter_names_curation = []
for result_name, sorting in sorting_outputs.items():
    rec_name, sorter = result_name
    sorter_names_curation.append(sorter)
    
    num_frames = recording_processed.get_num_frames()
    # firing rate threshold
    sorting_curated = st.curation.threshold_firing_rates(
        sorting,
        duration_in_frames=num_frames,
        threshold=firing_rate_threshold, 
        threshold_sign='less'
    )
    
    # isi violation threshold
    sorting_curated = st.curation.threshold_isi_violations(
        sorting,
        duration_in_frames=num_frames,
        threshold=isi_violation_threshold, 
        threshold_sign='greater'
    )
    
    # isi violation threshold
    sorting_curated = st.curation.threshold_snrs(
        sorting,
        recording=recording_processed,
        threshold=snr_threshold, 
        threshold_sign='less'
    )
    sorting_auto_curated.append(sorting_curated)

# 7) Quick save to NWB; writes only the spikes

## To complete the full conversion for other types of data, use the external script

In [None]:
# Name your NWBFile and decide where you want it saved
nwbfile_path = base_path / "Syntalos.nwb"

# Enter Session and Subject information here
session_description = "Enter session description here."

# Choose the sorting extractor from the notebook environment you would like to write to NWB
chosen_sorting_extractor = sorting_outputs[('rec0', 'ironclust')]

quick_write(
    intan_folder_path=syntalos_folder,
    session_description=session_description,
    save_path=nwbfile_path,
    sorting=chosen_sorting_extractor,
    lfp=recording_lfp,
    timestamps=recording.get_timestamps(),
    overwrite=False
)