# 03 Detection Head

Demonstrate the detection head interface.

In [None]:
import torch
import matplotlib.pyplot as plt
from PIL import Image
from transformers import AutoImageProcessor
from dinov3_lab.core.backbone import build_dinov3_hf
from dinov3_lab.tasks.detection.heads import DenseDetectionHead

# 1. Setup
model_id = "facebook/dinov3-vitl16-pretrain-lvd1689m"
backbone = build_dinov3_hf(model_id=model_id)
head = DenseDetectionHead(in_channels=1024, num_classes=80)
processor = AutoImageProcessor.from_pretrained(model_id)

# 2. Load Image
image_path = "../data/test_images/demo.jpg"
try:
    image = Image.open(image_path).convert("RGB")
except FileNotFoundError:
    print("Demo image not found. Using random noise.")
    image = Image.new('RGB', (448, 448), color = 'gray')

inputs = processor(images=image, return_tensors="pt")

# 3. Forward Pass
with torch.no_grad():
    out = backbone(inputs.pixel_values)
    grid = backbone.tokens_to_grid(out.patch_tokens, out.patch_hw)
    
# Cast grid to float32 to match head weights
detections = head(grid.float())
print(f"Detection map shape: {detections.shape}")

# Visualize one channel of detection map (untrained)
det_map = detections[0, 0].float().detach().cpu().numpy()

plt.figure(figsize=(10, 5))
plt.subplot(1, 2, 1)
plt.imshow(image)
plt.title("Input")
plt.axis("off")

plt.subplot(1, 2, 2)
plt.imshow(det_map)
plt.title("Detection Map (Ch 0, Untrained)")
plt.axis("off")
plt.show()