Skip to content

Commit

Permalink
render_flyaround bugfix
Browse files Browse the repository at this point in the history
Summary: Fixes a bug which would crash render_flyaround anytime  visualize_preds_keys is adjusted

Reviewed By: shapovalov

Differential Revision: D41124462

fbshipit-source-id: 127045a91a055909f8bd56c8af81afac02c00f60
  • Loading branch information
davnov134 authored and facebook-github-bot committed Nov 28, 2022
1 parent 35f8cb9 commit 94f321f
Showing 1 changed file with 25 additions and 7 deletions.
32 changes: 25 additions & 7 deletions pytorch3d/implicitron/models/visualization/render_flyaround.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,17 @@
import math
import os
import random
from typing import Any, Dict, List, Optional, Sequence, Tuple, TYPE_CHECKING, Union
from typing import (
Any,
Dict,
Iterable,
List,
Optional,
Sequence,
Tuple,
TYPE_CHECKING,
Union,
)

import numpy as np
import torch
Expand Down Expand Up @@ -180,7 +190,7 @@ def render_flyaround(
preds.update(net_input) # merge everything into one big dict

# Render the predictions to images
rendered_pred = _images_from_preds(preds)
rendered_pred = _images_from_preds(preds, extract_keys=visualize_preds_keys)
preds_total.append(rendered_pred)

# show the preds every 5% of the export iterations
Expand Down Expand Up @@ -223,17 +233,20 @@ def _load_whole_dataset(
return next(iter(load_all_dataloader))


def _images_from_preds(preds: Dict[str, Any]) -> Dict[str, torch.Tensor]:
imout = {}
for k in (
def _images_from_preds(
preds: Dict[str, Any],
extract_keys: Iterable[str] = (
"image_rgb",
"images_render",
"fg_probability",
"masks_render",
"depths_render",
"depth_map",
"_all_source_images",
):
),
) -> Dict[str, torch.Tensor]:
imout = {}
for k in extract_keys:
if k == "_all_source_images" and "image_rgb" in preds:
src_ims = preds["image_rgb"][1:].cpu().detach().clone()
v = _stack_images(src_ims, None)[None]
Expand Down Expand Up @@ -343,6 +356,9 @@ def _generate_prediction_videos(
# init a video writer for each predicted key
vws = {}
for k in predicted_keys:
if k not in preds[0]:
logger.warn(f"Cannot generate video for prediction key '{k}'")
continue
cache_dir = (
None
if video_frames_dir is None
Expand All @@ -355,13 +371,15 @@ def _generate_prediction_videos(
)

for rendered_pred in tqdm(preds):
for k in predicted_keys:
for k in vws:
vws[k].write_frame(
rendered_pred[k][0].clip(0.0, 1.0).detach().cpu().numpy(),
resize=resize,
)

for k in predicted_keys:
if k not in vws:
continue
vws[k].get_video()
logger.info(f"Generated {vws[k].out_path}.")
if viz is not None:
Expand Down

0 comments on commit 94f321f

Please sign in to comment.