# Notebook to investigate the performance of spike interface in localizing neurons 

In [1]:
import MEArec as mr # what we will use to create a synthetic recording
import spikeinterface.full as si  # what we will use to sort the spikes

# Other useful imports 
import matplotlib.pyplot as plt
import math 
import numpy as np
import yaml
import warnings
from probeinterface.plotting import plot_probe
from matplotlib import cm
from probeinterface import read_prb
from spikeinterface.sortingcomponents.peak_detection import detect_peaks
from os import listdir
from scipy.ndimage.filters import gaussian_filter1d
from mpl_toolkits.axes_grid1.anchored_artists import AnchoredSizeBar
import matplotlib.font_manager as fm
import spikeinterface.widgets as sw
import time
from matplotlib.transforms import Bbox, TransformedBbox
from matplotlib.image import BboxImage
from matplotlib.legend_handler import HandlerBase

  from scipy.ndimage.filters import gaussian_filter1d


In [2]:
cell_folder = "C:\\Users\\melin\\Anaconda3\\envs\\si_env\\Lib\\site-packages\\MEArec\\cell_models\\bbp" # folder where we downloaded some cell models (13)
working_folder = "C:\\Users\\melin\\Desktop\\COURS_M2_CNN\\Projet\\" # the current folder 
param_folder = working_folder + 'params_locations\\' # where we put they yaml files with all the parameters
files_folder = working_folder + 'temporary_files_locations\\' # where we want to put the created files 
param_ext = '.yml' # yaml files contain dictionaries for our parameters 
files_ext = ".h5" # h5 files are created to save what we do 
parallel_compute = 4 # number of parallel processing that will be ongoing for heavy process 
templates_file = working_folder + 'data_real\\templates\\full_templates.npy'

maps = 'Blues', 'Purples'
cmaps = cm.get_cmap(maps[0]), cm.get_cmap(maps[1])
colors_binary = cmaps[0](0.7), cmaps[1](0.7)

## Generate data with MEArec

In [3]:
def parameters_modification(param_folder, parameters) : 
    '''
    This function reads yaml files with the parameters for generating MEA templates, spiketrains and recordings 
    and can change the parameters defined in the dictionnary parameters
    '''
    new_parameters = {
    }
    for param in list(parameters.keys()) : 
        with open(param_folder+param+param_ext) as file:
            try:
                new_parameters[param] = yaml.safe_load(file)
                for elt in parameters[param] :
                    if elt in new_parameters[param] :
                        new_parameters[param][elt] = parameters[param][elt] 
                    else :
                        print('parameter not found')
            except yaml.YAMLError as exception:
                print(exception)
    return (new_parameters)

probe = 'Fake_probe'
parameters = {
    'templates' :{'probe' : probe} , 
    'spiketrains' : {},
    'recordings' : {}
}

params = parameters_modification(param_folder, parameters)


FileNotFoundError: [Errno 2] No such file or directory: 'C:\\Users\\melin\\Desktop\\COURS_M2_CNN\\Projet\\params_locations\\templates.yml'

Here, we generate artifical data with parameters defined in a yaml file. Templates and recording files are saved.

In [None]:
def generating_data(params, files_folder, probe) : 

    with warnings.catch_warnings():
        warnings.simplefilter("ignore") # There are a lot of deprecation warnings that make the ouput not very easy to read, so we ignore them for now 

        # Generation of templates, spiketrains and recordings with MEArec
        tempgen = mr.gen_templates(cell_models_folder=cell_folder, params=params['templates'], n_jobs=parallel_compute, verbose=False)
        spgen = mr.gen_spiketrains(params=params['spiketrains'], verbose=False)
        recgen = mr.gen_recordings(params=params['recordings'], spgen = spgen, tempgen=tempgen, n_jobs=parallel_compute, verbose=False)

        #  Saving in the output folder 'files_folder' 
        mr.save_template_generator(tempgen, filename=f"{files_folder}templates{probe}{files_ext}", verbose = False)
        mr.save_recording_generator(recgen, filename=f"{files_folder}recordings{probe}{files_ext}", verbose=False)


time_gener = time.time()
generating_data(params, files_folder, probe)
time_gener = time.time - time_gener

## Analyze the artificial data

In [None]:
class artificial_data:

    def __init__(self, files_folder, probe) :
        self.probe = probe
        self.folder = files_folder
        self.tempgen = mr.load_templates(f"{files_folder}templates{probe}{files_ext}")
        self.templates = self.tempgen.templates
        self.recgen = mr.load_recordings(f"{files_folder}recordings{probe}{files_ext}")
        self.templates_in_recording()

    def load_in_spike_interface(self) :
        self.recording_si, self.sorting_si = si.read_mearec(f"{self.folder}recordings{self.probe}{files_ext}")

    def templates_in_recording (self) :
        '''
        Not all the templates created are used in the recording
        We list in 'self.template_ids' the identification numbers of the templates of interest
        '''
        self.template_ids = self.recgen.template_ids
        self.n_templates = len(self.template_ids)
        self.n_probes = len(self.tempgen.templates[0])

    def predicted_locations (self, method = 'monopolar_triangulation', ms_before=1., ms_after=1.5, radius_um=50, max_distance_um=1000) :
        '''
        Extract waveforms for each template of interest
        And compute the location with the method selected, default is monopolar_triangulation
        '''
        self.locations_pred =  [0]*self.n_templates
        self.unit_pred = [0]*self.n_templates
        for num in range (self.n_templates) :
            sorting_num = self.sorting_si.select_units([f"#{num}"]) 
            wv_num = si.extract_waveforms(recording = self.recording_si, sorting= sorting_num, overwrite= False, folder=f"{self.folder}//waveforms//wv_{self.probe}_{num}", load_if_exists=True, max_spikes_per_unit=None)
            if method == 'monopolar_triangulation':
                self.locations_pred[num]  = si.compute_spike_locations(wv_num, method = method, ms_before=ms_before, ms_after= ms_after) 
                self.unit_pred[num] = si.compute_unit_locations(wv_num, method = method, radius_um=radius_um, max_distance_um=max_distance_um)
            else :
                self.locations_pred[num]  = si.compute_spike_locations(wv_num, method = method) 
                self.unit_pred[num] = si.compute_unit_locations(wv_num, method = method)

    def find_errors(self) :
        '''
        We get the real positions of the templates of interest and compare them with the predictions made by spike interface 
        '''
        self.locations_real = [self.tempgen.locations[template] for template in self.template_ids]
        self.pred_errors = [0]*self.n_templates
        self.unit_errors = [0]*self.n_templates
        for template in range (self.n_templates) :
            self.pred_errors[template] = [0]*len(self.locations_pred[template])
            xyzreal = (self.locations_real[template][1], self.locations_real[template][2], self.locations_real[template][0])
            for time in range (len(self.locations_pred[template])) :
                if len(self.locations_pred[template][time]) < 3 :
                    xyzpred = (self.locations_pred[template][time][0], self.locations_pred[template][time][1],0)
                    self.pred_errors[template][time] = math.dist(xyzreal, xyzpred)
                else:
                    xyzpred_1 = (self.locations_pred[template][time][0], self.locations_pred[template][time][1], self.locations_pred[template][time][2])
                    xyzpred_2 = (self.locations_pred[template][time][0], self.locations_pred[template][time][1], self.locations_pred[template][time][2])
                    self.pred_errors[template][time] = min(math.dist(xyzreal, xyzpred_1),math.dist(xyzreal, xyzpred_2))
            if len(self.locations_pred[template][time]) < 3 :
                unit_pred = (self.unit_pred[template][0][0], self.unit_pred[template][0][1], 0)
                self.unit_errors[template] = math.dist(xyzreal, unit_pred)
            else:
                unit_pred_1 = (self.unit_pred[template][0][0], self.unit_pred[template][0][1], self.unit_pred[template][0][2])
                unit_pred_2 = (self.unit_pred[template][0][0], self.unit_pred[template][0][1], -self.unit_pred[template][0][2])
                self.unit_errors[template] = min(math.dist(xyzreal, unit_pred_1), math.dist(xyzreal, unit_pred_2))
        
        self.median_errors = [np.median(self.pred_errors[template]) for template in range (self.n_templates)]
        self.std_errors = [np.std(self.pred_errors[template]) for template in range (self.n_templates)]
    
    def alpha_and_z(self) :
        '''
        When a location is estimated by spike interface, it is done in 3 dimensions and with an estimation of the spike amplitude (alpha)
        Here we get the predicted z positions and the alpha values 
        '''
        self.z_values = [0]*self.n_templates
        self.alpha_values = [0]*self.n_templates
        for template in range(self.n_templates) :
            self.z_values[template] = [self.locations_pred[template][time][2] for time in range (len(self.locations_pred[template]))]
            self.alpha_values[template]= [self.locations_pred[template][time][3] for time in range (len(self.locations_pred[template]))]
        
    def templates_features(self):
        '''
        We want to understand why some templates positions are better predicted than others.
        We hypothetized that the variance of the peak to peak and the frobenius norms could affect the localisation in spike interface
        So here we compute these values for the templates of interest 
        '''
        self.templates_infos = {
            'Peak to peak variances' : [] , 'Frobenius norms' : []
        }
        for template_id in self.template_ids :
            ptp = [np.max(self.templates[template_id][probe][:]) - np.min(self.templates[template_id][probe][:]) for probe in range(self.n_probes)]
            self.templates_infos['Peak to peak variances'].append(np.var(ptp))
            self.templates_infos['Frobenius norms'].append(np.linalg.norm(self.templates[template_id]))

    def excitatory_or_inhibitory(self) :
        '''
        Excitatory and inhibitory cells exhibit different spike shapes
        So here we identify the templates coming from excitatory cells/inhibitory 
        '''
        self.exc_num = []
        self.inh_num = []
        exc = ['STPC', 'TTPC1', 'TTPC2', 'UTPC'] 
        inh = ['BP', 'BTC', 'ChC', 'DBC', 'LBC', 'MC', 'NBC', 'NGC', 'SBC']
        for num, template in enumerate(self.template_ids) :
            for celltype in exc :
                if celltype in self.tempgen.celltypes[template] :
                    self.exc_num.append(num)
            for celltype in inh :
                if celltype in self.tempgen.celltypes[template] :
                    self.inh_num.append(num)


    def plot_parameters(self, n_to_plot = 15) :
        cmap = cm.get_cmap('PuBuGn')
        if self.n_templates < n_to_plot :
            self.n_to_plot = self.n_templates
        else :
            self.n_to_plot = n_to_plot
        self.colors = [cmap(c/n_to_plot) for c in range (n_to_plot)]

    def plot_predictions (self, n_to_plot=15) :
        '''
        Function to plot the predicted locations(small dots) and the real positions (large dots) 
        '''
        self.plot_parameters(n_to_plot)
        fig, ax = plt.subplots(1, 1, figsize=(6,8))
        plot_probe(self.recording_si.get_probe(), ax=ax)
        for template in range (self.n_to_plot) : 
            x_pred = [self.locations_pred[template][t][0] for t in range (len(self.locations_pred[template]))]
            y_pred = [self.locations_pred[template][t][1] for t in range (len(self.locations_pred[template]))]
            plt.scatter(x_pred, y_pred, s = 20, edgecolors= 'none', color = self.colors[template]) 
        x_real = [self.locations_real[template][1] for template in range (self.n_to_plot)]
        y_real = [self.locations_real[template][2] for template in range (self.n_to_plot)]
        x_unit = [self.unit_pred[template][0][0] for template in range (self.n_to_plot)]
        y_unit = [self.unit_pred[template][0][1] for template in range (self.n_to_plot)]
        plt.scatter(x_real, y_real, s=100, color = self.colors, edgecolors= 'black', marker = '*')  
        plt.scatter(x_unit, y_unit, s=100, color = self.colors, edgecolors= 'black')
        plt.title('Predicted positions compared to real positions')
        plt.xlabel('x coordinate in um')
        plt.ylabel('y coordinate in um')
        plt.show()

    def boxplot_errors(self, n_to_plot=15) :
        '''
        For each template multiple location predictions are made and this function helps to visualize the range of errors for each template 
        '''
        self.plot_parameters(n_to_plot)
        fig = plt.figure(figsize =(12, 8))
        ax = plt.subplot()
        bp = plt.boxplot(self.pred_errors[:self.n_to_plot], 0, '', patch_artist=True)
        ax.set_xticklabels([f"#{i}" for i in range (self.n_to_plot)]) # We only show a maximum of 10 boxplots for readibility
        ax.set_ylabel('Prediction error in um')
        ax.set_xlabel('Templates')
        ax.set_title('Distance between template position and spike location')
        for patch, color in zip(bp['boxes'], self.colors):
            patch.set_facecolor(color)      
        plt.show()

    def plot_error_against_features (self):
        '''
        Now this functions gives us a vizualization of a potential relation between the extracted features and the localization errors 
        '''
        self.templates_features()
        self.excitatory_or_inhibitory()
        y_titles = list(self.templates_infos.keys())
        fig, ax = plt.subplots(len(y_titles), figsize = (8,10))
        fig.suptitle = 'Plot ' + self.probe
        for n in range(len(y_titles)):
            for num in self.exc_num :
                ax[n].scatter(self.templates_infos[y_titles[n]][num], self.median_errors[num],color = 'green')
            for num in self.inh_num :
                ax[n].scatter(self.templates_infos[y_titles[n]][num], self.median_errors[num], color = 'red')
            ax[n].set_ylabel('Prediction error in um')
            ax[n].set_xlabel(y_titles[n])

        plt.show()
        return()



    def predictor(self, feature) :
        '''
        We hypothesize that the feature and the error are inversely proportional 
        '''
        xarray = np.array(self.templates_infos[feature])
        yarray = np.array([1/yi for yi in self.median_errors])
        X = np.vstack([xarray, np.ones(len(xarray))]).T
        model = np.linalg.lstsq(X,yarray, rcond=None)
        self.slope = model[0][0]
        self.rectification = model[1][0]

    def plot_fit_curve(self, feature) :
        '''
        We show the fitted curve compared to the real curve 
        '''
        plt.scatter(self.templates_infos[feature], self.median_errors, color = 'blue')
        x_fit = range(int(max(self.templates_infos[feature])))
        y_fit = [1/(self.slope*xi+self.rectification) for xi in x_fit]
        plt.scatter(x_fit, y_fit, color = 'black')
        plt.title(f"Finding a model for predicting the error based on the {feature}")

In [None]:
mea_data = {}
methods = 'center_of_mass', 'monopolar_triangulation'
probe = 'Fake_probe'
timing = {}
t = 1
t_2 = 1.5 
r = 50

for method in methods: 
    mea_data[method] = artificial_data(files_folder, probe)
    print('initialisation done')
    mea_data[method].load_in_spike_interface() 
    print('Data loaded in Spike Interface')
    timing[method] = time.time() 
    mea_data[method].predicted_locations(method = method, ms_before = t, ms_after = t, radius_um = r)
    timing[method] = time.time()  - timing[method]
    print('Locations predicted')
    mea_data[method].find_errors() 
    print('Errors found, plotting starting')
    mea_data[method].plot_predictions()
    mea_data[method].boxplot_errors() 
    mea_data[method].plot_error_against_features()

print(timing)


In [None]:
method = 'monopolar_triangulation'
probe = 'Fake_probe'

t = 0.4
t_2 = 0.4
r = 45

mea_data_opti = artificial_data(files_folder, probe)
print('initialisation done')
mea_data_opti.load_in_spike_interface() 
print('Data loaded in Spike Interface')
time_opti = time.time() 
mea_data_opti.predicted_locations(method = method, ms_before = t, ms_after = t, radius_um = r)
time_opti = time.time()  - time_opti
print('Locations predicted')

print(time_opti)

## Looking at real data (ground truth recordings)

In [None]:
class real_data_templates:
    def __init__(self, file, probe) :
        self.file = file 
        self.templates = np.load(file)
        self.templates = np.array(self.templates).T.tolist()
        self.probe = probe

    
    def templates_in_recording (self) :
        self.n_templates = len(self.templates)
        self.template_ids = range(self.n_templates)
        self.n_probes = len(self.templates[0])

    def templates_features(self):
        self.templates_in_recording()
        self.template_infos = {
            'Peak to peak variances' : [] , 'Frobenius norms' : []
        }
        for template_id in self.template_ids :
            ptp = [np.max(self.templates[template_id][probe][:]) - np.min(self.templates[template_id][probe][:]) for probe in range(self.n_probes)]
            self.template_infos['Peak to peak variances'].append(np.var(ptp))
            self.template_infos['Frobenius norms'].append(np.linalg.norm(self.templates[template_id]))

    def predict_error(self, feature, slope, rectification) :
        self.templates_features()
        x_fit = self.template_infos[feature]
        y_fit = [1/(slope*xi+rectification) for xi in x_fit]
        plt.scatter(x_fit, y_fit, color = 'black')
        plt.title(f"Predicting the error based on the {feature} for real data")

In [None]:
time_recording = 30
sampling_frequency=20000

class real_data_recordings :
    def __init__(self,folder, probe, rec) :
        self.rec = rec
        self.folder = folder
        self.probe = probe
        recording_si = si.read_mcsraw(f"{folder}recordings\\{rec}\\{rec}.raw")
        recording_si.annotate(is_filtered=True)
        recording_si = recording_si.frame_slice(0,sampling_frequency*time_recording) #30 secondes
        self.recording_si = recording_si.set_probegroup(read_prb(f"{folder}\\{probe}.prb"))
        juxta = si.read_binary(f"{folder}recordings\\{rec}\\{rec}.juxta.raw", sampling_frequency=sampling_frequency, num_chan=1, dtype='float32')
        peaks = detect_peaks(juxta, exclude_sweep_ms=2, detect_threshold=8)
        times = peaks['sample_ind']
        times = np.array([t for t in times if t<sampling_frequency*time_recording])
        self.sorting_si = si.NumpySorting.from_times_labels(times, np.zeros(len(times)), sampling_frequency=sampling_frequency)
        
    def predicted_locations (self, method = 'monopolar_triangulation') :
        wv_real = si.extract_waveforms(recording = self.recording_si, sorting= self.sorting_si, overwrite= False, folder=f"{self.folder}waveforms\\waforms_{self.rec}_{method}", load_if_exists=True, max_spikes_per_unit=None)
        self.locations_pred  = si.compute_spike_locations(wv_real, method = method) 
        self.unit_pred = si.compute_unit_locations(wv_real, method = method) 

    def find_error (self, real_positions) :
        self.locations_real = real_positions
        self.pred_errors = [math.dist((pred[0],pred[1]), self.locations_real) for pred in self.locations_pred]
        self.median_errors = np.median(self.pred_errors)
        self.std_errors = np.std(self.pred_errors)
        self.unit_error = math.dist(self.locations_real, (self.unit_pred[0][0], self.unit_pred[0][1]))


In [None]:
class multiple_recordings :
    def __init__(self, recording_objects) :
        self.n_recordings = len(recording_objects.keys())
        self.locations_real = [recording_objects[key].locations_real for key in recording_objects.keys()]
        self.locations_pred = [recording_objects[key].locations_pred for key in recording_objects.keys()]
        self.unit_pred = [recording_objects[key].unit_pred for key in recording_objects.keys()]
        self.pred_errors = [recording_objects[key].pred_errors for key in recording_objects.keys()]
        self.unit_errors = [recording_objects[key].unit_error for key in recording_objects.keys()]
        self.recording_si = [recording_objects[key].recording_si for key in recording_objects.keys()]
        self.median_errors = [recording_objects[key].median_errors for key in recording_objects.keys()]
        self.std_errors = [recording_objects[key].std_errors for key in recording_objects.keys()]
        
    def plot_parameters(self) :
        cmap = cm.get_cmap('PuBuGn')
        self.colors = [cmap(c/self.n_recordings) for c in range (self.n_recordings)]
    
    def plot_predictions (self, probe_object) :
        self.plot_parameters()
        #plot_probe(probe_object)
        plot_probe(mea_data[method].recording_si.get_probe())
        for rec in range (len(self.locations_real)) :
            x_pred = [pred[0] for pred in self.locations_pred[rec]]
            y_pred = [pred[1] for pred in self.locations_pred[rec]]
            plt.scatter(x_pred, y_pred, s = 20, edgecolors= 'none', color = self.colors[rec]) # the predictions are represented by small dots 
            plt.scatter(self.locations_real[rec][0], self.locations_real[rec][1], s=100, color = self.colors[rec],  edgecolors= 'black', marker = '*') 
            plt.scatter(self.unit_pred[rec][0][0], self.unit_pred[rec][0][1], s=100, color = self.colors[rec],  edgecolors= 'black')
            plt.title('Predictions vs positions')
            plt.xlabel('x coordinate in um')
            plt.ylabel('y coordinate in um')


    def violinplot_errors(self, colors) :
        self.plot_parameters()
        plt.figure()
        ax = plt.subplot()
        vp = plt.violinplot(self.pred_errors,  showmedians=True, showextrema=False)
        ax.set_ylabel('Location error (µm)', fontsize=13)
        ax.tick_params(axis='both', which='both', labelsize=11)
        ax.set_xticks([],[])
        ax.spines[["top", "right"]].set_visible(False)
        for body, color in zip(vp['bodies'], colors): 
            body.set_color(color)
            body.set_edgecolor('black')
            body.set_alpha(1)
        for key in vp.keys():
            if key != 'bodies':
                vp[key].set_edgecolor('black') 
        for n in range (self.n_recordings):
            plt.scatter(n+1,self.unit_errors[n], color = colors[n], edgecolors= 'black', s=50, linewidths=2, marker='o')
        plt.show()

In [None]:
probe ='mea_256'
probe_object = read_prb(f"{working_folder}\\data_real\\{probe}.prb")
for method in methods: 
    physio_templates = real_data_templates(templates_file, probe) 

In [None]:
p = probe_object.to_dict()
channel_positions = p['probes'][0]['contact_positions']

real_positions = {'20160426_patch3' : channel_positions[226] + [5,-4],
                  '20170726_patch1' : channel_positions[30] + [6,6],
                  '20170728_patch2' : channel_positions[118] + [-2,9],
                  '20160426_patch2' : channel_positions[200] + [-18, -12], 
                  '20160415_patch2' : channel_positions[69] +  [3,-15]}
                  
physio_recordings = {}
for method in methods : 
    physio_recordings[method] = {}
    for rec in listdir(working_folder + 'data_real\\recordings'): 
        folder = f"{working_folder}data_real\\"
        physio_recordings[method][rec] = real_data_recordings(folder, probe, rec)
        physio_recordings[method][rec].predicted_locations(method = method)
        physio_recordings[method][rec].find_error(real_positions[rec])

In [None]:
physio_all_recordings = {}
maps = {'center_of_mass':'Blues', 'monopolar_triangulation' :'Purples'}
for method in methods:
    physio_all_recordings[method] = multiple_recordings(physio_recordings[method])
    physio_all_recordings[method].plot_predictions(probe_object)
    cmapp = cm.get_cmap(maps[method])
    colors = [cmapp(c/len(listdir(working_folder + 'data_real\\recordings'))) for c in range (len(listdir(working_folder + 'data_real\\recordings')))]
    physio_all_recordings[method].violinplot_errors(colors)

# Figures for poster 

In [None]:
def subplots_methods(dico, n_to_plot, xy_ind_real) :
    fig, axes = plt.subplots(nrows= 1, ncols= len(methods), figsize=(22,12))
    
    x,y = xy_ind_real
    for m,method in enumerate(methods):
        self = dico[method]
        cmap = cmaps[m]
        self.colors = [cmap(c/n_to_plot) for c in range (n_to_plot)]
        if type(self.recording_si) == list:
            recording = self.recording_si[0]
        else :
            recording = self.recording_si
        si.plot_probe_map(recording, with_channel_ids=False, ax=axes[m])
        for template in range (n_to_plot) : 
            x_pred = [self.locations_pred[template][t][0] for t in range (len(self.locations_pred[template]))]
            y_pred = [self.locations_pred[template][t][1] for t in range (len(self.locations_pred[template]))]
            axes[m].scatter(x_pred, y_pred, s = 50, edgecolors= 'none', color = self.colors[template])     
        x_real = [self.locations_real[template][x] for template in range (n_to_plot)]
        y_real = [self.locations_real[template][y] for template in range (n_to_plot)]
        x_unit = [self.unit_pred[template][0][0] for template in range (n_to_plot)]
        y_unit = [self.unit_pred[template][0][1] for template in range (n_to_plot)]
        axes[m].scatter(x_real, y_real, s=500, color = self.colors, edgecolors= 'black', marker = '*', linewidth=2, label = 'Real position')  
        axes[m].scatter(x_unit, y_unit, s=400, color = self.colors, edgecolors= 'black', linewidth=2, label = 'Averaged template')
        axes[m].set_title((method.replace('_',' ')).capitalize(), fontsize=30, pad =10)
        axes[m].set_ylabel('')
        axes[m].set_xlabel('')
        axes[m].scatter(-300, -300, marker = '.', s= 100, color = 'black', label = 'Individual spike')
        axes[m].set_xticks([],[])
        axes[m].set_yticks([],[])
        fontprops = fm.FontProperties(size=22)
        scalebar = AnchoredSizeBar(axes[m].transData,
                           100, '100 µm', 'lower right', 
                           pad=0.5,
                           color='black',
                           frameon=False,
                           size_vertical=1, fontproperties=fontprops)

        axes[0].add_artist(scalebar)
        fig.tight_layout(pad=2)
        plt.legend(loc = 'lower right', fontsize = 24)
    


subplots_methods(mea_data, 30, (1,2))
subplots_methods(physio_all_recordings, 5, (0,1))

In [None]:
sigma = 2
center = (0,0)
colors_binary_fill = cmaps[0](0.4), cmaps[1](0.4)

def errors_fct_dist(object, sigma, center, xyreal): 
    (a,b) = xyreal
    maxi = 0
    fig, axes_all = plt.subplots(nrows= 3, ncols= 2, figsize=(14,12)) 
    features = 'Frobenius norms', 'Peak to peak variances'
    ax = axes_all[0]

    vps = [0,0]
    vps[0] = ax[0].violinplot([object[methods[0]].median_errors,  object[methods[1]].median_errors, object[methods[0]].std_errors,  object[methods[1]].std_errors], showmedians=True, showextrema=False)
    vps[1] = ax[1].violinplot([object[methods[0]].unit_errors, object[methods[1]].unit_errors], showmedians=True, showextrema=False)
    
    ax[0].set_ylabel('Location error (µm)', fontsize=20)

    colors_binary_plus = colors_binary + colors_binary
    for vp in vps: 
        for body, color in zip(vp['bodies'], colors_binary_plus): 
            body.set_color(color)
            body.set_alpha(0.7)
            body.set_edgecolor('black')
            body.set_linewidth(2)
        for key in vp.keys():
            if key != 'bodies':
                vp[key].set_edgecolor('black') 
                

    ax[0].set_xticks([1.5,3.5],['Median', 'Standard deviation'])
    ax[1].set_xticks([],[])

    for n in range (len(ax)) : 
        ax[n].tick_params(axis='both', which='both', labelsize=18)
        ax[n].spines[["top", "right"]].set_visible(False)
        ax[n].set_ylim([0,200])


    ax[0].set_title('Individual spike', fontsize = 25, pad=15)
    ax[1].set_title('Averaged template', fontsize = 25, pad=15)
    

    axes = axes_all[1]

    for m,method in enumerate(methods):
        self = object[method]
        self.distances = []
        self.stds = []

        for template in range (len(self.locations_real)) :
            x = self.locations_real[template][a]
            y = self.locations_real[template][b]
            self.distances.append(math.dist((x,y), center))
            self.stds.append(np.std(self.pred_errors[template]))

        self.order = np.argsort(np.array(self.distances))
        sorted_d = [self.distances[i] for i in self.order]
        sorted_error = [self.median_errors[i] for i in self.order]
        sorted_std_down = [self.median_errors[i] - self.stds[i] for i in self.order]
        sorted_std_up = [self.median_errors[i] + self.stds[i] for i in self.order]

        self.unit_std= np.std(self.unit_errors)
        sorted_unit_error = [self.unit_errors[i] for i in self.order]

        if sigma is not None:
            sorted_error = gaussian_filter1d(sorted_error, sigma=sigma)
            sorted_std_down = gaussian_filter1d(sorted_std_down, sigma=sigma)
            sorted_std_up = gaussian_filter1d(sorted_std_up, sigma=sigma)
            sorted_unit_error = gaussian_filter1d(sorted_unit_error, sigma=sigma)


        maxi = np.max([maxi,np.max(sorted_std_up)])
        axes[0].plot(sorted_d, sorted_error, color = colors_binary[m], label = (method.replace('_', ' ')).capitalize(), linewidth=4)
        axes[0].fill_between(sorted_d, sorted_std_down, y2= sorted_std_up, color = colors_binary_fill[m], alpha=0.5, edgecolor=None)
        axes[0].set_ylabel('Location error (µm)', fontsize = 20)

        axes[1].plot(sorted_d, sorted_unit_error, color = colors_binary[m], label = (method.replace('_', ' ')).capitalize(), linewidth=4)

    for n in range (len(axes)) : 
        axes[n].set_xlabel('Distance from center (µm)', fontsize = 20)  
        axes[n].set_xlim([np.min(self.distances), np.max(self.distances)])
        axes[n].set_ylim([0,150])
        axes[n].tick_params(axis='both', which='both', labelsize=18)
        axes[n].spines[["top", "right"]].set_visible(False)

    axes = axes_all[2]
    
    for m,method in enumerate(methods):
        self = object[method]

        self.order = np.argsort(np.array(self.templates_infos[features[0]]))
        sorted_f = [self.templates_infos[features[0]][i] for i in self.order]
        sorted_error = [self.median_errors[i] for i in self.order]
        sorted_std_down = [self.median_errors[i] - self.stds[i] for i in self.order]
        sorted_std_up = [self.median_errors[i] + self.stds[i] for i in self.order]

        self.unit_std= np.std(self.unit_errors)
        sorted_unit_error = [self.unit_errors[i] for i in self.order]

        if sigma is not None:
            sorted_error = gaussian_filter1d(sorted_error, sigma=sigma)
            sorted_std_down = gaussian_filter1d(sorted_std_down, sigma=sigma)
            sorted_std_up = gaussian_filter1d(sorted_std_up, sigma=sigma)
            sorted_unit_error = gaussian_filter1d(sorted_unit_error, sigma=sigma)

        maxi = np.max([maxi,np.max(sorted_std_up)])
        axes[0].plot(sorted_f, sorted_error, color = colors_binary[m], label = (method.replace('_', ' ')).capitalize(), linewidth=4)
        axes[0].fill_between(sorted_f, sorted_std_down, y2= sorted_std_up, color = colors_binary_fill[m], alpha=0.5, edgecolor=None)
        axes[0].set_ylabel('Location error (µm)', fontsize = 20)

        axes[1].plot(sorted_f, sorted_unit_error, color = colors_binary[m], label = (method.replace('_', ' ')).capitalize(), linewidth=4)

    for n in range (len(axes)) : 
        axes[n].set_xlabel('L2 norm', fontsize = 20)  
        axes[n].set_xlim([np.min(sorted_f), np.max(sorted_f)])
        axes[n].set_ylim([0,120])
        axes[n].tick_params(axis='both', which='both', labelsize=18)
        axes[n].spines[["top", "right"]].set_visible(False)

    for ax in axes_all:
        ax[1].spines['left'].set_visible(False)
        ax[1].set_yticks([],[])

    fig.subplots_adjust(wspace=0.2, hspace=0.4)

    plt.show()

errors_fct_dist(mea_data, sigma, center, (1,2))
errors_fct_dist(physio_all_recordings, sigma, center, (0,1))




In [None]:
mea_data['monopolar_triangulation'].alpha_and_z()

alpha_median = []
alpha_std = []
z_median = []
z_std = []

for template in range (mea_data['monopolar_triangulation'].n_templates) :
    alpha_median.append(np.median(mea_data['monopolar_triangulation'].alpha_values[template]))
    alpha_std.append(np.std(mea_data['monopolar_triangulation'].alpha_values[template]))
    z_median.append(np.median(mea_data['monopolar_triangulation'].z_values[template]))
    z_std.append(np.std(mea_data['monopolar_triangulation'].z_values[template]))

In [None]:
fig, ax = plt.subplots(nrows= 1, ncols= 2, figsize=(22,6)) 
colors_type = ['forestgreen', 'firebrick']
for m,method in enumerate(methods):
    mea_data[method].excitatory_or_inhibitory()
    errors_exc_unit = [mea_data[method].unit_errors[template] for template in mea_data[method].exc_num]
    errors_inh_unit = [mea_data[method].unit_errors[template] for template in mea_data[method].inh_num]
    vp = ax[m].violinplot([errors_exc_unit, errors_inh_unit], showmedians=True, showextrema=False)
    for body, color, edge in zip(vp['bodies'], [colors_binary[m],colors_binary[m]], colors_type): 
        body.set_color(color)
        body.set_alpha(0.7)
        body.set_edgecolor('black')
        body.set_linewidth(3)
    for key in vp.keys():
        if key != 'bodies':
            vp[key].set_edgecolor('black') 
    ax[m].tick_params(axis='both', which='both', labelsize=18)
    ax[m].set_xticks([1,2],['Exc', 'Inh'])

    for ticklabel,tickcolor in zip(ax[m].get_xticklabels(), colors_type):
        ticklabel.set_color(tickcolor)
        ticklabel.set_fontsize(25)

    ax[m].set_ylim([0,180])
    ax[m].spines[["top", "right"]].set_visible(False)

ax[1].set_yticks([],[])
ax[1].spines['left'].set_visible(False)
ax[0].set_ylabel('Location error (µm)', fontsize=20)
plt.show()

In [None]:


def add_patch(legend, handle, label):
    ax = legend.axes
    handles, labels = ax.get_legend_handles_labels()
    handles.append(handle)
    labels.append(label)
    legend._legend_box = None
    legend._init_legend_box(handles, labels)
    legend._set_loc(legend._loc)
    legend.set_title(legend.get_title().get_text())


class ImageHandler(HandlerBase):
    def create_artists(self, legend, orig_handle,
                       xdescent, ydescent, width, height, fontsize,
                       trans):

        # enlarge the image by these margins
        sx, sy = self.image_stretch 

        # create a bounding box to house the image
        bb = Bbox.from_bounds(xdescent - sx,
                              ydescent - sy + 12,
                              width + sx,
                              height + sy)

        tbb = TransformedBbox(bb, trans)
        image = BboxImage(tbb)
        image.set_data(self.image_data)

        self.update_prop(image, orig_handle, legend)

        return [image]


    def set_image(self, image_path, image_stretch=(0, 0)):
        self.image_data = plt.imread(image_path)
        self.image_stretch = image_stretch



fig, ax = plt.subplots(nrows=1, ncols=2, figsize=(12,6))
for m,method in enumerate(methods):
    self = physio_all_recordings[method]
    cmapp = cm.get_cmap(maps[method])
    colors = [cmapp(c/len(listdir(working_folder + 'data_real\\recordings'))) for c in range (len(listdir(working_folder + 'data_real\\recordings')))]
    vp = ax[m].violinplot(self.pred_errors,  showmedians=True, showextrema=False)
    ax[0].set_ylabel('Location error (µm)', fontsize=20)
    ax[m].set_ylim([0,200])
    ax[m].tick_params(axis='both', which='both', labelsize=16)
    ax[m].set_xticks([],[])
    ax[1].set_yticks([],[])
    ax[m].spines[["top", "right"]].set_visible(False)
    ax[1].spines['left'].set_visible(False)
    for body, color in zip(vp['bodies'], colors): 
        body.set_color(color)
        body.set_edgecolor('black')
        body.set_alpha(1)
        body.set_linewidth(2)
    for key in vp.keys():
        if key != 'bodies':
            vp[key].set_edgecolor('black') 
    for n in range (self.n_recordings):
        if (m,n) == (0,0):
            a = ax[m].scatter(n+1,self.unit_errors[n], color = colors[n], edgecolors= 'black', s=300, linewidth=2, marker='o')
        if (m,n) == (0,1):
            b = ax[m].scatter(n+1,self.unit_errors[n], color = colors[n], edgecolors= 'black', s=300, linewidth=2, marker='o')
        else : 
            ax[m].scatter(n+1,self.unit_errors[n], color = colors[n], edgecolors= 'black', s=300, linewidth=2, marker='o')



custom_handler = ImageHandler()
custom_handler.set_image("./poster/vp_symbol.png",image_stretch=(8, 16))

custom_handler2 = ImageHandler()
custom_handler2.set_image("./poster/circle.png",image_stretch=(6, 18))

fig.legend([a, b],['Individual spike', 'Averaged template'], handler_map={a: custom_handler, b: custom_handler2},labelspacing=2, frameon=False, fontsize=15, bbox_to_anchor=(0.36, 0.87), loc="upper center",  bbox_transform=fig.transFigure)

plt.show()

In [None]:
fig, ax = plt.subplots(nrows=1, ncols=2, figsize=(12,6))
cmap = cm.get_cmap(maps[1])
colors = [cmap(c/7) for c in range (1,6)]

r = 45
vp = ax[0].violinplot([mea_data_params[f"{t}ms_{r}um"].median_errors for t in ms],  showmedians=True, showextrema=False)
ax[0].set_xticks(range(1,6),ms)
ax[0].set_xlabel('Temporal window (ms)', fontsize=18)
for body, color in zip(vp['bodies'], colors): 
    body.set_color(color)
    body.set_edgecolor('black')
    body.set_alpha(1)
    body.set_linewidth(2)
    for key in vp.keys():
        if key != 'bodies':
            vp[key].set_edgecolor('black')

t = 1.2
all_units = []
for r in radius: 
    all_units.append(mea_data_params[f"{t}ms_{r}um"].unit_errors)
    
vp = ax[1].violinplot(all_units,  showmedians=True, showextrema=False)
ax[1].set_xticks(range(1,6),radius)
ax[1].set_xlabel('Radius (µm)', fontsize=18)
for body, color in zip(vp['bodies'], colors): 
    body.set_color(color)
    body.set_edgecolor('black')
    body.set_alpha(1)
    body.set_linewidth(2)
    for key in vp.keys():
        if key != 'bodies':
            vp[key].set_edgecolor('black')

for m in range(2) : 
    ax[m].set_ylim([0,125])
    ax[m].tick_params(axis='both', which='both', labelsize=16)
    ax[m].spines[["top", "right"]].set_visible(False)

ax[0].set_ylabel('Location error (µm)', fontsize=18)
ax[1].set_yticks([],[])
ax[1].spines['left'].set_visible(False)


In [None]:
template_number = 4
time_dur = 12000
trace0 = mea_data[method].recording_si.get_traces(segment_index=0, start_frame=0, end_frame=time_dur,channel_ids= [str(i) for i in range (1,11)])
trace0=np.array(trace0)
trace0 = trace0.T
toplot = 4
fig, ax = plt.subplots(nrows = toplot, ncols = 1) 
x = np.array(range(time_dur))
x = 0.03125*x
for i in range(toplot) :
    ax[i].plot(x,trace0[i], color = 'grey')
    ax[i].set_ylim(-500,500)
    if i != toplot - 1:
        ax[i].set_xticks([],[])
        ax[i].spines["bottom"]
    else :
        ax[i].set_xlabel('time (ms)', fontsize='14')
    ax[i].set_yticks([],[])
    if i != 0 :
        ax[i].spines["top"].set_visible(False)
fig.subplots_adjust(wspace=0, hspace=0)

#fig, axes = plt.subplots(nrows= 1, ncols= 1, figsize=(22,24))
#mr.plot_templates(mea_data[method].tempgen, template_ids=70, drifting=True, cmap='Reds_r', ax = axes)

In [None]:
sorting_num = mea_data[methods[0]].sorting_si.select_units(['#' + str(template_number)])     
wv_num = si.extract_waveforms(recording = mea_data[methods[0]].recording_si, sorting= sorting_num, overwrite= False, folder=f"{mea_data[methods[0]].folder}//waveforms//wv_{mea_data[methods[0]].probe}_{template_number}", load_if_exists=True, max_spikes_per_unit=None)
fig, ax = plt.subplots(1,1)
sw.plot_unit_waveforms(wv_num, max_channels=4, unit_colors={'#' + str(template_number): 'indianred'}, ax = ax, same_axis=True, plot_legend = False)
ax.spines["top"].set_visible(False)
ax.spines["right"].set_visible(False)
ax.set_ylim([-220,-150])
ax.set_title('')
