Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@

class DepthAnythingConfig(PretrainedConfig):
r"""
This is the configuration class to store the configuration of a [`DepthAnythingModel`]. It is used to instantiate an DepthAnything
This is the configuration class to store the configuration of a [`DepthAnythingModel`]. It is used to instantiate a DepthAnything
model according to the specified arguments, defining the model architecture. Instantiating a configuration with the
defaults will yield a similar configuration to that of the DepthAnything
[LiheYoung/depth-anything-small-hf](https://huggingface.co/LiheYoung/depth-anything-small-hf) architecture.
Expand Down Expand Up @@ -67,6 +67,11 @@ class DepthAnythingConfig(PretrainedConfig):
The index of the features to use in the depth estimation head.
head_hidden_size (`int`, *optional*, defaults to 32):
The number of output channels in the second convolution of the depth estimation head.
depth_estimation_type (`str`, *optional*, defaults to `"relative"`):
The type of depth estimation to use. Can be one of `["relative", "metric"]`.
max_depth (`float`, *optional*):
The maximum depth to use for the "metric" depth estimation head. 20 should be used for indoor models
and 80 for outdoor models. For "relative" depth estimation, this value is ignored.

Example:

Expand Down Expand Up @@ -100,6 +105,8 @@ def __init__(
fusion_hidden_size=64,
head_in_index=-1,
head_hidden_size=32,
depth_estimation_type="relative",
max_depth=None,
**kwargs,
):
super().__init__(**kwargs)
Expand Down Expand Up @@ -139,6 +146,10 @@ def __init__(
self.fusion_hidden_size = fusion_hidden_size
self.head_in_index = head_in_index
self.head_hidden_size = head_hidden_size
if depth_estimation_type not in ["relative", "metric"]:
raise ValueError("depth_estimation_type must be one of ['relative', 'metric']")
self.depth_estimation_type = depth_estimation_type
self.max_depth = max_depth if max_depth else 1

def to_dict(self):
"""
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -56,12 +56,21 @@ def get_dpt_config(model_name):
else:
raise NotImplementedError(f"Model not supported: {model_name}")

if "metric" in model_name:
depth_estimation_type = "metric"
max_depth = 20 if "indoor" in model_name else 80
else:
depth_estimation_type = "relative"
max_depth = None

config = DepthAnythingConfig(
reassemble_hidden_size=backbone_config.hidden_size,
patch_size=backbone_config.patch_size,
backbone_config=backbone_config,
fusion_hidden_size=fusion_hidden_size,
neck_hidden_sizes=neck_hidden_sizes,
depth_estimation_type=depth_estimation_type,
max_depth=max_depth,
)

return config
Expand Down Expand Up @@ -178,6 +187,12 @@ def prepare_img():
"depth-anything-v2-small": "depth_anything_v2_vits.pth",
"depth-anything-v2-base": "depth_anything_v2_vitb.pth",
"depth-anything-v2-large": "depth_anything_v2_vitl.pth",
"depth-anything-v2-metric-indoor-small": "depth_anything_v2_metric_hypersim_vits.pth",
"depth-anything-v2-metric-indoor-base": "depth_anything_v2_metric_hypersim_vitb.pth",
"depth-anything-v2-metric-indoor-large": "depth_anything_v2_metric_hypersim_vitl.pth",
"depth-anything-v2-metric-outdoor-small": "depth_anything_v2_metric_vkitti_vits.pth",
"depth-anything-v2-metric-outdoor-base": "depth_anything_v2_metric_vkitti_vitb.pth",
"depth-anything-v2-metric-outdoor-large": "depth_anything_v2_metric_vkitti_vitl.pth",
# v2-giant pending
}

Expand All @@ -198,6 +213,12 @@ def convert_dpt_checkpoint(model_name, pytorch_dump_folder_path, push_to_hub, ve
"depth-anything-v2-small": "depth-anything/Depth-Anything-V2-Small",
"depth-anything-v2-base": "depth-anything/Depth-Anything-V2-Base",
"depth-anything-v2-large": "depth-anything/Depth-Anything-V2-Large",
"depth-anything-v2-metric-indoor-small": "depth-anything/Depth-Anything-V2-Metric-Hypersim-Small",
"depth-anything-v2-metric-indoor-base": "depth-anything/Depth-Anything-V2-Metric-Hypersim-Base",
"depth-anything-v2-metric-indoor-large": "depth-anything/Depth-Anything-V2-Metric-Hypersim-Large",
"depth-anything-v2-metric-outdoor-small": "depth-anything/Depth-Anything-V2-Metric-VKITTI-Small",
"depth-anything-v2-metric-outdoor-base": "depth-anything/Depth-Anything-V2-Metric-VKITTI-Base",
"depth-anything-v2-metric-outdoor-large": "depth-anything/Depth-Anything-V2-Metric-VKITTI-Large",
}

# load original state_dict
Expand Down Expand Up @@ -272,6 +293,30 @@ def convert_dpt_checkpoint(model_name, pytorch_dump_folder_path, push_to_hub, ve
expected_slice = torch.tensor(
[[162.2751, 161.8504, 162.8788], [160.3138, 160.8050, 161.9835], [159.3812, 159.9884, 160.0768]]
)
elif model_name == "depth-anything-v2-metric-indoor-small":
expected_slice = torch.tensor(
[[1.3349, 1.2946, 1.2801], [1.2793, 1.2337, 1.2899], [1.2629, 1.2218, 1.2476]]
)
elif model_name == "depth-anything-v2-metric-indoor-base":
expected_slice = torch.tensor(
[[1.4601, 1.3824, 1.4904], [1.5031, 1.4349, 1.4274], [1.4570, 1.4578, 1.4200]]
)
elif model_name == "depth-anything-v2-metric-indoor-large":
expected_slice = torch.tensor(
[[1.5040, 1.5019, 1.5218], [1.5087, 1.5195, 1.5149], [1.5437, 1.5128, 1.5252]]
)
elif model_name == "depth-anything-v2-metric-outdoor-small":
expected_slice = torch.tensor(
[[9.5804, 8.0339, 7.7386], [7.9890, 7.2464, 7.7149], [7.7021, 7.2330, 7.3304]]
)
elif model_name == "depth-anything-v2-metric-outdoor-base":
expected_slice = torch.tensor(
[[10.2916, 9.0933, 8.8622], [9.1964, 9.3393, 9.0644], [8.9618, 9.4201, 9.2262]]
)
elif model_name == "depth-anything-v2-metric-outdoor-large":
expected_slice = torch.tensor(
[[14.0137, 13.3627, 13.1080], [13.2522, 13.3943, 13.3705], [13.0581, 13.4505, 13.3925]]
)
else:
raise ValueError("Not supported")

Expand Down
14 changes: 10 additions & 4 deletions src/transformers/models/depth_anything/modeling_depth_anything.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,6 @@
pixel_values (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)`):
Pixel values. Pixel values can be obtained using [`AutoImageProcessor`]. See [`DPTImageProcessor.__call__`]
for details.

output_attentions (`bool`, *optional*):
Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned
tensors for more detail.
Expand Down Expand Up @@ -318,7 +317,8 @@ class DepthAnythingDepthEstimationHead(nn.Module):
"""
Output head consisting of 3 convolutional layers. It progressively halves the feature dimension and upsamples
the predictions to the input resolution after the first convolutional layer (details can be found in the DPT paper's
supplementary material).
supplementary material). The final activation function is either ReLU or Sigmoid, depending on the depth estimation
type (relative or metric). For metric depth estimation, the output is scaled by the maximum depth used during pretraining.
"""

def __init__(self, config):
Expand All @@ -332,7 +332,13 @@ def __init__(self, config):
self.conv2 = nn.Conv2d(features // 2, config.head_hidden_size, kernel_size=3, stride=1, padding=1)
self.activation1 = nn.ReLU()
self.conv3 = nn.Conv2d(config.head_hidden_size, 1, kernel_size=1, stride=1, padding=0)
self.activation2 = nn.ReLU()
if config.depth_estimation_type == "relative":
self.activation2 = nn.ReLU()
elif config.depth_estimation_type == "metric":
self.activation2 = nn.Sigmoid()
else:
raise ValueError(f"Unknown depth estimation type: {config.depth_estimation_type}")
self.max_depth = config.max_depth

def forward(self, hidden_states: List[torch.Tensor], patch_height, patch_width) -> torch.Tensor:
hidden_states = hidden_states[self.head_in_index]
Expand All @@ -347,7 +353,7 @@ def forward(self, hidden_states: List[torch.Tensor], patch_height, patch_width)
predicted_depth = self.conv2(predicted_depth)
predicted_depth = self.activation1(predicted_depth)
predicted_depth = self.conv3(predicted_depth)
predicted_depth = self.activation2(predicted_depth)
predicted_depth = self.activation2(predicted_depth) * self.max_depth
predicted_depth = predicted_depth.squeeze(dim=1) # shape (batch_size, height, width)

return predicted_depth
Expand Down
26 changes: 25 additions & 1 deletion tests/models/depth_anything/test_modeling_depth_anything.py
Original file line number Diff line number Diff line change
Expand Up @@ -246,6 +246,7 @@ def prepare_img():
@slow
class DepthAnythingModelIntegrationTest(unittest.TestCase):
def test_inference(self):
# -- `relative` depth model --
image_processor = DPTImageProcessor.from_pretrained("LiheYoung/depth-anything-small-hf")
model = DepthAnythingForDepthEstimation.from_pretrained("LiheYoung/depth-anything-small-hf").to(torch_device)

Expand All @@ -265,4 +266,27 @@ def test_inference(self):
[[8.8204, 8.6468, 8.6195], [8.3313, 8.6027, 8.7526], [8.6526, 8.6866, 8.7453]],
).to(torch_device)

self.assertTrue(torch.allclose(outputs.predicted_depth[0, :3, :3], expected_slice, atol=1e-6))
self.assertTrue(torch.allclose(predicted_depth[0, :3, :3], expected_slice, atol=1e-6))

# -- `metric` depth model --
image_processor = DPTImageProcessor.from_pretrained("depth-anything/depth-anything-V2-metric-indoor-small-hf")
model = DepthAnythingForDepthEstimation.from_pretrained(
"depth-anything/depth-anything-V2-metric-indoor-small-hf"
).to(torch_device)

inputs = image_processor(images=image, return_tensors="pt").to(torch_device)

# forward pass
with torch.no_grad():
outputs = model(**inputs)
predicted_depth = outputs.predicted_depth

# verify the predicted depth
expected_shape = torch.Size([1, 518, 686])
self.assertEqual(predicted_depth.shape, expected_shape)

expected_slice = torch.tensor(
[[1.3349, 1.2946, 1.2801], [1.2793, 1.2337, 1.2899], [1.2629, 1.2218, 1.2476]],
).to(torch_device)

self.assertTrue(torch.allclose(predicted_depth[0, :3, :3], expected_slice, atol=1e-4))