In [None]:
import os
import torch
import torchvision.transforms as T
from PIL import Image
from torchvision.ops import box_convert

class SoccerNetDataset(torch.utils.data.Dataset):
    def __init__(self, root_dir, transform=None):
        self.root_dir = root_dir
        self.transform = transform or T.ToTensor()
        self.image_paths = []
        self.targets = []

        for seq_id in sorted(os.listdir(root_dir)):
            img_dir = os.path.join(root_dir, seq_id, "img1")
            gt_path = os.path.join(root_dir, seq_id, "gt", "gt.txt")
            if not os.path.exists(gt_path): continue

            # Read ground truth boxes per frame
            gt_map = {}
            with open(gt_path, 'r') as f:
                for line in f:
                    parts = list(map(int, line.strip().split(',')[:6]))  # frame, id, x, y, w, h
                    frame, _, x, y, w, h = parts
                    box = [x, y, w, h]
                    gt_map.setdefault(frame, []).append(box)

            for img_file in sorted(os.listdir(img_dir)):
                if not img_file.lower().endswith((".jpg", ".png")):
                    continue
                frame_id = int(img_file.split('.')[0])
                img_path = os.path.join(img_dir, img_file)
                boxes = gt_map.get(frame_id, [])

                if len(boxes) == 0:
                    continue

                # Convert [x, y, w, h] → [x1, y1, x2, y2]
                boxes_xyxy = box_convert(torch.tensor(boxes, dtype=torch.float32), in_fmt='xywh', out_fmt='xyxy')
                labels = torch.ones((len(boxes),), dtype=torch.int64)  # label 1 = player
                target = {"boxes": boxes_xyxy, "labels": labels}
                self.image_paths.append(img_path)
                self.targets.append(target)

    def __len__(self):
        return len(self.image_paths)

    def __getitem__(self, idx):
        img = Image.open(self.image_paths[idx]).convert("RGB")
        img_tensor = self.transform(img)
        return img_tensor, self.targets[idx]


In [None]:
from torchvision.models.detection import fasterrcnn_resnet50_fpn
import torchvision
import torch
import torchvision.transforms as T

# Load pretrained model
model = fasterrcnn_resnet50_fpn(pretrained=True)
model.train()
model.to("cuda")

# Replace head (optional, but useful for fine-tuning)
in_features = model.roi_heads.box_predictor.cls_score.in_features
model.roi_heads.box_predictor = torchvision.models.detection.faster_rcnn.FastRCNNPredictor(in_features, num_classes=2)

# Dataset + DataLoader
train_dataset = SoccerNetDataset("../soccernet_data/tracking/train")
train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=4, shuffle=True, collate_fn=lambda x: tuple(zip(*x)))

# Optimizer
params = [p for p in model.parameters() if p.requires_grad]
optimizer = torch.optim.SGD(params, lr=0.005, momentum=0.9, weight_decay=0.0005)

# Training loop
from tqdm import tqdm
for epoch in range(5):  # fine-tune for 5 epochs
    model.train()
    total_loss = 0
    for images, targets in tqdm(train_loader):
        images = list(img.to("cuda") for img in images)
        targets = [{k: v.to("cuda") for k, v in t.items()} for t in targets]

        loss_dict = model(images, targets)
        losses = sum(loss for loss in loss_dict.values())

        optimizer.zero_grad()
        losses.backward()
        optimizer.step()
        total_loss += losses.item()

    print(f"Epoch {epoch+1} loss: {total_loss:.4f}")


In [None]:
# Save
version = "v1.0"

torch.save(model.state_dict(), f"fasterrcnn_finetuned_{version}.pth")

# Load
# model.load_state_dict(torch.load("fasterrcnn_soccernet_finetuned.pth"))


In [None]:
import matplotlib.pyplot as plt

def plot_loss(epoch_losses, num_epochs):

    # Plot loss
    plt.figure(figsize=(8, 5))
    plt.plot(range(1, num_epochs+1), epoch_losses, marker='o', color='blue')
    plt.title("Training Loss per Epoch")
    plt.xlabel("Epoch")
    plt.ylabel("Loss")
    plt.grid(True)
    plt.xticks(range(1, num_epochs+1))
    plt.tight_layout()
    plt.show()
