In [None]:
try:
    import norfair
except:
    !pip install -q norfair
    import norfair

import tqdm
import gc
import numpy as np
import torch, torchvision
import matplotlib.pyplot as plt
import cv2
from PIL import Image
from scipy.optimize import linear_sum_assignment

from torchvision.models.detection import fasterrcnn_resnet50_fpn_v2, FasterRCNN_ResNet50_FPN_V2_Weights
from torchvision.utils import draw_bounding_boxes
from torchvision.transforms.functional import to_pil_image


device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
print(device)

In [None]:
weights = FasterRCNN_ResNet50_FPN_V2_Weights.DEFAULT
model = fasterrcnn_resnet50_fpn_v2(weights=weights, box_score_thresh=0.9)
model = model.to(device)
model.eval()

preprocess = weights.transforms()

In [None]:
tracker = norfair.Tracker(
    distance_function="iou",
    distance_threshold=0.5,
    initialization_delay=5,
)

In [None]:
# https://www.pexels.com/video/cars-on-highway-854671/
vid_cap = cv2.VideoCapture("854671-hd_1280_720_50fps.mp4")

FPS = vid_cap.get(cv2.CAP_PROP_FPS)
WIDTH = int(vid_cap.get(cv2.CAP_PROP_FRAME_WIDTH))
HEIGHT = int(vid_cap.get(cv2.CAP_PROP_FRAME_HEIGHT))
FRAME_COUNT = vid_cap.get(cv2.CAP_PROP_FRAME_COUNT)

print("Video capture FPS: {}, frames: {}, frame dimensions: {}x{}.".format(
    FPS, FRAME_COUNT, WIDTH, HEIGHT))

desample_rate = 2
vid_write = cv2.VideoWriter('prediction.mp4',
                            cv2.VideoWriter_fourcc(*'MP4V'),
                            FPS,
                             [WIDTH, HEIGHT],
                            True)

success = True
n = 0
preds = []
with tqdm.tqdm(total = FRAME_COUNT) as pbar:
    while success is True:
        success, bgr_frame = vid_cap.read()
        if success is False:
            break

        rgb_frame = cv2.cvtColor(bgr_frame, cv2.COLOR_BGR2RGB)
        batch = [preprocess(Image.fromarray(rgb_frame)).to(device)]
        with torch.no_grad():
            p = model(batch)[0]
            for k in p:
                p[k] = p[k].detach().cpu().numpy()
            preds.append(p)

        bboxes = np.round(p.get("boxes")).astype(np.int32)
        labels = p.get("labels")
        scores = p.get("scores")
        want = np.isin(labels, [3, 4, 6, 8])
        bboxes, labels, scores = bboxes[want], labels[want], scores[want]

        ids = [-1] * len(bboxes)
        norfair_detections = [
            norfair.Detection(points=p.reshape(2, 2), scores=np.array([s, s]), label=l)
            for p, l, s in zip(bboxes, labels, scores)
        ]
        tracked_objects = tracker.update(norfair_detections)
        if len(tracked_objects) > 0:
            dist_matrix = tracker.distance_function._compute_distance(
                bboxes, np.array([to.last_detection.points.ravel() for to in tracked_objects])
            )
            det_idx, trk_idx = linear_sum_assignment(dist_matrix)
            for d, t in zip(det_idx, trk_idx):
                ids[d] = tracked_objects[t].id - 1

        vis_frame = rgb_frame.copy()
        for b, l, x in zip(bboxes, labels, ids):
            b = np.round(b).astype(np.int32)
            vis_frame = cv2.rectangle(vis_frame, b[:2], b[2:], color=[0, 255, 0], thickness=2)
            vis_frame = cv2.putText(vis_frame, "{}".format(x), [b[0], b[3]],
                                    fontFace = cv2.FONT_HERSHEY_PLAIN,
                                    fontScale=2, color=[0, 255, 0], thickness = 2)

        vid_write.write(cv2.cvtColor(vis_frame, cv2.COLOR_RGB2BGR))
        n = n + 1
        pbar.update(1)

vid_cap.release()
vid_write.release()

In [None]:
# !ffmpeg -i output.mp4 \
#    -vf "fps=10,scale=800:-1:flags=lanczos,split[s0][s1];[s0]palettegen=max_colors=32[p];[s1][p]paletteuse=dither=bayer" \
#    -loop 0 output.gif