In [1]:
import torch
from torch.utils.data import Dataset, DataLoader
from torchvision import models, transforms

import json
import cv2
import numpy as np

In [2]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

print(device)

cuda


## Baseball Detection Model

In [15]:
from ultralytics import YOLO

model = YOLO("../runs/detect/train13/weights/best.pt")
# results = model.predict(source="baseball-detection-2-4/baseball-detection-2-4/test/images/0eb9ef3c-255_jpg.rf.a9da7151055bf3395bedbdc72c61c4e3.jpg", task="detect", save=True)
results = model.predict("../kate work/media/shohei_pitch(14).png", task="detect", save=True)



image 1/1 /home/kmvu/private/AI-Pitch-Recognition/kate work/../kate work/media/shohei_pitch(14).png: 320x544 (no detections), 14.7ms
Speed: 1.9ms preprocess, 14.7ms inference, 0.7ms postprocess per image at shape (1, 3, 320, 544)
Results saved to [1mruns/detect/predict3[0m


## Batter and Home Plate Model Predict Image

In [13]:
from ultralytics import YOLO

model = YOLO("../nicks_work/runs/detect/train12/weights/best.pt")
# results = model.predict(source="baseball-detection-2-4/baseball-detection-2-4/test/images/0eb9ef3c-255_jpg.rf.a9da7151055bf3395bedbdc72c61c4e3.jpg", task="detect", save=True)
results = model.predict("../kate work/media/shohei_pitch(14).png", task="detect", save=True)



image 1/1 /home/kmvu/private/AI-Pitch-Recognition/kate work/../kate work/media/shohei_pitch(14).png: 320x544 1 batter, 12.6ms
Speed: 2.3ms preprocess, 12.6ms inference, 446.1ms postprocess per image at shape (1, 3, 320, 544)
Results saved to [1mruns/detect/predict2[0m


## Batter and Home Plate Model Predict Video

In [20]:
from ultralytics import YOLO, settings

model = YOLO("../nicks_work/runs/detect/train12/weights/best.pt")

result = model.track("../kate work/strike_videos/All of Shohei Ohtani's 2022 Strikeouts 720 (14).mp4",conf=0.2, save=True)



errors for large sources or long-running streams and videos. See https://docs.ultralytics.com/modes/predict/ for help.

Example:
    results = model(source=..., stream=True)  # generator of Results objects
    for r in results:
        boxes = r.boxes  # Boxes object for bbox outputs
        masks = r.masks  # Masks object for segment masks outputs
        probs = r.probs  # Class probabilities for classification outputs

video 1/1 (frame 1/69) /home/kmvu/private/AI-Pitch-Recognition/kate work/../kate work/strike_videos/All of Shohei Ohtani's 2022 Strikeouts 720 (14).mp4: 320x544 2 batters, 1 home plate, 10.3ms
video 1/1 (frame 2/69) /home/kmvu/private/AI-Pitch-Recognition/kate work/../kate work/strike_videos/All of Shohei Ohtani's 2022 Strikeouts 720 (14).mp4: 320x544 2 batters, 1 home plate, 9.3ms
video 1/1 (frame 3/69) /home/kmvu/private/AI-Pitch-Recognition/kate work/../kate work/strike_videos/All of Shohei Ohtani's 2022 Strikeouts 720 (14).mp4: 320x544 2 batters, 1 home plate, 9

In [23]:
!brew install opencv
!pip install opencv-python --upgrade

/bin/bash: line 1: brew: command not found
Defaulting to user installation because normal site-packages is not writeable
Collecting opencv-python
  Downloading opencv_python-4.11.0.86-cp37-abi3-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.metadata (20 kB)
Downloading opencv_python-4.11.0.86-cp37-abi3-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (63.0 MB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m63.0/63.0 MB[0m [31m20.1 MB/s[0m eta [36m0:00:00[0m00:01[0m00:01[0m
[?25hInstalling collected packages: opencv-python
Successfully installed opencv-python-4.11.0.86


## Predict Strike Zone

In [1]:
import cv2
import torch
from ultralytics import YOLO

# Load trained models
batter_homeplate_model = YOLO("../nicks_work/runs/detect/train12/weights/best.pt")  # Detects batter & home plate
baseball_model = YOLO("../runs/detect/train13/weights/best.pt")  # Detects baseball

# Load input video
video_path = "../kate work/strike_videos/All of Shohei Ohtani's 2022 Strikeouts 720 (14).mp4"
cap = cv2.VideoCapture(video_path)

# Get video properties
fps = int(cap.get(cv2.CAP_PROP_FPS))
frame_width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH))
frame_height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT))

# Define VideoWriter to save output video
output_path = "output.mp4"
fourcc = cv2.VideoWriter_fourcc(*'mp4v')  # Codec
out = cv2.VideoWriter(output_path, fourcc, fps, (frame_width, frame_height))

frame_count = 0

while cap.isOpened():
    ret, frame = cap.read()
    if not ret:
        break

    frame_count += 1  # Track frame count

    # Run detection on batter & home plate
    batter_results = batter_homeplate_model(frame)[0]
    baseball_results = baseball_model(frame)[0]

    batter_bbox, plate_bbox, ball_bbox = None, None, None

    # Debugging: Print detection results
    # print(f"Frame {frame_count}:")
    # print(f"Batter & Home Plate Detections: {batter_results}")
    # print(f"Baseball Detections: {baseball_results}")

    # Extract batter & home plate bounding boxes
    for r in batter_results.boxes.data:  # Access detection results correctly
        class_id = int(r[-1])  # Class ID is the last value
        conf = r[-2]  # Confidence score
        if conf < 0.3:  # Ignore low-confidence detections
            continue
        if class_id == 0:  # Batter
            batter_bbox = r[:4].tolist()
        elif class_id == 1:  # Home Plate
            plate_bbox = r[:4].tolist()

    # Extract baseball bounding box
    for r in baseball_results.boxes.data:
        conf = r[-2]
        if conf < 0.3:  # Ignore low-confidence detections
            continue
        if int(r[-1]) == 0:  # Baseball
            ball_bbox = r[:4].tolist()

    # Draw detected objects
    if batter_bbox:
        cv2.rectangle(frame, (int(batter_bbox[0]), int(batter_bbox[1])),
                      (int(batter_bbox[2]), int(batter_bbox[3])), (255, 255, 0), 2)  # Yellow for Batter

    if plate_bbox:
        cv2.rectangle(frame, (int(plate_bbox[0]), int(plate_bbox[1])),
                      (int(plate_bbox[2]), int(plate_bbox[3])), (255, 0, 255), 2)  # Purple for Home Plate

    if ball_bbox:
        ball_center_x = (ball_bbox[0] + ball_bbox[2]) / 2
        ball_center_y = (ball_bbox[1] + ball_bbox[3]) / 2
        cv2.circle(frame, (int(ball_center_x), int(ball_center_y)), 5, (0, 255, 0), -1)  # Green for Ball

    # Compute Strike Zone if batter & home plate are detected
    # Compute Strike Zone if batter & home plate are detected
    if batter_bbox and plate_bbox:
        batter_top = batter_bbox[1]
        batter_bottom = batter_bbox[3]

    # Make strike zone smaller
        strike_zone_top = batter_top + (batter_bottom - batter_top) * 0.3  # Change from 0.2 to 0.3
        strike_zone_bottom = batter_top + (batter_bottom - batter_top) * 0.7  # Change from 0.8 to 0.7
        strike_zone_x_min = plate_bbox[0] + 10  # Narrow the width
        strike_zone_x_max = plate_bbox[2] - 10  # Narrow the width

    # Draw Strike Zone
    cv2.rectangle(frame, (int(strike_zone_x_min), int(strike_zone_top)),
                  (int(strike_zone_x_max), int(strike_zone_bottom)), (255, 0, 0), 2)  # Blue for Strike Zone
    # Check if the ball is a strike or ball
    if ball_bbox and batter_bbox and plate_bbox:
        if (strike_zone_x_min <= ball_center_x <= strike_zone_x_max and
            strike_zone_top <= ball_center_y <= strike_zone_bottom):
            result = "Strike!"
            color = (0, 255, 0)  # Green for Strike
        else:
            result = "Ball!"
            color = (0, 0, 255)  # Red for Ball

        # Display result
        cv2.putText(frame, result, (50, 50), cv2.FONT_HERSHEY_SIMPLEX, 1, color, 2)
        print(result)
        print(result)
        print(result)
        print(result)
        print(result)

    # Write frame to output video file
    out.write(frame)

cap.release()
out.release()

print(f"Video saved as {output_path}")


0: 320x544 2 batters, 1 home plate, 87.8ms
Speed: 8.1ms preprocess, 87.8ms inference, 362.8ms postprocess per image at shape (1, 3, 320, 544)

0: 320x544 (no detections), 29.6ms
Speed: 1.6ms preprocess, 29.6ms inference, 0.9ms postprocess per image at shape (1, 3, 320, 544)

0: 320x544 2 batters, 1 home plate, 11.3ms
Speed: 2.7ms preprocess, 11.3ms inference, 2.0ms postprocess per image at shape (1, 3, 320, 544)

0: 320x544 (no detections), 10.7ms
Speed: 1.4ms preprocess, 10.7ms inference, 0.7ms postprocess per image at shape (1, 3, 320, 544)

0: 320x544 2 batters, 1 home plate, 10.6ms
Speed: 1.5ms preprocess, 10.6ms inference, 1.7ms postprocess per image at shape (1, 3, 320, 544)

0: 320x544 (no detections), 10.6ms
Speed: 1.3ms preprocess, 10.6ms inference, 0.6ms postprocess per image at shape (1, 3, 320, 544)

0: 320x544 2 batters, 1 home plate, 10.5ms
Speed: 2.8ms preprocess, 10.5ms inference, 1.5ms postprocess per image at shape (1, 3, 320, 544)

0: 320x544 (no detections), 10.5ms