In [102]:
import matplotlib
#matplotlib.use('Agg')
%matplotlib tk
%autosave 180
import matplotlib.pyplot as plt
from IPython.core.display import display, HTML
display(HTML("<style>.container { width:100% !important; }</style>"))

import matplotlib.cm as cm
from matplotlib import gridspec

import numpy as np
import pandas as pd
import os
import shutil
import cv2
import glob2
import parmap

from numba import jit

clrs = ['blue', 'red', 'green', 'cyan',
'black','grey','brown','slategrey','darkviolet','darkmagenta',
'orange','firebrick','lawngreen','dodgerblue','crimson','orchid','slateblue',
'darkgreen','darkorange','indianred','darkviolet','deepskyblue','greenyellow',
'peru','cadetblue','forestgreen','slategrey','lightsteelblue','rebeccapurple',
'darkmagenta','yellow','hotpink']


Autosaving every 180 seconds


In [205]:
# functions
def binary_reader_waveforms(filename, n_channels, n_times, spikes, data_type='float32'):
    ''' Reader for loading raw binaries
    
        standardized_filename:  name of file contianing the raw binary
        n_channels:  number of channels in the raw binary recording 
        n_times:  length of waveform 
        spikes: 1D array containing spike times in sample rate of raw data
        channels: load specific channels only
        data_type: float32 for standardized data
    
    '''

    # ***** LOAD RAW RECORDING *****
    wfs=[]
    if data_type =='float32':
        data_len = 4
    else:
        data_len = 2
    
    #filename = '/media/cat/1TB/data/synthetic/run5/data_int16.bin'
    with open(filename, "rb") as fin:
        for ctr,s in enumerate(spikes):
            # index into binary file: time steps * 4  4byte floats * n_channels
            if True:
            #try:
                fin.seek(s * data_len * n_channels, os.SEEK_SET)
                
                temp = np.fromfile(
                    fin,
                    dtype='int16',
                    count=(n_times * n_channels))
                                
                wfs.append(temp.reshape(n_times, n_channels))

            #except:
            #    print ("can't load spike: ", s)
            #    pass
    fin.close()
    return np.array(wfs)

@jit
def match_units(s1, s2):
    ctr=0
    max_diff = 5
    matched_spike_times = []
    for s in s1:
        if np.min(np.abs(s2-s))<=max_diff:
            ctr+=1
            matched_spike_times.append(s)
            
    return (ctr, matched_spike_times)


def search_spikes_parallel(units_ground_truth,
                      templates_gt,
                      max_chans_sorted,
                      spike_train_gt,
                      spike_train_sorted):
    n_spikes = []
    ids_matched = []
    ptp_gt_unit = []
    purity = []
    completeness = []
    matched_spikes_array = []
    n_spikes_array = []
    venn_array = []
    unit_ids = []
    print ("units_ground_truth: ", units_ground_truth)
    for unit in units_ground_truth:
        #ptps_unit = templates_sorted[unit].ptp(0)
        ptps_unit = templates_gt[unit].ptp(0)
        max_chans_unit = np.argsort(ptps_unit)[::-1][:40]  #check largest nearby channels;
        # find nearest gt templates
        ids_nearest_sorted = np.where(np.in1d(max_chans_sorted, max_chans_unit))[0]

        print (" Matching unit: ", unit, " , with ", ids_nearest_sorted)

        # search spike time matches
        match_spikes = []
        all_spikes = []
        match_ids = []
        matched_spike_times_local = []
        idx2 = np.where(spike_train_gt[:,1]==unit)[0]
        # save n_spikes for each unit
        venn_spikes = []
        for id_ in ids_nearest_sorted:
            #print ("searching match unit: ", id_)
            idx_3 = np.where(spike_train_sorted[:,1]==id_)[0]
            matches, matched_spike_times = match_units(spike_train_gt[idx2,0], 
    #                              spike_train[idx2,0])
                                  spike_train_sorted[idx_3,0])

            #if matches>match_spikes:
            match_ids.append(id_)
            match_spikes.append(matches)
            all_spikes.append(idx_3.shape[0])
            matched_spike_times_local.append(matched_spike_times)
            venn_spikes.append(matches)

        venn_array.append(venn_spikes)
        n_spikes_array.append(idx2.shape[0])
        #purity.append(match_spikes/float(idx2.shape[0]))
        #completeness.append(match_spikes/float(all_spikes))

        matched_spikes_array.append(matched_spike_times_local)

        # save original id plus match_id
        ids_matched.append([unit,match_ids])

        # save ptp of gt unit
        ptps_unit = templates_gt[unit].ptp(0).max(0)
        ptp_gt_unit.append(ptps_unit)
        print ("done matching: ", unit)
        unit_ids.append(unit)
        
    return (n_spikes_array, ids_matched, ptp_gt_unit, purity,
            completeness, matched_spikes_array,
            n_spikes_array, venn_array, unit_ids)

def load_ks2_spikes(root_dir, n_channels,n_times):

    fname_out = root_dir + 'spike_train_final.npy'

    if os.path.exists(fname_out)==False:

        fname = root_dir+'/data_int16.bin'
        data_type = 'int16'

        #n_channels = 384
        #n_times = 101

        # load KS2 sorted times
        times = np.load(root_dir+'spike_times.npy')-n_times//2
        ids = np.load(root_dir +'spike_clusters.npy')

        # reorer the KS2 spike_train
        ctr=0
        for k in np.unique(ids):
            idx = np.where(ids==k)[0]
            ids[idx]=ctr
            ctr+=1

        spike_train = np.hstack((times,ids))
        print (spike_train.shape)

        units = np.unique(spike_train[:,1])
        print ("units: ", units.shape)
        ids, counts = np.unique(spike_train[:,1], return_counts=True)

        # parse KS2 units and keep only ones with minimum or max firing rates
        min_spikes = 600/4
        max_spikes = 6000*500

        good_units_ids = np.where(np.logical_and(counts>=min_spikes, counts<=max_spikes))[0]
        print ("# of good units: ", good_units_ids.shape[0], " of total KS2 units: ", ids.shape[0])

        # reorder the spikes_going forward
        ctr=0
        spike_train_final = np.zeros((0,2),'int32')
        for good_unit in good_units_ids:
            idx= np.where(spike_train[:,1]==good_unit)[0]
            spike_train[idx,1]=ctr

            spike_train_final = np.int32(np.vstack((spike_train_final,spike_train[idx])))
            ctr+=1

        print (spike_train_final)

        spike_train = spike_train_final.copy()
        print (" DONE ")

        # save spike train and templates_reloaded
        np.save(fname_out, spike_train)
        print(spike_train)

    else:
        spike_train = np.load(fname_out)
    
    return spike_train

def reload_ks2_templates(root_dir, spike_train, data_type):
    
    fname_out = root_dir + '/templates_reloaded_good.npy'
    if os.path.exists(fname_out)==False:
        
        # name of binary file
        fname = root_dir + 'data_int16.bin'
        
        templates = []
        ptps = []
        time_start = 0
        time_end = time_start+600
        wfs_array = []
        times = []
        for ctr, unit in enumerate(np.unique(spike_train[:,1])):#[:20]:
        #for ctr, unit in enumerate([248]):#[:20]:
            idx = np.where(spike_train[:,1]==unit)[0]
            spikes = np.int32(spike_train[idx,0]) #-n_times//2#-30

            # sub sample spikes to speed up loading
            idx = np.where(np.logical_and(spikes>=(time_start*30000), spikes<(time_end*30000)))[0]
            spikes = spikes[idx]
            if idx.shape[0]==0:
                ptps.append([])
                templates.append(np.zeros((n_times,n_channels)))
                times.append([])
                continue

            times.append(spikes)
            #print (spikes.shape)

            idx2 = np.where(spikes<60*30000)[0]
            spikes=spikes[idx2]

            wfs = binary_reader_waveforms(fname, n_channels, n_times, spikes, data_type)

            if wfs.shape[0]==0:
                wfs = np.zeros((10,n_times,n_channels))
                continue 
            # save template using only first 60 sec of data;
            temp = wfs.mean(0)
            templates.append(temp)

            print (ctr, '/', np.unique(spike_train[:,1]).shape[0], ' raw id: ', unit, 
                   wfs.shape, temp.shape, "oiringla spikes: ", idx.shape)

        np.save(fname_out, templates)
    else:
        templates = np.load(fname_out)

    return np.array(templates)

class Match_to_ground_truth(object):

    def __init__(self, root_dir, spike_train_sorted, templates_sorted):
        
        self.root_dir = root_dir
        self.spike_train_sorted = spike_train_sorted
        self.templates_sorted = templates_sorted
        
        # load ground truth data
        self.spike_train_gt = np.load(self.root_dir + 'ground_truth/spike_train_ground_truth.npy')
        self.templates_gt = np.load(self.root_dir + 'ground_truth/templates_ground_truth.npy')
        print (" ground truth templates: ", self.templates_gt.shape)
        self.units_ground_truth = np.arange(self.templates_gt.shape[0])
        self.max_chans_gt = self.templates_gt.ptp(1).argmax(1)
        #print ("max chans gt: ", self.max_chans_gt)

        # load sorted templates
        #print ("Templates sorted: ", self.templates_sorted.shape)
        self.units_sorted = np.arange(self.templates_sorted.shape[0])
        self.max_chans_sorted = self.templates_sorted.ptp(1).argmax(1)

        #print ("calling matchi units")
        self.match_units()
        
    def match_units(self):
        fname = self.root_dir + 'matches_res.npy'
        if os.path.exists(fname)==False:

            self.units_split = np.array_split(self.units_ground_truth, 6)
            #print (units_split)

            res = parmap.map(search_spikes_parallel,
                              self.units_split,
                              self.templates_gt,
                              self.max_chans_sorted,
                              self.spike_train_gt,
                              self.spike_train_sorted,
                              pm_processes=6)

            np.save(fname, res)
        else:
            res = np.load(fname, allow_pickle=True)

        self.n_spikes = []
        self.ids_matched = []
        self.ptp_gt_unit = []
        self.purity = []
        self.completeness = []
        self.matched_spikes_array = []
        self.n_spikes_array = []
        self.venn_array = []
        self.unit_ids = []
        #print (" # chunks: ", len(res))

        for k in range(len(res)):
            self.n_spikes.extend(res[k][0])
            self.ids_matched.extend(res[k][1])
            self.ptp_gt_unit.extend(res[k][2])
            self.purity.extend(res[k][3])
            self.completeness.extend(res[k][4])
            self.matched_spikes_array.extend(res[k][5])
            self.n_spikes_array.extend(res[k][6])
            self.venn_array.extend(res[k][7])
            self.unit_ids.extend(res[k][8])


    def make_pie_charts(self, n_matches2):
        
        print ("n matches: ", n_matches2)
        fig =plt.figure()
        colors = ['blue','red','green','magenta',
                  'cyan','yellow','pink','orange',
                  'brown','darkgreen']

        ptps = self.templates_gt.ptp(1).max(1)

        #n_spikes_array.
        for k in range(len(self.venn_array)):
            ax=plt.subplot(10,10,k+1)

            n_matches = len(self.venn_array[k])
            sizes = np.sort(self.venn_array[k])[::-1][:n_matches2]

            clrs = colors[:sizes.shape[0]]
            print ("sizes:" , sizes, " , clrs: ", clrs)
            # if the total number of spikes found is smaler than all spikes injected, add black piechart
            if sizes.sum(0)<self.n_spikes_array[k]:
                sizes = np.append(sizes, np.array(self.n_spikes_array[k]-sizes.sum(0)))
                clrs = np.append(clrs,'black')

            plt.pie(sizes, colors=clrs)
            plt.title(str(k)+" ptp: " +str(np.round(ptps[k],1)), fontsize=8)

        plt.suptitle("Drift simulation of injected neurons drift: ..."+
                     " sorted neurons (blue=best match, red=second best, green=third...)"+
                     "\n(sum matches can be > 100% depending on oversplits/duplicate units)",fontsize=12)
        plt.show()

# plot single unit max channel template and scatter plot
def plot_single_unit(root_dir, selected_unit,
                     matcher, n_matches2, clrs,
                     scale_amplitude):
    
    spike_train_gt = np.load(root_dir + 'ground_truth/spike_train_ground_truth.npy')
    templates_sorted = np.load(root_dir + 'templates_reloaded_good.npy')
    max_chans_sorted = templates_sorted.ptp(1).argmax(1)

    # fix n-chans and sample rate for computations below
    n_chans = templates_sorted.shape[2]
    sample_rate = 30000
    
    fname_int16 = root_dir + 'data_int16.bin'
    data_type = 'int16'

    # find units that match current unit
    matched_units = np.array(matcher.ids_matched[selected_unit][1])
    print ("matched_units: ", matched_units)

    # count # of spikes in each matching unit;
    n_spk = []
    for k in range(len(matcher.matched_spikes_array[selected_unit])):
        n_spk.append(len(matcher.matched_spikes_array[selected_unit][k]))
    
    # select top 5 matching units by size and plot them;
    n_spk = np.hstack(n_spk)
    idx = np.argsort(n_spk)[::-1][:n_matches2]

    wfs=[]
    ptps_sorted = []
    times_sorted = []
    for id_ in idx:
        print (" unit: ", selected_unit,", matching unit: ", id_)
        spk = np.array(matcher.matched_spikes_array[selected_unit][id_])
        idx = np.argsort(spk)
        spk=spk[idx]

        temp1 = binary_reader_waveforms(fname_int16, matcher.n_channels, matcher.n_times, spk, data_type)
            #def binary_reader_waveforms(filename, n_channels, n_times, spikes, data_type='float32'):

        wfs.append(temp1.copy())

        if temp1.shape[0]==0:
            ptps_sorted.append([])
            times_sorted.append([])
            continue
        
        times_sorted.append(spk)

        # compute ptps
        temp2 = temp1.mean(0)
        max_chan = temp2.ptp(0).argmax(0)

        # select fixed points of waveform max/min at which to compute PTP
        max_ = np.argmax(temp2[:,max_chan])
        min_ = np.argmin(temp2[:,max_chan])
        ptps_local =np.array(temp1[:,max_,max_chan]-
                             temp1[:,min_,max_chan])
        
        ptps_local = ptps_local/scale_amplitude
        ptps_sorted.append(ptps_local)

        print ("")

    # ****************** GROUND TRUTH UNIT COMPUTATION ***************
    # load spikes for ground truth unit:
    print (" loading injected unit: ", selected_unit)
    idx = np.where(spike_train_gt[:,1]==selected_unit)[0]
    spk_gt = spike_train_gt[idx,0]
    tot_spikes_gt = spk_gt.shape[0]
    #idx = np.argsort(spk_gt)
    #spk_gt=spk_gt[idx]
    
    wfs_gt = binary_reader_waveforms(fname_int16, matcher.n_channels, matcher.n_times, spk_gt, data_type)
    print (" ground truth wfs: ", wfs_gt.shape)    

    # compute ptps for the ground truth unit
    # select only first 1 minute of data to find peak/trough to limit drift artifacts
    idxt = np.where(spk_gt<sample_rate*60)[0]    
    temp2 = wfs_gt[idxt].mean(0)
    max_chan = temp2.ptp(0).argmax(0)
    max_ = np.argmax(temp2[:,max_chan])
    min_ = np.argmin(temp2[:,max_chan])
    
    # then compute ptp for all loaded data;
    ptps_gt =np.array(wfs_gt[:,max_,max_chan]-
                         wfs_gt[:,min_,max_chan])
    ptps_gt = ptps_gt/scale_amplitude

    # *************** PLOT RESULTS ************
    fig = plt.figure()
    gs = gridspec.GridSpec(n_matches2+1, 2)
    #gs.update(wspace=0.05, hspace=0.05)
        
    # plot time elapse unit
    cmap = cm.get_cmap('viridis',wfs_gt.shape[0])
    ax = plt.subplot(gs[0, 0])

    #for k in range(0, wfs_gt.shape[0],10):        
    for k in range(0, wfs_gt.shape[0],1):        
        temp = wfs_gt[k,:,max_chan].T/10.
        plt.plot(temp, c=cmap(k),alpha=.05)

    #spk_redone = np.array(spk_redone) 
    #ptps_redone = np.array(ptps_redone)
    ax.set_title("Max chan template (color=time)",fontsize=14)
    plt.ylabel("Template (SU)", fontsize=14)
    
    # plot scatter of ground truth unit
    ax = plt.subplot(gs[0, 1])
    #print (spk_redone.shape)
    #plt.scatter(spk_gt/30000., ptps_gt/10.,s=150, c='black', alpha=.5)
    plt.scatter(spk_gt/30000., ptps_gt,s=150, c='black', alpha=.5)
    plt.xticks([])
    plt.ylim(bottom=0)
    plt.title("PTP of ground truth injected unit", fontsize=14)

    # plot scatter of sorted units:
    for k in range(len(ptps_sorted)):
        ax = plt.subplot(gs[k+1, 1])
        if len(times_sorted[k])>0:
            plt.scatter(times_sorted[k]/30000., ptps_sorted[k],s=150, c=clrs[k], alpha=.5)
        plt.ylim(bottom=0)
        
        size = np.sort(matcher.venn_array[selected_unit])[::-1][k]

        tot_spikes = matcher.n_spikes_array[selected_unit]
        tp_rate = size/float(tot_spikes)   
        
        plt.title("Match #"+str(k)+ ", TP: "+str(round(tp_rate*100.,1))+"%", fontsize=14)
        if k <(len(ptps_sorted)-1):
            plt.xticks([])

    
#     #n_matches = len(matcher.venn_array[selected_unit])
#     sizes = np.sort(matcher.venn_array[selected_unit])[::-1][:n_matches2]

#     tot_spikes = matcher.n_spikes_array[selected_unit]
#     tp_rate = sizes/float(tot_spikes)
    
#     plt.title("PTP of best "+str(n_matches2)+" matching neurons, TP: "+str(np.round(tp_rate,3)*100)+"%", fontsize=14)
    plt.xlabel("Time (sec)", fontsize=14)
    
    temp = os.path.getsize(fname_int16)
    rec_len_sec = temp//2//n_channels//sample_rate
    print ("rec len sec: ", rec_len_sec)
    plt.suptitle("Injected unit: "+str(selected_unit)+ " with firing rate: "+\
                 str(round(tot_spikes_gt/float(rec_len_sec),1))+"Hz",fontsize=14)
    plt.show()

In [206]:
# dataset aprameters:
root_dir = ('/media/cat/1TB/Dropbox/data_temp/liam/data/neuropixels/run8/')
n_channels = 384
n_times = 101

# convert KS2 spike trains to 2 column standard and reorder spikes by sequential index
spike_train_sorted = load_ks2_spikes(root_dir, n_channels, n_times)
#print ("Spike train: \n", spike_train_sorted)

# remake templates from data
print ("Remaking templates ")
data_type='int16'
templates_sorted = reload_ks2_templates(root_dir, spike_train_sorted, data_type)
print ("templates sorted: ", templates_sorted.shape)


Remaking templates 
templates sorted:  (73, 101, 384)


In [207]:
# match to ground truth data;
matcher = Match_to_ground_truth(root_dir, spike_train_sorted, templates_sorted)
print ("n_spikes: ", matcher.n_spikes)
matcher.n_times = n_times
matcher.n_channels = n_channels

 ground truth templates:  (100, 101, 384)
n_spikes:  [388, 162, 136, 512, 349, 313, 316, 305, 131, 227, 58, 347, 391, 370, 204, 299, 368, 273, 316, 500, 499, 363, 306, 183, 145, 518, 127, 374, 267, 107, 372, 162, 317, 335, 558, 164, 314, 163, 297, 233, 103, 573, 366, 569, 140, 80, 469, 172, 491, 202, 327, 168, 585, 392, 159, 195, 86, 608, 372, 127, 318, 524, 411, 474, 447, 93, 620, 232, 82, 464, 360, 481, 142, 114, 510, 435, 263, 100, 145, 185, 535, 496, 380, 455, 572, 96, 546, 212, 75, 552, 616, 102, 393, 604, 496, 528, 111, 127, 250, 548]


In [208]:
# Make Pie Charts from results;
n_matches=3
matcher.make_pie_charts(n_matches)

n matches:  3
sizes: [388   9   6]  , clrs:  ['blue', 'red', 'green']
sizes: [162   5   2]  , clrs:  ['blue', 'red', 'green']
sizes: [136   2   1]  , clrs:  ['blue', 'red', 'green']
sizes: [507   9   8]  , clrs:  ['blue', 'red', 'green']
sizes: [347   6   6]  , clrs:  ['blue', 'red', 'green']
sizes: [311   5   5]  , clrs:  ['blue', 'red', 'green']
sizes: [315   5   3]  , clrs:  ['blue', 'red', 'green']
sizes: [303   3   3]  , clrs:  ['blue', 'red', 'green']
sizes: [131   2   2]  , clrs:  ['blue', 'red', 'green']
sizes: [226   6   5]  , clrs:  ['blue', 'red', 'green']
sizes: [1 0 0]  , clrs:  ['blue', 'red', 'green']
sizes: [345   4   3]  , clrs:  ['blue', 'red', 'green']
sizes: [387   7   7]  , clrs:  ['blue', 'red', 'green']
sizes: [364   9   8]  , clrs:  ['blue', 'red', 'green']
sizes: [203   4   3]  , clrs:  ['blue', 'red', 'green']
sizes: [290 184 114]  , clrs:  ['blue', 'red', 'green']
sizes: [338   5   4]  , clrs:  ['blue', 'red', 'green']
sizes: [271   7   5]  , clrs:  ['blue', 

In [194]:
scale_amplitude = 10.
selected_unit = 90 #15 #47 #39 #21
# plot ptp of spikes from different units:

n_matches = 3
plot_single_unit(root_dir, selected_unit, matcher, n_matches, clrs, scale_amplitude)

matched_units:  [16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 34 35 71]
 unit:  90 , matching unit:  9

 unit:  90 , matching unit:  13

 unit:  90 , matching unit:  4

 loading injected unit:  90
 ground truth wfs:  (616, 101, 384)
rec len sec:  60


(162, 101, 384)
(162,)
1 [162   5   2   2   2   2   2   1]
Tp rate:  1.0


In [15]:
print (len(ptp_gt_unit), len(purity),len(completeness))
print (ptp_gt_unit)
print (purity)
ax=plt.subplot(121)
plt.scatter(ptp_gt_unit, purity, c='blue')
plt.title("True Positive Rate \n(# spikes found / # spikes injected)", fontsize=15)
plt.xlabel("PTP of unit", fontsize=20)
plt.xlim(left=0)
plt.ylim(0,1.05)
ax.tick_params(axis='both', which='major', labelsize=20)

# plot purity
ax=plt.subplot(122)
ax.tick_params(axis='both', which='major', labelsize=20)
plt.scatter(ptp_gt_unit,completeness, c='red')
plt.title("Purity \n(# spikes found / # spikes in unit)", fontsize=15)
plt.xlabel("PTP of injected unit", fontsize=20)
plt.xlim(left=0)
plt.ylim(0,1.05)
plt.suptitle("Neuropixels Drift Simulation run3\n shifts: "+
             str(np.array([-0.1,  -0.05,  0.,    0.05])*40)+" um"+ " @ "+
             str(np.array([0.,  6000000., 12000000., 18000000.])/30000)+ " sec"+
            "\n10mins recording, 384 channels"+
            "\n Sorted using KS2")
plt.show()

49 49 49
[57.585655, 49.991535, 48.35278, 44.536102, 41.26409, 40.65353, 37.95051, 35.745094, 35.424194, 35.34573, 34.917114, 34.795994, 33.814716, 31.321787, 31.216097, 31.076012, 30.922415, 30.80615, 30.603268, 28.782059, 28.524315, 28.427055, 28.194874, 27.681364, 27.405285, 27.243063, 27.127403, 27.033249, 26.95942, 26.792904, 26.654663, 26.588776, 26.568539, 26.40509, 26.169048, 25.969437, 25.380676, 25.221592, 25.078396, 24.911406, 24.725485, 24.550064, 24.500408, 24.251526, 23.799553, 23.74379, 23.687141, 23.245668, 23.198347]
[0.6089331927537621, 0.5922794431633785, 0.9848484848484849, 0.574947084856648, 0.9981648169186402, 0.5800302571860817, 0.9991024490319271, 0.005543633762517883, 0.9980595084087969, 0.5879856349983676, 0.9989862757330006, 0.9976643041237113, 0.639399293286219, 0.517592387696393, 0.997495183044316, 0.8070210038457746, 0.5308740268931351, 0.9986566362170876, 0.6766889724081034, 0.6419079089145895, 0.9983930902882394, 0.5576145552560647, 0.8261060393258427, 0

In [13]:
t1 = np.load('/media/cat/1TB/data/synthetic/run3/ground_truth/templates_ground_truth.npy')
print (t1.shape)
t2 = np.load('/media/cat/1TB/data/synthetic/run3/templates_reloaded_good.npy')
print (t2.shape)

max1 = t1[2].ptp(0).argmax(0)
max2 = t2[50].ptp(0).argmax(0)
print (max1, max2)

ax=plt.subplot(121)
plt.plot(t1[2],c='black')
ax=plt.subplot(122)
plt.plot(t2[50]/10.,c='blue')
plt.show()


(49, 101, 384)
(104, 101, 384)
249 249


In [32]:
d1 = max_chans_sorted
d2 = max_chans_unit
t1, t2, t3 = np.intersect1d(d1,d2,return_indices=True)
print (t1,t2,t3)

t1 = np.where(np.in1d(d1, d2))[0]
print (t1)



[249 251] [46 47] [0 1]
[0 1]


In [67]:
spike_train = np.load('/media/cat/1TB/data/synthetic/run3/ground_truth/spike_train_ground_truth.npy')
print (spike_train)
templates = np.load('/media/cat/1TB/data/synthetic/run3/ground_truth/templates_ground_truth.npy')
print (templates.shape)

units = np.load('/media/cat/1TB/data/synthetic/run3/ground_truth/units.npy')
print (units)

ctr=0
spike_train_corrected = spike_train.copy()
for unit in units:
    idx = np.where(spike_train[:,1]==unit)[0]
    spike_train_corrected[idx,1]=ctr
    ctr+=1
    
print (spike_train_corrected)
np.save('/media/cat/1TB/data/synthetic/run3/ground_truth/spike_train_ground_truth_corrected.npy',spike_train_corrected)

[[      35      420]
 [    6119      420]
 [   12204      420]
 ...
 [17983395      247]
 [17990820      247]
 [17998245      247]]
(49, 101, 384)
[420 339 421 348 473 418  62 362 565 239 498 474 324 298 471 441 417 443
 356 424  63 437 485 160 156 302 346 334 431 486  66 204  14 258 300  16
 150  52 317 277 184 456  65 481  38 436 558 194 247]
[[      35        0]
 [    6119        0]
 [   12204        0]
 ...
 [17983395       48]
 [17990820       48]
 [17998245       48]]


In [21]:
units=np.arange(5)
units = [good_units[3]]

for unit in units:
    fig = plt.figure()
    ax = plt.subplot(131)
    #plt.plot(templates[unit].T)
    plt.plot(wfs_array[unit].mean(0))

    ax = plt.subplot(132)
    #spikes = spikes[idx]
    spikes = times[unit]
    print ("spikes:", spikes.shape, "ptps: ", ptps[unit].shape)
    plt.scatter(spikes/30000., ptps[unit])

    ax=plt.subplot(133)
    max_chan = wfs_array[unit].mean(0).ptp(0).argmax(0)#templates[unit].ptp(0).argmax(0)
    print (wfs_array[unit].shape)

    cmap = cm.get_cmap('viridis',wfs_array[unit].shape[0])
    clrs = cmap(np.arange(wfs_array[unit].shape[0]))
    print (clrs)
    for k in range(0,wfs_array[unit].shape[0],10):
        plt.plot(wfs_array[unit][k].T,c=cmap(k))


plt.show()



spikes: (8210,) ptps:  (8210,)
(8210, 101)
[[0.267004 0.004874 0.329415 1.      ]
 [0.267004 0.004874 0.329415 1.      ]
 [0.267004 0.004874 0.329415 1.      ]
 ...
 [0.993248 0.906157 0.143936 1.      ]
 [0.993248 0.906157 0.143936 1.      ]
 [0.993248 0.906157 0.143936 1.      ]]


In [9]:
fname = '/home/cat/Downloads/yass/samples/10chan/phy/whitening_mat.npy'

n_channels = 10

whitening_mat = np.random.rand(n_channels,n_channels)

np.save(fname, whitening_mat)
print (whitening_mat.shape)
print (whitening_mat)

(10, 10)
[[0.61659694 0.20125714 0.48515898 0.45059603 0.89409843 0.8158124
  0.30200264 0.06808628 0.60955651 0.7760584 ]
 [0.19814057 0.21134196 0.27791917 0.52312241 0.20263871 0.21097603
  0.34816364 0.1852993  0.90506002 0.80562087]
 [0.970196   0.70466193 0.08746122 0.33371635 0.44791585 0.00678562
  0.41421277 0.65443008 0.8541354  0.41522301]
 [0.88052786 0.34072116 0.12298287 0.23288985 0.98943396 0.56546028
  0.38736253 0.87851743 0.83882004 0.96018176]
 [0.5622632  0.19456103 0.89831444 0.56587043 0.12599613 0.24817026
  0.47188959 0.55809954 0.72628178 0.72821972]
 [0.59236306 0.06332327 0.35662913 0.93425548 0.15622572 0.77164442
  0.62013196 0.67339586 0.97380542 0.9211003 ]
 [0.55011567 0.22242238 0.41892959 0.23396425 0.80036351 0.00612786
  0.72244072 0.1119026  0.84247935 0.13666388]
 [0.1674969  0.17014984 0.70297569 0.28534628 0.87704159 0.62894665
  0.38268201 0.47207143 0.53495952 0.46683818]
 [0.58624526 0.67612232 0.60880662 0.66169478 0.10866143 0.88197916
  0.

In [24]:
max_chans = templates.ptp(1).argmax(1)
ptps = templates.ptp(1).max(1)
ptps_all = templates.ptp(1)
print (ptps_all.shape)

print (max_chans)
print (ptps)


(426, 364)
[  0   1   1   1   2   3   4   6   7   5  10  13  14  14  15  16  17  18
  20  21  30  26  29  29  30  31  32  35  37  41  42  43  45  46  49  50
  38  56  58  60  69  68  68  68  72  70  71  74  75  76 342  80  79  80
  80  83  82  84 333  85  86  87  88  88  88  88  89  89  90  88  87  92
 336  94  95  98  99  84 101 103 104 108 109 109 111 111 111 112 114 115
 119 119 120 121 125 127 127 124 131 133 134 134 132 140 137 138 140 332
 140 140 141 142 142 140 143 145 146 147 147 149 149 149 149 150 284 149
 154 154 154 162 164 165 166 164 169 170 173 173 174 177 178 173 180 181
 182 186 182 185 185 186 186 189 190 192 194 197 198 199 201 204 204 206
 208 208 209 209 213 209 210 212 205 209 213 213 213 213 216 216 216 217
 219 220 226 228 229 229 229 230 230 233 235 215 236 236 238 239 239 242
 245 245 248 248 248 248 249 249 252 252 253 253 262 262 262 262 262 262
 262 262 262 262 262 262 252 265 265 265 266 266 257 268 268 253 273 273
 274 277 277 277 280 262 280 282 280 284

In [33]:
# plot scatter vs. depth
fig=plt.figure()
ax=plt.subplot(111)
for k in range(426):
    idx = np.where(spike_clusters==k)[0]
    times = spike_times[idx]/30000.
    plt.scatter(times, times*0+geom[max_chans[k],1]+(np.random.rand()*80-40),s=1, alpha=.9)
    
ax.tick_params(axis = 'both', which = 'major', labelsize = 20)
plt.ylim(3800,-100)
plt.xlim(-10,times[-1]+10)
plt.ylabel("Depth (um)",fontsize=20)
plt.xlabel("Time (sec)",fontsize=20)

plt.show()

