# 04 Depth Estimation

Predict dense depth maps using DINOv3 features and ScanNet dataset with PyTorch Lightning.

In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
import pytorch_lightning as pl
import matplotlib.pyplot as plt
from PIL import Image
from transformers import AutoImageProcessor
from torch.utils.data import DataLoader

from dinov3_lab.core.backbone import build_dinov3_hf
from dinov3_lab.tasks.depth.heads import DepthHead
from dinov3_lab.data.datasets.scannet import ScanNetDataset

# 1. Setup Lightning Module
class DepthModule(pl.LightningModule):
    def __init__(self, model_id):
        super().__init__()
        self.save_hyperparameters()
        self.backbone = build_dinov3_hf(model_id=model_id)
        self.head = DepthHead(in_channels=1024)
        self.criterion = nn.MSELoss()
        self.processor = AutoImageProcessor.from_pretrained(model_id)
        
        # Freeze backbone
        for param in self.backbone.parameters():
            param.requires_grad = False

    def forward(self, pixel_values):
        with torch.no_grad():
            out = self.backbone(pixel_values)
            grid = self.backbone.tokens_to_grid(out.patch_tokens, out.patch_hw)
        return self.head(grid.float())

    def training_step(self, batch, batch_idx):
        images, depths = batch
        # Preprocessing
        inputs = self.processor(images=images, return_tensors="pt")
        inputs = {k: v.to(self.device) for k, v in inputs.items()}
        
        depth_pred = self(inputs['pixel_values'])
        
        # Upsample and normalize target
        if not isinstance(depths, torch.Tensor):
             import numpy as np
             depths_np = np.array(depths, dtype=np.float32) / 1000.0
             depths = torch.as_tensor(depths_np, device=self.device).unsqueeze(1)
        else:
             depths = depths.to(self.device).unsqueeze(1)
        
        depth_pred_up = nn.functional.interpolate(depth_pred, size=depths.shape[-2:], mode="bilinear", align_corners=False)
        loss = self.criterion(depth_pred_up, depths)
        self.log("train_loss", loss, prog_bar=True)
        return loss

    def configure_optimizers(self):
        return optim.AdamW(self.head.parameters(), lr=1e-3)

model_id = "facebook/dinov3-vitl16-pretrain-lvd1689m"

In [None]:
# 2. Dataset & Training
data_root = "../data/scannet"
try:
    dataset = ScanNetDataset(root=data_root, split="train")
    dataloader = DataLoader(dataset, batch_size=2, shuffle=True)
    print(f"Loaded ScanNet dataset with {len(dataset)} samples.")
except FileNotFoundError:
    print("ScanNet dataset not found.")
    dataset = None
    dataloader = None

if dataset:
    module = DepthModule(model_id=model_id)
    
    trainer = pl.Trainer(
        max_epochs=1,
        limit_train_batches=10,
        accelerator="gpu",
        devices=1,
        precision="bf16-mixed"
    )
    
    trainer.fit(module, dataloader)

In [None]:
# 4. Evaluation / Visualization
if dataset:
    module.eval()
    module.cuda()
    image, depth_gt = dataset[0]
    inputs = module.processor(images=image, return_tensors="pt")
    inputs = {k: v.to(module.device) for k, v in inputs.items()}

    with torch.no_grad():
        depth_pred = module(inputs['pixel_values'])
        depth_pred_up = nn.functional.interpolate(depth_pred, size=image.size[::-1], mode="bilinear", align_corners=False)

    plt.figure(figsize=(15, 5))
    plt.subplot(1, 3, 1)
    plt.imshow(image)
    plt.title("Input")
    plt.axis("off")

    plt.subplot(1, 3, 2)
    plt.imshow(depth_gt, cmap="plasma")
    plt.title("Ground Truth Depth")
    plt.axis("off")

    plt.subplot(1, 3, 3)
    plt.imshow(depth_pred_up[0, 0].cpu().numpy(), cmap="plasma")
    plt.title("Predicted Depth")
    plt.axis("off")
    plt.show()