# Basketball AI: Streamlined Player Detection and Tracking

This notebook takes a video and team configuration as input and outputs annotated video with player detection, tracking, and identification.

## Configuration

Set your API keys, source video path, and team rosters here.

In [1]:
import os
from pathlib import Path

# Set your API keys
os.environ["HF_TOKEN"] = "YOUR_HF_TOKEN_HERE"
os.environ["ROBOFLOW_API_KEY"] = "YOUR_ROBOFLOW_API_KEY_HERE"

# Configure paths
HOME = Path.cwd()
SOURCE_VIDEO_PATH = HOME / "your_video.mp4"  # Change this to your video path
OUTPUT_VIDEO_PATH = HOME / "output_annotated.mp4"

# Configure team rosters
TEAM_ROSTERS = {
    "New York Knicks": {
        "55": "Hukporti",
        "1": "Payne",
        "0": "Wright",
        "11": "Brunson",
        "3": "Hart",
        "32": "Towns",
        "44": "Shamet",
        "25": "Bridges",
        "2": "McBride",
        "23": "Robinson",
        "8": "Anunoby",
        "4": "Dadiet",
        "5": "Achiuwa",
        "13": "Kolek"
    },
        "Boston Celtics": {
        "42": "Horford",
        "55": "Scheierman",
        "9": "White",
        "20": "Davison",
        "7": "Brown",
        "0": "Tatum",
        "27": "Walsh",
        "4": "Holiday",
        "8": "Porzingis",
        "40": "Kornet",
        "88": "Queta",
        "11": "Pritchard",
        "30": "Hauser",
        "12": "Craig",
        "26": "Tillman"
    }
}

TEAM_COLORS = {
    "New York Knicks": "#006BB6",
    "Boston Celtics": "#007A33"
}

# Set which team is which cluster (you may need to adjust after first run)
TEAM_NAMES = {
    0: "Boston Celtics",
    1: "New York Knicks",
}

print("HOME:", HOME)
print("Source video:", SOURCE_VIDEO_PATH)
print("Output video:", OUTPUT_VIDEO_PATH)

HOME: /content
Source video: /content/your_video.mp4
Output video: /content/output_annotated.mp4


## Environment Setup

In [None]:
# Check GPU availability
!nvidia-smi

In [None]:
# Install SAM2 real-time
!git clone https://github.com/Gy920/segment-anything-2-real-time.git
%cd {HOME}/segment-anything-2-real-time
!pip install -e . -q
!python setup.py build_ext --inplace
!(cd checkpoints && bash download_ckpts.sh)
%cd {HOME}

In [None]:
# Install dependencies
!pip install -q gdown
!pip install -q inference-gpu
!pip install -q git+https://github.com/roboflow/supervision.git
!pip install -q git+https://github.com/roboflow/sports.git@feat/basketball
!pip install -q transformers num2words
!pip install -q flash-attn --no-build-isolation

os.environ["ONNXRUNTIME_EXECUTION_PROVIDERS"] = "[CUDAExecutionProvider]"

In [None]:
# Download fonts for annotations
!gdown https://drive.google.com/drive/folders/1RBjpI5Xleb58lujeusxH0W5zYMMA4ytO -O {HOME / "fonts"} --folder

## Import Dependencies

In [None]:
from IPython.display import Video
from typing import Dict, List, Optional, Union, Iterable, Tuple
from operator import itemgetter

import cv2
import numpy as np
import torch
from tqdm import tqdm

import supervision as sv
from inference import get_model
from sports.common.team import TeamClassifier

## Load Models

In [None]:
# Load detection model
PLAYER_DETECTION_MODEL_ID = "basketball-player-detection-3-ycjdo/4"
PLAYER_DETECTION_MODEL_CONFIDENCE = 0.4
PLAYER_DETECTION_MODEL_IOU_THRESHOLD = 0.9
PLAYER_DETECTION_MODEL = get_model(model_id=PLAYER_DETECTION_MODEL_ID)

PLAYER_CLASS_IDS = [3, 4, 5, 6, 7]  # player classes
NUMBER_CLASS_ID = 2

# Load number recognition model
NUMBER_RECOGNITION_MODEL_ID = "basketball-jersey-numbers-ocr/3"
NUMBER_RECOGNITION_MODEL = get_model(model_id=NUMBER_RECOGNITION_MODEL_ID)
NUMBER_RECOGNITION_MODEL_PROMPT = "Read the number."

# Load SAM2 tracking model
%cd {HOME}/segment-anything-2-real-time
from sam2.build_sam import build_sam2_camera_predictor

SAM2_CHECKPOINT = "checkpoints/sam2.1_hiera_large.pt"
SAM2_CONFIG = "configs/sam2.1/sam2.1_hiera_l.yaml"
predictor = build_sam2_camera_predictor(SAM2_CONFIG, SAM2_CHECKPOINT)
%cd {HOME}

print("All models loaded successfully!")

## Utility Functions

In [None]:
def filter_segments_by_distance(mask: np.ndarray, distance_threshold: float = 300) -> np.ndarray:
    """Keeps the main segment and removes segments farther than distance_threshold."""
    assert mask.dtype == bool, "Input mask must be boolean."
    mask_uint8 = mask.astype(np.uint8)
    num_labels, labels, stats, centroids = cv2.connectedComponentsWithStats(mask_uint8, connectivity=8)
    if num_labels <= 1:
        return mask.copy()
    main_label = 1 + np.argmax(stats[1:, cv2.CC_STAT_AREA])
    main_centroid = centroids[main_label]
    filtered_mask = np.zeros_like(mask, dtype=bool)
    for label in range(1, num_labels):
        centroid = centroids[label]
        dist = np.linalg.norm(centroid - main_centroid)
        if label == main_label or dist <= distance_threshold:
            filtered_mask[labels == label] = True
    return filtered_mask


def shrink_boxes(xyxy: np.ndarray, scale: float) -> np.ndarray:
    """Shrinks bounding boxes by a given scale factor while keeping their centers fixed."""
    x1, y1, x2, y2 = xyxy[:, 0], xyxy[:, 1], xyxy[:, 2], xyxy[:, 3]
    cx, cy = (x1 + x2) / 2, (y1 + y2) / 2
    w, h = (x2 - x1) * scale, (y2 - y1) * scale
    new_x1, new_y1 = cx - w / 2, cy - h / 2
    new_x2, new_y2 = cx + w / 2, cy + h / 2
    return np.stack([new_x1, new_y1, new_x2, new_y2], axis=1)


def xyxy_to_mask(boxes: np.ndarray, resolution_wh: Tuple[int, int]) -> np.ndarray:
    """Converts bounding boxes into bool masks."""
    width, height = resolution_wh
    n = boxes.shape[0]
    masks = np.zeros((n, height, width), dtype=bool)
    for i, (x_min, y_min, x_max, y_max) in enumerate(boxes):
        x_min = max(0, int(x_min))
        y_min = max(0, int(y_min))
        x_max = min(width - 1, int(x_max))
        y_max = min(height - 1, int(y_max))
        if x_max >= x_min and y_max >= y_min:
            masks[i, y_min:y_max + 1, x_min:x_max + 1] = True
    return masks


def coords_above_threshold(matrix: np.ndarray, threshold: float, sort_desc: bool = True) -> List[Tuple[int, int]]:
    """Return all (row_index, col_index) where value > threshold."""
    A = np.asarray(matrix)
    rows, cols = np.where(A > threshold)
    pairs = list(zip(rows.tolist(), cols.tolist()))
    if sort_desc:
        pairs.sort(key=lambda rc: A[rc[0], rc[1]], reverse=True)
    return pairs


Value = Union[int, str, None]

class PropertyValidator:
    """Validate a property per tracker_id after N consecutive identical reports."""

    def __init__(self, n_consecutive: int):
        if n_consecutive < 1:
            raise ValueError("n_consecutive must be >= 1")
        self.n = n_consecutive
        self._streak: Dict[int, int] = {}
        self._last: Dict[int, Optional[str]] = {}
        self._validated: Dict[int, Optional[str]] = {}

    def reset_id(self, tracker_id: int) -> None:
        self._streak.pop(tracker_id, None)
        self._last.pop(tracker_id, None)
        self._validated.pop(tracker_id, None)

    def reset_all(self) -> None:
        self._streak.clear()
        self._last.clear()
        self._validated.clear()

    def _normalize(self, value: Value) -> Optional[str]:
        if value is None:
            return None
        s = str(value).strip()
        return s if s else None

    def update(self, tracker_ids: List[int], values: List[Value]) -> List[Optional[str]]:
        if len(tracker_ids) != len(values):
            raise ValueError("tracker_ids and values must have the same length")
        output: List[Optional[str]] = []
        for tid, raw in zip(tracker_ids, values):
            if tid in self._validated and self._validated[tid] is not None:
                output.append(self._validated[tid])
                continue
            val = self._normalize(raw)
            if val is None:
                output.append(None)
                self._streak.setdefault(tid, 0)
                self._last.setdefault(tid, None)
                self._validated.setdefault(tid, None)
                continue
            last = self._last.get(tid)
            if last == val:
                self._streak[tid] = self._streak.get(tid, 0) + 1
            else:
                self._streak[tid] = 1
                self._last[tid] = val
            if self._streak[tid] >= self.n:
                self._validated[tid] = self._last[tid]
            output.append(self._validated.get(tid))
            self._last.setdefault(tid, val)
            self._validated.setdefault(tid, None)
        return output

    def get_validated(self, tracker_ids: Union[int, Iterable[int]]) -> Union[Optional[str], List[Optional[str]]]:
        """Query validated properties for one or many tracker_ids."""
        if isinstance(tracker_ids, int):
            return self._validated.get(tracker_ids)
        return [self._validated.get(tid) for tid in tracker_ids]

    def validated_dict(self) -> Dict[int, str]:
        """Snapshot of validated assignments. Excludes None values."""
        return {tid: val for tid, val in self._validated.items() if val is not None}

## Train Team Classifier

Collect player crops from the video and train a classifier to distinguish between teams.

In [None]:
print("Collecting training samples for team classification...")
STRIDE = 30
crops = []

frame_generator = sv.get_video_frames_generator(source_path=SOURCE_VIDEO_PATH, stride=STRIDE)

for frame in tqdm(frame_generator, desc="Extracting player crops"):
    result = PLAYER_DETECTION_MODEL.infer(
        frame,
        confidence=PLAYER_DETECTION_MODEL_CONFIDENCE,
        iou_threshold=PLAYER_DETECTION_MODEL_IOU_THRESHOLD,
        class_agnostic_nms=True
    )[0]
    detections = sv.Detections.from_inference(result)
    detections = detections[np.isin(detections.class_id, PLAYER_CLASS_IDS)]

    boxes = shrink_boxes(xyxy=detections.xyxy, scale=0.4)
    for box in boxes:
        crops.append(sv.crop_image(frame, box))

print(f"Collected {len(crops)} player crops")

# Train team classifier
print("Training team classifier...")
team_classifier = TeamClassifier(device="cuda")
team_classifier.fit(crops)
print("Team classifier trained successfully!")

## Process Video

This is the main processing loop that:
1. Detects players in the first frame
2. Tracks them throughout the video using SAM2
3. Classifies them into teams
4. Recognizes jersey numbers
5. Creates annotated output

In [None]:
print("Starting video processing...")

# Initialize validators
number_validator = PropertyValidator(n_consecutive=3)
team_validator = PropertyValidator(n_consecutive=1)

# Storage for frames and detections
frames_history = []
detections_history = []

# Get first frame and initialize tracking
frame_generator = sv.get_video_frames_generator(SOURCE_VIDEO_PATH)
frame = next(frame_generator)

print("Detecting players in first frame...")
result = PLAYER_DETECTION_MODEL.infer(
    frame,
    confidence=PLAYER_DETECTION_MODEL_CONFIDENCE,
    iou_threshold=PLAYER_DETECTION_MODEL_IOU_THRESHOLD,
    class_agnostic_nms=True
)[0]
detections = sv.Detections.from_inference(result)
detections = detections[np.isin(detections.class_id, PLAYER_CLASS_IDS)]

# Assign tracker IDs
TRACKER_ID = list(range(1, len(detections.class_id) + 1))

# Classify teams
boxes = shrink_boxes(xyxy=detections.xyxy, scale=0.4)
crops = [sv.crop_image(frame, box) for box in boxes]
TEAMS = np.array(team_classifier.predict(crops))
team_validator.update(tracker_ids=TRACKER_ID, values=TEAMS)

# Initialize SAM2 tracking
print("Initializing SAM2 tracking...")
with torch.inference_mode(), torch.autocast("cuda", dtype=torch.bfloat16):
    predictor.load_first_frame(frame)
    for xyxy, tracker_id in zip(detections.xyxy, TRACKER_ID):
        xyxy = np.array([xyxy])
        _, object_ids, mask_logits = predictor.add_new_prompt(
            frame_idx=0,
            obj_id=tracker_id,
            bbox=xyxy
        )

# Process all frames
print("Processing video frames...")
frame_generator = sv.get_video_frames_generator(SOURCE_VIDEO_PATH)

for index, frame in tqdm(enumerate(frame_generator), desc="Processing frames"):
    frame_h, frame_w, *_ = frame.shape

    with torch.inference_mode(), torch.autocast("cuda", dtype=torch.bfloat16):
        tracker_ids, mask_logits = predictor.track(frame)
        tracker_ids = np.array(tracker_ids)
        masks = (mask_logits > 0.0).cpu().numpy()
        masks = np.squeeze(masks).astype(bool)

        player_masks = np.array([
            filter_segments_by_distance(mask, distance_threshold=300)
            for mask in masks
        ])

        player_detections = sv.Detections(
            xyxy=sv.mask_to_xyxy(masks=player_masks),
            mask=player_masks,
            tracker_id=tracker_ids
        )

        frames_history.append(frame)
        detections_history.append(player_detections)

        # Perform number recognition at intervals
        if index % 5 == 0:
            result = PLAYER_DETECTION_MODEL.infer(
                frame,
                confidence=PLAYER_DETECTION_MODEL_CONFIDENCE,
                iou_threshold=PLAYER_DETECTION_MODEL_IOU_THRESHOLD
            )[0]
            number_detections = sv.Detections.from_inference(result)
            number_detections = number_detections[number_detections.class_id == NUMBER_CLASS_ID]

            if len(number_detections) > 0:
                number_detections.mask = xyxy_to_mask(
                    boxes=number_detections.xyxy,
                    resolution_wh=(frame_w, frame_h)
                )

                # Recognize numbers
                number_crops = [
                    sv.crop_image(frame, xyxy)
                    for xyxy in sv.pad_boxes(xyxy=number_detections.xyxy, px=10, py=10)
                ]
                numbers = [
                    NUMBER_RECOGNITION_MODEL.predict(number_crop, NUMBER_RECOGNITION_MODEL_PROMPT)[0]
                    for number_crop in number_crops
                ]

                # Match numbers to players
                iou = sv.mask_iou_batch(
                    masks_true=player_masks,
                    masks_detection=number_detections.mask,
                    overlap_metric=sv.OverlapMetric.IOS
                )

                pairs = coords_above_threshold(iou, 0.9)
                if pairs:
                    player_idx, number_idx = zip(*pairs)
                    player_idx = [i + 1 for i in player_idx]
                    numbers = list(itemgetter(*number_idx)(numbers))
                    number_validator.update(tracker_ids=player_idx, values=numbers)

print(f"Processed {len(frames_history)} frames")

## Generate Annotated Output Video

In [None]:
print("Generating annotated video...")

TEMP_OUTPUT_PATH = HOME / "temp_output.mp4"
video_info = sv.VideoInfo.from_video_path(SOURCE_VIDEO_PATH)

# Setup annotators
team_colors = sv.ColorPalette.from_hex([
    TEAM_COLORS[TEAM_NAMES[0]],
    TEAM_COLORS[TEAM_NAMES[1]]
])

team_mask_annotator = sv.MaskAnnotator(
    color=team_colors,
    opacity=0.5,
    color_lookup=sv.ColorLookup.INDEX
)

team_label_annotator = sv.RichLabelAnnotator(
    font_path=f"{HOME}/fonts/Staatliches-Regular.ttf",
    font_size=40,
    color=team_colors,
    text_color=sv.Color.WHITE,
    text_position=sv.Position.BOTTOM_CENTER,
    text_offset=(0, 10),
    color_lookup=sv.ColorLookup.INDEX
)

# Write annotated video
with sv.VideoSink(str(TEMP_OUTPUT_PATH), video_info) as sink:
    for frame, detections in tqdm(zip(frames_history, detections_history), desc="Annotating frames", total=len(frames_history)):
        detections = detections[detections.area > 100]

        teams = team_validator.get_validated(tracker_ids=detections.tracker_id)
        teams = np.array(teams).astype(int)
        numbers = number_validator.get_validated(tracker_ids=detections.tracker_id)
        numbers = np.array(numbers)

        labels = [
            f"#{number} {TEAM_ROSTERS[TEAM_NAMES[team]].get(number, '')}"
            for number, team in zip(numbers, teams)
        ]

        annotated_frame = frame.copy()
        annotated_frame = team_mask_annotator.annotate(
            scene=annotated_frame,
            detections=detections,
            custom_color_lookup=teams
        )
        annotated_frame = team_label_annotator.annotate(
            scene=annotated_frame,
            detections=detections,
            labels=labels,
            custom_color_lookup=teams
        )

        sink.write_frame(annotated_frame)

# Compress output
print("Compressing output video...")
!ffmpeg -y -loglevel error -i {TEMP_OUTPUT_PATH} -vcodec libx264 -crf 28 {OUTPUT_VIDEO_PATH}

# Clean up temp file
import os
if os.path.exists(TEMP_OUTPUT_PATH):
    os.remove(TEMP_OUTPUT_PATH)

print(f"\nProcessing complete! Output saved to: {OUTPUT_VIDEO_PATH}")

## Display Output Video

In [None]:
Video(OUTPUT_VIDEO_PATH, embed=True, width=1080)