In [1]:
import os
from PIL import Image
import torch
from torch.utils.data import DataLoader
from torchvision.transforms import functional as F
from pycocotools.coco import COCO

class CocoDataset(torch.utils.data.Dataset):
    def __init__(self, images_dir, annotations_file, transforms=None):
        self.images_dir = images_dir
        self.coco = COCO(annotations_file)
        self.img_ids = self.coco.getImgIds()
        self.transforms = transforms

    def __getitem__(self, idx):
        img_id = self.img_ids[idx]
        img_info = self.coco.loadImgs(img_id)[0]
        img_path = os.path.join(self.images_dir, img_info['file_name'])
        img = Image.open(img_path).convert("RGB")
        img_tensor = F.to_tensor(img)

        ann_ids = self.coco.getAnnIds(imgIds=img_id)
        anns = self.coco.loadAnns(ann_ids)

        boxes = []
        labels = []
        for ann in anns:
            x, y, w, h = ann['bbox']
            boxes.append([x, y, x+w, y+h])
            labels.append(ann['category_id'])

        target = {
            "boxes": torch.tensor(boxes, dtype=torch.float32),
            "labels": torch.tensor(labels, dtype=torch.int64)
        }

        if self.transforms:
            img_tensor = self.transforms(img_tensor)

        return img_tensor, target

    def __len__(self):
        return len(self.img_ids)


In [2]:
TRAIN_IMAGES = "handball-detection-8/train"
TRAIN_ANNOTATIONS = "handball-detection-8/train/_annotations.coco.json"
VAL_IMAGES = "handball-detection-8/valid"
VAL_ANNOTATIONS = "handball-detection-8/valid/_annotations.coco.json"

train_dataset = CocoDataset(TRAIN_IMAGES, TRAIN_ANNOTATIONS)
val_dataset = CocoDataset(VAL_IMAGES, VAL_ANNOTATIONS)

train_loader = DataLoader(train_dataset, batch_size=8, shuffle=True, collate_fn=lambda x: tuple(zip(*x)))
val_loader = DataLoader(val_dataset, batch_size=8, shuffle=False, collate_fn=lambda x: tuple(zip(*x)))


loading annotations into memory...
Done (t=0.02s)
creating index...
index created!
loading annotations into memory...
Done (t=0.01s)
creating index...
index created!


In [5]:
from torchvision.models.detection import ssd300_vgg16

DEVICE = 'cuda' if torch.cuda.is_available() else 'cpu'
NUM_CLASSES = 3  # background + goalpost + handball

# Create model with correct number of classes
model = ssd300_vgg16(weights=None, num_classes=NUM_CLASSES)
model.to(DEVICE)


optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)


In [6]:
from copy import deepcopy

NUM_EPOCHS = 200
PATIENCE = 10

best_model_wts = deepcopy(model.state_dict())
best_val_loss = float('inf')
patience_counter = 0

# ----------------------------
# Training loop
# ----------------------------
for epoch in range(NUM_EPOCHS):
    model.train()
    running_loss = 0.0

    # --- Training ---
    for images, targets in train_loader:
        images = [img.to(DEVICE) for img in images]
        targets = [{k: v.to(DEVICE) for k, v in t.items()} for t in targets]

        optimizer.zero_grad()
        loss_dict = model(images, targets)  # returns dict of losses
        loss = sum(loss for loss in loss_dict.values())
        loss.backward()
        optimizer.step()
        running_loss += loss.item()

    epoch_loss = running_loss / len(train_loader)
    print(f"Epoch [{epoch+1}/{NUM_EPOCHS}] Training Loss: {epoch_loss:.4f}")

    # --- Validation ---
    model.train()  # keep train mode to get loss dict
    val_loss = 0.0
    with torch.no_grad():
        for images, targets in val_loader:
            images = [img.to(DEVICE) for img in images]
            targets = [{k: v.to(DEVICE) for k, v in t.items()} for t in targets]
            loss_dict = model(images, targets)
            val_loss += sum(loss for loss in loss_dict.values()).item()

    val_loss /= len(val_loader)
    print(f"Epoch [{epoch+1}/{NUM_EPOCHS}] Validation Loss: {val_loss:.4f}")

    # --- Early Stopping ---
    if val_loss < best_val_loss:
        best_val_loss = val_loss
        best_model_wts = deepcopy(model.state_dict())
        patience_counter = 0
    else:
        patience_counter += 1
        if patience_counter >= PATIENCE:
            print(f"Early stopping triggered at epoch {epoch+1}")
            break

# ----------------------------
# Save best model
# ----------------------------
model.load_state_dict(best_model_wts)
torch.save(model.state_dict(), "ssd_trained_handball.pth")
print("Training complete, model saved as 'ssd_trained_handball.pth'")

Epoch [1/200] Training Loss: 7.4078
Epoch [1/200] Validation Loss: 4.6794
Epoch [2/200] Training Loss: 131.8626
Epoch [2/200] Validation Loss: 6.5755
Epoch [3/200] Training Loss: 7.1831
Epoch [3/200] Validation Loss: 4.8387
Epoch [4/200] Training Loss: 4.9009
Epoch [4/200] Validation Loss: 4.5526
Epoch [5/200] Training Loss: 4.5748
Epoch [5/200] Validation Loss: 4.3842
Epoch [6/200] Training Loss: 4.4597
Epoch [6/200] Validation Loss: 4.3756
Epoch [7/200] Training Loss: 4.4223
Epoch [7/200] Validation Loss: 4.3227
Epoch [8/200] Training Loss: 4.3503
Epoch [8/200] Validation Loss: 4.3109
Epoch [9/200] Training Loss: 4.2931
Epoch [9/200] Validation Loss: 4.3573
Epoch [10/200] Training Loss: 4.2354
Epoch [10/200] Validation Loss: 4.2101
Epoch [11/200] Training Loss: 4.1642
Epoch [11/200] Validation Loss: 4.1535
Epoch [12/200] Training Loss: 4.0622
Epoch [12/200] Validation Loss: 4.1045
Epoch [13/200] Training Loss: 3.9755
Epoch [13/200] Validation Loss: 4.0024
Epoch [14/200] Training Loss