Skip to content

Commit

Permalink
Fixes for RayBundle plotting
Browse files Browse the repository at this point in the history
Summary:
Fixes some issues with RayBundle plotting:
- allows plotting raybundles on gpu
- view -> reshape since we do not require contiguous raybundle tensors as input

Reviewed By: bottler, shapovalov

Differential Revision: D42665923

fbshipit-source-id: e9c6c7810428365dca4cb5ec80ef15ff28644163
  • Loading branch information
davnov134 authored and facebook-github-bot committed Jan 25, 2023
1 parent a12612a commit 9dc28f5
Show file tree
Hide file tree
Showing 3 changed files with 100 additions and 3 deletions.
16 changes: 16 additions & 0 deletions pytorch3d/vis/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,3 +3,19 @@
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

import warnings


try:
from .plotly_vis import get_camera_wireframe, plot_batch_individually, plot_scene
except ModuleNotFoundError as err:
if "plotly" in str(err):
warnings.warn(
"Cannot import plotly-based visualization code."
" Please install plotly to enable (pip install plotly)."
)
else:
raise

from .texture_vis import texturesuv_image_matplotlib, texturesuv_image_PIL
13 changes: 10 additions & 3 deletions pytorch3d/vis/plotly_vis.py
Original file line number Diff line number Diff line change
Expand Up @@ -100,6 +100,7 @@ class Lighting(NamedTuple): # pragma: no cover
vertexnormalsepsilon: float = 1e-12


@torch.no_grad()
def plot_scene(
plots: Dict[str, Dict[str, Struct]],
*,
Expand Down Expand Up @@ -407,6 +408,7 @@ def plot_scene(
return fig


@torch.no_grad()
def plot_batch_individually(
batched_structs: Union[
List[Struct],
Expand Down Expand Up @@ -888,8 +890,12 @@ def _add_ray_bundle_trace(
)

# make the ray lines for plotly plotting
nan_tensor = torch.Tensor([[float("NaN")] * 3])
ray_lines = torch.empty(size=(1, 3))
nan_tensor = torch.tensor(
[[float("NaN")] * 3],
device=ray_lines_endpoints.device,
dtype=ray_lines_endpoints.dtype,
)
ray_lines = torch.empty(size=(1, 3), device=ray_lines_endpoints.device)
for ray_line in ray_lines_endpoints:
# We combine the ray lines into a single tensor to plot them in a
# single trace. The NaNs are inserted between sets of ray lines
Expand Down Expand Up @@ -952,7 +958,7 @@ def _add_ray_bundle_trace(
current_layout = fig["layout"][plot_scene]

# update the bounds of the axes for the current trace
all_ray_points = ray_bundle_to_ray_points(ray_bundle).view(-1, 3)
all_ray_points = ray_bundle_to_ray_points(ray_bundle).reshape(-1, 3)
ray_points_center = all_ray_points.mean(dim=0)
max_expand = (all_ray_points.max(0)[0] - all_ray_points.min(0)[0]).max().item()
_update_axes_bounds(ray_points_center, float(max_expand), current_layout)
Expand Down Expand Up @@ -1002,6 +1008,7 @@ def _update_axes_bounds(
max_expand: the maximum spread in any dimension of the trace's vertices.
current_layout: the plotly figure layout scene corresponding to the referenced trace.
"""
verts_center = verts_center.detach().cpu()
verts_min = verts_center - max_expand
verts_max = verts_center + max_expand
bounds = torch.t(torch.stack((verts_min, verts_max)))
Expand Down
74 changes: 74 additions & 0 deletions tests/test_vis.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,74 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

import unittest

import torch
from pytorch3d.renderer import HeterogeneousRayBundle, PerspectiveCameras, RayBundle
from pytorch3d.structures import Meshes, Pointclouds
from pytorch3d.transforms import random_rotations

# Some of these imports are only needed for testing code coverage
from pytorch3d.vis import ( # noqa: F401
get_camera_wireframe, # noqa: F401
plot_batch_individually, # noqa: F401
plot_scene,
texturesuv_image_PIL, # noqa: F401
)


class TestPlotlyVis(unittest.TestCase):
def test_plot_scene(
self,
B: int = 3,
n_rays: int = 128,
n_pts_per_ray: int = 32,
n_verts: int = 32,
n_edges: int = 64,
n_pts: int = 256,
):
"""
Tests plotting of all supported structures using plot_scene.
"""
for device in ["cpu", "cuda:0"]:
plot_scene(
{
"scene": {
"ray_bundle": RayBundle(
origins=torch.randn(B, n_rays, 3, device=device),
xys=torch.randn(B, n_rays, 2, device=device),
directions=torch.randn(B, n_rays, 3, device=device),
lengths=torch.randn(
B, n_rays, n_pts_per_ray, device=device
),
),
"heterogeneous_ray_bundle": HeterogeneousRayBundle(
origins=torch.randn(B * n_rays, 3, device=device),
xys=torch.randn(B * n_rays, 2, device=device),
directions=torch.randn(B * n_rays, 3, device=device),
lengths=torch.randn(
B * n_rays, n_pts_per_ray, device=device
),
camera_ids=torch.randint(
low=0, high=B, size=(B * n_rays,), device=device
),
),
"camera": PerspectiveCameras(
R=random_rotations(B, device=device),
T=torch.randn(B, 3, device=device),
),
"mesh": Meshes(
verts=torch.randn(B, n_verts, 3, device=device),
faces=torch.randint(
low=0, high=n_verts, size=(B, n_edges, 3), device=device
),
),
"point_clouds": Pointclouds(
points=torch.randn(B, n_pts, 3, device=device),
),
}
}
)

0 comments on commit 9dc28f5

Please sign in to comment.