Skip to content

Commit

Permalink
ImplicitronRayBundle
Browse files Browse the repository at this point in the history
Summary: new implicitronRayBundle with added cameraIDs and camera counts. Added to enable a single raybundle inside Implicitron and easier extension in the future. Since RayBundle is named tuple and RayBundleHeterogeneous is dataclass and RayBundleHeterogeneous cannot inherit RayBundle. So if there was no ImplicitronRayBundle every function that uses RayBundle now would have to use Union[RayBundle, RaybundleHeterogeneous] which is confusing and unecessary complicated.

Reviewed By: bottler, kjchalup

Differential Revision: D39262999

fbshipit-source-id: ece160e32f6c88c3977e408e966789bf8307af59
  • Loading branch information
Darijan Gudelj authored and facebook-github-bot committed Oct 3, 2022
1 parent 6ae863f commit ad8907d
Show file tree
Hide file tree
Showing 18 changed files with 259 additions and 100 deletions.
5 changes: 2 additions & 3 deletions docs/tutorials/implicitron_volumes.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -145,10 +145,9 @@
"from pytorch3d.implicitron.dataset.dataset_base import FrameData\n",
"from pytorch3d.implicitron.dataset.rendered_mesh_dataset_map_provider import RenderedMeshDatasetMapProvider\n",
"from pytorch3d.implicitron.models.generic_model import GenericModel\n",
"from pytorch3d.implicitron.models.implicit_function.base import ImplicitFunctionBase\n",
"from pytorch3d.implicitron.models.implicit_function.base import ImplicitFunctionBase, ImplicitronRayBundle\n",
"from pytorch3d.implicitron.models.renderer.base import EvaluationMode\n",
"from pytorch3d.implicitron.tools.config import get_default_args, registry, remove_unused_components\n",
"from pytorch3d.renderer import RayBundle\n",
"from pytorch3d.renderer.implicit.renderer import VolumeSampler\n",
"from pytorch3d.structures import Volumes\n",
"from pytorch3d.vis.plotly_vis import plot_batch_individually, plot_scene"
Expand Down Expand Up @@ -393,7 +392,7 @@
"\n",
" def forward(\n",
" self,\n",
" ray_bundle: RayBundle,\n",
" ray_bundle: ImplicitronRayBundle,\n",
" fun_viewpool=None,\n",
" global_code=None,\n",
" ):\n",
Expand Down
14 changes: 8 additions & 6 deletions pytorch3d/implicitron/models/generic_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
RegularizationMetricsBase,
ViewMetricsBase,
)
from pytorch3d.implicitron.models.renderer.base import ImplicitronRayBundle
from pytorch3d.implicitron.tools import image_utils, vis_utils
from pytorch3d.implicitron.tools.config import (
expand_args_fields,
Expand All @@ -30,7 +31,8 @@
)
from pytorch3d.implicitron.tools.rasterize_mc import rasterize_mc_samples
from pytorch3d.implicitron.tools.utils import cat_dataclass
from pytorch3d.renderer import RayBundle, utils as rend_utils
from pytorch3d.renderer import utils as rend_utils

from pytorch3d.renderer.cameras import CamerasBase
from visdom import Visdom

Expand Down Expand Up @@ -387,7 +389,7 @@ def safe_slice_targets(
)

# (1) Sample rendering rays with the ray sampler.
ray_bundle: RayBundle = self.raysampler( # pyre-fixme[29]
ray_bundle: ImplicitronRayBundle = self.raysampler( # pyre-fixme[29]
target_cameras,
evaluation_mode,
mask=mask_crop[:n_targets]
Expand Down Expand Up @@ -568,14 +570,14 @@ def visualize(
def _render(
self,
*,
ray_bundle: RayBundle,
ray_bundle: ImplicitronRayBundle,
chunked_inputs: Dict[str, torch.Tensor],
sampling_mode: RenderSamplingMode,
**kwargs,
) -> RendererOutput:
"""
Args:
ray_bundle: A `RayBundle` object containing the parametrizations of the
ray_bundle: A `ImplicitronRayBundle` object containing the parametrizations of the
sampled rendering rays.
chunked_inputs: A collection of tensor of shape `(B, _, H, W)`. E.g.
SignedDistanceFunctionRenderer requires "object_mask", shape
Expand Down Expand Up @@ -899,7 +901,7 @@ def _tensor_collator(batch, new_dims) -> torch.Tensor:

def _chunk_generator(
chunk_size: int,
ray_bundle: RayBundle,
ray_bundle: ImplicitronRayBundle,
chunked_inputs: Dict[str, torch.Tensor],
tqdm_trigger_threshold: int,
*args,
Expand Down Expand Up @@ -932,7 +934,7 @@ def _chunk_generator(

for start_idx in iter:
end_idx = min(start_idx + chunk_size_in_rays, n_rays)
ray_bundle_chunk = RayBundle(
ray_bundle_chunk = ImplicitronRayBundle(
origins=ray_bundle.origins.reshape(batch_size, -1, 3)[:, start_idx:end_idx],
directions=ray_bundle.directions.reshape(batch_size, -1, 3)[
:, start_idx:end_idx
Expand Down
5 changes: 3 additions & 2 deletions pytorch3d/implicitron/models/implicit_function/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,9 +7,10 @@
from abc import ABC, abstractmethod
from typing import Optional

from pytorch3d.implicitron.models.renderer.base import ImplicitronRayBundle

from pytorch3d.implicitron.tools.config import ReplaceableBase
from pytorch3d.renderer.cameras import CamerasBase
from pytorch3d.renderer.implicit import RayBundle


class ImplicitFunctionBase(ABC, ReplaceableBase):
Expand All @@ -20,7 +21,7 @@ def __init__(self):
def forward(
self,
*,
ray_bundle: RayBundle,
ray_bundle: ImplicitronRayBundle,
fun_viewpool=None,
camera: Optional[CamerasBase] = None,
global_code=None,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,10 @@
from typing import Optional, Tuple

import torch
from pytorch3d.implicitron.models.renderer.base import ImplicitronRayBundle
from pytorch3d.implicitron.tools.config import registry
from pytorch3d.renderer.implicit import HarmonicEmbedding, RayBundle
from pytorch3d.renderer.implicit import HarmonicEmbedding

from torch import nn

from .base import ImplicitFunctionBase
Expand Down Expand Up @@ -127,7 +129,7 @@ def __post_init__(self):
def forward(
self,
*,
ray_bundle: Optional[RayBundle] = None,
ray_bundle: Optional[ImplicitronRayBundle] = None,
rays_points_world: Optional[torch.Tensor] = None,
fun_viewpool=None,
global_code=None,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -9,8 +9,9 @@

import torch
from pytorch3d.common.linear_with_repeat import LinearWithRepeat
from pytorch3d.implicitron.models.renderer.base import ImplicitronRayBundle
from pytorch3d.implicitron.tools.config import expand_args_fields, registry
from pytorch3d.renderer import ray_bundle_to_ray_points, RayBundle
from pytorch3d.renderer import ray_bundle_to_ray_points
from pytorch3d.renderer.cameras import CamerasBase
from pytorch3d.renderer.implicit import HarmonicEmbedding

Expand Down Expand Up @@ -130,7 +131,7 @@ def allows_multiple_passes() -> bool:
def forward(
self,
*,
ray_bundle: RayBundle,
ray_bundle: ImplicitronRayBundle,
fun_viewpool=None,
camera: Optional[CamerasBase] = None,
global_code=None,
Expand All @@ -144,7 +145,7 @@ def forward(
RGB color and opacity respectively.
Args:
ray_bundle: A RayBundle object containing the following variables:
ray_bundle: An ImplicitronRayBundle object containing the following variables:
origins: A tensor of shape `(minibatch, ..., 3)` denoting the
origins of the sampling rays in world coords.
directions: A tensor of shape `(minibatch, ..., 3)`
Expand All @@ -165,11 +166,12 @@ def forward(
"""
# We first convert the ray parametrizations to world
# coordinates with `ray_bundle_to_ray_points`.
# pyre-ignore[6]
rays_points_world = ray_bundle_to_ray_points(ray_bundle)
# rays_points_world.shape = [minibatch x ... x pts_per_ray x 3]

embeds = create_embeddings_for_implicit_function(
xyz_world=ray_bundle_to_ray_points(ray_bundle),
xyz_world=rays_points_world,
# pyre-fixme[6]: Expected `Optional[typing.Callable[..., typing.Any]]`
# for 2nd param but got `Union[None, torch.Tensor, torch.nn.Module]`.
xyz_embedding_function=self.harmonic_embedding_xyz
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -6,9 +6,10 @@
import torch
from omegaconf import DictConfig
from pytorch3d.common.linear_with_repeat import LinearWithRepeat
from pytorch3d.implicitron.models.renderer.base import ImplicitronRayBundle
from pytorch3d.implicitron.third_party import hyperlayers, pytorch_prototyping
from pytorch3d.implicitron.tools.config import Configurable, registry, run_auto_creation
from pytorch3d.renderer import ray_bundle_to_ray_points, RayBundle
from pytorch3d.renderer import ray_bundle_to_ray_points
from pytorch3d.renderer.cameras import CamerasBase
from pytorch3d.renderer.implicit import HarmonicEmbedding

Expand Down Expand Up @@ -68,15 +69,15 @@ def __post_init__(self):

def forward(
self,
ray_bundle: RayBundle,
ray_bundle: ImplicitronRayBundle,
fun_viewpool=None,
camera: Optional[CamerasBase] = None,
global_code=None,
**kwargs,
):
"""
Args:
ray_bundle: A RayBundle object containing the following variables:
ray_bundle: An ImplicitronRayBundle object containing the following variables:
origins: A tensor of shape `(minibatch, ..., 3)` denoting the
origins of the sampling rays in world coords.
directions: A tensor of shape `(minibatch, ..., 3)`
Expand All @@ -96,10 +97,11 @@ def forward(
"""
# We first convert the ray parametrizations to world
# coordinates with `ray_bundle_to_ray_points`.
# pyre-ignore[6]
rays_points_world = ray_bundle_to_ray_points(ray_bundle)

embeds = create_embeddings_for_implicit_function(
xyz_world=ray_bundle_to_ray_points(ray_bundle),
xyz_world=rays_points_world,
# pyre-fixme[6]: Expected `Optional[typing.Callable[..., typing.Any]]`
# for 2nd param but got `Union[torch.Tensor, torch.nn.Module]`.
xyz_embedding_function=self._harmonic_embedding,
Expand Down Expand Up @@ -175,15 +177,15 @@ def _get_colors(self, features: torch.Tensor, rays_directions: torch.Tensor):
def forward(
self,
raymarch_features: torch.Tensor,
ray_bundle: RayBundle,
ray_bundle: ImplicitronRayBundle,
camera: Optional[CamerasBase] = None,
**kwargs,
):
"""
Args:
raymarch_features: Features from the raymarching network of shape
`(minibatch, ..., self.in_features)`
ray_bundle: A RayBundle object containing the following variables:
ray_bundle: An ImplicitronRayBundle object containing the following variables:
origins: A tensor of shape `(minibatch, ..., 3)` denoting the
origins of the sampling rays in world coords.
directions: A tensor of shape `(minibatch, ..., 3)`
Expand Down Expand Up @@ -297,7 +299,7 @@ def _run_hypernet(self, global_code: torch.Tensor) -> Tuple[SRNRaymarchFunction]

def forward(
self,
ray_bundle: RayBundle,
ray_bundle: ImplicitronRayBundle,
fun_viewpool=None,
camera: Optional[CamerasBase] = None,
global_code=None,
Expand Down Expand Up @@ -350,7 +352,7 @@ def raymarch_function_tweak_args(cls, type, args: DictConfig) -> None:
def forward(
self,
*,
ray_bundle: RayBundle,
ray_bundle: ImplicitronRayBundle,
fun_viewpool=None,
camera: Optional[CamerasBase] = None,
global_code=None,
Expand Down Expand Up @@ -410,7 +412,7 @@ def hypernet_tweak_args(cls, type, args: DictConfig) -> None:
def forward(
self,
*,
ray_bundle: RayBundle,
ray_bundle: ImplicitronRayBundle,
fun_viewpool=None,
camera: Optional[CamerasBase] = None,
global_code=None,
Expand Down
7 changes: 4 additions & 3 deletions pytorch3d/implicitron/models/implicit_function/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,9 +10,9 @@

import torch.nn.functional as F
from pytorch3d.common.compat import prod
from pytorch3d.implicitron.models.renderer.base import ImplicitronRayBundle
from pytorch3d.renderer import ray_bundle_to_ray_points
from pytorch3d.renderer.cameras import CamerasBase
from pytorch3d.renderer.implicit import RayBundle


def broadcast_global_code(embeds: torch.Tensor, global_code: torch.Tensor):
Expand Down Expand Up @@ -190,15 +190,15 @@ def interpolate_volume(


def get_rays_points_world(
ray_bundle: Optional[RayBundle] = None,
ray_bundle: Optional[ImplicitronRayBundle] = None,
rays_points_world: Optional[torch.Tensor] = None,
) -> torch.Tensor:
"""
Converts the ray_bundle to rays_points_world if rays_points_world is not defined
and raises error if both are defined.
Args:
ray_bundle: A RayBundle object or None
ray_bundle: An ImplicitronRayBundle object or None
rays_points_world: A torch.Tensor representing ray points converted to
world coordinates
Returns:
Expand All @@ -213,5 +213,6 @@ def get_rays_points_world(
if rays_points_world is not None:
return rays_points_world
if ray_bundle is not None:
# pyre-ignore[6]
return ray_bundle_to_ray_points(ray_bundle)
raise ValueError("ray_bundle and rays_points_world cannot both be None")
44 changes: 42 additions & 2 deletions pytorch3d/implicitron/models/renderer/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,8 @@

from __future__ import annotations

import dataclasses

from abc import ABC, abstractmethod
from dataclasses import dataclass, field
from enum import Enum
Expand All @@ -25,6 +27,38 @@ class RenderSamplingMode(Enum):
FULL_GRID = "full_grid"


@dataclasses.dataclass
class ImplicitronRayBundle:
"""
Parametrizes points along projection rays by storing ray `origins`,
`directions` vectors and `lengths` at which the ray-points are sampled.
Furthermore, the xy-locations (`xys`) of the ray pixels are stored as well.
Note that `directions` don't have to be normalized; they define unit vectors
in the respective 1D coordinate systems; see documentation for
:func:`ray_bundle_to_ray_points` for the conversion formula.
camera_ids: A tensor of shape (N, ) which indicates which camera
was used to sample the rays. `N` is the number of different
sampled cameras.
camera_counts: A tensor of shape (N, ) which how many times the
coresponding camera in `camera_ids` was sampled.
`sum(camera_counts)==minibatch`
"""

origins: torch.Tensor
directions: torch.Tensor
lengths: torch.Tensor
xys: torch.Tensor
camera_ids: Optional[torch.Tensor] = None
camera_counts: Optional[torch.Tensor] = None

def is_packed(self) -> bool:
"""
Returns whether the ImplicitronRayBundle carries data in packed state
"""
return self.camera_ids is not None and self.camera_counts is not None


@dataclass
class RendererOutput:
"""
Expand Down Expand Up @@ -85,7 +119,7 @@ def requires_object_mask(self) -> bool:
@abstractmethod
def forward(
self,
ray_bundle,
ray_bundle: ImplicitronRayBundle,
implicit_functions: List[ImplicitFunctionWrapper],
evaluation_mode: EvaluationMode = EvaluationMode.EVALUATION,
**kwargs,
Expand All @@ -95,7 +129,7 @@ def forward(
that returns an instance of RendererOutput.
Args:
ray_bundle: A RayBundle object containing the following variables:
ray_bundle: An ImplicitronRayBundle object containing the following variables:
origins: A tensor of shape (minibatch, ..., 3) denoting
the origins of the rendering rays.
directions: A tensor of shape (minibatch, ..., 3)
Expand All @@ -108,6 +142,12 @@ def forward(
xys: A tensor of shape
(minibatch, ..., 2) containing the
xy locations of each ray's pixel in the NDC screen space.
camera_ids: A tensor of shape (N, ) which indicates which camera
was used to sample the rays. `N` is the number of different
sampled cameras.
camera_counts: A tensor of shape (N, ) which how many times the
coresponding camera in `camera_ids` was sampled.
`sum(camera_counts)==minibatch`
implicit_functions: List of ImplicitFunctionWrappers which define the
implicit function methods to be used. Most Renderers only allow
a single implicit function. Currently, only the
Expand Down
Loading

0 comments on commit ad8907d

Please sign in to comment.