In [1]:
import argparse
import json
import os 

import cv2
import numpy as np
import PIL.ImageColor as ImageColor

from tqdm import tqdm
from glob import glob

from norfair import Detection, Tracker

In [2]:
def euclidean_distance(detection, tracked_object):
    detection_center = detection.points[0, :]
    tracked_center   = tracked_object.estimate[0,:]
    return np.linalg.norm(detection_center - tracked_center)

def get_coordinates(candidate):
    xmin,ymin,w,h = candidate[2:6]
    xc=int(xmin+w/2)
    yc=int(ymin+h/2)
    return np.array([[xc, yc],
                     [xmin, ymin],
                     [w, h]])

def trunc_coords(bbox, img_h, img_w):
    xmin,ymin,w,h = bbox
    xmax = xmin+w
    ymax = ymin+h

    if xmin < 0:
        xmin=0
    if xmin > img_w:
        xmin=img_w
    if xmax < 0:
        xmax=0
    if xmax > img_w:
        xmax=img_w

    if ymin < 0:
        ymin=0
    if ymin > img_h:
        ymin=img_h
    if ymax < 0:
        ymax=0
    if ymax > img_h:
        ymax=img_h

    
    return [int(xmin), int(ymin), int(xmax-xmin), int(ymax-ymin)]

def bbox_valid(bbox, img_h, img_w):
    xmin,ymin,w,h = bbox
    xmax = xmin+w
    ymax = ymin+h

    if np.all(np.array([xmin, xmax]) == 0):
        return False
    if np.all(np.array([ymin, ymax]) == 0):
        return False
    if np.all(np.array([ymin, ymax]) == img_h):
        return False
    if np.all(np.array([xmin, xmax]) == img_w):
        return False
    if (xmin >= xmax or ymin >= ymax):
        return False
    return True

In [3]:
def track_dir(gt_path, out_path, tracker, W=400, H=320, split='val'):
    os.makedirs(os.path.dirname(out_path), exist_ok=True)
    
    anns = [x.rstrip().split(',') for x in open(gt_path)]
    anns = [list(map(float, lst)) for lst in anns]
    anns.sort(key=lambda x: x[0])
    
    frames = set([x[0] for x in anns])
    
    mot_objs = []
    for frame in frames:
        frame_objects = [x for x in anns if x[0]==frame]
        frame_objects = [x for x in frame_objects if x[6]>=0.2]
        detections = [Detection(get_coordinates(candidate), data=candidate) 
                      for candidate in frame_objects]
        tracked_objects = tracker.update(detections=detections)
        
        
        if len(tracked_objects) == 0:
            continue

        for tracked_object in tracked_objects:
            if not tracked_object.live_points.any():
                continue

            
            bbox = [int(tracked_object.estimate[1,0]),int(tracked_object.estimate[1,1]),
                    int(tracked_object.estimate[2,0]),int(tracked_object.estimate[2,1])]
            bbox = trunc_coords(bbox, H, W)
            if not bbox_valid(bbox, H, W):
                continue
            if len(tracked_object.last_detection.data) == 9:
                adds = tracked_object.last_detection.data[-3:]
                adds = [int(x) for x in adds[-3:]]
            elif len(tracked_object.last_detection.data) == 10:
                adds = tracked_object.last_detection.data[-4:]
                adds = [adds[-4]]+[int(x) for x in adds[-3:]]
            else:
                print('Invalid format')
            
            mot_obj = [int(frame), tracked_object.id] + bbox + adds
            
            mot_objs.append(mot_obj)
    
    with open(out_path, 'w') as f:
        for obj in mot_objs:
            line = ','.join([str(x) for x in obj])
            f.write('%s\n' % line)



In [None]:
for seq_dir in tqdm(glob('./results/011-best-val/detection/*.txt')):
    tracker = Tracker(distance_function=euclidean_distance, 
                  distance_threshold=30,
                  hit_inertia_min=5,
                  hit_inertia_max=20,
                  initialization_delay=1) 
        
    seq = os.path.splitext(seq_dir.split(os.sep)[-1])[0]
    out_path = os.path.join('./results/011-best-val/norfair', '%s.txt' % seq)
    track_dir(seq_dir, out_path, tracker)

 31%|████████████▉                             | 4/13 [05:30<10:48, 72.11s/it]