diff --git a/examples/community/marigold_depth_estimation.py b/examples/community/marigold_depth_estimation.py index 05fba7daa1a9..18e1b090e13e 100644 --- a/examples/community/marigold_depth_estimation.py +++ b/examples/community/marigold_depth_estimation.py @@ -170,10 +170,10 @@ def __call__( # Normalize rgb values rgb = np.transpose(image, (2, 0, 1)) # [H, W, rgb] -> [rgb, H, W] - rgb_norm = rgb / 255.0 + rgb_norm = rgb / 255.0 * 2.0 - 1.0 # [0, 255] -> [-1, 1] rgb_norm = torch.from_numpy(rgb_norm).to(self.dtype) rgb_norm = rgb_norm.to(device) - assert rgb_norm.min() >= 0.0 and rgb_norm.max() <= 1.0 + assert rgb_norm.min() >= -1.0 and rgb_norm.max() <= 1.0 # ----------------- Predicting depth ----------------- # Batch repeated input image