# Setup

In [None]:
from pathlib import Path
import platform
import os
import shutil
import numpy as np
import pandas as pd
from pprint import pprint
import matplotlib.pylab as plt
import spikeinterface.full as si    # may need to run pip install in the spikeinterface folder first
import probeinterface as pi
import warnings
warnings.filterwarnings("ignore", category=DeprecationWarning)
# import ipywidgets
%matplotlib widget

# Import my custom modules
from src import spikeinterface_hp 
from src import load_probe_rhd 
import importlib  # Allow module reloads in case any changes were made after starting the kernel: importlib.reload(spikeinterface_hp)

# Add path to kilosort2 and ironclust repos
si.Kilosort2Sorter.set_kilosort2_path(Path.resolve(Path('../../Kilosort-2.0')))
si.IronClustSorter.set_ironclust_path(Path.resolve(Path('../../ironclust')))

# Print SI version
print(f"SpikeInterface version: {si.__version__}")

# Print list of installed sorters
# print(si.installed_sorters())

# Spike sort with Ironclust

In [None]:
# Optional: modify params
# default_ic_params = si.IronClustSorter.default_params()
# pprint(default_ic_params)
# si.get_params_description('ironclust')

# Spike sort all files in spreadsheet with Ironclus
T = pd.read_excel(Path('D:/hannah/Dropbox/alab/Analysis/RECORDING_DEPTH_CHICK.xlsx'), sheet_name=0, header=0)
T = T.drop(T[T.exclude==1].index)
T_height = T.shape
T_height = T_height[0]
T_filename = T["filename"]
T_bad_chan = T["bad_chan"]

pprint(T)

overwrite = True

# for ii in range(21,T_height):
for ii in range(T_height):
    try:

        # Load the raw ephys recording
        results_path = Path('Z:/Hannah/ephys/project2/')/T_filename.iloc[ii]   #  A 32 channel recording, shank A only
        data_path = [f for f in results_path.iterdir() if f.is_dir() and f.name.startswith('raw')] # subfolder starting with raw*. Should only be one!
        data_path = data_path[0] # Set data_path to the folder containing the intan.rhd file etc
        print(data_path)
        recording_raw = spikeinterface_hp.read_intan_dat(data_path)
        # Load the probe for this file - should automatically deal with missing channels!
        probe = load_probe_rhd.H6(data_path/'info.rhd', results_path/'probe.json')
        recording_raw = recording_raw.set_probe(probe, group_mode="by_shank")

        # TODO: remove bad_chan (debugging below)
        # channnel_ids = recording_raw.channel_ids
        # good_ids = []
        # bad_ids = T_bad_chan.iloc[ii]
        # good_ids = [recording_raw.channel_ids[i] for i in range(recording_raw.get_num_channels()) if i not in bad_ids]
        # recording_raw = recording_raw.channel_slice(good_ids)
        # probe2 = recording_raw.get_probe()
        # si.plot_probe(probe2)

        # Run sorter if not already run
        sorter_name = 'ironclust'
        sorting_folder = results_path/(sorter_name+'_output')
        sorting_folder_si = results_path/(sorter_name+'_si_output')
        print(sorting_folder_si)

        # TEMP
        if sorting_folder.is_dir():
            print('Overwriting')
            shutil.rmtree(sorting_folder, ignore_errors=True)


        if sorting_folder_si.is_dir():
            if overwrite:
                print('Overwriting')
                shutil.rmtree(sorting_folder_si, ignore_errors=True)
            else:
                print('skipped, already run')
                continue
        sorting_ic = si.run_sorter(sorter_name, recording_raw, sorting_folder, verbose=True, remove_existing_folder=True) # filter=False, detect_threshold=1
        print(sorting_ic)

        # Save sorting in spike interface .npz format
        sorting_ic.save(folder=sorting_folder_si)

        # Delete the useless & large original ironclust output folder
        shutil.rmtree(sorting_folder, ignore_errors=True)

    except Exception:
        print('skipped, error')


In [None]:

# Spike sort all files in spreadsheet with Ironclus
T = pd.read_excel(Path('D:/hannah/Dropbox/alab/Analysis/RECORDING_DEPTH_CHICK.xlsx'), sheet_name=0, header=0)
T = T.drop(T[T.exclude==1].index)
T_height = T.shape
T_height = T_height[0]
T_filename = T["filename"]
T_bad_chan = T["bad_chan"]

recording_raw = spikeinterface_hp.read_intan_dat(data_path)
probe = load_probe_rhd.H6(data_path/'info.rhd', results_path/'probe.json')
recording_raw = recording_raw.set_probe(probe, group_mode="by_shank")

channnel_ids = recording_raw.channel_ids
pprint(channnel_ids)

bad_ids = np.array(eval(T_bad_chan.iloc[ii]))-1
# pprint(bad_ids)
good_ids = [recording_raw.channel_ids[i] for i in range(recording_raw.get_num_channels()) if i not in bad_ids]
# pprint(good_ids)
# recording_raw = recording_raw.channel_slice(range(32))

probe = recording_raw.get_probe()
# pi.plotting.plot_probe(probe,with_contact_id=True)
si.plot_probe_map(recording_raw, with_channel_index=True) #with_device_index #with_channel_index with_contact_id with_channel_ids
# si.plot_probe_map(recording, with_channel_ids=True)



# Find consensus units

In [None]:
# for ii in range(T_height):
ii = 21

# Load the raw ephys recording
results_path = Path('Z:/Hannah/ephys/project2/')/T_filename.iloc[ii]   #  A 32 channel recording, shank A only
print(results_path)

# Load two sortings
sorting_ks = si.read_kilosort(results_path/'kilosort2_output')
sorting_ic = si.load_extractor(results_path/'ironclust_si_output') # Load the sorting results (only needed if picking up from here)
print(sorting_ks)
print(sorting_ic)

# Compare two sorters
comp_ks_ic = si.compare_two_sorters(sorting1=sorting_ks, sorting2=sorting_ic, match_score=.2)  # returns SortingComparison object
match_to_ks = comp_ks_ic.hungarian_match_12
print('Relative to Kilosort2:')
pprint(match_to_ks)

# Plot agreement matrix *** comment if looping
si.plot_agreement_matrix(comp_ks_ic)

# Select only the matched units from the kilosort sorting
ks_ind = np.nonzero([match_to_ks!=-1])[1]
print(ks_ind)
sorting_clean = sorting_ks.select_units(ks_ind)
print(sorting_clean)

# Export the matching IC units relative to KS
match_to_ks = match_to_ks.rename('match')
match_filename = results_path/'si_match.csv'
match_to_ks.to_csv(match_filename) # header=False


# # Save all the best agreement scores for each KS cell (not necessarily the matched scores)
# # score_df = pd.DataFrame({'ks_ind': range(len(match_to_ks)), 'ic_ind':match_to_ks,'score': comp_ks_ic.agreement_scores.max(axis=1) })
# score_filename = results_path/'si_score_all.csv'
# comp_ks_ic.agreement_scores.to_csv(score_filename)

# Example of loading the curated phy output back in and selecting units marked "good"
# sorting_phy = si.PhySortingExtractor('path-to-phy-folder', exclude_cluster_groups=['noise'])
# good_ks_units = []
# for u in sorting_phy.get_unit_ids():
#     if sorting_phy.get_unit_property(u, 'KSLabel') == 'good':
#         good_ks_units.append(u)        
# sorting_ks_good = sorting_phy.select_units(good_ks_units)


# More post processing

Export final comparison sorting to Phy?
(Not necessary if I label good cells in existing kilosort folder - but save any manual sorting first)

In [None]:
%%script false
data_path = [f for f in results_path.iterdir() if f.is_dir() and f.name.startswith('raw')] # subfolder starting with raw*. Should only be one!
data_path = data_path[0] # Set data_path to the folder containing the intan.rhd file etc
print(data_path)
recording_raw = spikeinterface_hp.read_intan_dat(data_path)

# Load the probe for this file - should automatically deal with missing channels!
probe = load_probe_rhd.H6(data_path/'info.rhd', results_path/'probe.json')
recording_raw = recording_raw.set_probe(probe, group_mode="by_shank")

# Filter
recording_f = si.highpass_filter(recording_raw, freq_min=300) # Not recommended to throw out high frequencies before sortin
recording_cmr = si.common_reference(recording_f, reference='global', operator='median')

To export the original sortings to phy:

# Get waveforms
sub = '' #add onto output folders
waveform_folder = results_path/('clean_waveforms'+sub)
we = si.extract_waveforms(recording_cmr, sorting_clean, waveform_folder,return_scaled=True, overwrite=True,
    n_jobs=n_jobs, chunk_duration=chunk_duration, max_spikes_per_unit=max_spikes_per_unit,ms_before=ms_before, ms_after=ms_after)


chunk_duration = '1s'
max_ch_per_template = 1

# Export to Phy
phy_folder = results_path/('clean_phy'+sub)
we.recording = recording_raw # Just do this so that the path to the raw .dat file is stored in params.py
si.export_to_phy(we, output_folder=phy_folder,remove_if_exists=True,copy_binary=False,compute_amplitudes=True,
    max_channels_per_template=max_ch_per_template, chunk_duration=chunk_duration, n_jobs=n_jobs)




Get Quality metrics for KS sorting

In [None]:
# Extract waveforms
waveform_folder = results_path/('kilosort2_waveforms')
print(waveform_folder)
we = si.extract_waveforms(recording_cmr, sorting_ks, waveform_folder, return_scaled=True, overwrite=True, # TODO check effect of raw or cmr??
    n_jobs=n_jobs, chunk_duration=chunk_duration, max_spikes_per_unit=max_spikes_per_unit,ms_before=ms_before, ms_after=ms_after) #total_memory="10M",  
    #NOTE: return_scaled: If True and recording has gain_to_uV/offset_to_uV properties, waveforms are converted to uV.


# Get quality metrics for all KS sorting
waveform_folder = results_path/'kilosort2_waveforms'
we = si.WaveformExtractor.load_from_folder(waveform_folder) # Load the waveforms results (only needed if picking up from here)
metric_names=['snr', 'isi_violation', 'amplitude_cutoff','isolation_distance','firing_rate']
# NOTE: isi_violations_rate: "Rate of contaminating spikes as a fraction of overall rate. Higher values indicate more contamination" (== #sp_bad/#sp_total)
# Q: how to set bin width and refractory period for ISI violations?

qm = si.compute_quality_metrics(we,load_if_exists=True,metric_names=metric_names)

score_filename = results_path/'si_qm.csv'
qm.to_csv(score_filename)


Example: Quality metrics & automatic curation for single sorter

In [None]:
%%script false
waveform_folder = results_path/'kilosort2_waveforms'

# Load waveforms
we = si.WaveformExtractor.load_from_folder(waveform_folder) # Only needed if picking up from here

# Specifiy quality metrics
print(si.get_quality_metric_list())
metric_names=['snr', 'isi_violation', 'amplitude_cutoff','isolation_distance','firing_rate']
# NOTE: isi_violations_rate: "Rate of contaminating spikes as a fraction of overall rate. Higher values indicate more contamination" (== #sp_bad/#sp_total)
# Q: how to set bin width and refractory period for ISI violations?

# Compute quality metrics
qm = si.compute_quality_metrics(we,load_if_exists=True,metric_names=metric_names)
pprint(qm)

# Plot quality metrics
si.plot_quality_metrics(we, include_metrics=["amplitude_cutoff", "isi_violations_rate","firing_rate"])

# Screen sorting based on quality metrics
firing_rate_cutoff = 50
isi_viol_thresh =  0.1 # 0.2 is KS2 default
amplitude_cutoff_thresh = 0.05
our_query = f"firing_rate < {firing_rate_cutoff} & isi_violations_rate < {isi_viol_thresh} & amplitude_cutoff < {amplitude_cutoff_thresh} "

qm_keep = qm.query(our_query)
keep_unit_ids = qm_keep.index.values
sorting_clean = we.sorting.select_units(keep_unit_ids)
print(sorting)
print(sorting_clean)


Launch phy


In [None]:

%%script false # don't run this cell!

# Launch Phy
phy_folder = results_path/'clean_phy'
from phy.apps.template import template_gui
template_gui(phy_folder/'params.py')
Parking lot
%%script false # don't run this cell!

# Check data type of .npy file
from pprint import pprint
from pathlib import Path
import numpy as np
# file = Path(r"D:\hannah\Dropbox\code\spikesort\spikesort-hp\results\kilosort2_waveforms\spike_amplitudes\amplitude_segment_0.npy")
file = Path(r"Z:\Hannah\ephys\project2\HC05_220819\intersect_waveforms\spike_amplitudes\amplitude_segment_0.npy")
print(file)
temp = np.load(file)
print(type(temp))
print(type(temp[0]))
pprint(temp)
np.shape(temp)



