# 04 Depth Estimation

Predict dense depth maps using DINOv3 features and ScanNet dataset.

In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
import matplotlib.pyplot as plt
from PIL import Image
from transformers import AutoImageProcessor
from torch.utils.data import DataLoader

from dinov3_lab.core.backbone import build_dinov3_hf
from dinov3_lab.tasks.depth.heads import DepthHead
from dinov3_lab.data.datasets.scannet import ScanNetDataset

# 1. Setup
model_id = "facebook/dinov3-vitl16-pretrain-lvd1689m"
backbone = build_dinov3_hf(model_id=model_id)
head = DepthHead(in_channels=1024)
processor = AutoImageProcessor.from_pretrained(model_id)

# Freeze backbone
for param in backbone.parameters():
    param.requires_grad = False

In [None]:
# 2. Dataset & Dataloader
# NOTE: Update 'root' to point to your ScanNet dataset location
data_root = "../data/scannet"
try:
    dataset = ScanNetDataset(root=data_root, split="train")
    dataloader = DataLoader(dataset, batch_size=2, shuffle=True)
    print(f"Loaded ScanNet dataset with {len(dataset)} samples.")
except FileNotFoundError:
    print("ScanNet dataset not found. Please check the path.")
    dataset = None
    dataloader = None

if dataset:
    # 3. Training Loop
    optimizer = optim.AdamW(head.parameters(), lr=1e-3)
    criterion = nn.MSELoss() # Depth regression

    print("Starting training...")
    head.train()
    max_steps = 10
    step = 0
    
    for epoch in range(1):
        for images, depths in dataloader:
            if step >= max_steps: break
            
            # Prepare inputs
            inputs = processor(images=images, return_tensors="pt")
            
            # Forward pass backbone (no grad)
            with torch.no_grad():
                out = backbone(inputs.pixel_values)
                grid = backbone.tokens_to_grid(out.patch_tokens, out.patch_hw)
            
            # Forward pass head
            depth_pred = head(grid.float())
            
            # Upsample to match target size
            # Note: depths might need conversion to tensor if PIL
            if not isinstance(depths, torch.Tensor):
                 import numpy as np
                 # Assuming 16-bit png, convert to float meters if needed, or just normalize
                 # Here we just convert to tensor float
                 depths_np = np.array(depths, dtype=np.float32) / 1000.0 # Example scale
                 depths = torch.as_tensor(depths_np).unsqueeze(1) # (B, 1, H, W)
            
            depth_pred_up = nn.functional.interpolate(depth_pred, size=depths.shape[-2:], mode="bilinear", align_corners=False)
            
            loss = criterion(depth_pred_up, depths)
            
            # Backward
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
            
            step += 1
            print(f"Step {step}, Loss: {loss.item():.4f}")
        
        if step >= max_steps: break

In [None]:
# 4. Evaluation / Visualization
if dataset:
    head.eval()
    image, depth_gt = dataset[0]
    inputs = processor(images=image, return_tensors="pt")

    with torch.no_grad():
        out = backbone(inputs.pixel_values)
        grid = backbone.tokens_to_grid(out.patch_tokens, out.patch_hw)
        depth_pred = head(grid.float())
        # Resize for viz
        depth_pred_up = nn.functional.interpolate(depth_pred, size=image.size[::-1], mode="bilinear", align_corners=False)

    plt.figure(figsize=(15, 5))
    plt.subplot(1, 3, 1)
    plt.imshow(image)
    plt.title("Input")
    plt.axis("off")

    plt.subplot(1, 3, 2)
    plt.imshow(depth_gt, cmap="plasma")
    plt.title("Ground Truth Depth")
    plt.axis("off")

    plt.subplot(1, 3, 3)
    plt.imshow(depth_pred_up[0, 0].cpu().numpy(), cmap="plasma")
    plt.title("Predicted Depth")
    plt.axis("off")
    plt.show()