# SoccerNet SynLoc: Evaluation

This notebook covers:
1. Loading trained model
2. mAP-LocSim evaluation
3. Error analysis
4. Ablation study framework

## 1. Setup

In [None]:
import sys
import os
from pathlib import Path

IN_COLAB = 'google.colab' in sys.modules

if IN_COLAB:
    from google.colab import drive
    drive.mount('/content/drive')
    
    if not os.path.exists('soccernet-synloc'):
        !git clone https://github.com/YOUR_USERNAME/soccernet-synloc.git
        %cd soccernet-synloc
        !pip install -e .[dev] -q
    
    DATA_ROOT = Path('/content/drive/MyDrive/SoccerNet/synloc')
    CHECKPOINT_DIR = Path('/content/drive/MyDrive/SoccerNet/checkpoints')
else:
    DATA_ROOT = Path('./data/synloc')
    CHECKPOINT_DIR = Path('./checkpoints')

print(f"Data root: {DATA_ROOT}")
print(f"Checkpoint dir: {CHECKPOINT_DIR}")

In [None]:
import torch
import numpy as np
import matplotlib.pyplot as plt
import json
from torch.utils.data import DataLoader
from tqdm.auto import tqdm

device = 'cuda' if torch.cuda.is_available() else 'cpu'
print(f"Device: {device}")

## 2. Load Model

In [None]:
from synloc.models import YOLOXPose

# Load config
config_path = CHECKPOINT_DIR / 'config.json'
if config_path.exists():
    with open(config_path) as f:
        config = json.load(f)
    print("Loaded config:")
    for k, v in config.items():
        print(f"  {k}: {v}")
else:
    # Default config
    config = {
        'model_variant': 'tiny',
        'num_keypoints': 2,
        'input_size': (640, 640),
        'batch_size': 16
    }
    print("Using default config")

In [None]:
# Create model
model = YOLOXPose(
    variant=config['model_variant'],
    num_keypoints=config['num_keypoints'],
    input_size=tuple(config['input_size'])
)

# Load weights
checkpoint_path = CHECKPOINT_DIR / 'final_model.pth'
if checkpoint_path.exists():
    checkpoint = torch.load(checkpoint_path, map_location='cpu')
    model.load_state_dict(checkpoint['model_state_dict'])
    print(f"Loaded checkpoint from epoch {checkpoint.get('epoch', 'unknown')}")
else:
    print("No checkpoint found, using random weights")

model = model.to(device)
model.eval()
print("Model ready")

## 3. Prepare Validation Data

In [None]:
from synloc.data import SynLocDataset, get_val_transforms

# Test set for evaluation
test_dataset = SynLocDataset(
    ann_file=str(DATA_ROOT / 'test/annotations.json'),
    img_dir=str(DATA_ROOT / 'test/images'),
    transforms=get_val_transforms(config['input_size'][0]),
    input_size=tuple(config['input_size'])
)

test_loader = DataLoader(
    test_dataset,
    batch_size=config['batch_size'],
    shuffle=False,
    num_workers=4,
    collate_fn=SynLocDataset.collate_fn
)

print(f"Test dataset: {len(test_dataset)} images")

## 4. Run Inference

In [None]:
from synloc.evaluation import run_inference

# Run inference
results = run_inference(
    model,
    test_loader,
    device=device,
    score_thr=0.01,  # Low threshold to get full P-R curve
    nms_thr=0.65,
    max_per_img=100
)

print(f"Total detections: {len(results)}")

## 5. mAP-LocSim Evaluation

In [None]:
from synloc.evaluation import evaluate_predictions

# Evaluate with mAP-LocSim
metrics = evaluate_predictions(
    gt_file=str(DATA_ROOT / 'test/annotations.json'),
    results=results,
    position_from_keypoint_index=1,  # pelvis_ground
    score_threshold=None  # Auto-select via F1
)

print("\n" + "="*50)
print("Evaluation Results:")
print("="*50)
for k, v in metrics.items():
    print(f"  {k}: {v:.4f}")

## 6. Precision-Recall Curve

In [None]:
from synloc.evaluation.locsim import LocSimCOCOeval
from xtcocotools.coco import COCO

# Run full evaluation to get P-R data
coco_gt = COCO(str(DATA_ROOT / 'test/annotations.json'))
coco_dt = coco_gt.loadRes(results)

coco_eval = LocSimCOCOeval(coco_gt, coco_dt, 'bbox')
coco_eval.params.useSegm = None
coco_eval.params.position_from_keypoint_index = 1

coco_eval.evaluate()
coco_eval.accumulate()

In [None]:
# Plot P-R curve at LocSim=0.5
fig, axes = plt.subplots(1, 2, figsize=(14, 5))

# Precision-Recall
precision = coco_eval.eval['precision_50']
recall = coco_eval.eval['recall_50']
f1 = coco_eval.eval['f1_50']
scores = coco_eval.eval['scores_50']

axes[0].plot(recall, precision, 'b-', linewidth=2)
axes[0].fill_between(recall, precision, alpha=0.2)
axes[0].set_xlabel('Recall')
axes[0].set_ylabel('Precision')
axes[0].set_title('Precision-Recall Curve @ LocSim=0.5')
axes[0].set_xlim([0, 1])
axes[0].set_ylim([0, 1])
axes[0].grid(True)

# F1 vs Score threshold
valid = scores > 0
axes[1].plot(scores[valid], f1[valid], 'g-', linewidth=2)
best_idx = f1.argmax()
axes[1].axvline(x=scores[best_idx], color='r', linestyle='--', 
                label=f'Best threshold: {scores[best_idx]:.3f}')
axes[1].set_xlabel('Score Threshold')
axes[1].set_ylabel('F1 Score')
axes[1].set_title('F1 Score vs Score Threshold')
axes[1].legend()
axes[1].grid(True)

plt.tight_layout()
plt.show()

print(f"Best F1: {f1[best_idx]:.4f} at threshold {scores[best_idx]:.4f}")

## 7. Error Analysis

In [None]:
# Analyze errors by category
from collections import defaultdict

# Use optimal threshold
optimal_threshold = metrics['score_threshold']

# Count TP, FP, FN per image
error_analysis = defaultdict(lambda: {'tp': 0, 'fp': 0, 'fn': 0, 'gt_count': 0})

rng = coco_eval.params.areaRng[coco_eval.params.areaRngLbl.index('all')]
iou_idx = np.where(coco_eval.params.iouThrs == 0.5)[0][0]

for e in coco_eval.evalImgs:
    if e is None or e['aRng'] != rng:
        continue
    
    img_id = e['image_id']
    error_analysis[img_id]['gt_count'] = len(e['gtIds'])
    
    # Count matches
    dt_scores = np.array(e['dtScores'])
    dt_matches = e['dtMatches'][iou_idx]
    
    above_thr = dt_scores >= optimal_threshold
    matched = dt_matches > 0
    
    tp = (above_thr & matched).sum()
    fp = (above_thr & ~matched).sum()
    fn = len(e['gtIds']) - tp
    
    error_analysis[img_id]['tp'] = tp
    error_analysis[img_id]['fp'] = fp
    error_analysis[img_id]['fn'] = max(0, fn)

# Aggregate
total_tp = sum(e['tp'] for e in error_analysis.values())
total_fp = sum(e['fp'] for e in error_analysis.values())
total_fn = sum(e['fn'] for e in error_analysis.values())

print(f"Error Analysis @ threshold={optimal_threshold:.3f}")
print(f"  True Positives:  {total_tp}")
print(f"  False Positives: {total_fp}")
print(f"  False Negatives: {total_fn}")
print(f"  Precision: {total_tp/(total_tp+total_fp):.4f}")
print(f"  Recall: {total_tp/(total_tp+total_fn):.4f}")

In [None]:
# Find images with most errors
images_by_fn = sorted(error_analysis.items(), 
                      key=lambda x: x[1]['fn'], reverse=True)

print("\nImages with most false negatives (missed detections):")
for img_id, stats in images_by_fn[:10]:
    print(f"  Image {img_id}: {stats['fn']} FN, {stats['gt_count']} GT")

images_by_fp = sorted(error_analysis.items(),
                      key=lambda x: x[1]['fp'], reverse=True)

print("\nImages with most false positives (wrong detections):")
for img_id, stats in images_by_fp[:10]:
    print(f"  Image {img_id}: {stats['fp']} FP")

## 8. Visualize Errors

In [None]:
from synloc.visualization import draw_pitch, visualize_bev_predictions
from synloc.data.camera import keypoint_to_world
from PIL import Image

def visualize_errors(img_id, coco_gt, results, data_root, threshold=0.3):
    """Visualize predictions vs ground truth for an image."""
    # Get image info and annotations
    img_info = coco_gt.loadImgs(img_id)[0]
    ann_ids = coco_gt.getAnnIds(imgIds=img_id)
    anns = coco_gt.loadAnns(ann_ids)
    
    # Get predictions for this image
    preds = [r for r in results if r['image_id'] == img_id and r['score'] >= threshold]
    
    fig, axes = plt.subplots(1, 2, figsize=(16, 6))
    
    # Image view
    img_path = data_root / 'test/images' / img_info['file_name']
    img = Image.open(img_path)
    axes[0].imshow(img)
    
    # Draw GT bboxes in blue
    for ann in anns:
        x, y, w, h = ann['bbox']
        rect = plt.Rectangle((x, y), w, h, fill=False, 
                              edgecolor='blue', linewidth=2, linestyle='--')
        axes[0].add_patch(rect)
    
    # Draw predictions in green
    for pred in preds:
        x, y, w, h = pred['bbox']
        rect = plt.Rectangle((x, y), w, h, fill=False, 
                              edgecolor='green', linewidth=2)
        axes[0].add_patch(rect)
        axes[0].text(x, y-5, f"{pred['score']:.2f}", color='green', fontsize=8)
    
    axes[0].set_title(f"Image {img_id}: {len(anns)} GT (blue), {len(preds)} Pred (green)")
    axes[0].axis('off')
    
    # BEV view
    gt_positions = np.array([ann['position_on_pitch'][:2] for ann in anns])
    
    # Project predictions to BEV
    pred_positions = []
    camera_matrix = torch.tensor(img_info['camera_matrix'], dtype=torch.float32)
    undist_poly = torch.tensor(img_info['undist_poly'], dtype=torch.float32)
    
    for pred in preds:
        kpts = np.array(pred['keypoints']).reshape(-1, 3)
        # Use pelvis_ground (index 1)
        kpt = torch.tensor(kpts[1, :2], dtype=torch.float32).unsqueeze(0)
        # Normalize
        kpt_norm = (kpt - torch.tensor([(img_info['width']-1)/2, (img_info['height']-1)/2])) / img_info['width']
        world = keypoint_to_world(camera_matrix, undist_poly, kpt_norm)
        pred_positions.append(world[0, :2].numpy())
    
    if pred_positions:
        pred_positions = np.array(pred_positions)
    else:
        pred_positions = np.array([]).reshape(0, 2)
    
    visualize_bev_predictions(
        pred_positions, gt_positions,
        ax=axes[1],
        title='BEV View'
    )
    
    plt.tight_layout()
    plt.show()

In [None]:
# Visualize images with errors
for img_id, stats in images_by_fn[:3]:
    print(f"\nImage {img_id}: {stats['fn']} FN, {stats['fp']} FP")
    try:
        visualize_errors(img_id, coco_gt, results, DATA_ROOT, threshold=optimal_threshold)
    except Exception as e:
        print(f"Error visualizing: {e}")

## 9. Localization Error Analysis

In [None]:
# Compute localization errors for matched detections
loc_errors = []

for e in coco_eval.evalImgs:
    if e is None:
        continue
    
    img_id = e['image_id']
    dt_matches = e['dtMatches'][iou_idx]
    dt_scores = np.array(e['dtScores'])
    
    # Get image info
    img_info = coco_gt.loadImgs(int(img_id))[0]
    camera_matrix = torch.tensor(img_info['camera_matrix'], dtype=torch.float32)
    undist_poly = torch.tensor(img_info['undist_poly'], dtype=torch.float32)
    
    # Get GT positions
    gt_anns = coco_gt.loadAnns(coco_gt.getAnnIds(imgIds=int(img_id)))
    gt_id_to_pos = {ann['id']: ann['position_on_pitch'][:2] for ann in gt_anns}
    
    # Get detections for this image
    img_dets = [r for r in results if r['image_id'] == img_id]
    
    for i, (gt_id, score) in enumerate(zip(dt_matches, dt_scores)):
        if gt_id > 0 and score >= optimal_threshold:
            # Project detection to world
            det = img_dets[i] if i < len(img_dets) else None
            if det is None:
                continue
            
            kpts = np.array(det['keypoints']).reshape(-1, 3)
            kpt = torch.tensor(kpts[1, :2], dtype=torch.float32).unsqueeze(0)
            kpt_norm = (kpt - torch.tensor([(img_info['width']-1)/2, (img_info['height']-1)/2])) / img_info['width']
            
            try:
                world = keypoint_to_world(camera_matrix, undist_poly, kpt_norm)
                pred_pos = world[0, :2].numpy()
                gt_pos = np.array(gt_id_to_pos[int(gt_id)])
                
                error = np.linalg.norm(pred_pos - gt_pos)
                loc_errors.append(error)
            except:
                pass

loc_errors = np.array(loc_errors)
print(f"Localization errors for {len(loc_errors)} matched detections:")
print(f"  Mean: {loc_errors.mean():.3f} m")
print(f"  Median: {np.median(loc_errors):.3f} m")
print(f"  Std: {loc_errors.std():.3f} m")
print(f"  90th percentile: {np.percentile(loc_errors, 90):.3f} m")

In [None]:
# Plot error distribution
plt.figure(figsize=(12, 4))

plt.subplot(1, 2, 1)
plt.hist(loc_errors, bins=50, edgecolor='black')
plt.axvline(x=1.0, color='r', linestyle='--', label='LocSim tau=1m')
plt.xlabel('Localization Error (meters)')
plt.ylabel('Count')
plt.title('Distribution of Localization Errors')
plt.legend()

plt.subplot(1, 2, 2)
thresholds = np.linspace(0, 5, 100)
fractions = [np.mean(loc_errors <= t) for t in thresholds]
plt.plot(thresholds, fractions)
plt.axvline(x=1.0, color='r', linestyle='--', label='LocSim tau=1m')
plt.xlabel('Distance Threshold (meters)')
plt.ylabel('Fraction of Detections')
plt.title('Cumulative Localization Accuracy')
plt.legend()
plt.grid(True)

plt.tight_layout()
plt.show()

## 10. Ablation Framework

In [None]:
def run_ablation(model_variant, input_size, checkpoint_path=None):
    """Run evaluation with specific settings."""
    # Create model
    model = YOLOXPose(
        variant=model_variant,
        num_keypoints=2,
        input_size=input_size
    )
    
    if checkpoint_path and Path(checkpoint_path).exists():
        checkpoint = torch.load(checkpoint_path, map_location='cpu')
        model.load_state_dict(checkpoint['model_state_dict'])
    
    model = model.to(device)
    model.eval()
    
    # Create dataset
    test_dataset = SynLocDataset(
        ann_file=str(DATA_ROOT / 'test/annotations.json'),
        img_dir=str(DATA_ROOT / 'test/images'),
        transforms=get_val_transforms(input_size[0]),
        input_size=input_size
    )
    
    test_loader = DataLoader(
        test_dataset,
        batch_size=16,
        shuffle=False,
        num_workers=4,
        collate_fn=SynLocDataset.collate_fn
    )
    
    # Run inference
    results = run_inference(model, test_loader, device=device)
    
    # Evaluate
    metrics = evaluate_predictions(
        gt_file=str(DATA_ROOT / 'test/annotations.json'),
        results=results,
        position_from_keypoint_index=1
    )
    
    return metrics

In [None]:
# Example ablation: compare input sizes
# Note: You would need checkpoints trained at each size

ablation_results = {
    'current': metrics  # Already computed
}

# Uncomment to run ablation
# for size in [(640, 640), (960, 960)]:
#     checkpoint = CHECKPOINT_DIR / f'model_{size[0]}.pth'
#     if checkpoint.exists():
#         ablation_results[f'{size[0]}x{size[1]}'] = run_ablation(
#             config['model_variant'], size, checkpoint
#         )

print("\nAblation Results:")
print("-" * 60)
print(f"{'Setting':<20} {'mAP-LocSim':>12} {'Precision':>12} {'Recall':>12}")
print("-" * 60)
for name, m in ablation_results.items():
    print(f"{name:<20} {m['mAP_locsim']:>12.4f} {m['precision']:>12.4f} {m['recall']:>12.4f}")

## Summary

Evaluation complete! Key findings:

1. **mAP-LocSim**: {metrics.get('mAP_locsim', 'N/A'):.4f}
2. **Best F1**: {metrics.get('f1', 'N/A'):.4f} @ threshold {metrics.get('score_threshold', 'N/A'):.3f}
3. **Mean localization error**: {loc_errors.mean():.3f} m

Next steps:
- Proceed to `04_submission.ipynb` to generate challenge submission
- Consider improvements based on error analysis