In [None]:
# Testing depth estimation speed on a subset of ImageNet made of 1 class (n15075141)
# 1300 elements
input_folder = "image-net/ILSVRC2012_img_val"

## Load Marigold Pipeline with Optimizations

In [None]:
import diffusers
import torch
from diffusers.models.attention_processor import AttnProcessor2_0
from tqdm.auto import tqdm

device = "cuda" if torch.cuda.is_available() else "cpu"

pipe = diffusers.MarigoldDepthPipeline.from_pretrained(
    "prs-eth/marigold-depth-v1-1", variant="fp16", #torch_dtype=torch.float16 # commented this to make it run smoothly on Apple Silicon M1+
).to(device)

# pipe.vae = diffusers.AutoencoderTiny.from_pretrained(
#     "madebyollin/taesd", torch_dtype=torch.float16
# ).cuda()

# pipe.vae.set_attn_processor(AttnProcessor2_0()) 
# pipe.unet.set_attn_processor(AttnProcessor2_0())

# pipe.vae = torch.compile(pipe.vae, mode="reduce-overhead", fullgraph=True)
# pipe.unet = torch.compile(pipe.unet, mode="reduce-overhead", fullgraph=True)

In [None]:
import os
import random
from PIL import Image

# Get list of image files from the folder
image_files = [f for f in os.listdir(input_folder) if f.endswith(('.jpg', '.jpeg', '.JPEG', '.JPG'))]

# Select random image
image_path = os.path.join(input_folder, random.choice(image_files))

# Load the image
image = Image.open(image_path)
print(f"Loaded image from: {image_path}")
print(f"Image size: {image.size}")


In [None]:
import matplotlib.pyplot as plt

plt.imshow(image)
plt.axis('off')
plt.title("Random Image")
plt.show()

image.save("image.png")

In [None]:
depth_est = pipe(image)

In [None]:
depth_image = pipe.image_processor.visualize_depth(depth_est.prediction)[0]

plt.imshow(depth_image)
plt.axis('off')
plt.title("Depth Estimation")
plt.show()

In [None]:
import numpy as np

# The 16-bit PNG file stores the single channel values mapped linearly from the [0, 1] range into [0, 65535]
visualized_depth = pipe.image_processor.export_depth_to_16bit_png(
    depth_est.prediction[0]
)

visualized_depth[0].save("depth_estimation.png")

Color map

In [None]:
from vcs2425 import colormap

i = Image.open("depth_estimation.png")
colormaps = ['viridis', 'Spectral', 'plasma', 'gray']

for cmap in colormaps:
    ci = colormap(np.array(i), cmap)
    plt.imshow(ci)
    plt.title(f"Colormap: {cmap}")
    plt.axis('off')
    plt.show()
    
    # Save the colormapped image as is
    Image.fromarray(ci).save(f"depth_est_{cmap}.png")

In [None]:
depth_preds = []
num_steps = range(1, 21)

for steps in num_steps:
    result = pipe(image, num_inference_steps=steps)
    vis = pipe.image_processor.visualize_depth(result.prediction)[0]
    depth_preds.append(vis)

# Show all depth visualizations
fig, axes = plt.subplots(4, 5, figsize=(20, 16))
for idx, ax in enumerate(axes.flat):
    ax.imshow(depth_preds[idx])
    ax.set_title(f"Steps: {idx+1}")
    ax.axis('off')
plt.tight_layout()
plt.show()


In [None]:
import numpy as np
from torch.nn.functional import mse_loss

# Compute pixel-wise differences between consecutive images
diff_means = []

for i in range(1, len(depth_preds)):
    arr_curr = torch.tensor(np.array(depth_preds[i]).astype(float))
    arr_prev = torch.tensor(np.array(depth_preds[i-1]).astype(float))
    diff_mean = mse_loss(arr_curr, arr_prev).item()
    diff_means.append(diff_mean)

# Plot the mean pixel difference vs inference step
plt.figure(figsize=(8, 5))
plt.plot(range(2, len(depth_preds) + 1), diff_means, marker='o')
plt.xlabel('Inference Step')
plt.ylabel('Mean Pixel Difference')
plt.title('Mean Pixel Difference vs Inference Step')
plt.grid(True)
plt.show()


In [None]:
depth = pipe(
    image,
    num_inference_steps=5,
)

In [None]:
vis = pipe.image_processor.visualize_depth(depth.prediction)
depth_16bit = pipe.image_processor.export_depth_to_16bit_png(depth.prediction)

fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(12, 6))

ax1.imshow(vis[0])
ax1.set_title('Depth Visualization')
ax1.axis('off')

ax2.imshow(depth_16bit[0], cmap='gray')
ax2.set_title('Depth 16-bit PNG')
ax2.axis('off')

# fig.colorbar(ax1.imshow(vis[0]), ax=ax1, orientation='vertical')
plt.show()