Skip to content

Commit

Permalink
voxel_grid_implicit_function scaffold fixes
Browse files Browse the repository at this point in the history
Summary: Fix indexing of directions after filtering of points by scaffold.

Reviewed By: shapovalov

Differential Revision: D40853482

fbshipit-source-id: 9cfdb981e97cb82edcd27632c5848537ed2c6837
  • Loading branch information
bottler authored and facebook-github-bot committed Nov 3, 2022
1 parent e4a3298 commit a1f2ded
Show file tree
Hide file tree
Showing 2 changed files with 44 additions and 30 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
import math
import warnings
from dataclasses import fields
from typing import Callable, Dict, Optional, Tuple, Union
from typing import Callable, Dict, Optional, Tuple

import torch

Expand Down Expand Up @@ -118,11 +118,11 @@ class VoxelGridImplicitFunction(ImplicitFunctionBase, torch.nn.Module):
the calculation.)
scaffold_resolution (Tuple[int, int, int]): (width, height, depth) of the underlying
voxel grid which stores scaffold
scaffold_empty_space_threshold (float): if `self.get_density` evaluates to less than
scaffold_empty_space_threshold (float): if `self._get_density` evaluates to less than
this it will be considered as empty space and the scaffold at that point would
evaluate as empty space.
scaffold_occupancy_chunk_size (str or int): Number of xy scaffold planes to calculate
at the same time. To calculate the scaffold we need to query `get_density()` at
at the same time. To calculate the scaffold we need to query `_get_density()` at
every voxel, this calculation can be split into scaffold depth number of xy plane
calculations if you want the lowest memory usage, one calculation to calculate the
whole scaffold, but with higher memory footprint or any other number of planes.
Expand Down Expand Up @@ -242,14 +242,16 @@ def forward(
points = ray_bundle_to_ray_points(ray_bundle)
directions = ray_bundle.directions.reshape(-1, 3)
input_shape = points.shape
num_points_per_ray = input_shape[-2]
points = points.view(-1, 3)
non_empty_points = None

# ########## filter the points using the scaffold ########## #
if self._scaffold_ready and self.scaffold_filter_points:
# pyre-ignore[29]
non_empty_points = self.voxel_grid_scaffold(points)[..., 0] > 0
with torch.no_grad():
# pyre-ignore[29]
non_empty_points = self.voxel_grid_scaffold(points)[..., 0] > 0
points = points[non_empty_points]
directions = directions[non_empty_points]
if len(points) == 0:
warnings.warn(
"The scaffold has filtered all the points."
Expand All @@ -262,8 +264,8 @@ def forward(
)

# ########## calculate color and density ########## #
rays_densities, rays_colors = self.calculate_density_and_color(
points, directions, camera
rays_densities, rays_colors = self._calculate_density_and_color(
points, directions, camera, non_empty_points, num_points_per_ray
)

if not (self._scaffold_ready and self.scaffold_filter_points):
Expand All @@ -283,9 +285,8 @@ def forward(
rays_colors_combined = rays_colors.new_zeros(
(math.prod(input_shape[:-1]), rays_colors.shape[-1])
)
# pyre-ignore[61]
assert non_empty_points is not None
rays_densities_combined[non_empty_points] = rays_densities
# pyre-ignore[61]
rays_colors_combined[non_empty_points] = rays_colors

return (
Expand All @@ -294,23 +295,28 @@ def forward(
{},
)

def calculate_density_and_color(
def _calculate_density_and_color(
self,
points: torch.Tensor,
directions: torch.Tensor,
camera: Optional[CamerasBase] = None,
camera: Optional[CamerasBase],
non_empty_points: Optional[torch.Tensor],
num_points_per_ray: int,
) -> Tuple[torch.Tensor, torch.Tensor]:
"""
Calculates density and color at `points`.
If enabled use cuda streams.
Args:
points: points at which to calculate density and color.
Tensor of shape [..., 3].
directions: from which directions are the points viewed
Tensor of shape [..., 3].
Tensor of shape [n_points, 3].
directions: from which directions are the points viewed.
One per ray. Tensor of shape [n_rays, 3].
camera: A camera model which will be used to transform the viewing
directions
non_empty_points: indices of points which weren't filtered out;
used for expanding directions
num_points_per_ray: number of points per ray, needed to expand directions.
Returns:
Tuple of color (tensor of shape [..., 3]) and density
(tensor of shape [..., 1])
Expand All @@ -323,20 +329,24 @@ def calculate_density_and_color(
with torch.cuda.stream(other_stream):
# rays_densities.shape =
# [minibatch x n_rays_width x n_rays_height x pts_per_ray x density_dim]
rays_densities = self.get_density(points)
rays_densities = self._get_density(points)

# rays_colors.shape =
# [minibatch x n_rays_width x n_rays_height x pts_per_ray x color_dim]
rays_colors = self.get_color(points, camera, directions)
rays_colors = self._get_color(
points, camera, directions, non_empty_points, num_points_per_ray
)

current_stream.wait_stream(other_stream)
else:
# Same calculation as above, just serial.
rays_densities = self.get_density(points)
rays_colors = self.get_color(points, camera, directions)
rays_densities = self._get_density(points)
rays_colors = self._get_color(
points, camera, directions, non_empty_points, num_points_per_ray
)
return rays_densities, rays_colors

def get_density(self, points: torch.Tensor) -> torch.Tensor:
def _get_density(self, points: torch.Tensor) -> torch.Tensor:
"""
Calculates density at points:
1) Evaluates the voxel grid on points
Expand All @@ -356,11 +366,13 @@ def get_density(self, points: torch.Tensor) -> torch.Tensor:
# shape = [..., density_dim]
return self.decoder_density(harmonic_embedding_density)

def get_color(
def _get_color(
self,
points: torch.Tensor,
camera: Optional[CamerasBase],
directions: torch.Tensor,
non_empty_points: Optional[torch.Tensor],
num_points_per_ray: int,
) -> torch.Tensor:
"""
Calculates color at points using the viewing direction:
Expand All @@ -376,6 +388,9 @@ def get_color(
directions
directions: A tensor of shape `(..., 3)`
containing the direction vectors of sampling rays in world coords.
non_empty_points: indices of points which weren't filtered out;
used for expanding directions
num_points_per_ray: number of points per ray, needed to expand directions.
"""
# ########## transform direction ########## #
if self.xyz_ray_dir_in_camera_coords:
Expand All @@ -400,12 +415,11 @@ def get_color(
rays_directions_normed
)

n_rays = directions.shape[0]
points_per_ray: int = points.shape[0] // n_rays

harmonic_embedding_dir = torch.repeat_interleave(
harmonic_embedding_dir, points_per_ray, dim=0
harmonic_embedding_dir, num_points_per_ray, dim=0
)
if non_empty_points is not None:
harmonic_embedding_dir = harmonic_embedding_dir[non_empty_points]

# total color embedding is concatenation of the harmonic embedding of voxel grid
# output and harmonic embedding of the normalized direction
Expand Down Expand Up @@ -505,7 +519,7 @@ def _get_scaffold(self, epoch: int) -> bool:
)
for k in range(0, points.shape[-1], chunk_size):
points_in_planes = points[..., k : k + chunk_size]
planes.append(self.get_density(points_in_planes)[..., 0])
planes.append(self._get_density(points_in_planes)[..., 0])

density_cube = torch.cat(planes, dim=-1)
density_cube = torch.nn.functional.max_pool3d(
Expand Down
8 changes: 4 additions & 4 deletions tests/implicitron/test_voxel_grid_implicit_function.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,7 +89,7 @@ def new_density(points):
out.append(torch.tensor([[0.0]]))
return torch.cat(out).view(*inshape[:-1], 1).to(device)

func.get_density = new_density
func._get_density = new_density
func._get_scaffold(0)

points = torch.tensor(
Expand Down Expand Up @@ -136,15 +136,15 @@ def new_density(points):
assert torch.all(scaffold(points)), (scaffold(points), points.shape)
return points.sum(dim=-1, keepdim=True)

def new_color(points, camera, directions):
def new_color(points, camera, directions, non_empty_points, num_points_per_ray):
# check if all passed points should be passed here
assert torch.all(scaffold(points)) # , (scaffold(points), points)
return points * 2

# check both computation paths that they contain only points
# which are not in empty space
func.get_density = new_density
func.get_color = new_color
func._get_density = new_density
func._get_color = new_color
func.voxel_grid_scaffold.forward = scaffold
func._scaffold_ready = True

Expand Down

0 comments on commit a1f2ded

Please sign in to comment.