# SoccerNet SynLoc: Training

This notebook covers:
1. Configuration setup
2. Model initialization with pretrained backbone
3. Training loop with progress tracking
4. Checkpoint saving to Google Drive

## 1. Setup

In [None]:
import sys
import os
from pathlib import Path

IN_COLAB = 'google.colab' in sys.modules

if IN_COLAB:
    # Mount Drive
    from google.colab import drive
    drive.mount('/content/drive')
    
    # Install if needed
    if not os.path.exists('soccernet-synloc'):
        !git clone https://github.com/YOUR_USERNAME/soccernet-synloc.git
        %cd soccernet-synloc
        !pip install -e .[dev] -q
    
    DATA_ROOT = Path('/content/drive/MyDrive/SoccerNet/synloc')
    CHECKPOINT_DIR = Path('/content/drive/MyDrive/SoccerNet/checkpoints')
else:
    DATA_ROOT = Path('./data/synloc')
    CHECKPOINT_DIR = Path('./checkpoints')

CHECKPOINT_DIR.mkdir(parents=True, exist_ok=True)
print(f"Data root: {DATA_ROOT}")
print(f"Checkpoint dir: {CHECKPOINT_DIR}")

In [None]:
import torch
import numpy as np
import matplotlib.pyplot as plt
from torch.utils.data import DataLoader

# Check GPU
device = 'cuda' if torch.cuda.is_available() else 'cpu'
if device == 'cuda':
    print(f"GPU: {torch.cuda.get_device_name(0)}")
    print(f"Memory: {torch.cuda.get_device_properties(0).total_memory / 1e9:.1f} GB")
else:
    print("No GPU available")

## 2. Configuration

In [None]:
# Training configuration
config = {
    # Model
    'model_variant': 'tiny',  # tiny, s, m, l
    'num_keypoints': 2,
    'input_size': (640, 640),
    
    # Training
    'batch_size': 16,
    'epochs': 100,
    'lr': 1e-3,
    'weight_decay': 5e-4,
    'warmup_epochs': 5,
    
    # Data
    'num_workers': 4,
    
    # Loss weights
    'loss_cls_weight': 1.0,
    'loss_bbox_weight': 5.0,
    'loss_obj_weight': 1.0,
    'loss_kpt_weight': 1.0,
    
    # Augmentation
    'use_mosaic': True,  # Enable mosaic augmentation
    'mosaic_prob': 0.5,
    
    # Misc
    'use_amp': True,  # Automatic mixed precision
    'save_interval': 10,  # Save every N epochs
}

# Adjust for GPU memory
if device == 'cuda':
    gpu_mem = torch.cuda.get_device_properties(0).total_memory / 1e9
    if gpu_mem < 8:  # T4 with 15GB, but be conservative
        config['batch_size'] = 8
        config['model_variant'] = 'tiny'
    elif gpu_mem < 20:
        config['batch_size'] = 16
    else:  # A100
        config['batch_size'] = 32

print("Configuration:")
for k, v in config.items():
    print(f"  {k}: {v}")

## 3. Create Datasets

In [None]:
from synloc.data import SynLocDataset, get_train_transforms, get_val_transforms

# Training dataset
train_transforms = get_train_transforms(
    config['input_size'][0],
    use_mosaic=config['use_mosaic']
)

train_dataset = SynLocDataset(
    ann_file=str(DATA_ROOT / 'train/annotations.json'),
    img_dir=str(DATA_ROOT / 'train/images'),
    transforms=train_transforms,
    input_size=config['input_size']
)

# Validation dataset
val_transforms = get_val_transforms(config['input_size'][0])

val_dataset = SynLocDataset(
    ann_file=str(DATA_ROOT / 'valid/annotations.json'),
    img_dir=str(DATA_ROOT / 'valid/images'),
    transforms=val_transforms,
    input_size=config['input_size']
)

print(f"Train dataset: {len(train_dataset)} images")
print(f"Val dataset: {len(val_dataset)} images")

In [None]:
# Create dataloaders
train_loader = DataLoader(
    train_dataset,
    batch_size=config['batch_size'],
    shuffle=True,
    num_workers=config['num_workers'],
    collate_fn=SynLocDataset.collate_fn,
    pin_memory=True,
    drop_last=True
)

val_loader = DataLoader(
    val_dataset,
    batch_size=config['batch_size'],
    shuffle=False,
    num_workers=config['num_workers'],
    collate_fn=SynLocDataset.collate_fn,
    pin_memory=True
)

print(f"Train batches: {len(train_loader)}")
print(f"Val batches: {len(val_loader)}")

## 4. Create Model

In [None]:
from synloc.models import YOLOXPose

model = YOLOXPose(
    variant=config['model_variant'],
    num_keypoints=config['num_keypoints'],
    input_size=config['input_size']
)

# Count parameters
total_params = sum(p.numel() for p in model.parameters())
trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
print(f"Model: YOLOX-Pose {config['model_variant']}")
print(f"Total parameters: {total_params:,}")
print(f"Trainable parameters: {trainable_params:,}")

In [None]:
# Load pretrained backbone (optional)
# You can download COCO pretrained weights and load them here

pretrained_path = CHECKPOINT_DIR / f'yolox_{config["model_variant"]}_coco.pth'

if pretrained_path.exists():
    print(f"Loading pretrained weights from {pretrained_path}")
    state_dict = torch.load(pretrained_path, map_location='cpu')
    
    # Filter out head weights (different number of keypoints)
    model_state = model.state_dict()
    filtered_state = {}
    for k, v in state_dict.items():
        if k in model_state and v.shape == model_state[k].shape:
            filtered_state[k] = v
    
    model.load_state_dict(filtered_state, strict=False)
    print(f"Loaded {len(filtered_state)}/{len(model_state)} layers")
else:
    print("No pretrained weights found, training from scratch")

## 5. Setup Training

In [None]:
from synloc.training import SynLocTrainer

trainer = SynLocTrainer(
    model=model,
    train_loader=train_loader,
    val_loader=val_loader,
    device=device,
    lr=config['lr'],
    weight_decay=config['weight_decay'],
    epochs=config['epochs'],
    warmup_epochs=config['warmup_epochs'],
    checkpoint_dir=str(CHECKPOINT_DIR),
    use_amp=config['use_amp']
)

print(f"Trainer initialized")
print(f"Using AMP: {config['use_amp']}")

## 6. Training Loop

In [None]:
# Optional: Resume from checkpoint
resume_path = None  # Set to checkpoint path to resume

if resume_path and Path(resume_path).exists():
    trainer.load_checkpoint(resume_path)
    print(f"Resumed from {resume_path}")

In [None]:
# Train!
history = trainer.train()

print("\nTraining complete!")

## 7. Training Curves

In [None]:
def plot_training_history(history):
    """Plot training curves."""
    fig, axes = plt.subplots(2, 2, figsize=(14, 10))
    
    epochs = range(1, len(history['train_loss']) + 1)
    
    # Total loss
    axes[0, 0].plot(epochs, history['train_loss'], label='Train')
    if 'val_loss' in history:
        axes[0, 0].plot(epochs, history['val_loss'], label='Val')
    axes[0, 0].set_xlabel('Epoch')
    axes[0, 0].set_ylabel('Loss')
    axes[0, 0].set_title('Total Loss')
    axes[0, 0].legend()
    axes[0, 0].grid(True)
    
    # Classification loss
    if 'train_cls_loss' in history:
        axes[0, 1].plot(epochs, history['train_cls_loss'], label='Train')
        if 'val_cls_loss' in history:
            axes[0, 1].plot(epochs, history['val_cls_loss'], label='Val')
        axes[0, 1].set_xlabel('Epoch')
        axes[0, 1].set_ylabel('Loss')
        axes[0, 1].set_title('Classification Loss')
        axes[0, 1].legend()
        axes[0, 1].grid(True)
    
    # Bbox loss
    if 'train_bbox_loss' in history:
        axes[1, 0].plot(epochs, history['train_bbox_loss'], label='Train')
        if 'val_bbox_loss' in history:
            axes[1, 0].plot(epochs, history['val_bbox_loss'], label='Val')
        axes[1, 0].set_xlabel('Epoch')
        axes[1, 0].set_ylabel('Loss')
        axes[1, 0].set_title('Bbox Loss')
        axes[1, 0].legend()
        axes[1, 0].grid(True)
    
    # Keypoint loss
    if 'train_kpt_loss' in history:
        axes[1, 1].plot(epochs, history['train_kpt_loss'], label='Train')
        if 'val_kpt_loss' in history:
            axes[1, 1].plot(epochs, history['val_kpt_loss'], label='Val')
        axes[1, 1].set_xlabel('Epoch')
        axes[1, 1].set_ylabel('Loss')
        axes[1, 1].set_title('Keypoint Loss')
        axes[1, 1].legend()
        axes[1, 1].grid(True)
    
    plt.tight_layout()
    plt.show()

plot_training_history(history)

In [None]:
# Learning rate schedule
if 'lr' in history:
    plt.figure(figsize=(10, 4))
    plt.plot(history['lr'])
    plt.xlabel('Epoch')
    plt.ylabel('Learning Rate')
    plt.title('Learning Rate Schedule')
    plt.grid(True)
    plt.show()

## 8. Save Final Model

In [None]:
# Save final checkpoint
final_path = CHECKPOINT_DIR / 'final_model.pth'
trainer.save_checkpoint(str(final_path))
print(f"Saved final model to {final_path}")

# Save config
import json
config_path = CHECKPOINT_DIR / 'config.json'
with open(config_path, 'w') as f:
    json.dump(config, f, indent=2)
print(f"Saved config to {config_path}")

## 9. Quick Validation

In [None]:
from synloc.evaluation import visualize_predictions

# Run inference on a few validation samples
model.eval()
model = model.to(device)

# Get a batch
batch = next(iter(val_loader))
images = batch['image'].to(device)

with torch.no_grad():
    results = model.predict(
        images,
        input_size=config['input_size'],
        score_thr=0.3
    )

# Visualize
fig, axes = plt.subplots(2, 2, figsize=(16, 12))

mean = np.array([0.485, 0.456, 0.406])
std = np.array([0.229, 0.224, 0.225])

for i, ax in enumerate(axes.flatten()):
    if i >= len(images):
        break
    
    # Denormalize image
    img = images[i].cpu().permute(1, 2, 0).numpy()
    img = img * std + mean
    img = np.clip(img * 255, 0, 255).astype(np.uint8)
    
    # Draw predictions
    vis_img = visualize_predictions(
        img,
        {
            'bboxes': results[i]['bboxes'].cpu().numpy(),
            'scores': results[i]['scores'].cpu().numpy(),
            'keypoints': results[i]['keypoints'].cpu().numpy(),
            'keypoint_scores': results[i]['keypoint_scores'].cpu().numpy()
        },
        score_thr=0.3
    )
    
    ax.imshow(vis_img)
    ax.set_title(f"Image {i}: {len(results[i]['bboxes'])} detections")
    ax.axis('off')

plt.tight_layout()
plt.show()

## Summary

Training complete! Key outputs:
- Model checkpoint: `{CHECKPOINT_DIR}/final_model.pth`
- Training config: `{CHECKPOINT_DIR}/config.json`

Next steps:
- Proceed to `03_evaluation.ipynb` for mAP-LocSim evaluation
- Tune hyperparameters based on validation metrics