In [None]:
%load_ext autoreload
%autoreload 2

# SpikeInterface pipeline for Mease Lab - CED

In [None]:
from pathlib import Path
from os import getenv
import numpy as np
from pprint import pprint
import matplotlib.pyplot as plt

In [None]:
import spikeinterface as si
import spikeinterface.extractors as se
import spikeinterface.toolkit as st
import spikeinterface.sorters as ss
import spikeinterface.comparison as sc
import spikeinterface.widgets as sw
import spikeinterface.exporters as sx
import probeinterface as sp

In [None]:
from mease_lab_to_nwb.convert_ced.cednwbconverter import quick_write

In [None]:
import datetime
import resource

# print function that timestamps the output & displays max memory usage
def tprint(string=""):
    max_mem = resource.getrusage(resource.RUSAGE_SELF).ru_maxrss / 1e6
    print(f"|| {datetime.datetime.now():%H:%M:%S} || {max_mem} || {string}")

## 1) Load bin recording instead of smrx file

In [None]:
tprint("Loading recording")

In [None]:
bin_file = Path(
    r"/mnt/sds-hd/sd19b001/PainData/Corrected_Channel_Map/L6/Cortex/20.8.21/KS2/m6.bin"
)
recording_prb = "cambridge_neurotech_H3.prb"
sampling_frequency = 3.003003003003003e04
data_type = "int16"
numChan = 64;

In [None]:
spikeinterface_folder = bin_file.parent / "liam_new_api"
spikeinterface_folder.mkdir(parents=True, exist_ok=True)

In [None]:
# Rhd channels should have already been selected thanks to smrx2bin
recording = se.BinaryRecordingExtractor(
    bin_file, sampling_frequency, numChan, data_type
)

In [None]:
# load probe file
probegroup = sp.read_prb(recording_prb)

In [None]:
# add probe file to recording
recording = recording.set_probegroup(probegroup)

In [None]:
# this looks like the closest to get_shared_channel_property_names:
recording.get_property_keys()

In [None]:
plt.figure()
# note: previously get_traces returned array[channel][time]
# with new spikeinterface API order is swapped: array[time][channel]
plt.plot(recording.get_traces(end_frame=30000)[:, 0], label="channel 0")
plt.legend()
plt.show()

In [None]:
sw.plot_probe_map(recording)
plt.show()

In [None]:
print(f"Num channels: {recording.get_num_channels()}")
print(f"Channel ids: {recording.get_channel_ids()}")
print(f"Sampling rate: {recording.get_sampling_frequency()}")
print(
    f"Duration (s): {recording.get_num_frames() / recording.get_sampling_frequency()}"
)

# Get LFPs from MultiChannel Data

In [None]:
tprint("LFPs")

In [None]:
recording_lfp = recording

In [None]:
# note: no "resample" method in new API: skipping this line
# recording_lfp = st.preprocessing.resample(recording_lfp, resample_rate=1000)

In [None]:
apply_filter = True
freq_min_hp = 0.1
freq_max_hp = 300

In [None]:
if apply_filter:
    recording_lfp = st.preprocessing.bandpass_filter(
        recording_lfp, freq_min=freq_min_hp, freq_max=freq_max_hp
    )
else:
    recording_lfp = recording_lfp

In [None]:
# recording_lfp is a copy of the original recording with an additional bandpass_filter pre-processing step
# the pre-processing is "lazy", which means it isn't done until needed, e.g. when get_traces is called:
plt.figure()
plt.plot(recording.get_traces(end_frame=3000)[:, 0], label="channel 0")
plt.plot(recording_lfp.get_traces(end_frame=3000)[:, 0], label="channel 0 LFP")
plt.legend()
plt.show()

## 2) Pre-processing

In [None]:
tprint("Pre-processing")

In [None]:
st.preprocessing.preprocessers_full_list

In [None]:
apply_filter = True
apply_cmr = True
freq_min_hp = 600
freq_max_hp = 6000

In [None]:
# st.preprocessing.common_reference?

In [None]:
if apply_filter:
    recording_processed = st.preprocessing.bandpass_filter(
        recording, freq_min=freq_min_hp, freq_max=freq_max_hp
    )
else:
    recording_processed = recording

if apply_cmr:
    recording_processed = st.preprocessing.common_reference(recording_processed)

In [None]:
# Stub recording for fast testing; set to False for running processing pipeline on entire data
stub_test = False
nsec_stub = 60

if stub_test:
    recording_processed = recording_processed.frame_slice(
        0, int(nsec_stub * recording_processed.get_sampling_frequency())
    )
    recording_lfp = recording_lfp.frame_slice(
        0, int(nsec_stub * recording_lfp.get_sampling_frequency())
    )

print(f"Original signal length: {recording.get_num_frames()}")
print(f"Processed signal length: {recording_processed.get_num_frames()}")

## Inspect signals

In [None]:
fig, (ax0, ax1, ax2) = plt.subplots(ncols=3, figsize=(24, 8))
ax0.set_title("Recording")
sw.plot_timeseries(recording, time_range=[4, 8], ax=ax0)
ax1.set_title("LFP")
sw.plot_timeseries(recording_lfp, time_range=[4, 8], ax=ax1)
ax2.set_title("Processed")
sw.plot_timeseries(recording_processed, time_range=[4, 8], ax=ax2)
plt.show()

## 3) Run spike sorters

In [None]:
# print('Installed sorters:', ss.installed_sorters())

In [None]:
# sorter_list = [
#     "kilosort2_5"
# ]

In [None]:
# sorter_params = dict()
# for sorter in sorter_list:
#     # start with defaults
#     params = ss.get_default_params(sorter)
#     # make changes
#     params['chunk_mb'] = 2000
#     params['n_jobs_bin'] = 16
#     # print params
#     print(f"\n\n{sorter} params description:")
#     pprint(ss.get_params_description(sorter))
#     print(f"\n\n{sorter} params:")
#     pprint(params)
#     sorter_params[sorter] = params

In [None]:
# ss.available_sorters()

In [None]:
# ss.run_sorters?

In [None]:
# set this to True to use the local SSD on the cluster node to store temporary files
# if False, all data is written to SDS, which works fine but is a bit slower
# BUT: local SSD only has 120GB of space on most GPU nodes - for large recordings this may not be enough space!

In [None]:
# use_scratch_dir = True

In [None]:
# sorting_working_folder = spikeinterface_folder / 'simple_bin_output'
# if use_scratch_dir:
#     local_scratch_dir = Path(getenv("TMPDIR"))
#     tprint(f"using local_scratch_dir = {local_scratch_dir}")
#     sorting_working_folder = spikeinterface_folder / 'simple_bin_output'

In [None]:
tprint("Running sorters")

In [None]:
# sorting_outputs = ss.run_sorters(
#     sorter_list=sorter_list,
#     working_folder=sorting_working_folder,
#     recording_dict_or_list=dict(rec0=recording_processed),
#     sorter_params=sorter_params,
#     mode="keep", # "overwrite" to overwrite # change to "keep" to avoid repeating the spike sorting
#     verbose=True,
# )

The `sorting_outputs` is a dictionary with ("rec_name", "sorter_name") as keys.

In [None]:
# tprint("Finished running sorters")
# for result_name, sorting in sorting_outputs.items():
#     rec_name, sorter = result_name
#     print(f"{sorter} found {len(sorting.get_unit_ids())} units")

In [None]:
sorting = se.PhySortingExtractor(
    "/mnt/sds/PainData/Corrected_Channel_Map/L6/Cortex/20.8.21/KS2/simple_bin_output/rec0/kilosort2_5"
)

In [None]:
we = si.extract_waveforms(
    recording_processed,
    sorting,
    "temp_waveforms",
    n_jobs=6,
    total_memory="1G",
    verbose=True,
    progress_bar=True,
    overwrite=True,
)

In [None]:
tprint("post extract_waveforms")

In [None]:
sx.export_to_phy(
    we, "export_to_phy", n_jobs=6, total_memory="2G", verbose=True, progress_bar=True
)
tprint("post export_to_phy")

In [None]:
waveforms = sx.export_to_phy

In [None]:
tprint("compute templates")
templates = st.postprocessing.get_unit_templates(
    recording_processed, sorting, n_jobs=16, chunk_mb=2000, **postprocessing_params
)

tprint("compute EC features")
ec = st.postprocessing.compute_unit_template_features(
    recording_processed,
    sorting,
    n_jobs=16,
    chunk_mb=2000,
    feature_names=ec_list,
    as_dataframe=True,
    memmap=True,
)
## compute QCs
# qc = st.validation.compute_quality_metrics(sorting, recording=recording_processed,
#                                           metric_names=qc_list, as_dataframe=True, memmap = False)

# export to phy example
# pprint(postprocessing_params)
if sorter == "kilosort2_5":
    # pprint(postprocessing_params)
    recompute_info = True
    phy_folder = spikeinterface_folder / "phy" / sorter
    phy_folder.mkdir(parents=True, exist_ok=True)
    tprint("Exporting to phy")
    st.postprocessing.export_to_phy(
        recording_processed,
        sorting,
        phy_folder,
        compute_pc_features=False,
        verbose=True,
        memmap=True,
        recompute_info=True,
        n_jobs=16,
    )

## 4) Post-processing: extract waveforms, templates, quality metrics, extracellular features

### Set postprocessing parameters

In [None]:
st.postprocessing.template_metrics.pd.

In [None]:
# Post-processing params
postprocessing_params = st.postprocessing.get_common_params()
postprocessing_params["verbose"] = True
postprocessing_params["recompute_info"] = True
postprocessing_params["memmap"] = True
postprocessing_params[
    "max_spikes_per_unit"
] = 1000  # with None, all waveforms are extracted
pprint(postprocessing_params)

### Set quality metric list

In [None]:
# Quality metrics
qc_list = st.validation.get_quality_metrics_list()
print(f"Available quality metrics: {qc_list}")

In [None]:
# (optional) define subset of qc
qc_list = [
    "num_spikes",
    "firing_rate",
    "presence_ratio",
    "isi_violation",
    "amplitude_cutoff",
    "snr",
    "max_drift",
    "cumulative_drift",
    "silhouette_score",
    "isolation_distance",
    "l_ratio",
    "noise_overlap",
    "nn_hit_rate",
    "nn_miss_rate",
]

### Set extracellular features

In [None]:
# Extracellular features
ec_list = st.postprocessing.get_template_features_list()
print(f"Available EC features: {ec_list}")

### Postprocess all sorting outputs

In [None]:
for result_name, sorting in sorting_outputs.items():
    rec_name, sorter = result_name
    tprint(f"Postprocessing recording {rec_name} sorted with {sorter}")
    tmp_folder = local_scratch_dir / 'tmp_ced' / sorter
    tmp_folder.mkdir(parents=True, exist_ok=True)
    
    # set local tmp folder
    sorting.set_tmp_folder(tmp_folder)
     
       
    # pprint(postprocessing_params)
    
    st.postprocessing.
    
    tprint("compute waveforms")
    waveforms = st.postprocessing.get_unit_waveforms(recording_processed, sorting, 
                                                     n_jobs=16, chunk_mb=2000, **postprocessing_params)
    
    tprint("compute templates")
    templates = st.postprocessing.get_unit_templates(recording_processed, sorting, n_jobs=16, chunk_mb=2000, **postprocessing_params)
    
    tprint("compute EC features")
    ec = st.postprocessing.compute_unit_template_features(recording_processed, sorting, n_jobs=16, chunk_mb=2000,
                                                          feature_names=ec_list, as_dataframe=True, memmap = True)
    ## compute QCs
    #qc = st.validation.compute_quality_metrics(sorting, recording=recording_processed, 
    #                                           metric_names=qc_list, as_dataframe=True, memmap = False)
    
    # export to phy example
    # pprint(postprocessing_params)
    if sorter == "kilosort2_5":
       # pprint(postprocessing_params)
        recompute_info = True
        phy_folder = spikeinterface_folder / 'phy' / sorter
        phy_folder.mkdir(parents=True, exist_ok=True)
        tprint("Exporting to phy")
        st.postprocessing.export_to_phy(recording_processed, sorting, phy_folder, compute_pc_features=False, verbose=True, memmap = True, recompute_info = True, n_jobs=16)
        #st.postprocessing.export_to_phy(recording_processed, sorting, phy_folder, verbose=True, compute_pc_features=False, compute_amplitudes=False, memmap = False, recompute_info = True, n_jobs=24)

In [None]:
sorting_kilosort = sorting_outputs[("rec0", "kilosort2_5")]
print(f"Properties: {sorting_kilosort.get_shared_unit_property_names()}")
print(f"Spikefeatures: {sorting_kilosort.get_shared_unit_spike_feature_names()}")

In [None]:
tprint("Postprocessing done")
# stop here for now:
raise Exception("Stopping notebook")

### Load Phy-curated data back to SI

In [None]:
!phy template-gui Z:\PainData\m365\10min\phy\kilosort2_5\params.py

In [None]:
phy_folder = r"Z:\PainData\Corrected_Channel_Map\L6\Cortex\16.12.20\phy\kilosort3"
recording_phy = se.PhyRecordingExtractor(phy_folder)
sorting_curated = se.PhySortingExtractor(phy_folder)
sorting_phy = se.PhySortingExtractor(phy_folder, exclude_cluster_groups=["noise"])
print(f"Units after manual curation: {len(sorting_curated.get_unit_ids())}")

In [None]:
good_units = []
for u in sorting_phy.get_unit_ids():
    if sorting_phy.get_unit_property(u, "quality") == "good":
        good_units.append(u)
sorting_good = se.SubSortingExtractor(sorting_phy, unit_ids=good_units)
print(good_units)

In [None]:
?st.curation.threshold_num_spikes

In [None]:
sorting_curated = st.curation.threshold_num_spikes(
    sorting_curated, threshold=50, threshold_sign="less"
)
print(f"Units after num spikes curation: {len(sorting_curated.get_unit_ids())}")

In [None]:
tr_phy = recording_phy.get_traces(end_frame=30000)

In [None]:
plt.figure()
plt.plot(tr_phy[0])

# 7) Quick save to NWB; writes only the spikes and lfp

## To complete the full conversion for other types of data, either
###    1) Run the external conversion script before this notebook, and append to it by setting overwrite=False below
###    2) Run the external conversion script after this notebook, which will append the NWBFile you make here so long as overwrite=False in the external script

In [None]:
# Name your NWBFile and decide where you want it saved
nwbfile_path = r"Z:\PainData\Corrected_Channel_Map\L6\Cortex\16.12.20\phy\kilosort3\m380_NewChanMap.nwb"

# Enter Session and Subject information here
session_description = "m380 spikes without TTL - 10min test"

# Manually insert the session start time
session_start = datetime(2020, 10, 8)  # (Year, Month, Day)
session_start_time = "2020-12-16T16:30:00"

# Choose the sorting extractor from the notebook environment you would like to write to NWB
# chosen_sorting_extractor = sorting_outputs[('rec0', 'ironclust')]
# chosen_sorting_extractor = sorting_ensemble

# quick_write(
#   ced_file_path=bin_file,
#  session_description=session_description,
# session_start=session_start,
# save_path=nwbfile_path,
# sorting=sorting_curated,
# recording_lfp=None,
# overwrite=True
# )

# se.NwbRecordingExtractor.write_recording(recording_lfp, 'LFPs.nw')

quick_write(
    ced_file_path=bin_file,
    session_description=session_description,
    session_start=session_start,
    save_path=nwbfile_path,
    sorting=sorting_phy,
    recording_lfp=None,
    overwrite=True,
)