In [None]:
%load_ext autoreload
%autoreload 2
%matplotlib notebook

# SpikeInterface pipeline for Mease Lab - CED

In [None]:
from pathlib import Path
import numpy as np
from pprint import pprint
from datetime import datetime
import matplotlib.pyplot as plt

import spikeextractors as se
import spiketoolkit as st
import spikesorters as ss
import spikecomparison as sc
import spikewidgets as sw
from nwb_conversion_tools.conversion_tools import save_si_object

from mease_lab_to_nwb.convert_ced.cednwbconverter import CEDNWBConverter, quick_write

## 1) Load CED recording, set channel locations, compute LFP, and inspect signals

In [None]:
#ced_file = Path('/Users/abuccino/Documents/Data/catalyst/heidelberg/ced/m365_pt1_590-1190secs-001.smrx')
ced_file = Path('D:/CED_example_data/Other example/m365_pt1_590-1190secs-001.smrx')
spikeinterface_folder = Path('spikeinterface')
spikeinterface_folder.mkdir(parents=True, exist_ok=True)

In [None]:
# Automatically select Rhd channels
channel_info = se.CEDRecordingExtractor.get_all_channels_info(ced_file)

rhd_channels = []
for ch, info in channel_info.items():
    if "Rhd" in info["title"]:
        rhd_channels.append(ch)

In [None]:
recording = se.CEDRecordingExtractor(ced_file, smrx_channel_ids=rhd_channels)

prb_channels = [
    42, 36, 32, 38, 39, 47, 45, 37, 44, 46, 48, 35, 34, 40, 41, 43, 62, 60, 58, 56, 54, 52, 50, 33, 49, 51, 53,
    55, 57, 59, 61, 63, 1, 3, 5, 7, 9, 11, 13, 30, 14, 12, 10, 8, 6, 4, 2, 0, 21, 27, 31, 25, 24, 16, 18, 26,
    19, 17, 15, 28, 29, 23, 22, 20
]
prb_locations = [[0, x * 20] for x in range(len(prb_channels))]
recording.set_channel_locations(locations=np.array(prb_locations)[prb_channels, :])

### Stub recording for fast testing; set to False for running processing pipeline on entire data

In [None]:
stub_test = True
nsec_stub = 5

In [None]:
print(f"Num channels: {recording.get_num_channels()}")
print(f"Sampling rate: {recording.get_sampling_frequency()}")
print(f"Duration (s): {recording.get_num_frames() / recording.get_sampling_frequency()}")

### Load LFP

In [None]:
lfp_channels = []
for ch, info in channel_info.items():
    if "LFP" in info["title"]:
        lfp_channels.append(ch)

In [None]:
recording_lfp = se.CEDRecordingExtractor(ced_file, smrx_channel_ids=lfp_channels)

In [None]:
print(f"Sampling frequency AP: {recording.get_sampling_frequency()}")
print(f"Sampling frequency LF: {recording_lfp.get_sampling_frequency()}")      

### (Optional) Resample LFP

In [None]:
recording_lfp = st.preprocessing.resample(recording_lfp, resample_rate=1000)

### Inspect signals

In [None]:
w_ts_ap = sw.plot_timeseries(recording)

In [None]:
w_ts_lf = sw.plot_timeseries(recording_lfp, trange=[30, 40])

## 2) Pre-processing

In [None]:
apply_filter = False # the CED data appear to be already filtered
apply_cmr = True
freq_min_hp = 300
freq_max_hp = 3000

In [None]:
if apply_filter:
    recording_processed = st.preprocessing.bandpass_filter(recording, freq_min=freq_min_hp, freq_max=freq_max_hp)
else:
    recording_processed = recording
    
if apply_cmr:
    recording_processed = st.preprocessing.common_reference(recording_processed)

In [None]:
if stub_test:
    recording_processed = se.SubRecordingExtractor(recording_processed, end_frame=int(nsec_stub*recording_processed.get_sampling_frequency()))
    recording_lfp = se.SubRecordingExtractor(recording_lfp, end_frame=int(nsec_stub*recording_lfp.get_sampling_frequency()))

In [None]:
num_frames = recording_processed.get_num_frames()
print(num_frames)

In [None]:
w_ts_ap = sw.plot_timeseries(recording_processed)

## 3) Run spike sorters

In [None]:
sorter_list = [
    "ironclust",
    "waveclus",
    # 'klusta'
]

In [None]:
ss.installed_sorters()
ss.IronClustSorter.set_ironclust_path("D:/GitHub/ironclust")
ss.WaveClusSorter.set_waveclus_path("D:/GitHub/wave_clus")

In [None]:
# Inspect sorter-specific parameters and defaults
for sorter in sorter_list:
    print(f"\n\n{sorter} params description:")
    pprint(ss.get_params_description(sorter))
    print("Default params:")
    pprint(ss.get_default_params(sorter))    

In [None]:
# user-specific parameters
# sorter_params = dict(kilosort2=dict(car=False),
#                      ironclust=dict(),
#                      spykingcircus=dict())
sorter_params = dict()

In [None]:
ss.run_sorter?

In [None]:
sorting_outputs = ss.run_sorters(
    sorter_list=sorter_list, 
    working_folder=spikeinterface_folder / 'ced_si_output',
    recording_dict_or_list=dict(rec0=recording_processed), 
    sorter_params=sorter_params,
    mode="keep", # change to "keep" to avoid repeating the spike sorting
    verbose=True,
)

The `sorting_outputs` is a dictionary with ("rec_name", "sorter_name") as keys.

In [None]:
for result_name, sorting in sorting_outputs.items():
    rec_name, sorter = result_name
    print(f"{sorter} found {len(sorting.get_unit_ids())} units")

## 4) Post-processing: extract waveforms, templates, quality metrics, extracellular features

### Set postprocessing parameters

In [None]:
# Post-processing params
postprocessing_params = st.postprocessing.get_common_params()
pprint(postprocessing_params)

In [None]:
# (optional) change parameters
postprocessing_params['max_spikes_per_unit'] = 1000  # with None, all waveforms are extracted

### Set quality metric list

In [None]:
# Quality metrics
qc_list = st.validation.get_quality_metrics_list()
print(f"Available quality metrics: {qc_list}")

In [None]:
# (optional) define subset of qc
qc_list = ['snr', 'isi_violation', 'firing_rate']

### Set extracellular features

In [None]:
# Extracellular features
ec_list = st.postprocessing.get_template_features_list()
print(f"Available EC features: {ec_list}")

In [None]:
# (optional) define subset of ec
ec_list = ['peak_to_valley', 'halfwidth']

### Postprocess all sorting outputs

In [None]:
for result_name, sorting in sorting_outputs.items():
    rec_name, sorter = result_name
    print(f"Postprocessing recording {rec_name} sorted with {sorter}")
    tmp_folder = Path('tmp_ced') / sorter
    tmp_folder.mkdir(parents=True, exist_ok=True)
    
    # set local tmp folder
    sorting.set_tmp_folder(tmp_folder)
    
    # compute waveforms
    waveforms = st.postprocessing.get_unit_waveforms(recording_processed, sorting, 
                                                     **postprocessing_params)
    
    # compute templates
    templates = st.postprocessing.get_unit_templates(recording_processed, sorting, **postprocessing_params)
    
    # comput EC features
    ec = st.postprocessing.compute_unit_template_features(recording_processed, sorting,
                                                          feature_names=ec_list, as_dataframe=True)
    # compute QCs
    qc = st.validation.compute_quality_metrics(sorting, recording=recording_processed, 
                                               metric_names=qc_list, as_dataframe=True)
    
    # export to phy
    phy_folder = spikeinterface_folder / 'phy' / sorter
    phy_folder.mkdir(parents=True, exist_ok=True)
    print("Exporting to phy")
    st.postprocessing.export_to_phy(recording_processed, sorting, phy_folder, verbose=True)

### Load Phy-curated data back to SI

In [None]:
!phy template-gui spikeinterface/phy/ironclust/params.py

In [None]:
phy_folder = 'spikeinterface/phy/ironclust/'
sorting_curated = se.PhySortingExtractor(phy_folder)

## 5) Ensemble spike sorting

In [None]:
if len(sorting_outputs) > 1:
    # retrieve sortings and sorter names
    sorting_list = []
    sorter_names_comp = []
    for result_name, sorting in sorting_outputs.items():
        rec_name, sorter = result_name
        sorting_list.append(sorting)
        sorter_names_comp.append(sorter)
        
    # run multisorting comparison
    mcmp = sc.compare_multiple_sorters(sorting_list=sorting_list, name_list=sorter_names_comp)
    
    # plot agreement results
    w_agr = sw.plot_multicomp_agreement(mcmp)
    
    # extract ensamble sorting
    sorting_ensemble = mcmp.get_agreement_sorting(minimum_agreement_count=2)
    
    print(f"Ensemble sorting among {sorter_list} found: {len(sorting_ensemble.get_unit_ids())} units")

### (Optional) save ensemble output for later use

In [None]:
save_si_object(
    "sorting_ensemble", sorting_ensemble, spikeinterface_folder,
    cache_raw=False, include_properties=True, include_features=False
)

# 6) Automatic curation

In [None]:
# define curators and thresholds
isi_violation_threshold = 0.5
snr_threshold = 3
firing_rate_threshold = 0.05

In [None]:
sorting_auto_curated = []
sorter_names_curation = []
for result_name, sorting in sorting_outputs.items():
    rec_name, sorter = result_name
    sorter_names_curation.append(sorter)
    
    # firing rate threshold
    sorting_curated = st.curation.threshold_firing_rates(sorting, duration_in_frames=num_frames,
                                                         threshold=firing_rate_threshold, 
                                                         threshold_sign='less')
    
    # isi violation threshold
    sorting_curated = st.curation.threshold_isi_violations(sorting, duration_in_frames=num_frames,
                                                           threshold=isi_violation_threshold, 
                                                           threshold_sign='greater')
    
    # isi violation threshold
    sorting_curated = st.curation.threshold_snrs(sorting, recording=recording_processed,
                                                 threshold=snr_threshold, 
                                                 threshold_sign='less')
    sorting_auto_curated.append(sorting_curated)

In [None]:
print(sorting_auto_curated[0].get_unit_ids())
print(sorting_auto_curated[1].get_unit_ids())

# 7) Quick save to NWB; writes only the spikes

## To complete the full conversion for other types of data, use the external script

In [None]:
# Name your NWBFile and decide where you want it saved
nwbfile_path = ced_file.parent / "CED.nwb"

# Enter Session and Subject information here
session_description = "Enter session description here."

# Manually insert the session start time
session_start = datetime(1971, 1, 1, 1, 1, 1)  # (Year, Month, Day, Hour, Minute, Second)

# Choose the sorting extractor from the notebook environment you would like to write to NWB
chosen_sorting_extractor = sorting_outputs[('rec0', 'ironclust')]

quick_write(
    ced_file_path=ced_file,
    session_description=session_description,
    session_start=session_start,
    save_path=nwbfile_path,
    sorting=chosen_sorting_extractor,
    recording_lfp=recording_lfp,
    overwrite=True
)

# 8) Full Save to NWB

In [None]:
# Name your NWBFile and decide where you want it saved
nwbfile_path = ced_file.parent / "CED.nwb"

# Choose the sorting extractor from the notebook environment you would like to write to NWB
chosen_sorting_extractor = sorting_outputs[('rec0', 'ironclust')]

# If you do not want to include the LFP data from the recording, comment out this variable
chosen_recording_lfp = recording_lfp

# Enter Session and Subject information here
session_description = "Enter session description here."

# Manually insert the session start time
session_start = datetime(1971, 1, 1, 1, 1, 1)  # (Year, Month, Day, Hour, Minute, Second)

# Uncomment any fields you want to include
subject_info = dict(
    subject_id="Enter optional subject id here",
#    description="Enter optional subject description here",
#    weight="Enter subject weight here",
#    age=duration_isoformat(timedelta(days=0)),  # Enter the age of the subject in days
#    species="Mus musculus",
#    genotype="Enter subject genotype here",
#    sex="Enter subject sex here"
)

# Set some global conversion options here
# It's recommended to set stub_test to True on first attempt to ensure NWBFile output looks OK
conversion_stub_test = True
overwrite = True  # If the NWBFile exists at the path, replace it

## Run this cell to automatically perform conversion

In [None]:
# Automatically performs conversion based on above filepaths and options
source_data = dict(
    CEDRecording=dict(file_path=str(ced_file)),
    CEDStimulus=dict(file_path=str(ced_file))
)
conversion_options = dict(
    CEDRecording=dict(stub_test=conversion_stub_test),
    CEDStimulus=dict(stub_test=conversion_stub_test)
)
converter = CEDNWBConverter(source_data)
metadata = converter.get_metadata()
metadata['NWBFile'].update(session_description=session_description, session_start_time=session_start.astimezone())
metadata.update(Subject=subject_info)
run_args = dict(
    nwbfile_path=str(nwbfile_path),
    metadata=metadata,
    save_to_file=True,
    conversion_options=conversion_options,
    sorting=chosen_sorting_extractor,
    overwrite=overwrite
)
if "chosen_recording_lfp" in locals():
    run_args.update(recording_lfp=chosen_recording_lfp)
converter.run_conversion(**run_args)