# Tennis Ball Detector Training

This notebook is used for training a model to detect tennis balls in video frames. The training process involves loading the dataset, preprocessing the images, and training a YOLO model.

In [1]:
import os
import yaml
import cv2
import numpy as np
import torch
from torchvision import transforms
from torch.utils.data import DataLoader, Dataset
from models.yolo import Model  # Assuming a YOLO model implementation is available

# Define paths
data_dir = 'tennis-ball-detection-6/'
train_images_dir = os.path.join(data_dir, 'train/images/')
train_labels_dir = os.path.join(data_dir, 'train/labels/')
valid_images_dir = os.path.join(data_dir, 'valid/images/')
valid_labels_dir = os.path.join(data_dir, 'valid/labels/')

# Load configuration
with open(os.path.join(data_dir, 'data.yaml')) as f:
    config = yaml.safe_load(f)

# Define a custom dataset class
class TennisBallDataset(Dataset):
    def __init__(self, images_dir, labels_dir, transform=None):
        self.images_dir = images_dir
        self.labels_dir = labels_dir
        self.transform = transform
        self.images = os.listdir(images_dir)

    def __len__(self):
        return len(self.images)

    def __getitem__(self, idx):
        img_name = os.path.join(self.images_dir, self.images[idx])
        image = cv2.imread(img_name)
        label_name = os.path.join(self.labels_dir, self.images[idx].replace('.jpg', '.txt'))
        with open(label_name, 'r') as f:
            labels = f.readlines()
        
        if self.transform:
            image = self.transform(image)
        return image, labels

# Define transformations
transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Resize((640, 640)),
])

# Create datasets and dataloaders
train_dataset = TennisBallDataset(train_images_dir, train_labels_dir, transform)
train_loader = DataLoader(train_dataset, batch_size=16, shuffle=True)

valid_dataset = TennisBallDataset(valid_images_dir, valid_labels_dir, transform)
valid_loader = DataLoader(valid_dataset, batch_size=16, shuffle=False)

# Initialize the model
model = Model(config['model'])  # Load the YOLO model with the specified configuration

# Define loss function and optimizer
criterion = torch.nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=0.001)

# Training loop
num_epochs = 10
for epoch in range(num_epochs):
    model.train()
    for images, labels in train_loader:
        optimizer.zero_grad()
        outputs = model(images)
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()
    print(f'Epoch [{epoch+1}/{num_epochs}], Loss: {loss.item():.4f}')

# Save the trained model
torch.save(model.state_dict(), 'tennis_ball_detector.pth')
print('Model saved successfully!')