In [1]:
#SETUP

# KEY FUNCTIONS:
# next frame = right arrow key
# previous frame = left arrow key
# frame at beginning of blue track = "b"
# frame at end of green (focal) track = "g"
# next blue track = up arrow
# previous blue track = down arrow
# next green (focal) track = "." (same button as '>')
# previous green (focal) = "," (same button as '<')
# add point (automatically adds to end of current green track) = click anywhere on picture
# remove point = 'delete/backspace' key
# add blue track to end of green track = space bar
# permanently remove blue track = "-" 

#THINGS FOR USER TO ADD:
#put them as strings - for example: '/Users/dbasili/koger_drive/long-buffalo-data/positions.npy'
positions_path = '/Users/dbasili/koger_drive/positions.npy'      #the dots marking the animals on each frame
tracks_path = '/Users/dbasili/koger_drive/tracks.npy'            #file containing info about the tracks
picture_folder_path = '/Users/dbasili/koger_drive/frames/*jpg'   #folder with the pictures
factor = 0.33                                                    #how much you want to shrink the original image
skip = 2                                                         #how many frames you want to skip

#import
import cv2
import numpy as np
import glob

In [2]:
#CLASSES

#button class
class Button():
    def __init__(self,image,x1,x2,y1,y2,function,color1,color2,color3,text): #constructor
        self.x1 = x1
        self.x2 = x2
        self.y1 = y1
        self.y2 = y2
        self.color1 = color1
        self.color2 = color2
        self.color3 = color3
        self.text = text
        self.function = function
        image[y1:y2,x1:x2] = [color1,color2,color3]
        cv2.putText(image, text, (x1+5,y1+22), cv2.FONT_HERSHEY_DUPLEX, 0.5, (0,0,0), 1, cv2.LINE_AA)
    def check_if_pressed(self,x,y,image, window): #check if pressed button
        if x >= self.x1 and x <= self.x2 and y >= self.y1 and y <= self.y2: #if pressed
            #turn black
            image[self.y1:self.y2,self.x1:self.x2] = [0,0,0]
            cv2.imshow('pic0', image)
            cv2.waitKey(200)

            #turn back
            image[self.y1:self.y2,self.x1:self.x2] = [self.color1,self.color2,self.color3]
            cv2.putText(image, self.text, (self.x1+5,self.y1+22), cv2.FONT_HERSHEY_DUPLEX, 0.5, (0,0,0),1,cv2.LINE_AA)
            cv2.imshow('pic0', image)
            cv2.waitKey(1)

            #CALL FUNCTION
            window.function_caller(self.function)
            return True
        else:
            pass

#class for the window
class Window():
    
    #constructor
    def __init__(self, positions_path, tracks_path, picture_folder_path, factor):
        self.listofpositions = np.load(positions_path)
        self.listoftracks = np.load(tracks_path)
        for track in self.listoftracks:
            track['connected'] = []
        self.files = glob.glob(picture_folder_path)
        self.files.sort(key=lambda file: int(file.split('.')[-2].split('_')[-1]))
        self.factor = factor
        image = cv2.imread(self.files[0]) #to get size
        self.h = int(np.size(image,0)*factor)
        self.w = int(np.size(image,1)*factor)
        self.full_pic = np.zeros((int(self.h+self.h/15),self.w,3), dtype=np.uint8)
        self.focaltrackcount = 0
        self.trackcount = 1
        self.framecount = 0
        n = self.listoftracks[0]['first_frame'] + len(self.listoftracks[0]['track']) #to get good frame
        if n >= len(self.files)/skip:
            self.framecount = int(len(self.files)/skip) - 1
        else:
            self.framecount = n
        
        #make buttons
        n = 9
        back_frame_button = Button(self.full_pic,1,int(self.w/10),self.h+1,int(self.h+self.h/12),'back_frame_function',0,0,204,'frame -')
        blue_button = Button(self.full_pic,int(self.w/10)+1,int(self.w/10)+28,self.h,int(self.h+self.h/24),'blue_function',204,0,0,' ')
        green_button = Button(self.full_pic,int(self.w/10)+1,int(self.w/10)+28,int(self.h+self.h/24)+1,int(self.h+self.h/12),'green_function',0,204,0,' ')
        next_frame_button = Button(self.full_pic,int(self.w/10)+29,int(2*self.w/n),self.h+1,int(self.h+self.h/12),'next_frame_function',204,0,0,'frame +')
        back_track_button = Button(self.full_pic,int(2*self.w/n)+1,int(3*self.w/n),self.h+1,int(self.h+self.h/12),'back_track_function',0,0,204,'track -')
        next_track_button = Button(self.full_pic,int(3*self.w/n)+1,int(4*self.w/n),self.h+1,int(self.h+self.h/12),'next_track_function',204,0,0,'track +')
        back_focal=Button(self.full_pic,int(4*self.w/n)+1,int(5*self.w/n),self.h+1,int(self.h+self.h/12),'back_focal_track_function',0,0,204,'focal track-')
        next_focal=Button(self.full_pic,int(5*self.w/n)+1,int(6*self.w/n),self.h+1,int(self.h+self.h/12),'next_focal_track_function',204,0,0,'focal track+')
        remove_point = Button(self.full_pic,int(6*self.w/n)+1,int(7*self.w/n),self.h+1,int(self.h+self.h/12),'remove_point_function',0,0,204,'remove point')
        add_to_track = Button(self.full_pic,int(7*self.w/n)+1,int(8*self.w/n),self.h+1,int(self.h+self.h/12),'add_to_track_function',204,0,0,'add to track')
        remove_track = Button(self.full_pic,int(8*self.w/n)+1,int(9*self.w/n),self.h+1,int(self.h+self.h/12),'remove_track_function',0,0,204,'remove track')
        self.buttons = [back_frame_button,next_frame_button,back_track_button,next_track_button,back_focal]
        self.buttons.extend([next_focal,remove_point,add_to_track,remove_track,blue_button, green_button])
    
    #draw points on image
    def draw_points(self, picture, listofpositions, color1,color2,color3, r):
        for i in listofpositions:
            cv2.circle(picture, (int(i[1]), np.size(picture,0) - int(i[0])), r, (color1,color2,color3), -1)

    #to process clicking on a button
    def clicked(self, event, x, y, flags, param):
        if event == cv2.EVENT_LBUTTONDOWN:
            for button in self.buttons:
                if button.check_if_pressed(x,y,self.full_pic, self):
                    break
            else:
                self.add_point(x,y)

    #to draw window
    def draw_window(self):
        image = cv2.imread(self.files[skip*self.framecount]) #read image - skipping frames
        self.draw_points(image, self.listofpositions[self.framecount], 0,0,255, 20) #draw positions
        self.draw_points(image, self.listoftracks[self.focaltrackcount]['track'], 0,255,0, 0) # draw focal track
        self.draw_points(image, self.listoftracks[self.trackcount]['track'], 255,0,0, 0) # draw new track
        new_pic = cv2.resize(image, (self.w, self.h))
        self.full_pic[0:self.h,0:self.w] = new_pic[0:self.h,0:self.w] #put onto frame
        num = self.listoftracks[self.trackcount]['first_frame']-self.listoftracks[self.focaltrackcount]['first_frame']
        num -= len(self.listoftracks[self.focaltrackcount]['track'])
        cv2.putText(self.full_pic, 'frames ahead: %d'%num, (self.w-200,20), cv2.FONT_HERSHEY_DUPLEX, 0.5, (0,0,0), 1, cv2.LINE_AA)
        cv2.imshow('pic0', self.full_pic)

    # add a point to picture and to current track
    def add_point(self, x, y):
        x = int((1/self.factor)*x) #to adjust for picture
        y = int((1/self.factor)*self.h - (1/self.factor)*y) #to adjust for picture
        self.listoftracks[self.focaltrackcount]['connected'].append(len(self.listoftracks[self.focaltrackcount]['track']))
        self.listoftracks[self.focaltrackcount]['track'] = np.vstack((self.listoftracks[self.focaltrackcount]['track'], np.array([y,x])))
        self.listofpositions[self.framecount] = np.vstack((self.listofpositions[self.framecount], np.array([y,x]))) #add to frame
        self.draw_window()

    #find a good track to go to, given the focal track
    def find_trackcount(self):
        for i in range(len(self.listoftracks)): #make trackcount above 0
            num = self.listoftracks[i]['first_frame'] - self.listoftracks[self.focaltrackcount]['first_frame']
            num -= len(self.listoftracks[self.focaltrackcount]['track'])
            if num >= 0:
                self.trackcount = i
                break

    #to call right function when button is pressed
    def function_caller(self, function):
        if function == 'next_frame_function':
            self.next_frame_function()
        elif function == 'back_frame_function':
            self.back_frame_function()
        elif function == 'next_track_function':
            self.next_track_function()
        elif function == 'back_track_function':
            self.back_track_function()
        elif function == 'next_focal_track_function':
            self.next_focal_track_function()
        elif function == 'back_focal_track_function':
            self.back_focal_track_function()
        elif function == 'remove_point_function':
            self.remove_point_function()
        elif function == 'add_to_track_function':
            self.add_to_track_function()
        elif function == 'remove_track_function':
            self.remove_track_function()
        elif function == 'blue_function':
            self.blue_function()
        elif function == 'green_function':
            self.green_function()

    #to help have keyboard shortcuts
    def detect_keys(self, key):
        if key == 3: #right key = move frame forward
            self.next_frame_function()
        elif key == 2: #left key = move frame back
            self.back_frame_function()
        elif key == 0: #up key = move track forward
            self.next_track_function()
        elif key == 1: #down key = move track back
            self.back_track_function()
        elif key == ord('.'): # . key = move focal track forward
            self.next_focal_track_function()
        elif key == ord(','): # , key = move focal track back
            self.back_focal_track_function()
        elif key == 127: #delete key = delete point
            self.remove_point_function()
        elif key == 32: #space key = add to track
            self.add_to_track_function()
        elif key == 45: # - key = remove track
            self.remove_track_function()
        elif key == ord('g'): # 'g' key = green function
            self.green_function()
        elif key == ord('b'): # 'b' key = blue function
            self.blue_function()


    #functions for buttons
    def next_frame_function(self):
        n = self.framecount + 1
        if n < len(self.files)/skip and n >= 0:
            self.framecount = n
        self.draw_window()

    def back_frame_function(self):
        n = self.framecount - 1
        if n < len(self.files)/skip and n >= 0:
            self.framecount = n
        self.draw_window()

    def next_track_function(self):
        n = self.trackcount + 1
        if n < len(self.listoftracks) and n >= 0:
            self.trackcount = n
        self.draw_window()

    def back_track_function(self):
        n = self.trackcount - 1
        if n < len(self.listoftracks) and n >= 0:
            self.trackcount = n
        self.draw_window()

    def next_focal_track_function(self):
        #check that doesn't go out of bounds
        n1 = self.focaltrackcount + 1
        if n1 < len(self.listoftracks) and n1 >= 0:
            self.focaltrackcount = n1
        #to get good frame
        n = self.listoftracks[self.focaltrackcount]['first_frame'] + len(self.listoftracks[self.focaltrackcount]['track'])
        if n >= len(self.files)/skip:
            self.framecount = int(len(self.files)/skip) - 1
        else:
            self.framecount = n
        self.find_trackcount() #make blue track start after green track ends
        self.draw_window()

    def back_focal_track_function(self):
        #check that doesn't go out of bounds
        n1 = self.focaltrackcount - 1
        if n1 < len(self.listoftracks) and n1 >= 0:
            self.focaltrackcount = n1
        #to get good frame
        n = self.listoftracks[self.focaltrackcount]['first_frame'] + len(self.listoftracks[self.focaltrackcount]['track'])
        if n >= len(self.files)/skip:
            self.framecount = int(len(self.files)/skip) - 1
        else:
            self.framecount = n
        self.find_trackcount() #make blue track start after green track ends
        self.draw_window()

    def remove_point_function(self):
        the_track = self.listoftracks[self.focaltrackcount]['track'] #to make code cleaner
        if len(the_track) > 0: #make sure no exceptions
            self.listoftracks[self.focaltrackcount]['track'] = np.delete(the_track, len(the_track) - 1, 0)
        if len(self.listofpositions[self.framecount]) > 0: #for no exceptions
            self.listofpositions[self.framecount] = np.delete(self.listofpositions[self.framecount], len(self.listofpositions[self.framecount])-1, 0)
        self.draw_window()

    def add_to_track_function(self):
        if self.focaltrackcount == self.trackcount: #make sure isn't same as focal track
            pass
        else:
            self.listoftracks[self.focaltrackcount]['connected'].append(len(self.listoftracks[self.focaltrackcount]['track'])) #put marker
            temp = np.vstack((self.listoftracks[self.focaltrackcount]['track'], self.listoftracks[self.trackcount]['track'])) #for cleaner code
            self.listoftracks[self.focaltrackcount]['track'] = temp
            self.listoftracks = np.delete(self.listoftracks, self.trackcount, 0)
        self.find_trackcount() #make blue track start after green track ends
        self.draw_window()

    def remove_track_function(self):
        if len(self.listoftracks) > 0:
            self.listoftracks = np.delete(self.listoftracks, self.trackcount, 0)
        self.draw_window()

    def blue_function(self):
        n = self.listoftracks[self.trackcount]['first_frame'] #to get good frame
        if n >= len(self.files)/skip:
            self.framecount = int(len(self.files)/skip) - 1
        else:
            self.framecount = n
        self.draw_window()

    def green_function(self):
        n = self.listoftracks[self.focaltrackcount]['first_frame'] + len(self.listoftracks[self.focaltrackcount]['track']) #to get good frame
        if n >= len(self.files)/skip:
            self.framecount = int(len(self.files)/skip) - 1
        else:
            self.framecount = n
        self.draw_window()

In [None]:
#RUN

#get good track, draw image on window, set mouse to work
window = Window(positions_path, tracks_path, picture_folder_path, factor)

window.find_trackcount()
window.draw_window()
cv2.namedWindow('pic0')
cv2.setMouseCallback('pic0', window.clicked)

#loop to keep image updating
while True:
    cv2.imshow('pic0', window.full_pic)
    #stop loop if press 'x'
    key = cv2.waitKey(2) &0xff
    window.detect_keys(key)
    if key == 27:
        break
cv2.destroyAllWindows()