In [None]:
%load_ext autoreload
%autoreload 2
%matplotlib inline

In [None]:
import os, sys
import spikeinterface.full as si
import h5py
import numpy as np
from tqdm import tqdm
from glob import glob
sys.path.append("/home/phornauer/Git/axon_tracking/")
from axon_tracking import spike_sorting as ss
from axon_tracking import template_extraction as te
from axon_tracking import utils as ut
import matplotlib.pyplot as plt

In [None]:
te_params = dict()
te_params['n_jobs'] = 16 #Number of cores to use for waveform extraction
te_params['filter_band'] = 150 #Either float for the highpass filter frequency or list for the bandpass filter frequencies
te_params['overwrite'] = False #Flag if templates should be recalculated if already existing
te_params['max_spikes_per_unit'] = 1000 #Maximum number of spikes to be used for template extraction

qc_params = dict()
qc_params['min_n_spikes'] = 1500 #Minimum number of spikes to be detected for a unit for template extraction to take place
qc_params['exclude_mua'] = True #Exclude units that were labelled multi unit activity by kilosort
qc_params['use_bc'] = False #Use bombcell for QC

In [None]:
sorting_list = ['/net/bs-filesvr02/export/group/hierlemann/intermediate_data/Maxtwo/phornauer/Torsten_2/241010/T002523/AxonTracking/well001/sorter_output']

In [None]:
te.extract_templates_from_sorting_list(sorting_list, qc_params=qc_params, te_params=te_params)

In [None]:
test = np.load(
    '/net/bs-filesvr02/export/group/hierlemann/intermediate_data/Maxtwo/phornauer/Torsten_2/241010/T002523/AxonTracking/well000/sorter_output/templates/18.npy')

In [None]:
plt.imshow(np.min(test, axis=2).T, aspect='auto')

In [None]:
merge_unit_groups = si.get_potential_auto_merge(
    full_analyzer,
     resolve_graph=True
     )
analyzer_merged = full_analyzer.merge_units(merge_unit_groups=merge_unit_groups)

In [None]:
segment_sorting._annotations['phy_folder']

In [None]:
for sorting_path in tqdm(sorting_list):
    output_path = os.path.join(sorting_path, "sorter_output")
    sorting = si.KiloSortSortingExtractor(output_path)
    json_path = os.path.join(sorting_path, "spikeinterface_recording.json")
    multirecording = si.load_extractor(json_path, base_folder=True)

In [None]:
rec_path = ss.get_recording_path(multirecording)
stream_id = [p for p in sorting_path.split("/") 
             if p.startswith("well")][0]  # Find out which well this belongs to

rec_names, common_el, pos = ss.find_common_electrodes(rec_path, stream_id)
cleaned_sorting = te.select_good_units(sorting, **qc_params)
cleaned_sorting = si.remove_excess_spikes(
    cleaned_sorting, multirecording
)  # Relevant if last spike time == recording_length
cleaned_sorting.register_recording(multirecording)
segment_sorting = si.SplitSegmentSorting(cleaned_sorting, multirecording)

In [None]:
stream_id = [p for p in output_path.split("/") 
             if p.startswith("well")][0]  # Find out which well this belongs to

In [None]:
full_path = ss.get_recording_path(segment_sorting)
cutout_samples, cutout_ms = te.get_assay_information(full_path)
overwrite_wf = te_params["overwrite_wf"]
cutout = cutout_ms
n_jobs = te_params["n_jobs"]

In [None]:
si.set_global_job_kwargs(n_jobs=n_jobs, progress_bar=False)
full_analyzer = si.create_sorting_analyzer(
        sorting=cleaned_sorting,
        recording=multirecording
)

In [None]:
full_analyzer.compute(["random_spikes",
                      "waveforms",
                      "templates",
                      "spike_amplitudes",
                      "unit_locations",
                      "template_similarity",
                      "correlograms"],
                      extension_params={
                          "random_spikes":{"max_spikes_per_unit":900},
                          "waveforms":{"ms_before":cutout[0], "ms_after":cutout[1]}
                          }
                      )

In [None]:
merge_unit_groups = si.get_potential_auto_merge(
    full_analyzer,
    resolve_graph=True
    )
analyzer_merged = full_analyzer.merge_units(merge_unit_groups=merge_unit_groups)

In [None]:
removed_sorting = si.remove_redundant_units(
    analyzer_merged,
    duplicate_threshold=0.8,
    remove_strategy="minimum_shift"
)

In [None]:
analyzer_merged.compute("noise_levels")
metrics = si.compute_quality_metrics(analyzer_merged, n_jobs=n_jobs)

In [None]:
metrics.shape

In [None]:
output_path = '/net/bs-filesvr02/export/group/hierlemann/intermediate_data/Maxtwo/phornauer/AxonScan/Test'

In [None]:
sel_unit_ids = removed_sorting.get_unit_ids()
template_save_path = os.path.join(output_path, "templates")
if not os.path.exists(template_save_path):
    os.makedirs(template_save_path)

In [None]:
h5 = h5py.File(full_path)
rec_names = list(h5["wells"][stream_id].keys())

In [None]:
plt.scatter(metrics["firing_rate"],metrics["rp_contamination"])
plt.show()

In [None]:
h5 = h5py.File(full_path)
rec_names = list(h5["wells"][stream_id].keys())
n_units = cleaned_sorting.get_num_units()
template_matrix = np.full([n_units, sum(cutout_samples), 26400], np.nan)

for sel_idx, rec_name in enumerate(rec_names):
    rec = si.MaxwellRecordingExtractor(
            full_path, stream_id=stream_id, rec_name=rec_name
        )
    
    rec_centered = si.highpass_filter(rec, freq_min=150)
    
    seg_sort = si.SelectSegmentSorting(segment_sorting, sel_idx)
    seg_sort = si.remove_excess_spikes(seg_sort, rec_centered)
    seg_sort.register_recording(rec_centered)   
    
    analyzer = si.create_sorting_analyzer(
        sorting=seg_sort,
        recording=rec_centered,
        sparse=False,
        overwrite=overwrite_wf
        )

    analyzer.compute("random_spikes",n_jobs=n_jobs,max_spikes_per_unit=900)
    analyzer.compute("waveforms",ms_before=cutout[0], ms_after=cutout[1],n_jobs=n_jobs)
    analyzer.compute("templates",n_jobs=n_jobs)
    tmp = analyzer.get_extension(
        extension_name="templates"
    )
    tmp_data = tmp.get_data()
    
    els = rec.get_property("contact_vector")["electrode"]
    template_matrix[:, :, els] = tmp_data    


In [None]:
rec.get_property("contact_vector")["device_channel_indices"].shape

In [None]:
plt.plot(np.squeeze(template_matrix[0, :, :]))
plt.show()

In [None]:
metrics.iloc[1]

In [None]:
grid = te.convert_to_grid(template_matrix[1,:,:], pos)
fig, ax = plt.subplots()
ax.imshow(np.max(np.abs(grid),axis=2).T,vmax=20)

In [None]:
noise_tmp = np.squeeze(template_matrix[1,:,:])

In [None]:
th = -5
tmp_th = noise_tmp<th
grid_th = te.convert_to_grid(tmp_th, pos)

In [None]:
plt.imshow(grid_th.max(axis=2).T)

In [None]:
np.diff(np.nonzero(np.max(tmp_th,axis=0)))

In [None]:
import h5py as h5

In [None]:
mxw = h5.File(full_path)

In [None]:
mxw.keys()

In [None]:
mxw['recordings']['rec0000']['well005'].keys()

In [None]:
mxw['recordings']['rec0000']['well005']['spikes'][0][0]

In [None]:
fnos = mxw['recordings']['rec0000']['well005']['groups']['routed']['frame_nos']

In [None]:
fnos[-1] - fnos[0]

In [None]:
fnos[:10]

In [None]:
fdiff = np.diff(fnos)
np.unique(fdiff).min()

In [None]:
mxw['recordings']['rec0000']['well006']['groups']['routed'].keys()

In [None]:
mxw['recordings']['rec0000']['well006']['groups']['routed']['raw']

In [None]:
mxw['recordings']['rec0000']['well006']['events']

In [None]:
mxw['assay']['script_id'].keys()

In [None]:
mxw['assay']['inputs']['electrodes'][0]