@@ -50,14 +50,14 @@ class MarigoldDepthOutput(BaseOutput):
5050 Args:
5151 depth_np (`np.ndarray`):
5252 Predicted depth map, with depth values in the range of [0, 1].
53- depth_colored (`PIL.Image.Image`):
53+ depth_colored (`None` or ` PIL.Image.Image`):
5454 Colorized depth map, with the shape of [3, H, W] and values in [0, 1].
5555 uncertainty (`None` or `np.ndarray`):
5656 Uncalibrated uncertainty(MAD, median absolute deviation) coming from ensembling.
5757 """
5858
5959 depth_np : np .ndarray
60- depth_colored : Image .Image
60+ depth_colored : Union [ None , Image .Image ]
6161 uncertainty : Union [None , np .ndarray ]
6262
6363
@@ -139,14 +139,15 @@ def __call__(
139139 If set to 0, the script will automatically decide the proper batch size.
140140 show_progress_bar (`bool`, *optional*, defaults to `True`):
141141 Display a progress bar of diffusion denoising.
142- color_map (`str`, *optional*, defaults to `"Spectral"`):
142+ color_map (`str`, *optional*, defaults to `"Spectral"`, pass `None` to skip colorized depth map generation ):
143143 Colormap used to colorize the depth map.
144144 ensemble_kwargs (`dict`, *optional*, defaults to `None`):
145145 Arguments for detailed ensembling settings.
146146 Returns:
147147 `MarigoldDepthOutput`: Output class for Marigold monocular depth prediction pipeline, including:
148148 - **depth_np** (`np.ndarray`) Predicted depth map, with depth values in the range of [0, 1]
149- - **depth_colored** (`PIL.Image.Image`) Colorized depth map, with the shape of [3, H, W] and values in [0, 1]
149+ - **depth_colored** (`None` or `PIL.Image.Image`) Colorized depth map, with the shape of [3, H, W] and
150+ values in [0, 1]. None if `color_map` is `None`
150151 - **uncertainty** (`None` or `np.ndarray`) Uncalibrated uncertainty(MAD, median absolute deviation)
151152 coming from ensembling. None if `ensemble_size = 1`
152153 """
@@ -155,15 +156,19 @@ def __call__(
155156 input_size = input_image .size
156157
157158 if not match_input_res :
158- assert processing_res is not None , "Value error: `resize_output_back` is only valid with "
159+ assert (
160+ processing_res is not None
161+ ), "Value error: `resize_output_back` is only valid with "
159162 assert processing_res >= 0
160163 assert denoising_steps >= 1
161164 assert ensemble_size >= 1
162165
163166 # ----------------- Image Preprocess -----------------
164167 # Resize image
165168 if processing_res > 0 :
166- input_image = self .resize_max_res (input_image , max_edge_resolution = processing_res )
169+ input_image = self .resize_max_res (
170+ input_image , max_edge_resolution = processing_res
171+ )
167172 # Convert the image to RGB, to 1.remove the alpha channel 2.convert B&W to 3-channel
168173 input_image = input_image .convert ("RGB" )
169174 image = np .asarray (input_image )
@@ -188,12 +193,16 @@ def __call__(
188193 dtype = self .dtype ,
189194 )
190195
191- single_rgb_loader = DataLoader (single_rgb_dataset , batch_size = _bs , shuffle = False )
196+ single_rgb_loader = DataLoader (
197+ single_rgb_dataset , batch_size = _bs , shuffle = False
198+ )
192199
193200 # Predict depth maps (batched)
194201 depth_pred_ls = []
195202 if show_progress_bar :
196- iterable = tqdm (single_rgb_loader , desc = " " * 2 + "Inference batches" , leave = False )
203+ iterable = tqdm (
204+ single_rgb_loader , desc = " " * 2 + "Inference batches" , leave = False
205+ )
197206 else :
198207 iterable = single_rgb_loader
199208 for batch in iterable :
@@ -209,7 +218,9 @@ def __call__(
209218
210219 # ----------------- Test-time ensembling -----------------
211220 if ensemble_size > 1 :
212- depth_pred , pred_uncert = self .ensemble_depths (depth_preds , ** (ensemble_kwargs or {}))
221+ depth_pred , pred_uncert = self .ensemble_depths (
222+ depth_preds , ** (ensemble_kwargs or {})
223+ )
213224 else :
214225 depth_pred = depth_preds
215226 pred_uncert = None
@@ -233,12 +244,15 @@ def __call__(
233244 depth_pred = depth_pred .clip (0 , 1 )
234245
235246 # Colorize
236- depth_colored = self .colorize_depth_maps (
237- depth_pred , 0 , 1 , cmap = color_map
238- ).squeeze () # [3, H, W], value in (0, 1)
239- depth_colored = (depth_colored * 255 ).astype (np .uint8 )
240- depth_colored_hwc = self .chw2hwc (depth_colored )
241- depth_colored_img = Image .fromarray (depth_colored_hwc )
247+ if color_map is not None :
248+ depth_colored = self .colorize_depth_maps (
249+ depth_pred , 0 , 1 , cmap = color_map
250+ ).squeeze () # [3, H, W], value in (0, 1)
251+ depth_colored = (depth_colored * 255 ).astype (np .uint8 )
252+ depth_colored_hwc = self .chw2hwc (depth_colored )
253+ depth_colored_img = Image .fromarray (depth_colored_hwc )
254+ else :
255+ depth_colored_img = None
242256 return MarigoldDepthOutput (
243257 depth_np = depth_pred ,
244258 depth_colored = depth_colored_img ,
@@ -261,7 +275,9 @@ def _encode_empty_text(self):
261275 self .empty_text_embed = self .text_encoder (text_input_ids )[0 ].to (self .dtype )
262276
263277 @torch .no_grad ()
264- def single_infer (self , rgb_in : torch .Tensor , num_inference_steps : int , show_pbar : bool ) -> torch .Tensor :
278+ def single_infer (
279+ self , rgb_in : torch .Tensor , num_inference_steps : int , show_pbar : bool
280+ ) -> torch .Tensor :
265281 """
266282 Perform an individual depth prediction without ensembling.
267283
@@ -285,12 +301,16 @@ def single_infer(self, rgb_in: torch.Tensor, num_inference_steps: int, show_pbar
285301 rgb_latent = self ._encode_rgb (rgb_in )
286302
287303 # Initial depth map (noise)
288- depth_latent = torch .randn (rgb_latent .shape , device = device , dtype = self .dtype ) # [B, 4, h, w]
304+ depth_latent = torch .randn (
305+ rgb_latent .shape , device = device , dtype = self .dtype
306+ ) # [B, 4, h, w]
289307
290308 # Batched empty text embedding
291309 if self .empty_text_embed is None :
292310 self ._encode_empty_text ()
293- batch_empty_text_embed = self .empty_text_embed .repeat ((rgb_latent .shape [0 ], 1 , 1 )) # [B, 2, 1024]
311+ batch_empty_text_embed = self .empty_text_embed .repeat (
312+ (rgb_latent .shape [0 ], 1 , 1 )
313+ ) # [B, 2, 1024]
294314
295315 # Denoising loop
296316 if show_pbar :
@@ -304,10 +324,14 @@ def single_infer(self, rgb_in: torch.Tensor, num_inference_steps: int, show_pbar
304324 iterable = enumerate (timesteps )
305325
306326 for i , t in iterable :
307- unet_input = torch .cat ([rgb_latent , depth_latent ], dim = 1 ) # this order is important
327+ unet_input = torch .cat (
328+ [rgb_latent , depth_latent ], dim = 1
329+ ) # this order is important
308330
309331 # predict the noise residual
310- noise_pred = self .unet (unet_input , t , encoder_hidden_states = batch_empty_text_embed ).sample # [B, 4, h, w]
332+ noise_pred = self .unet (
333+ unet_input , t , encoder_hidden_states = batch_empty_text_embed
334+ ).sample # [B, 4, h, w]
311335
312336 # compute the previous noisy sample x_t -> x_t-1
313337 depth_latent = self .scheduler .step (noise_pred , t , depth_latent ).prev_sample
@@ -375,7 +399,9 @@ def resize_max_res(img: Image.Image, max_edge_resolution: int) -> Image.Image:
375399 `Image.Image`: Resized image.
376400 """
377401 original_width , original_height = img .size
378- downscale_factor = min (max_edge_resolution / original_width , max_edge_resolution / original_height )
402+ downscale_factor = min (
403+ max_edge_resolution / original_width , max_edge_resolution / original_height
404+ )
379405
380406 new_width = int (original_width * downscale_factor )
381407 new_height = int (original_height * downscale_factor )
@@ -384,7 +410,9 @@ def resize_max_res(img: Image.Image, max_edge_resolution: int) -> Image.Image:
384410 return resized_img
385411
386412 @staticmethod
387- def colorize_depth_maps (depth_map , min_depth , max_depth , cmap = "Spectral" , valid_mask = None ):
413+ def colorize_depth_maps (
414+ depth_map , min_depth , max_depth , cmap = "Spectral" , valid_mask = None
415+ ):
388416 """
389417 Colorize depth maps.
390418 """
@@ -526,7 +554,9 @@ def inter_distances(tensors: torch.Tensor):
526554 if max_res is not None :
527555 scale_factor = torch .min (max_res / torch .tensor (ori_shape [- 2 :]))
528556 if scale_factor < 1 :
529- downscaler = torch .nn .Upsample (scale_factor = scale_factor , mode = "nearest" )
557+ downscaler = torch .nn .Upsample (
558+ scale_factor = scale_factor , mode = "nearest"
559+ )
530560 input_images = downscaler (torch .from_numpy (input_images )).numpy ()
531561
532562 # init guess
0 commit comments