"""
END-TO-END: Build dataset from a labeled video, train YOLO (Ultralytics), then compute IoU per class vs Ground Truth.

You have:
- football.mp4
- labels/ frame_000000.txt ... frame_001529.txt   (YOLO normalized GT)
Classes (GT):
0 Team A
1 Team B
2 GoalKeeper A
3 GoalKeeper B
4 Ball
5 Referees

What this script does:
1) Extract frames from the video into dataset/images/all/
2) Split into train/val (by frames) and copy GT labels accordingly
3) Create data.yaml for Ultralytics
4) Train YOLO using a chosen base model (e.g., yolov8n.pt)
5) Run inference on VAL and compute IoU PER CLASS

IMPORTANT EVALUATION RULE :
- We compute IoU ONLY for GT-labeled objects.
- Any prediction with NO GT match is IGNORED (does not affect IoU).
- Matching is per-frame + per-class, greedy: each GT picks the best unused prediction of same class.

Install:
pip install ultralytics opencv-python pyyaml numpy

Run:
python train_and_iou.py

"""


In [1]:
import os
import random
import shutil
from pathlib import Path
import yaml
import cv2
import numpy as np
from ultralytics import YOLO

In [None]:

VIDEO_PATH = "football.mp4"


GT_LABELS_DIR = "labels" 


LABEL_PREFIX = "frame_"
LABEL_DIGITS = 6
LABEL_EXT = ".txt"
FRAME_INDEX_START = 0  


DATASET_ROOT = "dataset_soccer"

# Train/Val split
VAL_RATIO = 0.2
RANDOM_SEED = 42

In [None]:

BASE_MODEL = "yolov8n.pt"   
IMGSZ = 640                 
EPOCHS = 10
BATCH = 2                   
DEVICE = "cpu"                 
PROJECT_DIR = "runs_soccer" 


CONF_THRES = 0.25
IOU_REPORT_THRES = 0.0      
SAVE_VIS_VIDEO = True
VIS_OUT_PATH = "val_iou_overlay.mp4"

# Class names
NAMES = {
    0: "TeamA",
    1: "TeamB",
    2: "GK_A",
    3: "GK_B",
    4: "Ball",
    5: "Referee",
}
NUM_CLASSES = 6


In [None]:

def safe_rmtree(path: Path):
    if path.exists():
        shutil.rmtree(path)


def label_path_for_frame(frame_idx: int) -> Path:
    fid = frame_idx + FRAME_INDEX_START
    fname = f"{LABEL_PREFIX}{fid:0{LABEL_DIGITS}d}{LABEL_EXT}"
    return Path(GT_LABELS_DIR) / fname


def write_yaml(path: Path, data: dict):
    with open(path, "w", encoding="utf-8") as f:
        yaml.safe_dump(data, f, sort_keys=False)


def yolo_norm_to_xyxy(xc, yc, w, h, img_w, img_h):
    x1 = (xc - w / 2.0) * img_w
    y1 = (yc - h / 2.0) * img_h
    x2 = (xc + w / 2.0) * img_w
    y2 = (yc + h / 2.0) * img_h
    return np.array([x1, y1, x2, y2], dtype=np.float32)


In [5]:
def read_gt_labels(label_file: Path, img_w: int, img_h: int):
    """
    Returns list of dicts: {"cls": int, "box": np.array([x1,y1,x2,y2])}
    """
    gts = []
    if not label_file.exists():
        return gts
    with open(label_file, "r", encoding="utf-8") as f:
        for line in f:
            line = line.strip()
            if not line:
                continue
            parts = line.split()
            if len(parts) < 5:
                continue
            cls = int(float(parts[0]))
            xc, yc, w, h = map(float, parts[1:5])
            box = yolo_norm_to_xyxy(xc, yc, w, h, img_w, img_h)
            gts.append({"cls": cls, "box": box})
    return gts


def iou_xyxy(a, b) -> float:
    xA = max(a[0], b[0])
    yA = max(a[1], b[1])
    xB = min(a[2], b[2])
    yB = min(a[3], b[3])

    inter_w = max(0.0, xB - xA)
    inter_h = max(0.0, yB - yA)
    inter = inter_w * inter_h

    area_a = max(0.0, a[2] - a[0]) * max(0.0, a[3] - a[1])
    area_b = max(0.0, b[2] - b[0]) * max(0.0, b[3] - b[1])

    union = area_a + area_b - inter + 1e-6
    return float(inter / union)


In [None]:
def greedy_match_iou(gt_list, pred_list):
    """
    GT-driven greedy matching, SAME CLASS only.
    Ignores predictions without GT automatically (because we iterate GTs).
    Each prediction used once.
    Returns list of (cls, iou)
    """
    matches = []
    used_pred = set()

    preds_by_cls = {}
    for j, p in enumerate(pred_list):
        preds_by_cls.setdefault(p["cls"], []).append((j, p))

    for gt in gt_list:
        c = gt["cls"]
        gt_box = gt["box"]
        best_iou = -1.0
        best_j = None

        for j, p in preds_by_cls.get(c, []):
            if j in used_pred:
                continue
            val = iou_xyxy(gt_box, p["box"])
            if val > best_iou:
                best_iou = val
                best_j = j

        if best_j is not None:
            used_pred.add(best_j)
            if best_iou >= IOU_REPORT_THRES:
                matches.append((c, best_iou))

    return matches


def draw_boxes(frame, boxes, color, prefix=""):
    for b in boxes:
        x1, y1, x2, y2 = b["box"].astype(int)
        cls = b["cls"]
        name = NAMES.get(cls, str(cls))
        text = f"{prefix}{name}"
        if "conf" in b:
            text += f" {b['conf']:.2f}"
        cv2.rectangle(frame, (x1, y1), (x2, y2), color, 2)
        cv2.putText(frame, text, (x1, max(0, y1 - 5)),
                    cv2.FONT_HERSHEY_SIMPLEX, 0.5, color, 2, cv2.LINE_AA)



# Step 1: Extract frames
def extract_frames(video_path: str, out_dir: Path):
    out_dir.mkdir(parents=True, exist_ok=True)
    cap = cv2.VideoCapture(video_path)
    if not cap.isOpened():
        raise FileNotFoundError(f"Cannot open video: {video_path}")

    total = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
    w = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH))
    h = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT))
    fps = cap.get(cv2.CAP_PROP_FPS) or 25.0

    idx = 0
    while True:
        ok, frame = cap.read()
        if not ok:
            break
        
        fname = f"{LABEL_PREFIX}{(idx + FRAME_INDEX_START):0{LABEL_DIGITS}d}.jpg"
        cv2.imwrite(str(out_dir / fname), frame)
        idx += 1

    cap.release()
    return idx, (w, h, fps, total)



In [None]:

def build_dataset(frames_dir: Path, dataset_root: Path, num_frames: int):
    """
    Creates:
    dataset_root/
      images/train, images/val
      labels/train, labels/val
      data.yaml
    """
    
    safe_rmtree(dataset_root)
    (dataset_root / "images" / "train").mkdir(parents=True, exist_ok=True)
    (dataset_root / "images" / "val").mkdir(parents=True, exist_ok=True)
    (dataset_root / "labels" / "train").mkdir(parents=True, exist_ok=True)
    (dataset_root / "labels" / "val").mkdir(parents=True, exist_ok=True)

    
    indices = list(range(num_frames))
    random.seed(RANDOM_SEED)
    random.shuffle(indices)
    val_count = int(round(num_frames * VAL_RATIO))
    val_set = set(indices[:val_count])

    
    for i in range(num_frames):
        img_name = f"{LABEL_PREFIX}{(i + FRAME_INDEX_START):0{LABEL_DIGITS}d}.jpg"
        src_img = frames_dir / img_name

        src_lbl = label_path_for_frame(i)  

        if i in val_set:
            dst_img = dataset_root / "images" / "val" / img_name
            dst_lbl = dataset_root / "labels" / "val" / (img_name.replace(".jpg", ".txt"))
        else:
            dst_img = dataset_root / "images" / "train" / img_name
            dst_lbl = dataset_root / "labels" / "train" / (img_name.replace(".jpg", ".txt"))

        shutil.copy2(src_img, dst_img)

        
        if src_lbl.exists():
            shutil.copy2(src_lbl, dst_lbl)
        else:
            dst_lbl.write_text("", encoding="utf-8")

    
    data_yaml = {
        "path": str(dataset_root.resolve()),
        "train": "images/train",
        "val": "images/val",
        "nc": NUM_CLASSES,
        "names": [NAMES[i] for i in range(NUM_CLASSES)],
    }
    write_yaml(dataset_root / "data.yaml", data_yaml)

    return dataset_root / "data.yaml"



In [None]:

def train_yolo(data_yaml_path: Path):
    model = YOLO(BASE_MODEL)
    results = model.train(
        data=str(data_yaml_path),
        imgsz=IMGSZ,
        epochs=EPOCHS,
        batch=BATCH,
        device=DEVICE,
        project=PROJECT_DIR,
        name="train",
        pretrained=True,
        verbose=True,
    )
    
    best_path = Path(results.save_dir) / "weights" / "best.pt"
    if not best_path.exists():
        
        candidates = list(Path(results.save_dir).glob("**/best.pt"))
        if candidates:
            best_path = candidates[0]
        else:
            raise FileNotFoundError("Could not find best.pt after training.")
    return best_path


In [None]:

def evaluate_iou_on_val(best_weights: Path, data_yaml_path: Path):
    model = YOLO(str(best_weights))

    
    data = yaml.safe_load(Path(data_yaml_path).read_text(encoding="utf-8"))
    root = Path(data["path"])
    val_images_dir = root / data["val"]
    val_labels_dir = root / "labels" / "val"

    
    writer = None
    
    val_imgs = sorted(val_images_dir.glob("*.jpg"))
    if not val_imgs:
        raise RuntimeError("No validation images found. Check dataset build.")

    sample = cv2.imread(str(val_imgs[0]))
    H, W = sample.shape[:2]
    fps = 25.0

    if SAVE_VIS_VIDEO:
        fourcc = cv2.VideoWriter_fourcc(*"mp4v")
        writer = cv2.VideoWriter(VIS_OUT_PATH, fourcc, fps, (W, H))

    per_class_ious = {c: [] for c in range(NUM_CLASSES)}
    all_ious = []

    for img_path in val_imgs:
        frame = cv2.imread(str(img_path))
        if frame is None:
            continue

        
        lbl_path = val_labels_dir / (img_path.stem + ".txt")
        gt = read_gt_labels(lbl_path, W, H)

        
        res = model.predict(frame, conf=CONF_THRES, verbose=False)
        preds = []
        if res and res[0].boxes is not None:
            boxes = res[0].boxes
            xyxy = boxes.xyxy.cpu().numpy() if hasattr(boxes.xyxy, "cpu") else boxes.xyxy.numpy()
            cls = boxes.cls.cpu().numpy().astype(int) if hasattr(boxes.cls, "cpu") else boxes.cls.numpy().astype(int)
            conf = boxes.conf.cpu().numpy() if hasattr(boxes.conf, "cpu") else boxes.conf.numpy()

            for i in range(len(xyxy)):
                
                c = int(cls[i])
                if 0 <= c < NUM_CLASSES:
                    preds.append({"cls": c, "box": xyxy[i].astype(np.float32), "conf": float(conf[i])})

        
        matches = greedy_match_iou(gt, preds)

        for c, val in matches:
            per_class_ious[c].append(val)
            all_ious.append(val)

        if SAVE_VIS_VIDEO and writer is not None:
            vis = frame.copy()
            draw_boxes(vis, gt, (0, 255, 0), prefix="GT: ")
            draw_boxes(vis, preds, (0, 0, 255), prefix="PR: ")
            cv2.putText(vis, f"{img_path.name} | GT:{len(gt)} PR:{len(preds)} M:{len(matches)}",
                        (10, 25), cv2.FONT_HERSHEY_SIMPLEX, 0.7, (255, 255, 255), 2, cv2.LINE_AA)
            writer.write(vis)

    if writer is not None:
        writer.release()

    def mean_or_nan(arr):
        return float(np.mean(arr)) if len(arr) > 0 else float("nan")

    print("\n========== IoU Report (VAL, GT-only) ==========")
    print(f"Best weights: {best_weights}")
    print(f"Validation images: {len(val_imgs)}")
    print(f"Total matched GT boxes: {len(all_ious)}")
    print("------------------------------------------------")
    for c in range(NUM_CLASSES):
        miou = mean_or_nan(per_class_ious[c])
        n = len(per_class_ious[c])
        print(f"Class {c:1d} ({NAMES[c]}): matches={n:5d} | mean IoU={miou:.4f}")
    print("------------------------------------------------")
    print(f"Overall mean IoU: {mean_or_nan(all_ious):.4f}")
    if SAVE_VIS_VIDEO:
        print(f"Saved overlay video: {VIS_OUT_PATH} (GT green, PR red)")
    print("================================================\n")



In [None]:

def main():
    dataset_root = Path(DATASET_ROOT)

    
    frames_dir = Path("frames_all")
    frames_dir.mkdir(parents=True, exist_ok=True)

    print("[1/4] Extracting frames from video...")
    num_frames, meta = extract_frames(VIDEO_PATH, frames_dir)
    w, h, fps, total = meta
    print(f"Extracted {num_frames} frames. Video meta: {w}x{h}, fps={fps:.2f}, reported_frames={total}")

    print("[2/4] Building train/val split and dataset structure...")
    data_yaml_path = build_dataset(frames_dir, dataset_root, num_frames)
    print(f"Created: {data_yaml_path}")

    print("[3/4] Training YOLO...")
    best_weights = train_yolo(data_yaml_path)
    print(f"Training done. Best weights: {best_weights}")

    print("[4/4] Evaluating IoU on validation set...")
    evaluate_iou_on_val(best_weights, data_yaml_path)



if __name__ == "__main__":
    main()

[1/4] Extracting frames from video...
Extracted 1530 frames. Video meta: 1920x1080, fps=29.97, reported_frames=1565
[2/4] Building train/val split and dataset structure...
Created: dataset_soccer\data.yaml
[3/4] Training YOLO...
New https://pypi.org/project/ultralytics/8.4.6 available  Update with 'pip install -U ultralytics'
Ultralytics 8.3.202  Python-3.12.1 torch-2.8.0+cpu CPU (11th Gen Intel Core i7-1165G7 @ 2.80GHz)
[34m[1mengine\trainer: [0magnostic_nms=False, amp=True, augment=False, auto_augment=randaugment, batch=2, bgr=0.0, box=7.5, cache=False, cfg=None, classes=None, close_mosaic=10, cls=0.5, compile=False, conf=None, copy_paste=0.0, copy_paste_mode=flip, cos_lr=False, cutmix=0.0, data=dataset_soccer\data.yaml, degrees=0.0, deterministic=True, device=cpu, dfl=1.5, dnn=False, dropout=0.0, dynamic=False, embed=None, epochs=10, erasing=0.4, exist_ok=False, fliplr=0.5, flipud=0.0, format=torchscript, fraction=1.0, freeze=None, half=False, hsv_h=0.015, hsv_s=0.7, hsv_v=0.4, i