In [1]:
import matplotlib
%matplotlib tk
#matplotlib.use('Agg')

%load_ext autoreload
%autoreload 2

%autosave 180

import matplotlib.pyplot as plt
from IPython.core.display import display, HTML
display(HTML("<style>.container { width:100% !important; }</style>"))

import numpy as np
import os

import sys
sys.path.append('..')

# visualize results module
from visualize import visualize_svm as Visualize



Autosaving every 180 seconds


  from IPython.core.display import display, HTML


In [88]:

def load_right_paw_traces(n_sec,
                          root_dir, 
                          animal_id, 
                          session_id):
    
    # 
    #session_id = sessions[0]
    print ("Loading session: ", session_id)
    
    try:
        fname = os.path.join(root_dir, 
                             animal_id,
                             'tif_files',
                             session_id,
                             session_id + '_0secNoMove_movements.npz')

        data = np.load(fname, allow_pickle=True)
        feature_initiations = data['feature_initiations']
        feature_quiescent = data['feature_quiescent']
        labels = data['labels']
        video_shift = data['video_shift']
        #print ("video shifts: ", video_shift)
        #print ("labels: ", labels)
        label_id = 1
        #print (feature_initiations.shape)
        right_paw_initiations = feature_initiations[label_id]
    except:
        print ("no movements file")
        return None


    # load blue light trigger
    try: 
        sample_rate_imaging = 30.
        fname_blue_light = os.path.join(root_dir, 
                             animal_id,
                             'tif_files',
                             session_id,
                             'blue_light_frame_triggers.npz')


        blue_light = np.load(fname_blue_light, allow_pickle=True)
        light_start = blue_light['start_blue']/sample_rate_imaging
        light_end = blue_light['end_blue']/sample_rate_imaging
        img_frame_triggers = blue_light['img_frame_triggers']
    except:
        print ("no blue light trigger")
        return None
        
    #print ("blue light: ", light_start, light_end, img_frame_triggers)

    # load rewarded times
    fname_rewards = os.path.join(root_dir, 
                         animal_id,
                         'tif_files',
                         session_id,
                         'rewarded_times.txt')

    reward_times = np.loadtxt(fname_rewards)
    reward_times += video_shift[0]

    #################
    try: 
        import glob
        fname_dlc = os.path.join(root_dir, 
                             animal_id,
                             'tif_files',
                             session_id,
                             "*DLC*.h5")
        fname_dlc = glob.glob(fname_dlc)[0]

        import pandas as pd
        import h5py

        f = h5py.File(fname_dlc, 'r')
        for key in f.keys():
            df = pd.read_hdf(fname_dlc,key)

        #
        locs = df.to_numpy()

        #
        locs_right_paw = locs[:,3:6]
    except:
        print ("no DLC file")
        return None
        
    # load water rewarded snippets
    threshold = 0.5             # NN threshold for identifying a location
    sample_rate_video = 15
    idx = np.where(locs_right_paw[:,2]<threshold)[0]
    locs_right_paw[idx]=np.nan
    t = np.arange(locs_right_paw.shape[0])/sample_rate_video
    
    # 
    right_paw_reward_snippets = []
    for reward_time in reward_times:
        idx = np.argmin(np.abs(t-reward_time))
        temp = locs_right_paw[int(idx-n_sec*sample_rate_video):
                              int(idx+n_sec*sample_rate_video)]
        
        #
        if temp.shape[0]==n_sec*sample_rate_video*2:
            right_paw_reward_snippets.append(temp)
        
    # load spontaneous snippets 
    right_paw_spontaneous_snippets = []
    n_sec_lockout = 3
    last_time = 0
    idx_initiations = np.where(right_paw_initiations==1)[0]
    for initiation_time in idx_initiations:
        
        if initiation_time>(last_time+n_sec_lockout*sample_rate_video):
            
            # find nearest location in time
            temp = locs_right_paw[int(initiation_time-n_sec*sample_rate_video):
                                  int(initiation_time+n_sec*sample_rate_video)]
            if temp.shape[0]==n_sec*sample_rate_video*2:
                right_paw_spontaneous_snippets.append(temp)
            
            last_time = initiation_time
        
    right_paw_reward_snippets = np.array(right_paw_reward_snippets)
    right_paw_spontaneous_snippets = np.array(right_paw_spontaneous_snippets)
    

    # 
    try:
        fname_vid =  os.path.join(root_dir, 
                             animal_id,
                             'vids',
                             'prestroke',
                             session_id+'.mp4')

        import cv2
        vidcap = cv2.VideoCapture(fname_vid)
        success,image = vidcap.read()
        count = 0
        frame_id = 500
        while count<frame_id:
            #cv2.imwrite("frame%d.jpg" % count, image)     # save frame as JPEG file      
            success,image = vidcap.read()
            #print('Read a new frame: ', success)
            count += 1
    except:
        print ("no video")
        return None
    
    
    return (t, 
            locs_right_paw, 
            right_paw_initiations, 
            reward_times, 
            locs_right_paw,
            right_paw_reward_snippets, 
            right_paw_spontaneous_snippets,
            image)

def plot_paw_traces(res, session_id, n_sec):
    from tqdm import trange

    (t, 
     locs_right_paw, 
     right_paw_initiations,
     reward_times,
     locs_right_paw,
     right_paw_reward_snippets, 
     right_paw_spontaneous_snippets,
     image) = res

        
    ##
    plt.figure(figsize=(15,10))
    #plt.imshow(image,
    #          cmap='Blues')

    ax=plt.subplot(1,1,1)

    ########
    print (right_paw_reward_snippets.shape)
    print (right_paw_spontaneous_snippets.shape)

    cmap = plt.cm.viridis(right_paw_reward_snippets.shape[1])
    t= np.arange(-5,5,1/15.)
    cmap = matplotlib.cm.get_cmap('viridis', 
                                  right_paw_reward_snippets.shape[1])
    #
    line_plot = True
    plot_type = 'spontaneous'  # 'rewards', 'spontaneous'
    print ("PLOT TYPE: ", plot_type)
    # 
    if plot_type=='rewarded':
        data_to_plot = right_paw_reward_snippets
    if plot_type=='spontaneous':
        data_to_plot = right_paw_spontaneous_snippets

    #     
    for k in trange(data_to_plot.shape[0]):

        if line_plot:
            for p in range(data_to_plot.shape[1]-1):
                loc1 = data_to_plot[k,p,:]
                loc2 = data_to_plot[k,p+1,:]
                prob = loc1[2]
                if np.linalg.norm(loc1-loc2)<100:
                        plt.plot(
                            [loc1[0],loc2[0]],
                            [loc1[1],loc2[1]],
                            c=cmap(p),
                            alpha=.5
                        )
        if False:
            plt.scatter(data_to_plot[k,:,0],
                        data_to_plot[k,:,1],
                        c=np.arange(data_to_plot.shape[1]),
                        cmap='viridis',
                        alpha=.3
                       )

    # 
    plt.title(session_id)
    #plt.xlim(100,350)
    #plt.ylim(350,150)

    # # 
    # import matplotlib as mpl
    # norm = mpl.colors.Normalize(vmin=-5,vmax=5)
    # sm = plt.cm.ScalarMappable(cmap=cmap, norm=norm)
    # sm.set_array([])
    # plt.colorbar(sm, ticks=np.linspace(-5,5,11), 
    #              label='Time (sec)'
    #             )


    plt.title("# of " + plot_type+ " movements: "+str(data_to_plot.shape[0]))
    plt.xlabel("Pixels")
    plt.ylabel("Pixels")
    plt.savefig("/home/cat/self_init_paper/Revision_figures/temp/"+session_id+"_"+plot_type+".png")
    plt.close()
    #plt.show()
    
def plt_paw_scatter_pca(res, n_sec):
    from sklearn.decomposition import PCA
    (t, 
     _, 
     _,
     _,
     locs_right_paw,
     _, 
     _,
     _) = res

    print (locs_right_paw.shape)

    sample_rate_video = 15
    length_snippet = 1 * sample_rate_video
    X = np.array_split(locs_right_paw[:, :2],
                       np.arange(length_snippet,locs_right_paw.shape[0],length_snippet),axis=0)
    #print ("X: ", len(X))
    X = np.array(X[:-1])
    # find 1 sec snippets that don't have any nans in them
    idx = np.isnan(X).sum(axis=1).sum(axis=1)
    idx2 = np.where(idx==0)[0]
    #print ("X: ", X.shape, "idx: ", idx.shape, idx[:10], idx2.shape, idx2[:10])

    X = X[idx2]
    X_pca = X.reshape(X.shape[0], -1)
    print ("X_pca: ", X_pca.shape)

    # run pca
    pca = PCA(n_components=2)
    print ("start pca: ", X_pca.shape)
    X_pca = pca.fit_transform(X_pca)
    print (" pca done: ", X_pca.shape)

    fig=plt.figure()
    plt.scatter(X_pca[:,0],
                X_pca[:,1],
               c=np.arange(X_pca.shape[0]),
               alpha=.4)

    # 
    plt.title("# of 1sec segments: "+str(X_pca.shape[0]))
    plt.xlabel("PC1")
    plt.ylabel("PC2")
    plt.xlim(-500,500)
    plt.ylim(-200,200)

    if True:
        plt.savefig("/home/cat/self_init_paper/Revision_figures/temp/"+session_id+"_pca.png")
        plt.close()
    else:
        plt.show()

     
def plt_paw_scatter_pca_all(X, n_sec_partition):
    from sklearn.decomposition import PCA

    #
    sample_rate_video = 15
    length_snippet = n_sec_partition * sample_rate_video
    X = np.array_split(X[:, :2],
                       np.arange(length_snippet,X.shape[0],length_snippet),axis=0)
    
    # remove last bit
    X = np.array(X[:-1])
    
    # find n sec snippets that don't have any nans in them
    idx = np.isnan(X).sum(axis=1).sum(axis=1)
    idx2 = np.where(idx==0)[0]
    X = X[idx2]
    print ("X: ", X.shape)
    
    # zero-out each snippet to the beginning of the array
    X = X - X[:,0,:][:,None]
    
    # flatten data for PCA
    X_nsec = X.reshape(X.shape[0], -1)
    print ("X_nsec flattened: ", X_nsec.shape)

    # run pca
    
    if False:
        pca = PCA(n_components=2)
        print ("data input: ", X_nsec.shape)
        X_pca = pca.fit_transform(X_nsec)
        print (" pca done: ", X_pca.shape)
    else:
        
        import umap.umap_ as umap
        print ("Running Umap: ", X_nsec.shape)
        X_pca = umap.UMAP().fit_transform(X_nsec)
    
    fig=plt.figure(figsize=(5,5))
    plt.scatter(X_pca[:,0],
                X_pca[:,1],
                c=np.arange(X_pca.shape[0]),
                alpha=.4)

    # 
    plt.title("# of "+str(n_sec)+" sec segments: "+str(X_pca.shape[0]))
    plt.xlabel("Dim 1")
    plt.ylabel("Dim 2")
    #plt.xlim(-500,500)
    #plt.ylim(-200,200)

    if False:
        plt.savefig("/home/cat/self_init_paper/Revision_figures/temp/"+session_id+"_pca.png")
        plt.close()
    else:
        plt.show()   

In [251]:
###################################################
############ PLOT TRACES OF LEVER PULLS ###########
###################################################
root_dir = '/media/cat/4TBSSD/yuki/'
animal_id = 'AQ2'
session_id = 'all'

sessions = Visualize.get_sessions(root_dir,
                                  animal_id,
                                  session_id)
print ("# sessions: ", len(sessions))
############## 
ctr=0
right_paw_traces_allsessions = []
right_paw_ = []
for session_id in sessions:
    print ('')
    n_sec = 3
    res = load_right_paw_traces(n_sec,
                                root_dir, 
                                animal_id, 
                                session_id)
    
    # 
    if res==None:
        continue
        
    # OPTION 1: plot traces of each lever pull
    #plot_paw_traces(res, session_id, n_sec)
    
    
    # OPTION 2: accumulat all traces and then chop them into segments for PCA
    (t, 
         _, 
         _,
         _,
         locs_right_paw,
         right_paw_reward_snippets, 
         right_paw_spontaneous_snippets,
         _) = res
        
    right_paw_traces_allsessions.append(locs_right_paw.T)
    right_paw_.append(right_paw_reward_snippets)
# 
print (" # of sessions loaded: ", len(right_paw_traces_allsessions))

# 


# sessions:  110

Loading session:  AQ2am_Dec9_30Hz
no blue light trigger

Loading session:  AQ2am_Dec10_30Hz
no blue light trigger

Loading session:  AQ2pm_Dec10_30Hz
no blue light trigger

Loading session:  AQ2am_Dec11_30Hz

Loading session:  AQ2pm_Dec14_30Hz

Loading session:  AQ2am_Dec14_30Hz
no blue light trigger

Loading session:  AQ2pm_Dec16_30Hz

Loading session:  AQ2am_Dec17_30Hz

Loading session:  AQ2pm_Dec17_30Hz

Loading session:  AQ2am_Dec18_30Hz

Loading session:  AQ2pm_Dec18_30Hz

Loading session:  AQ2am_Dec21_30Hz

Loading session:  AQ2am_Dec22_30Hz
no blue light trigger

Loading session:  AQ2am_Dec23_30Hz
no blue light trigger

Loading session:  AQ2am_Dec28_30Hz

Loading session:  AQ2am_Dec29_30Hz

Loading session:  AQ2am_Dec30_30Hz
no blue light trigger

Loading session:  AQ2am_Dec31_30Hz
no movements file

Loading session:  AQ2am_Jan4_30Hz

Loading session:  AQ2am_Jan5_30Hz

Loading session:  AQ2am_Jan6_30Hz

Loading session:  AQ2am_Jan7_30Hz

Loading session:  AQ2am

In [258]:
###################################
######## PLOT VARIANCES ###########
###################################

print ("# of sessions: ", len(right_paw_))
xx = []
yy = []
plt.figure()
ctr=0
for k in range(len(right_paw_)):
    #ax=plt.subplot(5,5,ctr+1)
    X = right_paw_[k][:,:,:2]

    # zero centre the traces:
    X = X - X[:,0][:,None]
    #print ("X: ", X.shape)
    
    #     
    x = X[:,X.shape[0]//2:X.shape[0]//2+2,0]
    y = X[:,X.shape[0]//2:X.shape[0]//2+2,1]
    
    # 
    #x = X[:,:,0]
    #y = X[:,:,1]
    
    # 
    if x.shape[0]<10:
        #print ("too few: ", x.shape)
        continue
    
    # 
    std_x = np.nansum(np.nanstd(x,axis=1))
    std_y = np.nansum(np.nanstd(y,axis=1))
    
    # 
    y = np.abs(std_x)+np.abs(std_y)
    xmean = np.nanmean(x,axis=1)
    xstd = np.nanmedian(np.nanstd(x,axis=1))
    ystd = np.nanmedian(np.nanstd(x,axis=1))
    if np.isnan(xstd+ystd)==False: # and (xstd+ystd)<4:
        plt.scatter(k, xstd+ystd)
        
        # 
        xx.append(k)
        yy.append(xstd+ystd)
        
    #     
    if False:
        t = np.arange(x.shape[0])
        #print (xmean.shape, xstd.shape)
        if False:
            plt.fill_between(t, 
                         xmean-xstd,
                         xmean+xstd,                     
                     alpha=0.5)
        else:
            plt.plot(xstd)
    
        
    ctr+=1
#
xx=np.array(xx)
yy=np.array(yy)
m, b = np.polyfit(xx, yy, 1)

# 
plt.plot(xx, m*xx + b)
plt.ylabel("# of right paw intiations (preceded by 3sec of quiescence)")
plt.xlabel("Session ID (Chronlogical)")

from scipy import stats
print (yy)
pcorr = stats.pearsonr(xx,yy)
print (pcorr)
plt.title(str(np.round(pcorr,3)))
plt.ylim(bottom=0)
#plt.savefig("/home/cat/paw_movements_"+animal_id+".svg")
#plt.close()


plt.show()

# of sessions:  51
[ 0.34686279  0.32293701  0.21264648  0.49107361  0.21679688  0.81965637
  0.57815552  0.25947571  0.33599854  0.91677856  0.35125732  0.28631592
  0.24523926  0.39981079  0.19033813  0.78353882  1.38293457  1.41766357
  1.03747559  1.08544922  0.88739014  0.95884705  0.55236816  1.88926697
  1.25497437  0.33151245 10.36746216  1.75814819  1.1289978   0.42158508
  0.7053833   1.03085327  0.42999268  0.70848083  0.96316528  1.20001221
  1.0380249   4.85801697  1.75726318]
(0.3282579930124826, 0.04133098570110634)


  var = nanvar(a, axis=axis, dtype=dtype, out=out, ddof=ddof,
  xmean = np.nanmean(x,axis=1)
  r, k = function_base._ureduce(a, func=_nanmedian, axis=axis, out=out,


In [87]:
##################################################       
####### PLOT UMAP OR PCA OF MOVEMENT SEGS ########
##################################################
print (len(right_paw_traces_allsessions))
X = np.hstack(right_paw_traces_allsessions).T
print ("X in: ", X.shape)
n_sec = 5
plt_paw_scatter_pca_all(X,n_sec)
    
    #break
    

51
X in:  (1019252, 3)
X:  (4812, 75, 2)
X_nsec flattened:  (4812, 150)
Running Umap:  (4812, 150)


In [122]:
#########################################################
####### PLOT # OF RIGHT PAW MOVEMENTS PER SESSION #######
#########################################################
plt.figure(figsize=(7,6))
x = np.arange(len(right_paw_spontaneous_movements))
y =right_paw_spontaneous_movements
plt.scatter(x,
            y)

m, b = np.polyfit(x, y, 1)

plt.plot(x, m*x + b)
plt.ylabel("# of right paw intiations (preceded by 3sec of quiescence)")
plt.xlabel("Session ID (Chronlogical)")

from scipy import stats

pcorr = stats.pearsonr(x,y)
print (pcorr)
plt.title(str(np.round(pcorr,3)))
plt.ylim(bottom=0)
#plt.savefig("/home/cat/paw_movements_"+animal_id+".svg")
#plt.close()
plt.show()


(-0.18775994497314316, 0.3483258322549543)
