This notebook provides an easy to use interface for the depth estimation model "HybriDepth".

It accompanies our paper : <a href="https://arxiv.org/pdf/2407.18443">Hybrid Depth: Robust Depth Fusion By Leveraging Depth from Focus and Single-Image Priors</a>

# Setup

In [None]:
!git clone https://github.com/cake-lab/HybridDepth.git

In [None]:
!pip install kornia==0.6.7
!pip install pytorch_lightning

In [None]:
%cd ./HybridDepth

In [None]:
import torch
from utils.io import prepare_input_image

import numpy as np
import matplotlib.pyplot as plt

# Load model

Select one of our pre-trained HybridDepth models, each fine-tuned on different datasets and configurations. Specify the desired model configuration and initialize it with `pretrained=True` to load the pre-trained weights.

Available Pre-trained Models:

* `"HybridDepth_NYU5"`: Pre-trained on the NYU Depth V2 dataset using a 5-focal stack input, with both the DFF branch and refinement layer trained.
* `"HybridDepth_NYU10"`: Pre-trained on the NYU Depth V2 dataset using a 10-focal stack input, with both the DFF branch and refinement layer trained.
* `"HybridDepth_DDFF5"`: Pre-trained on the DDFF dataset using a 5-focal stack input.
* `"HybridDepth_NYU_PretrainedDFV5"`: Pre-trained only on the refinement layer with NYU Depth V2 dataset using a 5-focal stack, following pre-training with DFV.

In [None]:
# Choose a model by setting model_name to one of the options above.
# Example: Load the HybridDepth model pre-trained on NYU with DFV pre-training (5-focal stack).
model_name = 'HybridDepth_NYU_PretrainedDFV5'
model = torch.hub.load('cake-lab/HybridDepth', model_name, pretrained=True)
model.eval()
model.cuda()

# Prediction

### Download Sample Images

In [None]:
!wget https://github.com/cake-lab/HybridDepth/releases/download/v2.0/examples.zip

!unzip examples.zip

### Select example 00 or 01

In [None]:
# focal_stack, rgb_img, focus_dist = prepare_input_image('./example00')
focal_stack, rgb_img, focus_dist = prepare_input_image('./example01')

### Run inference

In [None]:
with torch.no_grad():
    out = model(rgb_img, focal_stack, focus_dist)

### Visualize

In [None]:
metric_depth = out[0].squeeze().cpu().numpy()
rgb_img = rgb_img.squeeze().cpu().numpy().transpose(1, 2, 0)
rgb_img = (rgb_img * 255).astype(np.uint8)

# visualize the results RGB + depth
plt.figure(figsize=(10, 10))
plt.subplot(2, 2, 1)
plt.imshow(rgb_img)
plt.title('RGB Image')
plt.axis('off')

plt.subplot(2, 2, 2)
plt.imshow(metric_depth, cmap='plasma')
plt.title('Depth Map')
cbar = plt.colorbar()
cbar.set_label('Depth (meters)')
plt.axis('off')