In [None]:
%load_ext autoreload
%autoreload 2

In [None]:
import os, time, shutil, sys
#from pathlib import Path
from glob import glob
from pprint import pprint
import numpy as np
import sklearn as sk
import spikeinterface.full as si
import matplotlib.pyplot as plt
import warnings

sys.path.append("/home/phornauer/Git/axon_tracking/")
from axon_tracking import spike_sorting as ss
from axon_tracking import template_extraction as te

In [None]:
root_path = "/net/bs-filesvr02/export/group/hierlemann/recordings/Maxtwo/phornauer/" # Fixed path root that all recordings have in common
path_pattern = ["231207", "Chemo*",  "T002443","AxonTracking","0*"] # Variable part of the path, where we collect all possible combinations using wildcards (*). It is still recommended to be as specific as possible to avoid ambiguities.
file_name = "data.raw.h5" # File name of the recording

full_path = os.path.join(root_path, *path_pattern, file_name)
path_list = glob(full_path)
print(f'Found {len(path_list)} recording paths matching the description:\n{full_path}\n')
pprint(path_list)

In [None]:
save_path_changes = {'pos': [0, 6, 7, 8, 9, 10, 11, 12, 13, 14], 'vals': ['/', 'intermediate_data', 'Maxtwo', 'phornauer','Chemogenetics_2', 'Week_2', 'T002443', 'AxonTracking', '','']}

save_path = ss.convert_rec_path_to_save_path(full_path, save_path_changes)
    
print(f'The save path corresponds to the pattern:\n {save_path}\n')

In [None]:
sorting_dict = te.find_successful_sortings(path_list, save_path_changes)
print(f'Found {sum(map(len, sorting_dict.values()))} successful sortings')

In [None]:
te_params = dict()
te_params['align_cutout'] = True #Align waveforms by max waveform peak
te_params['upsample'] = 2 #Factor by which to upsample waveforms
te_params['rm_outliers'] = True #Check if outliers should be removed
te_params['n_jobs'] = 16 #Number of cores to use for waveform extraction
te_params['n_neighbors'] = 10 #Number of neighbors for outlier detection
te_params['peak_cutout'] = 2 #Looking for peak +- this value around the expected peak (removing minor offsets)
te_params['overwrite_wf'] = False #Flag if waveform extraction should be repeated (e.g. different cutouts)
te_params['overwrite_tmp'] = True #Flag if templates should be recalculated if already existing

qc_params = dict()
qc_params['min_n_spikes'] = 500 #Minimum number of spikes to be detected for a unit for template extraction to take place
qc_params['exclude_mua'] = True #Exclude units that were labelled multi unit activity by kilosort

In [None]:
#To suppress warning when the outlier detection has too few samples
warnings.simplefilter("ignore")

te.extract_templates_from_sorting_dict(sorting_dict, qc_params, te_params)

In [None]:
root_path = "/net/bs-filesvr02/export/group/hierlemann/intermediate_data/Maxtwo/phornauer/iNeurons/230731/T002443/AxonTracking/"
stream_id = 'well009'
template_id = 189

In [None]:
sorting_path = "/net/bs-filesvr02/export/group/hierlemann/intermediate_data/Maxtwo/phornauer/iNeurons/230731/T002443/AxonTracking/well009/sorter_output/"
sorting = si.KiloSortSortingExtractor(sorting_path)
stream_id = 'well009' #Find out which well this belongs to
#print(stream_id)
#rec_names, common_el, pos = ss.find_common_electrodes(rec_path, stream_id)

In [None]:
rec_path = "/net/bs-filesvr02/export/group/hierlemann/recordings/Maxtwo/mpriouret/iNeurons/230731/T002443/AxonTracking/000150/data.raw.h5"

In [None]:
multirecording, pos = ss.concatenate_recording_slices(rec_path, stream_id)          

In [None]:
cleaned_sorting = te.select_units(sorting, **qc_params)
#cleaned_sorting = si.remove_excess_spikes(cleaned_sorting, multirecording) #Relevant if last spike time == recording_length
cleaned_sorting.register_recording(multirecording)
segment_sorting = si.SplitSegmentSorting(cleaned_sorting, multirecording)

In [None]:
sel_unit_id = 189
save_root = sorting_path
template_matrix = te.combine_templates(stream_id, segment_sorting, sel_unit_id, save_root, **te_params)

In [None]:
aligned_and_removed = te.convert_to_grid(template_matrix, pos)

In [None]:
plt.imshow(np.min(aligned_and_removed,axis=2).T,vmin=-5, vmax=0)
plt.show()

In [None]:
te_params['align_cutout'] = False #Align waveforms by max waveform peak
template_matrix = te.combine_templates(stream_id, segment_sorting, sel_unit_id, save_root, **te_params)
not_aligned_but_removed = te.convert_to_grid(template_matrix, pos)

In [None]:
plt.imshow(np.min(not_aligned_but_removed,axis=2).T,vmin=-5, vmax=0)
plt.show()

In [None]:
te_params['rm_outliers'] = False #Check if outliers should be removed
template_matrix = te.combine_templates(stream_id, segment_sorting, sel_unit_id, save_root, **te_params)
not_aligned_not_removed = te.convert_to_grid(template_matrix, pos)

In [None]:
plt.imshow(np.min(not_aligned_not_removed,axis=2).T,vmin=-2, vmax=0)
plt.show()

In [None]:
plot_data = np.diff(not_aligned_not_removed)
plt.imshow(np.min(plot_data,axis=2).T,vmin=-2, vmax=0)
plt.show()

In [None]:
plt.plot(template_matrix)
plt.show()