In [5]:

import os
import torch
import torchvision.transforms as T
import matplotlib.pyplot as plt
from collections import defaultdict
from PIL import Image
from tqdm import tqdm
from ultralytics import YOLO  # Ensure you have `pip install ultralytics`
import random


# --- CONFIG ---
# DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "mps" if torch.backends.mps.is_available() else "cpu")
print(f"Using device: {DEVICE}") 
BASE_DATA_ROOT = "../soccernet_data/tracking"
GT_FILENAME = "gt.txt"
IMAGE_FOLDER = "img1"
IMAGE_EXTS = ['.jpg', '.png']
NUM_VISUALS = 10
SCORE_THRESH = 0.5
IOU_THRESH = 0.5
SAMPLE_PER_SEQ = 30

Using device: mps


In [6]:


# --- LOAD MODEL ---
model = YOLO("discovery-runs/detect/ft-50epoch/weights/best.pt").to(DEVICE)
# model.half().to(DEVICE)
# model.to(DEVICE)

transform = T.ToTensor()

def load_gt_boxes(gt_path):
    gt_dict = defaultdict(list)
    if not os.path.exists(gt_path):
        return gt_dict
    with open(gt_path, 'r') as f:
        for line in f:
            parts = line.strip().split(',')
            frame, _, x, y, w, h, cls, _, _ = map(int, parts[:9])
            gt_dict[frame].append(torch.tensor([x, y, x + w, y + h], device=DEVICE))
    return gt_dict

def compute_iou(box1, box2):
    if box1.size(0) == 0 or box2.size(0) == 0:
        return torch.zeros((box1.size(0), box2.size(0)), device=box1.device)
    area1 = (box1[:, 2] - box1[:, 0]) * (box1[:, 3] - box1[:, 1])
    area2 = (box2[:, 2] - box2[:, 0]) * (box2[:, 3] - box2[:, 1])
    lt = torch.max(box1[:, None, :2], box2[:, :2])
    rb = torch.min(box1[:, None, 2:], box2[:, 2:])
    wh = (rb - lt).clamp(min=0)
    inter = wh[:, :, 0] * wh[:, :, 1]
    union = area1[:, None] + area2 - inter
    iou = inter / union
    return iou

def plot_gt_and_detections(image_tensor, detections, gt_boxes):
    from torchvision.utils import draw_bounding_boxes
    all_boxes = []
    labels = []
    colors = []

    for box in detections:
        all_boxes.append(box)
        labels.append("pred")
        colors.append("red")

    for box in gt_boxes:
        all_boxes.append(box)
        labels.append("gt")
        colors.append("green")

    if not all_boxes:
        return T.ToPILImage()(image_tensor)
    
    boxes_tensor = torch.stack(all_boxes).cpu()

    
    x1 = torch.min(boxes_tensor[:, 0], boxes_tensor[:, 2])
    y1 = torch.min(boxes_tensor[:, 1], boxes_tensor[:, 3])
    x2 = torch.max(boxes_tensor[:, 0], boxes_tensor[:, 2])
    y2 = torch.max(boxes_tensor[:, 1], boxes_tensor[:, 3])
    boxes_tensor = torch.stack([x1, y1, x2, y2], dim=1).to(torch.int)
    img_uint8 = (image_tensor * 255).byte().cpu()
    drawn = draw_bounding_boxes(img_uint8, boxes_tensor, labels=labels, colors=colors, width=2)
    return T.ToPILImage()(drawn)

# --- EXECUTION ---
results = []
sample_frames = []
total_tp = total_fp = total_fn = 0

seq_dirs = []
for split in ["train", "test"]:
    split_dir = os.path.join(BASE_DATA_ROOT, split)
    if not os.path.exists(split_dir):
        continue
    for d in sorted(os.listdir(split_dir)):
        full_path = os.path.join(split_dir, d)
        if os.path.isdir(full_path):
            seq_dirs.append((split, d))

print("Using device:", DEVICE)
print("Processing sequences...")

for split, seq_id in tqdm(seq_dirs, desc="Sequences", dynamic_ncols=True):
    seq_path = os.path.join(BASE_DATA_ROOT, split, seq_id)
    img_dir = os.path.join(seq_path, IMAGE_FOLDER)
    gt_path = os.path.join(seq_path, "gt", GT_FILENAME)
    gt_dict = load_gt_boxes(gt_path)

    if not os.path.exists(img_dir):
        continue

    all_img_paths = sorted([
        os.path.join(img_dir, file)
        for file in os.listdir(img_dir)
        if any(file.lower().endswith(ext) for ext in IMAGE_EXTS)
    ])

    random.shuffle(all_img_paths)
    all_img_paths = all_img_paths[:SAMPLE_PER_SEQ]

    for path in all_img_paths:
        try:
            img = Image.open(path).convert("RGB")
        except:
            continue

        img_tensor = transform(img).unsqueeze(0)
        img_tensor = img_tensor.half()
        img_tensor = img_tensor.squeeze(0) 

        img_tensor = img_tensor.to(DEVICE)



        filename = os.path.basename(path)
        try:
            frame_id = int(filename.split('.')[0])
        except:
            continue

        gt_boxes = gt_dict.get(frame_id, [])

        with torch.no_grad():
            yolo_result = model(img)[0]  # Get first result
            preds = yolo_result.boxes.data.to(DEVICE) if yolo_result.boxes is not None else torch.empty((0, 6)).to(DEVICE)

        pred_boxes = preds[:, :4][preds[:, 4] > SCORE_THRESH] if len(preds) else torch.empty((0, 4), device=DEVICE)

        # Accuracy
        if gt_boxes:
            gt_tensor = torch.stack(gt_boxes).to(DEVICE)
            if len(pred_boxes) > 0:
                ious = compute_iou(pred_boxes, gt_tensor)
                max_ious = ious.max(dim=1)[0]
                acc = (max_ious > IOU_THRESH).float().mean().item()
            else:
                acc = 0.0
        else:
            acc = 1.0 if len(pred_boxes) == 0 else 0.0

        results.append(acc)

        # Precision/Recall
        matched_gt = set()
        tp = fp = 0
        if len(pred_boxes) > 0 and len(gt_boxes) > 0:
            ious = compute_iou(pred_boxes, gt_tensor)
            for i in range(len(pred_boxes)):
                max_iou, gt_idx = ious[i].max(0)
                if max_iou > IOU_THRESH and gt_idx.item() not in matched_gt:
                    tp += 1
                    matched_gt.add(gt_idx.item())
                else:
                    fp += 1
        else:
            tp = 0
            fp = len(pred_boxes)

        fn = len(gt_boxes) - len(matched_gt)
        total_tp += tp
        total_fp += fp
        total_fn += fn

        if len(sample_frames) < NUM_VISUALS and seq_id not in [s[0] for s in sample_frames]:
            img_vis = plot_gt_and_detections(img_tensor, pred_boxes, gt_boxes)
            sample_frames.append((seq_id, filename, img_vis))
            

# --- METRICS ---
precision = total_tp / (total_tp + total_fp + 1e-6)
recall = total_tp / (total_tp + total_fn + 1e-6)
avg_acc = sum(results) / len(results) if results else 0
print(f"\nAverage Detection Accuracy over {len(results)} frames: {avg_acc * 100:.2f}%")
print(f"Precision: {precision:.3f}, Recall: {recall:.3f}")

# --- SHOW EXAMPLES ---
for seq_id, filename, img in sample_frames:
    plt.imshow(img)
    plt.title(f"Sequence {seq_id}, Frame {filename}\nRed = Prediction, Green = Ground Truth")
    plt.axis("off")
    plt.show()


Using device: mps
Processing sequences...


Sequences:   0%|          | 0/106 [00:00<?, ?it/s]


0: 384x640 18 players, 1 referee, 48.9ms
Speed: 2.4ms preprocess, 48.9ms inference, 13.9ms postprocess per image at shape (1, 3, 384, 640)

0: 384x640 12 players, 1 referee, 15.2ms
Speed: 2.1ms preprocess, 15.2ms inference, 14.0ms postprocess per image at shape (1, 3, 384, 640)

0: 384x640 1 goalkeeper, 15 players, 2 referees, 17.9ms
Speed: 2.1ms preprocess, 17.9ms inference, 13.9ms postprocess per image at shape (1, 3, 384, 640)

0: 384x640 13 players, 1 referee, 15.8ms
Speed: 2.1ms preprocess, 15.8ms inference, 14.5ms postprocess per image at shape (1, 3, 384, 640)

0: 384x640 14 players, 2 referees, 16.2ms
Speed: 1.9ms preprocess, 16.2ms inference, 8.9ms postprocess per image at shape (1, 3, 384, 640)

0: 384x640 1 goalkeeper, 15 players, 2 referees, 15.8ms
Speed: 2.0ms preprocess, 15.8ms inference, 12.7ms postprocess per image at shape (1, 3, 384, 640)

0: 384x640 14 players, 15.4ms
Speed: 2.1ms preprocess, 15.4ms inference, 9.2ms postprocess per image at shape (1, 3, 384, 640)

0

Sequences:   1%|          | 1/106 [00:05<10:17,  5.88s/it]


0: 384x640 13 players, 30.6ms
Speed: 2.2ms preprocess, 30.6ms inference, 14.2ms postprocess per image at shape (1, 3, 384, 640)

0: 384x640 22 players, 2 referees, 15.3ms
Speed: 2.1ms preprocess, 15.3ms inference, 13.7ms postprocess per image at shape (1, 3, 384, 640)

0: 384x640 1 ball, 20 players, 2 referees, 17.2ms
Speed: 2.1ms preprocess, 17.2ms inference, 14.4ms postprocess per image at shape (1, 3, 384, 640)

0: 384x640 1 ball, 7 players, 16.9ms
Speed: 2.8ms preprocess, 16.9ms inference, 14.1ms postprocess per image at shape (1, 3, 384, 640)

0: 384x640 1 ball, 19 players, 1 referee, 16.8ms
Speed: 2.2ms preprocess, 16.8ms inference, 16.1ms postprocess per image at shape (1, 3, 384, 640)

0: 384x640 13 players, 1 referee, 58.3ms
Speed: 2.3ms preprocess, 58.3ms inference, 19.0ms postprocess per image at shape (1, 3, 384, 640)

0: 384x640 21 players, 3 referees, 15.6ms
Speed: 2.3ms preprocess, 15.6ms inference, 13.8ms postprocess per image at shape (1, 3, 384, 640)

0: 384x640 21 p

Sequences:   2%|▏         | 2/106 [00:11<09:56,  5.73s/it]


0: 384x640 16 players, 1 referee, 35.0ms
Speed: 2.1ms preprocess, 35.0ms inference, 23.2ms postprocess per image at shape (1, 3, 384, 640)

0: 384x640 15 players, 2 referees, 15.3ms
Speed: 2.3ms preprocess, 15.3ms inference, 13.3ms postprocess per image at shape (1, 3, 384, 640)

0: 384x640 19 players, 16.3ms
Speed: 2.5ms preprocess, 16.3ms inference, 18.3ms postprocess per image at shape (1, 3, 384, 640)

0: 384x640 17 players, 2 referees, 17.2ms
Speed: 2.0ms preprocess, 17.2ms inference, 14.8ms postprocess per image at shape (1, 3, 384, 640)

0: 384x640 14 players, 1 referee, 14.4ms
Speed: 2.0ms preprocess, 14.4ms inference, 12.5ms postprocess per image at shape (1, 3, 384, 640)

0: 384x640 1 ball, 16 players, 2 referees, 14.6ms
Speed: 2.6ms preprocess, 14.6ms inference, 13.7ms postprocess per image at shape (1, 3, 384, 640)

0: 384x640 23 players, 14.3ms
Speed: 2.0ms preprocess, 14.3ms inference, 7.9ms postprocess per image at shape (1, 3, 384, 640)

0: 384x640 19 players, 13.9ms
S

Sequences:   3%|▎         | 3/106 [00:17<09:55,  5.78s/it]


0: 384x640 17 players, 3 referees, 35.2ms
Speed: 2.2ms preprocess, 35.2ms inference, 9.7ms postprocess per image at shape (1, 3, 384, 640)

0: 384x640 16 players, 2 referees, 15.2ms
Speed: 2.5ms preprocess, 15.2ms inference, 14.9ms postprocess per image at shape (1, 3, 384, 640)

0: 384x640 17 players, 2 referees, 14.9ms
Speed: 2.0ms preprocess, 14.9ms inference, 13.4ms postprocess per image at shape (1, 3, 384, 640)

0: 384x640 24 players, 3 referees, 14.8ms
Speed: 2.2ms preprocess, 14.8ms inference, 13.0ms postprocess per image at shape (1, 3, 384, 640)

0: 384x640 17 players, 1 referee, 15.8ms
Speed: 2.1ms preprocess, 15.8ms inference, 13.2ms postprocess per image at shape (1, 3, 384, 640)

0: 384x640 19 players, 3 referees, 14.2ms
Speed: 2.2ms preprocess, 14.2ms inference, 13.2ms postprocess per image at shape (1, 3, 384, 640)

0: 384x640 1 ball, 26 players, 2 referees, 14.3ms
Speed: 2.0ms preprocess, 14.3ms inference, 13.0ms postprocess per image at shape (1, 3, 384, 640)

0: 384

Sequences:   4%|▍         | 4/106 [00:22<09:35,  5.64s/it]


0: 384x640 21 players, 2 referees, 35.4ms
Speed: 2.1ms preprocess, 35.4ms inference, 16.1ms postprocess per image at shape (1, 3, 384, 640)

0: 384x640 26 players, 3 referees, 17.4ms
Speed: 2.4ms preprocess, 17.4ms inference, 8.5ms postprocess per image at shape (1, 3, 384, 640)

0: 384x640 26 players, 2 referees, 18.7ms
Speed: 2.4ms preprocess, 18.7ms inference, 15.4ms postprocess per image at shape (1, 3, 384, 640)

0: 384x640 23 players, 2 referees, 16.6ms
Speed: 2.4ms preprocess, 16.6ms inference, 13.2ms postprocess per image at shape (1, 3, 384, 640)

0: 384x640 23 players, 2 referees, 16.4ms
Speed: 2.2ms preprocess, 16.4ms inference, 13.9ms postprocess per image at shape (1, 3, 384, 640)

0: 384x640 26 players, 2 referees, 14.3ms
Speed: 2.4ms preprocess, 14.3ms inference, 7.5ms postprocess per image at shape (1, 3, 384, 640)

0: 384x640 20 players, 1 referee, 16.5ms
Speed: 2.8ms preprocess, 16.5ms inference, 14.1ms postprocess per image at shape (1, 3, 384, 640)

0: 384x640 1 go

Sequences:   5%|▍         | 5/106 [00:29<10:05,  5.99s/it]


0: 384x640 8 players, 1 referee, 35.8ms
Speed: 2.3ms preprocess, 35.8ms inference, 67.1ms postprocess per image at shape (1, 3, 384, 640)

0: 384x640 22 players, 1 referee, 15.7ms
Speed: 2.2ms preprocess, 15.7ms inference, 9.6ms postprocess per image at shape (1, 3, 384, 640)

0: 384x640 13 players, 1 referee, 14.9ms
Speed: 1.9ms preprocess, 14.9ms inference, 14.2ms postprocess per image at shape (1, 3, 384, 640)

0: 384x640 22 players, 1 referee, 14.7ms
Speed: 2.2ms preprocess, 14.7ms inference, 15.1ms postprocess per image at shape (1, 3, 384, 640)

0: 384x640 13 players, 3 referees, 18.9ms
Speed: 2.6ms preprocess, 18.9ms inference, 16.6ms postprocess per image at shape (1, 3, 384, 640)

0: 384x640 20 players, 1 referee, 17.8ms
Speed: 2.8ms preprocess, 17.8ms inference, 14.6ms postprocess per image at shape (1, 3, 384, 640)

0: 384x640 1 goalkeeper, 23 players, 1 referee, 16.7ms
Speed: 2.1ms preprocess, 16.7ms inference, 14.7ms postprocess per image at shape (1, 3, 384, 640)

0: 384

Sequences:   5%|▍         | 5/106 [00:33<11:24,  6.77s/it]


KeyboardInterrupt: 