In [1]:
%load_ext autoreload
%autoreload 2
%matplotlib notebook
%config Completer.use_jedi = False

# SpikeInterface pipeline for Movshon Lab - Blackrock

In [2]:
from pathlib import Path
import numpy as np
import matplotlib.pyplot as plt
from pprint import pprint
from datetime import datetime, timedelta
from isodate import duration_isoformat

import spikeextractors as se
import spiketoolkit as st
import spikesorters as ss
import spikecomparison as sc
import spikewidgets as sw
from nwb_conversion_tools.json_schema_utils import dict_deep_update
from nwb_conversion_tools.conversion_tools import save_si_object

## 1) Load recordings, compute LFP, and inspect signals

In [3]:
# Data files directory
dir_path = Path('/media/luiz/storage/taufferconsulting/client_ben/project_movshon/data')

# Test if file exists
file_path = dir_path / 'XX_LE_textures_20191128_002.ns6'
print(f'File exists: {file_path.is_file()}')

# Spikeinterface directory
dir_spikeinterface = dir_path / "spikeinterface"
dir_spikeinterface.mkdir(parents=True, exist_ok=True)
print(dir_spikeinterface)

File exists: True
/media/luiz/storage/taufferconsulting/client_ben/project_movshon/data/spikeinterface


In [4]:
# Choose Blackrock extension (ns*) to be used
nsx_to_load = 6

# Create recording extractor
recording = se.BlackrockRecordingExtractor(
    filename=str(file_path),
    nsx_to_load=nsx_to_load
)

# Set indvidual groups
recording.set_channel_groups(range(recording.get_num_channels()))

### Stub recording for fast testing; set to False for running processing pipeline on entire data

In [5]:
print(f"Num channels: {recording.get_num_channels()}")
print(f"Sampling rate: {recording.get_sampling_frequency()}")
print(f"Duration (s): {recording.get_num_frames() / recording.get_sampling_frequency()}")

Num channels: 192
Sampling rate: 30000.0
Duration (s): 1942.1709333333333


### Compute LFP

In [6]:
freq_min_lfp = 1
freq_max_lfp = 300
freq_resample_lfp = 1000.

# Apply bandpass filter
recording_lfp = st.preprocessing.bandpass_filter(recording, freq_min=freq_min_lfp, freq_max=freq_max_lfp)

# Resample lfp
recording_lfp = st.preprocessing.resample(recording_lfp, freq_resample_lfp)

In [7]:
print(f"Sampling frequency AP: {recording.get_sampling_frequency()}")
print(f"Sampling frequency LF: {recording_lfp.get_sampling_frequency()}")      

Sampling frequency AP: 30000.0
Sampling frequency LF: 1000.0


### Inspect signals

In [8]:
w_ts_ap = sw.plot_timeseries(recording, trange=[0, 5], channel_ids=[1, 2, 3])

<IPython.core.display.Javascript object>

In [9]:
# w_ts_lf = sw.plot_timeseries(recording_lfp, trange=[10, 15])

## 2) Pre-processing

In [10]:
apply_filter = True
apply_cmr = True
freq_min_hp = 300
freq_max_hp = 3000

In [11]:
if apply_filter:
    recording_processed = st.preprocessing.bandpass_filter(recording, freq_min=freq_min_hp, freq_max=freq_max_hp)
else:
    recording_processed = recording

if apply_cmr:
    recording_processed = st.preprocessing.common_reference(recording_processed)
    
# Stub recording for fast testing; set to False for running processing pipeline on entire data
stub_test = True
nsec_stub = 5
subr_ids = [i + 1 for i in range(10)]
if stub_test:
    recording_processed = se.SubRecordingExtractor(recording_processed, end_frame=int(nsec_stub*recording_processed.get_sampling_frequency()))
    recording_lfp = se.SubRecordingExtractor(recording_lfp, end_frame=int(nsec_stub*recording_lfp.get_sampling_frequency()))
    
# Subrecording for fast testing
subrec_processed = se.SubRecordingExtractor(
    parent_recording=recording_processed, 
    channel_ids=subr_ids
)
    
print(f"Original signal length: {recording.get_num_frames()}")
print(f"Processed signal length: {recording_processed.get_num_frames()}")

Original signal length: 58265128
Processed signal length: 150000


In [12]:
w_ts_ap = sw.plot_timeseries(recording_processed, trange=[0, 5], channel_ids=[1, 2, 3])

<IPython.core.display.Javascript object>

## 3) Run spike sorters    -     TODO

In [13]:
ss.installed_sorters()

['klusta']

In [14]:
sorter_list = [
    'klusta'
]

In [15]:
# Inspect sorter-specific parameters and defaults
for sorter in sorter_list:
    print(f"{sorter} params description:")
    pprint(ss.get_params_description(sorter))
    print("Default params:")
    pprint(ss.get_default_params(sorter))

klusta params description:
{'adjacency_radius': 'Radius in um to build channel neighborhood ',
 'chunk_mb': 'Chunk size in Mb for saving to binary format (default 500Mb)',
 'detect_sign': 'Use -1 (negative), 1 (positive) or 0 (both) depending on the '
                'sign of the spikes in the recording',
 'extract_s_after': 'Number of samples to cut out after the peak',
 'extract_s_before': 'Number of samples to cut out before the peak',
 'n_features_per_channel': 'Number of PCA features per channel',
 'n_jobs_bin': 'Number of jobs for saving to binary format (Default 1)',
 'num_starting_clusters': 'Number of initial clusters',
 'pca_n_waveforms_max': 'Maximum number of waveforms for PCA',
 'threshold_strong_std_factor': 'Strong threshold for spike detection',
 'threshold_weak_std_factor': 'Weak threshold for spike detection'}
Default params:
{'adjacency_radius': None,
 'chunk_mb': 500,
 'detect_sign': -1,
 'extract_s_after': 32,
 'extract_s_before': 16,
 'n_features_per_channel': 3,


In [16]:
# user-specific parameters
sorter_params = dict(
    klusta=dict()
)

In [17]:
# Choose which recording to use for sorting
rec_to_sort = recording_processed
# rec_to_sort = subrec_processed

# spike sort by group
sorting = ss.run_klusta(
    recording=rec_to_sort, 
    grouping_property='group',
    output_folder=dir_spikeinterface / "si_output",
)



RUNNING SHELL SCRIPT: /media/luiz/storage/taufferconsulting/client_ben/project_movshon/data/spikeinterface/si_output/0/run_klusta.sh
RUNNING SHELL SCRIPT: /media/luiz/storage/taufferconsulting/client_ben/project_movshon/data/spikeinterface/si_output/1/run_klusta.sh
RUNNING SHELL SCRIPT: /media/luiz/storage/taufferconsulting/client_ben/project_movshon/data/spikeinterface/si_output/2/run_klusta.sh
RUNNING SHELL SCRIPT: /media/luiz/storage/taufferconsulting/client_ben/project_movshon/data/spikeinterface/si_output/3/run_klusta.sh
RUNNING SHELL SCRIPT: /media/luiz/storage/taufferconsulting/client_ben/project_movshon/data/spikeinterface/si_output/4/run_klusta.sh
RUNNING SHELL SCRIPT: /media/luiz/storage/taufferconsulting/client_ben/project_movshon/data/spikeinterface/si_output/5/run_klusta.sh
RUNNING SHELL SCRIPT: /media/luiz/storage/taufferconsulting/client_ben/project_movshon/data/spikeinterface/si_output/6/run_klusta.sh
RUNNING SHELL SCRIPT: /media/luiz/storage/taufferconsulting/client_be

RUNNING SHELL SCRIPT: /media/luiz/storage/taufferconsulting/client_ben/project_movshon/data/spikeinterface/si_output/62/run_klusta.sh
RUNNING SHELL SCRIPT: /media/luiz/storage/taufferconsulting/client_ben/project_movshon/data/spikeinterface/si_output/63/run_klusta.sh
RUNNING SHELL SCRIPT: /media/luiz/storage/taufferconsulting/client_ben/project_movshon/data/spikeinterface/si_output/64/run_klusta.sh
RUNNING SHELL SCRIPT: /media/luiz/storage/taufferconsulting/client_ben/project_movshon/data/spikeinterface/si_output/65/run_klusta.sh
RUNNING SHELL SCRIPT: /media/luiz/storage/taufferconsulting/client_ben/project_movshon/data/spikeinterface/si_output/66/run_klusta.sh
RUNNING SHELL SCRIPT: /media/luiz/storage/taufferconsulting/client_ben/project_movshon/data/spikeinterface/si_output/67/run_klusta.sh
RUNNING SHELL SCRIPT: /media/luiz/storage/taufferconsulting/client_ben/project_movshon/data/spikeinterface/si_output/68/run_klusta.sh
RUNNING SHELL SCRIPT: /media/luiz/storage/taufferconsulting/cl

RUNNING SHELL SCRIPT: /media/luiz/storage/taufferconsulting/client_ben/project_movshon/data/spikeinterface/si_output/123/run_klusta.sh
RUNNING SHELL SCRIPT: /media/luiz/storage/taufferconsulting/client_ben/project_movshon/data/spikeinterface/si_output/124/run_klusta.sh
RUNNING SHELL SCRIPT: /media/luiz/storage/taufferconsulting/client_ben/project_movshon/data/spikeinterface/si_output/125/run_klusta.sh
RUNNING SHELL SCRIPT: /media/luiz/storage/taufferconsulting/client_ben/project_movshon/data/spikeinterface/si_output/126/run_klusta.sh
RUNNING SHELL SCRIPT: /media/luiz/storage/taufferconsulting/client_ben/project_movshon/data/spikeinterface/si_output/127/run_klusta.sh
RUNNING SHELL SCRIPT: /media/luiz/storage/taufferconsulting/client_ben/project_movshon/data/spikeinterface/si_output/128/run_klusta.sh
RUNNING SHELL SCRIPT: /media/luiz/storage/taufferconsulting/client_ben/project_movshon/data/spikeinterface/si_output/129/run_klusta.sh
RUNNING SHELL SCRIPT: /media/luiz/storage/taufferconsul

RUNNING SHELL SCRIPT: /media/luiz/storage/taufferconsulting/client_ben/project_movshon/data/spikeinterface/si_output/184/run_klusta.sh
RUNNING SHELL SCRIPT: /media/luiz/storage/taufferconsulting/client_ben/project_movshon/data/spikeinterface/si_output/185/run_klusta.sh
RUNNING SHELL SCRIPT: /media/luiz/storage/taufferconsulting/client_ben/project_movshon/data/spikeinterface/si_output/186/run_klusta.sh
RUNNING SHELL SCRIPT: /media/luiz/storage/taufferconsulting/client_ben/project_movshon/data/spikeinterface/si_output/187/run_klusta.sh
RUNNING SHELL SCRIPT: /media/luiz/storage/taufferconsulting/client_ben/project_movshon/data/spikeinterface/si_output/188/run_klusta.sh
RUNNING SHELL SCRIPT: /media/luiz/storage/taufferconsulting/client_ben/project_movshon/data/spikeinterface/si_output/189/run_klusta.sh
RUNNING SHELL SCRIPT: /media/luiz/storage/taufferconsulting/client_ben/project_movshon/data/spikeinterface/si_output/190/run_klusta.sh
RUNNING SHELL SCRIPT: /media/luiz/storage/taufferconsul

In [61]:
# Visualize pike waverforms
unit_id = 10
spike_id = 11

wv = sorting.get_unit_spike_features(unit_id=unit_id, feature_name='waveforms')
n_spikes, n_channels, n_samples = wv.shape
print(wv.shape)

plt.figure()
spk = np.squeeze(np.mean(wv, axis=1))
plt.plot(spk[spike_id, :])

(39, 192, 180)


<IPython.core.display.Javascript object>

[<matplotlib.lines.Line2D at 0x7f9fc4afeeb8>]

## 4) Post-processing: extract waveforms, templates, quality metrics, extracellular features

### Set postprocessing parameters

In [21]:
# Post-processing params
postprocessing_params = st.postprocessing.get_common_params()
pprint(postprocessing_params)

OrderedDict([('max_spikes_per_unit', 300),
             ('recompute_info', False),
             ('save_property_or_features', True),
             ('memmap', True),
             ('seed', 0),
             ('verbose', False),
             ('joblib_backend', 'loky')])


In [22]:
# (optional) change parameters
postprocessing_params['max_spikes_per_unit'] = 1000  # with None, all waveforms are extracted

### Set quality metric list

In [23]:
# Quality metrics
qc_list = st.validation.get_quality_metrics_list()
print(f"Available quality metrics: {qc_list}")

Available quality metrics: ['num_spikes', 'firing_rate', 'presence_ratio', 'isi_violation', 'amplitude_cutoff', 'snr', 'max_drift', 'cumulative_drift', 'silhouette_score', 'isolation_distance', 'l_ratio', 'd_prime', 'noise_overlap', 'nn_hit_rate', 'nn_miss_rate']


In [24]:
# (optional) define subset of qc
qc_list = ["snr", "isi_violation", "firing_rate"]

### Set extracellular features

In [25]:
# Extracellular features
ec_list = st.postprocessing.get_template_features_list()
print(f"Available EC features: {ec_list}")

Available EC features: ['peak_to_valley', 'halfwidth', 'peak_trough_ratio', 'repolarization_slope', 'recovery_slope']


In [26]:
# (optional) define subset of ec
ec_list = ["peak_to_valley", "halfwidth"]

### Postprocess all sorting outputs

In [27]:
# st.validation.compute_quality_metrics?

In [28]:
tmp_folder = dir_spikeinterface / 'tmp' / 'klusta'
tmp_folder.mkdir(parents=True, exist_ok=True)

# set local tmp folder
sorting.set_tmp_folder(tmp_folder)

# compute waveforms
waveforms = st.postprocessing.get_unit_waveforms(rec_to_sort, sorting, **postprocessing_params)

# compute templates
templates = st.postprocessing.get_unit_templates(rec_to_sort, sorting, **postprocessing_params)

# comput EC features
ec = st.postprocessing.compute_unit_template_features(rec_to_sort, sorting,
                                                      feature_names=ec_list, as_dataframe=True)
# compute QCs
qc = st.validation.compute_quality_metrics(sorting, recording=rec_to_sort, 
                                           metric_names=qc_list, as_dataframe=True)

# export to phy
phy_folder = dir_spikeinterface / 'phy' / 'klusta'
phy_folder.mkdir(parents=True, exist_ok=True)
st.postprocessing.export_to_phy(rec_to_sort, sorting, phy_folder)

Recomputing info


## 5) Compare with baseline - Extracting sorted spikes from Blackrock .nev

In [29]:
spikes_file = str(dir_path / 'XX_LE_textures_20191128_002.nev')

br_spike_extractor = se.BlackrockSortingExtractor(
    filename=spikes_file, 
    nsx_to_load=6
)

In [30]:
cmp_KL_BR = sc.compare_two_sorters(
    sorting1=sorting, sorting2=br_spike_extractor,
    sorting1_name='klusta', sorting2_name='blackrock'
)

In [31]:
# sw.plot_agreement_matrix(cmp_KL_BR)
cmp_KL_BR.agreement_scores

Unnamed: 0,0,1,2,3,4,5,6,7,8,9,...,290,291,292,293,294,295,296,297,298,299
0,0.000597,0.000276,0.000148,0.001019,0.000112,0.000000,0.000236,0.000317,0.000183,0.000222,...,0.000000,0.00000,0.000000,0.000000,0.0,0.0,0.0,0.000000,0.000000,0.000000
1,0.000000,0.000551,0.000000,0.001011,0.000112,0.000167,0.000157,0.000739,0.000549,0.000222,...,0.000095,0.00000,0.000000,0.000000,0.0,0.0,0.0,0.000000,0.000000,0.000000
2,0.000000,0.000000,0.000149,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,0.000056,...,0.000000,0.00000,0.000000,0.000000,0.0,0.0,0.0,0.000000,0.000000,0.000000
3,0.000199,0.000276,0.000148,0.002037,0.000335,0.000167,0.000315,0.000634,0.000428,0.000278,...,0.000095,0.00000,0.000000,0.000097,0.0,0.0,0.0,0.000000,0.000000,0.000178
4,0.000396,0.000412,0.000295,0.001505,0.001227,0.000499,0.000472,0.001054,0.000305,0.000555,...,0.000000,0.00000,0.000000,0.000000,0.0,0.0,0.0,0.000000,0.000000,0.000177
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
179,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,0.000061,0.000000,...,0.000000,0.00000,0.000000,0.000000,0.0,0.0,0.0,0.000087,0.000000,0.000000
180,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,...,0.001139,0.00103,0.003922,0.000097,0.0,0.0,0.0,0.000087,0.000000,0.000177
181,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,0.000061,0.000000,...,0.000000,0.00104,0.000000,0.000873,0.0,0.0,0.0,0.000000,0.000000,0.000177
182,0.000198,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,0.000056,...,0.000189,0.00000,0.000000,0.000097,0.0,0.0,0.0,0.001214,0.004603,0.000000


In [32]:
# # retrieve sortings and sorter names
# sorting_list = []
# sorter_names_comp = []
# for result_name, sorting in sorting_outputs.items():
#     rec_name, sorter = result_name
#     sorting_list.append(sorting)
#     sorter_names_comp.append(sorter)
    
# # run multisorting comparison
# mcmp = sc.compare_multiple_sorters(sorting_list=sorting_list, name_list=sorter_names)

# # plot agreement results
# w_agr = sw.plot_multicomp_agreement(mcmp)

# # extract ensemble sorting
# sorting_ensemble = mcmp.get_agreement_sorting(minimum_agreement_count=2)
# sorting_outputs.update(sorting_ensemble=sorting_ensemble)

# save_si_object(
#     "sorting_ensemble", sorting_ensemble, dir_spikeinterface,
#     cache_raw=False, include_properties=True, include_features=False
# )

# 6) Automatic curation

In [33]:
# define curators and thresholds
isi_violation_threshold = 0.5
snr_threshold = 5
firing_rate_threshold = 0.1

In [34]:
sorting_auto_curated = []
sorter_names_curation = []
# for result_name, sorting in sorting_outputs.items():
#     rec_name, sorter = result_name
#     sorter_names_curation.append(sorter)

num_frames = rec_to_sort.get_num_frames()

# firing rate threshold
sorting_curated = st.curation.threshold_firing_rates(
    sorting,
    duration_in_frames=num_frames,
    threshold=firing_rate_threshold, 
    threshold_sign='less'
)

# isi violation threshold
sorting_curated = st.curation.threshold_isi_violations(
    sorting,
    duration_in_frames=num_frames,
    threshold=isi_violation_threshold, 
    threshold_sign='greater'
)

# isi violation threshold
sorting_curated = st.curation.threshold_snrs(
    sorting,
    recording=recording_processed,
    threshold=snr_threshold, 
    threshold_sign='less'
)
sorting_auto_curated.append(sorting_curated)

# 7) Quick save to NWB; writes only the spikes

## To complete the full conversion for other types of data, use the external script

In [62]:
# # Name your NWBFile and decide where you want it saved
# nwbfile_path = base_path / "blackrock.nwb"

# # Enter Session and Subject information here
# session_description = "Enter session description here."

# # Choose the sorting extractor from the notebook environment you would like to write to NWB
# chosen_sorting_extractor = sorting_outputs[('rec0', 'ironclust')]

# quick_write(
#     intan_folder_path=syntalos_folder,
#     session_description=session_description,
#     save_path=nwbfile_path,
#     sorting=chosen_sorting_extractor,
#     lfp=recording_lfp,
#     timestamps=recording.get_timestamps(),
#     overwrite=False
# )