diff --git a/pytorch3d/implicitron/tools/model_io.py b/pytorch3d/implicitron/tools/model_io.py index f8410a877..f94a4ed25 100644 --- a/pytorch3d/implicitron/tools/model_io.py +++ b/pytorch3d/implicitron/tools/model_io.py @@ -98,6 +98,13 @@ def save_model(model, stats, fl, optimizer=None, cfg=None): return flstats, flmodel, flopt +def save_stats(stats, fl, cfg=None): + flstats = get_stats_path(fl) + logger.info("saving model stats to %s" % flstats) + stats.save(flstats) + return flstats + + def load_model(fl, map_location: Optional[dict]): flstats = get_stats_path(fl) flmodel = get_model_path(fl) diff --git a/pytorch3d/ops/points_to_volumes.py b/pytorch3d/ops/points_to_volumes.py index 249702558..f319d90ae 100644 --- a/pytorch3d/ops/points_to_volumes.py +++ b/pytorch3d/ops/points_to_volumes.py @@ -291,6 +291,7 @@ def add_pointclouds_to_volumes( mask=mask, mode=mode, rescale_features=rescale_features, + align_corners=initial_volumes.get_align_corners(), _python=_python, ) @@ -310,6 +311,7 @@ def add_points_features_to_volume_densities_features( grid_sizes: Optional[torch.LongTensor] = None, rescale_features: bool = True, _python: bool = False, + align_corners: bool = True, ) -> Tuple[torch.Tensor, torch.Tensor]: """ Convert a batch of point clouds represented with tensors of per-point @@ -356,6 +358,7 @@ def add_points_features_to_volume_densities_features( output densities are just summed without rescaling, so you may need to rescale them afterwards. _python: Set to True to use a pure Python implementation. + align_corners: as for grid_sample. Returns: volume_features: Output volume of shape `(minibatch, feature_dim, D, H, W)` volume_densities: Occupancy volume of shape `(minibatch, 1, D, H, W)` @@ -409,7 +412,7 @@ def add_points_features_to_volume_densities_features( grid_sizes, 1.0, # point_weight mask, - True, # align_corners + align_corners, # align_corners splat, ) diff --git a/pytorch3d/renderer/implicit/renderer.py b/pytorch3d/renderer/implicit/renderer.py index 98fa5d285..ffd7578e4 100644 --- a/pytorch3d/renderer/implicit/renderer.py +++ b/pytorch3d/renderer/implicit/renderer.py @@ -382,9 +382,9 @@ def forward( rays_densities = torch.nn.functional.grid_sample( volumes_densities, rays_points_local_flat, - align_corners=True, mode=self._sample_mode, padding_mode=self._padding_mode, + align_corners=self._volumes.get_align_corners(), ) # permute the dimensions & reshape densities after sampling @@ -400,9 +400,9 @@ def forward( rays_features = torch.nn.functional.grid_sample( volumes_features, rays_points_local_flat, - align_corners=True, mode=self._sample_mode, padding_mode=self._padding_mode, + align_corners=self._volumes.get_align_corners(), ) # permute the dimensions & reshape features after sampling diff --git a/pytorch3d/structures/volumes.py b/pytorch3d/structures/volumes.py index 93784caff..23ed743da 100644 --- a/pytorch3d/structures/volumes.py +++ b/pytorch3d/structures/volumes.py @@ -85,7 +85,7 @@ class Volumes: are linearly interpolated over the spatial dimensions of the volume. - Note that the convention is the same as for the 5D version of the `torch.nn.functional.grid_sample` function called with - `align_corners==True`. + the same value of `align_corners` argument. - Note that the local coordinate convention of `Volumes` (+X = left to right, +Y = top to bottom, +Z = away from the user) is *different* from the world coordinate convention of the @@ -143,7 +143,7 @@ class Volumes: torch.nn.functional.grid_sample( v.densities(), v.get_coord_grid(world_coordinates=False), - align_corners=True, + align_corners=align_corners, ) == v.densities(), i.e. sampling the volume at trivial local coordinates @@ -157,6 +157,7 @@ def __init__( features: Optional[_TensorBatch] = None, voxel_size: _VoxelSize = 1.0, volume_translation: _Translation = (0.0, 0.0, 0.0), + align_corners: bool = True, ) -> None: """ Args: @@ -186,6 +187,10 @@ def __init__( b) a Tensor of shape (3,) c) a Tensor of shape (minibatch, 3) d) a Tensor of shape (1,) (square voxels) + **align_corners**: If set (default), the coordinates of the corner voxels are + exactly −1 or +1 in the local coordinate system. Otherwise, the coordinates + correspond to the centers of the corner voxels. Cf. the namesake argument to + `torch.nn.functional.grid_sample`. """ # handle densities @@ -206,6 +211,7 @@ def __init__( voxel_size=voxel_size, volume_translation=volume_translation, device=self.device, + align_corners=align_corners, ) # handle features @@ -336,6 +342,13 @@ def features_list(self) -> List[torch.Tensor]: return None return self._features_densities_list(features_) + def get_align_corners(self) -> bool: + """ + Return whether the corners of the voxels should be aligned with the + image pixels. + """ + return self.locator._align_corners + def _features_densities_list(self, x: torch.Tensor) -> List[torch.Tensor]: """ Retrieve the list representation of features/densities. @@ -576,7 +589,7 @@ class VolumeLocator: are linearly interpolated over the spatial dimensions of the volume. - Note that the convention is the same as for the 5D version of the `torch.nn.functional.grid_sample` function called with - `align_corners==True`. + the same value of `align_corners` argument. - Note that the local coordinate convention of `VolumeLocator` (+X = left to right, +Y = top to bottom, +Z = away from the user) is *different* from the world coordinate convention of the @@ -634,7 +647,7 @@ class VolumeLocator: torch.nn.functional.grid_sample( v.densities(), v.get_coord_grid(world_coordinates=False), - align_corners=True, + align_corners=align_corners, ) == v.densities(), i.e. sampling the volume at trivial local coordinates @@ -651,6 +664,7 @@ def __init__( device: torch.device, voxel_size: _VoxelSize = 1.0, volume_translation: _Translation = (0.0, 0.0, 0.0), + align_corners: bool = True, ): """ **batch_size** : Batch size of the underlying grids @@ -674,15 +688,21 @@ def __init__( b) a Tensor of shape (3,) c) a Tensor of shape (minibatch, 3) d) a Tensor of shape (1,) (square voxels) + **align_corners**: If set (default), the coordinates of the corner voxels are + exactly −1 or +1 in the local coordinate system. Otherwise, the coordinates + correspond to the centers of the corner voxels. Cf. the namesake argument to + `torch.nn.functional.grid_sample`. """ self.device = device self._batch_size = batch_size self._grid_sizes = self._convert_grid_sizes2tensor(grid_sizes) self._resolution = tuple(torch.max(self._grid_sizes.cpu(), dim=0).values) + self._align_corners = align_corners # set the local_to_world transform self._set_local_to_world_transform( - voxel_size=voxel_size, volume_translation=volume_translation + voxel_size=voxel_size, + volume_translation=volume_translation, ) def _convert_grid_sizes2tensor( @@ -806,8 +826,17 @@ def _calculate_coordinate_grid( grid_sizes = self.get_grid_sizes() # generate coordinate axes + def corner_coord_adjustment(r): + return 0.0 if self._align_corners else 1.0 / r + vol_axes = [ - torch.linspace(-1.0, 1.0, r, dtype=torch.float32, device=self.device) + torch.linspace( + -1.0 + corner_coord_adjustment(r), + 1.0 - corner_coord_adjustment(r), + r, + dtype=torch.float32, + device=self.device, + ) for r in (de, he, wi) ] diff --git a/tests/test_volumes.py b/tests/test_volumes.py index fc6c7a401..76a30413f 100644 --- a/tests/test_volumes.py +++ b/tests/test_volumes.py @@ -312,6 +312,49 @@ def test_coord_grid_convention( ).permute(0, 2, 3, 4, 1) self.assertClose(grid_world_resampled, grid_world, atol=1e-7) + for align_corners in [True, False]: + v_trivial = Volumes(densities=densities, align_corners=align_corners) + + # check the case with x_world=(0,0,0) + pts_world = torch.zeros( + num_volumes, 1, 3, device=device, dtype=torch.float32 + ) + pts_local = v_trivial.world_to_local_coords(pts_world) + pts_local_expected = torch.zeros_like(pts_local) + self.assertClose(pts_local, pts_local_expected) + + # check the case with x_world=(-2, 3, -2) + pts_world_tuple = [-2, 3, -2] + pts_world = torch.tensor( + pts_world_tuple, device=device, dtype=torch.float32 + )[None, None].repeat(num_volumes, 1, 1) + pts_local = v_trivial.world_to_local_coords(pts_world) + pts_local_expected = torch.tensor( + [-1, 1, -1], device=device, dtype=torch.float32 + )[None, None].repeat(num_volumes, 1, 1) + self.assertClose(pts_local, pts_local_expected) + + # # check that the central voxel has coords x_world=(0, 0, 0) and x_local(0, 0, 0) + grid_world = v_trivial.get_coord_grid(world_coordinates=True) + grid_local = v_trivial.get_coord_grid(world_coordinates=False) + for grid in (grid_world, grid_local): + x0 = grid[0, :, :, 2, 0] + y0 = grid[0, :, 3, :, 1] + z0 = grid[0, 2, :, :, 2] + for coord_line in (x0, y0, z0): + self.assertClose( + coord_line, torch.zeros_like(coord_line), atol=1e-7 + ) + + # resample grid_world using grid_sampler with local coords + # -> make sure the resampled version is the same as original + grid_world_resampled = torch.nn.functional.grid_sample( + grid_world.permute(0, 4, 1, 2, 3), + grid_local, + align_corners=align_corners, + ).permute(0, 2, 3, 4, 1) + self.assertClose(grid_world_resampled, grid_world, atol=1e-7) + def test_coord_grid_convention_heterogeneous( self, num_channels=4, dtype=torch.float32 ):