In [1]:
import matplotlib
%matplotlib tk
%autosave 180
%load_ext autoreload
%autoreload 2

import nest_asyncio
%config Completer.use_jedi = False

#
import matplotlib.pyplot as plt
from IPython.core.display import display, HTML
display(HTML("<style>.container { width:100% !important; }</style>"))

# 
import numpy as np
import os
import scipy

# add root directory to be able to import packages
# todo: make all packages installable so they can be called/imported by environment
import sys

#
from tqdm import tqdm, trange

#
import opexebo as op

import astropy


Autosaving every 180 seconds


  from IPython.core.display import display, HTML


In [3]:
#################################################
#################################################
#################################################

#
def load_locs_traces(fname,
                     arena_size):
    
    #
    locs = np.load(fname)
    print (locs.shape)

    ####################### COMPUTE SPATIAL OCCUPANCY ###########################
    times = np.arange(locs.shape[0])

    #
    min_x = np.min(locs[:,0])
    max_x = np.max(locs[:,0])

    min_y = np.min(locs[:,1])
    max_y = np.max(locs[:,1])

    #
    locs[:,0] = (locs[:,0]-min_x)/(max_x-min_x)*arena_size[0]
    locs[:,1] = (locs[:,1]-min_y)/(max_y-min_y)*arena_size[1]

    ####################### LOAD SPIKES ########################
    fname = '/media/cat/4TB1/donato/nathalie/DON-007050/FS1/binarized_traces.npz'
    bin_ = np.load(fname,
                   allow_pickle=True)

    #
    upphases = bin_['F_upphase']
    #print ("traces: ", traces.shape, traces[0][:100])
    
    filtered_Fs = bin_['F_filtered']
    
    #
    return locs, upphases, times, filtered_Fs

#
def get_t_exp_from_filtered_F(upphase, 
                              filtered,
                             locs):
    
    ''' function uses upphase detected spiking
        and then uses the value of the filtered [ca]
    '''
    
    # check when upphase spiking occurs
    t = np.where(upphase>0)[0]
    
    print ("locs: ", locs.shape)
    
    tt = []
    x = []
    y = []
    gradient_scaling = 20
    for k in range(t.shape[0]):

        #
        temp_filtered = filtered[t[k]]
        grad = int(temp_filtered)

        # 
        if grad>0:
            tt.append(np.ones(grad))
            x.append(np.ones(grad)*locs[t[k],0])
            y.append(np.ones(grad)*locs[t[k],1])

    tt = np.hstack(tt)
    x = np.hstack(x)
    y = np.hstack(y)
    
    return tt,x,y

#
def get_t_exp_from_gradient(
                            upphase, 
                            filtered,
                            locs
                        ):
    ''' function uses upphase detected spiking
        and then scales the spiking based on the gradient of the filtered [ca]
    '''
    
    t = np.where(upphase>0)[0]
    
    print ("locs: ", locs.shape)
    
    tt = []
    x = []
    y = []
    gradient_scaling = 20
    for k in range(t.shape[0]):

        #
        temp_filtered = filtered[t[k]:t[k]+2]
        grad = int(np.gradient(temp_filtered)[0]*gradient_scaling)

        # 
        if grad>0:
            tt.append(np.ones(grad))
            x.append(np.ones(grad)*locs[t[k],0])
            y.append(np.ones(grad)*locs[t[k],1])

    tt = np.hstack(tt)
    x = np.hstack(x)
    y = np.hstack(y)
    
    return tt,x,y


#
def get_rms_and_place_field(cell_id,
                            upphases,
                            filtered_Fs,
                            locs,
                            occ_map,
                            arena_size,
                            sigma = 1.0,
                            scale_flag=False,
                            scale_value =1
                            #limits
                           ):

    #
    upphase = upphases[cell_id]
    filtered = filtered_Fs[cell_id]

    # find times when cell is spking and just feed those into the rm 
    
    # get the 
    #t, x, y = get_t_exp_from_gradient(upphase, 
    #                                  filtered,
    #                                  locs)
    
    t, x, y = get_t_exp_from_filtered_F(upphase, 
                                      filtered,
                                      locs)    
    
    #t_exp = get_t_exp_from_filtered_F(upphase, filtered)

    
    # if no spikes are found during movement
    if len(t)==0:
        
        temp = np.zeros((32,32))
        print (" no spiking found... cell: ", cell_id)
        
        return temp,temp,temp,temp
        

    # TODO: not clear what this array is supposed to do ...
    # it seems to take in boolean data?  not super clear
    
    # make the spikes_tracking vstack;  not completely clear this is correct; X->Y inverstion etc? not clear
    spikes_tracking = np.vstack((t,x,y))

    #
    limits = [0,80,0,80]

    # print ("# of spikes: ", spikes_tracking.shape)
    rm = op.analysis.rate_map(occ_map, 
                              spikes_tracking, 
                              bin_edges=bin_edges, 
                              arena_size=arena_size, 
                              #    limits=limits
                             )

    #
    res = op.analysis.rate_map_stats(rm, 
                                     occ_map, 
                                     debug=False)
        
    #
    if True:
        #sigma = 1
        rms = op.general.smooth(rm, 
                            sigma)
    else:
        rms = rm.copy()


    #g
    if False:
        rms = rms.filled(fill_value=0)
        occ_map = occ_map.filled(fill_value=0.001)

    if False:
        rms = rms*scale_value
        #print("scaling")
        #rms = rms+1
    #
    init_thresh = 0.95    
    min_bins = 5
    min_peak = 0.100  #100 μHz
    min_mean = 0.100000   # 100 mHz

    #
    fields, fields_map = op.analysis.place_field(rms, 
                                                init_thresh = init_thresh,
                                                min_bins = min_bins,
                                                #min_peak = min_peak ,
                                                #min_mean = min_mean,
                                                 search_method='sep',
                                              #    limits=limits
                                                )
    #if True:
    #    rms = rms.filled(fill_value=0)

    #
    #print (cell_id, "field ", fields)
    #print ("fields map: ", fields_map)
    
    #
    return rm, rms, fields_map, occ_map
    
#
def load_locs_traces_running(fname_locs,
                             fname_traces,
                             arena_size,
                             n_frames_per_sec=20,
                             n_pixels_per_cm=15,   # not used    
                             min_vel=4):          #minimum velocity in cm/sec
    
        
        
    ####################### LOAD SPIKES ########################
    data = np.load(fname_traces,
                   allow_pickle=True)

    #
    upphases = data['F_upphase']
    #print ("traces: ", traces.shape, traces[0][:100])
    
    filtered_Fs = data['F_filtered']
    
    
    ####################### LOAD LOCATIONS ###################
    locs = np.load(fname_locs)
    print (locs.shape)
    
    #################### COMPUTE VELOCITY ####################
    
    dists = np.linalg.norm(locs[1:,:]-locs[:-1,:], axis=1)
    print (dists.shape)
    
    #
    vel_all = (dists)*(n_frames_per_sec)
    
    #
    from scipy.signal import savgol_filter

    vel_all = savgol_filter(vel_all, n_frames_per_sec, 2)

    #
    idx_stationary = np.where(vel_all<min_vel)[0]
    vel = vel_all.copy()
    vel[idx_stationary] = np.nan

    
    #
    if False:
        plt.figure()
        t = np.arange(vel.shape[0])/n_frames_per_sec
        plt.plot(t, vel_all/10)
        plt.ylabel("Vel (cm/sec)")
        plt.xlabel("Time (sec)")
        plt.plot([t[0],t[-1] ],
                 [4,4],
                 '--'
                )
        plt.xlim(t[0],t[-1])
        plt.plot(t, vel/10,
                c='red')


        plt.show()
    
    ####################### NORMALIZE SIZE OF ARENA  ###########################
    #
    min_x = np.min(locs[:,0])
    max_x = np.max(locs[:,0])

    min_y = np.min(locs[:,1])
    max_y = np.max(locs[:,1])

    #
    locs[:,0] = (locs[:,0]-min_x)/(max_x-min_x)*arena_size[0]
    locs[:,1] = (locs[:,1]-min_y)/(max_y-min_y)*arena_size[1]

    ####################### DELETE EXTRA IMAGING TIME ###########################
    rec_duration = locs.shape[0]
    
    upphases = upphases[:,:rec_duration]
    filtered_Fs = filtered_Fs[:,:rec_duration]
    

    ####################### REMOVE STATIONARY PERIODS ###########################
    
    #times = np.delete(times, idx_stationary, axis=0)
    
    locs = np.delete(locs, idx_stationary, axis=0)
    
    upphases = np.delete(upphases, idx_stationary, axis=1)
    
    filtered_Fs = np.delete(filtered_Fs, idx_stationary, axis=1)
    
    ################### COMPUTE TIMES BASED ON THE MOVING PERIODS ######################
    #
    times = np.arange(locs.shape[0])    
    
    #
    print ("Locs: ", locs.shape, " uphases: ", upphases.shape)
   
    #
    return locs, upphases, times, filtered_Fs


In [4]:

def calc_tuningmap(occupancy, x_edges, y_edges, signaltracking, params):
    '''
    Calculate tuningmap
    Parameters
    ----------
    occupancy : masked np.array
        Smoothed occupancy. Masked where occupancy low
    x_edges : np.array
        Bin edges in x 
    y_edges : np.array
        Bin edges in y
    signaltracking : dict
        # Added by Horst 10-17-2022
          keys:
          signal       # Signal (events or calcium) amplitudes
          x_pos_signal # Tracking x position for signal 
          y_pos_signal # Tracking y position for signal 
        
    params : dict
        MapParams table entry
    
    Returns
    -------
    tuningmap_dict : dict
        - binned_raw : np.array: Binned raw (unsmoothed) signal
        - tuningmap_raw: np masked array: Unsmoothed tuningmap (mask where occupancy low)
        - tuningmap    : np masked array: Smoothed tuningmap (mask where occupancy low)
        - bin_max    : tuple   : (x,y) coordinate of bin with maximum signal
        - max        : float : Max of signal 
        
    '''
    tuningmap_dict = {}
    
    binned_signal = np.zeros_like(occupancy.data)
    # Add one at end to not miss signal at borders
    x_edges[-1] += 1
    y_edges[-1] += 1

    # Look up signal per bin
    for no_x in range(len(x_edges)-1):
        for no_y in range(len(y_edges)-1):
            boolean_x = (signaltracking['x_pos_signal'] >= x_edges[no_x]) & (signaltracking['x_pos_signal'] < x_edges[no_x+1])
            boolean_y = (signaltracking['y_pos_signal'] >= y_edges[no_y]) & (signaltracking['y_pos_signal'] < y_edges[no_y+1])
            extracted_signal = signaltracking['signal'][boolean_x & boolean_y]
            binned_signal[no_y, no_x] = np.nansum(extracted_signal)

    tuningmap_dict['binned_raw'] = binned_signal
    binned_signal = np.ma.masked_where(occupancy.mask, binned_signal)  # Masking. This step is probably unnecessary
    tuningmap_dict['tuningmap_raw'] = binned_signal / occupancy
    
    # Instead of smoothing the raw binned events, substitute those values that are masked in
    # occupancy map with nans.
    # Then use astropy.convolve to smooth padded version of the spikemap 
        
    binned_signal[occupancy.mask] = np.nan
    kernel = Gaussian2DKernel(x_stddev=params['sigma_signal'])

    pad_width = int(5*params['sigma_signal'])
    binned_signal_padded = np.pad(binned_signal, pad_width=pad_width, mode='symmetric')  # as in BNT
    binned_signal_smoothed = astropy.convolution.convolve(binned_signal_padded, kernel, boundary='extend')[pad_width:-pad_width, pad_width:-pad_width]
    binned_signal_smoothed = np.ma.masked_where(occupancy.mask, binned_signal_smoothed)  # Masking. This step is probably unnecessary
    masked_tuningmap = binned_signal_smoothed / occupancy

    tuningmap_dict['tuningmap']       = masked_tuningmap
    tuningmap_dict['bin_max']         = np.unravel_index(masked_tuningmap.argmax(), masked_tuningmap.shape)
    tuningmap_dict['max']             = np.max(masked_tuningmap)
    
    return tuningmap_dict

  

In [5]:
#############################################################
################### DEFAULT LOADING OF DATA #################
#############################################################

#
#############################################

#fname_locs = ('/media/cat/4TB1/donato/nathalie/DON-007050/FS1/DON-007050_20211022_'+
#         'TR-BSL_FS1-ACQDLC_resnet50_open_arena_white_floorSep8shuffle1_200000_locs.npy')
fname_locs = ('/media/cat/4TB1/donato/nathalie/DON-007050/FS9/DON-007050_20211030_TR-BSL_FS9-ACQDLC_resnet50_open_arena_white_floorSep8shuffle1_200000_locs.npy')

#
#fname_traces = '/media/cat/4TB1/donato/nathalie/DON-007050/FS1/binarized_traces.npz'
fname_traces = '/media/cat/4TB1/donato/nathalie/DON-007050/FS9/binarized_traces.npz'

#
arena_x = [300,1550]
arena_y = [175,1400]
arena_size = [80,80]
arena_shape = 'square'
bin_width = 2.5

#
locs, upphases, times, filtered_Fs = load_locs_traces_running(fname_locs,
                                                              fname_traces,
                                                              arena_size)

# compute spatial occpuancy map; only requires the locatoins and 
occ_map, coverage, bin_edges = op.analysis.spatial_occupancy(times, 
                                                             locs.T, 
                                                             arena_size, 
                                                             bin_width=bin_width)
#





(36000, 2)
(35999,)
Locs:  (35110, 2)  uphases:  (365, 35110)


In [6]:
#############################################################
################### DEFAULT LOADING OF DATA #################
#############################################################
from astropy.convolution import Gaussian2DKernel

#
def get_rms_and_place_field_from_tunning_map(cell_id,
                                            upphases,
                                            filtered_Fs,
                                            locs,
                                            occ_map,
                                            arena_size,
                                            sigma = 1.0,
                                            scale_flag=False,
                                            scale_value =1
                                            #limits
                                           ):

    #################################################
    #################################################
    #################################################
    upphase = upphases[cell_id]
    filtered = filtered_Fs[cell_id]
    
    # detect moving periods

    # 
    signaltracking_entry = {"x_pos_signal": locs[:,0],
                            "y_pos_signal": locs[:,1],
                            "signal": upphase*filtered
                           }
    
    # 
    params = {}
    params['sigma_signal'] = 2
    
    # Calculate tuningmap
    tuningmap_dict = calc_tuningmap(occ_map, 
                                    x_edges, 
                                    y_edges, 
                                    signaltracking_entry, 
                                    params)
    
    #
    rm = tuningmap_dict['tuningmap']
    
    #
    res = op.analysis.rate_map_stats(rm, 
                                     occ_map, 
                                     debug=False)
        
    #
    if True:
        #sigma = 1
        rms = op.general.smooth(rm, 
                            sigma)
    else:
        rms = rm.copy()


#     #g
#     if False:
#         rms = rms.filled(fill_value=0)
#         occ_map = occ_map.filled(fill_value=0.001)

#     if False:
#         rms = rms*scale_value

    #
    init_thresh = 0.75    
    min_bins = 5
    min_mean = rm.max()*0.1               # 
    min_peak = rm.max()*0.0001            # 

    #
    fields, fields_map = op.analysis.place_field(rms, 
                                                min_bins = min_bins,
                                                min_peak = min_peak,
                                                min_mean = min_mean,
                                                init_thresh = init_thresh,
                                                search_method='sep',
                                                debug=False
                                              #    limits=limits
                                                )

    
    #
    return rm, rms, fields_map, occ_map


#############################################################
#############################################################
#############################################################
scale_flag = False
scale_value = 1
sigma = 1.5

#
x_edges = bin_edges[0]
y_edges = bin_edges[1]

#
plt.figure(figsize=(15,15))
#cell_ids = np.arange(0,100,1)
cell_ids = np.arange(100)
#cell_ids = [27]
ctr=0
for cell_id in tqdm(cell_ids): 
    
   
    #
    rm, rms, fields_map, occ_map = get_rms_and_place_field_from_tunning_map(cell_id,
                                                                           upphases,
                                                                           filtered_Fs,
                                                                           locs,
                                                                           occ_map,
                                                                           arena_size,
                                                                           sigma,
                                                                           scale_flag,
                                                                           scale_value
                                                                           #limits
                                                                          )


    # #
    #ax=plt.subplot(1,2,ctr+1)
    #plt.imshow(rms/occ_map)
    #plt.title("rms/occ_map")

    #
    if cell_ids.shape[0]==100:
        ax=plt.subplot(10,10,ctr+1)
    elif cell_ids.shape[0]==10:
        ax=plt.subplot(2,5,ctr+1)
    else:
        ax=plt.subplot(1,1,1)
        
    #
    img1 = rms#/(occ_map+0.00001)
    if np.max(fields_map)!=0:
        img2 = fields_map/np.max(fields_map)*np.nanmax(img1)
    else:
        img2 = fields_map
    #
    img = np.vstack((img1, 
                     np.zeros(32)+np.nan,
                     img2))
    #img[32]=np.nan

    plt.imshow(img,
              vmin=np.min(img1),
              vmax=np.max(img1)
              )
    plt.xticks([])
    plt.yticks([])
    plt.title(str(cell_id),fontsize=10,pad=0.9)
    
    #
    res = op.analysis.rate_map_stats(rms, 
                                 occ_map, 
                                 debug=False)

    #
    coh = op.analysis.rate_map_coherence(rm)

    #
    text = "SI_rate: "+ str(round(res['spatial_information_rate'],2)) + \
        "  SI_cont: "+ str(round(res['spatial_information_content'],2)) + \
        "  Sparse: "+ str(round(res['sparsity'],2)) + ' \n '+ \
        "Select: "+  str(round(res['selectivity'],2))+ \
        "  Peak_r: "+  str(round(res['peak_rate'],2))+ \
        "  Mean_r: "+  str(round(res['mean_rate'],2))+ \
        "  Coh: "+str(round(coh,2))


    ax.set_ylabel(text, labelpad=.3, fontsize=3)
    
    ctr+=1

    #


plt.suptitle("cell "+str(cell_id)+ "\n"+str(res)+"\ncoherence "+str(coh),fontsize=10)

if False:
    plt.show()
else:
    plt.savefig("/home/cat/fields.svg")
    plt.close()


100%|██████████| 100/100 [00:27<00:00,  3.61it/s]


In [39]:
#print (res)
text = "SI_rate: "+ str(round(res['spatial_information_rate'],2)) + \
        "SI_cont: "+ str(round(res['spatial_information_content'],2)) + \
        "Sparse: "+ str(round(res['sparsity'],2)) + ' \n '+ \
        "Select: "+  str(round(res['selectivity'],2))+ \
        "Peak_r: "+  str(round(res['peak_rate'],2))+ \
        "Mean_r: "+  str(round(res['mean_rate'],2))

print (text)


['SI_rate: 2.65SI_cont: 1.14Sparse: 0.45 \n Select: 4.59Peak_r: 10.64Mean_r: 0.57']


In [9]:
###############################################################
############# SPATIAL SELETIVITY VS. SELECTIVITY ##############
###############################################################
scale_flag = False
scale_value = 1
sigma = 1.5

cell_ids = np.arange(upphases.shape[0])
#cell_ids = np.arange(103,200,1)
spc_array = []
sel_array = []
for cell_id in tqdm(cell_ids): 
    #print (cell_id)
    
    _, rms, _, _ = get_rms_and_place_field(cell_id,
                                           upphases.copy(),
                                           filtered_Fs.copy(),
                                           occ_map.copy(),
                                           arena_size.copy(),
                                           sigma,
                                           scale_flag,
                                           scale_value
                                           #limits
                                          )
    #
    res = op.analysis.rate_map_stats(rms.copy(), 
                                     occ_map.copy(), 
                                     debug=False)
    
   # print (res)
    spc = res['spatial_information_content']
    sel = res['selectivity']
    if np.isnan(spc) or np.isnan(sel):
        print ("ceel Id: ", spc, sel)

    spc_array.append(spc)
    sel_array.append(sel)
    
plt.figure()

plt.scatter(spc_array, sel_array)
plt.xlabel("special_information_content")
plt.ylabel("selectivity")

plt.show()



 22%|██▏       | 103/476 [00:09<00:29, 12.64it/s]

ceel Id:  nan nan


 37%|███▋      | 177/476 [00:16<00:27, 11.02it/s]


KeyboardInterrupt: 