# 03 Detection Head

Demonstrate the detection head interface using COCO dataset.

In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
import matplotlib.pyplot as plt
from PIL import Image
from transformers import AutoImageProcessor
from torch.utils.data import DataLoader

from dinov3_lab.core.backbone import build_dinov3_hf
from dinov3_lab.tasks.detection.heads import DenseDetectionHead
from dinov3_lab.data.datasets.coco import COCODataset

# 1. Setup
model_id = "facebook/dinov3-vitl16-pretrain-lvd1689m"
backbone = build_dinov3_hf(model_id=model_id)
head = DenseDetectionHead(in_channels=1024, num_classes=80)
processor = AutoImageProcessor.from_pretrained(model_id)

# Freeze backbone
for param in backbone.parameters():
    param.requires_grad = False

In [None]:
# 2. Dataset & Dataloader
# NOTE: Update 'root' to point to your COCO dataset location
data_root = "../data/coco"
try:
    # Note: Custom collate_fn might be needed for COCO if returning raw dicts
    def collate_fn(batch):
        return tuple(zip(*batch))
        
    dataset = COCODataset(root=data_root, split="train2017")
    dataloader = DataLoader(dataset, batch_size=2, shuffle=True, collate_fn=collate_fn)
    print(f"Loaded COCO dataset with {len(dataset)} samples.")
except (FileNotFoundError, ImportError) as e:
    print(f"COCO dataset not found or pycocotools missing: {e}")
    dataset = None
    dataloader = None

if dataset:
    # 3. Training Loop
    optimizer = optim.AdamW(head.parameters(), lr=1e-3)
    criterion = nn.MSELoss() # Placeholder loss

    print("Starting training...")
    head.train()
    max_steps = 10
    step = 0
    
    for epoch in range(1):
        for images, targets in dataloader:
            if step >= max_steps: break
            
            # Prepare inputs
            inputs = processor(images=list(images), return_tensors="pt")
            
            # Forward pass backbone
            with torch.no_grad():
                out = backbone(inputs.pixel_values)
                grid = backbone.tokens_to_grid(out.patch_tokens, out.patch_hw)
            
            # Forward pass head
            detections = head(grid.float())
            
            # Compute dummy loss since we don't have a real detection loss implemented yet
            # In a real scenario, we'd map COCO targets to the grid
            loss = detections.mean() * 0.0 + 1.0 # Dummy loss to run the loop
            
            # Backward
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
            
            step += 1
            print(f"Step {step}, Loss: {loss.item():.4f}")
        
        if step >= max_steps: break

In [None]:
# 4. Evaluation / Visualization
if dataset:
    head.eval()
    image, target = dataset[0]
    inputs = processor(images=image, return_tensors="pt")

    with torch.no_grad():
        out = backbone(inputs.pixel_values)
        grid = backbone.tokens_to_grid(out.patch_tokens, out.patch_hw)
        detections = head(grid.float())

    plt.figure(figsize=(10, 5))
    plt.subplot(1, 2, 1)
    plt.imshow(image)
    plt.title("Input")
    plt.axis("off")

    plt.subplot(1, 2, 2)
    # Visualize first channel of detection map
    plt.imshow(detections[0, 0].cpu().numpy())
    plt.title("Prediction (Ch 0)")
    plt.axis("off")
    plt.show()