# Set up

In [1]:
x=0

In [2]:
%matplotlib inline
import matplotlib.pyplot as plt
import os
import numpy as np
import pandas as pd
import pickle
import spikeinterface.full as si
import sys
sys.path.append('/mnt/cube/tsmcpher/code/')
from ephys_tsm import spike_util as su

# Prep data

In [3]:
# Thresholds for quality metric curation
isi_vr_thresh = [0.1,0.5]
snr_thresh = [1,2]

In [68]:
# Probe absolute location (unit locations are relative)
# SI formatting: probe width (x), depth (y), othogonal (z)
# Assuming flat of probe extends M/L, foot is anterior, and vertical implant, use:
# Note at angle VENTRAL is how far probe is lowered into brain
# Angle is deviation from vertical
# M/L (x), D/V (y), A/P (z)
# s_b1484_24, HVC right: [3800,-200,800], angle: 0
# s_b1357_23, HVC left: [-3500,-800,1000], angle: 0
# s_b1253_21, RA right: [3380,-4500,630], angle: 0
# s_b1253_21, RA left: [-3380,-4500,630], angle: 52
probe_angle_deg = 38
probe_abs_loc = np.array([-3000,-4000,500])

In [5]:
bird_in = 's_b1360_24'
sess_in = '2024-07-30'
ephys_software_in = 'sglx'
path_in = '/mnt/cube/chronic_ephys/der/{}/{}/{}/'.format(bird_in,sess_in,ephys_software_in)
epochs = os.listdir(path_in)
print(epochs)

['1108_g0', '0806_g0', '1410_g0']


In [6]:
epoch_i = 1
epoch_in = epochs[epoch_i]
sess_par = {
    'bird':bird_in, # bird id
    'sess':sess_in, # session date
    'epoch':epoch_in, # epoch
    'ephys_software':ephys_software_in, # recording software, sglx or oe
    'sorter':'kilosort4', # spike sorting algorithm
    'sort':0} # sort index
sort_dir = '/mnt/cube/chronic_ephys/der/{}/{}/{}/{}/{}/{}/'.format(
    sess_par['bird'],sess_par['sess'],sess_par['ephys_software'],
    sess_par['epoch'],sess_par['sorter'],sess_par['sort'])
sess_par,sort_dir

({'bird': 's_b1360_24',
  'sess': '2024-07-30',
  'epoch': '0806_g0',
  'ephys_software': 'sglx',
  'sorter': 'kilosort4',
  'sort': 0},
 '/mnt/cube/chronic_ephys/der/s_b1360_24/2024-07-30/sglx/0806_g0/kilosort4/0/')

In [7]:
sort_path = sort_dir + 'sorter_output/'
analyzer_path = sort_dir + 'sorting_analyzer/'
waveforms_path = sort_dir + 'waveforms/'
if os.path.exists(analyzer_path):
    print('sorting analyzer..')
    use_analyzer_not_wave = True
    metrics_path = analyzer_path + 'extensions/quality_metrics/metrics.csv'
    analyzer = si.load_sorting_analyzer(analyzer_path)
else:
    if os.path.exists(waveforms_path):
        print('waveforms..')
        use_analyzer_not_wave = False
        metrics_path = waveforms_path + 'quality_metrics/metrics.csv'
        analyzer = si.load_waveforms(waveforms_path)
    else: print('no analyzer or waveforms..')
metrics_pd = pd.read_csv(metrics_path)
metrics_list = metrics_pd.keys().tolist()
for this_metric in metrics_list:
    analyzer.sorting.set_property(this_metric,metrics_pd[this_metric].values)
isi_vr_label = np.full(analyzer.sorting.get_num_units(),'l')
isi_vr_label[np.where((analyzer.sorting.get_property('isi_violations_ratio') > isi_vr_thresh[0]) & 
                (analyzer.sorting.get_property('isi_violations_ratio') < isi_vr_thresh[1]))[0]] = 'm'
isi_vr_label[np.where(analyzer.sorting.get_property('isi_violations_ratio') > isi_vr_thresh[1])[0]] = 'h'  
analyzer.sorting.set_property('isi_vr_thresh',isi_vr_label)
snr_label = np.full(analyzer.sorting.get_num_units(),'l')
snr_label[np.where((analyzer.sorting.get_property('snr') > snr_thresh[0]) & 
                (analyzer.sorting.get_property('snr') < snr_thresh[1]))[0]] = 'm'
snr_label[np.where(analyzer.sorting.get_property('snr') > snr_thresh[1])[0]] = 'h' 
analyzer.sorting.set_property('snr_thresh',snr_label)
quality_labels = np.full(analyzer.sorting.get_num_units(),'_____')
quality_labels[np.where(isi_vr_label == 'h')[0]] = 'mua_4'
quality_labels[np.where((isi_vr_label == 'l') & (snr_label == 'h'))[0]] = 'sua_1'
quality_labels[np.where((isi_vr_label == 'l') & (snr_label == 'm'))[0]] = 'sua_2'
quality_labels[np.where((isi_vr_label == 'm') & (snr_label == 'h'))[0]] = 'sua_2'
quality_labels[np.where((isi_vr_label == 'm') & (snr_label == 'm'))[0]] = 'sua_3'
quality_labels[np.where(snr_label == 'l')[0]] = 'noise'
analyzer.sorting.set_property('quality_labels',quality_labels)
su.print_unit_counts(quality_labels)
print(analyzer.sorting); print(analyzer)

sorting analyzer..
sua_1: 49
sua_2: 65
sua_3: 65
mua_4: 194
noise: 9
total: 319
NumpySorting: 319 units - 1 segments - 30.0kHz
SortingAnalyzer: 384 channels - 319 units - 1 segments - binary_folder - sparse - has recording
Loaded 14 extensions: correlograms, template_similarity, principal_components, random_spikes, templates, unit_locations, waveforms, template_metrics, isi_histograms, amplitude_scalings, noise_levels, spike_locations, quality_metrics, spike_amplitudes


# Auto curation

In [8]:
%%time
merges_auto_init_all = []
merges_auto_all = []
sort_auto_all = []

presets_all = ['similarity_correlograms','x_contaminations','temporal_splits','feature_neighbors']
for this_preset in presets_all:
    print(this_preset + '..')
    merges_auto_init = si.get_potential_auto_merge(analyzer,preset=this_preset)
    merges_auto = su.merge_lists(merges_auto_init)
    print(merges_auto)
    if len(merges_auto) > 0: sort_auto = si.MergeUnitsSorting(analyzer.sorting,merges_auto)
    else: sort_auto = analyzer.sorting
    print(sort_auto)
    
    merges_auto_init_all.append(merges_auto_init)
    merges_auto_all.append(merges_auto_all)
    sort_auto_all.append(sort_auto)

similarity_correlograms..
[]
NumpySorting: 319 units - 1 segments - 30.0kHz
x_contaminations..
[[136, 143]]
MergeUnitsSorting: 318 units - 1 segments - 30.0kHz
temporal_splits..
[]
NumpySorting: 319 units - 1 segments - 30.0kHz
feature_neighbors..
[[1], [33, 2, 34, 5, 15, 16, 20, 21, 30, 31], [3], [4], [22], [26], [39], [40], [41], [44], [45], [48], [51, 54, 73, 59], [56], [61], [65], [66], [68], [69], [71], [72], [75], [76], [80], [82], [84], [88], [100], [101], [104], [109], [131, 136, 143], [137], [200], [202], [203], [208], [216], [217], [226], [232], [233], [256, 261], [293]]
MergeUnitsSorting: 304 units - 1 segments - 30.0kHz
CPU times: user 5min 7s, sys: 8.66 s, total: 5min 16s
Wall time: 4min 32s


# Manual curation

In [9]:
unit_table_properties = ['quality_labels','KSLabel','isi_violations_ratio','snr','num_spikes']
label_choices = ['sua_1','sua_2','sua_3','mua_4','noise']
pss = si.plot_sorting_summary(analyzer,curation=True,backend='sortingview',
                              unit_table_properties=unit_table_properties,label_choices=label_choices)

Computing sha1 of /home/AD/tsmcpher/.kachery-cloud/tmp_chVMxZKl/file.dat
https://figurl.org/f?v=npm://@fi-sci/figurl-sortingview@12/dist&d=sha1://8bd38be2b2a07d44872172e04718c9c740300c18


In [10]:
# sha_uri = 'sha1://eb06f8e626926b4dc905f7561617b2978530b2f7'

In [11]:
# %%capture cap
# sort_curated = si.apply_sortingview_curation(sorting=analyzer.sorting,uri_or_json=sha_uri,verbose=True)

In [12]:
# merge_str_all = cap.stdout
# merge_starts = su.str_find('[',merge_str_all)
# merge_stops = su.str_find(']',merge_str_all)
# merges_curated = [merge_str_all[merge_starts[i]+1:merge_stops[i]].split(',') for i in range(len(merge_starts))]
# quality_labels = np.full(sort_curated.get_num_units(),'_____')
# for this_label in label_choices:
#     quality_labels[np.where(sort_curated.get_property(this_label) == True)[0]] = this_label
# sort_curated.set_property('quality_labels',quality_labels)
# su.print_unit_counts(quality_labels)

# Depth labels

In [53]:
depth_dict = {'hvc':[[1910,None]],
              'ncm':[[840,1890]],
              'bad':[[None,840],[1890,1910]]
             }
unit_locations = list(si.compute_unit_locations(analyzer))  
probe_depth = np.vstack(unit_locations)[:,1]
depth_labels = np.full(probe_depth.shape,'XXX')
labels_is_all = []
for this_label in depth_dict.keys():
    depth_ranges = depth_dict[this_label]
    for this_range in depth_ranges:
        lower_bound = this_range[0]; upper_bound = this_range[1]
        if lower_bound == None: lower_bound = np.min(probe_depth)
        if upper_bound == None: upper_bound = np.max(probe_depth)
        label_is = list(np.where((probe_depth >= lower_bound) & (probe_depth <= upper_bound))[0])
        labels_is_all.append(label_is)
        depth_labels[label_is] = this_label
        print(this_label,len(label_is))
assert len(sum(labels_is_all,[])) == len(depth_labels)
depth_dict['depth_labels'] = depth_labels
depth_dict

hvc 44
ncm 32
bad 117
bad 126


{'hvc': [[1910, None]],
 'ncm': [[840, 1890]],
 'bad': [[None, 840], [1890, 1910]],
 'depth_labels': array(['bad', 'bad', 'bad', 'bad', 'bad', 'bad', 'bad', 'bad', 'bad',
        'bad', 'bad', 'bad', 'bad', 'bad', 'bad', 'bad', 'bad', 'bad',
        'bad', 'bad', 'bad', 'bad', 'bad', 'bad', 'bad', 'bad', 'bad',
        'bad', 'bad', 'bad', 'bad', 'bad', 'bad', 'bad', 'bad', 'bad',
        'bad', 'bad', 'bad', 'bad', 'bad', 'bad', 'bad', 'bad', 'bad',
        'bad', 'bad', 'bad', 'bad', 'bad', 'bad', 'bad', 'bad', 'bad',
        'bad', 'bad', 'bad', 'bad', 'bad', 'bad', 'bad', 'bad', 'bad',
        'bad', 'bad', 'bad', 'bad', 'bad', 'bad', 'bad', 'bad', 'bad',
        'bad', 'bad', 'bad', 'bad', 'bad', 'bad', 'bad', 'bad', 'bad',
        'bad', 'bad', 'bad', 'bad', 'bad', 'bad', 'bad', 'bad', 'bad',
        'bad', 'bad', 'bad', 'bad', 'bad', 'bad', 'bad', 'bad', 'bad',
        'bad', 'bad', 'bad', 'bad', 'bad', 'bad', 'bad', 'bad', 'bad',
        'bad', 'bad', 'bad', 'bad', 'bad', 'bad'

# Save out

In [63]:
sort_in = sort_auto_all[0]  # sort_curated sort_auto
merges_in = []#merges_auto_all[0] # merges_curated merges_auto

In [65]:
# get unit IDs
unit_ids = sort_in.get_unit_ids()
print(f"{len(unit_ids)} units after curation:")
iui = analyzer.sorting.get_unit_ids() # initial unit IDs
utm = [[int(x) for x in m] for m in merges_in] # units to merge
nui = np.arange(max(iui)+1, max(iui)+len(utm)+1) # new unit IDs
# set merged properties to unit with highest original spike rate
orig_unit_ids = [[x] for x in unit_ids]
not_max_spikes_is_all = []
for i, u in enumerate(utm):
    print(f'- Units {u} merged to {nui[i]}')
    idx = [np.where(iui == x)[0][0] for x in u]
    u_n_spks = analyzer.sorting.get_property('num_spikes')[idx]
    max_spikes_i = idx[np.argmax(u_n_spks)]
    not_max_spikes_is = [idx[nmi] for nmi in list(np.where(idx != max_spikes_i)[0])]
    nui_i = np.where(unit_ids == nui[i])[0][0]
    for this_metric in analyzer.sorting.get_property_keys():
        sort_in.get_property(this_metric)[nui_i] = analyzer.sorting.get_property(this_metric)[max_spikes_i]
    sort_in.get_property('num_spikes')[nui_i] = np.sum(u_n_spks)
    not_max_spikes_is_all.append(not_max_spikes_is)
if len(not_max_spikes_is_all) > 0:
    merged_unit_locations = np.delete(np.array(unit_locations),np.concatenate(not_max_spikes_is_all),axis=0)
else:
    merged_unit_locations = unit_locations
sort_final = sort_in

319 units after curation:


In [66]:
spk_df = pd.DataFrame({'unit': unit_ids})
spk_df['spike_train'] = spk_df['unit'].apply(lambda x: sort_final.get_unit_spike_train(unit_id=x, segment_index=0))
spk_df['unit_locations'] = list(merged_unit_locations)
spk_df['depth_labels'] = list(depth_labels)
spk_df['probe_location'] = spk_df.apply(su.add_probe_loc,axis=1)
spk_df['probe_angle'] = spk_df.apply(su.add_probe_angle,axis=1)
for prop in sort_final.get_property_keys():
    spk_df[prop] = sort_final.get_property(prop)
spk_df = spk_df.drop(columns=['original_cluster_id'])
spk_df['orig_unit'] = orig_unit_ids
spk_df.keys()

NameError: name 'probe_abs_loc' is not defined

In [67]:
spk_df

Unnamed: 0,unit,spike_train,unit_locations,depth_labels
0,0,"[123951, 147783, 152969, 153105, 191184, 23773...","[-14.073065567470673, 17.670251844640863, 1.40...",bad
1,1,"[73474, 96036, 138946, 160850, 173590, 266314,...","[-6.960734860035025, 36.56431000266426, 9.0094...",bad
2,2,"[363629, 1322175, 1322249, 2609953, 4171529, 4...","[28.243145302930486, 16.397182659697762, 1.009...",bad
3,3,"[501, 1010, 1484, 1906, 2419, 3113, 3427, 3741...","[3.345220529178426, 32.80625933021123, 1.00000...",bad
4,4,"[3434, 13800, 23463, 32599, 33698, 36838, 3748...","[6.003732926989229, 59.70841970244248, 1.00000...",bad
...,...,...,...,...
314,314,"[1321901, 1642489, 3872859, 4350387, 5900139, ...","[32.29415469029789, 1903.1954691834296, 2.2189...",bad
315,315,"[2116, 2620, 6116, 10611, 11111, 11118, 11613,...","[28.47801948110492, 1901.4852407894023, 1.0001...",bad
316,316,"[107, 605, 636, 1107, 1608, 3606, 3635, 4134, ...","[28.080712798162818, 1901.98421932352, 1.00004...",bad
317,317,"[345041, 361299, 363166, 793182, 1313623, 1643...","[21.119370966110907, 3828.3214553920816, 1.000...",hvc


In [27]:
with open(os.path.join(sort_dir,'spk_df.pkl'), 'wb') as handle:
    pickle.dump(spk_df,handle)