In [1]:
%load_ext autoreload
%autoreload 2

## This notebook contains code to run the resampled analysis 

### More precisely it consists in:

1. Code to compute relocalized positions + amplitude for spikes (unit amplitude is median)
    a. Location spread for each unit is computed using the spread() function src.resampled_locs_amps.py
3. Code to compute template SNR


In [2]:
import numpy as np
from pathlib import Path
import os
import h5py
from tqdm.auto import tqdm
import torch
import pickle 

from dartsort.util.waveform_util import make_channel_index
from dartsort.util.spikeio import read_waveforms_channel_index

from dartsort.templates import TemplateData

from dartsort.util.drift_util import registered_geometry, get_spike_pitch_shifts, get_waveforms_on_static_channels
from dartsort.util.data_util import (
    chunk_time_ranges, 
    DARTsortSorting,
    check_recording,
    keep_only_most_recent_spikes,
)
from dartsort.config import TemplateConfig
template_config=TemplateConfig(
    superres_templates=False,
    spatial_svdsmoothing = False,
    time_tracking = True,
    spikes_per_unit=100,
    subchunk_time_smoothing=False,
    realign_peaks=False,
)


from uhd_resampled_sorting_analysis.src.resampled_locs_amps import relocalize_after_clustering, spread
from uhd_resampled_sorting_analysis.src.template_snr import compute_template_SNR

import spikeinterface.core as sc

import matplotlib.pyplot as plt
import matplotlib.cm as cm
import colorcet as cc 

ccolors = cc.glasbey[:31]
def get_ccolor(k):
    if k == -1:
        return "#808080"
    else:
        return ccolors[k % len(ccolors)]
jet = cm.get_cmap("jet")

from spike_psvae.cluster_viz import array_scatter

  from .autonotebook import tqdm as notebook_tqdm
  jet = cm.get_cmap("jet")


In [3]:
sampling_rate = 30_000
dtype_preprocessed = "float32"
loc_radius = 100
uv_bit_gain_factor = 2.34


In [4]:
data_names = [
    # Session names
   "ZYE_0021___2021-05-01___1",
   "ZYE_0031___2021-12-03___1___p0_g0_imec0",
   "ZYE_0040___2021-08-15___3___p1_g0_imec0", 
    "ZYE_0021___2021-05-01___4___p2_g0_imec0",
    "ZYE_0057___2022-02-07",
    "ZYE_0057___2022-02-04___1___p0_g0___p0_g0_imec0",
    "ZYE_0031___2021-12-01___1___p0_g0_imec0",
    "ZYE_0057___2022-02-07",
    "ZYE_0021___2021-05-01___4___p2_g0_imec0",
    "ZYE_0057___2022-02-04___1___p0_g0___p0_g0_imec0",
    "ZYE_0031___2021-12-02___4___p2_g0_imec0",
    "ZYE_0057___2022-02-03___2___p1_g0___p1_g0_imec0",
    "LK_0011___2021-12-06___1___p0_g0_imec0"
]

slices_all = [
    # Here, cut parts of recording with low-quality data
    [280, None],
    [350, None],
    [None, None],
    [200, None],
    [None, None],
    [None, 3750],
    [None, None],
    [None, None],
    [200, None],
    [None, 3750],
    [None, 3700],
    [None, 3400],
    [None, None],
]


full_dir = Path("UHD_DATA")




In [8]:
for name_recording in data_names:

    print(f"relocalizing {name_recording}")

    name = name_recording + f"_pat1" # to get UHD data
    data_dir = full_dir / name
    
    subtraction_dir = data_dir / "subtraction_results"
    sub_h5 = subtraction_dir / "subtraction.h5"
    motion_estimate_name = subtraction_dir / "motion_estimate.obj"
    data_dir_cluster = data_dir / "initial_clustering"
    
    with h5py.File(sub_h5, "r+") as h5:
        times_samples = np.array(h5["times_samples"][:])
        channels = np.array(h5["channels"][:])
        geom_uhd = np.array(h5["geom"][:])
        
    cluster_labels = np.load(data_dir_cluster / "clustering_labels.npy")

    for pat_reloc in [1, 2, 3, 4]:

        geom = np.load(f"geom_array_pat{pat_reloc}.npy")
        
        full_dir = Path("UW_DATA")
        name_resampled = name_recording + f"_pat{pat_reloc}"
        resampled_data_dir = full_dir / name_resampled

        name_relocs = resampled_data_dir / "reloc_locations.npy"

        # if not os.path.exists(name_relocs):

        print(f"pattern {pat_reloc}")
    
        recording = sc.read_binary(
                resampled_data_dir / "standardized.bin",
                sampling_rate,
                dtype_preprocessed,
                num_channels=geom.shape[0],
                is_filtered=True,
            )    
        recording.set_dummy_probe_from_locations(
            geom, shape_params=dict(radius=10)
        )
        
        rec_sd = np.median(np.load(resampled_data_dir / "mean_and_standard_dev_value.npz")["sd"])
        
        loc_vector, amp_vector = relocalize_after_clustering(
            recording,
            geom,
            times_samples, 
            channels,
            rec_sd,
            geom_uhd,
            model="pointsource",
            device="cuda:3",
            n_spikes_fit_tpca = 10_000,
            batch_size=2048, 
        )

        np.save(resampled_data_dir / "reloc_locations.npy", loc_vector)
        np.save(resampled_data_dir / "reloc_amplitudes.npy", amp_vector*uv_bit_gain_factor)

    

relocalizing ZYE_0021___2021-05-01___4___p2_g0_imec0
pattern 1
fitting pca
relocalizing


  torch.tensor(wfs.max(1).values - wfs.min(1).values, device=device),
100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 220/220 [41:16<00:00, 11.26s/it]


relocalizing ZYE_0057___2022-02-07
pattern 1
fitting pca
relocalizing


  torch.tensor(wfs.max(1).values - wfs.min(1).values, device=device),
100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 331/331 [1:07:09<00:00, 12.17s/it]


In [57]:
data_names[0]

'ZYE_0021___2021-05-01___4___p2_g0_imec0'

In [6]:

for name_recording, slice_s in zip(data_names, slices_all):

    print(f"computing SNR for {name_recording}")

    full_dir = Path("/mnt/ssd2tb2/julien/UW_DATA")
    name = name_recording + f"_pat1" # to get UHD data
    data_dir = full_dir / name
    
    subtraction_dir = data_dir / "subtraction_results"
    sub_h5 = subtraction_dir / "subtraction.h5"
    motion_estimate_name = subtraction_dir / "motion_estimate.obj"
    data_dir_cluster = data_dir / "initial_clustering"
    
    with h5py.File(sub_h5, "r+") as h5:
        times_samples = np.array(h5["times_samples"][:])
        times_seconds = np.array(h5["times_seconds"][:])
        channels = np.array(h5["channels"][:])
        geom_uhd = np.array(h5["geom"][:])
        localization_results = np.array(h5["point_source_localizations"][:])
        amps = np.array(h5["denoised_ptp_amplitudes"][:])
        
    cluster_labels = np.load(data_dir_cluster / "clustering_labels.npy")
    _, cluster_labels[cluster_labels>-1] = np.unique(cluster_labels[cluster_labels>-1], return_inverse=True)

    for pat_reloc in [1, 2, 3, 4]:

        geom = np.load(f"geom_array_pat{pat_reloc}.npy")
        
        # full_dir = Path("UW_DATA")
        
        name_resampled = name_recording + f"_pat{pat_reloc}"
        resampled_data_dir = Path(f"UW_DATA/{name_resampled}")

        # name_relocs = resampled_data_dir / "reloc_locations.npy"
        # name_reamps = resampled_data_dir / "reloc_amplitudes.npy"

        # if not os.path.exists(name_relocs):

        print(f"pattern {pat_reloc}")
    
        recording = sc.read_binary(
                resampled_data_dir / "standardized.bin",
                sampling_rate,
                dtype_preprocessed,
                num_channels=geom.shape[0],
                is_filtered=True,
            )    
        recording.set_dummy_probe_from_locations(
            geom, shape_params=dict(radius=10)
        )

        # relocalization_results = np.load(name_relocs)
        # reamps_results = np.load(name_reamps)

        #1. Make sorting
        main_channels = ((geom_uhd[channels][None] - geom[:, None])**2).sum(2).argmin(0)
        
        sorting = DARTsortSorting(
            times_samples=times_samples,
            channels=main_channels,
            labels=cluster_labels,
            # parent_h5_path=sub_h5,
            extra_features={
                "point_source_localizations": localization_results,
                "denoised_ptp_amplitudes": amps,
                "times_seconds": times_seconds,
            },
        )

        #2. Make chunk_time_ranges_s and  template data list
        chunk_time_ranges_s = chunk_time_ranges(recording, chunk_length_samples=300*30_000, slice_s=slice_s)
        n_chunks = len(chunk_time_ranges_s)
        
        #3. Load motion_est
        filehandler =open(motion_estimate_name, 'rb') 
        me = pickle.load(filehandler)
        
        template_SNR = compute_template_SNR(
            sorting, # UHD sorting
            recording,
            me,
            chunk_time_ranges_s,
            geom,
            geom_uhd,
            localization_results, # re-localization results for each probe need to be computed beforehand
            template_config=template_config, # TO CHANGE
            template_data_list=None,
        )

        np.save(resampled_data_dir / "template_SNR.npy", template_SNR)

    

computing SNR for ZYE_0057___2022-02-04___1___p0_g0___p0_g0_imec0
pattern 2
fitting tsvd
keeping all necessary spikes
computing pitch shifts


Computing templates for all chunks: 100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 120/120 [02:10<00:00,  1.09s/it]


FOR LOOP DONE
snrs_by_chan DONE
templates linear done
templates denoised
spatial_svdsmoothing done
realign_peaks done
computing main chans
computing SNR


  0%|                                                                                                                                                                               | 0/192 [00:00<?, ?it/s]
12it [00:00, 159.54it/s]
  1%|▊                                                                                                                                                                    | 1/192 [00:22<1:12:48, 22.87s/it]
12it [00:00, 142.46it/s]
  1%|█▋                                                                                                                                                                   | 2/192 [00:50<1:20:53, 25.54s/it]
12it [00:00, 169.63it/s]
  2%|██▌                                                                                                                                                                  | 3/192 [01:17<1:22:18, 26.13s/it]
12it [00:00, 298.91it/s]
  2%|███▍                                                                       

pattern 3
fitting tsvd
keeping all necessary spikes
computing pitch shifts


Computing templates for all chunks: 100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 120/120 [02:19<00:00,  1.17s/it]


FOR LOOP DONE
snrs_by_chan DONE
templates linear done
templates denoised
spatial_svdsmoothing done
realign_peaks done
computing main chans
computing SNR


  0%|                                                                                                                                                                               | 0/192 [00:00<?, ?it/s]
12it [00:00, 153.70it/s]
  1%|▊                                                                                                                                                                    | 1/192 [00:22<1:12:10, 22.67s/it]
12it [00:00, 142.99it/s]
  1%|█▋                                                                                                                                                                   | 2/192 [00:49<1:20:27, 25.41s/it]
12it [00:00, 143.34it/s]
  2%|██▌                                                                                                                                                                  | 3/192 [01:16<1:21:12, 25.78s/it]
12it [00:00, 233.46it/s]
  2%|███▍                                                                       

In [7]:
print("ok")

ok
