# 02 Linear Probe Segmentation

Train a simple linear head on top of frozen DINOv3 features using ADE20K dataset.

In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
import matplotlib.pyplot as plt
from PIL import Image
from transformers import AutoImageProcessor
from torch.utils.data import DataLoader

from dinov3_lab.core.backbone import build_dinov3_hf
from dinov3_lab.tasks.segmentation.heads import LinearSegmentationHead
from dinov3_lab.data.datasets.ade20k import ADE20KDataset

# 1. Setup
model_id = "facebook/dinov3-vitl16-pretrain-lvd1689m"
backbone = build_dinov3_hf(model_id=model_id)
head = LinearSegmentationHead(in_channels=1024, num_classes=150)
processor = AutoImageProcessor.from_pretrained(model_id)

# Freeze backbone
for param in backbone.parameters():
    param.requires_grad = False

In [None]:
# 2. Dataset & Dataloader
# NOTE: Update 'root' to point to your ADE20K dataset location
data_root = "../data/ade20k"
try:
    dataset = ADE20KDataset(root=data_root, split="train")
    dataloader = DataLoader(dataset, batch_size=2, shuffle=True)
    print(f"Loaded ADE20K dataset with {len(dataset)} samples.")
except FileNotFoundError:
    print("ADE20K dataset not found. Please download it or check the path.")
    # Fallback for demo purposes if needed, or just stop
    dataset = None
    dataloader = None

if dataset:
    # 3. Training Loop
    optimizer = optim.AdamW(head.parameters(), lr=1e-3)
    criterion = nn.CrossEntropyLoss(ignore_index=0) # Assuming 0 is background/ignore in ADE20K raw

    print("Starting training...")
    head.train()
    # Train for a few batches for demo
    max_steps = 10
    step = 0
    
    for epoch in range(1):
        total_loss = 0
        for images, masks in dataloader:
            if step >= max_steps: break
            
            # Prepare inputs
            inputs = processor(images=images, return_tensors="pt")
            
            # Forward pass backbone (no grad)
            with torch.no_grad():
                out = backbone(inputs.pixel_values)
                grid = backbone.tokens_to_grid(out.patch_tokens, out.patch_hw)
            
            # Forward pass head
            logits = head(grid.float())
            
            # Upsample logits to mask size for loss computation
            # Note: masks might need to be converted to tensor long if not done by collate
            # Here we assume basic PIL -> Tensor conversion happened or we do it:
            if not isinstance(masks, torch.Tensor):
                 # Basic conversion if dataset returns PIL
                 import numpy as np
                 masks = torch.as_tensor(np.array(masks), dtype=torch.long)
            
            logits_up = nn.functional.interpolate(logits, size=masks.shape[-2:], mode="bilinear", align_corners=False)
            
            loss = criterion(logits_up, masks)
            
            # Backward
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
            
            total_loss += loss.item()
            step += 1
            print(f"Step {step}, Loss: {loss.item():.4f}")
        
        if step >= max_steps: break

In [None]:
# 4. Evaluation / Visualization
if dataset:
    head.eval()
    image, mask = dataset[0]
    inputs = processor(images=image, return_tensors="pt")

    with torch.no_grad():
        out = backbone(inputs.pixel_values)
        grid = backbone.tokens_to_grid(out.patch_tokens, out.patch_hw)
        logits = head(grid.float())
        logits_up = nn.functional.interpolate(logits, size=mask.size[::-1], mode="bilinear", align_corners=False)
        pred_mask = logits_up.argmax(dim=1)[0]

    plt.figure(figsize=(15, 5))
    plt.subplot(1, 3, 1)
    plt.imshow(image)
    plt.title("Input")
    plt.axis("off")

    plt.subplot(1, 3, 2)
    plt.imshow(mask)
    plt.title("Ground Truth")
    plt.axis("off")

    plt.subplot(1, 3, 3)
    plt.imshow(pred_mask.cpu().numpy())
    plt.title("Prediction")
    plt.axis("off")
    plt.show()