In [14]:
# to run GUI event loop
%autosave 180
%matplotlib qt
import matplotlib.pyplot as plt
plt.ion()

import numpy as np
#import nidaqmx
#from nidaqmx import stream_writers
import time
from tqdm import trange
#from pypylon import pylon
#from imageio import get_writer
from tqdm import trange
import os
from tqdm import tqdm

import scipy
from scipy.signal import butter, filtfilt


Autosaving every 180 seconds


In [266]:
#######################################################
#######################################################
#######################################################

#
def load_spikes(times, 
                data,
               window=30):

    spikes = []
    for time in times:
        #print (time)
        temp = data[time-window:time+window]
        spikes.append(temp)
        
    return np.array(spikes)

#
def get_templates(clusters, 
                  data,
                  times):

    #
    window = 30    
    
    #
    cell_ids = np.unique(clusters)
    max_chans = np.zeros(np.max(cell_ids)+1)+np.nan
    templates = np.zeros((np.max(cell_ids)+1,window*2,64))+np.nan
    n_spikes = np.zeros(np.max(cell_ids)+1)+np.nan
    for cell_id in tqdm(cell_ids):

        idx = np.where(clusters==cell_id)[0]

        #
        spikes = load_spikes(times[idx],  
                             data,
                            window)

        # scale spikes
        spikes = spikes*scaling

        #
        #print (cell_id, spikes.shape)

        #
        template = np.mean(spikes,axis=0)
        #print (cell_id, "template: ", template.shape)

        #
        ptp = np.ptp(template, axis = 0)

        max_chan = np.argmax(ptp)

        #
        max_chans[cell_id] = max_chan
        templates[cell_id] = template
        n_spikes[cell_id] = spikes.shape[0]
        
    return templates, max_chans, n_spikes

#
def fix_clusters(clusters):
    
    cell_ids = np.unique(clusters)
    clusters2 = clusters.copy()
    ctr=0
    for cell_id in cell_ids:
        idx = np.where(clusters==cell_id)[0]
        clusters2[idx]=ctr
        ctr+=1
        
    return clusters2
    
def butter_bandpass(lowcut, highcut, fs, order=5):
    nyq = 0.5 * fs
    low = lowcut / nyq
    high = highcut / nyq
    b, a = butter(order, [low, high], btype='band')
    return b, a


def butter_bandpass_filter(data, lowcut, highcut, fs, order=5):
    b, a = butter_bandpass(lowcut, highcut, fs, order=order)
    y = filtfilt(b, a, data)
    return y


#
def plot_single_templates(templates,
                          unit_id,
                          max_chans):
    
    locs = np.zeros((64,2))
    scalex = 2
    scaley = 250
    for k in range(64):
        tempx = (k%8)*scalex
        tempy = (k//8)*scaley
        locs[k] = [tempx, tempy]

    plt.figure()

    plt.scatter(locs[:,0],
                locs[:,1],
                c='black',
               s=10)
    
    #
    
    #
    t = np.arange(templates.shape[1])/60
    template = templates[unit_id]

    for chan in range(locs.shape[0]):

        temp = template[:,chan]  

        #
        plt.plot(t+locs[chan,0],
                 temp+locs[chan,1],
                 c='black')
        
        #
        plt.text(locs[chan,0],
                 locs[chan,1],
                 str(chan))
        
    #
    plt.xticks([])
    plt.yticks([])
    plt.show()    

#
def get_tetrode_geometry():
    tetrodes_ids = [
    [10,12,14],
    [17,19,21,23],
    [25,27,29,31],  # 23 is nearby and 21l CNOFIRMED
    [32,34,36,38],  # unclear re: 32, 34

    [33,35,37,39],
    [48,50,52,54], 
    [49,51],
    [53,55,57,59],

    [1,3,5,7],
    [16,18,20,22],
    ]

    #
    #tetrodes_ids = np.array(tetrodes_ids)

    #

    ###########
    tetrodes_flat = sum(tetrodes_ids, []) #np.array(tetrodes_ids).flatten()


    tetrodes_no_id = np.delete(np.arange(64), 
                               tetrodes_flat)

    
    return tetrodes_ids, tetrodes_no_id

#
def make_tetrode_locs(tetrodes_ids):
    
        #
    locs = np.zeros((64,2))
    scalex = 1.1
    scaley = 125
    
    #
    centrex = 0
    centrey = 0
    ctr=0
    for tetrode_group in tetrodes_ids:
        
        #
        for q,k in enumerate(tetrode_group):
            locs[k,0] = (centrex + q%2)*scalex
            locs[k,1] = (centrey + q//2)*scaley
        
        #
        
        #
        ctr+=1
        centrex = (ctr%4)*3
        centrey = (ctr//4)*3 #*500
        

    return locs

#
def plot_single_templates_grouped_tetrodes(templates,
                                           tetrodes_ids, 
                                           locs,
                                           unit_id,
                                           max_chans,
                                           tetrode_clr_ctr,
                                          legend_flag=False):

    clrs = ['black','blue','red','green',
            'cyan','magenta','pink','brown',
           'yellow','darkblue']

    #
    t = np.arange(templates.shape[1])/60
    template = templates[unit_id]

    ptp = np.ptp(template,axis=0)
    max_chan = np.argmax(ptp)
    
    # find tetrode group belonging to
#    for chan in range(locs.shape[0]):

    for ctr_tr, tetrode_group in enumerate(tetrodes_ids):
        
        if max_chan in tetrode_group:
            clr = clrs[tetrode_clr_ctr[ctr_tr]%10]
            tetrode_clr_ctr[ctr_tr]+=1
            #    
            for ctr_k, k in enumerate(tetrode_group):
                #print ("plotting: ", k, locs[k])
                temp = template[:,k]  

                # label only first tetrode
                # TODO: can plot all tetrodes at once
                if ctr_k==0:
                    plt.plot(t+locs[k,0],
                         temp+locs[k,1],
                         c=clr,
                         label=str(unit_id)
                        )
                else:
                    plt.plot(t+locs[k,0],
                         temp+locs[k,1],
                         c=clr
                        )

                #
                plt.text(locs[k,0],
                         locs[k,1],
                         str(k))
        
    #
    plt.xticks([])
    plt.yticks([])
    
    if legend_flag:
        plt.legend()
    
    
#
def plot_all_templates(templates,
                   max_chans):
    
    
    locs = np.zeros((64,2))
    scalex = 2
    scaley = 250
    for k in range(64):
        tempx = (k%8)*scalex
        tempy = (k//8)*scaley
        locs[k] = [tempx, tempy]

    plt.figure()

    plt.scatter(locs[:,0],
                locs[:,1],
                c='black',
               s=10)
    
    #
    
    #
    t = np.arange(templates.shape[1])/60
    for k in range(templates.shape[0]):
        #
        max_chan = int(max_chans[k])
        temp = templates[k,:,max_chan]  
        
        #
        plt.plot(t+locs[max_chan,0],
                 temp+locs[max_chan,1])
        
    #
    plt.xticks([])
    plt.yticks([])
    plt.show()    
    


In [16]:
#########################################################
#########################################################
#########################################################
root_dir = '/media/cat/4TB1/donato/DON-011737/2022_11_08/'
fname_data = os.path.join(root_dir,
                          'continuous.dat')

#
data = np.fromfile(fname_data,
                  dtype=np.int16).reshape(-1,64)

print (data.shape)

# save filtered version
if True:
    fname_out = fname_data.replace('.dat','_filtered.dat')
    
    if os.path.exists(fname_out)==False:

        lowcut=500
        highcut=6000
        fs = 30000
        order = 2

        for k in trange(data.shape[1]):
            data[:,k] = butter_bandpass_filter(data[:,k], 
                                               lowcut,
                                               highcut,
                                               fs,
                                               order)
        np.save(fname_out, data)
    else:
        data = np.load(fname_out)


(19537920, 64)


100%|██████████| 64/64 [00:35<00:00,  1.78it/s]


In [20]:
#############################################
#############################################
#############################################

#
scaling = 0.195

#
clusters = np.load(os.path.join(root_dir, 'spike_clusters.npy')).squeeze()
times = np.load(os.path.join(root_dir, 'spike_times.npy')).squeeze().astype('int32')

#
clusters = fix_clusters(clusters)

#
templates, max_chans, n_spikes = get_templates(clusters, 
                                               data,
                                               times)

# sort cells by max channel
idx = np.argsort(max_chans)

#####################################################
#####################################################
#####################################################
t=np.arange(templates.shape[1])/30.
plt.figure()
ctr=0
for cell_id in tqdm(idx):
    ax = plt.subplot(7,10,ctr+1)
    
    #
    plt.plot(t, templates[cell_id], alpha=.25, c='black')
    plt.title("id: " + str(cell_id)+", ch: "+str(max_chans[cell_id])+ ", #spk: " + str(n_spikes[cell_id]), fontsize=8)

    #
    if ctr<34:
        plt.xticks([])
    else:
        plt.xlabel("Time (msec)")
    ctr+=1
#
plt.suptitle("10minute recording \n"+fname_data)
plt.show()


100%|██████████| 70/70 [00:05<00:00, 12.44it/s]
100%|██████████| 70/70 [00:03<00:00, 19.53it/s]


In [23]:
#############################################
#############################################
#############################################

#
plot_all_templates(templates,
               max_chans)

#
plt.suptitle("Cell max chan locations + shape")
plt.show()

Text(0.5, 0.98, 'Cell max chan locations + shape')

In [267]:
############################################################
############ VIEW ALL CELLS ON ALL TETRODES ################
############################################################


#
tetrodes_ids, tetrodes_no_ids = get_tetrode_geometry()
locs = make_tetrode_locs(tetrodes_ids)

#
unit_ids = np.arange(70)

#
tetrode_clr_ctr = np.zeros(16,'int32')
plt.figure(figsize=(10,10))

#
for unit_id in unit_ids:

    #
    plot_single_templates_grouped_tetrodes(templates,
                                           tetrodes_ids, 
                                           locs,
                                           unit_id, 
                                           max_chans,
                                           tetrode_clr_ctr)

#
plt.suptitle("Cell max chan locations + shape")
plt.show()

In [270]:
############################################################
########## SAME THING AS ABOVE BUT PER TETRODE #############
############################################################

#
ptps = np.ptp(templates, axis=1)
max_chans = np.argmax(ptps,axis=1)

# find cells that are on the same tetrode
legend_flag = True
for tetrode in tetrodes_ids:
    
    plt.figure()
    #
    for k in range(templates.shape[0]):
        
        if max_chans[k] in tetrode:
    
            plot_single_templates_grouped_tetrodes(templates,
                                               tetrodes_ids, 
                                               locs,
                                               k, 
                                               max_chans,
                                               tetrode_clr_ctr,
                                               legend_flag)
    
    plt.title("Tetrode: "+str(tetrode))
    if False:
        plt.show()
    else:
        plt.savefig('/home/cat/tetrode_'+str(tetrode)+'.png')
        plt.close()
