Skip to content

Commit

Permalink
Avoid raysampler dict
Browse files Browse the repository at this point in the history
Summary:
A significant speedup (e.g. >2% of a forward pass).

Move NDCMultinomialRaysampler parts of AbstractMaskRaySampler to members instead of living in a dict. The dict was hiding them from the nn.Module system so their _xy_grid members were remaining on the CPU. Therefore they were being copied to the GPU in every forward pass.

(We couldn't easily use a ModuleDict here because the enum keys are not strs.)

Reviewed By: shapovalov

Differential Revision: D39668589

fbshipit-source-id: 719b88e4a08fd7263a284e0ab38189e666bd7e3a
  • Loading branch information
bottler authored and facebook-github-bot committed Sep 21, 2022
1 parent da7fe28 commit 305cf32
Showing 1 changed file with 32 additions and 32 deletions.
64 changes: 32 additions & 32 deletions pytorch3d/implicitron/models/renderer/ray_sampler.py
Original file line number Diff line number Diff line change
Expand Up @@ -100,34 +100,32 @@ def __post_init__(self):
),
}

self._raysamplers = {
EvaluationMode.TRAINING: NDCMultinomialRaysampler(
image_width=self.image_width,
image_height=self.image_height,
n_pts_per_ray=self.n_pts_per_ray_training,
min_depth=0.0,
max_depth=0.0,
n_rays_per_image=self.n_rays_per_image_sampled_from_mask
if self._sampling_mode[EvaluationMode.TRAINING]
== RenderSamplingMode.MASK_SAMPLE
else None,
unit_directions=True,
stratified_sampling=self.stratified_point_sampling_training,
),
EvaluationMode.EVALUATION: NDCMultinomialRaysampler(
image_width=self.image_width,
image_height=self.image_height,
n_pts_per_ray=self.n_pts_per_ray_evaluation,
min_depth=0.0,
max_depth=0.0,
n_rays_per_image=self.n_rays_per_image_sampled_from_mask
if self._sampling_mode[EvaluationMode.EVALUATION]
== RenderSamplingMode.MASK_SAMPLE
else None,
unit_directions=True,
stratified_sampling=self.stratified_point_sampling_evaluation,
),
}
self._training_raysampler = NDCMultinomialRaysampler(
image_width=self.image_width,
image_height=self.image_height,
n_pts_per_ray=self.n_pts_per_ray_training,
min_depth=0.0,
max_depth=0.0,
n_rays_per_image=self.n_rays_per_image_sampled_from_mask
if self._sampling_mode[EvaluationMode.TRAINING]
== RenderSamplingMode.MASK_SAMPLE
else None,
unit_directions=True,
stratified_sampling=self.stratified_point_sampling_training,
)
self._evaluation_raysampler = NDCMultinomialRaysampler(
image_width=self.image_width,
image_height=self.image_height,
n_pts_per_ray=self.n_pts_per_ray_evaluation,
min_depth=0.0,
max_depth=0.0,
n_rays_per_image=self.n_rays_per_image_sampled_from_mask
if self._sampling_mode[EvaluationMode.EVALUATION]
== RenderSamplingMode.MASK_SAMPLE
else None,
unit_directions=True,
stratified_sampling=self.stratified_point_sampling_evaluation,
)

def _get_min_max_depth_bounds(self, cameras: CamerasBase) -> Tuple[float, float]:
raise NotImplementedError()
Expand Down Expand Up @@ -169,11 +167,13 @@ def forward(

min_depth, max_depth = self._get_min_max_depth_bounds(cameras)

raysampler = {
EvaluationMode.TRAINING: self._training_raysampler,
EvaluationMode.EVALUATION: self._evaluation_raysampler,
}[evaluation_mode]

# pyre-fixme[29]:
# `Union[BoundMethod[typing.Callable(torch.Tensor.__getitem__)[[Named(self,
# torch.Tensor), Named(item, typing.Any)], typing.Any], torch.Tensor],
# torch.Tensor, torch.nn.Module]` is not a function.
ray_bundle = self._raysamplers[evaluation_mode](
ray_bundle = raysampler(
cameras=cameras,
mask=sample_mask,
min_depth=min_depth,
Expand Down

0 comments on commit 305cf32

Please sign in to comment.