# Ball tracking using BaseballCV Model

Source: [BaseballCV Example](https://baseballcv.com/using-repository/yolov9-usage.html)

In [None]:
from baseballcv.model import YOLOv9
from baseballcv.functions import LoadTools
import cv2
import numpy as np


In [30]:

# -------------------------
# 1. Pick device (MPS on Apple Silicon, else CPU)
# -------------------------
if torch.backends.mps.is_available():
    device = "mps"   # Apple GPU
else:
    device = "cpu"

print("Using device:", device)

# -------------------------
# 2. Load ball-tracking model via BaseballCV
# -------------------------
load_tools = LoadTools()
model_path = load_tools.load_model("ball_tracking")  # downloads / resolves ball_tracking.pt
model = YOLO(model_path)  # Ultralytics YOLO model

Using device: mps
2025-11-12 23:40:44,287 - LoadTools - INFO - Model found at models/od/YOLO/ball_tracking/model_weights/ball_tracking.pt


In [31]:
# Process a baseball video
video_path = "data/baseball_pitch.mp4"
cap = cv2.VideoCapture(video_path)
width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH))
height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT))
fps = cap.get(cv2.CAP_PROP_FPS)
output_path = "tracked_pitch.mp4"

# Create output video writer
fourcc = cv2.VideoWriter_fourcc(*'mp4v')
out = cv2.VideoWriter(output_path, fourcc, fps, (width, height))

# Track ball through video
ball_trajectory = []
frame_idx = 0

while cap.isOpened():
    ret, frame = cap.read()
    if not ret:
        break
    
    # Run inference
    results = model.inference(
        source=frame,
        conf_thres=0.35,
        iou_thres=0.45
    )
    
    # Process results
    for detection in results:
        boxes = detection.get('boxes', [])
        scores = detection.get('scores', [])
        labels = detection.get('labels', [])
        
        for box, score, label in zip(boxes, scores, labels):
            if model.model.names[int(label)].lower() == 'baseball':
                x1, y1, x2, y2 = map(int, box)
                center_x = (x1 + x2) // 2
                center_y = (y1 + y2) // 2
                
                # Add to trajectory
                ball_trajectory.append((frame_idx, center_x, center_y))
                
                # Draw box and center
                cv2.rectangle(frame, (x1, y1), (x2, y2), (0, 0, 255), 2)
                cv2.circle(frame, (center_x, center_y), 5, (0, 0, 255), -1)
    
    # Draw trajectory
    if len(ball_trajectory) > 1:
        for i in range(1, len(ball_trajectory)):
            if ball_trajectory[i][0] - ball_trajectory[i-1][0] <= 3:  # Only connect nearby frames
                pt1 = (ball_trajectory[i-1][1], ball_trajectory[i-1][2])
                pt2 = (ball_trajectory[i][1], ball_trajectory[i][2])
                cv2.line(frame, pt1, pt2, (255, 0, 0), 2)
    
    # Write frame
    out.write(frame)
    frame_idx += 1

cap.release()
out.release()
 

AttributeError: 'DetectionModel' object has no attribute 'inference'