In [1]:
from PIL import Image
import depth_pro

import torch

import time
import os
import matplotlib.pyplot as plt
import numpy as np
from pathlib import Path


def timeit(func):
    def wrapper(*args, **kwargs):
        start_time = time.time()
        result = func(*args, **kwargs)
        end_time = time.time()
        print(f"Function '{func.__name__}' took {end_time - start_time:.4f} seconds")
        return result
    return wrapper

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
# Select best available device
if torch.cuda.is_available():
    device = torch.device("cuda")
elif torch.backends.mps.is_available():
    device = torch.device("mps")
else:
    device = torch.device("cpu")

# Load model and preprocessing transform
model = depth_pro.create_model(device=device, precision=torch.half)
model.eval()
print(f"Model loaded on device: {device}")

  state_dict = torch.load(config.checkpoint_uri, map_location="cpu")


Model loaded on device: mps


In [3]:
@timeit
def inference(model, image_path, device):
    # Load and preprocess image
    image, _, f_px = depth_pro.load_rgb(image_path)

    with torch.no_grad():
        # Depth in [m] and Focal length in pixels
        depth, focal_length_px = model(image, f_px_override=f_px)

    return depth, focal_length_px

# Load your RGB image
def visualize_results(image_path, depth=None, depth_image_path=None):
    if depth is None and depth_image_path is None:
        raise ValueError("both the depth and the depth_image cannot be none!")

    if depth is not None and depth_image_path is not None:
        raise ValueError("both the depth and the depth_image cannot be provided!")
    
    image = Image.open(image_path)

    if depth is not None:
        # Convert depth map to numpy for visualization
        depth_map_np = depth.cpu().numpy()
    
        # Normalize depth map for better visualization
        depth_map_normalized = (depth_map_np - depth_map_np.min()) / (depth_map_np.max() - depth_map_np.min())
    else:
        depth_map_normalized = Image.open(depth_image_path)
    
    # Plot the images side by side
    fig, axes = plt.subplots(1, 2, figsize=(12, 6))
    
    # Show the original image
    axes[0].imshow(image)
    axes[0].set_title("RGB Image")
    axes[0].axis("off")
    
    # Show the depth map (use colormap for better visibility)
    axes[1].imshow(depth_map_normalized, cmap="viridis")
    axes[1].set_title("Depth Map")
    axes[1].axis("off")
    
    plt.show()

In [4]:
image_names = ["409238553452200_color.png", "409243305337100_color.png", "409547560576900_color.png", "409570843067100_color.png", 
               "409573045097900_color.png", "409592125011600_color.png"]

root_path = Path("data/saved_images/")
# image_path = "data/saved_images/409238553452200_color.png"
for image_name in image_names:
    image_path = root_path / image_name
    image = Image.open(image_path)
    depth, focallength_px = inference(model=model, image_path=image_path, device=device)
    # depth = torch.rand(image.size)
    depth_name = image_name.replace("_color", "_depthpro")
    tensor_path = Path(depth_name).with_suffix(".pt")
    torch.save(depth, open(root_path / tensor_path, "wb"))

Function 'inference' took 8.5100 seconds
Function 'inference' took 8.5221 seconds
Function 'inference' took 8.4280 seconds
Function 'inference' took 8.4504 seconds
Function 'inference' took 8.9466 seconds
Function 'inference' took 9.8759 seconds


In [None]:
visualize_results(image_path=image_path, depth=depth)
depth_image_path = "data/saved_images/409238553452200_depth_scaled.png"
visualize_results(image_path=image_path, depth_image_path=depth_image_path)

In [None]:
image_sizes = [(512, 512)]#, (1024, 1024), (1920, 1080), (2268, 3024)]
num_runs = 10  # Number of repetitions per image size
warmup_runs = 2


print(f"Using device: {device}\n")

for size in image_sizes:
    print(f"Running inference for image size: {size}")

    # Resize and save a temporary image
    image = Image.open(image_path).resize(size)
    temp_image_path = "temp_resized.jpg"
    image.save(temp_image_path)

    times = []

    for nr in range(num_runs):
        # Warmup runs not timed
        if nr >= warmup_runs:
            start_time = time.perf_counter()

        depth, focallength_px = inference(model=model, image_path=temp_image_path, device=device)

        if nr >= warmup_runs:
            end_time = time.perf_counter()
            elapsed_time = (end_time - start_time) * 1000  # convert to ms
            times.append(elapsed_time)

        # Sync device to ensure accurate timing
        if device.type == "cuda":
            torch.cuda.synchronize()
        elif device.type == "mps":
            torch.mps.synchronize()

    mean_time = np.mean(times)
    std_time = np.std(times)

    print(f"Inference time for {size}: {mean_time:.2f} ± {std_time:.2f} ms\n")


## Image size and performance

| Image Size      | Inference Time (ms) | Std Dev (ms) |
|---------------|-------------------|-------------|
| (512, 512)   | 547.06            | ±2.49       |
| (1024, 1024) | 593.40            | ±9.39       |
| (1920, 1080) | 638.44            | ±6.34       |
| (2268, 3024) | 699.48            | ±1.97       |


In [None]:
# Directory containing images
image_dir = "data/hands"
image_paths = [os.path.join(image_dir, f) for f in os.listdir(image_dir) if f.endswith((".jpg", ".png"))]

# Define grid size: each row contains (Image, Depth Map) pairs
num_cols = 2
num_rows = len(image_paths)

fig, axes = plt.subplots(num_rows, num_cols, figsize=(10, 5 * num_rows))

# Process and display each image with its depth map
for i, image_path in enumerate(image_paths):
    depth, _ = inference(image_path)  # Perform inference
    image = Image.open(image_path)  # Load image

    depth_np = depth.cpu().numpy() if isinstance(depth, torch.Tensor) else depth
    depth_min, depth_max = depth_np.min(), depth_np.max()
    if depth_max > depth_min:  # Avoid division by zero
        depth_np = (depth_np - depth_min) / (depth_max - depth_min)
   

    # Display original image
    ax1 = axes[i, 0] if num_rows > 1 else axes[0]
    ax1.imshow(image)
    ax1.axis("off")
    ax1.set_title(f"Image {i+1}")

    # Display normalized depth map
    ax2 = axes[i, 1] if num_rows > 1 else axes[1]
    ax2.imshow(depth_np, cmap="viridis")  # "viridis" for better visibility
    ax2.axis("off")
    ax2.set_title(f"Depth Map {i+1}")

plt.tight_layout()
plt.show()