# **NeuralTennis**

In [56]:
!pip install --quiet torch torchvision opencv-python-headless matplotlib plotly albumentations scikit-learn ultralytics

In [57]:
from google.colab import drive
drive.mount("/content/drive")

Drive already mounted at /content/drive; to attempt to forcibly remount, call drive.mount("/content/drive", force_remount=True).


In [58]:
import os

BASE_DIR = "/content/drive/MyDrive/NeuralTennis"
for sub in ["input", "models", "output"]:
    os.makedirs(f"{BASE_DIR}/{sub}", exist_ok=True)

In [4]:
import cv2

video_path = f"{BASE_DIR}/input/4.mp4"
cap = cv2.VideoCapture(video_path)
if not cap.isOpened():
    raise IOError(f"Cannot open {video_path}")
ret, frame = cap.read()
cap.release()
if not ret:
    raise IOError("Cannot read the first frame")
print("Video loaded – resolution:", frame.shape[1], "×", frame.shape[0])

Video loaded – resolution: 1920 × 1080


In [31]:
import cv2

def read_video(video_path):
    cap = cv2.VideoCapture(video_path)
    frames = []
    while True:
        ret, frame = cap.read()
        if not ret:
            break
        frames.append(frame)
    cap.release()
    return frames

def save_video(output_video_frames, output_video_path):
    fourcc = cv2.VideoWriter_fourcc(*'mp4v')
    out = cv2.VideoWriter(output_video_path, fourcc, 24, (output_video_frames[0].shape[1], output_video_frames[0].shape[0]))
    for frame in output_video_frames:
        out.write(frame)
    out.release()
    print(f"Video saved to {output_video_path}")


In [None]:
!pip --quiet install roboflow
from roboflow import Roboflow

rf = Roboflow(api_key="6h8It1fGMm2wN78COwWc")
project = rf.workspace("labia-5re0j").project("tennis-ball-detection-7hmzh")
version = project.version(1)
dataset = version.download("yolov8")


from ultralytics import YOLO

model = YOLO("yolov8x.pt")

model.train(
    data      = "/content/tennis-ball-detection-1/data.yaml",
    device    = 0,
    epochs    = 100,
    batch     = 32,
    imgsz     = 640,
    patience  = 10,
    project   = "/content/drive/MyDrive/NeuralTennis/models/ball_training",
    name      = "run",
    exist_ok  = True
)

In [32]:
#!pip install ultralytics
from ultralytics import YOLO
import cv2
import pandas as pd


class PlayerTracker:
    def __init__(self,model_path):
        self.model = YOLO(model_path)

    def detect_frame(self,frame):
        results = self.model.track(frame, persist=True)[0]
        id_name_dict = results.names

        player_dict = {}
        for box in results.boxes:
            track_id = int(box.id.tolist()[0])
            result = box.xyxy.tolist()[0]
            object_cls_id = box.cls.tolist()[0]
            object_cls_name = id_name_dict[object_cls_id]
            if object_cls_name == "person":
                player_dict[track_id] = result

        return player_dict

    def draw_bboxes(self,video_frames, player_detections):
        output_video_frames = []
        for frame, player_dict in zip(video_frames, player_detections):
            # Draw Bounding Boxes
            for track_id, bbox in player_dict.items():
                x1, y1, x2, y2 = bbox
                cv2.putText(frame, f"Player ID: {track_id}",(int(bbox[0]),int(bbox[1] -10 )),cv2.FONT_HERSHEY_SIMPLEX, 0.9, (197, 197, 197), 2)
                cv2.rectangle(frame, (int(x1), int(y1)), (int(x2), int(y2)), (197, 197, 197), 2)
            output_video_frames.append(frame)

        return output_video_frames


class BallTracker:
    def __init__(self,model_path):
        self.model = YOLO(model_path)

    def interpolate_ball_positions(self, ball_positions):
        ball_positions = [x.get(1,[]) for x in ball_positions]

        df_ball_positions = pd.DataFrame(ball_positions,columns=['x1','y1','x2','y2'])

        df_ball_positions = df_ball_positions.interpolate()
        df_ball_positions = df_ball_positions.bfill()

        ball_positions = [{1:x} for x in df_ball_positions.to_numpy().tolist()]

        return ball_positions

    def detect_frame(self,frame):
        results = self.model.predict(frame,conf=0.07)[0]

        ball_dict = {}
        for box in results.boxes:
            result = box.xyxy.tolist()[0]
            ball_dict[1] = result

        return ball_dict

    def draw_bboxes(self,video_frames, player_detections):
        output_video_frames = []
        for frame, ball_dict in zip(video_frames, player_detections):
            # Draw Bounding Boxes
            for track_id, bbox in ball_dict.items():
                x1, y1, x2, y2 = bbox
                cv2.putText(frame, f"Ball ID: {track_id}",(int(bbox[0]),int(bbox[1] -10 )),cv2.FONT_HERSHEY_SIMPLEX, 0.9, (96, 255, 168), 2)
                cv2.rectangle(frame, (int(x1), int(y1)), (int(x2), int(y2)), (96, 255, 168), 2)
            output_video_frames.append(frame)

        return output_video_frames

In [84]:
def main():

    input_video_path = "/content/drive/MyDrive/NeuralTennis/input/5.mp4"
    video_frames = read_video(input_video_path)

    player_tracker = PlayerTracker("/content/drive/MyDrive/NeuralTennis/models/yolov8x.pt")
    player_detections = [player_tracker.detect_frame(frame) for frame in video_frames]
    video_frames = player_tracker.draw_bboxes(video_frames, player_detections)

    ball_tracker = BallTracker("/content/drive/MyDrive/NeuralTennis/models/ball_training/run/weights/best.pt")
    ball_detections = [ball_tracker.detect_frame(frame) for frame in video_frames]
    ball_detections = ball_tracker.interpolate_ball_positions(ball_detections)
    video_frames = ball_tracker.draw_bboxes(video_frames, ball_detections)

    save_video(video_frames, "/content/drive/MyDrive/NeuralTennis/output/5_o_trained.mp4")

if __name__ == "__main__":
    main()


0: 384x640 12 persons, 1 tennis racket, 1 potted plant, 14.3ms
Speed: 2.5ms preprocess, 14.3ms inference, 1.9ms postprocess per image at shape (1, 3, 384, 640)

0: 384x640 12 persons, 1 tennis racket, 1 potted plant, 12.6ms
Speed: 2.4ms preprocess, 12.6ms inference, 1.7ms postprocess per image at shape (1, 3, 384, 640)

0: 384x640 12 persons, 1 tennis racket, 1 potted plant, 12.6ms
Speed: 2.5ms preprocess, 12.6ms inference, 1.9ms postprocess per image at shape (1, 3, 384, 640)

0: 384x640 12 persons, 1 tennis racket, 1 potted plant, 12.5ms
Speed: 2.4ms preprocess, 12.5ms inference, 1.7ms postprocess per image at shape (1, 3, 384, 640)

0: 384x640 12 persons, 1 tennis racket, 1 potted plant, 13.1ms
Speed: 2.5ms preprocess, 13.1ms inference, 1.8ms postprocess per image at shape (1, 3, 384, 640)

0: 384x640 12 persons, 1 tennis racket, 1 potted plant, 13.0ms
Speed: 2.5ms preprocess, 13.0ms inference, 1.8ms postprocess per image at shape (1, 3, 384, 640)

0: 384x640 12 persons, 1 potted p