In [None]:
import os, sys
import spikeinterface.full as si
import h5py
import numpy as np
from tqdm import tqdm
from glob import glob
sys.path.append("/home/phornauer/Git/axon_tracking/")
from axon_tracking import spike_sorting as ss
from axon_tracking import template_extraction as te
import matplotlib.pyplot as plt

In [None]:
te_params = dict()
te_params['align_cutout'] = True #Align waveforms by max waveform peak
te_params['upsample'] = 2 #Factor by which to upsample waveforms
te_params['rm_outliers'] = True #Check if outliers should be removed
te_params['n_jobs'] = 16 #Number of cores to use for waveform extraction
te_params['n_neighbors'] = 10 #Number of neighbors for outlier detection
te_params['peak_cutout'] = 2 #Looking for peak +- this value around the expected peak (removing minor offsets)
te_params['overwrite_wf'] = False #Flag if waveform extraction should be repeated (e.g. different cutouts)
te_params['overwrite_tmp'] = True #Flag if templates should be recalculated if already existing

qc_params = dict()
qc_params['min_n_spikes'] = 500 #Minimum number of spikes to be detected for a unit for template extraction to take place
qc_params['exclude_mua'] = True #Exclude units that were labelled multi unit activity by kilosort

In [None]:
sorting_list = ['/net/bs-filesvr02/export/group/hierlemann/intermediate_data/Maxtwo/phornauer/iNeurons/240618/T002523/AxonTracking/well006']

In [None]:
for sorting_path in tqdm(sorting_list):
    output_path = os.path.join(sorting_path, "sorter_output")
    sorting = si.KiloSortSortingExtractor(output_path)
    json_path = os.path.join(sorting_path, "spikeinterface_recording.json")
    multirecording = si.load_extractor(json_path, base_folder=True)
    


In [None]:
rec_path = ss.get_recording_path(multirecording)
stream_id = [p for p in sorting_path.split("/") if p.startswith("well")][
    0
]  # Find out which well this belongs to

rec_names, common_el, pos = ss.find_common_electrodes(rec_path, stream_id)
cleaned_sorting = te.select_good_units(sorting, **qc_params)
cleaned_sorting = si.remove_excess_spikes(
    cleaned_sorting, multirecording
)  # Relevant if last spike time == recording_length
cleaned_sorting.register_recording(multirecording)
segment_sorting = si.SplitSegmentSorting(cleaned_sorting, multirecording)

In [None]:
output_path = '/net/bs-filesvr02/export/group/hierlemann/intermediate_data/Maxtwo/phornauer/AxonScan/Test'

In [None]:
sel_unit_ids = segment_sorting.get_unit_ids()
template_save_path = os.path.join(output_path, "templates")
if not os.path.exists(template_save_path):
    os.makedirs(template_save_path)

full_path = ss.get_recording_path(segment_sorting)
cutout_samples, cutout_ms = te.get_assay_information(full_path)

In [None]:
full_path = ss.get_recording_path(segment_sorting)

h5 = h5py.File(full_path)
rec_names = list(h5["wells"][stream_id].keys())

In [None]:
for sel_idx, rec_name in enumerate(rec_names):
    wf_path = os.path.join(output_path, "waveforms", "seg" + str(sel_idx))

In [None]:
rec = si.MaxwellRecordingExtractor(
        full_path, stream_id=stream_id, rec_name=rec_name
    )
chunk_size = (
    np.min([10000, rec.get_num_samples()]) - 100
)  # Fallback for ultra short recordings (too little activity)

In [None]:
rec_centered = si.bandpass_filter(rec, freq_min=300, freq_max=4999)

In [None]:
seg_sort = si.SelectSegmentSorting(segment_sorting, sel_idx)
seg_sort = si.remove_excess_spikes(seg_sort, rec_centered)
seg_sort.register_recording(rec_centered)

In [None]:
overwrite_wf = te_params["overwrite_wf"]
cutout = cutout_ms
n_jobs = te_params["n_jobs"]

In [None]:
cutout_samples, cutout_ms = te.get_assay_information(full_path)

In [None]:
segment_sorting

In [None]:
h5 = h5py.File(full_path)
rec_names = list(h5["wells"][stream_id].keys())
n_units = seg_sort.get_num_units()
template_matrix = np.full([n_units, sum(cutout_samples), 26400], np.nan)

for sel_idx, rec_name in enumerate(rec_names):
    rec = si.MaxwellRecordingExtractor(
            full_path, stream_id=stream_id, rec_name=rec_name
        )
    
    rec_centered = si.bandpass_filter(rec, freq_min=300, freq_max=4999)
    
    seg_sort = si.SelectSegmentSorting(segment_sorting, sel_idx)
    
    
    analyzer = si.create_sorting_analyzer(
        sorting=seg_sort,
        recording=rec_centered,
        sparse=False,
        overwrite=overwrite_wf
        )

    analyzer.compute("random_spikes",n_jobs=n_jobs,max_spikes_per_unit=900)
    analyzer.compute("waveforms",ms_before=cutout[0], ms_after=cutout[1],n_jobs=n_jobs)
    analyzer.compute("templates",n_jobs=n_jobs)
    tmp = analyzer.get_extension(
        extension_name="templates"
    )
    tmp_data = tmp.get_data()
    
    els = rec.get_property("contact_vector")["electrode"]
    template_matrix[:, :, els] = tmp_data

In [None]:
plt.plot(np.squeeze(template_matrix[23, :, :]))
plt.show()

In [None]:
grid = te.convert_to_grid(template_matrix[23,:,:], pos)
fig, ax = plt.subplots()
ax.imshow(np.max(np.abs(grid),axis=2).T,vmax=20)


In [None]:
import h5py as h5

In [None]:
mxw = h5.File(full_path)

In [None]:
mxw.keys()

In [None]:
mxw['recordings']['rec0000']['well006'].keys()

In [None]:
mxw['recordings']['rec0000']['well006']['spikes']

In [None]:
34897 * 65

In [None]:
mxw['recordings']['rec0000']['well006']['groups']['routed'].keys()

In [None]:
mxw['recordings']['rec0000']['well006']['groups']['routed']['raw']

In [None]:
mxw['recordings']['rec0000']['well006']['events']

In [None]:
mxw['assay']['script_id'].keys()

In [None]:
mxw['assay']['inputs']['electrodes'][0]