## Multi Object Tracking sample
This notebook is an example how to perform object detection with multi-object tracking (MOT) from a video file to count vehicle traffic.
The **ByteTracker** is used for multi-object tracking (see https://github.com/ifzhang/ByteTrack)

### This sample uses the following external packages, which need to be installed:
1. **cython_bbox**: `pip install -e git+https://github.com/samson-wang/cython_bbox.git#egg=cython-bbox`
1. **lap**: `pip install lap`


### Specify dataset parameters here

In [None]:
# model name to be used for inference
model_name = "yolo_v5s_coco--512x512_quant_n2x_orca_1"
# input video file
input_filename = 'images/Traffic.mp4'

In [None]:
import degirum as dg
import numpy as np
import mytools, cv2
from pathlib import Path
import IPython.display
lap = mytools.import_optional_package("lap")
cython_bbox = mytools.import_optional_package("cython_bbox")
from mot.byte_tracker import BYTETracker
from mot.basetrack import BaseTrack

### Specify inference option here

In [None]:
# Please uncomment and edit one of the following inference options to specify your system configuration case according to
# https://cs.degirum.com/doc/0.5.0/degirum.html#system-configuration-for-specific-use-cases

# 1. DeGirum Cloud Zoo inference:
#zoo = dg.connect_model_zoo("dgcps://cs.degirum.com", token=mytools.token_get())

# 2. AIServer inference via IP address using models from DeGirum Cloud model zoo
#zoo = dg.connect_model_zoo(("192.168.0.7", "https://cs.degirum.com/degirum_com/public"), token=mytools.token_get())

# 3. AIServer inference via IP address using local model zoo
#zoo = dg.connect_model_zoo("192.168.0.1")

# 4. ORCA board installed locally using models from DeGirum Cloud Model Zoo
#zoo = dg.connect_model_zoo("https://cs.degirum.com/degirum_com/public", token=mytools.token_get())

# 5. Local inference with locally deployed model
#zoo = dg.connect_model_zoo("full/path/to/model.json")

In [None]:
# load object detection model
model = zoo.load_model(model_name)

# set model parameters
model.image_backend = 'opencv' # select OpenCV backend: needed to have overlay image in OpenCV format
model.input_numpy_colorspace = 'BGR'
model.overlay_show_probabilities = True
model.overlay_line_width = 1
model._model_parameters.InputImgFmt = ['JPEG']

In [None]:
# video input and output
orig_path = Path(input_filename)
ann_path = orig_path.with_name(orig_path.stem + "_annotated" + orig_path.suffix) # this is output path, you can change.

In [None]:
class dict_dot_notation(dict):
    def __init__(self, *args, **kwargs):
        super().__init__(*args, **kwargs)
        self.__dict__ = self

# return bool, check line intersect
def intersect(a, b, c, d):
    s = (a[0] - b[0]) * (c[1] - a[1]) - (a[1] - b[1]) * (c[0] - a[0])
    t = (a[0] - b[0]) * (d[1] - a[1]) - (a[1] - b[1]) * (d[0] - a[0])
    if s * t > 0:
        return False
    s = (c[0] - d[0]) * (a[1] - c[1]) - (c[1] - d[1]) * (a[0] - c[0])
    t = (c[0] - d[0]) * (b[1] - c[1]) - (c[1] - d[1]) * (b[0] - c[0])
    if s * t > 0:
        return False
    return True

In [None]:
# AI prediction loop
# this loop make a video to image folder with suffix "_annotated"
with mytools.open_video_stream(input_filename) as stream:
    
    image_w = int(stream.get(cv2.CAP_PROP_FRAME_WIDTH))
    image_h = int(stream.get(cv2.CAP_PROP_FRAME_HEIGHT))

    # count line (x, y)
    line_start = (0, 2 * image_h // 3)
    line_end = (image_w, line_start[1])

    # counters for each direction
    left = right = top = bottom = 0
    
    BaseTrack._count = 0 # reset track counter
    
    with mytools.Display("MoT") as display, \
         mytools.open_video_writer(str(ann_path), image_w, image_h) as writer:
    
        fps = 30 # you can specify input video FPS if you want
        tracker = BYTETracker(
            args=dict_dot_notation({
                'track_thresh': 0.3,
                'track_buffer': fps * 2,
                'match_thresh': 0.8,
                'mot20': False,
            }),
            frame_rate=fps
        )
        timeout_count_dict = {}
        is_counted_dict = {}
        trail_dict = {}
        timeout_count_initial = fps

        progress = mytools.Progress(int(stream.get(cv2.CAP_PROP_FRAME_COUNT)))
        for batch_result in model.predict_batch(mytools.video_source(stream, report_error=False)):
            # object detection
            results = batch_result.results
            bboxes = np.zeros((len(results), 5))
            image = batch_result.image

            # byte track
            for index, result in enumerate(results):
                bbox = np.array(result.get('bbox', [0, 0, 0, 0]))
                score = result.get('score', 0)
                bbox_and_score = np.append(bbox, score)
                bboxes[index] = bbox_and_score

            online_targets = tracker.update(bboxes, (1, 1), (1, 1))
            online_target_set = set([])

            # tracking start or continue
            for target in online_targets:
                tid = str(target.track_id)
                online_target_set.add(str(tid))

                box = tuple(map(int, target.tlbr)) # x1 y1 x2 y2
                center = tuple(map(int, target.tlwh_to_xyah(target.tlwh)[:2]))
                if trail_dict.get(tid, None) is None:
                    trail_dict[tid] = []
                if is_counted_dict.get(tid, None) is None:
                    is_counted_dict[tid] = False
                if not is_counted_dict[tid] and len(trail_dict[tid]) > 1:
                    trail_start = trail_dict[tid][0]
                    trail_end = center
                    is_cross = intersect(line_start, line_end, trail_start, trail_end)
                    if is_cross:
                        if trail_start[0] > trail_end[0]:
                            left += 1
                        if trail_start[0] < trail_end[0]:
                            right += 1
                        if trail_start[1] < trail_end[1]:
                            top += 1
                        if trail_start[1] > trail_end[1]:
                            bottom += 1
                        is_counted_dict[tid] = True
                trail_dict[tid].append(center)
                timeout_count_dict[tid] = timeout_count_initial
                if len(trail_dict[tid]) > 1:
                    cv2.polylines(image, [np.array(trail_dict[tid])], False, (255, 255, 0))
                mytools.Display.put_text(image, tid, (box[0], box[3]), (255,255,255), (0,0,0), cv2.FONT_HERSHEY_PLAIN)
                cv2.rectangle(image, box[0:2], box[2:4], color=(0, 255, 0), thickness=1)
                cv2.drawMarker(image, center, (255, 255, 0), markerType=cv2.MARKER_CROSS)
                

            # tracking terminate
            for tid in set(timeout_count_dict.keys()) - online_target_set:
                timeout_count_dict[tid] -= 1
                if timeout_count_dict[tid] == 0:
                    del timeout_count_dict[tid], is_counted_dict[tid], trail_dict[tid]

            text = 'Top={} Bottom={} Left={} Right={}'.format(top, bottom, left, right)
            mytools.Display.put_text(image, text, (image_w // 3, 0), (255,255,255), (0,0,0), cv2.FONT_HERSHEY_PLAIN)
            cv2.line(image, line_start, line_end, (0, 255, 0))

            writer.write(image)
            display.show(image)
            progress.step()

In [None]:
# display result
IPython.display.Video(filename=str(ann_path))

In [None]:
# display original video
IPython.display.Video(filename=str(orig_path))

In [None]:
image_w

In [None]:
cv2.FONT_HERSHEY_SIMPLEX