In [138]:
import numpy as np
import os
import json
import sys
import glob
import h5py
from itertools import groupby
from scipy import signal
import pickle as pkl

%matplotlib qt5
%matplotlib auto
import matplotlib.pyplot as plt
from matplotlib.patches import Rectangle, Circle
from matplotlib.collections import PatchCollection
from matplotlib import cm
from matplotlib import colors

Using matplotlib backend: Qt5Agg


In [152]:
%qtconsole

In [217]:
vid_locations = '/media/gravishlab/SeagateExpansionDrive/AntTrack/Tunnel_20180313-14'

file_list = []

file_list = glob.glob(os.path.join(vid_locations, '**/**/*antennae_predictions.h5'))
file_list = sorted(file_list)
print('Total Number of Videos: ',len(file_list))
len(file_list)


pix2mm = 959.7563/30 # Measured 3cm in 4 cameras
fps = 239.16

Total Number of Videos:  65


In [197]:
def remove_lowconf_pts(arr, conf, conf_cutoff, jump_limit):
    arr_highconf = arr.copy()
    # get rid of low confidence pts
    arr_highconf[conf<conf_cutoff]=np.nan
    return arr_highconf

def remove_jumps(arr, jump_limit):
    yy = np.isnan(arr)
    xx = range(len(yy))
    arr_nojump = np.empty(yy.shape)*np.nan
    for k,g in groupby(iter(xx), lambda x: yy[x]):
        if k == False: # if is not a group of nan
            g = np.array(list(g))
#             print(len(g))
            if len(g)>3:
                # get rid of drastic changes
                if jump_limit != None:
                    arr_OI = arr[g]
                    d_arr = np.abs(np.diff(arr_OI))
                    d_jump = np.abs(d_arr) > jump_limit
                    d_jump_cumsum = np.cumsum(np.insert(d_jump,0,0))
                    d_jump_opp = (d_jump_cumsum%2).astype(bool)
                    if np.sum(d_jump_opp==True)>np.sum(d_jump_opp == False):
                        d_jump_opp = np.logical_not(d_jump_opp)
                    arr_OI[d_jump_opp]=np.nan
                arr_nojump[g]=arr_OI
                    
    return arr_nojump
    


def middle_half(alist, wanted_parts=4):
    alist= alist[np.logical_not(np.isnan(alist))]
    length = len(alist)
    sections = np.array([ alist[i*length // wanted_parts: (i+1)*length // wanted_parts] 
             for i in range(wanted_parts) ])
    middle_half = np.concatenate(sections[1:3])
    if not len(middle_half)>0:
        return  np.nan, np.nan
    else:
        full_range = (np.max(alist)-np.min(alist))
        middle_range = (np.max(middle_half)-np.min(middle_half))
        med = np.mean(middle_half)
#         print(middle_range/full_range)
        if middle_range/full_range < 0.93: # theoretically for normal distribution mid_range/full_range = 0.16625
#             print('not  normal dist')
            sigma = middle_range/3
#             print('removed outliers')
        else:
            sigma = full_range/4
#         print(med, sigma)
    return med, sigma

def remove_outliers(arr):
    med, sigma = middle_half(arr)
    where_far_away = (np.abs(arr-med)> 2*sigma)
    arr[where_far_away]=np.nan
    return arr

def remove_outliers2d(arr_x, arr_y):
    med_x, sigma_x = middle_half(arr_x)
    med_y, sigma_y = middle_half(arr_y)
#     if sigma_x < 15:
#         sigma_x = 15
    where_far_away = np.logical_or(np.abs(arr_x-med_x)> 2*sigma_x, np.abs(arr_y-med_y)> 2*sigma_y)
#     print('removing %i outliers'%np.sum(where_far_away))
    arr_x[where_far_away]=np.nan
    arr_y[where_far_away]=np.nan
    return arr_x, arr_y

def remove_outliers(arr_x, arr_y):
    med_x, sigma_x = middle_half(arr_x)
    med_y, sigma_y = middle_half(arr_y)
    
    where_far_away = np.logical_or(np.abs(arr_x-med_x)> 2*sigma_x, np.abs(arr_y-med_y)> 2*sigma_y)
    arr_x[where_far_away]=np.nan
    arr_y[where_far_away]=np.nan
    
    return arr_x, arr_y

def find_nan_gaps(arr, limit):  
    yy = np.isnan(arr)
    xx = range(len(yy))
    where_gapOI = np.full(arr.shape, False)
    where_othergaps = np.full(arr.shape, False)
    for k,g in groupby(iter(xx), lambda x: yy[x]):
        if k == True: # if is a group of nan
            g = list(g)
            if any(x in g for x in [0, len(arr)-1]): # if first or last group
                where_othergaps[np.array(g)]=True
#                 print('remove: ', len(g), g)
                continue
                
            if len(g)<= limit: # length is below limit
                where_gapOI[np.array(g)]=True
#                 print('interp: ', len(g), g)
            else:
                where_othergaps[np.array(g)]=True
#                 print('remove: ', len(g), g)
    return where_gapOI, where_othergaps

def find_interp_idcs(where_interpolate):
    interp_idcs = []
    for val in [-1,0,1]:
        interp_idcs = np.concatenate([interp_idcs,np.where(where_interpolate)[0]+val])
    interp_idcs = np.sort(np.array(list(set(interp_idcs)))) # get of repeat elements
    interp_idcs = interp_idcs[np.logical_and(interp_idcs>-1, interp_idcs < len(where_interpolate))].astype(np.uint32) # only elements in range
    return interp_idcs

def interp_vals(arr, interp_idcs): # array includes nan values
    interp = arr.copy()
    if len(interp_idcs)>0:
        temp = arr[interp_idcs]
        interpolated_vals = np.interp(
            interp_idcs, 
            interp_idcs[np.logical_not(np.isnan(temp))], temp[np.logical_not(np.isnan(temp))] )
        interp[interp_idcs] = interpolated_vals
    return interp
    
def lowpass_filt_sections(arr):
    yy = np.isnan(arr)
    xx = range(len(yy))
    full_filtered = np.empty(yy.shape)*np.nan
    for k,g in groupby(iter(xx), lambda x: yy[x]):
        if k == False: # if is a group of nan
            g = list(g)
#             print('section to lowpass fitler: ', len(g))
            if len(g)>9:
                b, a = signal.butter(2,0.3,btype='low')
                filtered = signal.filtfilt(b, a, arr[np.array(g)])
                full_filtered[np.array(g)]=filtered
    return full_filtered



def interpolate_filter_tracking(arr, offsets, conf, conf_cutoff, jump_limit, nan_gap_limit, plots = False):
    
    all_frames = np.arange(arr.shape[0])
    
    # initialize variables
    joint_x, joint_y, x_offset, y_offset, conf, \
    joint_x_highconf, joint_y_highconf, joint_x_interp, joint_y_interp, joint_x_filt, joint_y_filt = \
    (np.empty(all_frames.shape)*np.nan for i in range(11))
    
    # set up variables
    joint_x = arr[:,0]
    joint_y = arr[:,1]
    x_offset = offsets[:,0]
    y_offset = offsets[:,0]
    joint_x = joint_x - x_offset
    joint_y = joint_y - y_offset

    joint_x_highconf = remove_lowconf_pts(joint_x, conf, conf_cutoff, jump_limit)
    joint_y_highconf = remove_lowconf_pts(joint_y, conf, conf_cutoff, jump_limit)
    
    # remove big jumps
    joint_x_highconf = remove_jumps(joint_x_highconf, jump_limit)
    joint_y_highconf = remove_jumps(joint_y_highconf, jump_limit)
    
    # remove outliers
#     print('before removing outliers: ', np.sum(np.isnan(joint_x_highconf)), ' nan of ', len(joint_x_highconf) )
    joint_x_highconf, joint_y_highconf = remove_outliers2d(joint_x_highconf, joint_y_highconf)
#     print('after removing outliers: ', np.sum(np.isnan(joint_x_highconf)), ' nan of ', len(joint_x_highconf) )
    
    # interpolate 
    
    where_interpolate, where_remove = find_nan_gaps(joint_x_highconf, nan_gap_limit)
#     print(np.sum(np.isfinite(joint_y_highconf)), np.sum(where_interpolate))
    joint_x_interp = interp_vals(joint_x_highconf, find_interp_idcs(where_interpolate))
    joint_y_interp = interp_vals(joint_y_highconf, find_interp_idcs(where_interpolate))

    # lowpass filter
    joint_x_filt = lowpass_filt_sections(joint_x_interp)
    joint_y_filt = lowpass_filt_sections(joint_y_interp)
#     print(joint_x_filt.shape)
    
    if plots:
        # PLOT THINGS
        fig = plt.figure(figsize=(15,5))
        ax1=plt.subplot(2,1,1)
        ax_limits=[]
        ax_limits.append([np.nanmin(joint_x)-10, np.nanmax(joint_x)+10])
        ax2 = plt.subplot(2,1,2)
        ax_limits.append([np.nanmin(joint_y)-10, np.nanmax(joint_y)+10])

        for xx,ax in enumerate([ax1, ax2]):
            for kk, inter in enumerate(all_frames[where_interpolate]):
                if kk == 0:
                    rect = Rectangle((inter-0.5, ax_limits[xx][0]),
                                     1, np.diff(ax_limits[xx]), alpha = 0.2, fc = 'm', ec = None, label = 'interpolated')
                else:
                    rect = Rectangle((inter-0.5, ax_limits[xx][0]),
                                     1, np.diff(ax_limits[xx]), alpha = 0.2, fc = 'm', ec = None)
                ax.add_patch(rect)
            for kk,remov in enumerate(all_frames[where_remove]):
                if kk == 0:
                    rect = Rectangle((remov-0.5, ax_limits[xx][0]),
                                     1, np.diff(ax_limits[xx]), alpha = 0.05, fc = 'k', ec = None, label = 'removed')
                else:
                    rect = Rectangle((remov-0.5, ax_limits[xx][0]),
                                     1, np.diff(ax_limits[xx]), alpha = 0.05, fc = 'k', ec = None)
                ax.add_patch(rect)

        plt.sca(ax1)
        cmap = cm.bwr
        plt.scatter(all_frames, joint_x, c = conf, s = 10, 
                cmap = cmap, norm = colors.Normalize(vmin=0, vmax=1), label = 'raw tracking')
        plt.plot(all_frames, joint_x_interp, '.k', alpha = 0.5, MarkerSize = 2)#, label = 'interpolated')
        plt.plot(all_frames, joint_x_filt, '-g', alpha = 0.5, label = 'filtered', )
        plt.ylabel('x (pix)')
        plt.legend(loc = 'upper right', frameon=False, fontsize = 7)

        


        plt.sca(ax2)
        plt.plot(all_frames, joint_y_interp, '.k', alpha = 0.5, label = 'interpolated', MarkerSize = 2)
        plt.plot(all_frames, joint_y_filt, '-g', alpha = 0.5, label = 'filtered', )
        plt.scatter(all_frames, joint_y, c = conf, s = 10, 
                cmap = cmap, norm = colors.Normalize(vmin=0, vmax=1))
        plt.ylabel('y (pix)')
        plt.gca().invert_yaxis()

        
        
        

        cax = plt.axes([0.93,0.1,0.02,0.8])
        plt.colorbar(cax=cax, label='confidence')
        plt.clim(0,1)

#     print('final filter len: ', np.sum(np.isnan(joint_x_filt)), ' nan of ', len(joint_x_filt) )
    return np.concatenate((joint_x_filt[:,np.newaxis], joint_y_filt[:,np.newaxis]),axis=1);



# RUN ON SPECIFIC TRIAL & PLOT
# plt.close('all')

# thorax_filt = interpolate_filter_tracking(thorax, np.zeros(thorax.shape), thorax_conf, 
#                                 conf_cutoff = 0.6, jump_limit = 10, nan_gap_limit = 5, plots = True)
# antenna_filt = np.empty([thorax.shape[0],2,2])*np.nan
# for anten_num in range(0,2):
#     antenna_filt[:,:,anten_num] = interpolate_filter_tracking(antenna[:,:,anten_num], thorax_filt, thorax_conf, 
#                                     conf_cutoff = 0.6, jump_limit = 10, nan_gap_limit = 5, plots = True)

    
        

In [220]:
plt.close('all')
f1 = plt.figure(figsize = (4,12))
# a1 = plt.subplot(111)
tr_num = 0

for save_path in file_list:
#     print(save_path)

    if os.path.exists(save_path):
        
        subtype = ['0mm','1mm','3mm','5mm'].index(save_path.split('/')[-3])
        pltcolors = ['k','r','g','b']
        
        hf = h5py.File(save_path,'r')
        joint_loc = hf['positions_pred'][()].astype(np.float32) #hf.get('positions_pred')
        joint_conf = hf['conf_pred'][()] #hf.get('conf_pred')
        hf.close
        del hf
        
        
        if joint_loc.shape[0] < 100:
            continue
        
#         print(subtype, ' -- ', joint_loc.shape[0])
        
        thorax = np.full(joint_loc.shape[0:2],np.nan)
        antenna = np.full((joint_loc.shape[0],2,2),np.nan)
        
        thorax[:,0]=joint_loc[:,0,2]
        thorax[:,1]=joint_loc[:,1,2]
        thorax_conf=joint_conf[:,2]
        thorax[thorax_conf<0.6,:]=np.nan     
        
        antenna_conf = [[]]*2
        for anten_num in range(0,2):
            antenna[:,0,anten_num]=joint_loc[:,0,28+4*anten_num]
            antenna[:,1,anten_num]=joint_loc[:,1,28+4*anten_num]
            antenna_conf[anten_num]=joint_conf[:,28+4*anten_num]
            antenna[thorax_conf<0.6,:, anten_num]=np.nan
            antenna[antenna_conf[anten_num]<0.6,:, anten_num]=np.nan
        
        
        # REMOVE LOW CON AND LOWPASS FILT
        thorax_filt = interpolate_filter_tracking(thorax, np.zeros(thorax.shape), thorax_conf, 
                                        conf_cutoff = 0.6, jump_limit = 10, nan_gap_limit = 5, plots = False)
        antenna_filt = np.empty([thorax.shape[0],2,2])*np.nan
        for anten_num in range(0,2):
            antenna_filt[:,:,anten_num] = interpolate_filter_tracking(antenna[:,:,anten_num], thorax_filt, thorax_conf, 
                                            conf_cutoff = 0.6, jump_limit = 10, nan_gap_limit = 5, plots = False)
            
        if np.sum(np.isfinite(antenna_filt[:,0,0]))<50:
            continue
            
        print(subtype, ' -- ', np.sum(np.isfinite(antenna_filt[:,0,0])))
            
        # SAVE AS PICKLE
        pname = ('/').join(save_path.split('/')[:-1]) + '/antennae_%s_tr%i.pkl'%(save_path.split('/')[-3],tr_num)
        with open(pname, 'wb') as f:
            pkl.dump(antenna_filt,f)
        tr_num=tr_num+1
            
        
        # PLOT
        for anten_num in range(0,2):
            plt.subplot(4,1,subtype+1)
            plt.plot(antenna_filt[:,0,anten_num],antenna_filt[:,1,anten_num],'-', color =pltcolors[subtype], alpha=0.1)
            plt.plot(antenna_filt[:,0,anten_num],antenna_filt[:,1,anten_num],'.', color =pltcolors[subtype], alpha=0.1, MarkerSize = 2)
        plt.plot(0,0,'.k')
        plt.xlim([0,100])
        plt.ylim([-60,60])
        plt.axis('equal')
        plt.gca().invert_yaxis()
#             a1.plot(antenna_filt[:,0,anten_num],antenna_filt[:,1,anten_num],'-', color =pltcolors[subtype], alpha=0.1)
#             a1.plot(antenna_filt[:,0,anten_num],antenna_filt[:,1,anten_num],'.', color =pltcolors[subtype], alpha=0.1, MarkerSize = 2)
#         a1.plot(0,0,'.k')
            
            
        # individual forward and lateral speeds
#         plt.figure()
#         for anten_num in range(0,2):
#             plt.subplot(2,1,anten_num+1)
#             plt.plot(np.diff(antenna_filt[:,0,anten_num])*fps/pix2mm, '-r')
#             plt.plot((anten_num*2-1)*np.diff(antenna_filt[:,1,anten_num])*fps/pix2mm, '-b')
#             plt.text(len(thorax)-50, 20, 'forward' , color = 'r')
#             plt.text(len(thorax)-50, 15, 'lateral' , color = 'b')
#             plt.ylabel('%s antennae speed (mm/s)'%['L','R'][anten_num])
#             plt.ylim([-40,40])
#             if anten_num ==0:
#                 plt.title('%s'%pname.split('/')[-1])
            
            
#         # total speed
#         plt.figure()
#         for anten_num in range(0,2):
#             plt.subplot(2,1,anten_num+1)
#             plt.plot(np.linalg.norm(np.diff(antenna_filt[:,:,anten_num],axis=0),axis=1)*fps/pix2mm, '-k', alpha = 0.4)
#             plt.ylabel('%s antennae speed (mm/s)'%['L','R'][anten_num])
            
            




  after removing the cwd from sys.path.


0  --  199
0  --  199
0  --  326
0  --  271
0  --  121
0  --  68
0  --  153
0  --  105
0  --  297
0  --  167
1  --  434
1  --  177
1  --  106
1  --  156
1  --  221
1  --  233
1  --  504
1  --  449
1  --  104
1  --  53
2  --  669
2  --  333
2  --  682
2  --  628
2  --  361
2  --  196
2  --  152




2  --  85
2  --  644
2  --  707
2  --  504
3  --  344
3  --  632
3  --  478
3  --  236
3  --  244
3  --  554
3  --  234
3  --  488
3  --  381
3  --  552


In [75]:
plt.close('all')
for anten_num in range(0,2):

    plt.plot(antenna_filt[:,0,anten_num],antenna_filt[:,1,anten_num],'-k',alpha=0.2)
    plt.plot(antenna_filt[:,0,anten_num],antenna_filt[:,1,anten_num],'.k',alpha=0.4)
    plt.plot(0,0,'.k')

In [None]:
#         # LOAD JOINT TRACKING DATA FOR ****36*** TRACKED POINTS
#         # legs
#         for joint_num in range(0,6):
#             video.objects[ant_num]['joint%i_x'%joint_num]=joint_loc[:,0,27-4*joint_num]
#             video.objects[ant_num]['joint%i_y'%joint_num]=joint_loc[:,1,27-4*joint_num]
#             video.objects[ant_num]['joint%i_conf'%joint_num]=joint_conf[:,27-4*joint_num]
#         # antennae
#         for anten_num in range(0,2):
#             video.objects[ant_num]['antenna%i_x'%anten_num]=joint_loc[:,0,28+4*anten_num]
#             video.objects[ant_num]['antenna%i_y'%anten_num]=joint_loc[:,1,28+4*anten_num]
#             video.objects[ant_num]['antenna%i_conf'%anten_num]=joint_conf[:,28+4*anten_num]
#         # thorax & neck
#         video.objects[ant_num]['thorax_x']=joint_loc[:,0,2]
#         video.objects[ant_num]['thorax_y']=joint_loc[:,1,2]
#         video.objects[ant_num]['thorax_conf']=joint_conf[:,2]
#         video.objects[ant_num]['neck_x']=joint_loc[:,0,1]
#         video.objects[ant_num]['neck_y']=joint_loc[:,1,1]
#         video.objects[ant_num]['neck_conf']=joint_conf[:,1]
