Skip to content

Commit

Permalink
raybundle input to ImplicitFunctions -> api unification
Browse files Browse the repository at this point in the history
Summary: Currently some implicit functions in implicitron take a raybundle, others take ray_points_world. raybundle is what they really need. However, the raybundle is going to become a bit more flexible later, as it will contain different numbers of rays for each camera.

Reviewed By: bottler

Differential Revision: D39173751

fbshipit-source-id: ebc038e426d22e831e67a18ba64655d8a61e1eb9
  • Loading branch information
Darijan Gudelj authored and facebook-github-bot committed Sep 5, 2022
1 parent 70dc9c4 commit 72c3a0e
Show file tree
Hide file tree
Showing 9 changed files with 60 additions and 19 deletions.
1 change: 1 addition & 0 deletions pytorch3d/implicitron/models/implicit_function/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ def __init__(self):
@abstractmethod
def forward(
self,
*,
ray_bundle: RayBundle,
fun_viewpool=None,
camera: Optional[CamerasBase] = None,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3,14 +3,15 @@
# implicit_differentiable_renderer.py
# Copyright (c) 2020 Lior Yariv
import math
from typing import Tuple
from typing import Optional, Tuple

import torch
from pytorch3d.implicitron.tools.config import registry
from pytorch3d.renderer.implicit import HarmonicEmbedding
from pytorch3d.renderer.implicit import HarmonicEmbedding, RayBundle
from torch import nn

from .base import ImplicitFunctionBase
from .utils import get_rays_points_world


@registry.register
Expand Down Expand Up @@ -125,14 +126,16 @@ def __post_init__(self):
# inconsistently.
def forward(
self,
# ray_bundle: RayBundle,
rays_points_world: torch.Tensor, # TODO: unify the APIs
*,
ray_bundle: Optional[RayBundle] = None,
rays_points_world: Optional[torch.Tensor] = None,
fun_viewpool=None,
global_code=None,
**kwargs,
):
# this field only uses point locations
# rays_points_world = ray_bundle_to_ray_points(ray_bundle)
# rays_points_world.shape = [minibatch x ... x pts_per_ray x 3]
rays_points_world = get_rays_points_world(ray_bundle, rays_points_world)

if rays_points_world.numel() == 0 or (
self.embed_fn is None and fun_viewpool is None and global_code is None
Expand Down Expand Up @@ -179,4 +182,4 @@ def forward(
# pyre-fixme[29]: `Union[torch.Tensor, torch.nn.Module]` is not a function.
x = self.softplus(x)

return x # TODO: unify the APIs
return x
Original file line number Diff line number Diff line change
Expand Up @@ -129,6 +129,7 @@ def allows_multiple_passes() -> bool:

def forward(
self,
*,
ray_bundle: RayBundle,
fun_viewpool=None,
camera: Optional[CamerasBase] = None,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -349,6 +349,7 @@ def raymarch_function_tweak_args(cls, type, args: DictConfig) -> None:

def forward(
self,
*,
ray_bundle: RayBundle,
fun_viewpool=None,
camera: Optional[CamerasBase] = None,
Expand Down Expand Up @@ -408,6 +409,7 @@ def hypernet_tweak_args(cls, type, args: DictConfig) -> None:

def forward(
self,
*,
ray_bundle: RayBundle,
fun_viewpool=None,
camera: Optional[CamerasBase] = None,
Expand Down
30 changes: 30 additions & 0 deletions pytorch3d/implicitron/models/implicit_function/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,9 @@

import torch.nn.functional as F
from pytorch3d.common.compat import prod
from pytorch3d.renderer import ray_bundle_to_ray_points
from pytorch3d.renderer.cameras import CamerasBase
from pytorch3d.renderer.implicit import RayBundle


def broadcast_global_code(embeds: torch.Tensor, global_code: torch.Tensor):
Expand Down Expand Up @@ -185,3 +187,31 @@ def interpolate_volume(
**kwargs,
)
return out[:, :, :, 0, 0].permute(0, 2, 1)


def get_rays_points_world(
ray_bundle: Optional[RayBundle] = None,
rays_points_world: Optional[torch.Tensor] = None,
) -> torch.Tensor:
"""
Converts the ray_bundle to rays_points_world if rays_points_world is not defined
and raises error if both are defined.
Args:
ray_bundle: A RayBundle object or None
rays_points_world: A torch.Tensor representing ray points converted to
world coordinates
Returns:
A torch.Tensor representing ray points converted to world coordinates
of shape [minibatch x ... x pts_per_ray x 3].
"""
if rays_points_world is not None and ray_bundle is not None:
raise ValueError(
"Cannot define both rays_points_world and ray_bundle,"
+ " one has to be None."
)
if rays_points_world is not None:
return rays_points_world
if ray_bundle is not None:
return ray_bundle_to_ray_points(ray_bundle)
raise ValueError("ray_bundle and rays_points_world cannot both be None")
2 changes: 1 addition & 1 deletion pytorch3d/implicitron/models/renderer/lstm_renderer.py
Original file line number Diff line number Diff line change
Expand Up @@ -118,7 +118,7 @@ def forward(

# eval the raymarching function
raymarch_features, _ = implicit_function(
ray_bundle_t,
ray_bundle=ray_bundle_t,
raymarch_features=None,
)
if self.verbose:
Expand Down
2 changes: 1 addition & 1 deletion pytorch3d/implicitron/models/renderer/multipass_ea.py
Original file line number Diff line number Diff line change
Expand Up @@ -148,7 +148,7 @@ def _run_raymarcher(
)

output = self.raymarcher(
*implicit_functions[0](ray_bundle),
*implicit_functions[0](ray_bundle=ray_bundle),
ray_lengths=ray_bundle.lengths,
density_noise_std=density_noise_std,
)
Expand Down
20 changes: 11 additions & 9 deletions pytorch3d/implicitron/models/renderer/sdf_renderer.py
Original file line number Diff line number Diff line change
Expand Up @@ -101,7 +101,7 @@ def forward(
object_mask = object_mask.bool()

implicit_function = implicit_functions[0]
implicit_function_gradient = functools.partial(gradient, implicit_function)
implicit_function_gradient = functools.partial(_gradient, implicit_function)

# object_mask: silhouette of the object
batch_size, *spatial_size, _ = ray_bundle.lengths.shape
Expand All @@ -113,7 +113,7 @@ def forward(

with torch.no_grad(), evaluating(implicit_function):
points, network_object_mask, dists = self.ray_tracer(
sdf=lambda x: implicit_function(x)[
sdf=lambda x: implicit_function(rays_points_world=x)[
:, 0
], # TODO: get rid of this wrapper
cam_loc=cam_loc,
Expand All @@ -125,7 +125,7 @@ def forward(
depth = dists.reshape(batch_size, num_pixels, 1)
points = (cam_loc + depth * ray_dirs).reshape(-1, 3)

sdf_output = implicit_function(points)[:, 0:1]
sdf_output = implicit_function(rays_points_world=points)[:, 0:1]
# NOTE most of the intermediate variables are flattened for
# no apparent reason (here and in the ray tracer)
ray_dirs = ray_dirs.reshape(-1, 3)
Expand Down Expand Up @@ -157,7 +157,7 @@ def forward(

points_all = torch.cat([surface_points, eikonal_points], dim=0)

output = implicit_function(surface_points)
output = implicit_function(rays_points_world=surface_points)
surface_sdf_values = output[
:N, 0:1
].detach() # how is it different from sdf_output?
Expand All @@ -181,7 +181,9 @@ def forward(
grad_theta = None

empty_render = differentiable_surface_points.shape[0] == 0
features = implicit_function(differentiable_surface_points)[None, :, 1:]
features = implicit_function(rays_points_world=differentiable_surface_points)[
None, :, 1:
]
normals_full = features.new_zeros(
batch_size, *spatial_size, 3, requires_grad=empty_render
)
Expand Down Expand Up @@ -260,13 +262,13 @@ def _sample_network(


@torch.enable_grad()
def gradient(module, x):
x.requires_grad_(True)
y = module.forward(x)[:, :1]
def _gradient(module, rays_points_world):
rays_points_world.requires_grad_(True)
y = module.forward(rays_points_world=rays_points_world)[:, :1]
d_output = torch.ones_like(y, requires_grad=False, device=y.device)
gradients = torch.autograd.grad(
outputs=y,
inputs=x,
inputs=rays_points_world,
grad_outputs=d_output,
create_graph=True,
retain_graph=True,
Expand Down
6 changes: 4 additions & 2 deletions tests/implicitron/test_srn.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@ def test_srn_implicit_function(self):
implicit_function = SRNImplicitFunction()
device = torch.device("cpu")
bundle = self._get_bundle(device=device)
rays_densities, rays_colors = implicit_function(bundle)
rays_densities, rays_colors = implicit_function(ray_bundle=bundle)
out_features = implicit_function.raymarch_function.out_features
self.assertEqual(
rays_densities.shape,
Expand All @@ -62,7 +62,9 @@ def test_srn_hypernet_implicit_function(self):
implicit_function.to(device)
global_code = torch.rand(_BATCH_SIZE, latent_dim_hypernet, device=device)
bundle = self._get_bundle(device=device)
rays_densities, rays_colors = implicit_function(bundle, global_code=global_code)
rays_densities, rays_colors = implicit_function(
ray_bundle=bundle, global_code=global_code
)
out_features = implicit_function.hypernet.out_features
self.assertEqual(
rays_densities.shape,
Expand Down

0 comments on commit 72c3a0e

Please sign in to comment.