In [2]:
import os
import cv2
import torch
from ultralytics import YOLO
from collections import deque
from tqdm import tqdm
from scipy.optimize import linear_sum_assignment
import numpy as np

In [None]:
from google.colab import drive
drive.mount('/content/drive')

# ***Sort***

In [4]:
class KalmanFilter:
    def __init__(self):
        # Stato: [x, y, dx, dy]
        self.dt = 1.0
        self.A = np.array([[1, 0, self.dt, 0],
                           [0, 1, 0, self.dt],
                           [0, 0, 1, 0],
                           [0, 0, 0, 1]])
        self.H = np.eye(2, 4)  # Osserviamo solo x, y
        self.Q = np.eye(4) * 0.01
        self.R = np.eye(2) * 10.0
        self.P = np.eye(4) * 500.
        self.x = np.zeros((4, 1))  # stato iniziale

    def initiate(self, cx, cy):
        self.x[:2] = np.array([[cx], [cy]])

    def predict(self):
        self.x = self.A @ self.x
        self.P = self.A @ self.P @ self.A.T + self.Q
        return self.x[:2].flatten()

    def update(self, z):
        z = np.array(z).reshape(2, 1)
        S = self.H @ self.P @ self.H.T + self.R
        K = self.P @ self.H.T @ np.linalg.inv(S)
        y = z - self.H @ self.x
        self.x += K @ y
        self.P = (np.eye(4) - K @ self.H) @ self.P

#-------------------------------------------------------------------------------

class Track:
    def __init__(self, bbox, track_id):
        cx = (bbox[0] + bbox[2]) / 2 # Calculate center of the bounding box
        cy = (bbox[1] + bbox[3]) / 2 # Calculate center of the bounding box
        self.kf = KalmanFilter()
        self.kf.initiate(cx, cy)
        self.bbox = bbox
        self.id = track_id
        self.time_since_update = 0

    def predict(self):
        pred = self.kf.predict()
        self.time_since_update += 1
        return pred

    def update(self, bbox):
        cx = (bbox[0] + bbox[2]) / 2
        cy = (bbox[1] + bbox[3]) / 2
        self.kf.update([cx, cy])
        self.bbox = bbox
        self.time_since_update = 0

#-------------------------------------------------------------------------------

def iou(bb1, bb2):
    xA = max(bb1[0], bb2[0])
    yA = max(bb1[1], bb2[1])
    xB = min(bb1[2], bb2[2])
    yB = min(bb1[3], bb2[3])
    interArea = max(0, xB - xA) * max(0, yB - yA)
    boxAArea = (bb1[2] - bb1[0]) * (bb1[3] - bb1[1])
    boxBArea = (bb2[2] - bb2[0]) * (bb2[3] - bb2[1])
    return interArea / float(boxAArea + boxBArea - interArea + 1e-5)

#-------------------------------------------------------------------------------

class Sort:
    def __init__(self, max_age=5, iou_threshold=0.3):
        """
        max_age: quanti frame tenere vivi i track senza associazioni
        iou_threshold: soglia minima di IoU per associare una detection a un track
        """
        self.max_age = max_age
        self.iou_threshold = iou_threshold
        self.tracks = []
        self.track_id_count = 0

    def update(self, detections):
        # 1) Predict
        for track in self.tracks:
            track.predict()

        # 2) Association (Hungarian)
        if len(self.tracks) == 0:
            matched, unmatched_dets = [], list(range(len(detections)))
        else:
            iou_matrix = np.zeros((len(self.tracks), len(detections)))
            for t, track in enumerate(self.tracks):
                for d, det in enumerate(detections):
                    iou_matrix[t, d] = iou(track.bbox, det)

            row_ind, col_ind = linear_sum_assignment(-iou_matrix)
            matched, unmatched_dets = [], list(range(len(detections)))
            for r, c in zip(row_ind, col_ind):
                if iou_matrix[r, c] < self.iou_threshold:
                    continue
                matched.append((r, c))
                unmatched_dets.remove(c)

        # 3) Update matched tracks
        for t_idx, d_idx in matched:
            self.tracks[t_idx].update(detections[d_idx])

        # 4) Create new tracks per ogni detection non associata
        for idx in unmatched_dets:
            self.tracks.append(Track(detections[idx], self.track_id_count))
            self.track_id_count += 1

        # 5) Rimuovi i track troppo “vecchi”
        self.tracks = [t for t in self.tracks if t.time_since_update <= self.max_age]

        # Ritorna lista di (id, bbox)
        return [(t.id, t.bbox) for t in self.tracks]

#-------------------------------------------------------------------------------


# ***Main***

In [None]:
# --- CONFIG ---
INPUT_VIDEO   = "/content/drive/MyDrive/LabIA/2.mp4"
OUTPUT_VIDEO  = "output/tracked_2.mp4"
YOLO_WEIGHTS  = "yolov8x.pt"
PLAYER_CONF   = 0.4
BALL_CONF     = 0.1         # abbassato
IMG_SIZE      = 1600
MAX_AGE_BALL  = 10          # tener vivo ball-track più a lungo
IOU_THRESH_B  = 0.2         # per matching della palla
SMOOTH_WINDOW = 5           # quanti frame mediar
# ----------------

def main():
    device = "cuda" if torch.cuda.is_available() else "cpu"
    print(f"Using device: {device}")

    # 1) load YOLO
    model = YOLO(YOLO_WEIGHTS)
    model.model.to(device)

    # 2) init trackers
    player_tracker = Sort()
    ball_tracker   = Sort(max_age=MAX_AGE_BALL, iou_threshold=IOU_THRESH_B)

    # 3) frame reader / writer
    cap = cv2.VideoCapture(INPUT_VIDEO)
    w = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH))
    h = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT))
    fps = cap.get(cv2.CAP_PROP_FPS)
    total = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))

    os.makedirs("output", exist_ok=True)
    out = cv2.VideoWriter(OUTPUT_VIDEO,
                          cv2.VideoWriter_fourcc(*"mp4v"),
                          fps, (w,h))

    pbar = tqdm(total=total, desc="Processing", unit="frame", dynamic_ncols=True)

    # 4) struttura per smoothing
    ball_history = {}  # tid -> deque([bbox,...])

    while True:
        ret, frame = cap.read()
        if not ret:
            break

        # 5) YOLO inference
        results = model(frame, device=device, imgsz=IMG_SIZE)[0]
        dets = results.boxes

        # 6) split detections
        player_dets, ball_dets = [], []
        for box in dets:
            cls  = int(box.cls[0])
            conf = float(box.conf[0])
            x1,y1,x2,y2 = box.xyxy[0].tolist()
            if cls == 0 and conf >= PLAYER_CONF:
                player_dets.append([x1,y1,x2,y2])
            elif cls == 32 and conf >= BALL_CONF:
                ball_dets.append([x1,y1,x2,y2])

        # 7) update trackers
        tracks_p = player_tracker.update(player_dets)
        tracks_b = ball_tracker.update(ball_dets)

        # 8) draw players
        for tid, bbox in tracks_p:
            x1,y1,x2,y2 = map(int, bbox)
            cv2.rectangle(frame, (x1,y1),(x2,y2),(0,255,0),2)
            cv2.putText(frame, f"P{tid}", (x1,y1-6),
                        cv2.FONT_HERSHEY_SIMPLEX, 0.5, (0,255,0),1)

        # 9) smoothing + draw balls
        for tid, bbox in tracks_b:
            # init history
            if tid not in ball_history:
                ball_history[tid] = deque(maxlen=SMOOTH_WINDOW)
            ball_history[tid].append(bbox)

            # media dei bbox
            arr = np.array(ball_history[tid])
            x1, y1, x2, y2 = map(int, arr.mean(axis=0))

            cv2.rectangle(frame, (x1,y1),(x2,y2),(255,200,100),2)
            cv2.putText(frame, f"B{tid}", (x1,y1-6),
                        cv2.FONT_HERSHEY_SIMPLEX, 0.5, (255,200,100),1)

        # 10) write & progress
        out.write(frame)
        pbar.update()

    pbar.close()
    cap.release()
    out.release()
    print(f"\nVideo saved to {OUTPUT_VIDEO}")

if __name__=="__main__":
    main()