In [None]:
# Notebook settings
%load_ext autoreload
%autoreload 2
# %matplotlib notebook
%config Completer.use_jedi = False

# SpikeInterface pipeline for Movshon Lab - OpenEphys

In [None]:
from pathlib import Path
import numpy as np
import matplotlib.pyplot as plt
from pprint import pprint
import pytz

import spikeextractors as se
import spiketoolkit as st
import spikesorters as ss
import spikecomparison as sc
import spikewidgets as sw

## 1) Load recordings, compute LFP, and inspect signals

In [None]:
# Data files directory
# dir_path = Path('/home/luiz/storage/taufferconsulting/client_ben/project_movshon/movshon_data/oephys')
dir_path = Path('/home/luiz/storage/taufferconsulting/client_ben/project_movshon/movshon_data/expo/exampledata/expo_openephys/m666l3#7/openephys')

# Spikeinterface directory
dir_spikeinterface = dir_path / "spikeinterface"
dir_spikeinterface.mkdir(parents=True, exist_ok=True)
print(dir_spikeinterface)

# Create recording extractor
recording = se.OpenEphysRecordingExtractor(folder_path=dir_path)

print()
print(f"Num channels: {recording.get_num_channels()}")
print(f"Sampling rate: {recording.get_sampling_frequency()}")
print(f"Duration (s): {recording.get_num_frames() / recording.get_sampling_frequency()}")

### Compute LFP

In [None]:
freq_min_lfp = 1
freq_max_lfp = 300
freq_resample_lfp = 1000.

# Apply bandpass filter
recording_lfp = st.preprocessing.bandpass_filter(
    recording=recording, 
    freq_min=freq_min_lfp, 
    freq_max=freq_max_lfp
)

# Resample lfp
recording_lfp = st.preprocessing.resample(
    recording=recording_lfp, 
    resample_rate=freq_resample_lfp,
)

print(f"Sampling frequency Raw: {recording.get_sampling_frequency()}")
print(f"Sampling frequency LF: {recording_lfp.get_sampling_frequency()}")   

### Inspect signals

In [None]:
w_ts_raw = sw.plot_timeseries(recording, trange=[0, 5], channel_ids=[1, 2, 3])

In [None]:
w_ts_lf = sw.plot_timeseries(recording_lfp, trange=[0, 5], channel_ids=[1, 2, 3])

## 2) Pre-processing
- Filters
- Common-reference removal
- Remove bad channels
- Remove stimulation artifacts

Ref: https://spikeinterface.readthedocs.io/en/latest/modules/toolkit/plot_1_preprocessing.html#preprocessing-tutorial

In [None]:
apply_filter = True
apply_cmr = True
freq_min_hp = 300
freq_max_hp = 3000

In [None]:
# Filtered recordings
if apply_filter:
    recording_processed = st.preprocessing.bandpass_filter(recording, freq_min=freq_min_hp, freq_max=freq_max_hp)
else:
    recording_processed = recording

if apply_cmr:
    recording_processed = st.preprocessing.common_reference(recording_processed)
    
# Stub recording for fast testing; set to False for running processing pipeline on entire data
stub_test = True
nsec_stub = 10
subr_ids = [i + 1 for i in range(10)]
if stub_test:
    recording_processed = se.SubRecordingExtractor(
        parent_recording=recording_processed, 
        channel_ids=subr_ids,
        end_frame=int(nsec_stub*recording_processed.get_sampling_frequency())
    )
    recording_lfp = se.SubRecordingExtractor(recording_lfp, end_frame=int(nsec_stub*recording_lfp.get_sampling_frequency()))
    
print(f"Original signal length: {recording.get_num_frames()}")
print(f"Processed signal length: {recording_processed.get_num_frames()}")

In [None]:
w_ts_processed = sw.plot_timeseries(recording_processed, trange=[0, 5], channel_ids=[1, 2, 3])

## 3) Run spike sorters

Ref: https://spikeinterface.readthedocs.io/en/latest/sortersinfo.html

In [None]:
ss.installed_sorters()

In [None]:
sorter_list = ['klusta']

# Inspect sorter-specific parameters and defaults
for sorter in sorter_list:
    print(f"{sorter} params description:")
    pprint(ss.get_params_description(sorter))
    print("Default params:")
    pprint(ss.get_default_params(sorter))
    print()

In [None]:
# Choose which recording to use for sorting
rec_to_sort = recording_processed

# run sorting
sorting = ss.run_klusta(
    recording=rec_to_sort, 
    output_folder=dir_spikeinterface / "si_output"
)

## 4) Post-processing
- Compute spike waveforms
- Compute unit templates
- Compute extracellular features

Ref: https://spikeinterface.readthedocs.io/en/latest/modules/toolkit/plot_2_postprocessing.html

### Postprocess sorting results

In [None]:
# Post-processing params
postprocessing_params = st.postprocessing.get_common_params()
pprint(postprocessing_params)

In [None]:
# (optional) change parameters
postprocessing_params['max_spikes_per_unit'] = 1000  # with None, all waveforms are extracted

In [None]:
tmp_folder = dir_spikeinterface / 'tmp' / 'klusta'
tmp_folder.mkdir(parents=True, exist_ok=True)

# set local tmp folder
sorting.set_tmp_folder(tmp_folder)

# compute waveforms
waveforms = st.postprocessing.get_unit_waveforms(rec_to_sort, sorting, **postprocessing_params)

# compute templates
templates = st.postprocessing.get_unit_templates(rec_to_sort, sorting, **postprocessing_params)

# export to phy
phy_folder = dir_spikeinterface / 'phy' / 'klusta'
phy_folder.mkdir(parents=True, exist_ok=True)
st.postprocessing.export_to_phy(rec_to_sort, sorting, phy_folder)

In [None]:
# Visualize spike template waverforms
unit_id = 1

plt.figure(figsize=(12, 6))
spk = np.squeeze(templates[unit_id])
plt.plot(spk.T)
plt.ylabel('Avg spike trace per channel', fontsize=12)

In [None]:
# Extracellular features
ec_list = st.postprocessing.get_template_features_list()
print(f"Available EC features: {ec_list}")

In [None]:
# (optional) define subset of ec
ec_list = ["peak_to_valley", "halfwidth"]

In [None]:
# comput EC features
ec = st.postprocessing.compute_unit_template_features(
    rec_to_sort, 
    sorting,
    feature_names=ec_list, 
    as_dataframe=True
)

ec.head()

# 5) Automatic curation

You can automatically curate the spike sorting output using the quality metrics.

Ref: https://spikeinterface.readthedocs.io/en/latest/modules/toolkit/plot_4_curation.html

In [None]:
# st.validation.compute_quality_metrics?

# Quality metrics
qc_list = st.validation.get_quality_metrics_list()
print(f"Available quality metrics: {qc_list}")

In [None]:
# (optional) define subset of qc
qc_list = ["snr", "isi_violation", "firing_rate"]

In [None]:
# compute quality metrics
qc = st.validation.compute_quality_metrics(
    sorting=sorting, 
    recording=rec_to_sort, 
    metric_names=qc_list, 
    as_dataframe=True
)

qc.head(10)

In [None]:
# define curators and thresholds
firing_rate_threshold = 0.1
isi_violation_threshold = 0.6
snr_threshold = 4

In [None]:
num_frames = rec_to_sort.get_num_frames()

# firing rate threshold
sorting_curated = st.curation.threshold_firing_rates(
    sorting,
    duration_in_frames=num_frames,
    threshold=firing_rate_threshold, 
    threshold_sign='less'
)

# isi violation threshold
sorting_curated = st.curation.threshold_isi_violations(
    sorting_curated,
    duration_in_frames=num_frames,
    threshold=isi_violation_threshold, 
    threshold_sign='greater'
)

# isi violation threshold
sorting_curated = st.curation.threshold_snrs(
    sorting_curated,
    recording=rec_to_sort,
    threshold=snr_threshold, 
    threshold_sign='less'
)

In [None]:
print(f'Number of sorted units before curation: {len(sorting.get_unit_ids())}')
print(f'Number of sorted units after curation: {len(sorting_curated.get_unit_ids())}')

# 6) Quick save to NWB using Spikeinterface

Ref: https://pynwb.readthedocs.io/en/stable/pynwb.file.html#pynwb.file.NWBFile

In [None]:
# Save results to NWB file
output_nwb = 'openephys_si_results.nwb'

# Add customized Metadata info - Optional
session_start_time = recording._fileobj.experiments[0].datetime
session_start_time_tzaware = pytz.timezone('EST').localize(session_start_time)

metadata = se.NwbRecordingExtractor.get_nwb_metadata(recording=rec_to_sort)
metadata['NWBFile'].update(
    session_start_time=session_start_time_tzaware,
    session_description='a detailed description of this experimental session...',
    institution='NYU',
    lab='Movshon lab',
    pharmacology='Description of drugs used',
    experimenter=['Person1', 'Person2'],
    keywords=['openephys', 'tutorial', 'etc']
)
metadata['Ecephys']['Device'][0].update(description='a detailed description of this device')

# Write voltage traces data
se.NwbRecordingExtractor.write_recording(
    recording=rec_to_sort,
    save_path=output_nwb,
    overwrite=True,
    metadata=metadata
)

# Write spiking data
se.NwbSortingExtractor.write_sorting(
    sorting=sorting_curated,
    save_path=output_nwb,
    overwrite=False
)

In [None]:
# Check NWB file with widgets
from pynwb import NWBFile, NWBHDF5IO
from nwbwidgets import nwb2widget

io = NWBHDF5IO(output_nwb, 'r')
nwbfile = io.read()
nwb2widget(nwbfile)

In [None]:
io.close()

# 7) Include Expo trials with NWB conversion tools

In [None]:
from movshon_lab_to_nwb import MovshonExpoNWBConverter

In [None]:
# Source data
base_path = Path('/home/luiz/storage/taufferconsulting/client_ben/project_movshon/movshon_data/expo/exampledata/expo_openephys/m666l3#7')

expo_file = base_path / 'm666l3#7[ori16].xml'
ttl_file = base_path / 'openephys/100_ADC1.continuous'

source_data = dict(
    ExpoDataInterface=dict(
        expo_file=str(expo_file),
        ttl_file=str(ttl_file)
    )
)

# Initialize converter
converter = MovshonExpoNWBConverter(source_data=source_data)

# Conversion options
conversion_options = dict(
    ExpoDataInterface=dict(convert_expo=True)
)

converter.run_conversion(
    metadata={},
    nwbfile_path=output_nwb, 
    overwrite=False,
    conversion_options=conversion_options
)

In [None]:
io = NWBHDF5IO(output_nwb, 'r')
nwbfile = io.read()
nwb2widget(nwbfile)

In [None]:
io.close()