In [None]:
import time
import cv2
import json
import numpy as np
from tqdm import tqdm
import torch
from torchreid import models
from collections import OrderedDict
from torchvision.transforms import Normalize, ToTensor, Resize, Compose
from PIL import Image, ImageDraw, ImageFont

from tracker import Tracker

In [None]:
VIDEO = 'video0015_cut.mp4'
DETECTIONS = 'detections.json'
OUTPUT = 'output.avi'
DEVICE = torch.device('cuda:0')

# make the video smaller by using:
# ffmpeg -i output.avi -s 960x540 -c:v libx264 -c:a copy output.mp4 -y

# Load model

In [None]:
p = 'checkpoints/osnet_ain_x1_0_msmt17_256x128_amsgrad_ep50_lr0.0015_coslr_b64_fb10_softmax_labsmth_flip_jitter.pth'
s = torch.load(p)
s = OrderedDict([(k[7:], v) for k, v in s.items()])

model = models.osnet_ain.osnet_ain_x1_0(num_classes=4101, pretrained=False)
model.load_state_dict(s)

model.fc[1] = torch.nn.Identity()
model.fc[2] = torch.nn.Identity()
model.classifier = torch.nn.Identity()
model = model.eval().to(DEVICE)

# Define preprocessing

In [None]:
mean = [0.485, 0.456, 0.406]
std = [0.229, 0.224, 0.225]
height = 256
width = 128

transform = Compose([
    Resize((height, width)),
    ToTensor(),
    Normalize(mean=mean, std=std),
])

# Define embedding

In [None]:
def get_descriptors(frame, boxes):

    if len(boxes) == 0:
        return []

    descriptors = []
    frame_rgb = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
    
    with torch.no_grad():
        for xmin, ymin, xmax, ymax in boxes:
            crop = frame_rgb[ymin:ymax, xmin:xmax]
            x = transform(Image.fromarray(crop))
            x = x.unsqueeze(0).to(DEVICE)
            descriptors.append(model(x).cpu().numpy())

    return np.concatenate(descriptors, axis=0)

# Box drawer

In [None]:
font = ImageFont.truetype('/usr/share/fonts/truetype/dejavu/DejaVuSans-Bold.ttf', size=40)
COLORS = {}
TRACKS = {}

def draw(image, boxes, ids):
    """
    Arguments:
        image: a numpy uint8 array with shape [h, w, 3].
        boxes: a list of arrays with shape [4].
    Returns:
        a numpy uint8 array with shape [h, w, 3].
    """

    image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
    image = Image.fromarray(image)
    drawer = ImageDraw.Draw(image, 'RGBA')

    for box, i in zip(boxes, ids):
            
        if i not in COLORS:
            COLORS[i] = tuple(np.random.randint(0, 256, size=3))
            
        outline = COLORS[i]
        xmin, ymin, xmax, ymax = box
        box = [(xmin, ymin), (xmax, ymax)]
        drawer.rectangle(box, outline=outline, width=10)
        drawer.text(box[0], f'{i}', font=font)
        
        if i not in TRACKS:
            TRACKS[i] = []
            
        TRACKS[i].append(((xmin + xmax)//2, ymin))
        if len(TRACKS[i]) > 1:
            for j, (x, y) in enumerate(TRACKS[i][:-1]):
                x1, y1 = TRACKS[i][j+1]
                drawer.line([(x, y), (x1, y1)], fill=outline + (150,), width=10)
    
    image = np.array(image)
    image = cv2.cvtColor(image, cv2.COLOR_RGB2BGR)
    return image

# Run

In [None]:
with open(DETECTIONS, 'r') as f:
    detections = f.read()
    detections = json.loads(detections)

In [None]:
times = []  # detection times
tracker = Tracker(threshold=0.6, wait=4)

cap = cv2.VideoCapture(VIDEO)
fps = int(cap.get(cv2.CAP_PROP_FPS))
width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH))
height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT))
num_frames = 10000#int(cap.get(cv2.CAP_PROP_FRAME_COUNT))

fourcc = cv2.VideoWriter_fourcc(*'MJPG')
out = cv2.VideoWriter(OUTPUT, fourcc, 1, (width, height))

frame_id = 0
for i in tqdm(range(num_frames)):
    _, frame = cap.read()
    
    if i % (2 * fps) != 0:
        continue

    boxes = detections[str(i)]['boxes']
    scores = detections[str(i)]['scores']
    
    boxes = np.array(boxes)
    scores = np.array(scores)
    boxes = boxes[scores > 0.9]
    
    if len(boxes) > 0:
        ymin = boxes[:, 1]
        boxes = boxes[(ymin/height) < 0.6]

    start_time = time.perf_counter()
    descriptors = get_descriptors(frame, boxes)
    tracks = tracker.update(boxes, descriptors)
    boxes = [t['x'] for t in tracks if t['u'] == 1]
    ids = [t['i'] for t in tracks if t['u'] == 1]
    times.append(time.perf_counter() - start_time)

    out_frame = draw(frame, boxes, ids)
    out.write(out_frame)

cap.release()
out.release()

times = np.array(times[10:])
print(times.mean(), times.std())