In [None]:
%load_ext autoreload

%autoreload 2

from IPython.core.interactiveshell import InteractiveShell

InteractiveShell.ast_node_interactivity = "all"

# Detection Threshold & Playing Field Filter Comparison

Compare the effect of different YOLO confidence thresholds **and** playing field filter configurations on player detection quality. The goal is to find the combination that keeps small distant players while rejecting crowd/UI false positives.

## Imports

In [None]:
import cv2
import matplotlib.pyplot as plt
import numpy as np
import pyrootutils
from matplotlib.patches import Rectangle

from football_tracking_demo.config import load_config
from football_tracking_demo.detector import PlayerDetector
from football_tracking_demo.filtering import filter_detections

## Parameters

In [None]:
root = pyrootutils.setup_root(
    search_from=".",
    indicator="pyproject.toml",
    project_root_env_var=True,
    dotenv=True,
    pythonpath=True,
    cwd=True,
)

VIDEO_PATH = str(root / "data" / "match.mp4")
CONFIG_PATH = str(root / "config" / "config.yaml")

# Confidence thresholds to compare
THRESHOLDS = [0.05, 0.1, 0.20, 0.35, 0.50, 0.65]

# Sample frames to visualize
SAMPLE_FRAME_INDICES = [0, 1500, 3000, 4500, 6000]

# Number of frames for the sweep plot
SWEEP_N_FRAMES = 200

## Load Config & Sample Frames

Load the project config and extract sample frames from the video.

In [None]:
config = load_config(CONFIG_PATH)

cap = cv2.VideoCapture(VIDEO_PATH)
frames = {}
max_idx = max(SAMPLE_FRAME_INDICES)

for i in range(max_idx + 1):
    ret, frame = cap.read()
    if not ret:
        break
    if i in SAMPLE_FRAME_INDICES:
        frames[i] = frame

cap.release()
print(f"Loaded {len(frames)} sample frames: {sorted(frames.keys())}")

## Build Detectors at Each Threshold

Create one `PlayerDetector` per threshold. All other settings (model, HUD mask, NMS) stay the same so only the confidence threshold varies.

In [None]:
detectors = {}
for thr in THRESHOLDS:
    detectors[thr] = PlayerDetector(
        model_name=config["detection"]["model"],
        conf_threshold=thr,
        iou_threshold=config["detection"]["nms_iou_threshold"],
        device=config["detection"]["device"],
        hud_top=config["hud_mask"]["top_percent"],
        hud_bottom=config["hud_mask"]["bottom_percent"],
        hud_enabled=config["hud_mask"]["enabled"],
        shape_filter_config=config.get("detection_shape_filter"),
        field_mask_config=config.get("playing_field_mask"),
    )

print(f"Created {len(detectors)} detectors with thresholds: {THRESHOLDS}")

## Visual Comparison per Frame

For each sample frame, show all thresholds side-by-side with bounding boxes drawn. This makes it easy to spot where lower thresholds introduce false positives and where higher thresholds drop real players.

In [None]:
def draw_dets_on_ax(ax, frame_bgr, detections, title=""):
    """Draw detection boxes on a matplotlib axis."""
    rgb = cv2.cvtColor(frame_bgr, cv2.COLOR_BGR2RGB)
    ax.imshow(rgb)
    for det in detections:
        x1, y1, x2, y2, conf = det
        rect = Rectangle(
            (x1, y1),
            x2 - x1,
            y2 - y1,
            linewidth=1.5,
            edgecolor="lime",
            facecolor="none",
        )
        ax.add_patch(rect)
        ax.text(
            x1,
            y1 - 3,
            f"{conf:.2f}",
            color="lime",
            fontsize=6,
            bbox=dict(boxstyle="round,pad=0.15", facecolor="black", alpha=0.6),
        )
    ax.set_title(title, fontsize=11)
    ax.set_axis_off()


for idx, frame in sorted(frames.items()):
    n = len(THRESHOLDS)
    fig, axes = plt.subplots(1, n, figsize=(6 * n, 6))
    if n == 1:
        axes = [axes]

    for ax, thr in zip(axes, THRESHOLDS):
        dets = detectors[thr].detect_and_filter(frame)
        draw_dets_on_ax(ax, frame, dets, title=f"conf={thr:.2f}  ({len(dets)} dets)")

    fig.suptitle(f"Frame {idx}", fontsize=14, y=1.01)
    plt.tight_layout()
    plt.show()

## Detection Count Sweep

Run each threshold over the first N frames and plot the number of (filtered) detections per frame. A good threshold yields a stable count without wild spikes (false positives) or drops (missed players).

In [None]:
counts = {thr: [] for thr in THRESHOLDS}

cap = cv2.VideoCapture(VIDEO_PATH)

for i in range(SWEEP_N_FRAMES):
    ret, frame = cap.read()
    if not ret:
        break
    for thr in THRESHOLDS:
        dets = detectors[thr].detect_and_filter(frame)
        counts[thr].append(len(dets))

cap.release()
print(f"Processed {i + 1} frames")

In [None]:
fig, ax = plt.subplots(figsize=(14, 5))

for thr in THRESHOLDS:
    _ = ax.plot(counts[thr], label=f"conf={thr:.2f}", alpha=0.8)

_ = ax.set_xlabel("Frame")
_ = ax.set_ylabel("Detection Count (after filtering)")
_ = ax.set_title("Detections per Frame at Different Confidence Thresholds")
_ = ax.legend()
_ = ax.grid(True, alpha=0.3)
plt.tight_layout()
plt.show()

## Summary Statistics

Print mean, median, std, min, and max detection counts per threshold for a quick numerical comparison.

In [None]:
print(f"{'Threshold':>10} {'Mean':>8} {'Median':>8} {'Std':>8} {'Min':>6} {'Max':>6}")
print("-" * 52)

for thr in THRESHOLDS:
    c = np.array(counts[thr])
    print(
        f"{thr:>10.2f} {c.mean():>8.1f} {np.median(c):>8.1f} "
        f"{c.std():>8.1f} {c.min():>6d} {c.max():>6d}"
    )

## Confidence Distribution

Histogram of all detection confidences (before any confidence threshold) on a single sample frame. This shows where the bulk of detections fall and helps pick a threshold that separates real players from noise.

In [None]:
# Use the lowest threshold detector to capture the widest range of confidences
sample_frame = frames[SAMPLE_FRAME_INDICES[1]]
all_dets = detectors[min(THRESHOLDS)].detect(sample_frame)
confs = [d[4] for d in all_dets]

fig, ax = plt.subplots(figsize=(10, 4))
ax.hist(confs, bins=30, edgecolor="black", alpha=0.7)

for thr in THRESHOLDS:
    ax.axvline(x=thr, linestyle="--", linewidth=1.5, label=f"conf={thr:.2f}")

ax.set_xlabel("Confidence")
ax.set_ylabel("Count")
ax.set_title(
    f"Detection Confidence Distribution (Frame {SAMPLE_FRAME_INDICES[1]}, {len(confs)} raw detections)"
)
ax.legend()
ax.grid(True, alpha=0.3)
plt.tight_layout()
plt.show()

## Playing Field Filter Configurations

Define several playing field filter presets to compare. The key trade-off: tight size constraints remove false positives but also drop small/distant players. The green playing field mask can compensate by allowing loose sizes while rejecting off-field detections.

In [None]:
FILTER_PRESETS = {
    "current": {
        "shape": {
            "enabled": True,
            "min_bbox_width": 20,
            "min_bbox_height": 40,
            "max_bbox_width": 300,
            "max_bbox_height": 500,
            "min_aspect_ratio": 0.3,
            "max_aspect_ratio": 4.0,
        },
        "mask": None,
    },
    "loose": {
        "shape": {
            "enabled": True,
            "min_bbox_width": 10,
            "min_bbox_height": 15,
            "max_bbox_width": 400,
            "max_bbox_height": 600,
            "min_aspect_ratio": 0.2,
            "max_aspect_ratio": 5.0,
        },
        "mask": None,
    },
    "loose + field mask": {
        "shape": {
            "enabled": True,
            "min_bbox_width": 10,
            "min_bbox_height": 15,
            "max_bbox_width": 400,
            "max_bbox_height": 600,
            "min_aspect_ratio": 0.2,
            "max_aspect_ratio": 5.0,
        },
        "mask": {
            "enabled": True,
            "hsv_lower": [35, 40, 40],
            "hsv_upper": [85, 255, 255],
        },
    },
    "no filter": {"shape": None, "mask": None},
}

print("Filter presets:", list(FILTER_PRESETS.keys()))

## Raw vs Filtered: What Does the Filter Drop?

Pick a single low confidence threshold and show raw detections (no playing field filter) next to each filter preset. Red boxes = detections that got filtered out, green boxes = detections that survived. This reveals exactly which players are being lost.

In [None]:
INSPECT_THRESHOLD = 0.20

# Use a detector with no playing field filter to get all raw detections
raw_detector = PlayerDetector(
    model_name=config["detection"]["model"],
    conf_threshold=INSPECT_THRESHOLD,
    iou_threshold=config["detection"]["nms_iou_threshold"],
    device=config["detection"]["device"],
    hud_top=config["hud_mask"]["top_percent"],
    hud_bottom=config["hud_mask"]["bottom_percent"],
    hud_enabled=config["hud_mask"]["enabled"],
)


def draw_kept_vs_dropped(ax, frame_bgr, raw_dets, kept_dets, title=""):
    """Green = kept, Red = dropped by filter."""
    rgb = cv2.cvtColor(frame_bgr, cv2.COLOR_BGR2RGB)
    ax.imshow(rgb)

    kept_set = {tuple(d[:4]) for d in kept_dets}

    for det in raw_dets:
        x1, y1, x2, y2, conf = det
        is_kept = tuple(det[:4]) in kept_set
        color = "lime" if is_kept else "red"
        rect = Rectangle(
            (x1, y1),
            x2 - x1,
            y2 - y1,
            linewidth=1.5,
            edgecolor=color,
            facecolor="none",
        )
        ax.add_patch(rect)
        w, h = x2 - x1, y2 - y1
        ax.text(
            x1,
            y2 + 10,
            f"{w:.0f}x{h:.0f}",
            color=color,
            fontsize=5,
            bbox=dict(boxstyle="round,pad=0.1", facecolor="black", alpha=0.5),
        )

    ax.set_title(title, fontsize=10)
    ax.set_axis_off()


for idx, frame in sorted(frames.items()):
    raw_dets = raw_detector.detect(frame)
    n = len(FILTER_PRESETS)
    fig, axes = plt.subplots(1, n, figsize=(6 * n, 6))

    for ax, (name, preset) in zip(axes, FILTER_PRESETS.items()):
        kept = filter_detections(raw_dets, frame, preset["shape"], preset["mask"])
        draw_kept_vs_dropped(
            ax,
            frame,
            raw_dets,
            kept,
            title=f"{name}\n{len(kept)}/{len(raw_dets)} kept",
        )

    fig.suptitle(
        f"Frame {idx} — conf={INSPECT_THRESHOLD}  (green=kept, red=dropped)",
        fontsize=13,
        y=1.02,
    )
    plt.tight_layout()
    plt.show()

## Threshold x Filter Grid

For a single sample frame, show a full grid: rows = filter presets, columns = confidence thresholds. Each cell shows the detection count, making it easy to find the best combination.

In [None]:
GRID_THRESHOLDS = [0.10, 0.20, 0.35, 0.50]
grid_frame_idx = sorted(frames.keys())[1]  # pick a mid-video frame
grid_frame = frames[grid_frame_idx]

# Build unfiltered detectors at each threshold (no playing field filter)
grid_detectors = {}
for thr in GRID_THRESHOLDS:
    grid_detectors[thr] = PlayerDetector(
        model_name=config["detection"]["model"],
        conf_threshold=thr,
        iou_threshold=config["detection"]["nms_iou_threshold"],
        device=config["detection"]["device"],
        hud_top=config["hud_mask"]["top_percent"],
        hud_bottom=config["hud_mask"]["bottom_percent"],
        hud_enabled=config["hud_mask"]["enabled"],
    )

n_rows = len(FILTER_PRESETS)
n_cols = len(GRID_THRESHOLDS)
fig, axes = plt.subplots(n_rows, n_cols, figsize=(5 * n_cols, 4.5 * n_rows))

for r, (filter_name, preset) in enumerate(FILTER_PRESETS.items()):
    for c, thr in enumerate(GRID_THRESHOLDS):
        ax = axes[r, c]
        raw_dets = grid_detectors[thr].detect(grid_frame)
        kept = filter_detections(raw_dets, grid_frame, preset["shape"], preset["mask"])
        draw_dets_on_ax(ax, grid_frame, kept, title=f"{len(kept)} dets")

        if r == 0:
            _ = ax.set_title(f"conf={thr:.2f}\n{len(kept)} dets", fontsize=10)
        if c == 0:
            # Re-enable the y-axis so the label is visible (draw_dets_on_ax
            # calls set_axis_off which hides it). We only need the label,
            # not ticks, so hide those explicitly.
            ax.yaxis.set_visible(True)
            _ = ax.set_ylabel(
                filter_name, fontsize=11, rotation=0, labelpad=100, va="center"
            )
            ax.yaxis.set_ticks([])
            for spine in ax.spines.values():
                spine.set_visible(False)

fig.suptitle(f"Threshold x Filter — Frame {grid_frame_idx}", fontsize=14, y=1.0)
plt.tight_layout()
plt.show()

## Combined Sweep: Filter x Threshold over Time

Run every (filter, threshold) combination over N frames. One subplot per filter preset, each showing detection-count curves for all thresholds. This reveals which combination gives the most stable, reasonable player count.

In [None]:
# Collect counts: combo_counts[(filter_name, threshold)] = [count_per_frame]
combo_counts = {}
for fname in FILTER_PRESETS:
    for thr in GRID_THRESHOLDS:
        combo_counts[(fname, thr)] = []

cap = cv2.VideoCapture(VIDEO_PATH)

for i in range(SWEEP_N_FRAMES):
    ret, frame = cap.read()
    if not ret:
        break
    # Run detection once per threshold, then apply each filter
    raw_cache = {}
    for thr in GRID_THRESHOLDS:
        raw_cache[thr] = grid_detectors[thr].detect(frame)

    for fname, preset in FILTER_PRESETS.items():
        for thr in GRID_THRESHOLDS:
            kept = filter_detections(
                raw_cache[thr], frame, preset["shape"], preset["mask"]
            )
            combo_counts[(fname, thr)].append(len(kept))

cap.release()
print(
    f"Processed {i + 1} frames x {len(FILTER_PRESETS)} filters x {len(GRID_THRESHOLDS)} thresholds"
)

In [None]:
filter_names = list(FILTER_PRESETS.keys())
n_filters = len(filter_names)

fig, axes = plt.subplots(n_filters, 1, figsize=(14, 4 * n_filters), sharex=True)

for ax, fname in zip(axes, filter_names):
    for thr in GRID_THRESHOLDS:
        c = combo_counts[(fname, thr)]
        _ = ax.plot(c, label=f"conf={thr:.2f}", alpha=0.8)
    _ = ax.set_ylabel("Detections")
    _ = ax.set_title(f"Filter: {fname}", fontsize=12)
    _ = ax.legend(loc="upper right", fontsize=8)
    _ = ax.grid(True, alpha=0.3)

_ = axes[-1].set_xlabel("Frame")
fig.suptitle("Detection Count over Time — Filter x Threshold", fontsize=14, y=1.01)
plt.tight_layout()
plt.show()

## Summary Heatmap

Mean detection count per (filter, threshold) combination displayed as a heatmap. Higher values in a cell = more detections kept. Look for the combination that lands in a reasonable range (e.g. 10-25 for a wide-angle broadcast frame with ~22 players visible).

In [None]:
# Build heatmap matrix: rows = filters, cols = thresholds
heatmap = np.zeros((len(filter_names), len(GRID_THRESHOLDS)))

for r, fname in enumerate(filter_names):
    for c, thr in enumerate(GRID_THRESHOLDS):
        heatmap[r, c] = np.mean(combo_counts[(fname, thr)])

fig, ax = plt.subplots(figsize=(8, 5))
im = ax.imshow(heatmap, cmap="YlOrRd", aspect="auto")

_ = ax.set_xticks(range(len(GRID_THRESHOLDS)))
_ = ax.set_xticklabels([f"{t:.2f}" for t in GRID_THRESHOLDS])
_ = ax.set_yticks(range(len(filter_names)))
_ = ax.set_yticklabels(filter_names)
_ = ax.set_xlabel("Confidence Threshold")
_ = ax.set_ylabel("Filter Preset")

# Annotate cells with values
for r in range(len(filter_names)):
    for c in range(len(GRID_THRESHOLDS)):
        val = heatmap[r, c]
        color = "white" if val > heatmap.max() * 0.6 else "black"
        _ = ax.text(
            c, r, f"{val:.1f}", ha="center", va="center", color=color, fontsize=12
        )

_ = fig.colorbar(im, ax=ax, label="Mean detections / frame")
_ = ax.set_title(f"Mean Detections per Frame ({SWEEP_N_FRAMES} frames)", fontsize=13)
plt.tight_layout()
plt.show()

# Print the same data as a table
print(f"\n{'Filter':<22}", end="")
for thr in GRID_THRESHOLDS:
    print(f"  conf={thr:.2f}", end="")
print()
print("-" * (22 + 12 * len(GRID_THRESHOLDS)))
for r, fname in enumerate(filter_names):
    print(f"{fname:<22}", end="")
    for c in range(len(GRID_THRESHOLDS)):
        print(f"  {heatmap[r, c]:>8.1f}", end="")
    print()