# Label video frames for training supervised ML models

Simple approach to loading in a movie and allowing the user to label each frame using keystrokes assigned to particular behaviors

In [1]:
import cv2
import numpy as np
from tqdm.notebook import tqdm
from tqdm import tnrange
import numpy as np

In [59]:
#####################################################
####### Load Video Frames ##########################
#####################################################
def LoadVideoFrames(video_file,num_frames=None):
    video = cv2.VideoCapture(file)
    frames = []
    
    if num_frames is None:
        num_frames = int(video.get(cv2.CAP_PROP_FRAME_COUNT))

    for i in tqdm(range(num_frames),desc='Loading video'):
        # Read video capture
        ret, frame = video.read()
        gray=cv2.cvtColor(frame,cv2.COLOR_BGR2GRAY)

        frames.append(gray)

        key = cv2.waitKey(1)

    video.release()
    
    return frames

#####################################################
####### Play Video Frames ###########################
#####################################################
def PlayVideoFrames(frames):
    
    playVideo = True

    frame_counter = 0
    while playVideo is True:

        frame = frames[frame_counter]
        cv2.imshow('video',frame)

        key = cv2.waitKey(0)

        while key not in [ord('q'),ord(','),ord('.')]:
            key = cv2.waitKey(0)

        if key == ord('.'):
            frame_counter += 1
        elif key == ord(','):
            frame_counter -= 1
        if key == ord('q'):
            break
            
    cv2.destroyAllWindows()

#####################################################
####### Play & Label Video Frames ###################
#####################################################
def on_trackbar(val):
    return

def setFrameCounter(frame_counter,num_frames):
    
    #if the user has scrolled past the end, go to the beginning 
    if frame_counter == num_frames:
        frame_counter = 0 
    #if user has scrolled to the left of the beginning, go to the end
    elif frame_counter == -1:
        frame_counter = num_frames -1
    
    return frame_counter

def PlayAndLabelFrames(frames,label_dict = {'w':'walking','t':'turning','s':'standing'},return_labeled_frames=False):
    '''
    Set up variables
    '''
    #create numpy array to store the labels. Initialize as strings of zeros
    labels = np.zeros(len(frames)).astype('str')

    #get the key ords and names for each label
    label_ords = [ord(k) for k in list(label_dict.keys())]
    label_names = list(label_dict.values())
    #create a dictionary that maps the key ords to the label names
    #i.e. replacing keystrokes with key ords as the dict keys
    label_key_dict = {}
    for label_ord,label_name in zip(label_ords,label_names):
        label_key_dict[label_ord] = label_name
    
    #get number of frames
    num_frames = len(frames)
    
    #initialize frame_counter, set PlayVideo boolean to True, and start displaying video
    #for labeling
    playVideo = True
    frame_counter = 0

    # create display window
    cv2.namedWindow('Video',cv2.WINDOW_NORMAL)
    cv2.resizeWindow('Video',800,800)
    cv2.createTrackbar('frame', 'Video', 0,num_frames,on_trackbar)
    '''
    Play & Label Video
    '''
    while playVideo is True:

        #get current frame & display it
        frame_counter = cv2.getTrackbarPos('frame','Video')
        frame = frames[frame_counter]
        cv2.imshow('Video',frame)

        #wait for keypress
        key = cv2.waitKey(0)

        '''
        Check to see if the user pressed any of the label keys
        '''
        if key in label_ords:
            #get the label name
            label = label_key_dict[key]
            
            '''
            #annotate the frame with the label text
            cv2.rectangle(frame,(0,1024),(250,950),(0,0,0),-1) #need a solid background so that...
            #...the labels can be overwritten
            cv2.putText(frame,label,(0,1000),cv2.FONT_HERSHEY_COMPLEX,1,(255,255,255),2,cv2.LINE_AA)
            '''
            
            #annotate the frame with the label text
            cv2.rectangle(frame,(0,900),(250,800),(0,0,0),-1) #need a solid background so that...
            #...the labels can be overwritten
            cv2.putText(frame,label,(0,875),cv2.FONT_HERSHEY_COMPLEX,1,(255,255,255),2,cv2.LINE_AA)
            
            
            #update the frame (with annotation)
            frames[frame_counter] = frame
            #update the label array with current label
            labels[frame_counter] = label

            '''
        Now check to see if the user to trying to control the playback
        '''
        elif key == ord(','): # if `<` then go back
            frame_counter -= 1
            frame_counter = setFrameCounter(frame_counter,num_frames)
            cv2.setTrackbarPos("frame","Video", frame_counter)

        elif key == ord('.'): # if `>` then advance
            frame_counter += 1
            frame_counter = setFrameCounter(frame_counter,num_frames)
            cv2.setTrackbarPos("frame","Video", frame_counter)

        elif key == ord('q'): #if `q` then quit
            playVideo = False


    #close any opencv windows    
    cv2.destroyAllWindows()
    
    #if return_labeled_frames is True, return them along with the labels. Else just returns labels
    if return_labeled_frames:
        return labels,frames
    else:
        #return labels
        return labels
    

#####################################################
############# Interpolate Labels ####################
#####################################################

def interpolate_labels(labels):
    '''
    Interpolate frame labels for unlabeled frames where the previous labeled frame and the next labeled frame have the same label. 
    IE. a sequence of ['standing','nolabel','nolabel','standing','walking','nolabel'] is transformed to:
    ['standing','standing','standing','standing','walking','nolabel']
    
    Note that the last unlabeled frame is not interpolated, since there was no labeled frame following it. The same thing happens
    for unlabeled frames that occur before the first labeled frame. 
    '''
    nolab_frames = np.where(labels == '0.0')[0]
    lab_frames = np.where(labels != '0.0')[0]
    labels_interp = labels.copy()

    for nolab_frame in nolab_frames:

        if ((nolab_frame > 0) & (np.where(lab_frames > nolab_frame)[0].size>0)):

            label_prev = labels_interp[nolab_frame-1]
            label_next = labels_interp[lab_frames[np.where(lab_frames > nolab_frame)[0][0]]]

            if label_prev == label_next:
                labels_interp[nolab_frame] = label_prev
    
    return labels_interp


#####################################################
########## Annotate Frames with Labels ##############
#####################################################
def annotate_frames(frames,labels):
    
    frames_out = frames.copy()
    
    num_frames = len(frames)
    num_labels = labels.shape[0]
    
    assert num_frames == num_labels,'number of frames must equal number of labels'
    

    for i in range(num_frames):
        
        frame = frames_out[i]
        label = labels[i]
        
        if label is not '0.0':
            '''
            for 1024x1280
            #annotate the frame with the label text
            cv2.rectangle(frame,(0,1024),(250,950),(0,0,0),-1) #need a solid background so that...
            #...the labels can be overwritten
            cv2.putText(frame,label,(0,1000),cv2.FONT_HERSHEY_COMPLEX,1,(255,255,255),2,cv2.LINE_AA)
            '''
            #annotate the frame with the label text
            cv2.rectangle(frame,(0,900),(250,800),(0,0,0),-1) #need a solid background so that...
            #...the labels can be overwritten
            cv2.putText(frame,label,(0,875),cv2.FONT_HERSHEY_COMPLEX,1,(255,255,255),2,cv2.LINE_AA)
            
            
            #overwrite the frame
            frames_out[i] = frame
    
    return frames_out

In [3]:
label_dict = {'w':'walking','t':'turning','s':'standing'}

In [56]:
#%%time
#file = '/home/sneufeld/Desktop/752_openfield.avi'
file = '/home/sneufeld/Downloads/Basler_acA1300-60gm__21503351__20191212_131714734.mp4'
num_frames = 100
frames = LoadVideoFrames(file,num_frames)

HBox(children=(IntProgress(value=0, description='Loading video', style=ProgressStyle(description_width='initia…




In [48]:
frames[0].shape

(1024, 1280)

In [54]:
frames[0].shape

(900, 1100)

In [28]:
cv2.destroyAllWindows()

In [60]:
labels,frames_new = PlayAndLabelFrames(frames,return_labeled_frames=True)

In [58]:
labels

array(['walking', '0.0', '0.0', '0.0', '0.0', '0.0', '0.0', 'walking',
       '0.0', '0.0', '0.0', 'walking', '0.0', '0.0', '0.0', '0.0', '0.0',
       '0.0', '0.0', '0.0', '0.0', '0.0', '0.0', '0.0', '0.0', '0.0',
       '0.0', '0.0', '0.0', '0.0', '0.0', '0.0', '0.0', '0.0', '0.0',
       '0.0', '0.0', '0.0', '0.0', '0.0', '0.0', '0.0', '0.0', '0.0',
       '0.0', '0.0', '0.0', '0.0', '0.0', '0.0', '0.0', '0.0', '0.0',
       '0.0', '0.0', '0.0', '0.0', '0.0', '0.0', '0.0', '0.0', '0.0',
       '0.0', '0.0', '0.0', '0.0', '0.0', '0.0', '0.0', '0.0', '0.0',
       '0.0', '0.0', '0.0', '0.0', '0.0', '0.0', '0.0', 'standing',
       'standing', 'standing', '0.0', '0.0', '0.0', '0.0', '0.0', '0.0',
       '0.0', '0.0', '0.0', '0.0', '0.0', '0.0', '0.0', '0.0', '0.0',
       '0.0', '0.0', '0.0', '0.0'], dtype='<U32')

In [25]:
def frames2video(video_name,frames,fps):
    
    size = frames[0].shape
    print(size)
    #video = cv2.VideoWriter(video_name,cv2.VideoWriter_fourcc(*'DIVX'),fps,size)
    video = cv2.VideoWriter(video_name,-1,fps,(size[1],size[0]))
    for frame in frames:
        video.write(frame)
    video.release()

In [34]:
frames[0].shape

(1024, 1280)

In [37]:
video_name = '/home/sneufeld/Desktop/video.avi'

size = frames[0].shape
width = size[1]
height = size[0]
print(width)
print(height)
fps = 30.0
#video = cv2.VideoWriter(video_name,cv2.VideoWriter_fourcc(*'DIVX'),fps,size)
video = cv2.VideoWriter(video_name,cv2.VideoWriter_fourcc('M','J','P','G'),fps,(width,height))
for frame in frames:
    video.write(frame)
video.release()

1280
1024


In [91]:
labels_interp

array(['0.0', '0.0', '0.0', '0.0', 'turning', 'turning', 'turning',
       'turning', 'turning', 'turning', 'turning', 'turning', 'turning',
       'turning', 'turning', 'turning', 'turning', 'turning', 'turning',
       'turning', 'turning', 'turning', 'turning', 'turning', 'turning',
       'turning', 'turning', 'turning', 'turning', 'turning', 'turning',
       'turning', 'turning', 'turning', 'turning', 'turning', 'turning',
       'turning', 'turning', 'turning', 'turning', 'turning', 'turning',
       'turning', 'turning', 'turning', 'turning', 'turning', 'turning',
       'turning', 'turning', 'turning', 'turning', 'turning', 'turning',
       'turning', 'turning', 'turning', 'turning', 'turning', 'turning',
       'turning', 'turning', 'turning', 'turning', 'turning', 'turning',
       'walking', 'walking', 'walking', 'walking', 'walking', 'walking',
       'standing', 'standing', 'standing', 'standing', 'standing',
       'standing', 'standing', 'standing', 'standing', 'standi

In [58]:
np.where(labels == 'walking')[0]

array([67, 68, 69, 70, 71, 72])

In [59]:
np.where(labels_interp == 'walking')[0]

array([67, 68, 69, 70, 71, 72])

In [24]:
lab_frames[0]

array([ 4,  5,  6,  7,  8,  9, 12, 19, 26, 55, 59, 60, 61, 62, 63, 64, 65,
       66, 67, 68, 69, 70, 71, 72, 73, 74, 75, 76, 77, 84])

In [33]:
nolab_frames

array([ 0,  1,  2,  3, 10, 11, 13, 14, 15, 16, 17, 18, 20, 21, 22, 23, 24,
       25, 27, 28, 29, 30, 31, 32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42,
       43, 44, 45, 46, 47, 48, 49, 50, 51, 52, 53, 54, 56, 57, 58, 78, 79,
       80, 81, 82, 83, 85, 86, 87, 88, 89, 90, 91, 92, 93, 94, 95, 96, 97,
       98, 99])

In [27]:
lab_frames

array([ 4,  5,  6,  7,  8,  9, 12, 19, 26, 55, 59, 60, 61, 62, 63, 64, 65,
       66, 67, 68, 69, 70, 71, 72, 73, 74, 75, 76, 77, 84])

In [32]:
labels[lab_frames[np.where(lab_frames > 10)[0][0]]]

'turning'

In [9]:
for nolab_start,nolab_end in zip(unlabeled_frames[:-1],unlabeled_frames[1:]:
    
    if nolab_start > 0:
        #previous label is easy
        label_prev = labels[nolab_start - 1]
        #next label is harder - requires peeking into the labeled frame indices array
                            
        label_next = labels[nolab_end + 1]
        
        if label_prev == label_next:
            
        

(array([ 0,  1,  2,  3, 10, 11, 13, 14, 15, 16, 17, 18, 20, 21, 22, 23, 24,
        25, 27, 28, 29, 30, 31, 32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42,
        43, 44, 45, 46, 47, 48, 49, 50, 51, 52, 53, 54, 56, 57, 58, 78, 79,
        80, 81, 82, 83, 85, 86, 87, 88, 89, 90, 91, 92, 93, 94, 95, 96, 97,
        98, 99]),)

In [12]:
nolab_start = labels[4]
nolab_start

'turning'

In [10]:
nolab_frames = np.where(labels != '0.0')
nolab_frames

(array([ 4,  5,  6,  7,  8,  9, 12, 19, 26, 55, 59, 60, 61, 62, 63, 64, 65,
        66, 67, 68, 69, 70, 71, 72, 73, 74, 75, 76, 77, 84]),)