# SpikeInterface pipeline for Brody Lab

In [None]:
from pathlib import Path
import numpy as np
import matplotlib.pyplot as plt
from pprint import pprint

import spikeextractors as se  # for reading raw data
import spiketoolkit as st  # pre- and post- processed, validation, and curation
import spikesorters as ss
import spikecomparison as sc  # ensemble sorting and spike train comparisons
import spikewidgets as sw  # visualizations

from brody_lab_to_nwb import make_extractor

%matplotlib notebook

### Set global paramaters

In [None]:
n_jobs = 4  # Jobs for parallelization
chunk_mb = 2000  # RAM per job

verbose = True
use_memmap = False  # Enable memmap caching (recommended for Linux/macOS)

## 1a) Load AP recordings, LF recordings

### Choose one of the following formats

In [None]:
base_path = Path("D:/Brody")

spikeglx_folder_name = "A242_2019_05_30"  # or A256_2020_10_07, T219_2019_11_22
spikeglx_session_name = "2019-05-30_g0"  # usually the same as spikeglx_folder_path, except for A242
spikeglx_file_path = (
    base_path / "Chronic Rat Neuropixels (Poisson Clicks Task)" / spikeglx_folder_name / "Raw" 
    / spikeglx_session_name / f"{spikeglx_session_name}_imec0" / f"{spikeglx_session_name}_t0.imec0.ap.bin"
)

spikegadgets_session_name = "W122_06_09_2019_1_fromSD"  # or W122_09_02_2019_1_fromSD
spikegadgets_file_path = base_path / "WirelessTetrodes" / f"{spikegadgets_session_name}.rec"

neuralynx_session_name = "A182_2018_10_05"  # A182_2018_10_05 or K236_2017_09_06
neuralynx_folder_path = base_path / "Neuralynx Tetrode Data" / neuralynx_session_name / "Raw"

### Make spikeinterface folders

In [None]:
recording_folder = spikeglx_file_path.parent
spikeinterface_folder = recording_folder / "spikeinterface"
spikeinterface_folder.mkdir(parents=True, exist_ok=True)

### (optional) stub recording for fast testing; set to False for running processing pipeline on entire data

In [None]:
stub_test = True
nsec_stub = 30  # seconds

### Make RecordingExtractor

In [None]:
recording = make_extractor(file_or_folder_path=spikegadgets_file_path)

if stub_test:
    recording = se.SubRecordingExtractor(recording, end_frame=int(nsec_stub*recording.get_sampling_frequency()))

In [None]:
print(f"Sampling frequency AP: {recording.get_sampling_frequency()}")

### Inspect signals

In [None]:
w_ts_ap = sw.plot_timeseries(recording_ap, channel_ids=recording_ap.get_channel_ids()[::10], trange=[1, 5])

In [None]:
w_ts_lf = sw.plot_timeseries(recording_lf, channel_ids=recording_lf.get_channel_ids()[::10])

## 2) Pre-processing

In [None]:
apply_bandpass = True
apply_cmr = False  # Temporary until fix on spiketoolkits for SpikeGLX only

In [None]:
if apply_bandpass:
    recording_processed = st.preprocessing.bandpass_filter(recording_ap)
else:
    recording_processed = recording_ap

if apply_cmr:
    recording_processed = st.preprocessing.common_reference(recording_processed)
else:
    recording_processed = recording_processed

In [None]:
rates, amps = st.postprocessing.compute_channel_spiking_activity(
    start_frame=0,
    end_frame=recording_processed.get_num_frames(), 
    detect_threshold=5,
    recompute_info=True,
    recording=recording_processed,
    n_jobs=n_jobs,
    chunk_mb=chunk_mb,
    verbose=verbose
)

In [None]:
fig, axs = plt.subplots(nrows=2)
sw.plot_activity_map(recording_processed, activity="rate", colorbar=True, ax=axs[0], transpose=True, colorbar_orientation="horizontal")
sw.plot_activity_map(recording_processed, activity="amplitude", colorbar=True, ax=axs[1], transpose=True, colorbar_orientation="horizontal")

## 3) Run spike sorters

In [None]:
sorter_list = [
    "herdingspikes",
    "tridesclous",
    #"ironclust",
]

#ss.IronClustSorter.set_ironclust_path("D:/GitHub/ironclust")

In [None]:
# Inspect sorter-specific parameters and defaults
for sorter in sorter_list:
    print(f"{sorter} params description:")
    pprint(ss.get_params_description(sorter))
    print("Default params:")
    pprint(ss.get_default_params(sorter))    

In [None]:
# user-specific parameters
sorter_params = dict(
    herdingspikes=dict(filter=False),
    tridesclous=dict(n_jobs_bin=n_jobs, chunk_mb=chunk_mb)
    #ironclust=dict(n_jobs_bin=n_jobs, chunk_mb=chunk_mb)
)

In [None]:
sorting_outputs = ss.run_sorters(
    sorter_list=sorter_list, 
    recording_dict_or_list=dict(rec0=recording_processed),
    working_folder=spikeinterface_folder / "working1",
    mode="keep", # change to "keep" to avoid repeating the spike sorting
    sorter_params=sorter_params,
    verbose=verbose,
    run_sorter_kwargs=dict(raise_error=False)
)

In [None]:
for result_name, sorting in sorting_outputs.items():
    rec_name, sorter = result_name
    print(f"{sorter} found {len(sorting.get_unit_ids())} units")
    
    # tridesclous sometimes has empty clusters
    active_units = []
    for u in sorting.get_unit_ids():
        if len(sorting.get_unit_spike_train(u)) > 0:
            active_units.append(u)
    
    if len(active_units) < len(sorting.get_unit_ids()):
        sorting_outputs[result_name] = se.SubSortingExtractor(sorting, unit_ids=active_units)
        print(f"{sorter} found {len(active_units)} units after removing empty")

## 4) Post-processing: extract waveforms, templates, quality metrics, extracellular features

### Set quality metric list
#### Reference: https://allensdk.readthedocs.io/en/latest/_static/examples/nb/ecephys_quality_metrics.html

In [None]:
print(f"Available quality metrics: {st.validation.get_quality_metrics_list()}")

In [None]:
# Specify a subset of these metrics
qm_list = ['snr', 'isi_violation', 'firing_rate']

### Set extracellular features
#### Reference: https://github.com/AllenInstitute/ecephys_spike_sorting/tree/master/ecephys_spike_sorting/modules/mean_waveforms

In [None]:
print(f"Available quality metrics: {st.postprocessing.get_template_features_list()}")

In [None]:
# Specify a subset of these features
ecf_list = ['peak_to_valley', 'halfwidth']

### Set postprocessing parameters

In [None]:
postprocessing_params = st.postprocessing.get_postprocessing_params()
pprint(f"Default parameters: {postprocessing_params}")

In [None]:
postprocessing_params['max_spikes_per_unit'] = 1000  # with None, all waveforms are extracted
postprocessing_params['n_jobs'] = n_jobs
postprocessing_params['chunk_mb'] = chunk_mb
postprocessing_params['verbose'] = verbose
postprocessing_params['recompute_info'] = True
postprocessing_params['memmap'] = use_memmap

export_to_phy = False

### Postprocess all sorting outputs

In [None]:
for result_name, sorting in sorting_outputs.items():
    rec_name, sorter = result_name
    print(f"Postprocessing recording {rec_name} sorted with {sorter}")
    tmp_folder = spikeinterface_folder / 'tmp' / sorter
    tmp_folder.mkdir(parents=True, exist_ok=True)
    
    # set local tmp folder
    sorting.set_tmp_folder(tmp_folder)
    
    # compute waveforms
    waveforms = st.postprocessing.get_unit_waveforms(recording_processed, sorting, **postprocessing_params)
    
    # compute templates
    templates = st.postprocessing.get_unit_templates(recording_processed, sorting, **postprocessing_params)
    
    # compute metrics
    qm = st.validation.compute_quality_metrics(
        sorting,
        recording=recording_processed, 
        metric_names=qm_list,
        as_dataframe=True
    )
    
    # compute features
    ecf = st.postprocessing.compute_unit_template_features(
        recording_processed,
        sorting,
        feature_names=ecf_list,
        as_dataframe=True
    )
    
    # export to phy - can be expensive, disable if not used
    if export_to_phy:
        phy_folder = spikeinterface_folder / 'phy' / sorter
        phy_folder.mkdir(parents=True, exist_ok=True)
        st.postprocessing.export_to_phy(
            recording_processed,
            sorting,
            phy_folder,
            compute_pc_features=False,
            compute_amplitudes=False,
            save_property_or_features=False
        )

In [None]:
display(qm)

In [None]:
display(ecf)

### Visualize templates

In [None]:
sorting_ic = sorting_outputs[("rec0", "herdingspikes")]
w = sw.plot_unit_templates(
    recording_processed,
    sorting_ic,
    unit_ids=[0],
    radius=100,  # This is spatial distance specifying channels around the unit - default is None
    lw=0.5
)
#w.figure.set_size_inches((5,15))  # this can be handy if radius=None above, to give a better visual of probe laayout

### Run phy and load curated data

####  Reference: https://phy.readthedocs.io/en/latest/

In [None]:
!phy template-gui spikeinterface/phy/ironclust/params.py

In [None]:
sorting_manual_curated = se.PhySortingExtractor(phy_folder, exclude_cluster_groups=['noise'])

In [None]:
print(f"ironclust found {len(sorting_manual_curated.get_unit_ids())} units after manual curation")

## 5) Ensemble spike sorting

#### If using more than one spike sorter, this automated method has been shown to give similar results to manual curation under certain conditions: https://elifesciences.org/articles/61834

In [None]:
if len(sorting_outputs) > 1:
    # retrieve sortings and sorter names
    sorting_list = []
    sorter_names_comp = []
    for result_name, sorting in sorting_outputs.items():
        rec_name, sorter = result_name
        sorting_list.append(sorting)
        sorter_names_comp.append(sorter)
        
    # run multisorting comparison
    mcmp = sc.compare_multiple_sorters(sorting_list=sorting_list, name_list=sorter_names_comp)
    
    # plot agreement results
    w_agr = sw.plot_multicomp_agreement(mcmp)
    
    # extract ensamble sorting
    sorting_ensemble = mcmp.get_agreement_sorting(minimum_agreement_count=2)
    
    print(f"Ensemble sorting among {sorter_list} found: {len(sorting_ensemble.get_unit_ids())} units")
    sw.plot_rasters(sorting_ensemble)

# 6) Automatic curation

In [None]:
firing_rate_threshold = 0.1
isi_violation_threshold = 0.5
snr_threshold = 3

In [None]:
sorting_auto_curated = []
sorter_names_curation = []
for result_name, sorting in sorting_outputs.items():
    rec_name, sorter = result_name
    sorter_names_curation.append(sorter)
    
    # firing rate threshold
    sorting_curated = st.curation.threshold_firing_rates(
        sorting,
        duration_in_frames=recording_processed.get_num_frames(),
        threshold=firing_rate_threshold, 
        threshold_sign='less'
    )
    print(f"{sorter} found {len(sorting_curated.get_unit_ids())} units after thresholding firing rates")
    
    # isi violation threshold
    sorting_curated = st.curation.threshold_isi_violations(
        sorting_curated,
        duration_in_frames=recording_processed.get_num_frames(),
        threshold=isi_violation_threshold, 
        threshold_sign='greater'
    )
    print(f"{sorter} found {len(sorting_curated.get_unit_ids())} units after thresholding isi violations")
    
    # snr threshold
    sorting_curated = st.curation.threshold_snrs(
        sorting_curated,
        recording=recording_processed,
        threshold=snr_threshold, 
        threshold_sign='less'
    )
    sorting_auto_curated.append(sorting_curated)
    print(f"{sorter} found {len(sorting_curated.get_unit_ids())} units after thresholding snr")

# 7) Save to NWB; writes only the spikes

In [None]:
# The name of the NWBFile containing behavioral and/or full recording data
nwbfile_path = base_path.parent / f"Brody_PoissonClicks_{session_name}_full_autocurated.nwb"

# Choose the sorting extractor from the notebook environment you would like to write to NWB
#chosen_sorting_extractor = sorting_outputs[("rec0", "tridesclous")]
chosen_sorting_extractor = sorting_curated

se.NwbSortingExtractor.write_sorting(
    sorting=chosen_sorting_extractor,
    save_path=nwbfile_path,
    overwrite=True,  # this appends the file. True would write a new file,
    skip_features=["waveforms"]
)