# SpikeInterface pipeline for Feldman Lab

In [None]:
from pathlib import Path
import numpy as np
import matplotlib.pyplot as plt
from pprint import pprint

import spikeextractors as se
import spiketoolkit as st
import spikesorters as ss
import spikecomparison as sc
import spikewidgets as sw

%matplotlib notebook

## 1a) Load AP recordings, LF recordings and TTL signals

In [None]:
# base_path = Path("D:/Feldman")
# #base_path = Path("/Users/abuccino/Documents/Data/catalyst/brody/A256_bank1_2020_09_30_g0")
# #base_data_path = Path("D:/Neuropixels/Neuropixels/A256_bank1_2020_09_30/A256_bank1_2020_09_30_g0")
# base_data_path = Path("20210115_NPX_and_behavior/2021_01_15_E105/towersTask_g0")
# ap_bin_path = base_data_path / "towersTask_g0_imec0" / "towersTask_g0_t0.imec0.ap.bin"
# lf_bin_path = base_data_path / "towersTask_g0_imec0" / "towersTask_g0_t0.imec0.lf.bin"

In [None]:
base_path = Path("/Users/abuccino/Documents/Data/catalyst/feldman/")
session_name = "LR_210209_2_g1"
# session_name = "LR_210209_g1"
# session_name = "LR_210209_2_g0"
# session_name = "LR_210209_2_g1"
ap_bin_path = base_path / session_name / f"{session_name}_imec0" / f"{session_name}_t0.imec0.ap.bin"
lf_bin_path = base_path / session_name / f"{session_name}_imec0" / f"{session_name}_t0.imec0.lf.bin"
nidq_bin_path = base_path / session_name / f"{session_name}_t0.nidq.bin"

### Make spikeinterface folders

In [None]:
recording_folder = Path('.')
spikeinterface_folder = recording_folder / "spikeinterface"
spikeinterface_folder.mkdir(parents=True, exist_ok=True)

### (optional) stub recording for fast testing; set to False for running processing pipeline on entire data

In [None]:
stub_test = False
nsec_stub = 5

In [None]:
recording_ap = se.SpikeGLXRecordingExtractor(ap_bin_path)
recording_lf = se.SpikeGLXRecordingExtractor(lf_bin_path)

if stub_test:
    recording_ap = se.SubRecordingExtractor(recording_ap, end_frame=int(nsec_stub*recording_ap.get_sampling_frequency()))
    recording_lf = se.SubRecordingExtractor(recording_lf, end_frame=int(nsec_stub*recording_lf.get_sampling_frequency()))

In [None]:
print(f"Sampling frequency AP: {recording_ap.get_sampling_frequency()}")
print(f"Sampling frequency LF: {recording_lf.get_sampling_frequency()}")      

## Load TTL signals
### ToDo: but general sketch of common methods shown below

### (Option 1): Use TTLs from the ap.bin file

In [None]:
ttl, states = recording_ap.get_ttl_events(channel_id=6)
rising_times = ttl[states==1]
print(f"Number of TTL events in ap file: {len(rising_times)}")

In [None]:
if len(rising_times) > 0:
    start_time = recording_ap.frame_to_time(rising_times[0])
    start_frame_ap = int(recording_ap.time_to_frame(start_time))
    start_frame_lf = int(recording_lf.time_to_frame(start_time))
    print(f"Start frame AP: {start_frame_ap}")
    print(f"Start frame LF: {start_frame_lf}")    
else:
    print("No TTL events found in ap file.")

### (Option 2): Use TTLs from the nidq.bin file

In [None]:
recording_nidq = se.SpikeGLXRecordingExtractor(nidq_bin_path)

In [None]:
recording_nidq.get_sampling_frequency()

In [None]:
sw.plot_timeseries(recording_nidq, trange=[0, 120])

### Synchronize recording

In [None]:
recording_ap_sync = se.SubRecordingExtractor(recording_ap, start_frame=start_frame_ap)
recording_lf_sync = se.SubRecordingExtractor(recording_lf, start_frame=start_frame_lf)

### Inspect signals

In [None]:
w_ts_ap = sw.plot_timeseries(recording_ap, channel_ids=recording_ap.get_channel_ids()[::4], trange=[0, 5])

In [None]:
w_ts_lf = sw.plot_timeseries(recording_lf, channel_ids=recording_lf.get_channel_ids()[::10])

## 2) Pre-processing

In [None]:
apply_cmr = True

In [None]:
if apply_cmr:
    recording_processed = st.preprocessing.common_reference(recording_ap_sync)
else:
    recording_processed = recording_ap_sync

In [None]:
num_frames = recording_processed.get_num_frames()

In [None]:
rates, amps = st.postprocessing.compute_channel_spiking_activity(
    recording_processed,
    n_jobs=16,
    chunk_mb=4000,
    start_frame=10*30000,
    end_frame=20*30000, 
    detect_threshold=8,
    recompute_info=True, 
    verbose=True
)

In [None]:
fig, axs = plt.subplots(ncols=2)
sw.plot_activity_map(recording_processed, activity="rate", colorbar=True, ax=axs[0])
sw.plot_activity_map(recording_processed, activity="amplitude", colorbar=True, ax=axs[1])

## 3) Run spike sorters

In [None]:
sorter_list = [
    'tridesclous',
    #'spykingcircus',
    'herdingspikes',
    #'kilosort2',
]

In [None]:
# Inspect sorter-specific parameters and defaults
for sorter in sorter_list:
    print(f"{sorter} params description:")
    pprint(ss.get_params_description(sorter))
    print("Default params:")
    pprint(ss.get_default_params(sorter))    

In [None]:
# user-specific parameters
sorter_params = dict(
    #kilosort2=dict(car=False, n_jobs_bin=12, chunk_mb=4000),
    #ironclust=dict(filter=True),
    tridesclous=dict(n_jobs_bin=12, chunk_mb=4000),
    spykingcircus=dict(filter=True, num_workers=16),
    herdingspikes=dict(filter=True)
)

In [None]:
sorting_outputs = ss.run_sorters(
    sorter_list=sorter_list, 
    recording_dict_or_list=dict(rec0=recording_processed),
    working_folder=spikeinterface_folder / "working1",
    mode="keep", # change to "keep" to avoid repeating the spike sorting
    sorter_params=sorter_params,
    verbose=True,
    run_sorter_kwargs=dict(raise_error=False)
)

In [None]:
for result_name, sorting in sorting_outputs.items():
    rec_name, sorter = result_name
    print(f"{sorter} found {len(sorting.get_unit_ids())} units")
    
    # tridesclous sometimes has empty clusters
    active_units = []
    for u in sorting.get_unit_ids():
        if len(sorting.get_unit_spike_train(u)) > 0:
            active_units.append(u)
    
    if len(active_units) < len(sorting.get_unit_ids()):
        sorting_outputs[result_name] = se.SubSortingExtractor(sorting, unit_ids=active_units)
        print(f"{sorter} found {len(active_units)} units after removing empty")

In [None]:
sorting_tdc = sorting_outputs[('rec0', 'tridesclous')]
sw.plot_unit_templates(recording_processed, sorting_tdc, unit_ids=[11, 21], radius=None)

## 4) Post-processing: extract waveforms, templates, quality metrics, extracellular features

### Set postprocessing parameters

In [None]:
# Post-processing params
postprocessing_params = st.postprocessing.get_postprocessing_params()
pprint(postprocessing_params)

In [None]:
# (optional) change parameters
postprocessing_params['max_spikes_per_unit'] = 1000  # with None, all waveforms are extracted
postprocessing_params['n_jobs'] = 16  # n jobs
postprocessing_params['chunk_mb'] = 4000  # max RAM usage in Mb
postprocessing_params['verbose'] = True  # max RAM usage in Mb

### Set quality metric list

In [None]:
# Quality metrics
qc_list = st.validation.get_quality_metrics_list()
print(f"Available quality metrics: {qc_list}")

In [None]:
# (optional) define subset of qc
qc_list = ['snr', 'isi_violation', 'firing_rate']

### Set extracellular features

In [None]:
# Extracellular features
ec_list = st.postprocessing.get_template_features_list()
print(f"Available EC features: {ec_list}")

In [None]:
# (optional) define subset of ec
ec_list = ['peak_to_valley', 'halfwidth']

### Postprocess all sorting outputs

In [None]:
for result_name, sorting in sorting_outputs.items():
    rec_name, sorter = result_name
    print(f"Postprocessing recording {rec_name} sorted with {sorter}")
    tmp_folder = spikeinterface_folder / 'tmp' / sorter
    tmp_folder.mkdir(parents=True, exist_ok=True)
    
    # set local tmp folder
    sorting.set_tmp_folder(tmp_folder)
    
    # compute waveforms
    waveforms = st.postprocessing.get_unit_waveforms(recording_processed, sorting, **postprocessing_params)
    
    # compute templates
    templates = st.postprocessing.get_unit_templates(recording_processed, sorting, **postprocessing_params)
    
    # comput EC features
    ec = st.postprocessing.compute_unit_template_features(
        recording_processed,
        sorting,
        feature_names=ec_list,
        as_dataframe=True
    )

    # compute QCs
    qc = st.validation.compute_quality_metrics(
        sorting,
        recording=recording_processed, 
        metric_names=qc_list,
        as_dataframe=True
    )
    
    # export to phy
    if sorter == "kilosort2":
        phy_folder = spikeinterface_folder / 'phy' / sorter
        phy_folder.mkdir(parents=True, exist_ok=True)
        st.postprocessing.export_to_phy(recording_processed, sorting, phy_folder)

### Run phy and load curated data

In [None]:
!phy template-gui spikeinterface/phy/kilosort2/params.py

In [None]:
sorting_manual_curated = se.PhySortingExtractor(phy_folder, exclude_cluster_groups=['noise'])

In [None]:
print(f"Kilosort2 found {len(sorting_manual_curated.get_unit_ids())} units after manual curation")

## 5) Ensemble spike sorting

In [None]:
if len(sorting_outputs) > 1:
    # retrieve sortings and sorter names
    sorting_list = []
    sorter_names_comp = []
    for result_name, sorting in sorting_outputs.items():
        rec_name, sorter = result_name
        sorting_list.append(sorting)
        sorter_names_comp.append(sorter)
        
    # run multisorting comparison
    mcmp = sc.compare_multiple_sorters(sorting_list=sorting_list, name_list=sorter_names_comp)
    
    # plot agreement results
    w_agr = sw.plot_multicomp_agreement(mcmp)
    
    # extract ensamble sorting
    sorting_ensemble = mcmp.get_agreement_sorting(minimum_agreement_count=2)
    
    print(f"Ensamble sorting among {sorter_list} found: {len(sorting_ensemble.get_unit_ids())} units")

In [None]:
sw.plot_rasters(sorting_ensemble)

# 6) Automatic curation

In [None]:
# define curators and thresholds
isi_violation_threshold = 0.5
snr_threshold = 5
firing_rate_threshold = 0.1

In [None]:
sorting_auto_curated = []
sorter_names_curation = []
for result_name, sorting in sorting_outputs.items():
    rec_name, sorter = result_name
    sorter_names_curation.append(sorter)
    
    # firing rate threshold
    sorting_curated = st.curation.threshold_firing_rates(
        sorting,
        duration_in_frames=num_frames,
        threshold=firing_rate_threshold, 
        threshold_sign='less'
    )
    
    # isi violation threshold
    sorting_curated = st.curation.threshold_isi_violations(
        sorting_curated,
        duration_in_frames=num_frames,
        threshold=isi_violation_threshold, 
        threshold_sign='greater'
    )
    
    # isi violation threshold
    sorting_curated = st.curation.threshold_snrs(
        sorting_curated,
        recording=recording_processed,
        threshold=snr_threshold, 
        threshold_sign='less'
    )
    sorting_auto_curated.append(sorting_curated)
    print(f"{sorter} found {len(sorting_curated.get_unit_ids())} units after auto curation")

# 7) Save to NWB; writes only the spikes

In [None]:
# The name of the NWBFile containing behavioral or full recording data
nwbfile_path = base_data_path / f"Feldman_{session_name}.nwb"

# Choose the sorting extractor from the notebook environment you would like to write to NWB
chosen_sorting_extractor = sorting_outputs[('rec0', 'ironclust')]

se.NwbSortingExtractor.write_sorting(
    sorting=chosen_sorting_extractor,
    save_path=nwbfile_path,
    overwrite=False  # this appends the file. True would write a new file
)