Skip to content

Commit

Permalink
Improve memory efficiency in VolumeSampler
Browse files Browse the repository at this point in the history
Summary: Avoids use of `torch.cat` operation when rendering a volume by instead issuing multiple calls to `torch.nn.functional.grid_sample`. Density and color tensors can be large.

Reviewed By: bottler

Differential Revision: D40072399

fbshipit-source-id: eb4cd34f6171d54972bbf2877065f973db497de0
  • Loading branch information
khundman authored and facebook-github-bot committed Oct 6, 2022
1 parent 0d8608b commit 4c8338b
Showing 1 changed file with 22 additions and 17 deletions.
39 changes: 22 additions & 17 deletions pytorch3d/renderer/implicit/renderer.py
Original file line number Diff line number Diff line change
Expand Up @@ -363,35 +363,40 @@ def forward(
volumes_densities = self._volumes.densities()
dim_density = volumes_densities.shape[1]
volumes_features = self._volumes.features()
# adjust the volumes_features variable in case we have a feature-less volume
if volumes_features is None:
dim_feature = 0
data_to_sample = volumes_densities
else:
dim_feature = volumes_features.shape[1]
data_to_sample = torch.cat((volumes_densities, volumes_features), dim=1)

# reshape to a size which grid_sample likes
rays_points_local_flat = rays_points_local.view(
rays_points_local.shape[0], -1, 1, 1, 3
)

# run the grid sampler
data_sampled = torch.nn.functional.grid_sample(
data_to_sample,
# run the grid sampler on the volumes densities
rays_densities = torch.nn.functional.grid_sample(
volumes_densities,
rays_points_local_flat,
align_corners=True,
mode=self._sample_mode,
)

# permute the dimensions & reshape after sampling
data_sampled = data_sampled.permute(0, 2, 3, 4, 1).view(
*rays_points_local.shape[:-1], data_sampled.shape[1]
# permute the dimensions & reshape densities after sampling
rays_densities = rays_densities.permute(0, 2, 3, 4, 1).view(
*rays_points_local.shape[:-1], volumes_densities.shape[1]
)

# split back to densities and features
rays_densities, rays_features = data_sampled.split(
[dim_density, dim_feature], dim=-1
)
# if features exist, run grid sampler again on the features densities
if volumes_features is None:
dim_feature = 0
_, rays_features = rays_densities.split([dim_density, dim_feature], dim=-1)
else:
rays_features = torch.nn.functional.grid_sample(
volumes_features,
rays_points_local_flat,
align_corners=True,
mode=self._sample_mode,
)

# permute the dimensions & reshape features after sampling
rays_features = rays_features.permute(0, 2, 3, 4, 1).view(
*rays_points_local.shape[:-1], volumes_features.shape[1]
)

return rays_densities, rays_features

0 comments on commit 4c8338b

Please sign in to comment.