# 02 Linear Probe Segmentation

Train a simple linear head on top of frozen DINOv3 features.

In [None]:
import torch
import matplotlib.pyplot as plt
from PIL import Image
from transformers import AutoImageProcessor
from dinov3_lab.core.backbone import build_dinov3_hf
from dinov3_lab.tasks.segmentation.heads import LinearSegmentationHead

# 1. Setup
backbone = build_dinov3_hf()
head = LinearSegmentationHead(in_channels=1024, num_classes=150)
processor = AutoImageProcessor.from_pretrained("facebook/dinov3-vitl16-pretrain-lvd1689m")

# 2. Load Image
image_path = "../data/test_images/demo.jpg"
try:
    image = Image.open(image_path).convert("RGB")
except FileNotFoundError:
    print("Demo image not found. Using random noise.")
    image = Image.new('RGB', (448, 448), color = 'gray')

inputs = processor(images=image, return_tensors="pt")

# 3. Forward Pass
with torch.no_grad():
    out = backbone(inputs.pixel_values)
    
# Convert tokens to grid
grid = backbone.tokens_to_grid(out.patch_tokens, out.patch_hw)
print(f"Grid shape: {grid.shape}")

# 4. Head Prediction
logits = head(grid)
print(f"Logits shape: {logits.shape}")

# Upsample to image size
image_size = image.size[::-1] # (H, W)
logits_up = backbone.upsample_grid_to_image(logits, image_size)
print(f"Upsampled logits: {logits_up.shape}")

# Visualize argmax (untrained head, so random noise)
pred_mask = logits_up.argmax(dim=1)[0].cpu().numpy()

plt.figure(figsize=(10, 5))
plt.subplot(1, 2, 1)
plt.imshow(image)
plt.title("Input")
plt.axis("off")

plt.subplot(1, 2, 2)
plt.imshow(pred_mask)
plt.title("Prediction (Untrained)")
plt.axis("off")
plt.show()