In [1]:
# to run GUI event loop
%autosave 180
%matplotlib qt
import matplotlib.pyplot as plt
plt.ion()

import numpy as np
#import nidaqmx
#from nidaqmx import stream_writers
import time
from tqdm import trange
#from pypylon import pylon
#from imageio import get_writer
from tqdm import trange
import os
from tqdm import tqdm

import scipy
from scipy.signal import butter, filtfilt
import scipy.signal


Autosaving every 180 seconds


In [2]:
#######################################################
#######################################################
#######################################################

#
def load_spikes(times, 
                data,
               window=30):

    spikes = []
    for time in times:
        #print (time)
        temp = data[time-window:time+window]
        if temp.shape[0]<(window*2):
            continue
        
        spikes.append(temp)
        
    return np.array(spikes)

#
def get_templates(clusters, 
                  data,
                  times):

    #
    window = 30    
    
    #
    cell_ids = np.unique(clusters)
    max_chans = np.zeros(np.max(cell_ids)+1)+np.nan
    templates = np.zeros((np.max(cell_ids)+1,window*2,64))+np.nan
    n_spikes = np.zeros(np.max(cell_ids)+1)+np.nan
    for cell_id in tqdm(cell_ids):

        idx = np.where(clusters==cell_id)[0]

        #
        spikes = load_spikes(times[idx],  
                             data,
                             window)


        # scale spikes
        spikes = spikes*scaling

        #
        #print (cell_id, spikes.shape)

        #
        template = np.mean(spikes,axis=0)
        #print (cell_id, "template: ", template.shape)

        #
        ptp = np.ptp(template, axis = 0)

        max_chan = np.argmax(ptp)

        #
        max_chans[cell_id] = max_chan
        templates[cell_id] = template
        n_spikes[cell_id] = spikes.shape[0]
        
    return templates, max_chans, n_spikes

#
def fix_clusters(clusters):
    
    cell_ids = np.unique(clusters)
    clusters2 = clusters.copy()
    ctr=0
    for cell_id in cell_ids:
        idx = np.where(clusters==cell_id)[0]
        clusters2[idx]=ctr
        ctr+=1
        
    return clusters2
    
def butter_bandpass(lowcut, highcut, fs, order=5):
    nyq = 0.5 * fs
    low = lowcut / nyq
    high = highcut / nyq
    b, a = butter(order, [low, high], btype='band')
    return b, a


def butter_bandpass_filter(data, lowcut, highcut, fs, order=5):
    b, a = butter_bandpass(lowcut, highcut, fs, order=order)
    y = filtfilt(b, a, data)
    return y

def notch_filter(data, notch_freq,
                 samp_freq=30000):

# Create/view notch filter
#samp_freq = 1000  # Sample frequency (Hz)
#notch_freq = 60.0  # Frequency to be removed from signal (Hz)
    quality_factor = 30.0  # Quality factor
    b_notch, a_notch = scipy.signal.iirnotch(notch_freq, quality_factor, samp_freq)
    freq, h = scipy.signal.freqz(b_notch, a_notch, fs = samp_freq)
    y_notched = scipy.signal.filtfilt(b_notch, a_notch, data)
    
    return y_notched
    
    


#
def plot_single_templates(templates,
                          unit_id,
                          max_chans):
    
    locs = np.zeros((64,2))
    scalex = 2
    scaley = 250
    for k in range(64):
        tempx = (k%8)*scalex
        tempy = (k//8)*scaley
        locs[k] = [tempx, tempy]

    plt.figure()

    plt.scatter(locs[:,0],
                locs[:,1],
                c='black',
               s=10)
    
    #
    
    #
    t = np.arange(templates.shape[1])/60
    template = templates[unit_id]

    for chan in range(locs.shape[0]):

        temp = template[:,chan]  

        #
        plt.plot(t+locs[chan,0],
                 temp+locs[chan,1],
                 c='black')
        
        #
        plt.text(locs[chan,0],
                 locs[chan,1],
                 str(chan))
        
    #
    plt.xticks([])
    plt.yticks([])
    plt.show()    


#
def make_tetrode_locs(tetrodes_ids):
    
        #
    locs = np.zeros((64,2))
    scalex = 1.1
    scaley = 125
    
    #
    centrex = 0
    centrey = 0
    ctr=0
    for tetrode_group in tetrodes_ids:
        
        #
        for q,k in enumerate(tetrode_group):
            locs[k,0] = (centrex + q%2)*scalex
            locs[k,1] = (centrey + q//2)*scaley
        
        #
        
        #
        ctr+=1
        centrex = (ctr%4)*3
        centrey = (ctr//4)*3 #*500
        

    return locs

#
def plot_single_templates_grouped_tetrodes(templates,
                                           tetrodes_ids, 
                                           locs,
                                           unit_id,
                                           max_chans,
                                           tetrode_clr_ctr,
                                          legend_flag=False):

    clrs = ['black','blue','red','green',
            'cyan','magenta','pink','brown',
           'yellow','darkblue']

    #
    t = np.arange(templates.shape[1])/60
    template = templates[unit_id]

    ptp = np.ptp(template,axis=0)
    max_chan = np.argmax(ptp)
    
    # find tetrode group belonging to
#    for chan in range(locs.shape[0]):

    for ctr_tr, tetrode_group in enumerate(tetrodes_ids):
        
        if max_chan in tetrode_group:
            clr = clrs[tetrode_clr_ctr[ctr_tr]%10]
            tetrode_clr_ctr[ctr_tr]+=1
            #    
            for ctr_k, k in enumerate(tetrode_group):
                #print ("plotting: ", k, locs[k])
                temp = template[:,k]  

                # label only first tetrode
                # TODO: can plot all tetrodes at once
                if ctr_k==0:
                    plt.plot(t+locs[k,0],
                         temp+locs[k,1],
                         c=clr,
                         label=str(unit_id)
                        )
                else:
                    plt.plot(t+locs[k,0],
                         temp+locs[k,1],
                         c=clr
                        )

                #
                plt.text(locs[k,0],
                         locs[k,1],
                         str(k))
        
    #
    plt.xticks([])
    plt.yticks([])
    
    if legend_flag:
        plt.legend()
    
    
#
def plot_all_templates(templates,
                   max_chans):
    
    
    locs = np.zeros((64,2))
    scalex = 2
    scaley = 250
    for k in range(64):
        tempx = (k%8)*scalex
        tempy = (k//8)*scaley
        locs[k] = [tempx, tempy]

    plt.figure()

    plt.scatter(locs[:,0],
                locs[:,1],
                c='black',
               s=10)
    
    #
    
    #
    t = np.arange(templates.shape[1])/60
    for k in range(templates.shape[0]):
        #
        max_chan = int(max_chans[k])
        temp = templates[k,:,max_chan]  
        
        #
        plt.plot(t+locs[max_chan,0],
                 temp+locs[max_chan,1])
        
    #
    plt.xticks([])
    plt.yticks([])
    plt.show()    
    
#
def make_rasters(clusters, times):
    
    #
    sample_rate = 30000
    
    # make rasters with 10ms precision
    rasters = np.zeros((np.max(times)//30, n_clusters), dtype='bool')
    print ("rasters: ", rasters.shape)
    
    #
    for c in np.unique(clusters):
        
        idx = np.where(clusters==c)[0]
        t = times[idx]//30

        for k in range(1):
            try:
                rasters[t+k,c]=1
            except:
                pass

    return rasters.T

In [45]:
#########################################################
#########################################################
#########################################################
root_dir = '//media/cat/4TB1/donato/DON-011737/2022_11_09/2022-11-09_13-35-49/Record Node 110/experiment1/recording1/continuous/Acquisition_Board-100.Rhythm Data/'
#root_dir = '/media/cat/4TB1/donato/pup_Test1/2022-11-16_11-30-09/Record Node 110/experiment1/recording1/continuous/Acquisition_Board-100.Rhythm Data/'
root_dir = '/media/cat/4TB1/donato/pup_Test1/2022-11-16_13-41-57/Record Node 110/experiment2/recording1/continuous/Acquisition_Board-100.Rhythm Data/'
fname_data = os.path.join(root_dir, 'continuous.dat')

#
data = np.fromfile(fname_data,
                  dtype=np.int16).reshape(-1,64)

print (data.shape)

# save filtered version
fname_out = fname_data.replace('.dat','_filtered.npy')
print (fname_out)


#
if os.path.exists(fname_out)==False:

    lowcut=500
    highcut=6000
    fs = 30000
    order = 2

    data_filtered = data.copy()
    for k in trange(data.shape[1], desc='filtering'):
        data_filtered[:,k] = butter_bandpass_filter(data[:,k], 
                                           lowcut,
                                           highcut,
                                           fs,
                                           order)
    
    np.save(fname_out, data_filtered)
else:
    data_filtered = np.load(fname_out)

    

(6921216, 64)
/media/cat/4TB1/donato/pup_Test1/2022-11-16_13-41-57/Record Node 110/experiment2/recording1/continuous/Acquisition_Board-100.Rhythm Data/continuous_filtered.npy


In [54]:
#############################################
############# PLOT RASTERS ##################
#############################################

plt.figure()
t = np.arange(0,data.shape[0],1)
#t = np.arange(2000000,3000000,1)
chans = np.arange(19,19+4*4,1)
#chans = np.arange(0,18,1)
#chans = np.arange(64)

#mean = np.median(data[t][:,chans],1)

#print (m#ean.shape)
subsample = 1
for k in chans:
    
    temp = data[t,k]
    mean = np.median(temp)
    temp = temp-mean
    
    temp = notch_filter(temp, 50)
    temp = notch_filter(temp, 60)
    temp = notch_filter(temp, 120)
    
    plt.plot(t[::subsample]/30000, 
             temp[::subsample]+k*30000)
plt.xlabel("Time (sec)")
plt.show()
    
    

In [20]:
#############################################
############# PLOT RASTERS ##################
#############################################

#
scaling = 0.195
sample_rate = 30000

#
clusters = np.load(os.path.join(root_dir, 'spike_clusters.npy')).squeeze()
times = np.load(os.path.join(root_dir, 'spike_times.npy')).squeeze().astype('int32')


#
clusters = fix_clusters(clusters)
n_clusters = np.unique(clusters).shape[0]
print ("# of clusters: ", n_clusters)
       




rasters = make_rasters(clusters, times)
# 
plt.figure()

plt.imshow(rasters,
           cmap='gray_r',
           aspect='auto',
           extent = [0,rasters.shape[1]/1000,0,rasters.shape[0]],
           interpolation='none')
    
#
plt.xlabel("Time (sec)")
plt.ylabel("Neuron ID")
plt.suptitle(fname_data, fontsize=12)
plt.show()


# of clusters:  53
rasters:  (230705, 53)


In [21]:
#############################################
############# PLOT TEMPLATES ################
#############################################

#
scaling = 0.195

#
clusters = np.load(os.path.join(root_dir, 'spike_clusters.npy')).squeeze()
times = np.load(os.path.join(root_dir, 'spike_times.npy')).squeeze().astype('int32')

#
clusters = fix_clusters(clusters)
print ("# of clusters: ", np.unique(clusters).shape[0])

#
templates, max_chans, n_spikes = get_templates(clusters, 
                                               data,
                                               times)

# sort cells by max channel
idx = np.argsort(max_chans)

#####################################################
#####################################################
#####################################################
t=np.arange(templates.shape[1])/30.
plt.figure()
ctr=0
for cell_id in tqdm(idx):
    ax = plt.subplot(7,10,ctr+1)
    
    #
    plt.plot(t, templates[cell_id], alpha=.25, c='black')
    plt.title("id: " + str(cell_id)+", ch: "+str(max_chans[cell_id])+ ", #spk: " + str(n_spikes[cell_id]), fontsize=8)

    #
    if ctr<34:
        plt.xticks([])
    else:
        plt.xlabel("Time (msec)")
    ctr+=1
#
plt.suptitle("10minute recording \n"+fname_data)
plt.show()


# of clusters:  53


100%|██████████| 53/53 [00:01<00:00, 47.72it/s]
100%|██████████| 53/53 [00:02<00:00, 23.46it/s]


In [20]:
#############################################
#############################################
#############################################

#
plot_all_templates(templates,
               max_chans)

#
plt.suptitle("Cell max chan locations + shape")
plt.show()

In [26]:
############################################################
############ VIEW ALL CELLS ON ALL TETRODES ################
############################################################


#
tetrodes_ids, tetrodes_no_ids = get_tetrode_geometry()
locs = make_tetrode_locs(tetrodes_ids)

#
unit_ids = np.arange(63)

#
tetrode_clr_ctr = np.zeros(16,'int32')
plt.figure(figsize=(10,10))

#
for unit_id in unit_ids:

    #
    plot_single_templates_grouped_tetrodes(templates,
                                           tetrodes_ids, 
                                           locs,
                                           unit_id, 
                                           max_chans,
                                           tetrode_clr_ctr)

#
plt.suptitle("Cell max chan locations + shape")
plt.show()

In [27]:
############################################################
########## SAME THING AS ABOVE BUT PER TETRODE #############
############################################################

#
ptps = np.ptp(templates, axis=1)
max_chans = np.argmax(ptps,axis=1)

# find cells that are on the same tetrode
legend_flag = True
for tetrode in tetrodes_ids:
    
    plt.figure()
    #
    for k in range(templates.shape[0]):
        
        if max_chans[k] in tetrode:
    
            plot_single_templates_grouped_tetrodes(templates,
                                               tetrodes_ids, 
                                               locs,
                                               k, 
                                               max_chans,
                                               tetrode_clr_ctr,
                                               legend_flag)
    
    plt.title("Tetrode: "+str(tetrode))
    if False:
        plt.show()
    else:
        plt.savefig('/home/cat/tetrode_'+str(tetrode)+'.png')
        plt.close()


In [23]:
#
def get_tetrode_geometry():
    
    tetrodes_ids = [
    [10,12,14],
    [17,19,21,23],
    [25,27,29,31],  # 23 is nearby and 21l CNOFIRMED
    [32,34,36,38],  # unclear re: 32, 34

    [33,35,37,39],
    [48,50,52,54], 
    [49,51],
    [53,55,57,59],

    [1,3,5,7],
    [16,18,20,22],
    ]

    #
    #tetrodes_ids = np.array(tetrodes_ids)

    #

    ###########
    tetrodes_flat = sum(tetrodes_ids, []) #np.array(tetrodes_ids).flatten()


    tetrodes_no_id = np.delete(np.arange(64), 
                               tetrodes_flat)

    
    return tetrodes_ids, tetrodes_no_id


tetrodes_ids, _ = get_tetrode_geometry()

print (tetrodes_ids)

#np.save(os.path.join('/media/cat/4TB1/donato/DON-011737/',"tetrodes.npy"),
#           tetrodes_ids)
        

[[10, 12, 14], [17, 19, 21, 23], [25, 27, 29, 31], [32, 34, 36, 38], [33, 35, 37, 39], [48, 50, 52, 54], [49, 51], [53, 55, 57, 59], [1, 3, 5, 7], [16, 18, 20, 22]]


In [26]:
####################################################

def make_yass_geometry_from_file(layout, dist_inter = 100):
    
    all_chans = np.arange(64)
    
    #
    ctr = 0
    geom = []
    y = 0
    
    # put tetrodes together
    for tetrode in layout:
        
        x = 0
        for line in tetrode:
            x += 20
            geom.append([line, x,y])
            
            #
            idx = np.where(all_chans==line)[0]
            all_chans = np.delete(all_chans, idx)
            
        y+=150

    # add all other channels:
    for chan in all_chans:
        
        x = 0
        y+=150
        geom.append([chan, x,y])    

    #
    geom = np.vstack(geom)

    print ("geom: ", geom)
    
    # sort by chan ID
    idx = np.argsort(geom[:,0])
    print (idx)
        
    return geom[idx][:,1:]


def make_yass_geometry_staggered(layout, dist_inter = 100):
    
    odd_chans = np.arange(1,64,2)
    even_chans = np.arange(0,64,2)
    
    #
    ctr = 0
    geom = []
    
    # put tetrodes together
    y = 0
    x = 0
    for odd_chan in odd_chans:
        
        geom.append([odd_chan, x,y])
        x += 20

    # add all other channels:
    for even_chan in even_chans:
        
        geom.append([even_chan, x,y])    
        x += 20

    #
    geom = np.vstack(geom)

    #print ("geom: ", geom)
    
    # sort by chan ID
    idx = np.argsort(geom[:,0])
    print (idx)
        
    return geom[idx][:,1:]


layout = np.load('/media/cat/4TB1/donato/DON-011737/tetrodes.npy',allow_pickle=True)

geom = make_yass_geometry_staggered(layout)

np.savetxt('/media/cat/4TB1/donato/DON-011737/geom.txt',
          geom)

#
print (geom)


[32  0 33  1 34  2 35  3 36  4 37  5 38  6 39  7 40  8 41  9 42 10 43 11
 44 12 45 13 46 14 47 15 48 16 49 17 50 18 51 19 52 20 53 21 54 22 55 23
 56 24 57 25 58 26 59 27 60 28 61 29 62 30 63 31]
[[ 640    0]
 [   0    0]
 [ 660    0]
 [  20    0]
 [ 680    0]
 [  40    0]
 [ 700    0]
 [  60    0]
 [ 720    0]
 [  80    0]
 [ 740    0]
 [ 100    0]
 [ 760    0]
 [ 120    0]
 [ 780    0]
 [ 140    0]
 [ 800    0]
 [ 160    0]
 [ 820    0]
 [ 180    0]
 [ 840    0]
 [ 200    0]
 [ 860    0]
 [ 220    0]
 [ 880    0]
 [ 240    0]
 [ 900    0]
 [ 260    0]
 [ 920    0]
 [ 280    0]
 [ 940    0]
 [ 300    0]
 [ 960    0]
 [ 320    0]
 [ 980    0]
 [ 340    0]
 [1000    0]
 [ 360    0]
 [1020    0]
 [ 380    0]
 [1040    0]
 [ 400    0]
 [1060    0]
 [ 420    0]
 [1080    0]
 [ 440    0]
 [1100    0]
 [ 460    0]
 [1120    0]
 [ 480    0]
 [1140    0]
 [ 500    0]
 [1160    0]
 [ 520    0]
 [1180    0]
 [ 540    0]
 [1200    0]
 [ 560    0]
 [1220    0]
 [ 580    0]
 [1240    0]
 [ 600    0

In [31]:
temps = np.load('/media/cat/4TB1/donato/DON-011737/2022_10_27/yass/tmp/output/templates/templates_600sec.npy')

print (temps.shape)

plt.figure()
for k in range(77):
    
    ax=plt.subplot(8,10,k+1)
    
    plt.plot(temps[k], 
             c='black')
    
    plt.title(str(k))
    plt.xticks([])
    plt.yticks([])


plt.show()

(50, 151, 64)


IndexError: index 50 is out of bounds for axis 0 with size 50

In [30]:
data = np.load('/media/cat/4TB1/donato/DON-011737/2022_10_27/yass/tmp/output/spike_train.npy')
print (data.shape)

(999885, 2)
