# 03 Detection Head

Demonstrate the detection head interface using COCO dataset with PyTorch Lightning.

In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
import pytorch_lightning as pl
import matplotlib.pyplot as plt
from PIL import Image
from transformers import AutoImageProcessor
from torch.utils.data import DataLoader

from dinov3_lab.core.backbone import build_dinov3_hf
from dinov3_lab.tasks.detection.heads import DenseDetectionHead
from dinov3_lab.data.datasets.coco import COCODataset

# 1. Setup Lightning Module
class DetectionModule(pl.LightningModule):
    def __init__(self, model_id, num_classes=80):
        super().__init__()
        self.save_hyperparameters()
        self.backbone = build_dinov3_hf(model_id=model_id)
        self.head = DenseDetectionHead(in_channels=1024, num_classes=num_classes)
        self.processor = AutoImageProcessor.from_pretrained(model_id)
        
        # Freeze backbone
        for param in self.backbone.parameters():
            param.requires_grad = False

    def forward(self, pixel_values):
        with torch.no_grad():
            out = self.backbone(pixel_values)
            grid = self.backbone.tokens_to_grid(out.patch_tokens, out.patch_hw)
        return self.head(grid.float())

    def training_step(self, batch, batch_idx):
        images, targets = batch
        # Preprocessing
        inputs = self.processor(images=list(images), return_tensors="pt")
        inputs = {k: v.to(self.device) for k, v in inputs.items()}
        
        detections = self(inputs['pixel_values'])
        
        # Dummy loss logic from original notebook
        loss = detections.mean() * 0.0 + 1.0
        self.log("train_loss", loss, prog_bar=True)
        return loss

    def configure_optimizers(self):
        return optim.AdamW(self.head.parameters(), lr=1e-3)

model_id = "facebook/dinov3-vitl16-pretrain-lvd1689m"

In [None]:
# 2. Dataset & Training
data_root = "../data/coco"
try:
    def collate_fn(batch):
        return tuple(zip(*batch))
        
    dataset = COCODataset(root=data_root, split="train2017")
    dataloader = DataLoader(dataset, batch_size=2, shuffle=True, collate_fn=collate_fn)
    print(f"Loaded COCO dataset with {len(dataset)} samples.")
except (FileNotFoundError, ImportError) as e:
    print(f"COCO dataset not found: {e}")
    dataset = None
    dataloader = None

if dataset:
    module = DetectionModule(model_id=model_id)
    
    trainer = pl.Trainer(
        max_epochs=1,
        limit_train_batches=10,
        accelerator="gpu",
        devices=1,
        precision="bf16-mixed"
    )
    
    trainer.fit(module, dataloader)

In [None]:
# 4. Evaluation / Visualization
if dataset:
    module.eval()
    module.cuda()
    image, target = dataset[0]
    inputs = module.processor(images=image, return_tensors="pt")
    inputs = {k: v.to(module.device) for k, v in inputs.items()}

    with torch.no_grad():
        detections = module(inputs['pixel_values'])

    plt.figure(figsize=(10, 5))
    plt.subplot(1, 2, 1)
    plt.imshow(image)
    plt.title("Input")
    plt.axis("off")

    plt.subplot(1, 2, 2)
    plt.imshow(detections[0, 0].cpu().numpy())
    plt.title("Prediction (Ch 0)")
    plt.axis("off")
    plt.show()