In [None]:
%load_ext autoreload
%autoreload 2

# SpikeInterface pipeline for Mease Lab - CED

In [None]:
import spikeextractors as se
import spiketoolkit as st
import spikesorters as ss
import spikecomparison as sc
import spikewidgets as sw
from pathlib import Path
import numpy as np
import matplotlib.pyplot as plt
from pprint import pprint
from pathlib import Path
from datetime import datetime
%matplotlib notebook

In [None]:
from nwb_conversion_tools.conversion_tools import save_si_object

In [None]:
stub_test = True
nsec_stub = 5

In [None]:
spikeinterface_folder = Path('spikeinterface')
spikeinterface_folder.mkdir(parents=True, exist_ok=True)

## 1) Load CED recordings and probe, compute LFP, and inspect signals

In [None]:
ced_file = Path('/Users/abuccino/Documents/Data/catalyst/heidelberg/ced/m365_pt1_590-1190secs-001.smrx')
probe_file = '../probe_files/cambridge_neurotech_H3.prb'

In [None]:
# Automatically select Rhd channels
channel_info = se.CEDRecordingExtractor.get_all_channels_info(ced_file)

rhd_channels = []
for ch, info in channel_info.items():
    if "Rhd" in info["title"]:
        rhd_channels.append(ch)

In [None]:
recording = se.CEDRecordingExtractor(ced_file, smrx_ch_inds=rhd_channels)

In [None]:
print(f"Num channels: {recording.get_num_channels()}")
print(f"Sampling rate: {recording.get_sampling_frequency()}")
print(f"Duration (s): {recording.get_num_frames() / recording.get_sampling_frequency()}")

In [None]:
# Load probe file to re-order channels and add location
recording = se.load_probe_file(recording, probe_file)

In [None]:
# pressing on the electrode will show its channel id!
sw.plot_electrode_geometry(recording)

### Load LFP

In [None]:
lfp_channels = []
for ch, info in channel_info.items():
    if "LFP" in info["title"]:
        lfp_channels.append(ch)

In [None]:
recording_lfp = se.CEDRecordingExtractor(ced_file, smrx_ch_inds=lfp_channels)

In [None]:
print(f"Sampling frequency AP: {recording.get_sampling_frequency()}")
print(f"Sampling frequency LF: {recording_lfp.get_sampling_frequency()}")      

### (Optional) Resample LFP

In [None]:
recording_lfp = st.preprocessing.resample(recording_lfp, resample_rate=1000)

### Inspect signals

In [None]:
w_ts_ap = sw.plot_timeseries(recording)

In [None]:
w_ts_lf = sw.plot_timeseries(recording_lfp, trange=[30, 40])

## 2) Pre-processing

In [None]:
apply_filter = False # the CED data appear to be already filtered
apply_cmr = True
freq_min_hp = 300
freq_max_hp = 3000

In [None]:
if apply_filter:
    recording_processed = st.preprocessing.bandpass_filter(recording, freq_min=freq_min_hp, freq_max=freq_max_hp)
else:
    recording_processed = recording
if apply_cmr:
    recording_processed = st.preprocessing.common_reference(recording)
else:
    recording_processed = recording

In [None]:
if stub_test:
    recording_processed = se.SubRecordingExtractor(recording_processed, 
                                                   end_frame=int(nsec_stub * 
                                                                 recording_processed.get_sampling_frequency()))

In [None]:
num_frames = recording_processed.get_num_frames()
print(num_frames)

In [None]:
w_ts_ap = sw.plot_timeseries(recording_processed)

## 3) Run spike sorters

In [None]:
sorter_list = ['ironclust']

In [None]:
ss.installed_sorters()

In [None]:
# Inspect sorter-specific parameters and defaults
for sorter in sorter_list:
    print(f"{sorter} params description:")
    pprint(ss.get_params_description(sorter))
    print("Default params:")
    pprint(ss.get_default_params(sorter))    

In [None]:
# user-specific parameters
# sorter_params = dict(kilosort2=dict(car=False),
#                      ironclust=dict(),
#                      spykingcircus=dict())
sorter_params = {}

In [None]:
sorting_outputs = ss.run_sorters(sorter_list=sorter_list, 
                                 working_folder=spikeinterface_folder / 'ced_si_output',
                                 recording_dict_or_list=dict(rec0=recording_processed), 
                                 sorter_params=sorter_params, verbose=True)

The `sorting_outputs` is a dictionary with ("rec_name", "sorter_name") as keys.

In [None]:
print(sorting_outputs.keys())

In [None]:
print(sorting_outputs.values())

In [None]:
sw.plot_rasters(sorting_outputs[('rec0', 'ironclust')])

## 4) Post-processing: extract waveforms, templates, quality metrics, extracellular features

### Set postprocessing parameters

In [None]:
# Post-processing params
postprocessing_params = st.postprocessing.get_common_params()
pprint(postprocessing_params)

In [None]:
# (optional) change parameters
postprocessing_params['max_spikes_per_unit'] = 1000  # with None, all waveforms are extracted

### Set quality metric list

In [None]:
# Quality metrics
qc_list = st.validation.get_quality_metrics_list()
print(f"Available quality metrics: {qc_list}")

In [None]:
# (optional) define subset of qc
qc_list = ['snr', 'isi_violation', 'firing_rate']

### Set extracellular features

In [None]:
# Extracellular features
ec_list = st.postprocessing.get_template_features_list()
print(f"Available EC features: {ec_list}")

In [None]:
# (optional) define subset of ec
ec_list = ['peak_to_valley', 'halfwidth']

### Postprocess all sorting outputs

In [None]:
for result_name, sorting in sorting_outputs.items():
    rec_name, sorter = result_name
    tmp_folder = Path('tmp_ced') / sorter
    tmp_folder.mkdir(parents=True, exist_ok=True)
    
    # set local tmp folder
    sorting.set_tmp_folder(tmp_folder)
    
    # compute waveforms
    waveforms = st.postprocessing.get_unit_waveforms(recording_processed, sorting, 
                                                     **postprocessing_params)
    
    # compute templates
    templates = st.postprocessing.get_unit_templates(recording_processed, sorting, **postprocessing_params)
    
    # comput EC features
    ec = st.postprocessing.compute_unit_template_features(recording_processed, sorting,
                                                          feature_names=ec_list, as_dataframe=True)
    # compute QCs
    qc = st.validation.compute_quality_metrics(sorting, recording=recording_processed, 
                                               metric_names=qc_list, as_dataframe=True)
    
    # export to phy
    phy_folder = spikeinterface_folder / 'phy' 
    print("Exporting to phy")
    st.postprocessing.export_to_phy(recording_processed, sorting, phy_folder, verbose=True)

### Load Phy-curated data back to SI

In [None]:
sorting_curated = se.PhySortingExtractor(phy_folder)

## 5) Ensamble spike sorting

If len(sorter_list) > 1

In [None]:
# retrieve sortings and sorter names
sorting_list = []
sorter_names_comp = []
for result_name, sorting in sorting_outputs.items():
    rec_name, sorter = result_name
    sorting_list.append(sorting)
    sorter_names_comp.append(sorter)

In [None]:
# run multisorting comparison
mcmp = sc.compare_multiple_sorters(sorting_list=sorting_list, name_list=sorter_names)

In [None]:
# plot agreement results
w_agr = sw.plot_multicomp_agreement(mcmp)

In [None]:
# extract ensamble sorting
sorting_ensamble = mcmp.get_agreement_sorting(minimum_agreement_count=2)

# 6) Automatic curation

In [None]:
# define curators and thresholds
isi_violation_threshold = 0.5
snr_threshold = 3
firing_rate_threshold = 0.05

In [None]:
sorting_auto_curated = []
sorter_names_curation = []
for result_name, sorting in sorting_outputs.items():
    rec_name, sorter = result_name
    sorter_names_curation.append(sorter)
    
    # firing rate threshold
    sorting_curated = st.curation.threshold_firing_rates(sorting, duration_in_frames=num_frames,
                                                         threshold=firing_rate_threshold, 
                                                         threshold_sign='less')
    
    # isi violation threshold
    sorting_curated = st.curation.threshold_isi_violations(sorting, duration_in_frames=num_frames,
                                                           threshold=isi_violation_threshold, 
                                                           threshold_sign='greater')
    
    # isi violation threshold
    sorting_curated = st.curation.threshold_snrs(sorting, recording=recording_processed,
                                                 threshold=snr_threshold, 
                                                 threshold_sign='less')
    sorting_auto_curated.append(sorting_curated)

In [None]:
sorting_auto_curated[0].get_unit_ids()

## 7) Save all outputs for NWB conversion

In [None]:
nwb_folder = spikeinterface_folder / 'nwb'
nwb_folder.mkdir(parents=True, exist_ok=True)

### Save recording

In [None]:
save_si_object("recording_raw", recording, output_folder=nwb_folder)

### Save sorting

In [None]:
save_si_object("sorting_curated", sorting_curated, output_folder=nwb_folder)