# Setup

Imports

In [None]:
from pathlib import Path
import platform
import os
import shutil
import time
import numpy as np
import pandas as pd
from pprint import pprint
import matplotlib.pylab as plt
import spikeinterface.full as si    # may need to run pip install in the spikeinterface folder first
import probeinterface as pi
import warnings
warnings.filterwarnings("ignore", category=DeprecationWarning)
# import ipywidgets
%matplotlib widget

# Import my custom modules
from src import spikeinterface_hp 
# from src import rhdutilities 
from src import load_probe_rhd 
import importlib  # Allow module reloads in case any changes were made after starting the kernel: importlib.reload(spikeinterface_hp)

# Add path to kilosort2 and ironclust repos
si.Kilosort2Sorter.set_kilosort2_path(Path.resolve(Path('../../Kilosort-2.0')))
si.IronClustSorter.set_ironclust_path(Path.resolve(Path('../../ironclust')))

# Print SI version
print(f"SpikeInterface version: {si.__version__}")

# Print list of installed sorters
# print(si.installed_sorters())

Load data from the server

In [None]:
# Set paths that are used throughout: results and data
# results_path = Path('Z:/Hannah/ephys/project2/HC07_221014') 
# results_path = Path('Z:/Hannah/ephys/project2/HC07_221015') 
# results_path = Path('Z:/Hannah/ephys/project2/HC07_221019')  Y
# results_path = Path('Z:/Hannah/ephys/project2/HC07_221021')  Y
# results_path = Path('Z:/Hannah/ephys/project2/HC07_221024')

results_path = Path('Z:/Hannah/ephys/project2/HC05_220819')   #  A 32 channel recording, shank A only

data_path = [f for f in results_path.iterdir() if f.is_dir() and f.name.startswith('raw')] # subfolder starting with raw*. Should only be one!
data_path = data_path[0] # Set data_path to the folder containing the intan.rhd file etc
print(data_path)

importlib.reload(spikeinterface_hp)
recording_raw = spikeinterface_hp.read_intan_dat(data_path)

# Load the probe for this file - should automatically deal with missing channels!
importlib.reload(load_probe_rhd)
probe = load_probe_rhd.H6(data_path/'info.rhd', results_path/'probe.json')
recording_raw = recording_raw.set_probe(probe, group_mode="by_shank")

# Visualizations/checks
probe = recording_raw.get_probe() # sorts by index in .dat file
print(probe.to_dataframe(complete=True).loc[:, ['x','y','contact_ids', 'shank_ids', 'device_channel_indices']])
# pi.plotting.plot_probe(probe, with_channel_index=True) 
pi.plotting.plot_probe(probe, with_device_index=True) # Now the channel index matches the device index bc it was sorted


Bandpass filter and apply common median reference

In [None]:
# Filter
recording_f = si.highpass_filter(recording_raw, freq_min=300) # Not recommended to throw out high frequencies before sortin
recording_cmr = si.common_reference(recording_f, reference='global', operator='median')

# Plot the traces after applying CMR:
w = si.plot_timeseries({"filt": recording_f, "common": recording_cmr},
    order_channel_by_depth=True, time_range=[10,11], channel_ids=range(16), show_channel_ids=True)

General sorting/waveform/phy settings

In [None]:
# waveform extractor settings
ms_before = 1.5 # default 3
ms_after = 2.5  # default 4
max_spikes_per_unit = 500

# export_to_phy setting
max_ch_per_template = 8
chunk_duration = '10s'
n_jobs = -1

Load Kilosort2 sorting (do this separately in matlab!)

In [None]:
# To export the original sortings to phy:
sorting_ks = si.read_kilosort(results_path/'kilosort2_output')

# Extract waveforms
waveform_folder = results_path/('kilosort2_waveforms')
print(waveform_folder)
we = si.extract_waveforms(recording_cmr, sorting_ks, waveform_folder, return_scaled=True, overwrite=True, # TODO check effect of raw or cmr??
    n_jobs=n_jobs, chunk_duration=chunk_duration, max_spikes_per_unit=max_spikes_per_unit,ms_before=ms_before, ms_after=ms_after) #total_memory="10M",  
    #NOTE: return_scaled: If True and recording has gain_to_uV/offset_to_uV properties, waveforms are converted to uV.


Run Ironclust

In [None]:

# Run sorter
sorter_name = 'ironclust'
sorting_folder = results_path/(sorter_name+'_output')
sorting_ic = si.run_sorter(sorter_name, recording_raw, sorting_folder, verbose=True, remove_existing_folder=True) # filter=False, detect_threshold=1
print(sorting_ic)

# Save sorting in spike interface .npz format
sorting_folder = results_path/(sorter_name+'_si_output')
sorting_ic.save(folder=sorting_folder)
print(sorting_folder)



# Post-processing



Get Quality metrics for KS sorting

In [None]:
# Get quality metrics for all KS sorting
waveform_folder = results_path/'kilosort2_waveforms'
we = si.WaveformExtractor.load_from_folder(waveform_folder) # Load the waveforms results (only needed if picking up from here)
metric_names=['snr', 'isi_violation', 'amplitude_cutoff','isolation_distance','firing_rate']
# NOTE: isi_violations_rate: "Rate of contaminating spikes as a fraction of overall rate. Higher values indicate more contamination" (== #sp_bad/#sp_total)
# Q: how to set bin width and refractory period for ISI violations?

qm = si.compute_quality_metrics(we,load_if_exists=True,metric_names=metric_names)

score_filename = results_path/'si_qm.csv'
qm.to_csv(score_filename)


Compare with Kilosort2 as reference

In [None]:

# Load sortings from .npz format (SI preferred):
sorting_ks = si.read_kilosort(results_path/'kilosort2_output')
sorting_ic = si.load_extractor(results_path/'ironclust_si_output') # Load the sorting results (only needed if picking up from here)
# sorting_ks = si.read_phy(results_path/'kilosort2_phy') # Alternate if saved as phy

sub = '' #add onto output folders

# Compare two sorters
comp_ks_ic = si.compare_two_sorters(sorting1=sorting_ks, sorting2=sorting_ic)  # returns SortingComparison object
print('Relative to Kilosort2:')
match_to_ks = comp_ks_ic.hungarian_match_12
match_to_ic = comp_ks_ic.hungarian_match_21

# Save all the best agreement scores for each KS cell (not necessarily the matched scores)
# score_df = pd.DataFrame({'ks_ind': range(len(match_to_ks)), 'ic_ind':match_to_ks,'score': comp_ks_ic.agreement_scores.max(axis=1) })
score_filename = results_path/'si_score_all.csv'
comp_ks_ic.agreement_scores.to_csv(score_filename)

# Export the matching IC units relative to KS
match_to_ks = match_to_ks.rename('match')
match_filename = results_path/'si_match.csv'
match_to_ks.to_csv(match_filename) # header=False

# Select only the matched units from the kilosort sorting
ks_ind = np.nonzero([match_to_ks!=-1])[1]
print(ks_ind)
sorting_clean = sorting_ks.select_units(ks_ind)
print(sorting_clean)


# Get waveforms
waveform_folder = results_path/('clean_waveforms'+sub)
we = si.extract_waveforms(recording_cmr, sorting_clean, waveform_folder,return_scaled=True, overwrite=True,
    n_jobs=n_jobs, chunk_duration=chunk_duration, max_spikes_per_unit=max_spikes_per_unit,ms_before=ms_before, ms_after=ms_after)


chunk_duration = '1s'
max_ch_per_template = 1

# Export to Phy
phy_folder = results_path/('clean_phy'+sub)
we.recording = recording_raw # Just do this so that the path to the raw .dat file is stored in params.py
si.export_to_phy(we, output_folder=phy_folder,remove_if_exists=True,copy_binary=False,compute_amplitudes=True,
    max_channels_per_template=max_ch_per_template, chunk_duration=chunk_duration, n_jobs=n_jobs)


Example: Quality metrics & automatic curation for single sorter

In [None]:
%%script false
waveform_folder = results_path/'kilosort2_waveforms'

# Load waveforms
we = si.WaveformExtractor.load_from_folder(waveform_folder) # Only needed if picking up from here

# Specifiy quality metrics
print(si.get_quality_metric_list())
metric_names=['snr', 'isi_violation', 'amplitude_cutoff','isolation_distance','firing_rate']
# NOTE: isi_violations_rate: "Rate of contaminating spikes as a fraction of overall rate. Higher values indicate more contamination" (== #sp_bad/#sp_total)
# Q: how to set bin width and refractory period for ISI violations?

# Compute quality metrics
qm = si.compute_quality_metrics(we,load_if_exists=True,metric_names=metric_names)
pprint(qm)

# Plot quality metrics
si.plot_quality_metrics(we, include_metrics=["amplitude_cutoff", "isi_violations_rate","firing_rate"])

# Screen sorting based on quality metrics
firing_rate_cutoff = 50
isi_viol_thresh =  0.1 # 0.2 is KS2 default
amplitude_cutoff_thresh = 0.05
our_query = f"firing_rate < {firing_rate_cutoff} & isi_violations_rate < {isi_viol_thresh} & amplitude_cutoff < {amplitude_cutoff_thresh} "

qm_keep = qm.query(our_query)
keep_unit_ids = qm_keep.index.values
sorting_clean = we.sorting.select_units(keep_unit_ids)
print(sorting)
print(sorting_clean)


Launch phy


In [None]:

%%script false # don't run this cell!

# Launch Phy
phy_folder = results_path/'clean_phy'
from phy.apps.template import template_gui
template_gui(phy_folder/'params.py')
Parking lot
%%script false # don't run this cell!

# Check data type of .npy file
from pprint import pprint
from pathlib import Path
import numpy as np
# file = Path(r"D:\hannah\Dropbox\code\spikesort\spikesort-hp\results\kilosort2_waveforms\spike_amplitudes\amplitude_segment_0.npy")
file = Path(r"Z:\Hannah\ephys\project2\HC05_220819\intersect_waveforms\spike_amplitudes\amplitude_segment_0.npy")
print(file)
temp = np.load(file)
print(type(temp))
print(type(temp[0]))
pprint(temp)
np.shape(temp)


# Example of loading the curated phy output back in and selecting units marked "good"
sorting_phy = si.PhySortingExtractor('path-to-phy-folder', exclude_cluster_groups=['noise'])
good_ks_units = []
for u in sorting_phy.get_unit_ids():
    if sorting_phy.get_unit_property(u, 'KSLabel') == 'good':
        good_ks_units.append(u)        
sorting_ks_good = sorting_phy.select_units(good_ks_units)
