Skip to content

Commit

Permalink
Enable mixed frame raysampling
Browse files Browse the repository at this point in the history
Summary:
Changed ray_sampler and metrics to be able to use mixed frame raysampling.

Ray_sampler now has a new member which it passes to the pytorch3d raysampler.
If the raybundle is heterogeneous metrics now samples images by padding xys first. This reduces memory consumption.

Reviewed By: bottler, kjchalup

Differential Revision: D39542221

fbshipit-source-id: a6fec23838d3049ae5c2fd2e1f641c46c7c927e3
  • Loading branch information
Darijan Gudelj authored and facebook-github-bot committed Oct 3, 2022
1 parent ad8907d commit c311a4c
Show file tree
Hide file tree
Showing 8 changed files with 102 additions and 35 deletions.
1 change: 1 addition & 0 deletions projects/implicitron_trainer/experiment.py
Original file line number Diff line number Diff line change
Expand Up @@ -222,6 +222,7 @@ def run(self) -> None:
train_loader=train_loader,
val_loader=val_loader,
test_loader=test_loader,
train_dataset=datasets.train,
model=model,
optimizer=optimizer,
scheduler=scheduler,
Expand Down
2 changes: 2 additions & 0 deletions projects/implicitron_trainer/tests/experiment.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -197,6 +197,7 @@ model_factory_ImplicitronModelFactory_args:
n_pts_per_ray_training: 64
n_pts_per_ray_evaluation: 64
n_rays_per_image_sampled_from_mask: 1024
n_rays_total_training: null
stratified_point_sampling_training: true
stratified_point_sampling_evaluation: false
scene_extent: 8.0
Expand All @@ -208,6 +209,7 @@ model_factory_ImplicitronModelFactory_args:
n_pts_per_ray_training: 64
n_pts_per_ray_evaluation: 64
n_rays_per_image_sampled_from_mask: 1024
n_rays_total_training: null
stratified_point_sampling_training: true
stratified_point_sampling_evaluation: false
min_depth: 0.1
Expand Down
9 changes: 8 additions & 1 deletion pytorch3d/implicitron/models/generic_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -473,7 +473,7 @@ def curried_viewpooler(pts):
self.view_metrics(
results=preds,
raymarched=rendered,
xys=ray_bundle.xys,
ray_bundle=ray_bundle,
image_rgb=safe_slice_targets(image_rgb),
depth_map=safe_slice_targets(depth_map),
fg_probability=safe_slice_targets(fg_probability),
Expand Down Expand Up @@ -932,6 +932,11 @@ def _chunk_generator(
if len(iter) >= tqdm_trigger_threshold:
iter = tqdm.tqdm(iter)

def _safe_slice(
tensor: Optional[torch.Tensor], start_idx: int, end_idx: int
) -> Optional[torch.Tensor]:
return tensor[start_idx:end_idx] if tensor is not None else None

for start_idx in iter:
end_idx = min(start_idx + chunk_size_in_rays, n_rays)
ray_bundle_chunk = ImplicitronRayBundle(
Expand All @@ -943,6 +948,8 @@ def _chunk_generator(
:, start_idx:end_idx
],
xys=ray_bundle.xys.reshape(batch_size, -1, 2)[:, start_idx:end_idx],
camera_ids=_safe_slice(ray_bundle.camera_ids, start_idx, end_idx),
camera_counts=_safe_slice(ray_bundle.camera_counts, start_idx, end_idx),
)
extra_args = kwargs.copy()
for k, v in chunked_inputs.items():
Expand Down
60 changes: 46 additions & 14 deletions pytorch3d/implicitron/models/metrics.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,8 +9,10 @@
from typing import Any, Dict, Optional

import torch
from pytorch3d.implicitron.models.renderer.ray_sampler import ImplicitronRayBundle
from pytorch3d.implicitron.tools import metric_utils as utils
from pytorch3d.implicitron.tools.config import registry, ReplaceableBase
from pytorch3d.ops import packed_to_padded, padded_to_packed
from pytorch3d.renderer import utils as rend_utils

from .renderer.base import RendererOutput
Expand Down Expand Up @@ -60,7 +62,7 @@ def __post_init__(self) -> None:
def forward(
self,
raymarched: RendererOutput,
xys: torch.Tensor,
ray_bundle: ImplicitronRayBundle,
image_rgb: Optional[torch.Tensor] = None,
depth_map: Optional[torch.Tensor] = None,
fg_probability: Optional[torch.Tensor] = None,
Expand All @@ -79,10 +81,8 @@ def forward(
names of the output metrics `metric_name_i` with their corresponding
values `metric_value_i` represented as 0-dimensional float tensors.
raymarched: Output of the renderer.
xys: A tensor of shape `(B, ..., 2)` containing 2D image locations at which
the predictions are defined. All ground truth inputs are sampled at
these locations in order to extract values that correspond to the
predictions.
ray_bundle: ImplicitronRayBundle object which was used to produce the raymarched
object
image_rgb: A tensor of shape `(B, H, W, 3)` containing ground truth rgb
values.
depth_map: A tensor of shape `(B, Hd, Wd, 1)` containing ground truth depth
Expand Down Expand Up @@ -141,7 +141,7 @@ class ViewMetrics(ViewMetricsBase):
def forward(
self,
raymarched: RendererOutput,
xys: torch.Tensor,
ray_bundle: ImplicitronRayBundle,
image_rgb: Optional[torch.Tensor] = None,
depth_map: Optional[torch.Tensor] = None,
fg_probability: Optional[torch.Tensor] = None,
Expand All @@ -165,10 +165,8 @@ def forward(
input 3D coordinates used to compute the eikonal loss.
raymarched.aux["density_grid"]: A tensor of shape `(B, Hg, Wg, Dg, 1)`
containing a `Hg x Wg x Dg` voxel grid of density values.
xys: A tensor of shape `(B, ..., 2)` containing 2D image locations at which
the predictions are defined. All ground truth inputs are sampled at
these locations in order to extract values that correspond to the
predictions.
ray_bundle: ImplicitronRayBundle object which was used to produce the raymarched
object
image_rgb: A tensor of shape `(B, H, W, 3)` containing ground truth rgb
values.
depth_map: A tensor of shape `(B, Hd, Wd, 1)` containing ground truth depth
Expand Down Expand Up @@ -209,7 +207,7 @@ def forward(
"""
metrics = self._calculate_stage(
raymarched,
xys,
ray_bundle,
image_rgb,
depth_map,
fg_probability,
Expand All @@ -221,7 +219,7 @@ def forward(
metrics.update(
self(
raymarched.prev_stage,
xys,
ray_bundle,
image_rgb,
depth_map,
fg_probability,
Expand All @@ -235,7 +233,7 @@ def forward(
def _calculate_stage(
self,
raymarched: RendererOutput,
xys: torch.Tensor,
ray_bundle: ImplicitronRayBundle,
image_rgb: Optional[torch.Tensor] = None,
depth_map: Optional[torch.Tensor] = None,
fg_probability: Optional[torch.Tensor] = None,
Expand All @@ -253,6 +251,27 @@ def _calculate_stage(
_reshape_nongrid_var(x)
for x in [raymarched.features, raymarched.masks, raymarched.depths]
]
xys = ray_bundle.xys

# If ray_bundle is packed than we can sample images in padded state to lower
# memory requirements. Instead of having one image for every element in
# ray_bundle we can than have one image per unique sampled camera.
if ray_bundle.is_packed():
# pyre-ignore[6]
cumsum = torch.cumsum(ray_bundle.camera_counts, dim=0, dtype=torch.long)
first_idxs = torch.cat(
(
# pyre-ignore[16]
ray_bundle.camera_counts.new_zeros((1,), dtype=torch.long),
cumsum[:-1],
)
)
# pyre-ignore[16]
num_inputs = int(ray_bundle.camera_counts.sum())
# pyre-ignore[6]
max_size = int(torch.max(ray_bundle.camera_counts))
xys = packed_to_padded(xys, first_idxs, max_size)

# reshape the sampling grid as well
# TODO: we can get rid of the singular dimension here and in _reshape_nongrid_var
# now that we use rend_utils.ndc_grid_sample
Expand All @@ -262,7 +281,20 @@ def _calculate_stage(
def sample(tensor, mode):
if tensor is None:
return tensor
return rend_utils.ndc_grid_sample(tensor, xys, mode=mode)
if ray_bundle.is_packed():
# select images that corespond to sampled cameras if raybundle is packed
tensor = tensor[ray_bundle.camera_ids]
result = rend_utils.ndc_grid_sample(tensor, xys, mode=mode)
if ray_bundle.is_packed():
# Images after sampling are in a form [batch, 3, max_num_rays, 1],
# packed_to_padded combines first two dimensions so we need to swap 1st
# and 2nd dimension. the result is [n_rays_total_training, 1, 3, 1]
# (we use keepdim=True).
result = result.transpose(1, 2)
result = padded_to_packed(result, first_idxs, num_inputs)[:, None]
result = result.transpose(1, 2)

return result

# eval all results in this size
image_rgb = sample(image_rgb, mode="bilinear")
Expand Down
20 changes: 9 additions & 11 deletions pytorch3d/implicitron/models/renderer/ray_point_refiner.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,9 @@
# LICENSE file in the root directory of this source tree.

import torch
from pytorch3d.implicitron.models.renderer.base import ImplicitronRayBundle
from pytorch3d.implicitron.tools.config import Configurable, expand_args_fields
from pytorch3d.renderer import RayBundle

from pytorch3d.renderer.implicit.sample_pdf import sample_pdf


Expand Down Expand Up @@ -42,21 +43,21 @@ def __post_init__(self) -> None:

def forward(
self,
input_ray_bundle: RayBundle,
input_ray_bundle: ImplicitronRayBundle,
ray_weights: torch.Tensor,
**kwargs,
) -> RayBundle:
) -> ImplicitronRayBundle:
"""
Args:
input_ray_bundle: An instance of `RayBundle` specifying the
input_ray_bundle: An instance of `ImplicitronRayBundle` specifying the
source rays for sampling of the probability distribution.
ray_weights: A tensor of shape
`(..., input_ray_bundle.legths.shape[-1])` with non-negative
elements defining the probability distribution to sample
ray points from.
Returns:
ray_bundle: A new `RayBundle` instance containing the input ray
ray_bundle: A new `ImplicitronRayBundle` instance containing the input ray
points together with `n_pts_per_ray` additionally sampled
points per ray. For each ray, the lengths are sorted.
"""
Expand All @@ -79,9 +80,6 @@ def forward(
# Resort by depth.
z_vals, _ = torch.sort(z_vals, dim=-1)

return RayBundle(
origins=input_ray_bundle.origins,
directions=input_ray_bundle.directions,
lengths=z_vals,
xys=input_ray_bundle.xys,
)
new_bundle = ImplicitronRayBundle(**vars(input_ray_bundle))
new_bundle.lengths = z_vals
return new_bundle
25 changes: 23 additions & 2 deletions pytorch3d/implicitron/models/renderer/ray_sampler.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,7 +72,17 @@ class AbstractMaskRaySampler(RaySamplerBase, torch.nn.Module):
sampling_mode_evaluation: Same as above but for evaluation.
n_pts_per_ray_training: The number of points sampled along each ray during training.
n_pts_per_ray_evaluation: The number of points sampled along each ray during evaluation.
n_rays_per_image_sampled_from_mask: The amount of rays to be sampled from the image grid
n_rays_per_image_sampled_from_mask: The amount of rays to be sampled from the image
grid. Given a batch of image grids, this many is sampled from each.
`n_rays_per_image_sampled_from_mask` and `n_rays_total_training` cannot both be
defined.
n_rays_total_training: (optional) How many rays in total to sample from the entire
batch of provided image grid. The result is as if `n_rays_total_training`
cameras/image grids were sampled with replacement from the cameras / image grids
provided and for every camera one ray was sampled.
`n_rays_per_image_sampled_from_mask` and `n_rays_total_training` cannot both be
defined, to use you have to set `n_rays_per_image` to None.
Used only for EvaluationMode.TRAINING.
stratified_point_sampling_training: if set, performs stratified random sampling
along the ray; otherwise takes ray points at deterministic offsets.
stratified_point_sampling_evaluation: Same as above but for evaluation.
Expand All @@ -85,14 +95,23 @@ class AbstractMaskRaySampler(RaySamplerBase, torch.nn.Module):
sampling_mode_evaluation: str = "full_grid"
n_pts_per_ray_training: int = 64
n_pts_per_ray_evaluation: int = 64
n_rays_per_image_sampled_from_mask: int = 1024
n_rays_per_image_sampled_from_mask: Optional[int] = 1024
n_rays_total_training: Optional[int] = None
# stratified sampling vs taking points at deterministic offsets
stratified_point_sampling_training: bool = True
stratified_point_sampling_evaluation: bool = False

def __post_init__(self):
super().__init__()

if (self.n_rays_per_image_sampled_from_mask is not None) and (
self.n_rays_total_training is not None
):
raise ValueError(
"Cannot both define n_rays_total_training and "
"n_rays_per_image_sampled_from_mask."
)

self._sampling_mode = {
EvaluationMode.TRAINING: RenderSamplingMode(self.sampling_mode_training),
EvaluationMode.EVALUATION: RenderSamplingMode(
Expand All @@ -110,9 +129,11 @@ def __post_init__(self):
if self._sampling_mode[EvaluationMode.TRAINING]
== RenderSamplingMode.MASK_SAMPLE
else None,
n_rays_total=self.n_rays_total_training,
unit_directions=True,
stratified_sampling=self.stratified_point_sampling_training,
)

self._evaluation_raysampler = NDCMultinomialRaysampler(
image_width=self.image_width,
image_height=self.image_height,
Expand Down
19 changes: 12 additions & 7 deletions pytorch3d/renderer/implicit/raysampling.py
Original file line number Diff line number Diff line change
Expand Up @@ -90,11 +90,12 @@ def __init__(
min_depth: The minimum depth of a ray-point.
max_depth: The maximum depth of a ray-point.
n_rays_per_image: If given, this amount of rays are sampled from the grid.
`n_rays_per_image` and `n_rays_total` cannot both be defined.
n_rays_total: How many rays in total to sample from the cameras provided. The result
is as if `n_rays_total` cameras were sampled with replacement from the
cameras provided and for every camera one ray was sampled. If set, this disables
`n_rays_per_image` and returns the HeterogeneousRayBundle with
batch_size=n_rays_total.
is as if `n_rays_total_training` cameras were sampled with replacement from the
cameras provided and for every camera one ray was sampled. If set returns the
HeterogeneousRayBundle with batch_size=n_rays_total.
`n_rays_per_image` and `n_rays_total` cannot both be defined.
unit_directions: whether to normalize direction vectors in ray bundle.
stratified_sampling: if True, performs stratified random sampling
along the ray; otherwise takes ray points at deterministic offsets.
Expand Down Expand Up @@ -144,13 +145,15 @@ def forward(
min_depth: The minimum depth of a ray-point.
max_depth: The maximum depth of a ray-point.
n_rays_per_image: If given, this amount of rays are sampled from the grid.
`n_rays_per_image` and `n_rays_total` cannot both be defined.
n_pts_per_ray: The number of points sampled along each ray.
stratified_sampling: if set, overrides stratified_sampling provided
in __init__.
n_rays_total: How many rays in total to sample from the cameras provided. The result
is as if `n_rays_total_training` cameras were sampled with replacement from the
cameras provided and for every camera one ray was sampled. If set, returns the
cameras provided and for every camera one ray was sampled. If set returns the
HeterogeneousRayBundle with batch_size=n_rays_total.
`n_rays_per_image` and `n_rays_total` cannot both be defined.
Returns:
A named tuple RayBundle or dataclass HeterogeneousRayBundle with the
following fields:
Expand Down Expand Up @@ -352,13 +355,15 @@ def __init__(
min_y: The smallest y-coordinate of each ray's source pixel.
max_y: The largest y-coordinate of each ray's source pixel.
n_rays_per_image: The number of rays randomly sampled in each camera.
`n_rays_per_image` and `n_rays_total` cannot both be defined.
n_pts_per_ray: The number of points sampled along each ray.
min_depth: The minimum depth of each ray-point.
max_depth: The maximum depth of each ray-point.
n_rays_total: How many rays in total to sample from the cameras provided. The result
is as if `n_rays_total_training` cameras were sampled with replacement from the
cameras provided and for every camera one ray was sampled. If set, this returns
the HeterogeneousRayBundleyBundle with batch_size=n_rays_total.
cameras provided and for every camera one ray was sampled. If set returns the
HeterogeneousRayBundle with batch_size=n_rays_total.
`n_rays_per_image` and `n_rays_total` cannot both be defined.
unit_directions: whether to normalize direction vectors in ray bundle.
stratified_sampling: if True, performs stratified sampling in n_pts_per_ray
bins for each ray; otherwise takes n_pts_per_ray deterministic points
Expand Down
1 change: 1 addition & 0 deletions tests/implicitron/data/overrides.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,7 @@ raysampler_AdaptiveRaySampler_args:
n_pts_per_ray_training: 64
n_pts_per_ray_evaluation: 64
n_rays_per_image_sampled_from_mask: 1024
n_rays_total_training: null
stratified_point_sampling_training: true
stratified_point_sampling_evaluation: false
scene_extent: 8.0
Expand Down

0 comments on commit c311a4c

Please sign in to comment.