In [None]:
%load_ext autoreload
%autoreload 2

# SpikeInterface pipeline for Mease Lab - CED

In [None]:
from pathlib import Path
import numpy as np
from pprint import pprint
from datetime import datetime
import matplotlib.pyplot as plt

import spikeextractors as se
import spiketoolkit as st
import spikesorters as ss
import spikecomparison as sc
import spikewidgets as sw

from mease_lab_to_nwb.convert_ced.cednwbconverter import quick_write

%matplotlib notebook

## 1) Load CED recording, set channel locations, compute LFP, and inspect signals

In [None]:
ced_file = Path(
    "/home/luiz/storage/taufferconsulting/client_ben/project_heidelberg_gui/heidelberg_data/CED_example_data/M365/pt1 15 + mech.smrx"
)
# ced_file = Path('/Users/abuccino/Documents/Data/catalyst/heidelberg/ced/m365_pt1_590-1190secs-001.smrx')
# ced_file = Path('D:/CED_example_data/Other example/m365_pt1_590-1190secs-001.smrx')
probe_file = "../probe_files/cambridge_neurotech_H3.prb"
spikeinterface_folder = ced_file.parent
spikeinterface_folder.mkdir(parents=True, exist_ok=True)

In [None]:
# Automatically select Rhd channels
channel_info = se.CEDRecordingExtractor.get_all_channels_info(ced_file)

rhd_channels = []
for ch, info in channel_info.items():
    if "Rhd" in info["title"]:
        rhd_channels.append(ch)

In [None]:
recording = se.CEDRecordingExtractor(ced_file, smrx_channel_ids=rhd_channels)

# Load probe file to re-order channels and add location
recording = se.load_probe_file(recording, probe_file)

In [None]:
sw.plot_electrode_geometry(recording)

## Concatenate multiple recordings 

With the `MultiRecordingTimeExtractor`, you can easily concatenate the multiple recordings in time. The recordings must have the same channels and locations (e.g. same probe file).

In [None]:
# here we concatenate the same file as an example
recording_files = [ced_file, ced_file, ced_file]

recordings = []
for file in recording_files:
    # Automatically select Rhd channels
    channel_info = se.CEDRecordingExtractor.get_all_channels_info(file)

    rhd_channels = []
    for ch, info in channel_info.items():
        if "Rhd" in info["title"]:
            rhd_channels.append(ch)
    recording = se.CEDRecordingExtractor(ced_file, smrx_channel_ids=rhd_channels)
    recording = se.load_probe_file(recording, probe_file)
    recordings.append(recording)
# instantiate a MultiRecording object
multirecording = se.MultiRecordingTimeExtractor(recordings)

The `multirecording` is also a `RecordingExtractor` and it can be used for further processing. It contains `epoch` information about start and end of each recording.

In [None]:
for epoch_name in multirecording.get_epoch_names():
    print(multirecording.get_epoch_info(epoch_name))

If you have separate files from different probes recorded using the same device, you can concatenate them in the channel dimension. You can add separate groups to the different recordings.

In [None]:
# here we concatenate the same recording as an example
multirec_group = se.MultiRecordingChannelExtractor(
    [recording, recording], groups=[0, 1]
)

In [None]:
print(multirec_group.get_channel_groups())

Different groups can be spike sorted separately using the `grouping_property='group'` argument. The when spike sorting by group, the output units have a property called `group` with info about which group it's been found on.

**Note that MultiRecordingTimeExtractor and ChannelExtractor can be easily combined!**

In [None]:
print(f"Num channels: {recording.get_num_channels()}")
print(f"Sampling rate: {recording.get_sampling_frequency()}")
print(
    f"Duration (s): {recording.get_num_frames() / recording.get_sampling_frequency()}"
)

### Load LFP

In [None]:
lfp_channels = []
for ch, info in channel_info.items():
    if "LFP" in info["title"]:
        lfp_channels.append(ch)

In [None]:
recording_lfp = se.CEDRecordingExtractor(ced_file, smrx_channel_ids=lfp_channels)

In [None]:
print(f"Sampling frequency AP: {recording.get_sampling_frequency()}")
print(f"Sampling frequency LF: {recording_lfp.get_sampling_frequency()}")

### (Optional) Resample LFP

In [None]:
recording_lfp = st.preprocessing.resample(recording_lfp, resample_rate=1000)

### Inspect signals

In [None]:
w_ts_ap = sw.plot_timeseries(recording, trange=[10, 12])

In [None]:
w_ts_lf = sw.plot_timeseries(recording_lfp, trange=[30, 40])

## 2) Pre-processing

In [None]:
apply_filter = False  # the CED data appear to be already filtered
apply_cmr = True
freq_min_hp = 300
freq_max_hp = 3000

In [None]:
?st.preprocessing.common_reference

In [None]:
if apply_filter:
    recording_processed = st.preprocessing.bandpass_filter(
        recording, freq_min=freq_min_hp, freq_max=freq_max_hp
    )
else:
    recording_processed = recording

if apply_cmr:
    recording_processed = st.preprocessing.common_reference(recording_processed)

In [None]:
# Stub recording for fast testing; set to False for running processing pipeline on entire data
stub_test = True
nsec_stub = 30

if stub_test:
    recording_processed = se.SubRecordingExtractor(
        parent_recording=recording_processed,
        end_frame=int(nsec_stub * recording_processed.get_sampling_frequency()),
    )
    recording_lfp = se.SubRecordingExtractor(
        recording_lfp, end_frame=int(nsec_stub * recording_lfp.get_sampling_frequency())
    )

print(f"Original signal length: {recording.get_num_frames()}")
print(f"Processed signal length: {recording_processed.get_num_frames()}")

In [None]:
num_frames = recording_processed.get_num_frames()
print(num_frames)

In [None]:
w_ts_ap = sw.plot_timeseries(recording_processed)

## 3) Run spike sorters

In [None]:
ss.installed_sorters()
# ss.IronClustSorter.set_ironclust_path("D:/GitHub/ironclust")

In [None]:
sorter_list = [
    "herdingspikes",
    #     "ironclust",
    "klusta",
]

In [None]:
# Inspect sorter-specific parameters and defaults
for sorter in sorter_list:
    print(f"\n\n{sorter} params description:")
    pprint(ss.get_params_description(sorter))
    print("Default params:")
    pprint(ss.get_default_params(sorter))

In [None]:
# user-specific parameters
sorter_params = dict(
    #     ironclust={'detect_threshold': 6},
    klusta={},
    herdingspikes={},
)

In [None]:
?ss.run_sorters

In [None]:
sorting_outputs = ss.run_sorters(
    sorter_list=sorter_list,
    working_folder=spikeinterface_folder / "ced_si_output",
    recording_dict_or_list=dict(rec0=recording_processed),
    sorter_params=sorter_params,
    mode="overwrite",  # change to "keep" to avoid repeating the spike sorting
    verbose=True,
)

The `sorting_outputs` is a dictionary with ("rec_name", "sorter_name") as keys.

In [None]:
for result_name, sorting in sorting_outputs.items():
    rec_name, sorter = result_name
    print(f"{sorter} found {len(sorting.get_unit_ids())} units")

### Split sorting output when concatenation is used

If you concatenated multiple recordings into a `MultiRecordingTimeExtractor`, you can split the sorting output using the epoch information and the `SubSortingExtractor`.

In [None]:
if isinstance(recording_processed, se.MultiRecordingTimeExtractor):
    sortings_split = []
    sorting_to_be_split = sorting_outputs[("rec0", "ironclust")]

    for epoch_name in recording_processed.get_epoch_names():
        epoch_info = multirecording.get_epoch_info(epoch_name)
        sorting_split = se.SubSortingExtractor(
            sorting_to_be_split,
            start_frame=epoch_info["start_frame"],
            end_frame=epoch_info["end_frame"],
        )
        sortings_split.append(sorting_split)

## 4) Post-processing: extract waveforms, templates, quality metrics, extracellular features

### Set postprocessing parameters

In [None]:
# Post-processing params
postprocessing_params = st.postprocessing.get_common_params()
postprocessing_params["verbose"] = True
postprocessing_params["recompute_info"] = True
pprint(postprocessing_params)

In [None]:
# (optional) change parameters
postprocessing_params[
    "max_spikes_per_unit"
] = 1000  # with None, all waveforms are extracted

**Important note for Windows**: on Windows, we currently have some problems with the `memmap` argument. While we fix it, we recommend to set it to `False`.

### Set quality metric list

In [None]:
# Quality metrics
qc_list = st.validation.get_quality_metrics_list()
print(f"Available quality metrics: {qc_list}")

In [None]:
# (optional) define subset of qc
qc_list = ["snr", "isi_violation", "firing_rate"]

### Set extracellular features

In [None]:
# Extracellular features
ec_list = st.postprocessing.get_template_features_list()
print(f"Available EC features: {ec_list}")

In [None]:
# (optional) define subset of ec
ec_list = ["peak_to_valley", "halfwidth"]

### Postprocess all sorting outputs

In [None]:
for result_name, sorting in sorting_outputs.items():
    rec_name, sorter = result_name
    print(f"Postprocessing recording {rec_name} sorted with {sorter}")
    tmp_folder = Path("tmp_ced") / sorter
    tmp_folder.mkdir(parents=True, exist_ok=True)

    # set local tmp folder
    sorting.set_tmp_folder(tmp_folder)

    # compute waveforms
    waveforms = st.postprocessing.get_unit_waveforms(
        recording_processed, sorting, n_jobs=16, chunk_mb=2000, **postprocessing_params
    )

    # compute templates
    templates = st.postprocessing.get_unit_templates(
        recording_processed, sorting, **postprocessing_params
    )

    # comput EC features
    ec = st.postprocessing.compute_unit_template_features(
        recording_processed, sorting, feature_names=ec_list, as_dataframe=True
    )
    # compute QCs
    qc = st.validation.compute_quality_metrics(
        sorting, recording=recording_processed, metric_names=qc_list, as_dataframe=True
    )

    # export to phy example
    if sorter == "ironclust":
        phy_folder = spikeinterface_folder / "phy" / sorter
        phy_folder.mkdir(parents=True, exist_ok=True)
        print("Exporting to phy")
        st.postprocessing.export_to_phy(
            recording_processed, sorting, phy_folder, verbose=True
        )

In [None]:
sorting_ironclust = sorting_outputs[("rec0", "ironclust")]
print(f"Properties: {sorting_ironclust.get_shared_unit_property_names()}")
print(f"Spikefeatures: {sorting_ironclust.get_shared_unit_spike_feature_names()}")

### Load Phy-curated data back to SI

In [None]:
!phy template-gui /Users/abuccino/Documents/Data/catalyst/heidelberg/ced/phy/ironclust/params.py

In [None]:
phy_folder = "/Users/abuccino/Documents/Data/catalyst/heidelberg/ced/phy/ironclust/"
sorting_curated = se.PhySortingExtractor(phy_folder)
print(f"Units after manual curation: {len(sorting_curated.get_unit_ids())}")

## 5) Ensemble spike sorting

In [None]:
if len(sorting_outputs) > 1:
    # retrieve sortings and sorter names
    sorting_list = []
    sorter_names_comp = []
    for result_name, sorting in sorting_outputs.items():
        rec_name, sorter = result_name
        sorting_list.append(sorting)
        sorter_names_comp.append(sorter)

    # run multisorting comparison
    mcmp = sc.compare_multiple_sorters(
        sorting_list=sorting_list, name_list=sorter_names_comp
    )

    # plot agreement results
    w_agr = sw.plot_multicomp_agreement(mcmp)

    # extract ensamble sorting
    sorting_ensemble = mcmp.get_agreement_sorting(minimum_agreement_count=2)

    print(
        f"Ensemble sorting among {sorter_list} found: {len(sorting_ensemble.get_unit_ids())} units"
    )

# 6) Automatic curation

In [None]:
# define curators and thresholds
isi_violation_threshold = 0.5
snr_threshold = 3
firing_rate_threshold = 0.05

In [None]:
sortings = []
sortings_auto_curated = []
sorter_names_curation = []
for result_name, sorting in sorting_outputs.items():
    rec_name, sorter = result_name
    sorter_names_curation.append(sorter)
    sortings.append(sorting)

    # firing rate threshold
    sorting_curated = st.curation.threshold_firing_rates(
        sorting,
        duration_in_frames=num_frames,
        threshold=firing_rate_threshold,
        threshold_sign="less",
    )

    # isi violation threshold
    sorting_curated = st.curation.threshold_isi_violations(
        sorting_curated,
        duration_in_frames=num_frames,
        threshold=isi_violation_threshold,
        threshold_sign="greater",
    )

    # isi violation threshold
    sorting_curated = st.curation.threshold_snrs(
        sorting_curated,
        recording=recording_processed,
        threshold=snr_threshold,
        threshold_sign="less",
    )
    sortings_auto_curated.append(sorting_curated)

In [None]:
for (sort_name, sort, sort_curated) in zip(
    sorter_names_curation, sortings, sortings_auto_curated
):
    print(f"{sort_name}")
    print(f"Units before curation: {len(sort.get_unit_ids())}")
    print(f"Units after curation: {len(sort_curated.get_unit_ids())}")

## TODO Show how to split sorting outputs!


# 7) Quick save to NWB; writes only the spikes and lfp

## To complete the full conversion for other types of data, either
###    1) Run the external conversion script before this notebook, and append to it by setting overwrite=False below
###    2) Run the external conversion script after this notebook, which will append the NWBFile you make here so long as overwrite=False in the external script

In [None]:
# Name your NWBFile and decide where you want it saved
nwbfile_path = ced_file.parent / "CED_test.nwb"

# Enter Session and Subject information here
session_description = "Enter session description here."

# Manually insert the session start time
session_start = datetime(1971, 1, 1)  # (Year, Month, Day)

# Choose the sorting extractor from the notebook environment you would like to write to NWB
# chosen_sorting_extractor = sorting_outputs[('rec0', 'ironclust')]
# chosen_sorting_extractor = sorting_ensemble

quick_write(
    ced_file_path=ced_file,
    session_description=session_description,
    session_start=session_start,
    save_path=nwbfile_path,
    #     sorting=chosen_sorting_extractor,
    recording_lfp=recording_lfp,
    overwrite=True,
)

In [None]:
# Check NWB file with widgets
from pynwb import NWBFile, NWBHDF5IO
from nwbwidgets import nwb2widget

io = NWBHDF5IO(str(nwbfile_path), "r")
nwbfile = io.read()
nwb2widget(nwbfile)

In [None]:
nwbfile.processing["ecephys"].data_interfaces["LFP"].electrical_series[
    "ElectricalSeries"
]