In [1]:
import matplotlib
%matplotlib tk
%autosave 180
%load_ext autoreload
%autoreload 2

#
import matplotlib.pyplot as plt
from IPython.core.display import display, HTML
display(HTML("<style>.container { width:100% !important; }</style>"))

# 
import matplotlib.cm as cm

# 
import numpy as np
import os
from tqdm import trange
import parmap
import glob
from sklearn.decomposition import PCA
#import umap

# 
from scipy.io import loadmat
import scipy

# 

Autosaving every 180 seconds


In [323]:
######################################
############ FILTER DATA #############
######################################
import scipy
import scipy.ndimage

class CentreBody():
    
    def __init__(self):
        
       # 
       self.node_names = ['nose',          # 0
                      'lefteye',       # 1
                      'righteye',      # 2
                      'leftear',       # 3
                      'rightear',      # 4
                      'spine1',        # 5
                      'spine2',        # 6
                      'spine3',        # 7
                      'spine4',        # 8
                      'spine5',        # 9
                      'tail1',         # 10
                      'tail2',         # 11
                      'tail3',         # 12
                      'tail4']         # 13

       self.feature_ids = np.array([0,5,6,7,8,9])       
        
        
    def get_fnames(self):

        self.fnames = glob.glob(self.root_dir+"/*_compressed.npy")

    
    def filter_data(self):
        
        print ("  ... median filtering ...")

            
        if self.parallel:
            parmap.map(self.filter_data1, self.fnames,
                      pm_processes=8,
                      pm_pbar=True)
        else:
            for fname in fnames:
                pass

    def filter_data1(self, fname):    

        fname_out = fname.replace('.npy','_median_filtered.npy')
        if os.path.exists(fname_out)==False:
            data = np.load(fname)

            data_filtered = data.copy()
            for a in range(data.shape[1]):
                for f in range(data.shape[2]):
                    for l in range(data.shape[3]):
                        data_filtered[:,a,f,l] = filter_data2(data[:,a,f,l])

            np.save(fname_out, data_filtered)


    def filter_data2(self, x, width=25):

        # replace data with previous 
        for k in range(1000):
            idx = np.where(np.isnan(x))[0]
            if idx.shape[0]==0:
                break

            if idx[0]==0:
                idx=idx[1:]
            x[idx] = x[idx-1]

        x = scipy.ndimage.median_filter(x, width=25)

        return x
    

    def reject_outliers2(self, x,y,
                        max_dist_pairwise,
                        max_dist_all=100):  # number of deviations away

        ''' Function returns indexes for which [x,y] array points are close to at least 2 other points

            Goal: to generate very clean data which has 6 body features well connected for downstream analysis

        '''

        # method 2: explicitly reject points that are > max distance from nearest 2 points
        temp = np.vstack((x,y))
        dists = scipy.spatial.distance.cdist(temp.T, temp.T)

        # first check points inside the array to ensure they have 2 close neighbours
        # if they don't, remove them so other points can't be connected to them.
        idx_far = []
        for k in range(1,temp.shape[1]-1,1):
            #idx = np.where(dists[k]<=max_dist_pairwise)[0]
            temp = dists[k]
            if np.abs(temp[k]-temp[k-1])>max_dist_pairwise or np.abs(temp[k]-temp[k+1])>max_dist_pairwise:
                idx_far.append(k)
                dists[:,k]= 1E3

        # check start and end points to ensure they have nearby val
        if np.abs(dists[0,1])>max_dist_pairwise:
            idx_far.append(0)
            #print (dists[0], 'excluded ', 0)

        if np.abs(dists[dists.shape[1]-1,dists.shape[1]-2])>max_dist_pairwise:
            idx_far.append(dists.shape[1]-1)
            #print (dists[0], 'excluded ', dists.shape[1]-1)


        x[idx_far] = np.nan
        y[idx_far] = np.nan

        return x, y


    def reject_outliers1(self, fname, feature_ids, max_dist):
        
        fname2 = fname.replace('.npy','_median_filtered.npy')
        
        fname_out = fname2.replace('.npy','_outliers.npy')
        
        if os.path.exists(fname_out)==False:
            
            data = np.load(fname2)
            for f in range(0,data.shape[0],1):

                for k in range(data.shape[1]):
                    # 
                    x = data[f,k,feature_ids,0]
                    y = data[f,k,feature_ids,1]

                    x, y = self.reject_outliers2(x,y, max_dist)

                    data[f,k,feature_ids,0] = x
                    data[f,k,feature_ids,1] = y

            np.save(fname_out, data)

    def reject_outliers(self, max_dist=40):
        
        print ("  ... rejecting outliers....")
        
        self.fnames = glob.glob(self.root_dir+"/*_compressed.npy")

        if self.parallel:
            parmap.map(self.reject_outliers1, self.fnames, 
                       self.feature_ids,
                       max_dist,
                       pm_processes=8,
                       pm_pbar=True)
        else:
            for fname in fnames:
                pass


    def centre_and_align2(self, data, frame, centre_pt=0):

        if True:
            # centre the data on the nose
            data[:,0] -= data[centre_pt,0]
            data[:,1] -= data[centre_pt,1]

            # get angle between +x axis and head location (i.e. 2nd position)
            t = -np.arctan2(*data[1].T[::-1])-np.pi/2

            # get rotation
            rotmat = np.array([[np.cos(t), -np.sin(t)], 
                               [np.sin(t),  np.cos(t)]])

            # Apply rotation to each row of m
            m2 = (rotmat @ data.T).T

            return m2

    #     # use PCA alignment:
    #     else: 
    #         # Fit the PCA object, but do not transform the data
    #         pca = PCA(2)
    #         try:
    #             pca.fit(data)
    #         except:
    #             print ("frame: ", frame, "  data: crash:", data)
    #             return None

    #         # pca.components_ : array, shape (n_components, n_features)
    #         # cos theta
    #         ct = pca.components_[0, 0]
    #         # sin theta
    #         st = pca.components_[0, 1]

    #         # One possible value of theta that lies in [0, pi]
    #         t = np.arccos(ct)

    #         t+=np.pi/2.

    #         # If t is in quadrant 1, rotate CLOCKwise by t
    #         if ct > 0 and st > 0:
    #             t *= -1
    #         # If t is in Q2, rotate COUNTERclockwise by the complement of theta
    #         elif ct < 0 and st > 0:
    #             t = np.pi - t
    #         # If t is in Q3, rotate CLOCKwise by the complement of theta
    #         elif ct < 0 and st < 0:
    #             t = -(np.pi - t)
    #         # If t is in Q4, rotate COUNTERclockwise by theta, i.e., do nothing
    #         elif ct > 0 and st < 0:
    #             pass

    #         # Manually build the ccw rotation matrix
    #         rotmat = np.array([[np.cos(t), -np.sin(t)], 
    #                            [np.sin(t),  np.cos(t)]])

    #         # Apply rotation to each row of m
    #         m2 = (rotmat @ data.T).T

    #         # Center the rotated point cloud at (0, 0)
    #         m2 -= m2.mean(axis=0)


    #         # make sure data faces up
    #         if m2[0,1]<m2[1,1]:
    #             m2[:,1] = m2[:,1][::-1]

        #return m2
    
    def centre_and_align(self):
        
        print ("  ... center and aligning ...")

        self.fnames = glob.glob(self.root_dir+"/*_compressed.npy")

        if self.parallel:
            parmap.map(self.centre_and_align1, self.fnames, 
                       self.feature_ids,
                       pm_processes=8,
                       pm_pbar=True)
        else:
            for fname in fnames:
                pass
            

    def centre_and_align1(self, fname,
                         feature_ids):

        fname2 = fname.replace('.npy','_median_filtered_outliers.npy')
        
        fname_out = fname2.replace('.npy','_centre_aligned.npy')
        #fname_out_good_only = fname2.replace('.npy','_centre_aligned.npy')
        
        if os.path.exists(fname_out)==False:
            
            data = np.load(fname2)

            # 
            centre_pt = 0

            features_full = np.zeros((data.shape[0],data.shape[1],feature_ids.shape[0],2), 
                                          'float32')+np.nan

            features_array = []
            for k in range(4):
                features_array.append([])

            for f in range(0,data.shape[0],1):

                # loop over each animal
                for k in range(data.shape[1]):

                    x = data[f,k,feature_ids,0]
                    y = data[f,k,feature_ids,1]

                    idx = np.where(np.isnan(x))[0]
                    if idx.shape[0]==0:
                        #print (f, k, x.shape)

                        locs = np.vstack((x,y)).T

                        # centre and align data
                        locs_pca = self.centre_and_align2(locs,f,centre_pt)

                        if locs_pca is not None:
                            idx = np.where(np.isnan(locs_pca))[0]

                            if idx.shape[0]>0:
                                continue
                                
                            features_full[f,k] = locs_pca
                            features_array[k].append(locs_pca)

            np.save(fname_out, features_full)

            
import sklearn.experimental
from sklearn.experimental import enable_iterative_imputer
from sklearn.impute import IterativeImputer

#print (len(features_array))

import seaborn as sns
import pandas as pd

def evaluate_imputation_error(features_array, animal_id, res, idx_train, idx_test):
    
    temp = np.vstack(features_array[animal_id])
    fig=plt.figure()
    ax=plt.subplot()
    
    diff = np.abs(temp[idx_test]-res)
    #print ("diff: ", diff.shape)
    
    errors = []
    for k in range(diff.shape[1]):
        errors.append([])
    
    for k in range(diff.shape[0]):
        for p in range(diff.shape[1]):
            temp = diff[k,p]
            #print (temp)
            tdiff = np.linalg.norm(temp)
            if tdiff>0:
                errors[p].append(tdiff)

    t =[]
    for k in range(len(errors)):
        temp = errors[k]
        pad = np.zeros(100000-len(errors[k]),'float32')+np.nan
        temp = np.concatenate((temp, pad))
        t.append(temp)
        
    data = np.array(t).T
    #print (data.shape)
    columns = ['nose','spine1','spine2', 'spine3', 'spine4', 'spine5']
    #columns = ['errors']
    df = pd.DataFrame(data, columns = columns)

    #print ("DF: ", df)

    # plot
    ax=plt.subplot(2,1,1)
    sns.violinplot(data=df) #x=df['spine2'])
    plt.ylim(bottom=0)
    plt.ylabel(" pixel error")
    
    ax=plt.subplot(2,1,2)
    plt.title("Zoom",fontsize=20)
    sns.violinplot(data=df) #x=df['spine2'])
    plt.ylabel(" pixel error")
    plt.ylim(0,50)
            
        
# 
def generate_imputation_model(features_array, animal_id, idx=None):

    temp = np.vstack(features_array[animal_id])
    print (temp.shape)

    X_all = temp.reshape(temp.shape[0],-1)
    print ("X_all: ", X_all.shape)

    if False:
        split = 0.9
        idx = np.random.choice(np.arange(X_all.shape[0]), int(X_all.shape[0]*split),replace=False)
        idx_not = np.delete(np.arange(X_all.shape[0]),idx)

    if idx is not None: 
        #
        X_train = X_all[idx]

    #
    print ("fitting...")
    imp = IterativeImputer(max_iter=10, random_state=0)
    imp.fit(X_train)
    print ("done")
    
    return imp

def predict_imputation(imp, features_array, animal_id, drops='fixed', idx_test=None, n_drops = 3):
    
    # 
    temp = np.vstack(features_array[animal_id])
    print (temp.shape)

    X_all = temp.reshape(temp.shape[0],-1)

    # select frames to predict
    if idx_test is not None:
        X_test = X_all[idx_test]
    
    # do drop outs in the test set:
    if True:
        X_test = X_test.reshape(-1,6,2)
        idx_drop = np.zeros((n_drops, X_test.shape[0]),'int32')
        for k in range(idx_drop.shape[1]):#n_drops):
            if drops=='fixed':
                idx_drop[:,k] = [1,3,5] #np.random.choice(np.arange(6),n_drops,replace=False)
            else:
                idx_drop[:,k] = np.random.choice(np.arange(6),n_drops,replace=False)
                
                
        print ("idx drop: ", idx_drop[0].shape)
        for k in range(len(idx_drop)):
            for p in range(idx_drop[k].shape[0]):
                X_test[p,idx_drop[k][p]]=np.nan
    else:
        # just evaluate the data as is...
        pass

    # TEST STEP
    X_test = X_test.reshape(X_test.shape[0],-1)
    res = imp.transform(X_test).reshape(-1,6,2)
    #print ("Res: ", res.shape)

    return res, idx_drop

    
cb = CentreBody()
cb.parallel = True
cb.root_dir = '/media/cat/1TB/dan/cohort1/slp/'
cb.get_fnames()

# median filter data
cb.filter_data()

# 
cb.reject_outliers()

#
cb.centre_and_align()

#



  ... median filtering ...


100%|██████████| 23/23 [00:00<00:00, 792.13it/s]

  ... rejecting outliers....



100%|██████████| 23/23 [00:00<00:00, 3217.78it/s]

  ... center and aligning ...



100%|██████████| 23/23 [00:00<00:00, 2067.05it/s]


In [309]:
##############################################
############## STACK ALL THE DATA ############
##############################################

# 
file_order = [

'2020-3-16_11-56-56-704655',  # day time starts correct day
'2020-3-16_12-57-12-418305',
'2020-3-16_01-57-27-327194',
'2020-3-16_02-57-41-995158',
'2020-3-16_03-57-56-902379',
'2020-3-16_04-58-11-998956',
'2020-3-16_05-58-27-193818',
'2020-3-16_06-58-43-678014',
'2020-3-16_07-59-00-362242',
'2020-3-16_08-59-17-534732',
'2020-3-16_09-59-34-731308',
'2020-3-16_10-59-50-448686',

'2020-3-16_12-54-07-193951',  # night time of previous day though
'2020-3-16_01-54-23-358257',
'2020-3-16_02-54-39-170978',
'2020-3-16_03-54-54-231226',
'2020-3-16_04-55-09-841582',
'2020-3-16_05-55-25-305681',
'2020-3-16_06-55-40-714236',
'2020-3-16_07-55-55-775234',
'2020-3-16_08-56-11-096689',
'2020-3-16_09-56-26-362091',
'2020-3-16_10-56-41-406701',
]

# stack the postures for each animal
features_array = []
for k in range(4):
    features_array.append([])
    
# 
from tqdm import tqdm
for file in tqdm(file_order):
    
    if True:
        fname = glob.glob(os.path.join(cb.root_dir,file+"*_centre_aligned.npy").replace("-","_"))[0]
        d3 = np.load(fname)
    if False:
        fname = glob.glob(os.path.join(cb.root_dir,file+"*_median_filtered_outliers.npy").replace("-","_"))[0]    
    
        d3 = np.load(fname)
        d3 = d3[:,:,np.array([0,5,6,7,8,9])]  
        
    if False:
        fname = glob.glob(os.path.join(cb.root_dir,file+"*_median_filtered.npy").replace("-","_"))[0]    
    
        d3 = np.load(fname)
        d3 = d3[:,:,np.array([0,5,6,7,8,9])]       

        
        
    #print (d3.shape)

    # loop over animals and keep only complete data (i.e. 6 pts)
    for k in range(4):
        # find nans and delete any frame that is missing even a single value for that animal
        idx = np.where(np.isnan(d3[:,k]))
        ids = np.unique(idx[0])
    
        idx_all = np.arange(d3.shape[0])
        idx_good = np.delete(idx_all, ids)
#        print (idx_good.shape)
        
        features_array[k].append(d3[idx_good,k])

        
    #break


100%|██████████| 23/23 [00:02<00:00, 10.78it/s]


In [325]:
#

def plot_imputation_results(features_array, animal_id, idx_test, res, idx_drop):
    
    # 
    labels = ['n','s1','s2','s3','s4','s5']
    
    # grab the selected data:
    temp = np.vstack(features_array[animal_id])
    #print ("temp: ", temp.shape)
    fig = plt.figure()
    ax = plt.axes()
    shift = 0
    for k in range(10):
        plt.subplot(2,5,k+1)

        ############ PLOT GROUND TRUTH ##############
        id2 = np.random.choice(idx_test,1)[0]
        #print (id2)
        
        #print (temp[id2].shape)
        plt.scatter(temp[id2,:,0],
                    temp[id2,:,1],
                    c='blue',
                    s=np.arange(1,7,1)[::-1]*20,
                    alpha=.7,
                    edgecolor='black', label='truth')

        #
        id3 = np.where(idx_test==id2)[0]
        #print (id3)

        ############ PLOT IMPUTED LOCS ##############
        plt.scatter(res[id3,:,0]+shift,
                    res[id3,:,1],
                    c='red',
                    s=np.arange(1,7,1)[::-1]*20,
                    alpha=.7,
                    edgecolor='black', label='imputed')

        if True: #k==0:
            for p in range(6):
#                 plt.text(res[id3,p,0],
#                          res[id3,p,1],labels[p])
                plt.text(temp[id2,p,0],
                         temp[id2,p,1],labels[p])

        # draw lines
        for p in range(len(idx_drop)):
            #print ("connectgin: ", p, idx_drop[p][id3])
            #print ("idx_drop full: ", np.array(idx_drop).shape)
            plt.plot([temp[id2,idx_drop[p][id3],0], res[id3,idx_drop[p][id3],0]+shift],
                     [temp[id2,idx_drop[p][id3],1], res[id3,idx_drop[p][id3],1]],
                     '--',c='black')

        if k==0:
            plt.legend(fontsize=8)

        plt.title("frame id: "+str(id2) + "\ndrops: "+str(np.array(idx_drop)[:,id3]),fontsize=8)
        
        x1 = (np.max(np.abs(temp[id2,:,0])), 
                             np.max(np.abs(res[id3,:,0])),
                             np.max(np.abs(temp[id2,:,1])), 
                             np.max(np.abs(res[id3,:,1])))
        #print ("x1: ", x1)

        max_ = np.max(x1)*1.2
        #plt.xlim(-max_, max_+shift)
        #plt.xlim(-100, 200)
        plt.ylim(-max_,10)
        ax.set_aspect('equal', 'datalim')
        #print ('')

    plt.show()

    
    
animal_id = 0

#
n_frames = np.vstack(features_array[animal_id]).shape[0]
split = 0.9
idx_train = np.random.choice(np.arange(n_frames),int(n_frames*split),replace=False)
imp = generate_imputation_model(features_array, animal_id, idx_train)

# 
idx_test = np.delete(np.arange(n_frames), idx_train)
res, idx_drop = predict_imputation(imp, features_array, animal_id, drops=None, idx_test=idx_test, n_drops = 3)

#
plot_imputation_results(features_array, animal_id, idx_test, res, idx_drop)
plt.suptitle("animal "+str(animal_id)+ " imputed vs. ground truth",fontsize=20)

# 
evaluate_imputation_error(features_array, animal_id, res, idx_train, idx_test)
#plt.suptitle("Clean data only ",fontsize=20)
plt.suptitle("animal "+str(animal_id)+ "  Egocentric (fixed nose) errors (pixels)",fontsize=20)
#plt.suptitle("animal "+str(animal_id)+ " NON-Egocentric (fixed nose) errors (pixels)",fontsize=20)


plt.show()


(218641, 6, 2)
X_all:  (218641, 12)
fitting...
done
(218641, 6, 2)
idx drop:  (21865,)


In [241]:

#         

        

        

diff:  (10000, 6, 2)
(10000, 6)
DF:        nose    spine1     spine2     spine3     spine4     spine5
0      NaN  8.053196  15.022601  20.459671  17.801094  20.735806
1      NaN  6.009436   7.212002   7.303462  28.604742  53.004890
2      NaN  1.747940   9.411833   8.524810   3.233682   9.551811
3      NaN  4.166721   2.071897  13.535925  25.377176  37.943233
4      NaN  0.956593   8.782355   2.688964  14.177631   5.399886
...    ...       ...        ...        ...        ...        ...
9995   NaN       NaN        NaN        NaN        NaN        NaN
9996   NaN       NaN        NaN        NaN        NaN        NaN
9997   NaN       NaN        NaN        NaN        NaN        NaN
9998   NaN       NaN        NaN        NaN        NaN        NaN
9999   NaN       NaN        NaN        NaN        NaN        NaN

[10000 rows x 6 columns]


In [135]:
df

Unnamed: 0,nose,spine1,spine2,spine3,spine4,spine5
0,,3.818413,4.814587,24.999798,25.545969,19.456072
1,,0.858467,4.052742,19.418888,24.486452,19.967733
2,,4.557945,8.433762,24.798075,24.866077,56.824924
3,,0.740223,8.106142,26.428474,25.756119,56.594051
4,,2.175694,7.933312,17.512733,42.814701,63.951710
...,...,...,...,...,...,...
9995,,,,,,
9996,,,,,,
9997,,,,,,
9998,,,,,,


In [None]:
##########################################################
##########################################################
##########################################################
from scipy.spatial import cKDTree
import joblib

def knn_triage(th, pca_wf):
    tree = cKDTree(pca_wf)
    dist, ind = tree.query(pca_wf, k=6)
    dist = np.sum(dist, 1)
    idx_keep1 = dist <= np.percentile(dist, th)
    return idx_keep1



# Fit the PCA object, but do not transform the data
for k in range(4):
    ax=plt.subplot(2,2,k+1)
    
    temp = features_array[k]
    d = []
    clrs = []
    for p in range(len(temp)):
        d.append(temp[p])
        clrs.extend(np.zeros(temp[p].shape[0])+p)
    
    clrs = np.array(clrs)
    d = np.vstack(d)
    print ("D: ", d.shape)
    d = d.reshape(d.shape[0],-1)
    continue
    #d = sklearn.preprocessing.normalize(d)

    # remove 1% of outliers
    if True:
        th = 95  # % of data to keep
        idx_keep = knn_triage(th, d)
        print (" d before traige: ", d.shape)
        d = d[idx_keep]
        print (" d after traige: ", d.shape)
        clrs = clrs[idx_keep]
    
    
    if False:
        pca = PCA(2)

        print ("... data into pca: ", d.shape)

        feats_pca = pca.fit_transform(d)
        print (feats_pca.shape)

        # 
        plt.scatter(feats_pca[::5,0],
           feats_pca[::5,1],
            #c=np.arange(feats_pca.shape[0])[::5],
            c=clrs[::5],
            alpha=.05)
        
    if True:
        
#         import gpumap
#         #from sklearn.datasets import load_digits

#         #digits = load_digits()
#         print ("Data into gpumap: ", d.shape)
#         feats_pca = gpumap.GPUMAP().fit_transform(d)
#         print ("Data out of gpumap: ", feats_pca.shape)

        import umap
    
        umap = umap.UMAP(n_components=2,
                        init='random',
                        random_state=0)
        
        d = d[::2]
        clrs = clrs[::2]
        
        print ("... data into umap: ", d.shape)
        
        if False:
            umap_ = umap.fit(d) #[::10])
            feats_pca = umap_.transform(d)
        else:
            feats_pca = umap.fit_transform(d) #[::10])
        
        
            # remove 1% of outliers
        if True:
            th = 90  # % of data to keep
            idx_keep = knn_triage(th, feats_pca)
            print (" d before traige: ", feats_pca.shape)
            feats_pca = feats_pca[idx_keep]
            print (" d after traige: ", feats_pca.shape)
            clrs = clrs[idx_keep]
        
        plt.scatter(feats_pca[:,0],
               feats_pca[:,1],
                #c=np.arange(feats_pca.shape[0])[::5],
                c=clrs,
                alpha=.05)
    if False:
        
        #from openTSNE import TSNE
        #print ("... data into tsne: ", d.shape)
        #feats_pca = TSNE().fit(d)
        
        
        from fastTSNE import TSNE

        tsne = TSNE(
            n_components=2, perplexity=30, learning_rate=100, early_exaggeration=12,
            n_jobs=4, 
            #angle=0.5, 
            initialization='random', metric='euclidean',
            n_iter=750, early_exaggeration_iter=250, neighbors='exact',
            negative_gradient_method='bh', min_num_intervals=10,
            #ints_in_inverval=2, 
            #late_exaggeration_iter=100, 
            #late_exaggeration=4,
        )
        
        # 
        feats_pca = tsne.fit(d)

        print (" output: ", feats_pca.shape)


        plt.scatter(feats_pca[:,0],
            feats_pca[:,1],
            #c=np.arange(feats_pca.shape[0])[::5],
            c=clrs,
            alpha=.05)

    # 
    plt.title("Animal:"+str(k))
    
    
plt.suptitle("Static vertically aligned postures",fontsize=20)
plt.show()
#plt.show()

In [15]:
lens = [218641, 94647, 132861, 176982]

lens = np.array(lens)
print (np.round(lens/(23*89900)*100,2))

[10.57  4.58  6.43  8.56]


In [44]:
#################################################
############### IMPUTE MISSING DATA #############
#################################################


from sklearn.experimental import enable_iterative_imputer
from sklearn.impute import IterativeImputer
imp = IterativeImputer(max_iter=10, random_state=0)
imp.fit([[1, 2], [3, 6], [4, 8], [np.nan, 3], [7, np.nan]])
X_test = [[np.nan, 2], [6, np.nan], [np.nan, 6]]
# the model learns that the second feature is double the first

print(np.round(imp.transform(X_test)))

[[ 1.  2.]
 [ 6. 12.]
 [ 3.  6.]]


In [None]:
##############################################
########## FEATURIZE BEHAVIOR CHUNKS #########
##############################################
from sklearn import decomposition
import sklearn

fig = plt.figure()
X_all = []
n_events = []
for animal_id in animal_ids:
    X = X4[animal_id].copy()
    X = X.reshape(X.shape[0], -1)
    print (X.shape)
    X_all.append(X)
    n_events.append(X.shape[0])

#     
X_all = np.vstack(X_all)
print (X_all.shape)
X = sklearn.preprocessing.normalize(X_all)

#
if True:
    pca = decomposition.PCA(n_components=3)

    X_pca = pca.fit_transform(X_all)
    print (X_pca.shape)
    
if False:
    import umap
    umap = umap.UMAP(n_components=2,
                    init='random',
                    random_state=0)

    umap_ = umap.fit(X_all[::10])

    X_pca = umap_.transform(X_all)
        

print ("plotting: ", X_pca.shape)


print (n_events)
fig=plt.figure()
for k in range(4):
    ax = plt.subplot(2,2,k+1)
    start = np.int32(n_events[:k]).sum()
    end = np.int32(n_events[:k+1]).sum()
    print (start, end)
    plt.scatter(X_pca[start:end,0],
                X_pca[start:end,1],
               alpha=.1)

plt.show()