In [14]:
%pip install deep-sort-realtime

Note: you may need to restart the kernel to use updated packages.


In [18]:
%matplotlib inline

In [1]:
import torch
import numpy as np
import cv2 as cv
import matplotlib.pyplot as plt
from PIL import Image
import os
import uuid
from colors import sky_blue, blue
from VideoExporter import VideoExporter
from HSVClassifier import HSVClassifier
from video import display_video
from Detector import PlayerDetector
from tracker import DeepSortTracker

In [2]:
long_video_path = "./input/long-input.mp4"
short_video_path = "./input/short-input.mp4"
output_video_path = "./output.mp4"
output_folder = "./output/"
teams = [
    {"name": "chelsea", "colors": blue, "bgr_color": (255, 0, 0)},
    {"name": "mancity", "colors": sky_blue, "bgr_color": (250, 206, 135)},
]
model_path = "/Users/haithemsaida/Projects/Perso/yolov5/"

In [3]:
# ball_model = torch.hub.load(model_path, "custom", path="./Ball.pt", source="local")

In [4]:
model = torch.hub.load(
    "/Users/haithemsaida/Projects/Perso/yolov5", "yolov5n", source="local",
)

YOLOv5 🚀 v7.0-284-g95ebf68f Python-3.11.7 torch-2.2.1 CPU

Fusing layers... 
YOLOv5n summary: 213 layers, 1867405 parameters, 0 gradients, 4.5 GFLOPs
Adding AutoShape... 


# Detecting players using YOLOv5 :

In [5]:
def save_image(arr):
    if arr["label"].split(" ")[0] == "person":
        if not os.path.exists(f"./images"):
            os.makedirs(f"./images")
        im = Image.fromarray(arr["im"])
        im.save(f"./images/{uuid.uuid4()}.png")

# Ball detection using YOLOv5 and a custom-built dataset :

In [6]:
def draw(frame,box,team,color):
    xB = int(box[2])
    xA = int(box[0])
    yB = int(box[3])
    yA = int(box[1])
    cv.rectangle(frame, (xA, yA), (xB, yB), color, 3)
    font_scale = 0.7
    thickness = 2
    text_width, text_height = cv.getTextSize(
        team, cv.FONT_HERSHEY_SIMPLEX, font_scale, thickness
    )[0]
    text_x = (xB + xA) // 2 - text_width // 2
    text_y = (
                yA - 8
            )  

    cv.putText(
        frame,
        team,
        (text_x, text_y),
        cv.FONT_HERSHEY_SIMPLEX,
        font_scale,
        color,
        thickness,
    )

In [7]:
tracker = DeepSortTracker()
track_history = {}

In [11]:
cap = cv.VideoCapture(short_video_path)

frame_width = int(cap.get(cv.CAP_PROP_FRAME_WIDTH))
frame_height = int(cap.get(cv.CAP_PROP_FRAME_HEIGHT))
fps = int(cap.get(cv.CAP_PROP_FPS))
fourcc = cv.VideoWriter_fourcc(*"MP4V")
out = cv.VideoWriter(output_video_path, fourcc, fps, (frame_width, frame_height))

# A way to limit FPS
FPS = -1
skipped = 0
frames_counter = 0
hsv_classifier = HSVClassifier(teams)
player_detector = PlayerDetector(model, teams)
while cap.isOpened():
    ret, frame = cap.read()
    if not ret:
        break
    if skipped < FPS:
        skipped += 1
    else:
        frame_copy = frame.copy()
        frames_counter += 1
        skipped = 0
        players = player_detector.inference(frame_copy)
        predictions = []
        for player in players:
            box=player['box']
            conf=player['conf']
            label=player['label']
            left = int(box[0])
            top = int(box[1])
            width = int(box[2] - box[0])
            height = int(box[3] - box[1])
            bbox = [left, top, width, height]
            predictions.append((bbox, conf, label))
        tracks_current = tracker.object_tracker.update_tracks(
            predictions, frame=frame_copy
        )
        tracker.display_track(track_history, tracks_current, frame_copy)
        # for player in players:
        #     box, current_player_team, color = player.values()
        #     draw(frame_copy, box, current_player_team, color)
        out.write(frame_copy)
cap.release()
out.release()

OpenCV: FFMPEG: tag 0x5634504d/'MP4V' is not supported with codec id 12 and format 'mp4 / MP4 (MPEG-4 Part 14)'
OpenCV: FFMPEG: fallback to use tag 0x7634706d/'mp4v'


In [None]:
def calculate_centroid(tl_x, tl_y, w, h):
    mid_x = int(tl_x + w / 2)
    mid_y = int(tl_y + h / 2)
    return mid_x, mid_y


def convert_output(outputs: torch.Tensor):
    # Output of format []
    return


def convert_history_to_dict(track_history):
    history_dict = {}
    for frame_content in track_history:
        obj_ids, tlwhs, _ = frame_content
        for obj_id, tlwh in zip(obj_ids, tlwhs):
            tl_x, tl_y, w, h = tlwh
            mid_x, mid_y = calculate_centroid(tl_x, tl_y, w, h)

            if obj_id not in history_dict.keys():
                history_dict[obj_id] = [[mid_x, mid_y]]
            else:
                history_dict[obj_id].append([mid_x, mid_y])

    return history_dict


def plot_tracking(image, track_history):
    obj_ids, tlwhs, class_ids = track_history[-1]
    history_dict = convert_history_to_dict(track_history)

    im = np.ascontiguousarray(np.copy(image))
    im_h, im_w = im.shape[:2]

    top_view = np.zeros([im_w, im_w, 3], dtype=np.uint8) + 255

    num_detections = len(tlwhs)
    label_count = {class_name: 0 for class_name in CLASSES}
    for label_idx in class_ids:
        label_count[ID2CLASSES[label_idx]] += 1

    for i, tlwh in enumerate(tlwhs):
        x1, y1, w, h = tlwh
        intbox = tuple(map(int, (x1, y1, x1 + w, y1 + h)))
        obj_id = int(obj_ids[i])
        class_id = class_ids[i]
        id_text = "{}".format(int(obj_id))
        color = COLORS[class_id]
        cv.rectangle(
            im, intbox[0:2], intbox[2:4], color=color, thickness=line_thickness
        )
        cv.putText(
            im,
            id_text,
            (intbox[0], intbox[1]),
            cv.FONT_HERSHEY_PLAIN,
            text_scale,
            color,
            thickness=text_thickness,
        )
        cv.putText(
            im,
            ID2CLASSES[class_id],
            (intbox[0], intbox[3] + 20),
            cv.FONT_HERSHEY_PLAIN,
            text_scale,
            color,
            thickness=text_thickness,
        )

        for idx in range(len(history_dict[obj_id]) - 1):
            prev_point, next_point = (
                history_dict[obj_id][idx],
                history_dict[obj_id][idx + 1],
            )
            cv.line(im, prev_point, next_point, color, 2)

    return im