In [None]:
# Copyright (c) 2026, ETH Zurich, Manthan Patel
#
# This source code is licensed under the Apache License, Version 2.0
# found in the LICENSE file in the root directory of this source tree.

import torch
import os
from pathlib import Path
from defm.utils import preprocess_depth_image

# Add the project root (defm) to sys.path
root_dir = Path(os.getcwd()).parent.parent.resolve()

%load_ext autoreload
%autoreload 2


model_list = ["defm_vit_s14", "defm_vit_l14"]  # Available DeFM ViT models

MODEL_NAME = "defm_vit_l14"
PATCH_SIZE = 14  # ViT patch size
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [None]:
model = torch.hub.load(
    repo_or_dir=root_dir,  # Path to your DeFM root folder
    model=MODEL_NAME,
    source="local",
    pretrained=True,
)
model.eval().to(DEVICE)
print(
    f"Loaded model: {MODEL_NAME} with {sum(p.numel() for p in model.parameters()) / 1e6:.2f}M parameters."
)

In [None]:
# Forward Inference Example with Dummy Data
dummy_depth = (
    torch.randn(224, 224, 1) * 100
)  # Dummy depth input with max depth 100 meters

# Target Size must be a multiple of the patch size
# If target size is None, the input size is resized to the nearest multiple of the patch size
# The passed depth image should be in meters
# This is very important for correct metric-depth based normalization

normalized_depth = preprocess_depth_image(
    dummy_depth,
    patch_size=PATCH_SIZE,
).to(DEVICE)

with torch.no_grad():
    output = model.get_intermediate_layers(
        normalized_depth, n=1, reshape=True, return_class_token=True
    )

spatial_tokens = output[0][0]
class_token = output[0][1]

print(f"Output Spatial Tokens: {spatial_tokens.shape}")  # (B, C, H', W')
print(f"Output Class Token: {class_token.shape}")  # (B, C)