In [None]:
%matplotlib inline
import matplotlib.pyplot as plt
import numpy as np
from pprint import pprint
from pathlib import Path

import spikeinterface.core as si  # import core only
# from spikeinterface.core import load_extractor
import spikeinterface.extractors as se
import spikeinterface.preprocessing as spre
import spikeinterface.sorters as ss
import spikeinterface.postprocessing as spost
import spikeinterface.qualitymetrics as sqm
import spikeinterface.comparison as sc
import spikeinterface.exporters as sexp
import spikeinterface.curation as scur
import spikeinterface.widgets as sw

from probeinterface.plotting import plot_probe
from spikeinterface.sortingcomponents.peak_detection import detect_peaks
from spikeinterface.sortingcomponents.peak_localization import localize_peaks

In [None]:
# set directory path for reading/saving data
DATA_DIR = "E:\Crick\ephys_test"

In [None]:
# set global arguments for parallel processing
global_job_kwargs = dict(n_jobs=4, chunk_duration="1s")
si.set_global_job_kwargs(**global_job_kwargs)

## 1. Read recording

In [None]:
# # if loading previously preprocessed recording
# preprocessed_recording = si.load_extractor(f"{DATA_DIR}\{recording_date}\preprocessed_recording")
# print(preprocessed_recording)

In [None]:
# load simulated data 
data_path = f"{DATA_DIR}\Record_Node_104" # full data path (90GB) --> "E:\Crick\ephys_test\Record_Node_104"
recording = se.read_openephys(data_path, stream_id="1") # stream_id=0: "NI-DAQmx-102.PXIe-6341"; stream_id=1: "Neuropix-PXI-100.ProbeA"
print('done')

In [None]:
print(recording)

In [None]:
# details about probe (optional)
recording.get_probe().to_dataframe()

In [None]:
# details about probe (optional)
fig, ax = plt.subplots(figsize=(15,10))
sw.plot_probe_map(recording, ax=ax, with_channel_ids=True)
ax.set_ylim(-100, 100)

## 2. Preprocessing

In [None]:
rec1 = spre.highpass_filter(recording=recording, freq_min=400.)
bad_channel_ids, channel_labels = spre.detect_bad_channels(rec1)
rec2 = rec1.remove_channels(bad_channel_ids)
print('bad_channel_ids: ', bad_channel_ids)

rec3 = spre.phase_shift(recording=rec2)
rec4 = spre.common_reference(recording=rec3, operator='median', reference='global')
preprocessed_recording = rec4
preprocessed_recording

In [None]:
# # plot details about preprocessing steps (optional)
# # this method is interactive (a bit laggy), the ones below are static
# %matplotlib widget
# sw.plot_traces({'filter': rec1, 'cmr': rec4}, backend='ipywidgets')

In [None]:
# plot details about preprocessing steps (optional)
fig, axs = plt.subplots(ncols=3, figsize=(20, 10))
sw.plot_traces(rec1, backend='matplotlib', clim=(-50, 50), ax=axs[0])
sw.plot_traces(rec4, backend='matplotlib', clim=(-50, 50), ax=axs[1])
sw.plot_traces(preprocessed_recording, backend='matplotlib', clim=(-50, 50), ax=axs[2])
for i, label in enumerate(('filter', 'cmr', 'final')):
    axs[i].set_title(label)

In [None]:
# plot some channels (optional)
fig, ax = plt.subplots(figsize=(20, 10))
some_chans = preprocessed_recording.channel_ids[[100, 150, 200, ]]
sw.plot_traces({'filter': rec1, 'cmr': rec4}, backend='matplotlib', mode='line', ax=ax, channel_ids=some_chans)

In [None]:
# only run if preprocessed recording not already saved
job_kwargs = dict(n_jobs=40, chunk_duration='1s', progress_bar=True)
preprocessed_recording = preprocessed_recording.save(format="binary", folder=f"{DATA_DIR}\{recording_date}_out\preprocessed_recording", **job_kwargs)
preprocessed_recording

## 3. Check noise level

In [None]:
# we can estimate the noise on the scaled traces (microV) or on the raw one (which is in our case int16).
noise_levels_microV = si.get_noise_levels(preprocessed_recording, return_scaled=True)
noise_levels_int16 = si.get_noise_levels(preprocessed_recording, return_scaled=False)

In [None]:
fig, ax = plt.subplots()
_ = ax.hist(noise_levels_microV, bins=np.arange(5, 15, 2.5))
ax.set_xlabel('noise  [microV]')

In [None]:
fig, ax = plt.subplots()
_ = ax.hist(noise_levels_microV, bins=np.arange(5, 30, 2.5))
ax.set_xlabel('noise  [microV]')

## 4. Detect and localize peaks

In [None]:
job_kwargs = dict(n_jobs=40, chunk_duration='1s', progress_bar=True)
peaks = detect_peaks(preprocessed_recording,  method='locally_exclusive', noise_levels=noise_levels_int16,
                     detect_threshold=5, radius_um=50., **job_kwargs)
peaks

In [None]:
peak_locations = localize_peaks(preprocessed_recording, peaks, method='center_of_mass', radius_um=50., **job_kwargs)

## 5. Check for drift (optional?)

In [None]:
# check for drifts
fs = preprocessed_recording.sampling_frequency
fig, ax = plt.subplots(figsize=(10, 8))
ax.scatter(peaks['sample_index'] / fs, peak_locations['y'], color='k', marker='.',  alpha=0.002)

In [None]:
# we can also use the peak location estimates to have an insight of cluster separation before sorting
fig, ax = plt.subplots(figsize=(10, 10))
sw.plot_probe_map(preprocessed_recording, ax=ax, with_channel_ids=True)
ax.set_ylim(-100, 300)

ax.scatter(peak_locations['x'], peak_locations['y'], color='purple', alpha=0.002)

## 6. Run Spike Sorter

In [None]:
# check which sorters are implemented/available and which are installed
print("Available sorters", ss.available_sorters())
print("Installed sorters", ss.installed_sorters())

In [None]:
ss.get_default_sorter_params("kilosort4")
params_kilosort4 = {'do_correction': False}
print(params_kilosort4)

In [None]:
# run kilosort4 (specific to neuropixels)
sorting_kilo4 = ss.run_sorter(sorter_name="kilosort4", recording=preprocessed_recording, folder=f"{DATA_DIR}\{recording_date}_out\output_kilosort4", verbose=True, **params_kilosort4)
print(sorting_kilo4)

In [None]:
print("Units found by kilo4:", sorting_kilo4.get_unit_ids())

## 7. Create sorting analyzer + compute postprocessing and metrics

In [None]:
# postprocessing --> SortingAnalyzer object = pairing BaseRecording and BaseSorting
# SortingAnalyzer can be loaded in memory or saved in a folder
# this folder contains waveforms, templates, and other postprocessing data
analyzer_kilo4 = si.create_sorting_analyzer(sorting=sorting_kilo4, recording=preprocessed_recording, format='binary_folder', folder=rf"{DATA_DIR}\{recording_date}_out\analyzer_kilo4_binary")

In [None]:
### NEUROPIXEL SPECIFIC -- https://spikeinterface.readthedocs.io/en/latest/how_to/analyze_neuropixels.html

In [None]:
analyzer_kilo4.compute("random_spikes", method="uniform", max_spikes_per_unit=500)
analyzer_kilo4.compute("waveforms",  ms_before=1.5,ms_after=2., **job_kwargs)
analyzer_kilo4.compute("templates", operators=["average", "median", "std"])
analyzer_kilo4.compute("noise_levels")
analyzer_kilo4

In [None]:
analyzer_kilo4.compute("correlograms")
analyzer_kilo4.compute("unit_locations")
analyzer_kilo4.compute("spike_amplitudes", **job_kwargs)
analyzer_kilo4.compute("template_similarity")
analyzer_kilo4

In [None]:
# quality metrics
metric_names=['firing_rate', 'presence_ratio', 'snr', 'isi_violation', 'amplitude_cutoff']
metrics = analyzer_kilo4.compute("quality_metrics").get_data()
metrics

In [None]:
amplitude_cutoff_thresh = 0.1
isi_violations_ratio_thresh = 1
presence_ratio_thresh = 0.9
our_query = f"(amplitude_cutoff < {amplitude_cutoff_thresh}) & (isi_violations_ratio < {isi_violations_ratio_thresh}) & (presence_ratio > {presence_ratio_thresh})"
print(our_query)

In [None]:
keep_units = metrics.query(our_query)
keep_unit_ids = keep_units.index.values
keep_unit_ids

In [None]:
analyzer_clean = analyzer_kilo4.select_units(keep_unit_ids, folder=rf"{DATA_DIR}\{recording_date}_out\analyzer_clean", format='binary_folder')

In [None]:
# export spike sorting report to a folder
sexp.export_report(analyzer_clean, rf"{DATA_DIR}\{recording_date}_out\report", format='png')

In [None]:
# load in existing sorting analyzer_clean if already created and saved
# analyzer_clean = si.load_sorting_analyzer(rf"{DATA_DIR}\{recording_date}_out\analyzer_clean")
# analyzer_clean

In [None]:
# # load in existing sorting analyzer if already created and saved
# analyzer_kilo4 = si.load_sorting_analyzer(rf"{DATA_DIR}\{recording_date}_out\analyzer_kilo4_binary", load_extensions=True)
# analyzer_kilo4

## 9. Export to Phy (for manual curation)

In [None]:
# alternatively, export data locally to Phy
# Phy --> GUI for manual curation of spike sorting output
sexp.export_to_phy(analyzer_kilo4, rf"{DATA_DIR}\{recording_date}_out\phy_folder", verbose=True)

In [None]:
# after curating with Phy, reload the "curated sorting" 
# exclude units labeled as "noise"
sorting_curated_phy = se.read_phy(rf"{DATA_DIR}\{recording_date}_out\phy_folder", exclude_cluster_groups=["noise"])
print(sorting_curated_phy)

In [None]:
# after curating with Phy, reload the "curated sorting" 
# exclude units labeled as "noise" and "mua"
sorting_curated_phy = se.read_phy(rf"{DATA_DIR}\{recording_date}_out\phy_folder", exclude_cluster_groups=["noise", "mua"])
print(sorting_curated_phy)