In [54]:
import matplotlib
matplotlib.use('Agg')

%load_ext autoreload
%autoreload 2

%matplotlib tk
%autosave 180
import matplotlib.pyplot as plt
from IPython.core.display import display, HTML
display(HTML("<style>.container { width:100% !important; }</style>"))

import numpy as np
import os

import glob2
from sklearn.decomposition import PCA
import pycorrelate


# visualize results module
from Visualize import Visualize, get_sessions, load_trial_times_whole_stack, get_lever_offset_seconds

from utility_DLC import *


# manually add labels from DLC
from tqdm import tqdm, trange
import scipy

# 
import glob

import umap

# 
#data_dir = '/media/cat/4TBSSD/yuki/'
data_dir = '/media/cat/4TBSSD/yuki/'


labels = ['left_paw', 
'right_paw',
'nose',
'jaw',
'right_ear',
'tongue',
'lever']



The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


Autosaving every 180 seconds


In [5]:
# 
def clean_traces(data,
                smooth_window = 10):
    
    # 
    for k in trange(data.shape[0]):
        temp = data[k]
        probs = temp[:,2]
        
        threshold = 0.5
        idx = np.where(probs<threshold)[0]
        x = temp[:,0]
        y = temp[:,1]
        x[idx] = np.nan
        y[idx] = np.nan
        
        # 

        for i in range(10000):
            idx = np.where(np.isnan(x))[0]
            #print (idx.shape)
            if idx.shape[0]==0:
                break

            if idx[0]==0:
                idx=idx[1:]
                
            x[idx] = x[idx-1]
            y[idx] = y[idx-1]

        x = scipy.ndimage.median_filter(x, size=smooth_window)
        y = scipy.ndimage.median_filter(y, size=smooth_window)
        #print (x.shape, y.shape, data.shape)
        
        data[k,:,0]=x
        data[k,:,1]=y
    return data
            
    

def plot_velocities(traces, movements, window):
    
    # 
    t=np.linspace(-window, window,window*2*15-1)

    
    fig=plt.figure()
    ax1=plt.subplot(111)
    mean_vel=[]
    for k in trange(movements.shape[0]):
        t_start = int(movements[k]-window)*15
        t_end = int(movements[k]+window)*15

        temp = traces[feature_selected,t_start:t_end,:2]

        if True:
            ax1=plt.subplot(221)
            vel = np.sqrt((temp[1:,0]-temp[:-1,0])**2+
                          (temp[1:,1]-temp[:-1,1])**2)

            mean_vel.append(vel)
            plt.plot(t,vel,c='black',linewidth=2,alpha=.1)

    # complete ax1 plots
    mean_vel = np.array(mean_vel).mean(0)
    ax1.plot(t,mean_vel,c='black',linewidth=5,alpha=1)
    ax1.plot([0,0],[0.1,250],'--',linewidth=5,c='black')
    ax1.set_ylim(0.1,250)
    ax1.set_xlim(t[0],t[-1])
    ax1.semilogy()




def get_positions(movements,
                  traces,
                  colors,
                  plotting=False):
    
    padding_x = 20
    padding_y = 20

    if plotting:
        fig=plt.figure()
        ax1=plt.subplot(111)
    
    # 
    all_traces = []
    for k in trange(movements.shape[0]):
        t_start = int(movements[k]-window)*15
        t_end = int(movements[k]+window)*15

        temp = traces[feature_selected,t_start:t_end,:2]
        if temp.shape[0]==0:
            continue
        
        temp = temp - temp[temp.shape[0]//2]
        temp[:,0] = scipy.ndimage.median_filter(temp[:,0], size=10)
        temp[:,1] = scipy.ndimage.median_filter(temp[:,1], size=10)

        
        # zero out to t=0 point
            
        
        if plotting:
            if True:
                for p in range(1, temp.shape[0]-1,1):
                    plt.plot([temp[p-1,0],temp[p,0]],
                             [temp[p-1,1],temp[p,1]], 
                             c=colors[p],
                             linewidth=3,
                             alpha=.1)
            else:
                plt.plot(temp[:,0], temp[:,1], 
                         c=colors[k],alpha=.4)


            plt.scatter(temp[:,0],temp[:,1], 
                         c=colors, 
                        s=100, alpha=.1)
        
        if temp.shape[0]==(150):
            #print (temp.shape, t_end, t_start)
            all_traces.append(temp)
    
    # complete ax2 plots
    if plotting:
        ax1.set_xlim(ax1.get_xlim()[0]-padding_x,
                     ax1.get_xlim()[1]+padding_x)
        ax1.set_ylim(ax1.get_ylim()[0]-padding_y,
                     ax1.get_ylim()[1]+padding_y)


    all_traces = np.array(all_traces)


    return np.array(all_traces)

# 
def plot_average_positions(mean_trace, 
                           color,
                           fig=None,
                           ax1=None):
    
    #fig = plt.figure()
    if ax1 is None:
        fig=plt.figure()
        ax1=plt.subplot(111)
    #
    mean_trace = mean_trace.mean(0)
    mean_trace[:,0] = scipy.ndimage.median_filter(mean_trace[:,0], size=5)
    mean_trace[:,1] = scipy.ndimage.median_filter(mean_trace[:,1], size=5)

    # 
    if False:
        for p in range(1, mean_trace.shape[0]-1,1):
            ax1.plot([mean_trace[p-1,0],mean_trace[p,0]],
                     [mean_trace[p-1,1],mean_trace[p,1]], 
                     c=color,
                     linewidth=5, alpha=.7)
        
    else:
        ax1.plot(mean_trace[:,0],mean_trace[:,1],
                     c=color,
                     linewidth=5, alpha=.7)
    ax1.set_xticks([])
    ax1.set_yticks([])

    #ax1.scatter(mean_trace[:,0],mean_trace[:,1], 
    #            #c=colors,alpha=1)
    #            c=color,alpha=.1)

#     ax1.set_xlim(np.min(mean_trace[:,0])-1,
#                  np.max(mean_trace[:,0])+1)
#     ax1.set_ylim(np.min(mean_trace[:,1])-1,
#                  np.max(mean_trace[:,1])+1)


def get_movements_lever_pos(trace,
                            times,
                            lockout_window=0):
    
    
    
    
    movement_threshold = 3  # number of pixels per frame inidicating movement from stationarity

    movements = np.zeros((trace.shape[0])) # This tracks any change in movement.

    # compute velocity and median value for 1D data
    vel = trace[1:]-trace[:-1]
           
    idx = np.where(vel<=1)[0]
    vel[idx]=np.nan

    idx2 = np.where(vel>=movement_threshold)[0]  # VELOCITY > min thresh means movement
    movements[idx2]+=1

    # 
    idx = np.where(movements>0)[0]
    print ('idx: ', idx.shape)
    #
    chunks = []
    for p in range(1,idx.shape[0],1):
        if (times[idx[p]]-times[idx[p-1]])>= lockout_window:
            chunks.append([times[idx[p-1]], times[idx[p]]])


    print (labels[k], "  # of quiescent periods: ", len(chunks))
    return np.array(chunks)


In [6]:
##########################
####### INITIALIZE #######
##########################
# LEVER PULL
vis = Visualize()

# lever-related data
vis.main_dir = data_dir
vis.random_flag = False  # shuffle data to show baseline

# 
vis.window = 15
vis.lockout_window = 10
vis.lockout = False

vis.pca_var = 0.95
vis.pca_flag = True


vis.significance = 0.05
vis.linewidth=10

#
vis.smooth_window = 10
vis.lockout = False
vis.xvalidation = 10
vis.sliding_window = 30

# 
vis.window = 15
vis.cbar_thick = 0.05
vis.alpha = 1.0
    
# 
vis.min_trials = 10
# 
vis.ctr_plot = 0
    

# 
vis.animal_id= 'IA1'
session = 'Feb1_'

print ("   session: ", session)
vis.session_id = session
vis.cbar_offset = 0



   session:  Feb1_


In [5]:
#########################################
##### LOAD TRACES FROM H5 AND CLEAN #####
#########################################
animals = ['IA1','IA2','IA3','IJ1','IJ2','AQ2']
animal_id = 5

fnames = np.loadtxt('/media/cat/4TBSSD/yuki/'+animals[animal_id]+'/tif_files/sessions.txt',dtype='str')

#
for root_dir in fnames:
    
    try:
        temp = root_dir + '/*.h5'
        fname_h5 = glob.glob(temp)[0]
        fname_npy = fname_h5.replace('h5','npy')
    except:
        print("video file not found: ", root_dir)
        continue
    
    fname_out = fname_npy.replace('.npy','_clean.npy')

    if os.path.exists(fname_out)==False:
        print ('processing: ', root_dir)
        traces_original = np.load(fname_npy)
        if True:
            traces_original[:,:,0] = np.int32(traces_original[:,:,0])
            traces_original[:,:,1] = np.int32(traces_original[:,:,1])

        print (traces_original.shape)

        # plt.plot(traces_original[0,:,0])
        smooth_window = 3
        traces = clean_traces(traces_original, smooth_window)
        print (traces.shape)

        np.save(fname_out,traces)
        print ('')

        
        

video file not found:  /media/cat/4TBSSD/yuki/AQ2/tif_files/AQ2am_Apr1_Week4_30Hz
video file not found:  /media/cat/4TBSSD/yuki/AQ2/tif_files/AQ2am_Apr4_Week5_30Hz
video file not found:  /media/cat/4TBSSD/yuki/AQ2/tif_files/AQ2am_Apr5_Week5_30Hz
video file not found:  /media/cat/4TBSSD/yuki/AQ2/tif_files/AQ2am_Apr6_Week5_30Hz
video file not found:  /media/cat/4TBSSD/yuki/AQ2/tif_files/AQ2am_Apr7_Week5_30Hz
video file not found:  /media/cat/4TBSSD/yuki/AQ2/tif_files/AQ2am_Apr8_Week5_30Hz
video file not found:  /media/cat/4TBSSD/yuki/AQ2/tif_files/AQ2am_Apr11_Week6_30Hz
video file not found:  /media/cat/4TBSSD/yuki/AQ2/tif_files/AQ2am_Apr12_Week6_30Hz
video file not found:  /media/cat/4TBSSD/yuki/AQ2/tif_files/AQ2am_Apr13_Week6_30Hz
video file not found:  /media/cat/4TBSSD/yuki/AQ2/tif_files/AQ2am_Apr14_Week6_30Hz
video file not found:  /media/cat/4TBSSD/yuki/AQ2/tif_files/AQ2am_Apr15_Week6_30Hz
video file not found:  /media/cat/4TBSSD/yuki/AQ2/tif_files/AQ2am_Apr18_Week7_30Hz
video file

In [6]:
#######################################################
##### VISUALIZE MOTION LOGINTUIDAINLLY (Fig 3E,F) #####
#######################################################

# FIG 3E, F VIRIDIS PLOTS AVERAGE BODY FEAUTRE POSITION


# 
import glob
window = 5

n = int(window)*15*2
colors = plt.cm.cool(np.linspace(0,1,n))

colors_all = plt.cm.viridis(np.linspace(0,1,40))


n_colors = [30,12,11,10,11,70]

# select body feature to track
# ['left_paw',  0
# 'right_paw',  1
# 'nose',       2
# 'jaw',        3
# 'right_ear',  4
# 'tongue',     5
# 'lever']      6    

feature_selected = 2


# 
fig=plt.figure(figsize=(15,10))
for animal_id in range(6):
    ax=plt.subplot(2,3,animal_id+1)
    animal = animals[animal_id]
    fnames = np.loadtxt('/media/cat/4TBSSD/yuki/'+str(animal)+'/tif_files/sessions.txt',dtype='str')

    # map viridis color onto # of sessoins
    colors_sessions = plt.cm.viridis(np.linspace(0,1,n_colors[animal_id]))


    # loop over the sessions in the animal and 
    ctr=0
    traces_array = []
    for root_dir in fnames:

        try:
            fname_h5 = glob.glob(root_dir + '/*.h5')[0]
            fname_npy = fname_h5.replace('h5','npy')
        except:
            continue

        #
        traces = np.load(fname_npy)
        #print (traces.shape)

        # 
        try:
            fname_movements = glob.glob(root_dir + "/*"+str(3)+"secNoMove_movements.npz")[0]
        #print(fname_movements)
        except:
            continue
        movements = np.load(fname_movements,allow_pickle=True)['feature_quiescent']
      
        movements = np.array(movements[feature_selected])
        if movements.shape[0]==0:
            continue

    #     # 
    #     if False:
    #         plot_velocities(traces, 
    #                         movements, 
    #                         window)
        else:
            movements = movements[:,1]
            all_traces = get_positions(movements,
                                       traces,
                                       colors)

            traces_array.append(all_traces)


        plot_average_positions(all_traces,
                              colors_sessions[ctr],
                              fig,
                              ax)

        ctr+=1


    print (' DONE ', animal_id, " # sessions: ", ctr)
    print ('')

if False:
    plt.savefig('/home/cat/'+str(feature_selected)+'_longitudinal.png',dpi=600)
    plt.close()
else:
    plt.show()

        # plt.suptitle(labels[feature_selected],fontsize=20)



100%|██████████| 118/118 [00:00<00:00, 4640.41it/s]
100%|██████████| 81/81 [00:00<00:00, 4143.46it/s]
100%|██████████| 57/57 [00:00<00:00, 4174.02it/s]
100%|██████████| 60/60 [00:00<00:00, 4399.47it/s]
100%|██████████| 25/25 [00:00<00:00, 3532.82it/s]
100%|██████████| 23/23 [00:00<00:00, 3649.57it/s]
100%|██████████| 42/42 [00:00<00:00, 3599.67it/s]
100%|██████████| 113/113 [00:00<00:00, 4282.76it/s]
100%|██████████| 75/75 [00:00<00:00, 4195.14it/s]
100%|██████████| 51/51 [00:00<00:00, 3702.39it/s]
100%|██████████| 76/76 [00:00<00:00, 4352.96it/s]
100%|██████████| 61/61 [00:00<00:00, 3899.95it/s]
100%|██████████| 132/132 [00:00<00:00, 4250.53it/s]
100%|██████████| 98/98 [00:00<00:00, 4904.45it/s]
100%|██████████| 120/120 [00:00<00:00, 4859.63it/s]
100%|██████████| 99/99 [00:00<00:00, 4852.36it/s]
100%|██████████| 88/88 [00:00<00:00, 4802.03it/s]
100%|██████████| 104/104 [00:00<00:00, 4823.49it/s]
100%|██████████| 101/101 [00:00<00:00, 4864.39it/s]
100%|██████████| 141/141 [00:00<00:00,

 DONE  0  # sessions:  30



100%|██████████| 133/133 [00:00<00:00, 4906.31it/s]
100%|██████████| 125/125 [00:00<00:00, 4901.35it/s]
100%|██████████| 161/161 [00:00<00:00, 4816.12it/s]
100%|██████████| 129/129 [00:00<00:00, 4774.33it/s]
100%|██████████| 95/95 [00:00<00:00, 4716.33it/s]
100%|██████████| 56/56 [00:00<00:00, 4675.55it/s]
100%|██████████| 36/36 [00:00<00:00, 4519.86it/s]
100%|██████████| 142/142 [00:00<00:00, 4867.33it/s]
100%|██████████| 112/112 [00:00<00:00, 4836.33it/s]
100%|██████████| 164/164 [00:00<00:00, 4889.44it/s]
100%|██████████| 121/121 [00:00<00:00, 4865.69it/s]
100%|██████████| 101/101 [00:00<00:00, 4902.55it/s]
  0%|          | 0/105 [00:00<?, ?it/s]

 DONE  1  # sessions:  12



100%|██████████| 105/105 [00:00<00:00, 4795.94it/s]
100%|██████████| 143/143 [00:00<00:00, 4974.13it/s]
100%|██████████| 154/154 [00:00<00:00, 4814.53it/s]
100%|██████████| 142/142 [00:00<00:00, 4870.68it/s]
100%|██████████| 105/105 [00:00<00:00, 4809.14it/s]
100%|██████████| 79/79 [00:00<00:00, 4873.73it/s]
100%|██████████| 100/100 [00:00<00:00, 4835.10it/s]
100%|██████████| 115/115 [00:00<00:00, 4841.65it/s]
100%|██████████| 11/11 [00:00<00:00, 3906.63it/s]
100%|██████████| 67/67 [00:00<00:00, 4724.19it/s]
100%|██████████| 62/62 [00:00<00:00, 4760.23it/s]
100%|██████████| 93/93 [00:00<00:00, 4785.32it/s]
100%|██████████| 120/120 [00:00<00:00, 4891.65it/s]
  0%|          | 0/65 [00:00<?, ?it/s]

 DONE  2  # sessions:  11



100%|██████████| 65/65 [00:00<00:00, 4711.40it/s]
100%|██████████| 103/103 [00:00<00:00, 4877.65it/s]
100%|██████████| 61/61 [00:00<00:00, 2882.94it/s]
100%|██████████| 113/113 [00:00<00:00, 4784.59it/s]
100%|██████████| 128/128 [00:00<00:00, 4918.88it/s]
100%|██████████| 132/132 [00:00<00:00, 4865.35it/s]
100%|██████████| 113/113 [00:00<00:00, 4909.28it/s]
100%|██████████| 101/101 [00:00<00:00, 4918.95it/s]
  0%|          | 0/96 [00:00<?, ?it/s]

 DONE  3  # sessions:  10



100%|██████████| 96/96 [00:00<00:00, 4702.85it/s]
100%|██████████| 101/101 [00:00<00:00, 4906.47it/s]
100%|██████████| 122/122 [00:00<00:00, 4977.58it/s]
100%|██████████| 102/102 [00:00<00:00, 4855.68it/s]
100%|██████████| 168/168 [00:00<00:00, 4943.34it/s]
100%|██████████| 160/160 [00:00<00:00, 4911.69it/s]
100%|██████████| 35/35 [00:00<00:00, 4579.93it/s]
100%|██████████| 128/128 [00:00<00:00, 4830.89it/s]
100%|██████████| 75/75 [00:00<00:00, 4725.94it/s]
100%|██████████| 145/145 [00:00<00:00, 4871.86it/s]
100%|██████████| 123/123 [00:00<00:00, 4843.94it/s]
100%|██████████| 92/92 [00:00<00:00, 4872.48it/s]
100%|██████████| 90/90 [00:00<00:00, 4798.91it/s]
  0%|          | 0/24 [00:00<?, ?it/s]

 DONE  4  # sessions:  11



100%|██████████| 24/24 [00:00<00:00, 4144.91it/s]
100%|██████████| 15/15 [00:00<00:00, 3985.47it/s]
100%|██████████| 23/23 [00:00<00:00, 4513.80it/s]
100%|██████████| 19/19 [00:00<00:00, 4492.71it/s]
100%|██████████| 48/48 [00:00<00:00, 4709.61it/s]
100%|██████████| 44/44 [00:00<00:00, 4738.96it/s]
100%|██████████| 79/79 [00:00<00:00, 4888.68it/s]
100%|██████████| 40/40 [00:00<00:00, 4700.82it/s]
100%|██████████| 74/74 [00:00<00:00, 4825.61it/s]
100%|██████████| 107/107 [00:00<00:00, 4802.78it/s]
100%|██████████| 137/137 [00:00<00:00, 4840.29it/s]
100%|██████████| 122/122 [00:00<00:00, 4824.49it/s]
100%|██████████| 130/130 [00:00<00:00, 4805.36it/s]
100%|██████████| 137/137 [00:00<00:00, 4808.69it/s]
100%|██████████| 99/99 [00:00<00:00, 4816.73it/s]
100%|██████████| 44/44 [00:00<00:00, 4650.12it/s]
100%|██████████| 63/63 [00:00<00:00, 4768.84it/s]
100%|██████████| 122/122 [00:00<00:00, 4895.15it/s]
100%|██████████| 114/114 [00:00<00:00, 4841.79it/s]
100%|██████████| 36/36 [00:00<00:00,

 DONE  5  # sessions:  70






In [7]:
############################################################
##### VISUALIZE PULS IN 2D; FLATTEN DATA FOR PCA etc ######
############################################################

# FIG 3- PCA/UMAP PLOTS

# flatten data for PCA
t_flat = []
colors = plt.cm.viridis(np.linspace(0,1,len(traces_array)))

print ("shape: ", traces_array[0].shape, len(traces_array))
clrs = []
for k in range(len(traces_array)):
    if len(traces_array[k].shape)<2:
        continue
    t_flat.append(traces_array[k].reshape(traces_array[k].shape[0],-1))
    clrs.append(np.zeros((traces_array[k].shape[0],4))+colors[k])
    
t_flat = np.vstack(t_flat)
print ('t flat: ', t_flat.shape)
clrs = np.vstack(clrs)


######################################################
#### COMPUTE PCA AND UMAP DISTRIBUTIONS RESULTS ######
######################################################
X = t_flat
pca = PCA(n_components=2)
pca.fit(X)
p = pca.transform(X)
print (p.shape)

fit = umap.UMAP()
%time u = fit.fit_transform(X)
print (u.shape)

#######################################
########### PLOT RESULTS ##############
#######################################
fig = plt.figure(figsize=(10,10))
cm = plt.cm.get_cmap('viridis')
ax=plt.subplot(221)

plt.scatter(p[:,0], p[:,1], c=clrs, s=25, edgecolor = 'black', alpha=.5)

ax=plt.subplot(222)
sc = plt.scatter(u[:,0], u[:,1], c=clrs, s=25, edgecolor = 'black', alpha=.5)
#plt.colorbar(sc)


if True:
    plt.savefig('/home/cat/'+str(feature_selected)+'.png',dpi=600)
    plt.close()
else:
    plt.show()
    

shape:  (35, 150, 2) 70
t flat:  (6860, 300)
(6860, 2)
CPU times: user 5min 43s, sys: 4.67 s, total: 5min 47s
Wall time: 34.9 s
(6860, 2)


In [8]:
########################################################################
### VISUALIZE INTER-BEHAVIOR-INTERVAL BODY MOVEMENT DISTRIBUTIONS ######
########################################################################

# FIG 3 G, H;  Histograms log-log of movement distributions

# LEVER PULL
vis = Visualize()

# lever-related data
vis.main_dir = data_dir
vis.window = 15
vis.lockout_window = 10
vis.lockout = False

vis.pca_var = 0.95
vis.pca_flag = True
    
n_colors = [30,12,11,10,11,70]
n_colors_lever = [69, 42, 42,44,42,  109]
animals = ['IA1','IA2','IA3','IJ1','IJ2','AQ2']
animal_id = 4
alpha = 1
ctr=0
for animal_id in range(6):

    # 
    vis.animal_id= animals[animal_id]

    session = 'all'

    width = 1
    bins = np.arange(.5,30,width)


    vis.get_sessions() 
    colors_sessions = plt.cm.viridis(np.linspace(0,1,n_colors[animal_id]))
    colors_levers = plt.cm.viridis(np.linspace(0,1,n_colors_lever[animal_id]))

    #
    fig=plt.figure(figsize=(10,5))
    ctr_clr=0
    ctr_lever = 0
    for session in vis.sessions:



        # first try to plot lever
        # 
        temp = os.path.split(session)[1].replace('.tif','')
        fname = os.path.join(vis.main_dir,vis.animal_id,'tif_files',temp,
                            temp+'_abspositions.npy')
        try:
            pos = np.float32(np.load(fname,allow_pickle=True))
            times = np.float32(np.load(fname.replace('positions','times')))
        except:
            continue


        ax=plt.subplot(2,4,7)
        temp = get_movements_lever_pos(pos, times)
        temp = np.array(temp)
        plt.title('lever')
        plt.semilogy()
        plt.semilogx()
        plt.xlim(bins[0],bins[-1])
        plt.ylim(1,top=1.5E3)
        plt.plot([3.0,3.0],[0,1E4],'--',linewidth=3,c='black',alpha=.6)
        if temp.shape[0]>0:
            lens = temp[:,1]-temp[:,0]
            y = np.histogram(lens, bins=bins)
            plt.plot(y[1][:-1],y[0],
                    c=colors_levers[ctr_lever],
                    alpha=alpha)

        ctr_lever+=1

        # plot features next
        temp = os.path.split(session)[1].replace('.tif','')
        fname = os.path.join(vis.main_dir,vis.animal_id,'tif_files',temp,
                            temp+'_0secNoMove_movements.npz')

        try:
            d = np.load(fname,allow_pickle=True)
        except:
            ctr+=1
            continue

        a = d['all_quiescent']
        print (a.shape)
        print (a[:3])

        f = d['feature_quiescent']
        print (f.shape, np.array(f[1]).shape)
        print (np.array(f[1][:3]))


        for k in range(len(f)):
            ax=plt.subplot(2,4,k+1)
            if labels[k]=='lever':  # skip lever and add it below
                continue
            plt.title(labels[k])
            temp = np.array(f[k])
            if temp.shape[0]==0:
                continue
            lens = temp[:,1]-temp[:,0]

            y = np.histogram(lens, bins=bins)
            plt.plot(y[1][:-1],y[0],
                    c=colors_sessions[ctr_clr],
                    alpha=alpha)
            
            plt.semilogy()
            plt.semilogx()
            plt.xlim(bins[0],bins[-1])
            plt.ylim(1,top=1.5E3)

            plt.plot([3.0,3.0],[0,1E4],'--',linewidth=3,c='black',alpha=.6)

        # plot all     
        ax=plt.subplot(2,4,8)
        plt.title("all")
        plt.semilogy()
        plt.semilogx()
        plt.xlim(bins[0],bins[-1])
        plt.ylim(1,top=1.5E3)
        temp = np.array(a)
        if temp.shape[0]>0:
            lens = temp[:,1]-temp[:,0]
            y = np.histogram(lens, bins=bins)
            plt.plot(y[1][:-1],y[0],
                    c=colors_sessions[ctr_clr],
                    alpha=alpha)
        print (ctr_clr, n_colors[animal_id])
        plt.plot([3.0,3.0],[0,1E4],'--',linewidth=3,c='black',alpha=.6)


        ctr_clr+=1
        print (ctr_clr, n_colors[animal_id])


    print ("# lever sessions: ", ctr_lever)

    plt.suptitle(vis.animal_id)

    if True:
        plt.savefig('/home/cat/'+animals[animal_id]+'.png',dpi=300)
        plt.close()
    else:
        plt.show()  
    
    
    
    

idx:  (139,)
lever   # of quiescent periods:  138
(4710, 2)
[[0.33333333 0.4       ]
 [0.4        0.46666667]
 [0.46666667 0.8       ]]
(7,) (786, 2)
[[0.33333333 5.        ]
 [5.         5.8       ]
 [5.8        5.86666667]]
0 30
1 30
idx:  (583,)
lever   # of quiescent periods:  582
(4374, 2)
[[0.33333333 0.4       ]
 [0.4        0.46666667]
 [0.46666667 0.53333333]]
(7,) (1228, 2)
[[1.33333333 1.4       ]
 [1.4        1.73333333]
 [1.73333333 1.8       ]]
1 30
2 30
idx:  (1751,)
lever   # of quiescent periods:  1750
(3270, 2)
[[0.33333333 0.4       ]
 [0.4        0.93333333]
 [0.93333333 1.06666667]]
(7,) (523, 2)
[[ 8.33333333 13.53333333]
 [13.53333333 46.86666667]
 [46.86666667 47.26666667]]
2 30
3 30
idx:  (8,)
lever   # of quiescent periods:  7
(4894, 2)
[[0.26666667 0.33333333]
 [0.33333333 0.4       ]
 [0.4        0.46666667]]
(7,) (403, 2)
[[ 6.66666667 15.06666667]
 [15.06666667 32.2       ]
 [32.2        42.46666667]]
3 30
4 30
idx:  (140,)
lever   # of quiescent periods: 

idx:  (244,)
lever   # of quiescent periods:  243
(5412, 2)
[[0.33333333 0.4       ]
 [0.4        0.46666667]
 [0.46666667 0.53333333]]
(7,) (1320, 2)
[[0.33333333 7.        ]
 [7.         7.06666667]
 [7.06666667 7.33333333]]
28 30
29 30
idx:  (514,)
lever   # of quiescent periods:  513
(6404, 2)
[[0.33333333 0.4       ]
 [0.4        0.46666667]
 [0.46666667 0.53333333]]
(7,) (3765, 2)
[[0.6        0.66666667]
 [0.66666667 0.73333333]
 [0.73333333 0.86666667]]
29 30
30 30
# lever sessions:  69
idx:  (63,)
lever   # of quiescent periods:  62
(4213, 2)
[[0.33333333 0.4       ]
 [0.4        2.13333333]
 [2.13333333 2.2       ]]
(7,) (749, 2)
[[0.33333333 3.06666667]
 [3.06666667 3.33333333]
 [3.33333333 3.4       ]]
0 12
1 12
idx:  (74,)
lever   # of quiescent periods:  73
(6067, 2)
[[0.4        0.53333333]
 [0.53333333 0.66666667]
 [0.66666667 0.73333333]]
(7,) (1039, 2)
[[ 0.4        47.33333333]
 [47.33333333 49.26666667]
 [49.26666667 49.4       ]]
1 12
2 12
idx:  (686,)
lever   # of

idx:  (1917,)
lever   # of quiescent periods:  1916
idx:  (1903,)
lever   # of quiescent periods:  1902
idx:  (2904,)
lever   # of quiescent periods:  2903
idx:  (1813,)
lever   # of quiescent periods:  1812
idx:  (1275,)
lever   # of quiescent periods:  1274
idx:  (1158,)
lever   # of quiescent periods:  1157
idx:  (1336,)
lever   # of quiescent periods:  1335
idx:  (1814,)
lever   # of quiescent periods:  1813
idx:  (2738,)
lever   # of quiescent periods:  2737
idx:  (2700,)
lever   # of quiescent periods:  2699
idx:  (3273,)
lever   # of quiescent periods:  3272
idx:  (1282,)
lever   # of quiescent periods:  1281
idx:  (1416,)
lever   # of quiescent periods:  1415
idx:  (3912,)
lever   # of quiescent periods:  3911
idx:  (1854,)
lever   # of quiescent periods:  1853
idx:  (2868,)
lever   # of quiescent periods:  2867
idx:  (2906,)
lever   # of quiescent periods:  2905
# lever sessions:  42
idx:  (708,)
lever   # of quiescent periods:  707
(6824, 2)
[[0.33333333 0.4       ]
 [0.4    

lever   # of quiescent periods:  4030
idx:  (4215,)
lever   # of quiescent periods:  4214
idx:  (2081,)
lever   # of quiescent periods:  2080
idx:  (2112,)
lever   # of quiescent periods:  2111
idx:  (2686,)
lever   # of quiescent periods:  2685
idx:  (2614,)
lever   # of quiescent periods:  2613
idx:  (1675,)
lever   # of quiescent periods:  1674
idx:  (2500,)
lever   # of quiescent periods:  2499
idx:  (2841,)
lever   # of quiescent periods:  2840
idx:  (1694,)
lever   # of quiescent periods:  1693
idx:  (1768,)
lever   # of quiescent periods:  1767
idx:  (2250,)
lever   # of quiescent periods:  2249
idx:  (2271,)
lever   # of quiescent periods:  2270
idx:  (1914,)
lever   # of quiescent periods:  1913
idx:  (1698,)
lever   # of quiescent periods:  1697
idx:  (3250,)
lever   # of quiescent periods:  3249
idx:  (2691,)
lever   # of quiescent periods:  2690
idx:  (2629,)
lever   # of quiescent periods:  2628
idx:  (1658,)
lever   # of quiescent periods:  1657
# lever sessions:  42
idx:

(12231, 2)
[[0.33333333 0.4       ]
 [0.4        0.46666667]
 [0.46666667 0.53333333]]
(7,) (2638, 2)
[[0.4        0.8       ]
 [0.8        0.93333333]
 [0.93333333 2.46666667]]
30 70
31 70
idx:  (1490,)
lever   # of quiescent periods:  1489
(13712, 2)
[[0.33333333 0.4       ]
 [0.4        0.6       ]
 [0.6        0.66666667]]
(7,) (2030, 2)
[[0.33333333 8.26666667]
 [8.26666667 8.4       ]
 [8.4        8.46666667]]
31 70
32 70
idx:  (550,)
lever   # of quiescent periods:  549
(7079, 2)
[[0.26666667 0.33333333]
 [0.33333333 0.46666667]
 [0.46666667 0.8       ]]
(7,) (1860, 2)
[[0.26666667 3.13333333]
 [3.13333333 3.26666667]
 [3.26666667 3.8       ]]
32 70
33 70
idx:  (1434,)
lever   # of quiescent periods:  1433
(9174, 2)
[[0.26666667 0.33333333]
 [0.33333333 0.4       ]
 [0.4        0.46666667]]
(7,) (1859, 2)
[[0.26666667 0.33333333]
 [0.33333333 0.4       ]
 [0.4        2.2       ]]
33 70
34 70
idx:  (41,)
lever   # of quiescent periods:  40
(9781, 2)
[[0.33333333 0.4       ]
 [0.4

idx:  (1906,)
lever   # of quiescent periods:  1905
idx:  (883,)
lever   # of quiescent periods:  882
idx:  (1256,)
lever   # of quiescent periods:  1255
idx:  (2099,)
lever   # of quiescent periods:  2098
idx:  (2036,)
lever   # of quiescent periods:  2035
idx:  (2171,)
lever   # of quiescent periods:  2170
idx:  (2383,)
lever   # of quiescent periods:  2382
idx:  (2395,)
lever   # of quiescent periods:  2394
idx:  (626,)
lever   # of quiescent periods:  625
idx:  (3283,)
lever   # of quiescent periods:  3282
idx:  (852,)
lever   # of quiescent periods:  851
idx:  (1608,)
lever   # of quiescent periods:  1607
idx:  (1461,)
lever   # of quiescent periods:  1460
idx:  (2150,)
lever   # of quiescent periods:  2149
idx:  (2310,)
lever   # of quiescent periods:  2309
idx:  (2350,)
lever   # of quiescent periods:  2349
idx:  (3010,)
lever   # of quiescent periods:  3009
idx:  (2632,)
lever   # of quiescent periods:  2631
idx:  (1746,)
lever   # of quiescent periods:  1745
idx:  (2411,)
leve

In [35]:
#####################################################
####### SCATTER PLOTS INTER BODY MOVEMENT DIFFS #####
#####################################################

# Fig 3I 

# 
session_id = 'all'
animal_id = 'IA1'
animal_ids = ['IA1','IA2','IA3','IJ1','IJ2','AQ2']
no_movement = 3

# 
#print (sessions)
results = []
for a in range(6):
    results.append([])
    for k in range(7):
        results[a].append([])
        for p in range(7):
            results[a][k].append([])


vis.main_dir = data_dir

#
uncorrelated = np.zeros((7,7),'int32')
correlated = np.zeros((7,7),'int32')
rejected = np.zeros((7,7),'int32')
for a, animal_id in enumerate(animal_ids):
    sessions = get_sessions(vis.main_dir,
                            animal_id,
                            session_id)

    for session in sessions:
        print (session)
        temp_, code_04_times, feature_quiescent = load_trial_times_whole_stack(
                                                                         vis.main_dir,
                                                                         animal_id,
                                                                         session,
                                                                         no_movement)
        if temp_ is None:
            print (" Res is None, continuing")
            continue

        # 
        for k in range(7):
            for p in range(0,7,1):
                if len(np.array(feature_quiescent[p]).shape)==2 and len(np.array(feature_quiescent[k]).shape)==2:
                    t = np.array(feature_quiescent[k])[:,1]
                    u =  np.array(feature_quiescent[p])[:,1]

                    if t.shape[0]>=10 and u.shape[0]>=10:
                        res = pycorrelate.pcorrelate(t, u, bins=bins)

                        res[:100] = res[-100:] =0
                        std = np.std(res,0)
                        if std>(np.max(res)/5):
                            shift = np.nan
                            uncorrelated[k,p]+=1

                        else:
                            correlated[k,p]+=1
                            argmax = np.argmax(res)
                            shift = bins[argmax]
                        if np.abs(shift)<5:
                            results[a][k][p].append(shift) 
                else:
                    rejected[k,p]+=1
                #plt.scatter(k, shift,c='black', alpha=.4)
            else:
                pass

############################################
####### SCATTER PLOTS ALL CORRELATIONS #####
############################################
fig=plt.figure(figsize=(10,8))
clrs1 = ['black','blue','red','green','brown','magenta']

# loop over animals
for a, animal_id in enumerate(animal_ids):
    # loop over feature 1
    for k in range(7):
        # loop over feature 2
        for p in range(k+1,7,1):
            ax=plt.subplot(7,7,k*7+p+1)
            plt.ylim(-5,5)
            res = np.array(results[a][k][p])
            
            # mouse 2 is missing tongue tracking... just repalce with jaw or previous mouse for visualization purpose
            if (a==1):
                if (p==5):
                    res = np.array(results[a][k][3])
                elif k==5 and p==6:
                    res = np.array(results[a-1][k][6])

                #res+= np.random.rand(res.shape[0])-0.5
            
                print (a,k,p, res)
            
            plt.scatter(np.zeros(res.shape[0])+a*.1, 
                        res, 
                        c=clrs1[a], 
                        alpha=.4)
            plt.plot([-0.1,0.6],[0,0],'--',c='black',alpha=.4)
            if k==0:
                plt.title(labels[p],fontsize=13)
            if p == k+1:
                plt.ylabel(labels[k],fontsize=13)
            plt.xticks([])
            plt.yticks([])
            
            plt.xlim(-0.1,0.6)
            plt.fill_between(np.arange(-0.1,0.1*7,.1),-3, +3,color='black', alpha=.01)
        
        print ('')

if False:
    plt.savefig('/home/cat/correlated_bodyparts.png',dpi=300)
    plt.close()
else:
    plt.show()    

IA1pm_Feb1_30Hz
Lever to [Ca] shift:  2.566666666666667
DLC to [Ca] shift:  -10.0
IA1pm_Feb2_30Hz
Lever to [Ca] shift:  2.5
DLC to [Ca] shift:  -10.0
IA1pm_Feb3_30Hz
Lever to [Ca] shift:  2.033333333333333
DLC to [Ca] shift:  -10.0
IA1pm_Feb4_30Hz
Lever to [Ca] shift:  2.4
DLC to [Ca] shift:  -10.0
IA1pm_Feb5_30Hz
Lever to [Ca] shift:  2.2666666666666666
DLC to [Ca] shift:  -10.0
IA1pm_Feb9_30Hz
Lever to [Ca] shift:  2.3333333333333335
DLC to [Ca] shift:  -10.0
IA1pm_Feb10_30Hz
Lever to [Ca] shift:  2.7333333333333334
DLC to [Ca] shift:  -10.0
IA1pm_Feb11_30Hz
Lever to [Ca] shift:  2.533333333333333
DLC to [Ca] shift:  -10.0
IA1pm_Feb12_30Hz
Lever to [Ca] shift:  2.466666666666667
DLC to [Ca] shift:  -10.0
IA1pm_Feb15_30Hz
Lever to [Ca] shift:  2.6666666666666665
DLC to [Ca] shift:  -10.0
IA1pm_Feb16_30Hz
Lever to [Ca] shift:  2.5
DLC to [Ca] shift:  -10.0
IA1pm_Feb17_30Hz
Lever to [Ca] shift:  2.8333333333333335
DLC to [Ca] shift:  -10.0
IA1pm_Feb18_30Hz
 Res is None, continuing
IA1pm

Lever to [Ca] shift:  2.8666666666666667
DLC to [Ca] shift:  -10.0
IJ1pm_Feb2_30Hz
Lever to [Ca] shift:  2.3666666666666667
DLC to [Ca] shift:  -10.0
IJ1pm_Feb3_30Hz
Lever to [Ca] shift:  2.7
DLC to [Ca] shift:  -10.0
IJ1pm_Feb4_30Hz
Lever to [Ca] shift:  2.1
DLC to [Ca] shift:  -10.0
IJ1pm_Feb5_30Hz
Lever to [Ca] shift:  2.1333333333333333
DLC to [Ca] shift:  -10.0
IJ1pm_Feb9_30Hz
Lever to [Ca] shift:  2.6333333333333333
DLC to [Ca] shift:  -10.0
IJ1pm_Feb10_30Hz
Lever to [Ca] shift:  2.1333333333333333
DLC to [Ca] shift:  -10.0
IJ1pm_Feb11_30Hz
Lever to [Ca] shift:  2.4
DLC to [Ca] shift:  -10.0
IJ1pm_Feb12_30Hz
Lever to [Ca] shift:  3.1
DLC to [Ca] shift:  -10.0
IJ1pm_Feb16_30Hz
Lever to [Ca] shift:  2.533333333333333
DLC to [Ca] shift:  -10.0
IJ1pm_Feb17_30Hz
 Res is None, continuing
IJ1pm_Feb18_30Hz
 Res is None, continuing
IJ1pm_Feb19_30Hz
 Res is None, continuing
IJ1pm_Feb22_30Hz
 Res is None, continuing
IJ1pm_Feb23_30Hz
 Res is None, continuing
IJ1pm_Feb24_30Hz
 Res is None, co

AQ2am_Feb10_30Hz
Lever to [Ca] shift:  2.1
DLC to [Ca] shift:  -10.0
AQ2am_Feb11_30Hz
Lever to [Ca] shift:  2.7666666666666666
DLC to [Ca] shift:  -10.0
AQ2am_Feb12_30Hz
Lever to [Ca] shift:  2.8333333333333335
DLC to [Ca] shift:  -10.0
AQ2am_Feb15_30Hz
Lever to [Ca] shift:  2.8333333333333335
DLC to [Ca] shift:  -10.0
AQ2am_Feb16_30Hz
Lever to [Ca] shift:  2.8333333333333335
DLC to [Ca] shift:  -10.0
AQ2am_Feb17_30Hz
Lever to [Ca] shift:  2.3666666666666667
DLC to [Ca] shift:  -10.0
AQ2am_Feb18_30Hz
 Res is None, continuing
AQ2am_Feb19_30Hz
 Res is None, continuing
AQ2am_Feb22_30Hz
 Res is None, continuing
AQ2am_Feb23_30Hz
 Res is None, continuing
AQ2am_Feb25_30Hz
 Res is None, continuing
AQ2am_Feb26_30Hz
 Res is None, continuing
AQ2am_Feb29_30Hz
 Res is None, continuing
AQ2am_Mar1_30Hz
 Res is None, continuing
AQ2am_Mar2_30Hz
 Res is None, continuing
AQ2am_Mar3_30Hz
 Res is None, continuing
AQ2pm_Mar7_Day3_30Hz
 Res is None, continuing
AQ2pm_Mar9_Day5_30Hz
 Res is None, continuing
AQ

In [105]:
#####################################################################
####### SCATTER PLOTS # OF BODY MOVEMENT PER SESSION PER ANIMAL #####
#####################################################################

def polyfit(x, y, degree=1):
    results = {}

    coeffs = np.polyfit(x, y, degree)

     # Polynomial Coefficients
    results['polynomial'] = coeffs.tolist()

    # r-squared
    p = np.poly1d(coeffs)
    
    # fit values, and mean
    yhat = p(x)                         # or [p(z) for z in x]
    ybar = np.sum(y)/len(y)          # or sum(y)/len(y)
    ssreg = np.sum((yhat-ybar)**2)   # or sum([ (yihat - ybar)**2 for yihat in yhat])
    sstot = np.sum((y - ybar)**2)    # or sum([ (yi - ybar)**2 for yi in y])
    results['determination'] = ssreg / sstot

    # find trends from first to last
    diff = yhat[-1]-yhat[0]
    
    from scipy.stats import pearsonr
 
    # calculate Pearson's correlation
    corr_pearson, _ = pearsonr(y, yhat)
    
    from sklearn.metrics import mean_squared_error
    
    rmse = mean_squared_error(y, yhat)



    return coeffs, ssreg/sstot, corr_pearson, diff, rmse


cmaps = ['Greys_r','Purples_r','Blues_r',"Greens_r","Reds_r",
         "Oranges_r","PuRd_r","PuBu_r",
        "YlGn_r","RdPu_r"]


labels = ['left_paw', 
'right_paw',
'nose',
'jaw',
'right_ear',
'tongue',
'lever']

clrs = ['green','blue','yellow','cyan','red','magenta','black']

# 
session_id = 'all'
animal_id = 'IA1'
animal_ids = ['IA1','IA2','IA3','IJ1','IJ2','AQ2']
#animal_ids = ["IA1"]
no_movement = 3

#
rsq_array = []

first_last = []
mse_array = []

fig=plt.figure(figsize=(10,6))
for a, animal_id in enumerate(animal_ids):
    ax=plt.subplot(2,3,a+1)
    sessions = get_sessions(vis.main_dir,
                            animal_id,
                            session_id)
    
    # 
    rsq_array.append([])
    first_last.append([])
    mse_array.append([])
    
    # 
    ctr_sess=0
    locs = []
    for k in range(7):
        locs.append([])
    for session in sessions:
        print (session)
        temp_, code_04_times, feature_quiescent = load_trial_times_whole_stack(
                                                                         vis.main_dir,
                                                                         animal_id,
                                                                         session,
                                                                         no_movement)
        
        if feature_quiescent is not None:
            for k in range(len(feature_quiescent)):
                if a==1 and k==5:
                    continue
            
                temp = np.array(feature_quiescent[k])
                plt.scatter(ctr_sess, temp.shape[0],
                            s=100,
                            c=clrs[k],
                            edgecolor='black',
                            alpha=.3)
                
                locs[k].append([ctr_sess,temp.shape[0]])
                
            ctr_sess+=1
    # 
    for k in range(6):
        temp = np.array(locs[k])
        if a==1 and k==5:
            first_last[a].append(np.nan)
            continue
        
        coef, rsq, corr_pearson, diff, mse = polyfit(temp[:,0],temp[:,1],1)
        print (a,k," rsq: ", rsq, "  Pearson corr: ", corr_pearson)
        
        #coef = np.polyfit(temp[:,0],temp[:,1],1)
        poly1d_fn = np.poly1d(coef) 
        
        rsq_array[a].append(rsq)
        first_last[a].append(diff)
        mse_array[a].append(mse)
        
        # poly1d_fn is now a function which takes in x and returns an estimate for y
        x=np.arange(len(locs[k]))
        Y = poly1d_fn(x)
        plt.plot(x, Y, linewidth=5, c=clrs[k], label=str(round(corr_pearson,2))) #labels[k])
       
    plt.ylim(0,225)
    plt.xlim(0,x[-1])
    plt.xticks([])
    plt.yticks([])
    #if a==0:
    plt.legend(fontsize=16)        
                
    print ('')
plt.show()

IA1pm_Feb1_30Hz
Lever to [Ca] shift:  2.566666666666667
DLC to [Ca] shift:  -10.0
IA1pm_Feb2_30Hz
Lever to [Ca] shift:  2.5
DLC to [Ca] shift:  -10.0
IA1pm_Feb3_30Hz
Lever to [Ca] shift:  2.033333333333333
DLC to [Ca] shift:  -10.0
IA1pm_Feb4_30Hz
Lever to [Ca] shift:  2.4
DLC to [Ca] shift:  -10.0
IA1pm_Feb5_30Hz
Lever to [Ca] shift:  2.2666666666666666
DLC to [Ca] shift:  -10.0
IA1pm_Feb9_30Hz
Lever to [Ca] shift:  2.3333333333333335
DLC to [Ca] shift:  -10.0
IA1pm_Feb10_30Hz
Lever to [Ca] shift:  2.7333333333333334
DLC to [Ca] shift:  -10.0
IA1pm_Feb11_30Hz
Lever to [Ca] shift:  2.533333333333333
DLC to [Ca] shift:  -10.0
IA1pm_Feb12_30Hz
Lever to [Ca] shift:  2.466666666666667
DLC to [Ca] shift:  -10.0
IA1pm_Feb15_30Hz
Lever to [Ca] shift:  2.6666666666666665
DLC to [Ca] shift:  -10.0
IA1pm_Feb16_30Hz
Lever to [Ca] shift:  2.5
DLC to [Ca] shift:  -10.0
IA1pm_Feb17_30Hz
Lever to [Ca] shift:  2.8333333333333335
DLC to [Ca] shift:  -10.0
IA1pm_Feb18_30Hz
IA1pm_Feb19_30Hz
IA1pm_Feb22_3

  idx = np.where(codes==code)[0]


Lever to [Ca] shift:  2.9
DLC to [Ca] shift:  -10.0
IA2pm_Feb16_30Hz
Lever to [Ca] shift:  2.7333333333333334
DLC to [Ca] shift:  -10.0
IA2pm_Feb17_30Hz
Lever to [Ca] shift:  2.7
DLC to [Ca] shift:  -10.0
IA2pm_Feb18_30Hz
IA2pm_Feb19_30Hz
IA2pm_Feb22_30Hz
IA2pm_Feb23_30Hz
IA2pm_Feb24_30Hz
IA2pm_Feb25_30Hz
IA2pm_Feb26_30Hz
IA2pm_Feb29_30Hz
IA2pm_Mar1_30Hz
IA2pm_Mar2_30Hz
IA2pm_Mar3_30Hz
IA2am_Mar4_30Hz
IA2am_Mar7_30Hz
IA2pm_Mar8_30Hz
IA2am_Mar9_30Hz
IA2am_Mar10_30Hz
IA2am_Mar11_30Hz
IA2pm_Mar14_30Hz
IA2am_Mar15_30Hz
IA2pm_Mar16_30Hz
IA2pm_Mar17_30Hz
IA2pm_Mar18_30Hz
IA2pm_Mar21_30Hz
IA2pm_Mar23_30Hz
IA2pm_Mar24_30Hz
IA2pm_Mar29_30Hz
IA2pm_Mar30_30Hz
IA2pm_Mar31_30Hz
IA2pm_Apr1_30Hz
IA2pm_Apr4_30Hz
IA2pm_Apr5_30Hz
IA2pm_Apr6_30Hz
1 0  rsq:  0.0044965637443398765   Pearson corr:  0.06705642209617138
1 1  rsq:  0.014501805570249427   Pearson corr:  0.12042344277693358
1 2  rsq:  0.024550676719134912   Pearson corr:  0.15668655564257902
1 3  rsq:  0.0014118733336292194   Pearson corr:  0.03

  idx = np.where(codes==code)[0]


Lever to [Ca] shift:  3.3
DLC to [Ca] shift:  -10.0
AQ2pm_Dec10_30Hz
Lever to [Ca] shift:  1.8
DLC to [Ca] shift:  -10.0
AQ2am_Dec11_30Hz
Lever to [Ca] shift:  2.1666666666666665
DLC to [Ca] shift:  -10.0
AQ2pm_Dec14_30Hz
Lever to [Ca] shift:  2.6
DLC to [Ca] shift:  -10.0
AQ2am_Dec14_30Hz
Lever to [Ca] shift:  2.3333333333333335
DLC to [Ca] shift:  -10.0
AQ2pm_Dec16_30Hz
Lever to [Ca] shift:  2.2666666666666666
DLC to [Ca] shift:  -10.0
AQ2am_Dec17_30Hz
Lever to [Ca] shift:  2.066666666666667
DLC to [Ca] shift:  -10.0
AQ2pm_Dec17_30Hz
Lever to [Ca] shift:  2.2
DLC to [Ca] shift:  -10.0
AQ2am_Dec18_30Hz
Lever to [Ca] shift:  2.7
DLC to [Ca] shift:  -10.0
AQ2pm_Dec18_30Hz
Lever to [Ca] shift:  2.533333333333333
DLC to [Ca] shift:  -10.0
AQ2am_Dec21_30Hz
Lever to [Ca] shift:  2.533333333333333
DLC to [Ca] shift:  -10.0
AQ2am_Dec22_30Hz
Lever to [Ca] shift:  2.5
DLC to [Ca] shift:  -10.0
AQ2am_Dec23_30Hz
Lever to [Ca] shift:  1.9333333333333333
DLC to [Ca] shift:  -10.0
AQ2am_Dec28_30Hz
L

In [119]:
fig=plt.figure()
ax=plt.subplot(121)
for p in range(len(rsq_array)):
    arr = []
    for k in range(len(rsq_array[p])):
        print (k,p,rsq_array[p][k])
        temp = rsq_array[p][k]
        arr.append(temp)
        plt.scatter(p, temp, 
                    s=300,
                    edgecolor='black',
                    c=clrs[p],
                    alpha=.2)
    arr=np.array(arr)
    print (arr)
    plt.scatter(p, np.nanmean(arr,0), 
                s=800,
                edgecolor='black',
                c=clrs[p],
               alpha=1)    
    
# ax=plt.subplot(132)
# for p in range(len(rmse_array)):
#     arr = []
#     for k in range(len(rmse_array[p])):
#         temp = np.sqrt(mse_array[p][k])
#         arr.append(temp)
#         plt.scatter(p, temp, 
#                     s=100,
#                     edgecolor='black',
#                     c=clrs[p],
#                     alpha=.2)
#     arr=np.array(arr)
#     print (arr)
#     plt.scatter(p, np.nanmean(arr,0), 
#                 s=300,
#                 edgecolor='black',
#                 c=clrs[p],
#                alpha=1)    
plt.xlim(-0.1,5.1)
plt.ylim(0,1.0)
plt.xticks([])
plt.yticks([])

ax=plt.subplot(122)
for p in range(len(rsq_array)):
    arr = []
    for k in range(len(rsq_array[p])):
        temp = first_last[p][k]
        arr.append(temp)
        plt.scatter(p, temp, 
                    s=300,
                    edgecolor='black',
                    c=clrs[p],
                    alpha=.2)
    arr=np.array(arr)
    print (arr)
    plt.scatter(p, np.nanmean(arr,0), 
                s=800,
                edgecolor='black',
                c=clrs[p],
               alpha=1)       
plt.plot([-0.1,5.1],[0,0],'--',c='black',linewidth=5,alpha=.5)
plt.xlim(-0.1,5.1)
plt.ylim(-100,100)
plt.xticks([])
plt.yticks([])
plt.show()

0 0 0.11619049376105949
1 0 0.006296256295899559
2 0 0.19169811822809382
3 0 0.3220963839014383
4 0 0.020690802238381885
5 0 0.004620009798426293
[0.11619049 0.00629626 0.19169812 0.32209638 0.0206908  0.00462001]
0 1 0.0044965637443398765
1 1 0.014501805570249427
2 1 0.024550676719134912
3 1 0.0014118733336292194
4 1 0.030899959370329227
[0.00449656 0.01450181 0.02455068 0.00141187 0.03089996]
0 2 0.0016143457651371658
1 2 0.0006045105789351439
2 2 0.11874000681502415
3 2 0.16347554054428176
4 2 0.20323133699473983
5 2 0.16282859571889136
[0.00161435 0.00060451 0.11874001 0.16347554 0.20323134 0.1628286 ]
0 3 0.005511088556071228
1 3 0.15817722624330452
2 3 0.0001940094235809273
3 3 0.003505639070848288
4 3 0.3801782413569438
5 3 0.024141204420908315
[5.51108856e-03 1.58177226e-01 1.94009424e-04 3.50563907e-03
 3.80178241e-01 2.41412044e-02]
0 4 0.009521611552537026
1 4 0.0036939890710383275
2 4 0.1631670315921181
3 4 0.8260402391927234
4 4 0.8489306269043384
5 4 0.22056652529606702
[

In [104]:
print (rmse_array)

[[565.5555951056729, 673.7410752688172, 672.7945618588553, 504.3188629341243, 690.0919886293412, 1310.8811617846989], [987.7813714063714, 638.7397047397048, 1187.8466394716395, 476.4652292152293, 1160.6318958818958], [184.51157024793395, 923.6231404958679, 555.718181818182, 998.8793388429754, 453.2099173553719, 1643.9140495867769], [1382.429090909091, 271.100606060606, 976.0206060606058, 422.0751515151516, 1202.3054545454547, 2989.2896969696967], [496.5652892561984, 150.68099173553722, 442.20743801652895, 184.6446280991736, 169.63966942148753, 1111.794214876033], [1099.8564959946013, 746.5112118674781, 1623.933135708662, 1125.3112328630168, 925.4060467150732, 995.292356874164]]


In [109]:
# TEST

times = np.load('/media/cat/4TBSSD/yuki/IA1/tif_files/IA1pm_Feb2_30Hz/IA1pm_Feb2_30Hz_abstimes.npy')
pos = np.load('/media/cat/4TBSSD/yuki/IA1/tif_files/IA1pm_Feb1_30Hz/IA1pm_Feb1_30Hz_abspositions.npy')

chunks = get_movements_lever_pos(pos,
                                 times)
print (np.array(chunks)[:10])

print (chunks[:,1]-chunks[:,0])


idx:  (139,)
lever   # of quiescent periods:  138
[[0.26710892 0.27502704]
 [0.27502704 0.28320885]
 [0.28320885 0.29155397]
 [0.29155397 0.30801296]
 [0.30801296 0.31631207]
 [0.31631207 0.324301  ]
 [0.324301   0.33246207]
 [0.33246207 0.340698  ]
 [0.340698   7.28061558]
 [7.28061558 7.28894892]]
[7.91811943e-03 8.18181038e-03 8.34512711e-03 1.64589882e-02
 8.29911232e-03 7.98892975e-03 8.16106796e-03 8.23593140e-03
 6.93991758e+00 8.33333333e-03 8.33333333e-03 8.33333333e-03
 8.33333333e-03 8.33333333e-03 3.89740100e+01 8.33333333e-03
 8.33333333e-03 8.33333333e-03 8.33333333e-03 8.33333333e-03
 8.33333333e-03 8.33333333e-03 8.33333333e-03 8.33333333e-03
 8.33333333e-03 8.33333333e-03 1.82649083e+02 8.33333333e-03
 8.33333333e-03 8.33333333e-03 8.33333333e-03 8.33333333e-03
 8.33333333e-03 8.33333333e-03 5.00000000e-02 8.33333333e-03
 8.33333333e-03 8.33333333e-03 2.50000000e-01 8.33333333e-03
 1.56416667e+01 8.33333333e-03 4.16666667e-02 8.33333333e-03
 8.33333333e-03 8.33333333e-

In [50]:
#############################################
######### MAKE A MOVIE WITH DLC TRACES ######
#############################################

# #
# fname_traces = ''
# traces = np.load(fname_traces)
print (traces.shape)

# 
fname_video = '/media/cat/4TBSSD/yuki/IA1/vids/prestroke/IA1pm_Feb1_30Hz.mp4'
movements = np.zeros((traces.shape[0],traces.shape[1]),'int32')

# 
start = 0*15+1
end = start+60*15
make_video_dlc(traces_original,
               movements,
               fname_video,
               start,
               end)

  3%|▎         | 30/900 [00:00<00:02, 293.48it/s]

(7, 20059, 3)
Traces:  (7, 20059, 3)
loadin gmovie:  /media/cat/4TBSSD/yuki/IA1/vids/prestroke/IA1pm_Feb1_30Hz.mp4
Frame size read:  (360, 640, 3)


100%|██████████| 900/900 [00:03<00:00, 286.22it/s]


In [30]:
d= np.load('/media/cat/4TBSSD/yuki/AQ2/tif_files/AQ2pm_Jan18_30Hz/AQ2pm_Jan18_30Hz_whole_stack_trial_ROItimeCourses_15sec_pca30components.npy')
print (d.shape)

(40000, 30)
