# VGGT-based Player Localization Training

This notebook trains a DETR-style decoder on top of a frozen VGGT encoder for player localization in soccer images.

**Pipeline:** 1080p Image → VGGT Encoder (frozen) → DETR Decoder (trainable) → Player Positions

## 1. Setup Environment

In [None]:
# Check GPU
!nvidia-smi

In [None]:
# Install dependencies
!pip install -q einops safetensors huggingface_hub scipy xtcocotools

In [None]:
# Clone repositories
import os

# Clone sskit (your fork)
if not os.path.exists('sskit'):
    !git clone https://github.com/hiteacherIamhumble/soccernet.git sskit

# Clone VGGT
if not os.path.exists('vggt'):
    !git clone https://github.com/facebookresearch/vggt.git

In [None]:
# Install packages in development mode
!pip install -q -e sskit/
!pip install -q -e vggt/

In [None]:
# Add paths
import sys
sys.path.insert(0, 'sskit')
sys.path.insert(0, 'vggt')

## 2. Download Data

Upload your data to Google Drive or download from source.

In [None]:
# Mount Google Drive (optional - if data is stored there)
from google.colab import drive
drive.mount('/content/drive')

In [None]:
# Option 1: Copy data from Google Drive
# !cp -r /content/drive/MyDrive/soccernet_data/* sskit/data/

# Option 2: Download and extract data (update URLs as needed)
# !cd sskit/data && wget <train.zip_url> && unzip -q train.zip
# !cd sskit/data && wget <val.zip_url> && unzip -q val.zip
# !cd sskit/data && wget <annotations.zip_url> && unzip -q annotations.zip

# Verify data structure
!ls -la sskit/data/

## 3. Configuration

In [None]:
from dataclasses import dataclass
from typing import Optional, Tuple

@dataclass
class TrainingConfig:
    # Data
    data_root: str = 'sskit/data'
    target_size: Tuple[int, int] = (1078, 1918)  # (H, W) divisible by 14
    
    # Model
    hidden_dim: int = 256
    num_queries: int = 30
    num_decoder_layers: int = 6
    num_heads: int = 8
    dropout: float = 0.1
    
    # Training
    epochs: int = 100
    batch_size: int = 4
    lr: float = 1e-4
    weight_decay: float = 1e-4
    warmup_epochs: int = 5
    grad_clip: float = 0.1
    
    # Loss
    weight_position: float = 5.0
    weight_confidence: float = 1.0
    weight_no_object: float = 0.1
    
    # Other
    num_workers: int = 2
    save_dir: str = 'checkpoints'
    save_freq: int = 10
    log_freq: int = 50
    vggt_model: str = 'facebook/VGGT-1B'
    resume: Optional[str] = None

config = TrainingConfig()
print(f"Training config: {config}")

## 4. Load Models

In [None]:
import torch
import torch.nn as nn
from pathlib import Path

# Set device
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Using device: {device}")

In [None]:
# Load VGGT encoder (frozen)
from vggt.models.vggt import VGGT

print(f"Loading VGGT encoder from {config.vggt_model}...")
encoder = VGGT.from_pretrained(config.vggt_model)
encoder = encoder.to(device)
encoder.eval()

# Freeze all parameters
for param in encoder.parameters():
    param.requires_grad = False

print("VGGT encoder loaded and frozen.")

In [None]:
# Create decoder
from sskit.models import DETRPlayerDecoder, HungarianMatcher, SetCriterion

decoder = DETRPlayerDecoder(
    dim_in=2048,
    hidden_dim=config.hidden_dim,
    num_queries=config.num_queries,
    num_decoder_layers=config.num_decoder_layers,
    num_heads=config.num_heads,
    dropout=config.dropout,
).to(device)

num_params = sum(p.numel() for p in decoder.parameters())
print(f"Decoder parameters: {num_params:,}")

In [None]:
# Create loss functions
matcher = HungarianMatcher(
    cost_position=config.weight_position,
    cost_confidence=config.weight_confidence,
)

criterion = SetCriterion(
    matcher=matcher,
    weight_position=config.weight_position,
    weight_confidence=config.weight_confidence,
    weight_no_object=config.weight_no_object,
)

print("Loss functions created.")

## 5. Create Dataloaders

In [None]:
from sskit.data import SynLocDataset
from sskit.data.dataset import collate_fn

data_root = Path(config.data_root)

train_dataset = SynLocDataset(
    root_dir=str(data_root / 'train'),
    coco_json=str(data_root / 'annotations' / 'train.json'),
    target_size=config.target_size,
)

val_dataset = SynLocDataset(
    root_dir=str(data_root / 'val'),
    coco_json=str(data_root / 'annotations' / 'val.json'),
    target_size=config.target_size,
)

print(f"Train dataset: {len(train_dataset)} images")
print(f"Val dataset: {len(val_dataset)} images")

In [None]:
train_loader = torch.utils.data.DataLoader(
    train_dataset,
    batch_size=config.batch_size,
    shuffle=True,
    num_workers=config.num_workers,
    pin_memory=True,
    collate_fn=collate_fn,
    drop_last=True,
)

val_loader = torch.utils.data.DataLoader(
    val_dataset,
    batch_size=config.batch_size,
    shuffle=False,
    num_workers=config.num_workers,
    pin_memory=True,
    collate_fn=collate_fn,
)

print(f"Train batches: {len(train_loader)}")
print(f"Val batches: {len(val_loader)}")

## 6. Setup Optimizer and Scheduler

In [None]:
from torch.optim import AdamW
from torch.optim.lr_scheduler import CosineAnnealingLR, LinearLR, SequentialLR

optimizer = AdamW(
    decoder.parameters(),
    lr=config.lr,
    weight_decay=config.weight_decay,
)

warmup_scheduler = LinearLR(
    optimizer,
    start_factor=0.01,
    end_factor=1.0,
    total_iters=config.warmup_epochs,
)

main_scheduler = CosineAnnealingLR(
    optimizer,
    T_max=config.epochs - config.warmup_epochs,
    eta_min=config.lr * 0.01,
)

scheduler = SequentialLR(
    optimizer,
    schedulers=[warmup_scheduler, main_scheduler],
    milestones=[config.warmup_epochs],
)

print("Optimizer and scheduler created.")

In [None]:
# Resume from checkpoint (optional)
start_epoch = 0
best_val_loss = float('inf')

if config.resume:
    checkpoint = torch.load(config.resume, map_location=device)
    decoder.load_state_dict(checkpoint['decoder_state_dict'])
    optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
    if checkpoint['scheduler_state_dict']:
        scheduler.load_state_dict(checkpoint['scheduler_state_dict'])
    start_epoch = checkpoint['epoch'] + 1
    best_val_loss = checkpoint.get('best_val_loss', float('inf'))
    print(f"Resumed from epoch {start_epoch}")

## 7. Training Functions

In [None]:
from tqdm.notebook import tqdm
import json

def train_one_epoch(encoder, decoder, criterion, dataloader, optimizer, device, epoch, config):
    """Train for one epoch."""
    decoder.train()
    
    total_loss = 0.0
    total_loss_pos = 0.0
    total_loss_conf = 0.0
    num_batches = 0
    
    pbar = tqdm(dataloader, desc=f'Epoch {epoch}', leave=True)
    
    for i, (images, targets) in enumerate(pbar):
        # Move to device
        images = images.to(device)
        targets = [{k: v.to(device) if torch.is_tensor(v) else v for k, v in t.items()}
                   for t in targets]
        
        # Forward through frozen encoder
        with torch.no_grad():
            aggregated_tokens_list, patch_start_idx = encoder.aggregator(images)
        
        # Forward through decoder
        outputs = decoder(aggregated_tokens_list, patch_start_idx)
        
        # Compute loss
        loss_dict = criterion(outputs, targets)
        loss = loss_dict['loss']
        
        # Backward
        optimizer.zero_grad()
        loss.backward()
        
        # Gradient clipping
        if config.grad_clip > 0:
            torch.nn.utils.clip_grad_norm_(decoder.parameters(), config.grad_clip)
        
        optimizer.step()
        
        # Logging
        total_loss += loss.item()
        total_loss_pos += loss_dict['loss_position'].item()
        total_loss_conf += loss_dict['loss_confidence'].item()
        num_batches += 1
        
        if (i + 1) % config.log_freq == 0 or i == len(dataloader) - 1:
            pbar.set_postfix({
                'loss': f'{total_loss / num_batches:.4f}',
                'pos': f'{total_loss_pos / num_batches:.4f}',
                'conf': f'{total_loss_conf / num_batches:.4f}',
            })
    
    return {
        'loss': total_loss / num_batches,
        'loss_position': total_loss_pos / num_batches,
        'loss_confidence': total_loss_conf / num_batches,
    }

In [None]:
@torch.no_grad()
def validate(encoder, decoder, criterion, dataloader, device):
    """Validate the model."""
    decoder.eval()
    
    total_loss = 0.0
    total_loss_pos = 0.0
    total_loss_conf = 0.0
    num_batches = 0
    
    pbar = tqdm(dataloader, desc='Validation', leave=True)
    
    for images, targets in pbar:
        images = images.to(device)
        targets = [{k: v.to(device) if torch.is_tensor(v) else v for k, v in t.items()}
                   for t in targets]
        
        # Forward
        aggregated_tokens_list, patch_start_idx = encoder.aggregator(images)
        outputs = decoder(aggregated_tokens_list, patch_start_idx)
        
        # Compute loss
        loss_dict = criterion(outputs, targets)
        
        total_loss += loss_dict['loss'].item()
        total_loss_pos += loss_dict['loss_position'].item()
        total_loss_conf += loss_dict['loss_confidence'].item()
        num_batches += 1
        
        pbar.set_postfix({
            'loss': f'{total_loss / num_batches:.4f}',
        })
    
    return {
        'loss': total_loss / num_batches,
        'loss_position': total_loss_pos / num_batches,
        'loss_confidence': total_loss_conf / num_batches,
    }

In [None]:
def save_checkpoint(decoder, optimizer, scheduler, epoch, config, best_val_loss, is_best=False):
    """Save checkpoint."""
    save_dir = Path(config.save_dir)
    save_dir.mkdir(parents=True, exist_ok=True)
    
    checkpoint = {
        'epoch': epoch,
        'decoder_state_dict': decoder.state_dict(),
        'optimizer_state_dict': optimizer.state_dict(),
        'scheduler_state_dict': scheduler.state_dict() if scheduler else None,
        'best_val_loss': best_val_loss,
        'config': vars(config) if hasattr(config, '__dict__') else config,
    }
    
    # Save latest
    torch.save(checkpoint, save_dir / 'latest.pt')
    
    # Save periodic
    if (epoch + 1) % config.save_freq == 0:
        torch.save(checkpoint, save_dir / f'epoch_{epoch+1}.pt')
    
    # Save best
    if is_best:
        torch.save(checkpoint, save_dir / 'best.pt')
    
    print(f"Checkpoint saved: epoch {epoch + 1}" + (" (best)" if is_best else ""))

## 8. Training Loop

In [None]:
# Create save directory
save_dir = Path(config.save_dir)
save_dir.mkdir(parents=True, exist_ok=True)

# Save config
with open(save_dir / 'config.json', 'w') as f:
    json.dump(vars(config) if hasattr(config, '__dict__') else config, f, indent=2)

print(f"Saving checkpoints to: {save_dir}")

In [None]:
# Training history
history = {
    'train_loss': [],
    'train_loss_pos': [],
    'train_loss_conf': [],
    'val_loss': [],
    'val_loss_pos': [],
    'val_loss_conf': [],
    'lr': [],
}

In [None]:
print(f"Starting training from epoch {start_epoch}...")
print(f"Total epochs: {config.epochs}")
print(f"Batch size: {config.batch_size}")
print(f"Learning rate: {config.lr}")
print("-" * 50)

for epoch in range(start_epoch, config.epochs):
    # Train
    train_metrics = train_one_epoch(
        encoder, decoder, criterion, train_loader, optimizer, device, epoch, config
    )
    
    # Validate
    val_metrics = validate(encoder, decoder, criterion, val_loader, device)
    
    # Update scheduler
    scheduler.step()
    current_lr = scheduler.get_last_lr()[0]
    
    # Log
    print(
        f'Epoch {epoch + 1}/{config.epochs} | '
        f'Train: loss={train_metrics["loss"]:.4f}, pos={train_metrics["loss_position"]:.4f}, conf={train_metrics["loss_confidence"]:.4f} | '
        f'Val: loss={val_metrics["loss"]:.4f} | '
        f'LR: {current_lr:.6f}'
    )
    
    # Update history
    history['train_loss'].append(train_metrics['loss'])
    history['train_loss_pos'].append(train_metrics['loss_position'])
    history['train_loss_conf'].append(train_metrics['loss_confidence'])
    history['val_loss'].append(val_metrics['loss'])
    history['val_loss_pos'].append(val_metrics['loss_position'])
    history['val_loss_conf'].append(val_metrics['loss_confidence'])
    history['lr'].append(current_lr)
    
    # Save checkpoint
    is_best = val_metrics['loss'] < best_val_loss
    if is_best:
        best_val_loss = val_metrics['loss']
    
    save_checkpoint(decoder, optimizer, scheduler, epoch, config, best_val_loss, is_best=is_best)

print("\nTraining complete!")
print(f"Best validation loss: {best_val_loss:.4f}")

## 9. Plot Training History

In [None]:
import matplotlib.pyplot as plt

fig, axes = plt.subplots(2, 2, figsize=(14, 10))

# Total loss
ax = axes[0, 0]
ax.plot(history['train_loss'], label='Train')
ax.plot(history['val_loss'], label='Val')
ax.set_xlabel('Epoch')
ax.set_ylabel('Loss')
ax.set_title('Total Loss')
ax.legend()
ax.grid(True)

# Position loss
ax = axes[0, 1]
ax.plot(history['train_loss_pos'], label='Train')
ax.plot(history['val_loss_pos'], label='Val')
ax.set_xlabel('Epoch')
ax.set_ylabel('Loss')
ax.set_title('Position Loss')
ax.legend()
ax.grid(True)

# Confidence loss
ax = axes[1, 0]
ax.plot(history['train_loss_conf'], label='Train')
ax.plot(history['val_loss_conf'], label='Val')
ax.set_xlabel('Epoch')
ax.set_ylabel('Loss')
ax.set_title('Confidence Loss')
ax.legend()
ax.grid(True)

# Learning rate
ax = axes[1, 1]
ax.plot(history['lr'])
ax.set_xlabel('Epoch')
ax.set_ylabel('Learning Rate')
ax.set_title('Learning Rate Schedule')
ax.grid(True)

plt.tight_layout()
plt.savefig(save_dir / 'training_history.png', dpi=150)
plt.show()

## 10. Save History and Download Checkpoints

In [None]:
# Save history
with open(save_dir / 'history.json', 'w') as f:
    json.dump(history, f, indent=2)

print(f"Training history saved to {save_dir / 'history.json'}")

In [None]:
# List saved files
!ls -lh {config.save_dir}/

In [None]:
# Option 1: Copy to Google Drive
# !cp -r {config.save_dir} /content/drive/MyDrive/soccernet_checkpoints/

# Option 2: Download directly
from google.colab import files
files.download(f'{config.save_dir}/best.pt')

## 11. Quick Inference Test

In [None]:
# Load best model
checkpoint = torch.load(f'{config.save_dir}/best.pt', map_location=device)
decoder.load_state_dict(checkpoint['decoder_state_dict'])
decoder.eval()
print(f"Loaded best model from epoch {checkpoint['epoch'] + 1}")

In [None]:
# Run inference on a sample
@torch.no_grad()
def inference_sample(encoder, decoder, dataset, idx, conf_threshold=0.5):
    image, target = dataset[idx]
    image = image.unsqueeze(0).to(device)  # Add batch dim
    
    # Forward
    aggregated_tokens_list, patch_start_idx = encoder.aggregator(image)
    outputs = decoder(aggregated_tokens_list, patch_start_idx)
    
    # Get predictions above threshold
    positions = outputs['positions'][0]  # [N, 2]
    confidences = outputs['confidences'][0]  # [N]
    
    mask = confidences > conf_threshold
    pred_positions = positions[mask]
    pred_scores = confidences[mask]
    
    return {
        'pred_positions': pred_positions.cpu(),
        'pred_scores': pred_scores.cpu(),
        'gt_positions': target['positions'],
        'num_gt': target['num_players'],
        'image_size': target['image_size'],
    }

# Test on first validation sample
result = inference_sample(encoder, decoder, val_dataset, 0)
print(f"Ground truth players: {result['num_gt']}")
print(f"Predicted players (conf > 0.5): {len(result['pred_positions'])}")
print(f"Confidence scores: {result['pred_scores'].tolist()[:5]}...")

In [None]:
# Visualize predictions
import numpy as np
from PIL import Image as PILImage

def visualize_predictions(dataset, idx, result):
    image, target = dataset[idx]
    
    # Convert image tensor to numpy
    img = image.squeeze(0).permute(1, 2, 0).numpy()  # [H, W, 3]
    img = (img * 255).astype(np.uint8)
    
    H, W = result['image_size']
    
    # Denormalize positions
    gt_pos = result['gt_positions'].numpy() * np.array([W, H])
    pred_pos = result['pred_positions'].numpy() * np.array([W, H])
    
    # Plot
    fig, ax = plt.subplots(1, 1, figsize=(16, 9))
    ax.imshow(img)
    
    # Ground truth (green)
    if len(gt_pos) > 0:
        ax.scatter(gt_pos[:, 0], gt_pos[:, 1], c='green', s=100, marker='o', 
                   label=f'GT ({len(gt_pos)})', edgecolors='white', linewidths=2)
    
    # Predictions (red)
    if len(pred_pos) > 0:
        ax.scatter(pred_pos[:, 0], pred_pos[:, 1], c='red', s=100, marker='x',
                   label=f'Pred ({len(pred_pos)})', linewidths=2)
    
    ax.legend(loc='upper right', fontsize=12)
    ax.set_title(f'Player Localization - Image {idx}')
    ax.axis('off')
    plt.tight_layout()
    plt.show()

visualize_predictions(val_dataset, 0, result)