## Setup

In [1]:
import MEArec as mr
import numpy as np
import scipy.optimize
import os
import sys
import re
import ast
import matplotlib.pyplot as plt
import pandas as pd
import pickle
import time
from collections import defaultdict
from joblib import Parallel, delayed

import spikeinterface as si
import spikeinterface.core as sc
import spikeinterface.extractors as se
import spikeinterface.sorters as ss
import spikeinterface.preprocessing as spre
import spikeinterface.postprocessing as spost
import spikeinterface.widgets as sw
import spikeinterface.comparison as scomp
import spikeinterface.curation as scu

sys.path.append('src')
from src.util_eval import *
from src.util_loc import *

In [2]:
# Overall parameters
mearec_seed = 43
dead_indices_seeds = [42, 43, 44]
run_id = f'toy_{mearec_seed}'
sorter_name = 'mountainsort4'
days = ['D0', 'D1', 'D2', 'D3', 'D4', 'D5']
methods = ['center_of_mass', 'monopolar_triangulation', 'grid_convolution']
stable_days = 1
dead_electrodes_per_day = 75
gt_toggle = 0 # if 0, use ground truth data as experimental data
correct_radius = 30 # microns

# MEArec parameters
mearec_probe = 'Neuropixels-384'
num_channels = 384
recgen_duration = 15
mearec_noise_level = 10
output_mearec = f'output/{run_id}'

In [3]:
# Create the template and recording generator for multiple days
cell_folder = mr.get_default_cell_models_folder()
temp_params = mr.get_default_templates_params()
temp_params['n'] = 25
temp_params['probe'] = mearec_probe

for dead_indices_seed_i, dead_indices_seed in enumerate(dead_indices_seeds):    
    output_mearec_seed = f'{output_mearec}/seed_{dead_indices_seed}'
    
    if os.path.exists(f'{output_mearec_seed}/templates.h5'):
        tempgen = mr.load_templates(f'{output_mearec_seed}/templates.h5')
        recgen = mr.load_recordings(f'{output_mearec_seed}/D0/recording.h5')
    
    else:
        temp_params['seed'] = mearec_seed
        tempgen = mr.gen_templates(cell_models_folder=cell_folder, params=temp_params, templates_tmp_folder=None, intraonly=False, 
                                   parallel=True, recompile=False, n_jobs=None, delete_tmp=True, verbose=False)
        mr.save_template_generator(tempgen, f'{output_mearec_seed}/templates.h5')
            
        for day_i, day in enumerate(days):
            rec_params = mr.get_default_recordings_params()
            rec_params['spiketrains']['n_exc'] = 35 # Number of excitatory neurons
            rec_params['spiketrains']['n_inh'] = 15 # Number of inhibitory neurons

            rec_params['spiketrains']['duration'] = recgen_duration # Duration of recording in seconds
            rec_params['recordings']['fs'] = 30000 # Sampling frequency
            rec_params['recordings']['noise_level'] = mearec_noise_level
            rec_params['seeds']['templates'] = mearec_seed

            recgen = mr.gen_recordings(params=rec_params, tempgen=tempgen)
            mr.save_recording_generator(recgen, f'{output_mearec_seed}/{day}/recording.h5')

Setting n_jobs to 20 CPUs
Starting simulation 1/13 - cell: L5_BP_bAC217_1

Starting simulation 4/13 - cell: L5_DBC_bAC217_1

Starting simulation 2/13 - cell: L5_BTC_bAC217_1

Starting simulation 7/13 - cell: L5_NBC_bAC217_1

Starting simulation 10/13 - cell: L5_STPC_cADpyr232_1

Starting simulation 13/13 - cell: L5_UTPC_cADpyr232_1

Starting simulation 3/13 - cell: L5_ChC_cACint209_1

Starting simulation 5/13 - cell: L5_LBC_bAC217_1

Starting simulation 9/13 - cell: L5_SBC_bNAC219_1

Starting simulation 6/13 - cell: L5_MC_bAC217_1

Starting simulation 8/13 - cell: L5_NGC_bNAC219_1

Starting simulation 11/13 - cell: L5_TTPC1_cADpyr232_1

Starting simulation 12/13 - cell: L5_TTPC2_cADpyr232_1

Intracellular simulation: /home/hao-zhao/.config/mearec/1.9.1/cell_models/bbp/L5_BP_bAC217_1
Extracellular simulation: /home/hao-zhao/.config/mearec/1.9.1/cell_models/bbp/L5_BP_bAC217_1
Intracellular simulation: /home/hao-zhao/.config/mearec/1.9.1/cell_models/bbp/L5_DBC_bAC217_1
Extracellular simul

In [4]:
# Function to process each seed
def process_seed(dead_indices_seed_i, dead_indices_seed):
    output_mearec_seed = f'{output_mearec}/seed_{dead_indices_seed}'
    
    # Dead indices
    if os.path.exists(f'{output_mearec_seed}/dead_indices.pkl'):
        print('Loading dead indices')
        with open(f'{output_mearec_seed}/dead_indices.pkl', 'rb') as f:
            dead_indices = pickle.load(f)
    else:
        print('Creating dead indices')
        np.random.seed(dead_indices_seed)
        
        dead_indices = []
        dead_indices_temp = np.random.choice(range(num_channels), size=dead_electrodes_per_day*(len(days)-stable_days), replace=False)
        for day_i, day in enumerate(days):
            if day_i < stable_days:
                dead_indices.append([])
            else:
                dead_indices.append(dead_indices_temp[:dead_electrodes_per_day*(day_i-1)])
        
        with open(f'{output_mearec_seed}/dead_indices.pkl', "wb") as file:
            pickle.dump(dead_indices, file)

    # Recording data
    recording_mearec = se.MEArecRecordingExtractor(f'{output_mearec_seed}/D0/recording.h5')
    probe = recording_mearec.get_probe()

    recordings = []
    for day_i, day in enumerate(days):
        
        if os.path.exists(f'{output_mearec_seed}/{day}/recording'):
            recordings.append(sc.load_extractor(f'{output_mearec_seed}/{day}/recording').set_probe(probe, in_place=True))
        else:
            recording_mearec = se.MEArecRecordingExtractor(f'{output_mearec_seed}/{day}/recording.h5')
            probe = recording_mearec.get_probe()
            get_recording_noise(recording_mearec, dead_indices[day_i], f'{output_mearec_seed}/{day}/recording')
            recordings.append(sc.load_extractor(f'{output_mearec_seed}/{day}/recording').set_probe(probe, in_place=True))
            

    recording = sc.concatenate_recordings(recordings).set_probe(probe, in_place=True)    
    
    # Sorting data
    sortings_gt = []
    for day_i, day in enumerate(days):
        sortings_gt.append(se.MEArecSortingExtractor(f'{output_mearec_seed}/{day}/recording.h5'))
    sorting_gt = sc.append_sortings(sortings_gt)        
    
    # Waveform extractor data
    wes_gt = []
    for day_i, day in enumerate(days):
        wes_gt.append(si.extract_waveforms(recordings[day_i], sortings_gt[day_i], folder=f'{output_mearec_seed}/{day}/waveforms_gt', ms_before=1, ms_after=2, load_if_exists=True))
        wes_gt[-1].run_extract_waveforms()

    # Calculate template locations
    if os.path.exists(f'{output_mearec_seed}/loc_est_units'):
        print('Loading template localization estimates')
        with open(f'{output_mearec_seed}/loc_est_units/loc_est_units.pkl', 'rb') as f:
            loc_est_units = pickle.load(f)
        with open(f'{output_mearec_seed}/loc_est_units/time_units.pkl', 'rb') as f:
            time_units = pickle.load(f)
            
    else:
        print('Calculating template localization estimates')            
        loc_est_units = {} # First entry is method. Then the entry is a list of locations across days.
        time_units = {}

        for method in methods:
            start_time = time.time()
            loc_est_units[method] = get_unit_loc_est(method, wes_gt)
            time_units[method] = time.time() - start_time
        
        os.makedirs(f'{output_mearec_seed}/loc_est_units')
        pickle.dump(loc_est_units, open(f'{output_mearec_seed}/loc_est_units/loc_est_units.pkl', 'wb'))
        pickle.dump(time_units, open(f'{output_mearec_seed}/loc_est_units/time_units.pkl', 'wb'))        

    # Calculate spike location estimates
    if os.path.exists(f'{output_mearec_seed}/loc_est_spikes'):
        print('Loading spike localization estimates')
        with open(f'{output_mearec_seed}/loc_est_spikes/loc_est_spikes.pkl', 'rb') as f:
            loc_est_spikes = pickle.load(f)
        with open(f'{output_mearec_seed}/loc_est_spikes/time_spikes.pkl', 'rb') as f:
            time_spikes = pickle.load(f)

    else:
        print('Calculating spike localization estimates')
        loc_est_spikes = {} # First entry is method, second is day.
        time_spikes = {}

        for method in methods:
            
            loc_est_spikes[method] = {}
            time_spikes[method] = {}
            for day_i, day in enumerate(days):
                
                start_time = time.time()
                if method == 'monopolar_triangulation':
                    unit_loc_est = spost.compute_spike_locations(wes_gt[day_i], method=method, outputs='by_unit', method_kwargs={'optimizer': 'least_square'})
                else:
                    unit_loc_est = spost.compute_spike_locations(wes_gt[day_i], method=method, outputs='by_unit')
                time_spikes[method][day] = time.time() - start_time
                loc_est_spikes[method][day] = unit_loc_est
        
        os.makedirs(f'{output_mearec_seed}/loc_est_spikes')
        pickle.dump(loc_est_spikes, open(f'{output_mearec_seed}/loc_est_spikes/loc_est_spikes.pkl', 'wb'))
        pickle.dump(time_spikes, open(f'{output_mearec_seed}/loc_est_spikes/time_spikes.pkl', 'wb'))
    
# Use joblib to parallelize the processing of seeds
Parallel(n_jobs=2)(delayed(process_seed)(dead_indices_seed_i, dead_indices_seed) for dead_indices_seed_i, dead_indices_seed in enumerate(dead_indices_seeds))

Creating dead indices
Creating dead indices


  ar2 = np.asarray(ar2).ravel()
  mask |= (ar1 == a)
  ar2 = np.asarray(ar2).ravel()
  mask |= (ar1 == a)
write_binary_recording:   0%|          | 0/15 [00:00<?, ?it/s]

write_binary_recording with n_jobs = 1 and chunk_size = 30000
write_binary_recording with n_jobs = 1 and chunk_size = 30000


write_binary_recording: 100%|##########| 15/15 [00:06<00:00,  2.46it/s]
write_binary_recording: 100%|##########| 15/15 [00:06<00:00,  2.45it/s]


write_binary_recording with n_jobs = 1 and chunk_size = 30000


write_binary_recording:   0%|          | 0/15 [00:00<?, ?it/s]

write_binary_recording with n_jobs = 1 and chunk_size = 30000


write_binary_recording: 100%|##########| 15/15 [00:06<00:00,  2.44it/s]
write_binary_recording: 100%|##########| 15/15 [00:06<00:00,  2.44it/s]


write_binary_recording with n_jobs = 1 and chunk_size = 30000


write_binary_recording:  20%|##        | 3/15 [00:01<00:04,  2.40it/s]

write_binary_recording with n_jobs = 1 and chunk_size = 30000


write_binary_recording: 100%|##########| 15/15 [00:06<00:00,  2.45it/s]
write_binary_recording: 100%|##########| 15/15 [00:06<00:00,  2.45it/s]
write_binary_recording:   0%|          | 0/15 [00:00<?, ?it/s]

write_binary_recording with n_jobs = 1 and chunk_size = 30000


write_binary_recording:  13%|#3        | 2/15 [00:00<00:05,  2.36it/s]

write_binary_recording with n_jobs = 1 and chunk_size = 30000


write_binary_recording: 100%|##########| 15/15 [00:06<00:00,  2.45it/s]
write_binary_recording: 100%|##########| 15/15 [00:06<00:00,  2.48it/s]


write_binary_recording with n_jobs = 1 and chunk_size = 30000


write_binary_recording:   0%|          | 0/15 [00:00<?, ?it/s].42it/s]

write_binary_recording with n_jobs = 1 and chunk_size = 30000


write_binary_recording: 100%|##########| 15/15 [00:06<00:00,  2.42it/s]
write_binary_recording: 100%|##########| 15/15 [00:06<00:00,  2.47it/s]


write_binary_recording with n_jobs = 1 and chunk_size = 30000


write_binary_recording:   7%|6         | 1/15 [00:00<00:05,  2.38it/s]

write_binary_recording with n_jobs = 1 and chunk_size = 30000


write_binary_recording: 100%|##########| 15/15 [00:06<00:00,  2.44it/s]
write_binary_recording: 100%|##########| 15/15 [00:05<00:00,  2.50it/s]
extract waveforms shared_memory multi buffer: 100%|##########| 15/15 [00:00<00:00, 24.56it/s]
extract waveforms shared_memory multi buffer: 100%|##########| 15/15 [00:00<00:00, 21.61it/s]
extract waveforms memmap multi buffer: 100%|##########| 15/15 [00:00<00:00, 27.54it/s]
extract waveforms memmap multi buffer: 100%|##########| 15/15 [00:00<00:00, 29.39it/s]
extract waveforms memmap multi buffer: 100%|##########| 15/15 [00:00<00:00, 26.72it/s]
extract waveforms memmap multi buffer: 100%|##########| 15/15 [00:00<00:00, 30.70it/s]0it/s]
extract waveforms shared_memory multi buffer: 100%|##########| 15/15 [00:00<00:00, 31.31it/s]
extract waveforms shared_memory multi buffer: 100%|##########| 15/15 [00:00<00:00, 32.78it/s]
extract waveforms memmap multi buffer: 100%|##########| 15/15 [00:00<00:00, 27.52it/s]
extract waveforms memmap multi buffer: 

Calculating template localization estimates


extract waveforms memmap multi buffer: 100%|##########| 15/15 [00:00<00:00, 35.38it/s]
extract waveforms memmap multi buffer: 100%|##########| 15/15 [00:00<00:00, 35.78it/s]


Calculating template localization estimates
Calculating spike localization estimates


localize peaks using center_of_mass: 100%|##########| 15/15 [00:00<00:00, 129.70it/s]
localize peaks using center_of_mass: 100%|##########| 15/15 [00:00<00:00, 107.37it/s]
localize peaks using center_of_mass: 100%|##########| 15/15 [00:00<00:00, 118.23it/s]
localize peaks using center_of_mass: 100%|##########| 15/15 [00:00<00:00, 111.52it/s]
localize peaks using center_of_mass: 100%|##########| 15/15 [00:00<00:00, 118.24it/s]
localize peaks using center_of_mass: 100%|##########| 15/15 [00:00<00:00, 122.45it/s]


Calculating spike localization estimates


localize peaks using center_of_mass: 100%|##########| 15/15 [00:00<00:00, 141.74it/s]
localize peaks using center_of_mass: 100%|##########| 15/15 [00:00<00:00, 130.37it/s]
localize peaks using center_of_mass: 100%|##########| 15/15 [00:00<00:00, 131.80it/s]
localize peaks using center_of_mass: 100%|##########| 15/15 [00:00<00:00, 131.65it/s]
localize peaks using center_of_mass: 100%|##########| 15/15 [00:00<00:00, 135.13it/s]
localize peaks using center_of_mass: 100%|##########| 15/15 [00:00<00:00, 131.92it/s]
localize peaks using monopolar_triangulation: 100%|##########| 15/15 [00:14<00:00,  1.03it/s]
localize peaks using monopolar_triangulation: 100%|##########| 15/15 [00:14<00:00,  1.04it/s]
localize peaks using monopolar_triangulation: 100%|##########| 15/15 [00:13<00:00,  1.11it/s]
localize peaks using monopolar_triangulation: 100%|##########| 15/15 [00:15<00:00,  1.00s/it]
localize peaks using monopolar_triangulation: 100%|##########| 15/15 [00:14<00:00,  1.02it/s]
localize peaks

Creating dead indices


localize peaks using grid_convolution: 100%|##########| 15/15 [00:01<00:00, 10.18it/s]
extract waveforms shared_memory mono buffer: 100%|##########| 15/15 [00:00<00:00, 1598.68it/s]
localize peaks using grid_convolution: 100%|##########| 15/15 [00:01<00:00, 10.35it/s]
extract waveforms shared_memory mono buffer: 100%|##########| 15/15 [00:00<00:00, 1575.50it/s]
localize peaks using grid_convolution: 100%|##########| 15/15 [00:01<00:00, 10.95it/s]
extract waveforms shared_memory mono buffer: 100%|##########| 15/15 [00:00<00:00, 1549.47it/s]
localize peaks using grid_convolution: 100%|##########| 15/15 [00:01<00:00, 11.92it/s]
  ar2 = np.asarray(ar2).ravel()
  mask |= (ar1 == a)
write_binary_recording:   0%|          | 0/15 [00:00<?, ?it/s]

write_binary_recording with n_jobs = 1 and chunk_size = 30000


write_binary_recording: 100%|##########| 15/15 [00:05<00:00,  2.76it/s]


write_binary_recording with n_jobs = 1 and chunk_size = 30000


write_binary_recording: 100%|##########| 15/15 [00:05<00:00,  2.75it/s]


write_binary_recording with n_jobs = 1 and chunk_size = 30000


write_binary_recording: 100%|##########| 15/15 [00:05<00:00,  2.78it/s]
write_binary_recording:   0%|          | 0/15 [00:00<?, ?it/s]

write_binary_recording with n_jobs = 1 and chunk_size = 30000


write_binary_recording: 100%|##########| 15/15 [00:05<00:00,  2.74it/s]


write_binary_recording with n_jobs = 1 and chunk_size = 30000


write_binary_recording: 100%|##########| 15/15 [00:05<00:00,  2.77it/s]


write_binary_recording with n_jobs = 1 and chunk_size = 30000


write_binary_recording: 100%|##########| 15/15 [00:05<00:00,  2.74it/s]
extract waveforms shared_memory multi buffer: 100%|##########| 15/15 [00:00<00:00, 79.05it/s]
extract waveforms memmap multi buffer: 100%|##########| 15/15 [00:00<00:00, 39.98it/s]
extract waveforms memmap multi buffer: 100%|##########| 15/15 [00:00<00:00, 39.86it/s]
extract waveforms shared_memory multi buffer: 100%|##########| 15/15 [00:00<00:00, 80.00it/s]
extract waveforms memmap multi buffer: 100%|##########| 15/15 [00:00<00:00, 39.35it/s]
extract waveforms memmap multi buffer: 100%|##########| 15/15 [00:00<00:00, 38.62it/s]
extract waveforms shared_memory multi buffer: 100%|##########| 15/15 [00:00<00:00, 80.66it/s]
extract waveforms memmap multi buffer: 100%|##########| 15/15 [00:00<00:00, 38.97it/s]
extract waveforms memmap multi buffer: 100%|##########| 15/15 [00:00<00:00, 38.69it/s]
extract waveforms shared_memory multi buffer: 100%|##########| 15/15 [00:00<00:00, 77.34it/s]
extract waveforms memmap multi

Calculating template localization estimates
Calculating spike localization estimates


localize peaks using center_of_mass: 100%|##########| 15/15 [00:00<00:00, 140.72it/s]
localize peaks using center_of_mass: 100%|##########| 15/15 [00:00<00:00, 131.63it/s]
localize peaks using center_of_mass: 100%|##########| 15/15 [00:00<00:00, 137.23it/s]
localize peaks using center_of_mass: 100%|##########| 15/15 [00:00<00:00, 143.67it/s]
localize peaks using center_of_mass: 100%|##########| 15/15 [00:00<00:00, 141.43it/s]
localize peaks using center_of_mass: 100%|##########| 15/15 [00:00<00:00, 143.45it/s]
localize peaks using monopolar_triangulation: 100%|##########| 15/15 [00:13<00:00,  1.15it/s]
localize peaks using monopolar_triangulation: 100%|##########| 15/15 [00:13<00:00,  1.11it/s]
localize peaks using monopolar_triangulation: 100%|##########| 15/15 [00:14<00:00,  1.04it/s]
localize peaks using monopolar_triangulation: 100%|##########| 15/15 [00:15<00:00,  1.03s/it]
localize peaks using monopolar_triangulation: 100%|##########| 15/15 [00:16<00:00,  1.09s/it]
localize peaks

[None, None, None]