In [1]:
import json
import cv2
import numpy as np
import pandas as pd
from imageai.Detection import ObjectDetection
import requests as req
import matplotlib.pyplot as plt
import matplotlib.image as img
import os
from random import randint
from glob import glob

In [2]:
if not os.path.exists('output'):
    os.makedirs('output')

In [3]:
# Function to get download model
def getModel():    
    model_url = 'https://github.com/OlafenwaMoses/ImageAI/releases/download/1.0/yolo.h5'
    if not os.path.exists('yolo.h5'):
        r = req.get(model_url, timeout=0.5)
        with open('yolo.h5', 'wb') as outfile:
            outfile.write(r.content)

In [4]:
#getModel()

In [5]:
# load Model
def loadModel():
    execution_path = os.getcwd()

    detector = ObjectDetection()
    detector.setModelTypeAsYOLOv3()
    detector.setModelPath( os.path.join(execution_path , "yolo.h5"))
    detector.loadModel()
    peopleOnly = detector.CustomObjects(person=True)
    return detector, peopleOnly



In [6]:
trackerTypes = ['KCF', 'CSRT']

def createTracker(trackerType):
    if trackerType == trackerTypes[0]:
        tracker = cv2.TrackerKCF_create()
    elif trackerType == trackerTypes[1]:
        tracker = cv2.TrackerCSRT_create()
    else:
        tracker = None
        print('Incorrect tracker name')
        print('Available trackers are:')
        for t in trackerTypes:
            print(t)

    return tracker

In [7]:
# loadind model
detector, peopleOnly = loadModel()

In [8]:


# function to to detect objects with bounding box
def detection(image, detector=detector, detection_objects=peopleOnly):
    detectedImage, detections = detector.detectObjectsFromImage(custom_objects=detection_objects,
                                                            input_type="array", 
                                                            input_image=image, 
                                                            output_type = "array",
                                                            minimum_percentage_probability=30)
    BBoxes = []
    colors = []
    for BBox in detections:
        BBoxes.append((BBox["box_points"][0], 
                       BBox["box_points"][1], 
                       int(BBox["box_points"][2]-BBox["box_points"][0]), 
                       int(BBox["box_points"][3]-BBox["box_points"][1])))
        colors.append(((randint(64, 255), randint(64, 255), randint(64, 255))))
    return BBoxes, colors
    


In [9]:
def cross(p1,p2,p3): # Cross product judgment
    x1=p2[0]-p1[0]
    y1=p2[1]-p1[1]
    x2=p3[0]-p1[0]
    y2=p3[1]-p1[1]
    return x1*y2-x2*y1  


#Determine whether the two line segments intersect
def segment(p1,p2,p3,p4): 
  
    if(max(p1[0],p2[0])>=min(p3[0],p4[0]) 
    and max(p3[0],p4[0])>=min(p1[0],p2[0]) 
    and max(p1[1],p2[1])>=min(p3[1],p4[1]) 
    and max(p3[1],p4[1])>=min(p1[1],p2[1])): 
        if(cross(p1,p2,p3)*cross(p1,p2,p4)<=0  
            and cross(p3,p4,p1)*cross(p3,p4,p2)<=0):
            D=1
        else:
            D=0
    else:
        D=0
    return D

def check(l1,l2,sq):
    # step 1 check if end point is in the square
    if ( l1[0] >= sq[0] and l1[1] >= sq[1] and  l1[0] <= sq[2] and  l1[1] <= sq[3]) or ( l2[0] >= sq[0] and l2[1] >= sq[1] and  l2[0] <= sq[2] and  l2[1] <= sq[3]):
        return 1
    else:
        # step 2 check if diagonal cross the segment
        p1 = [sq[0],sq[1]]
        p2 = [sq[2],sq[3]]
        p3 = [sq[2],sq[1]]
        p4 = [sq[0],sq[3]]
        if segment(l1,l2,p1,p2) or segment(l1,l2,p3,p4):
            return 1
        else:
            return 0

In [10]:
def objectTracker(filename, p1, p2):
    
    tracker = cv2.TrackerCSRT_create()
    video = cv2.VideoCapture(filename)
    ret, frame = video.read()
    token = filename.split("\\")
    outputFilename = "output\\output_"+token[-1]
    bboxes, colors = detection(frame)
    frame_width = int(video.get(3))
    frame_height = int(video.get(4))
    # I am using major version 4 though
    (major_ver, minor_ver, subminor_ver) = (cv2.__version__).split('.')
    if int(major_ver)  < 3 :
        fps = video.get(cv2.cv.CV_CAP_PROP_FPS)
    else :
        fps = video.get(cv2.CAP_PROP_FPS)

    size = (frame_width, frame_height)
    
    multiTracker = cv2.MultiTracker_create()
    for bbox in bboxes:
        multiTracker.add(createTracker('CSRT'), frame, bbox)
        
        
    out = cv2.VideoWriter(outputFilename,  cv2.VideoWriter_fourcc(*'XVID'), fps, size)
    
    items = []
    record = []
    frame_count = 0    
    while True:
        ret, frame = video.read()
        cv2.namedWindow('Tracking', cv2.WINDOW_NORMAL)
        frame_count += 1 
        if not ret:
            break
        ret, boxes = multiTracker.update(frame)
        
        
        for i, newbox in enumerate(boxes):
            p1 = (int(newbox[0]), int(newbox[1]))
            p2 = (int(newbox[0] + newbox[2]), int(newbox[1] + newbox[3]))
            cv2.rectangle(frame, p1, p2, colors[i], 2, 1)
            
            if i not in items:
                crossFlag = check(p1, p2, newbox)
                if crossFlag==1:
                    items.append(i)
                    (x, y, w, h) = [int(v) for v in newbox]
                    
                    time = frame_count/fps
                    
                    image = frame[y:y+h, x:x+w]
                    imagename = "output\\"+token[-1]+"Person_"+str(i)+".png"
                    cv2.imwrite(imagename, image)
                    dict_data = {"id":i,"time":str(time)}
                    record.append(dict_data)
        
        cv2.line(frame, p1, p2, (255,0,0), 12, lineType = 8, shift = 0)
        cv2.resizeWindow('Tracking', int(frame_width*0.5), int(frame_height*0.5))
        out.write(frame)
        cv2.imshow('Tracking', frame)
        if cv2.waitKey(1) & 0XFF ==27: # ESC
            break
    
    savepath = "output\\"+token[-1]+".json"
    with open(savepath, 'w') as f:
        json.dump(record, f)
          
    video.release()
    out.release()
    cv2.destroyAllWindows()
        
       
    
    
    

In [11]:
fileLists = glob("videos\*.mp4")
for path in fileLists:
    objectTracker(path, (0,20), (0,2000))
    