In [2]:
%load_ext autoreload
%autoreload 2

import matplotlib
import matplotlib.pyplot as plt
#matplotlib.use('Agg')
%matplotlib tk
%autosave 180

import sys
sys.path.append('/home/cat/code/gerbil/utils/')

#
import numpy as np

#
#from visualize import visualize
from track import track as Track


The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


Autosaving every 180 seconds


In [6]:
######################################################## 
################ GENERATE FIXED TRACK ##################
########################################################

# STEP 1: Convert both .slp files (before and after Id-swtch) to .npy files and fix them
# RUN: this cell 2 x

# 
fnames_slp = [
    '/home/cat/data/dan/id_switch/post_fix_metrics/2020_07_30_19_25_43_717047_compressed_Day.slp',
    '/home/cat/data/dan/id_switch/post_fix_metrics/2020_07_30_14_11_50_567584_compressed_Day.slp'
]

for fname_slp in fnames_slp:
    track = Track.Track  (fname_slp)
    track.track_type = 'features'

    ###### parameters for computing body centroid #######
    track.use_dynamic_centroid = True   # True: alg. serches for the first non-nan value in this body order [2,3,1,0,4,5]
                                         # - advantage: much more robust to lost features
                                         # False: we fix the centroid to a specific body part
                                         # - advantage less jitter for some applications
    track.centroid_body_id = [2]         # if centroid flag is False; we use this body part instead

    ##### run track fixer #######
    track.fix_all_tracks()

    ##### join spatially close but temporally distant chunks #####
    if True:
        #
        track.memory_interpolate_tracks_spine()

    ##### save the fixed spines will overwrite the previous/defatul spine values####
    track.save_centroid()

    #
    print ("Done...")

KeyboardInterrupt: 

In [8]:
####################################################
########## MAKE PLOTS OF SPINE CENTRES #############
####################################################

# OPTIONAL VISUALIZE TRACKS
def plot_track(track, animal_id):
    
    #
    plt.plot(track[:,animal_id,0],
             track[:,animal_id,1],
             #label=str(k),

            )

    plt.xlim(0,1000)
    plt.ylim(0,800)

#
animal_id = 0
    
#    
track = np.load(fnames_slp[0].replace('.slp','_spine.npy'))
plot_track(track, animal_id)

#    
track = np.load(fnames_slp[1].replace('.slp','_spine.npy'))
plot_track(track, animal_id)


In [44]:
###########################################################
###########################################################
###########################################################
def check_switch(tracks, track, animal_id, threshold=10):
    
    #
    tracks = np.delete(tracks,animal_id, axis=0)

    #
    dists = np.linalg.norm(tracks-track, axis=1)
    if np.nanmin(dists)<threshold:
        return True
    
    return False
    
    
#
def find_bouts(track,animal_id):
    from tqdm import trange
    bouts = []
    track1 = track[:,animal_id]
    
    starts = []
    ends = []
    switches=0
    if np.isnan(track1[0,0])==False:
        starts.append(0)
    #
    for k in trange(1,track1.shape[0]-1,1):
        if np.isnan(track1[k,0]) and np.isnan(track1[k+1,0])==False:
            starts.append(k)
        #
        elif np.isnan(track1[k,0])==False and np.isnan(track1[k+1,0])==True:
            ends.append(k)
            
            # also check if the next frame is a switch
            res = check_switch(track[k+1], track1[k], animal_id)
            
            if res:
                switches+=1
            
    if len(starts)!=len(ends):
        starts=starts[:-1]
    
    bouts = np.vstack((starts,ends)).T
    #print ("bouts: ", bouts)
    
    return bouts, switches



def get_n_frames(track, animal_id):
    
    # get # of frames with animal detected
    #print (track.shape)
    s = track[:,animal_id].sum(1)
    
    idx = np.where(np.isnan(s)==False)[0]
    #print (s.shape, idx.shape)
    
    n_frames = idx.shape[0]
    
    return n_frames
    
def get_ave_features(full_track, animal_id):
    # get average # of detected frames
    
    #
    full_track = full_track[:,animal_id].sum(2)
    
    # find non-nan values 
    idx2 = np.where(np.isnan(full_track)==False)

    # make an array
    idx2 = np.array(idx2)[0]
    #print (idx2)
    
    # for each frame count the # of non-nan features
    n_unique, counts = np.unique(idx2, return_counts=True)
 
    # make empty list to start
    n_feats=np.zeros(full_track.shape[0])
 
    return np.mean(counts)


##########################################
##########################################
##########################################
animal_ids = np.arange(6)
plt.figure()

# 
for ctr, animal_id in enumerate(animal_ids):
    print ("processing animal: ", animal_id)
    #
    track = np.load(fnames_slp[0].replace('.slp','_spine.npy'))
    bouts1,switches1 = find_bouts(track, animal_id)

    track = np.load(fnames_slp[1].replace('.slp','_spine.npy'))
    bouts2,switches2 = find_bouts(track, animal_id)
    
    #
    spine_track = np.load(fnames_slp[0].replace('.slp','_spine.npy'))
    n_frames1 = get_n_frames(spine_track, animal_id)
    spine_track = np.load(fnames_slp[1].replace('.slp','_spine.npy'))
    n_frames2 = get_n_frames(spine_track, animal_id)
    
    #
    full_track = np.load(fnames_slp[0].replace('.slp','.npy'))
    ave_feats1 = get_ave_features(full_track, animal_id)
    full_track = np.load(fnames_slp[1].replace('.slp','.npy'))
    ave_feats2 = get_ave_features(full_track, animal_id)
    #print ("averages: ", ave_feats1, ave_feats2)

    # 
    ax=plt.subplot(2,3,1)
    
    if ctr==0:
        plt.bar(ctr, bouts1.shape[0], 0.4, color='black', label='original')
        plt.bar(ctr+0.4, bouts2.shape[0], 0.4, color='blue', label='350fixed')
    else:
        plt.bar(ctr, bouts1.shape[0], 0.4, color='black')
        plt.bar(ctr+0.4, bouts2.shape[0], 0.4, color='blue')
    plt.legend()
    plt.ylabel("# of segments")
    
    #
    ax2=plt.subplot(2,3,2)
    max_ = 1800
    plt.title("350fixed - original")
    #plt.semilogy()
    diffs1 = bouts1[:,1]-bouts1[:,0]
    y1= np.histogram(diffs1, bins=np.arange(0,max_,10))
    
    #
    diffs2 = bouts2[:,1]-bouts2[:,0]
    y2= np.histogram(diffs2, bins=np.arange(0,max_,10))
    
    plt.plot(y1[1][:-1],y2[0]-y1[0], label=str(ctr))
    plt.plot([0,max_],[0,0],'--',c='grey')
    plt.xlabel("duration of bout")
    plt.ylabel("# of bouts")
    plt.legend()
    
    #
    ax3=plt.subplot(2,3,3)
    plt.bar(ctr, diffs1.sum()/track.shape[0], 0.4, color='black')
    plt.bar(ctr+0.4, diffs2.sum()/track.shape[0], 0.4, color='blue')  
    plt.ylabel("total duration of tracked bouts ")
    plt.xlabel("animal id")
    
    ax3=plt.subplot(2,3,4)
    plt.bar(ctr, switches1, 0.4, color='black')
    plt.bar(ctr+0.4, switches2, 0.4, color='blue')  
    plt.ylabel("Likely ID switches")
    plt.xlabel("animal id")
    
    ax4=plt.subplot(2,3,5)
    plt.bar(ctr, n_frames1, 0.4, color='black')
    plt.bar(ctr+0.4, n_frames2, 0.4, color='blue')  
    plt.ylabel("n_frames detected")
    plt.xlabel("animal id")
    
    ax4=plt.subplot(2,3,6)
    plt.bar(ctr, ave_feats1, 0.4, color='black')
    plt.bar(ctr+0.4, ave_feats2, 0.4, color='blue')  
    plt.ylabel("ave # of features/frame")
    plt.xlabel("animal id")
    
#
plt.show()



processing animal:  0


100%|██████████████████████████████████| 28800/28800 [00:00<00:00, 63203.15it/s]
100%|██████████████████████████████████| 28800/28800 [00:00<00:00, 65361.67it/s]


processing animal:  1


100%|██████████████████████████████████| 28800/28800 [00:00<00:00, 65913.56it/s]
100%|██████████████████████████████████| 28800/28800 [00:00<00:00, 62162.68it/s]


processing animal:  2


100%|██████████████████████████████████| 28800/28800 [00:00<00:00, 54805.41it/s]
100%|██████████████████████████████████| 28800/28800 [00:00<00:00, 62804.29it/s]


processing animal:  3


100%|██████████████████████████████████| 28800/28800 [00:00<00:00, 63119.17it/s]
100%|██████████████████████████████████| 28800/28800 [00:00<00:00, 61684.25it/s]


processing animal:  4


100%|██████████████████████████████████| 28800/28800 [00:00<00:00, 61637.73it/s]
100%|██████████████████████████████████| 28800/28800 [00:00<00:00, 59888.45it/s]


processing animal:  5


100%|██████████████████████████████████| 28800/28800 [00:00<00:00, 57411.33it/s]
100%|██████████████████████████████████| 28800/28800 [00:00<00:00, 60250.06it/s]
