In [7]:
import matplotlib
%matplotlib tk
%autosave 180
%load_ext autoreload
%autoreload 2

import nest_asyncio
%config Completer.use_jedi = False

#
import matplotlib.pyplot as plt
from IPython.core.display import display, HTML
display(HTML("<style>.container { width:100% !important; }</style>"))

# 
import numpy as np
import os
import scipy

# add root directory to be able to import packages
# todo: make all packages installable so they can be called/imported by environment
import sys

#
from tqdm import tqdm, trange

#
import opexebo

import astropy

from utils import process_ephys


Autosaving every 180 seconds
The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


In [174]:
#############################################################
################### DEFAULT LOADING OF DATA #################
#############################################################


#fname_locs = ('/media/cat/4TB1/donato/nathalie/DON-007050/FS1/DON-007050_20211022_'+
#         'TR-BSL_FS1-ACQDLC_resnet50_open_arena_white_floorSep8shuffle1_200000_locs.npy')

root_dir = '/media/cat/4TB1/donato/DON-011737/2022_11_14/'

#
fname_locs = root_dir+'/Test_011737_Test1_60minDLC_resnet50_OF darkbox white floorMar21shuffle1_350000.npy'
fname_locs_times = root_dir + '/011737_Test1_60min_frame_times.npy'
fname_shift = root_dir+'/events/MessageCenter/timestamps.npy'

#
#fname_traces = '/media/cat/4TB1/donato/nathalie/DON-007050/FS1/binarized_traces.npz'
fname_spikes = root_dir + '/yass/spike_train.npy'

#
#arena_x = [300,1550]
#arena_y = [175,1400]
arena_size = [80,80]
arena_shape = 'square'
bin_width = 2.5

#
locs, locs_binned, locs_times, spikes = process_ephys(fname_locs,
                                                      fname_spikes,
                                                      fname_locs_times,
                                                      fname_shift,
                                                      arena_size,
                                                      bin_width)

# compute spatial occpuancy map; only requires the locatoins and the time at which the mouse was at some location
occ_map, coverage, bin_edges = op.analysis.spatial_occupancy(locs_times, 
                                                             locs.T, 
                                                             arena_size, 
                                                             bin_width=bin_width)

#
plt.figure()
plt.imshow(occ_map)
plt.show()



locs:  (138617, 2)
locs tims:  (138617,)
locs times:  [0.00000000e+00 5.98552227e-02 8.77804756e-02 ... 3.59992525e+03
 3.59994918e+03 3.59997511e+03]
distances:  (138616,)
vel_all:  [ 0.93317371  2.03108485  1.07346283 ... 29.88362156 12.8680298
 26.38388726]
loaded spikes:  [[3.0000000e+00 7.0000000e+00]
 [4.0000000e+00 2.9000000e+01]
 [5.0000000e+00 1.0500000e+02]
 ...
 [1.0956083e+08 9.9000000e+01]
 [1.0956083e+08 9.8000000e+01]
 [1.0956085e+08 1.0000000e+01]]
time shift (sec):  46.754133333333336
spikes pre:  (10223352, 2)
spikes post:  (10099500, 2)


In [199]:
###################################################################################################
###################################################################################################
###################################################################################################

def get_spikes(locs,
               locs_times,
               spikes_all, 
               cell_id):
               
    #
    idx = np.where(spikes_all[:,1]==cell_id)[0]
    
    #
    spikes = spikes_all[idx,0]
    idx = np.argsort(spikes)
    spikes=spikes[idx]
    
    #
    print ("cell: ", cell_id, ", spike times: ", spikes.shape)
    
    # make the spikes_tracking vstack;  not completely clear this is correct; X->Y inverstion etc? not clear
    min_time = 0.0333
    min_spikes = 3

    #
    spikes_tracking = []
    ctr=0
    for k in trange(locs_times.shape[0]):
        temp = locs_times[k]
        idx1 = np.where(np.logical_and(spikes>(temp-min_time), spikes<=temp))[0]

        #
        if idx1.shape[0]>= min_spikes:
            ctr+=1
            spikes_tracking.append([k,
                                    locs[k,0], 
                                    locs[k,1]])
        
        # delete spikes before
        if idx1.shape[0]>0:
            spikes = spikes[idx1[0]:]

    #
    spikes_tracking = np.vstack(spikes_tracking)
        
                       
    return spikes_tracking.T

#
def get_field(occ_map,
              cell_spikes,
              arena_size):
    
    #
    try:
        rm = opexebo.analysis.rate_map(occ_map, 
                                   cell_spikes, 
                                   arena_size)
    except:
        rm = np.zeros((32,32))

    #
    sigma = 1
    rms = op.general.smooth(rm, sigma)

    #
    init_thresh = 1   
    min_bins = 2
    min_mean = rms.max()*0.1               # 
    min_peak = rms.max()*0.1            # 

    #
    fields, fields_map = op.analysis.place_field(rms, 
                                                 min_bins = min_bins,
                                                 min_peak = min_peak,
                                                 min_mean = min_mean,
                                                 init_thresh = init_thresh,
                                                 search_method='sep',
                                                 debug=False
                                              #    limits=limits
                                                )
    #
    return rms, fields_map
    
#
def compute_place_field(cell_id,
                        locs,
                        locs_times,
                        spikes,
                        occ_map,
                        arena_size,
                       ):
    
    #
    cell_spikes = get_spikes(
                             #locs_binned,
                             locs,
                             locs_times,
                             spikes,
                             cell_id)
    
    #
    n_spikes = cell_spikes.shape[1]
    print ("cell spikes: ", cell_spikes.shape)
    
    # find mid-point of time:
    mid_pt = locs_times.shape[0]//2
    mid_pt_idx = np.argmin(np.abs(mid_pt-cell_spikes[0]))
    print ("mid pt: ", mid_pt)
    print ("mid_pt_idx: ", mid_pt_idx)
    
    #
    plt.figure(figsize=(10,10))

    
    #
    rms, fields_map = get_field(occ_map,
                                cell_spikes,
                                arena_size)

    #
    ax=plt.subplot(3,3,1)
    plt.imshow(rms)
    plt.ylabel("rate map (smoothed)")
    
    #
    ax=plt.subplot(3,3,4)
    plt.imshow(rms/occ_map)
    plt.ylabel("rate map/occ_map")
    

    #
    ax=plt.subplot(3,3,7)
    plt.imshow(fields_map)
    plt.ylabel("fields")

    
    ########################################
    rms, fields_map = get_field(occ_map,
                                cell_spikes[:,:mid_pt_idx//2],
                                arena_size)

    #
    ax=plt.subplot(3,3,2)
    plt.imshow(rms)
    
    #
    ax=plt.subplot(3,3,5)
    plt.imshow(rms/occ_map)
    
    #
    ax=plt.subplot(3,3,8)
    plt.imshow(fields_map)

    #
    rms, fields_map = get_field(occ_map,
                                cell_spikes[:,mid_pt_idx//2:],
                                arena_size)
    ax=plt.subplot(3,3,3)
    plt.imshow(rms)
    
    #
    ax=plt.subplot(3,3,6)
    plt.imshow(rms/occ_map)
    
    
    ax=plt.subplot(3,3,9)
    plt.imshow(fields_map)
    plt.suptitle("Cell: "+str(cell_id) + ", # spks: "+str(n_spikes))

    if False:
        plt.show()
    else:
        plt.savefig('/media/cat/4TB1/temp/place_field_'+str(cell_id)+'.png')
        plt.close()

####################################################################
#               
cell_ids = np.arange(np.max(spikes[:,1]))

cell_ids = [14]
cell_ids = np.arange(10,20,1)

for cell_id in cell_ids:
    compute_place_field(cell_id,
                        locs,
                        locs_times,
                        spikes,
                        occ_map,
                        arena_size)


cell:  10 , spike times:  (177069,)


100%|█████████████████████████████████████████████████████████████████████████████████████████| 94125/94125 [00:07<00:00, 12283.12it/s]


cell spikes:  (3, 23518)
mid pt:  47062
mid_pt_idx:  11822
cell:  11 , spike times:  (61498,)


100%|█████████████████████████████████████████████████████████████████████████████████████████| 94125/94125 [00:03<00:00, 26004.63it/s]


cell spikes:  (3, 3432)
mid pt:  47062
mid_pt_idx:  2009
cell:  12 , spike times:  (304772,)


100%|██████████████████████████████████████████████████████████████████████████████████████████| 94125/94125 [00:11<00:00, 8107.77it/s]


cell spikes:  (3, 49903)
mid pt:  47062
mid_pt_idx:  25891
cell:  13 , spike times:  (82886,)


100%|█████████████████████████████████████████████████████████████████████████████████████████| 94125/94125 [00:04<00:00, 21735.99it/s]


cell spikes:  (3, 5922)
mid pt:  47062
mid_pt_idx:  2917
cell:  14 , spike times:  (22087,)


100%|█████████████████████████████████████████████████████████████████████████████████████████| 94125/94125 [00:02<00:00, 39143.82it/s]


cell spikes:  (3, 296)
mid pt:  47062
mid_pt_idx:  87
cell:  15 , spike times:  (5613,)


100%|█████████████████████████████████████████████████████████████████████████████████████████| 94125/94125 [00:01<00:00, 58065.14it/s]


cell spikes:  (3, 243)
mid pt:  47062
mid_pt_idx:  109
cell:  16 , spike times:  (33782,)


100%|█████████████████████████████████████████████████████████████████████████████████████████| 94125/94125 [00:02<00:00, 32309.59it/s]


cell spikes:  (3, 1955)
mid pt:  47062
mid_pt_idx:  679
cell:  17 , spike times:  (18637,)


100%|█████████████████████████████████████████████████████████████████████████████████████████| 94125/94125 [00:02<00:00, 39003.97it/s]


cell spikes:  (3, 302)
mid pt:  47062
mid_pt_idx:  49
cell:  18 , spike times:  (19146,)


100%|█████████████████████████████████████████████████████████████████████████████████████████| 94125/94125 [00:02<00:00, 38214.04it/s]


cell spikes:  (3, 515)
mid pt:  47062
mid_pt_idx:  169
cell:  19 , spike times:  (4315,)


100%|█████████████████████████████████████████████████████████████████████████████████████████| 94125/94125 [00:01<00:00, 56917.48it/s]


cell spikes:  (3, 127)
mid pt:  47062
mid_pt_idx:  43


In [166]:
plt.figure()

print (cell_spikes.shape)
x = cell_spikes[1]
y = cell_spikes[2]
print (x.shape)

plt.scatter(x,
            y
           )

plt.show()

(3, 7120)
(7120,)


In [167]:
print (cell_spikes[0])

[  184.   185.   186. ... 94049. 94094. 94095.]


In [6]:
#############################################################
################### DEFAULT LOADING OF DATA #################
#############################################################


scale_flag = False
scale_value = 1
sigma = 1.5

#
x_edges = bin_edges[0]
y_edges = bin_edges[1]

#
plt.figure(figsize=(15,15))
#cell_ids = np.arange(0,100,1)
cell_ids = np.arange(100)
#cell_ids = [27]
ctr=0
for cell_id in tqdm(cell_ids): 
    
   
    #
    rm, rms, fields_map, occ_map = get_rms_and_place_field_from_tunning_map(cell_id,
                                                                           upphases,
                                                                           filtered_Fs,
                                                                           locs,
                                                                           occ_map,
                                                                           arena_size,
                                                                           sigma,
                                                                           scale_flag,
                                                                           scale_value
                                                                           #limits
                                                                          )


    # #
    #ax=plt.subplot(1,2,ctr+1)
    #plt.imshow(rms/occ_map)
    #plt.title("rms/occ_map")

    #
    if cell_ids.shape[0]==100:
        ax=plt.subplot(10,10,ctr+1)
    elif cell_ids.shape[0]==10:
        ax=plt.subplot(2,5,ctr+1)
    else:
        ax=plt.subplot(1,1,1)
        
    #
    img1 = rms#/(occ_map+0.00001)
    if np.max(fields_map)!=0:
        img2 = fields_map/np.max(fields_map)*np.nanmax(img1)
    else:
        img2 = fields_map
    #
    img = np.vstack((img1, 
                     np.zeros(32)+np.nan,
                     img2))
    #img[32]=np.nan

    plt.imshow(img,
              vmin=np.min(img1),
              vmax=np.max(img1)
              )
    plt.xticks([])
    plt.yticks([])
    plt.title(str(cell_id),fontsize=10,pad=0.9)
    
    #
    res = op.analysis.rate_map_stats(rms, 
                                 occ_map, 
                                 debug=False)

    #
    coh = op.analysis.rate_map_coherence(rm)

    #
    text = "SI_rate: "+ str(round(res['spatial_information_rate'],2)) + \
        "  SI_cont: "+ str(round(res['spatial_information_content'],2)) + \
        "  Sparse: "+ str(round(res['sparsity'],2)) + ' \n '+ \
        "Select: "+  str(round(res['selectivity'],2))+ \
        "  Peak_r: "+  str(round(res['peak_rate'],2))+ \
        "  Mean_r: "+  str(round(res['mean_rate'],2))+ \
        "  Coh: "+str(round(coh,2))


    ax.set_ylabel(text, labelpad=.3, fontsize=3)
    
    ctr+=1

    #


plt.suptitle("cell "+str(cell_id)+ "\n"+str(res)+"\ncoherence "+str(coh),fontsize=10)

if False:
    plt.show()
else:
    plt.savefig("/home/cat/fields.svg")
    plt.close()


100%|██████████| 100/100 [00:27<00:00,  3.61it/s]


In [39]:
#print (res)
text = "SI_rate: "+ str(round(res['spatial_information_rate'],2)) + \
        "SI_cont: "+ str(round(res['spatial_information_content'],2)) + \
        "Sparse: "+ str(round(res['sparsity'],2)) + ' \n '+ \
        "Select: "+  str(round(res['selectivity'],2))+ \
        "Peak_r: "+  str(round(res['peak_rate'],2))+ \
        "Mean_r: "+  str(round(res['mean_rate'],2))

print (text)


['SI_rate: 2.65SI_cont: 1.14Sparse: 0.45 \n Select: 4.59Peak_r: 10.64Mean_r: 0.57']


In [9]:
###############################################################
############# SPATIAL SELETIVITY VS. SELECTIVITY ##############
###############################################################
scale_flag = False
scale_value = 1
sigma = 1.5

cell_ids = np.arange(upphases.shape[0])
#cell_ids = np.arange(103,200,1)
spc_array = []
sel_array = []
for cell_id in tqdm(cell_ids): 
    #print (cell_id)
    
    _, rms, _, _ = get_rms_and_place_field(cell_id,
                                           upphases.copy(),
                                           filtered_Fs.copy(),
                                           occ_map.copy(),
                                           arena_size.copy(),
                                           sigma,
                                           scale_flag,
                                           scale_value
                                           #limits
                                          )
    #
    res = op.analysis.rate_map_stats(rms.copy(), 
                                     occ_map.copy(), 
                                     debug=False)
    
   # print (res)
    spc = res['spatial_information_content']
    sel = res['selectivity']
    if np.isnan(spc) or np.isnan(sel):
        print ("ceel Id: ", spc, sel)

    spc_array.append(spc)
    sel_array.append(sel)
    
plt.figure()

plt.scatter(spc_array, sel_array)
plt.xlabel("special_information_content")
plt.ylabel("selectivity")

plt.show()



 22%|██▏       | 103/476 [00:09<00:29, 12.64it/s]

ceel Id:  nan nan


 37%|███▋      | 177/476 [00:16<00:27, 11.02it/s]


KeyboardInterrupt: 

In [3]:
import numpy as np
data = np.load('/media/cat/4TB1/ephys/spike_times.npy')
print (data.shape)

d2 = np.load('/media/cat/4TB1/ephys/spike_clusters.npy')
print (d2.shape)

(10223352,)
(10223352,)


In [22]:
d = np.load('/media/cat/4TB1/donato/DON-011737/2022_11_14/011737_Test1_60min_frame_times.npy')
print (d)

plt.figure()
plt.plot(d[1:]-d[:-1])

plt.show()

[1.6684163e+09 1.6684163e+09 1.6684163e+09 ... 1.6684199e+09 1.6684199e+09
 1.6684199e+09]


In [None]:
root_dir = '/media/cat/4TB1/donato/DON-011737/2022_11_18/events/'

d1 = np.load(root_dir+'/MessageCenter/sample_numbers.npy')
print (d1)

d2 = np.load(root_dir+'/MessageCenter/text.npy')
print (d2)

d3 = np.load(root_dir+'/MessageCenter/timestamps.npy')
print (d3)

d4 = np.load(root_dir +'/Acquisition_Board-100.Rhythm Data/TTL/full_words.npy')
print (d4)

d5 = np.load(root_dir +'/Acquisition_Board-100.Rhythm Data/TTL/sample_numbers.npy')
print (d5)

d6 = np.load(root_dir+'/Acquisition_Board-100.Rhythm Data/TTL/states.npy')
print (d6)

d7 = np.load(root_dir+'/Acquisition_Board-100.Rhythm Data/TTL/timestamps.npy')
print (d7)