# 02 Linear Probe Segmentation

Train a simple linear head on top of frozen DINOv3 features using ADE20K dataset with PyTorch Lightning.

In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
import pytorch_lightning as pl
import matplotlib.pyplot as plt
from PIL import Image
from transformers import AutoImageProcessor
from torch.utils.data import DataLoader

from dinov3_lab.core.backbone import build_dinov3_hf
from dinov3_lab.tasks.segmentation.heads import LinearSegmentationHead
from dinov3_lab.data.datasets.ade20k import ADE20KDataset

# 1. Setup Lightning Module
class SegmentationModule(pl.LightningModule):
    def __init__(self, model_id, num_classes=150):
        super().__init__()
        self.save_hyperparameters()
        self.backbone = build_dinov3_hf(model_id=model_id)
        self.head = LinearSegmentationHead(in_channels=1024, num_classes=num_classes)
        self.criterion = nn.CrossEntropyLoss(ignore_index=0)
        self.processor = AutoImageProcessor.from_pretrained(model_id)
        
        # Freeze backbone
        for param in self.backbone.parameters():
            param.requires_grad = False

    def forward(self, pixel_values):
        with torch.no_grad():
            out = self.backbone(pixel_values)
            grid = self.backbone.tokens_to_grid(out.patch_tokens, out.patch_hw)
        return self.head(grid.float())

    def training_step(self, batch, batch_idx):
        images, masks = batch
        # Preprocessing on the fly (ideal: move to dataset/collate)
        inputs = self.processor(images=images, return_tensors="pt")
        inputs = {k: v.to(self.device) for k, v in inputs.items()}
        
        logits = self(inputs['pixel_values'])
        
        if not isinstance(masks, torch.Tensor):
             import numpy as np
             masks = torch.as_tensor(np.array(masks), dtype=torch.long, device=self.device)
        else:
             masks = masks.to(self.device)
        
        logits_up = nn.functional.interpolate(logits, size=masks.shape[-2:], mode="bilinear", align_corners=False)
        loss = self.criterion(logits_up, masks)
        self.log("train_loss", loss, prog_bar=True)
        return loss

    def configure_optimizers(self):
        return optim.AdamW(self.head.parameters(), lr=1e-3)

model_id = "facebook/dinov3-vitl16-pretrain-lvd1689m"

In [None]:
# 2. Dataset & Training
data_root = "../data/ade20k"
try:
    dataset = ADE20KDataset(root=data_root, split="train")
    dataloader = DataLoader(dataset, batch_size=2, shuffle=True)
    print(f"Loaded ADE20K dataset with {len(dataset)} samples.")
except FileNotFoundError:
    print("ADE20K dataset not found.")
    dataset = None
    dataloader = None

if dataset:
    module = SegmentationModule(model_id=model_id)
    
    trainer = pl.Trainer(
        max_epochs=1,
        limit_train_batches=10,
        accelerator="gpu",
        devices=1,
        precision="bf16-mixed" # A100 support
    )
    
    trainer.fit(module, dataloader)

In [None]:
# 4. Evaluation / Visualization
if dataset:
    module.eval()
    module.cuda()
    image, mask = dataset[0]
    inputs = module.processor(images=image, return_tensors="pt")
    inputs = {k: v.to(module.device) for k, v in inputs.items()}

    with torch.no_grad():
        logits = module(inputs['pixel_values'])
        logits_up = nn.functional.interpolate(logits, size=mask.size[::-1], mode="bilinear", align_corners=False)
        pred_mask = logits_up.argmax(dim=1)[0]

    plt.figure(figsize=(15, 5))
    plt.subplot(1, 3, 1)
    plt.imshow(image)
    plt.title("Input")
    plt.axis("off")

    plt.subplot(1, 3, 2)
    plt.imshow(mask)
    plt.title("Ground Truth")
    plt.axis("off")

    plt.subplot(1, 3, 3)
    plt.imshow(pred_mask.cpu().numpy())
    plt.title("Prediction")
    plt.axis("off")
    plt.show()