diff --git a/pytorch3d/structures/volumes.py b/pytorch3d/structures/volumes.py index 7f3b66f02..08ad02ee4 100644 --- a/pytorch3d/structures/volumes.py +++ b/pytorch3d/structures/volumes.py @@ -23,6 +23,7 @@ _Translation = _Vector _TensorBatch = Union[torch.Tensor, List[torch.Tensor], Tuple[torch.Tensor]] +_ALL_CONTENT: slice = slice(0, None) class Volumes: @@ -65,9 +66,9 @@ class Volumes: VOLUME COORDINATES - Additionally, the `Volumes` class keeps track of the locations of the - centers of the volume cells in the local volume coordinates as well as in - the world coordinates. + Additionally, using the `VolumeLocator` class the `Volumes` class keeps track + of the locations of the centers of the volume cells in the local volume + coordinates as well as in the world coordinates. Local coordinates: - Represent the locations of the volume cells in the local coordinate @@ -125,7 +126,7 @@ class Volumes: appropriate `world_coordinates` argument. Internally, the mapping between `x_local` and `x_world` is represented - as a `Transform3d` object `Volumes._local_to_world_transform`. + as a `Transform3d` object `Volumes.VolumeLocator._local_to_world_transform`. Users can access the relevant transformations with the `Volumes.get_world_to_local_coords_transform()` and `Volumes.get_local_to_world_coords_transform()` @@ -197,21 +198,24 @@ def __init__( # assign to the internal buffers self._densities = densities_ - self._grid_sizes = grid_sizes + + # assign a coordinate transformation member + self.locator = VolumeLocator( + batch_size=len(self), + grid_sizes=grid_sizes, + voxel_size=voxel_size, + volume_translation=volume_translation, + device=self.device, + ) # handle features self._features = None if features is not None: self._set_features(features) - # set the local_to_world transform - self._set_local_to_world_transform( - voxel_size=voxel_size, volume_translation=volume_translation - ) - def _convert_densities_features_to_tensor( self, x: _TensorBatch, var_name: str - ) -> Tuple[torch.Tensor, torch.Tensor]: + ) -> Tuple[torch.Tensor, torch.LongTensor]: """ Handle the `densities` or `features` arguments to the constructor. """ @@ -251,252 +255,9 @@ def _convert_densities_features_to_tensor( f"{var_name} must be either a list or a tensor with " f"shape (batch_size, {var_name}_dim, H, W, D)." ) + # pyre-ignore[7] return x_tensor, x_shapes - def _voxel_size_translation_to_transform( - self, - voxel_size: torch.Tensor, - volume_translation: torch.Tensor, - batch_size: int, - ) -> Transform3d: - """ - Converts the `voxel_size` and `volume_translation` constructor arguments - to the internal `Transform3d` object `local_to_world_transform`. - """ - volume_size_zyx = self.get_grid_sizes().float() - volume_size_xyz = volume_size_zyx[:, [2, 1, 0]] - - # x_local = ( - # (x_world + volume_translation) / (0.5 * voxel_size) - # ) / (volume_size - 1) - - # x_world = ( - # x_local * (volume_size - 1) * 0.5 * voxel_size - # ) - volume_translation - - local_to_world_transform = Scale( - (volume_size_xyz - 1) * voxel_size * 0.5, device=self.device - ).translate(-volume_translation) - - return local_to_world_transform - - def _handle_voxel_size( - self, voxel_size: _VoxelSize, batch_size: int - ) -> torch.Tensor: - """ - Handle the `voxel_size` argument to the `Volumes` constructor. - """ - err_msg = ( - "voxel_size has to be either a 3-tuple of scalars, or a scalar, or" - " a torch.Tensor of shape (3,) or (1,) or (minibatch, 3) or (minibatch, 1)." - ) - if isinstance(voxel_size, (float, int)): - # convert a scalar to a 3-element tensor - voxel_size = torch.full( - (1, 3), voxel_size, device=self.device, dtype=torch.float32 - ) - elif isinstance(voxel_size, torch.Tensor): - if voxel_size.numel() == 1: - # convert a single-element tensor to a 3-element one - voxel_size = voxel_size.view(-1).repeat(3) - elif len(voxel_size.shape) == 2 and ( - voxel_size.shape[0] == batch_size and voxel_size.shape[1] == 1 - ): - voxel_size = voxel_size.repeat(1, 3) - return self._convert_volume_property_to_tensor(voxel_size, batch_size, err_msg) - - def _handle_volume_translation( - self, translation: _Translation, batch_size: int - ) -> torch.Tensor: - """ - Handle the `volume_translation` argument to the `Volumes` constructor. - """ - err_msg = ( - "`volume_translation` has to be either a 3-tuple of scalars, or" - " a Tensor of shape (1,3) or (minibatch, 3) or (3,)`." - ) - return self._convert_volume_property_to_tensor(translation, batch_size, err_msg) - - def _convert_volume_property_to_tensor( - self, x: _Vector, batch_size: int, err_msg: str - ) -> torch.Tensor: - """ - Handle the `volume_translation` or `voxel_size` argument to - the Volumes constructor. - Return a tensor of shape (N, 3) where N is the batch_size. - """ - if isinstance(x, (list, tuple)): - if len(x) != 3: - raise ValueError(err_msg) - x = torch.tensor(x, device=self.device, dtype=torch.float32)[None] - x = x.repeat((batch_size, 1)) - elif isinstance(x, torch.Tensor): - ok = ( - (x.shape[0] == 1 and x.shape[1] == 3) - or (x.shape[0] == 3 and len(x.shape) == 1) - or (x.shape[0] == batch_size and x.shape[1] == 3) - ) - if not ok: - raise ValueError(err_msg) - if x.device != self.device: - x = x.to(self.device) - if x.shape[0] == 3 and len(x.shape) == 1: - x = x[None] - if x.shape[0] == 1: - x = x.repeat((batch_size, 1)) - else: - raise ValueError(err_msg) - - return x - - def get_coord_grid(self, world_coordinates: bool = True) -> torch.Tensor: - """ - Return the 3D coordinate grid of the volumetric grid - in local (`world_coordinates=False`) or world coordinates - (`world_coordinates=True`). - - The grid records location of each center of the corresponding volume voxel. - - Local coordinates are scaled s.t. the values along one side of the - volume are in range [-1, 1]. - - Args: - **world_coordinates**: if `True`, the method - returns the grid in the world coordinates, - otherwise, in local coordinates. - - Returns: - **coordinate_grid**: The grid of coordinates of shape - `(minibatch, depth, height, width, 3)`, where `minibatch`, - `height`, `width` and `depth` are the batch size, height, width - and depth of the volume `features` or `densities`. - """ - # TODO(dnovotny): Implement caching of the coordinate grid. - return self._calculate_coordinate_grid(world_coordinates=world_coordinates) - - def _calculate_coordinate_grid( - self, world_coordinates: bool = True - ) -> torch.Tensor: - """ - Calculate the 3D coordinate grid of the volumetric grid either in - in local (`world_coordinates=False`) or - world coordinates (`world_coordinates=True`) . - """ - - densities = self.densities() - ba, _, de, he, wi = densities.shape - grid_sizes = self.get_grid_sizes() - - # generate coordinate axes - vol_axes = [ - torch.linspace(-1.0, 1.0, r, dtype=torch.float32, device=self.device) - for r in (de, he, wi) - ] - - # generate per-coord meshgrids - Z, Y, X = meshgrid_ij(vol_axes) - - # stack the coord grids ... this order matches the coordinate convention - # of torch.nn.grid_sample - vol_coords_local = torch.stack((X, Y, Z), dim=3)[None].repeat(ba, 1, 1, 1, 1) - - # get grid sizes relative to the maximal volume size - grid_sizes_relative = ( - torch.tensor([[de, he, wi]], device=grid_sizes.device, dtype=torch.float32) - - 1 - ) / (grid_sizes - 1).float() - - if (grid_sizes_relative != 1.0).any(): - # if any of the relative sizes != 1.0, adjust the grid - grid_sizes_relative_reshape = grid_sizes_relative[:, [2, 1, 0]][ - :, None, None, None - ] - vol_coords_local *= grid_sizes_relative_reshape - vol_coords_local += grid_sizes_relative_reshape - 1 - - if world_coordinates: - vol_coords = self.local_to_world_coords(vol_coords_local) - else: - vol_coords = vol_coords_local - - return vol_coords - - def get_local_to_world_coords_transform(self) -> Transform3d: - """ - Return a Transform3d object that converts points in the - the local coordinate frame of the volume to world coordinates. - Local volume coordinates are scaled s.t. the coordinates along one - side of the volume are in range [-1, 1]. - - Returns: - **local_to_world_transform**: A Transform3d object converting - points from local coordinates to the world coordinates. - """ - return self._local_to_world_transform - - def get_world_to_local_coords_transform(self) -> Transform3d: - """ - Return a Transform3d object that converts points in the - world coordinates to the local coordinate frame of the volume. - Local volume coordinates are scaled s.t. the coordinates along one - side of the volume are in range [-1, 1]. - - Returns: - **world_to_local_transform**: A Transform3d object converting - points from world coordinates to local coordinates. - """ - return self.get_local_to_world_coords_transform().inverse() - - def world_to_local_coords(self, points_3d_world: torch.Tensor) -> torch.Tensor: - """ - Convert a batch of 3D point coordinates `points_3d_world` of shape - (minibatch, ..., dim) in the world coordinates to - the local coordinate frame of the volume. Local volume - coordinates are scaled s.t. the coordinates along one side of the volume - are in range [-1, 1]. - - Args: - **points_3d_world**: A tensor of shape `(minibatch, ..., 3)` - containing the 3D coordinates of a set of points that will - be converted from the local volume coordinates (ranging - within [-1, 1]) to the world coordinates - defined by the `self.center` and `self.voxel_size` parameters. - - Returns: - **points_3d_local**: `points_3d_world` converted to the local - volume coordinates of shape `(minibatch, ..., 3)`. - """ - pts_shape = points_3d_world.shape - return ( - self.get_world_to_local_coords_transform() - .transform_points(points_3d_world.view(pts_shape[0], -1, 3)) - .view(pts_shape) - ) - - def local_to_world_coords(self, points_3d_local: torch.Tensor) -> torch.Tensor: - """ - Convert a batch of 3D point coordinates `points_3d_local` of shape - (minibatch, ..., dim) in the local coordinate frame of the volume - to the world coordinates. - - Args: - **points_3d_local**: A tensor of shape `(minibatch, ..., 3)` - containing the 3D coordinates of a set of points that will - be converted from the local volume coordinates (ranging - within [-1, 1]) to the world coordinates - defined by the `self.center` and `self.voxel_size` parameters. - - Returns: - **points_3d_world**: `points_3d_local` converted to the world - coordinates of the volume of shape `(minibatch, ..., 3)`. - """ - pts_shape = points_3d_local.shape - return ( - self.get_local_to_world_coords_transform() - .transform_points(points_3d_local.view(pts_shape[0], -1, 3)) - .view(pts_shape) - ) - def __len__(self) -> int: return self._densities.shape[0] @@ -530,8 +291,7 @@ def __getitem__( densities=self.densities()[index], ) # dont forget to update grid_sizes! - new._grid_sizes = self.get_grid_sizes()[index] - new._local_to_world_transform = self._local_to_world_transform[index] + self.locator._copy_transform_and_sizes(new.locator, index=index) return new def features(self) -> Optional[torch.Tensor]: @@ -593,16 +353,6 @@ def _features_densities_list(self, x: torch.Tensor) -> List[torch.Tensor]: x_list = struct_utils.padded_to_list(x, pad_sizes.tolist()) return x_list - def get_grid_sizes(self) -> torch.LongTensor: - """ - Returns the sizes of individual volumetric grids in the structure. - - Returns: - **grid_sizes**: Tensor of spatial sizes of each of the volumes - of size (batchsize, 3), where i-th row holds (D_i, H_i, W_i). - """ - return self._grid_sizes - def update_padded( self, new_densities: torch.Tensor, new_features: Optional[torch.Tensor] = None ) -> "Volumes": @@ -656,6 +406,525 @@ def _set_densities_features(self, var_name: str, x: _TensorBatch) -> None: ) setattr(self, "_" + var_name, x_tensor) + def clone(self) -> "Volumes": + """ + Deep copy of Volumes object. All internal tensors are cloned + individually. + + Returns: + new Volumes object. + """ + return copy.deepcopy(self) + + def to(self, device: Device, copy: bool = False) -> "Volumes": + """ + Match the functionality of torch.Tensor.to() + If copy = True or the self Tensor is on a different device, the + returned tensor is a copy of self with the desired torch.device. + If copy = False and the self Tensor already has the correct torch.device, + then self is returned. + + Args: + device: Device (as str or torch.device) for the new tensor. + copy: Boolean indicator whether or not to clone self. Default False. + + Returns: + Volumes object. + """ + device_ = make_device(device) + if not copy and self.device == device_: + return self + + other = self.clone() + if self.device == device_: + return other + + other.device = device_ + other._densities = self._densities.to(device_) + if self._features is not None: + # pyre-fixme[16]: `Optional` has no attribute `to`. + other._features = self.features().to(device_) + self.locator._copy_transform_and_sizes(other.locator, device=device_) + other.locator = other.locator.to(device, copy) + return other + + def cpu(self) -> "Volumes": + return self.to("cpu") + + def cuda(self) -> "Volumes": + return self.to("cuda") + + def get_grid_sizes(self) -> torch.LongTensor: + """ + Returns the sizes of individual volumetric grids in the structure. + + Returns: + **grid_sizes**: Tensor of spatial sizes of each of the volumes + of size (batchsize, 3), where i-th row holds (D_i, H_i, W_i). + """ + return self.locator.get_grid_sizes() + + def get_local_to_world_coords_transform(self) -> Transform3d: + """ + Return a Transform3d object that converts points in the + the local coordinate frame of the volume to world coordinates. + Local volume coordinates are scaled s.t. the coordinates along one + side of the volume are in range [-1, 1]. + + Returns: + **local_to_world_transform**: A Transform3d object converting + points from local coordinates to the world coordinates. + """ + return self.locator.get_local_to_world_coords_transform() + + def get_world_to_local_coords_transform(self) -> Transform3d: + """ + Return a Transform3d object that converts points in the + world coordinates to the local coordinate frame of the volume. + Local volume coordinates are scaled s.t. the coordinates along one + side of the volume are in range [-1, 1]. + + Returns: + **world_to_local_transform**: A Transform3d object converting + points from world coordinates to local coordinates. + """ + return self.get_local_to_world_coords_transform().inverse() + + def world_to_local_coords(self, points_3d_world: torch.Tensor) -> torch.Tensor: + """ + Convert a batch of 3D point coordinates `points_3d_world` of shape + (minibatch, ..., dim) in the world coordinates to + the local coordinate frame of the volume. Local volume + coordinates are scaled s.t. the coordinates along one side of the volume + are in range [-1, 1]. + + Args: + **points_3d_world**: A tensor of shape `(minibatch, ..., 3)` + containing the 3D coordinates of a set of points that will + be converted from the local volume coordinates (ranging + within [-1, 1]) to the world coordinates + defined by the `self.center` and `self.voxel_size` parameters. + + Returns: + **points_3d_local**: `points_3d_world` converted to the local + volume coordinates of shape `(minibatch, ..., 3)`. + """ + return self.locator.world_to_local_coords(points_3d_world) + + def local_to_world_coords(self, points_3d_local: torch.Tensor) -> torch.Tensor: + """ + Convert a batch of 3D point coordinates `points_3d_local` of shape + (minibatch, ..., dim) in the local coordinate frame of the volume + to the world coordinates. + + Args: + **points_3d_local**: A tensor of shape `(minibatch, ..., 3)` + containing the 3D coordinates of a set of points that will + be converted from the local volume coordinates (ranging + within [-1, 1]) to the world coordinates + defined by the `self.center` and `self.voxel_size` parameters. + + Returns: + **points_3d_world**: `points_3d_local` converted to the world + coordinates of the volume of shape `(minibatch, ..., 3)`. + """ + return self.locator.local_to_world_coords(points_3d_local) + + def get_coord_grid(self, world_coordinates: bool = True) -> torch.Tensor: + """ + Return the 3D coordinate grid of the volumetric grid + in local (`world_coordinates=False`) or world coordinates + (`world_coordinates=True`). + + The grid records location of each center of the corresponding volume voxel. + + Local coordinates are scaled s.t. the values along one side of the + volume are in range [-1, 1]. + + Args: + **world_coordinates**: if `True`, the method + returns the grid in the world coordinates, + otherwise, in local coordinates. + + Returns: + **coordinate_grid**: The grid of coordinates of shape + `(minibatch, depth, height, width, 3)`, where `minibatch`, + `height`, `width` and `depth` are the batch size, height, width + and depth of the volume `features` or `densities`. + """ + return self.locator.get_coord_grid(world_coordinates) + + +class VolumeLocator: + """ + The `VolumeLocator` class keeps track of the locations of the + centers of the volume cells in the local volume coordinates as well as in + the world coordinates for a voxel grid structure in 3D. + + Local coordinates: + - Represent the locations of the volume cells in the local coordinate + frame of the volume. + - The center of the voxel indexed with `[·, ·, 0, 0, 0]` in the volume + has its 3D local coordinate set to `[-1, -1, -1]`, while the voxel + at index `[·, ·, depth_i-1, height_i-1, width_i-1]` has its + 3D local coordinate set to `[1, 1, 1]`. + - The first/second/third coordinate of each of the 3D per-voxel + XYZ vector denotes the horizontal/vertical/depth-wise position + respectively. I.e the order of the coordinate dimensions in the + volume is reversed w.r.t. the order of the 3D coordinate vectors. + - The intermediate coordinates between `[-1, -1, -1]` and `[1, 1, 1]`. + are linearly interpolated over the spatial dimensions of the volume. + - Note that the convention is the same as for the 5D version of the + `torch.nn.functional.grid_sample` function called with + `align_corners==True`. + - Note that the local coordinate convention of `VolumeLocator` + (+X = left to right, +Y = top to bottom, +Z = away from the user) + is *different* from the world coordinate convention of the + renderer for `Meshes` or `Pointclouds` + (+X = right to left, +Y = bottom to top, +Z = away from the user). + + World coordinates: + - These define the locations of the centers of the volume cells + in the world coordinates. + - They are specified with the following mapping that converts + points `x_local` in the local coordinates to points `x_world` + in the world coordinates: + ``` + x_world = ( + x_local * (volume_size - 1) * 0.5 * voxel_size + ) - volume_translation, + ``` + here `voxel_size` specifies the size of each voxel of the volume, + and `volume_translation` is the 3D offset of the central voxel of + the volume w.r.t. the origin of the world coordinate frame. + Both `voxel_size` and `volume_translation` are specified in + the world coordinate units. `volume_size` is the spatial size of + the volume in form of a 3D vector `[width, height, depth]`. + - Given the above definition of `x_world`, one can derive the + inverse mapping from `x_world` to `x_local` as follows: + ``` + x_local = ( + (x_world + volume_translation) / (0.5 * voxel_size) + ) / (volume_size - 1) + ``` + - For a trivial volume with `volume_translation==[0, 0, 0]` + with `voxel_size=-1`, `x_world` would range + from -(volume_size-1)/2` to `+(volume_size-1)/2`. + + Coordinate tensors that denote the locations of each of the volume cells in + local / world coordinates (with shape `(depth x height x width x 3)`) + can be retrieved by calling the `VolumeLocator.get_coord_grid()` getter with the + appropriate `world_coordinates` argument. + + Internally, the mapping between `x_local` and `x_world` is represented + as a `Transform3d` object `VolumeLocator._local_to_world_transform`. + Users can access the relevant transformations with the + `VolumeLocator.get_world_to_local_coords_transform()` and + `VolumeLocator.get_local_to_world_coords_transform()` + functions. + + Example coordinate conversion: + - For a "trivial" volume with `voxel_size = 1.`, + `volume_translation=[0., 0., 0.]`, and the spatial size of + `DxHxW = 5x5x5`, the point `x_world = (-2, 0, 2)` gets mapped + to `x_local=(-1, 0, 1)`. + - For a "trivial" volume `v` with `voxel_size = 1.`, + `volume_translation=[0., 0., 0.]`, the following holds: + ``` + torch.nn.functional.grid_sample( + v.densities(), + v.get_coord_grid(world_coordinates=False), + align_corners=True, + ) == v.densities(), + ``` + i.e. sampling the volume at trivial local coordinates + (no scaling with `voxel_size`` or shift with `volume_translation`) + results in the same volume. + """ + + def __init__( + self, + batch_size: int, + grid_sizes: Union[ + torch.LongTensor, Tuple[int, int, int], List[torch.LongTensor] + ], + device: torch.device, + voxel_size: _VoxelSize = 1.0, + volume_translation: _Translation = (0.0, 0.0, 0.0), + ): + """ + **batch_size** : Batch size of the underlaying grids + **grid_sizes** : Represents the resolutions of different grids in the batch. Can be + a) tuple of form (H, W, D) + b) list/tuple of length batch_size of lists/tuples of form (H, W, D) + c) torch.Tensor of shape (batch_size, H, W, D) + H, W, D are height, width, depth respectively. If `grid_sizes` is a tuple than + all the grids in the batch have the same resolution. + **voxel_size**: Denotes the size of each volume voxel in world units. + Has to be one of: + a) A scalar (square voxels) + b) 3-tuple or a 3-list of scalars + c) a Tensor of shape (3,) + d) a Tensor of shape (minibatch, 3) + e) a Tensor of shape (minibatch, 1) + f) a Tensor of shape (1,) (square voxels) + **volume_translation**: Denotes the 3D translation of the center + of the volume in world units. Has to be one of: + a) 3-tuple or a 3-list of scalars + b) a Tensor of shape (3,) + c) a Tensor of shape (minibatch, 3) + d) a Tensor of shape (1,) (square voxels) + """ + self.device = device + self._batch_size = batch_size + self._grid_sizes = self._convert_grid_sizes2tensor(grid_sizes) + self._resolution = tuple(torch.max(self._grid_sizes.cpu(), dim=0).values) + + # set the local_to_world transform + self._set_local_to_world_transform( + voxel_size=voxel_size, volume_translation=volume_translation + ) + + def _convert_grid_sizes2tensor( + self, x: Union[torch.LongTensor, List[torch.LongTensor], Tuple[int, int, int]] + ) -> torch.LongTensor: + """ + Handle the grid_sizes argument to the constructor. + """ + if isinstance(x, (list, tuple)): + if isinstance(x[0], (torch.LongTensor, list, tuple)): + if self._batch_size != len(x): + raise ValueError("x should have a batch size of 'batch_size'") + # pyre-ignore[6] + if any(len(x_) != 3 for x_ in x): + raise ValueError( + "`grid_sizes` has to be a list of 3-dim tensors of shape: " + "(height, width, depth)" + ) + x_shapes = torch.stack( + [ + torch.tensor( + # pyre-ignore[6] + list(x_), + dtype=torch.long, + device=self.device, + ) + for x_ in x + ], + dim=0, + ) + elif isinstance(x[0], int): + x_shapes = torch.stack( + [ + torch.tensor(list(x), dtype=torch.long, device=self.device) + for _ in range(self._batch_size) + ], + dim=0, + ) + else: + raise ValueError( + "`grid_sizes` can be a list/tuple of int or torch.Tensor not of " + + "{type(x[0])}." + ) + + elif torch.is_tensor(x): + if x.ndim != 2: + raise ValueError( + "`grid_sizes` has to be a 2-dim tensor of shape: (minibatch, 3)" + ) + x_shapes = x.to(self.device) + else: + raise ValueError( + "grid_sizes must be either a list of tensors with shape (H, W, D), tensor with" + "shape (batch_size, H, W, D) or a tuple of (H, W, D)." + ) + # pyre-ignore[7] + return x_shapes + + def _voxel_size_translation_to_transform( + self, + voxel_size: torch.Tensor, + volume_translation: torch.Tensor, + batch_size: int, + ) -> Transform3d: + """ + Converts the `voxel_size` and `volume_translation` constructor arguments + to the internal `Transform3d` object `local_to_world_transform`. + """ + volume_size_zyx = self.get_grid_sizes().float() + volume_size_xyz = volume_size_zyx[:, [2, 1, 0]] + + # x_local = ( + # (x_world + volume_translation) / (0.5 * voxel_size) + # ) / (volume_size - 1) + + # x_world = ( + # x_local * (volume_size - 1) * 0.5 * voxel_size + # ) - volume_translation + + local_to_world_transform = Scale( + (volume_size_xyz - 1) * voxel_size * 0.5, device=self.device + ).translate(-volume_translation) + + return local_to_world_transform + + def get_coord_grid(self, world_coordinates: bool = True) -> torch.Tensor: + """ + Return the 3D coordinate grid of the volumetric grid + in local (`world_coordinates=False`) or world coordinates + (`world_coordinates=True`). + + The grid records location of each center of the corresponding volume voxel. + + Local coordinates are scaled s.t. the values along one side of the + volume are in range [-1, 1]. + + Args: + **world_coordinates**: if `True`, the method + returns the grid in the world coordinates, + otherwise, in local coordinates. + + Returns: + **coordinate_grid**: The grid of coordinates of shape + `(minibatch, depth, height, width, 3)`, where `minibatch`, + `height`, `width` and `depth` are the batch size, height, width + and depth of the volume `features` or `densities`. + """ + # TODO(dnovotny): Implement caching of the coordinate grid. + return self._calculate_coordinate_grid(world_coordinates=world_coordinates) + + def _calculate_coordinate_grid( + self, world_coordinates: bool = True + ) -> torch.Tensor: + """ + Calculate the 3D coordinate grid of the volumetric grid either + in local (`world_coordinates=False`) or + world coordinates (`world_coordinates=True`) . + """ + + ba, (de, he, wi) = self._batch_size, self._resolution + grid_sizes = self.get_grid_sizes() + + # generate coordinate axes + vol_axes = [ + torch.linspace(-1.0, 1.0, r, dtype=torch.float32, device=self.device) + for r in (de, he, wi) + ] + + # generate per-coord meshgrids + Z, Y, X = meshgrid_ij(vol_axes) + + # stack the coord grids ... this order matches the coordinate convention + # of torch.nn.grid_sample + vol_coords_local = torch.stack((X, Y, Z), dim=3)[None].repeat(ba, 1, 1, 1, 1) + + # get grid sizes relative to the maximal volume size + grid_sizes_relative = ( + torch.tensor([[de, he, wi]], device=grid_sizes.device, dtype=torch.float32) + - 1 + ) / (grid_sizes - 1).float() + + if (grid_sizes_relative != 1.0).any(): + # if any of the relative sizes != 1.0, adjust the grid + grid_sizes_relative_reshape = grid_sizes_relative[:, [2, 1, 0]][ + :, None, None, None + ] + vol_coords_local *= grid_sizes_relative_reshape + vol_coords_local += grid_sizes_relative_reshape - 1 + + if world_coordinates: + vol_coords = self.local_to_world_coords(vol_coords_local) + else: + vol_coords = vol_coords_local + + return vol_coords + + def get_local_to_world_coords_transform(self) -> Transform3d: + """ + Return a Transform3d object that converts points in the + the local coordinate frame of the volume to world coordinates. + Local volume coordinates are scaled s.t. the coordinates along one + side of the volume are in range [-1, 1]. + + Returns: + **local_to_world_transform**: A Transform3d object converting + points from local coordinates to the world coordinates. + """ + return self._local_to_world_transform + + def get_world_to_local_coords_transform(self) -> Transform3d: + """ + Return a Transform3d object that converts points in the + world coordinates to the local coordinate frame of the volume. + Local volume coordinates are scaled s.t. the coordinates along one + side of the volume are in range [-1, 1]. + + Returns: + **world_to_local_transform**: A Transform3d object converting + points from world coordinates to local coordinates. + """ + return self.get_local_to_world_coords_transform().inverse() + + def world_to_local_coords(self, points_3d_world: torch.Tensor) -> torch.Tensor: + """ + Convert a batch of 3D point coordinates `points_3d_world` of shape + (minibatch, ..., dim) in the world coordinates to + the local coordinate frame of the volume. Local volume + coordinates are scaled s.t. the coordinates along one side of the volume + are in range [-1, 1]. + + Args: + **points_3d_world**: A tensor of shape `(minibatch, ..., 3)` + containing the 3D coordinates of a set of points that will + be converted from the local volume coordinates (ranging + within [-1, 1]) to the world coordinates + defined by the `self.center` and `self.voxel_size` parameters. + + Returns: + **points_3d_local**: `points_3d_world` converted to the local + volume coordinates of shape `(minibatch, ..., 3)`. + """ + pts_shape = points_3d_world.shape + return ( + self.get_world_to_local_coords_transform() + .transform_points(points_3d_world.view(pts_shape[0], -1, 3)) + .view(pts_shape) + ) + + def local_to_world_coords(self, points_3d_local: torch.Tensor) -> torch.Tensor: + """ + Convert a batch of 3D point coordinates `points_3d_local` of shape + (minibatch, ..., dim) in the local coordinate frame of the volume + to the world coordinates. + + Args: + **points_3d_local**: A tensor of shape `(minibatch, ..., 3)` + containing the 3D coordinates of a set of points that will + be converted from the local volume coordinates (ranging + within [-1, 1]) to the world coordinates + defined by the `self.center` and `self.voxel_size` parameters. + + Returns: + **points_3d_world**: `points_3d_local` converted to the world + coordinates of the volume of shape `(minibatch, ..., 3)`. + """ + pts_shape = points_3d_local.shape + return ( + self.get_local_to_world_coords_transform() + .transform_points(points_3d_local.view(pts_shape[0], -1, 3)) + .view(pts_shape) + ) + + def get_grid_sizes(self) -> torch.LongTensor: + """ + Returns the sizes of individual volumetric grids in the structure. + + Returns: + **grid_sizes**: Tensor of spatial sizes of each of the volumes + of size (batchsize, 3), where i-th row holds (D_i, H_i, W_i). + """ + return self._grid_sizes + def _set_local_to_world_transform( self, voxel_size: _VoxelSize = 1.0, @@ -690,17 +959,104 @@ def _set_local_to_world_transform( voxel_size, volume_translation, len(self) ) - def clone(self) -> "Volumes": + def _copy_transform_and_sizes( + self, + other: "VolumeLocator", + device: Optional[torch.device] = None, + index: Optional[ + Union[int, List[int], Tuple[int], slice, torch.Tensor] + ] = _ALL_CONTENT, + ) -> None: """ - Deep copy of Volumes object. All internal tensors are cloned - individually. + Copies the local to world transform and grid sizes to other VolumeLocator object + and moves it to specified device. Operates in place on other. - Returns: - new Volumes object. + Args: + other: VolumeLocator object to which to copy + device: torch.device on which to put the result, defatults to self.device + index: Specifies which parts to copy. + Can be an int, slice, list of ints or a boolean or a long tensor. + Defaults to all items (`:`). """ - return copy.deepcopy(self) + device = device if device is not None else self.device + other._grid_sizes = self._grid_sizes[index].to(device) + other._local_to_world_transform = self.get_local_to_world_coords_transform()[ + index + ].to(device) - def to(self, device: Device, copy: bool = False) -> "Volumes": + def _handle_voxel_size( + self, voxel_size: _VoxelSize, batch_size: int + ) -> torch.Tensor: + """ + Handle the `voxel_size` argument to the `VolumeLocator` constructor. + """ + err_msg = ( + "voxel_size has to be either a 3-tuple of scalars, or a scalar, or" + " a torch.Tensor of shape (3,) or (1,) or (minibatch, 3) or (minibatch, 1)." + ) + if isinstance(voxel_size, (float, int)): + # convert a scalar to a 3-element tensor + voxel_size = torch.full( + (1, 3), voxel_size, device=self.device, dtype=torch.float32 + ) + elif isinstance(voxel_size, torch.Tensor): + if voxel_size.numel() == 1: + # convert a single-element tensor to a 3-element one + voxel_size = voxel_size.view(-1).repeat(3) + elif len(voxel_size.shape) == 2 and ( + voxel_size.shape[0] == batch_size and voxel_size.shape[1] == 1 + ): + voxel_size = voxel_size.repeat(1, 3) + return self._convert_volume_property_to_tensor(voxel_size, batch_size, err_msg) + + def _handle_volume_translation( + self, translation: _Translation, batch_size: int + ) -> torch.Tensor: + """ + Handle the `volume_translation` argument to the `VolumeLocator` constructor. + """ + err_msg = ( + "`volume_translation` has to be either a 3-tuple of scalars, or" + " a Tensor of shape (1,3) or (minibatch, 3) or (3,)`." + ) + return self._convert_volume_property_to_tensor(translation, batch_size, err_msg) + + def __len__(self) -> int: + return self._batch_size + + def _convert_volume_property_to_tensor( + self, x: _Vector, batch_size: int, err_msg: str + ) -> torch.Tensor: + """ + Handle the `volume_translation` or `voxel_size` argument to + the VolumeLocator constructor. + Return a tensor of shape (N, 3) where N is the batch_size. + """ + if isinstance(x, (list, tuple)): + if len(x) != 3: + raise ValueError(err_msg) + x = torch.tensor(x, device=self.device, dtype=torch.float32)[None] + x = x.repeat((batch_size, 1)) + elif isinstance(x, torch.Tensor): + ok = ( + (x.shape[0] == 1 and x.shape[1] == 3) + or (x.shape[0] == 3 and len(x.shape) == 1) + or (x.shape[0] == batch_size and x.shape[1] == 3) + ) + if not ok: + raise ValueError(err_msg) + if x.device != self.device: + x = x.to(self.device) + if x.shape[0] == 3 and len(x.shape) == 1: + x = x[None] + if x.shape[0] == 1: + x = x.repeat((batch_size, 1)) + else: + raise ValueError(err_msg) + + return x + + def to(self, device: Device, copy: bool = False) -> "VolumeLocator": """ Match the functionality of torch.Tensor.to() If copy = True or the self Tensor is on a different device, the @@ -713,7 +1069,7 @@ def to(self, device: Device, copy: bool = False) -> "Volumes": copy: Boolean indicator whether or not to clone self. Default False. Returns: - Volumes object. + VolumeLocator object. """ device_ = make_device(device) if not copy and self.device == device_: @@ -724,18 +1080,24 @@ def to(self, device: Device, copy: bool = False) -> "Volumes": return other other.device = device_ - other._densities = self._densities.to(device_) - if self._features is not None: - # pyre-fixme[16]: `Optional` has no attribute `to`. - other._features = self.features().to(device_) + other._grid_sizes = self._grid_sizes.to(device_) other._local_to_world_transform = self.get_local_to_world_coords_transform().to( - device_ + device ) - other._grid_sizes = self._grid_sizes.to(device_) return other - def cpu(self) -> "Volumes": + def clone(self) -> "VolumeLocator": + """ + Deep copy of VoluVolumeLocatormes object. All internal tensors are cloned + individually. + + Returns: + new VolumeLocator object. + """ + return copy.deepcopy(self) + + def cpu(self) -> "VolumeLocator": return self.to("cpu") - def cuda(self) -> "Volumes": + def cuda(self) -> "VolumeLocator": return self.to("cuda") diff --git a/tests/test_volumes.py b/tests/test_volumes.py index 9384d0a52..fc6c7a401 100644 --- a/tests/test_volumes.py +++ b/tests/test_volumes.py @@ -11,7 +11,7 @@ import numpy as np import torch -from pytorch3d.structures.volumes import Volumes +from pytorch3d.structures.volumes import VolumeLocator, Volumes from pytorch3d.transforms import Scale from .common_testing import TestCaseMixin @@ -53,8 +53,8 @@ def _check_indexed_volumes(self, v, selected, indices): for selectedIdx, index in indices: self.assertClose(selected.densities()[selectedIdx], v.densities()[index]) self.assertClose( - v._local_to_world_transform.get_matrix()[index], - selected._local_to_world_transform.get_matrix()[selectedIdx], + v.locator._local_to_world_transform.get_matrix()[index], + selected.locator._local_to_world_transform.get_matrix()[selectedIdx], ) if selected.features() is not None: self.assertClose(selected.features()[selectedIdx], v.features()[index]) @@ -149,10 +149,55 @@ def test_get_item( with self.assertRaises(IndexError): v_selected = v[index] + def test_locator_init(self, batch_size=9, resolution=(3, 5, 7)): + with self.subTest("VolumeLocator init with all sizes equal"): + grid_sizes = [resolution for _ in range(batch_size)] + locator_tuple = VolumeLocator( + batch_size=batch_size, grid_sizes=resolution, device=torch.device("cpu") + ) + locator_list = VolumeLocator( + batch_size=batch_size, grid_sizes=grid_sizes, device=torch.device("cpu") + ) + locator_tensor = VolumeLocator( + batch_size=batch_size, + grid_sizes=torch.tensor(grid_sizes), + device=torch.device("cpu"), + ) + expected_grid_sizes = torch.tensor(grid_sizes) + expected_resolution = resolution + assert torch.allclose(expected_grid_sizes, locator_tuple._grid_sizes) + assert torch.allclose(expected_grid_sizes, locator_list._grid_sizes) + assert torch.allclose(expected_grid_sizes, locator_tensor._grid_sizes) + self.assertEqual(expected_resolution, locator_tuple._resolution) + self.assertEqual(expected_resolution, locator_list._resolution) + self.assertEqual(expected_resolution, locator_tensor._resolution) + + with self.subTest("VolumeLocator with different sizes in different grids"): + grid_sizes_list = [ + torch.randint(low=1, high=42, size=(3,)) for _ in range(batch_size) + ] + grid_sizes_tensor = torch.cat([el[None] for el in grid_sizes_list]) + locator_list = VolumeLocator( + batch_size=batch_size, + grid_sizes=grid_sizes_list, + device=torch.device("cpu"), + ) + locator_tensor = VolumeLocator( + batch_size=batch_size, + grid_sizes=grid_sizes_tensor, + device=torch.device("cpu"), + ) + expected_grid_sizes = grid_sizes_tensor + expected_resolution = tuple(torch.max(expected_grid_sizes, dim=0).values) + assert torch.allclose(expected_grid_sizes, locator_list._grid_sizes) + assert torch.allclose(expected_grid_sizes, locator_tensor._grid_sizes) + self.assertEqual(expected_resolution, locator_list._resolution) + self.assertEqual(expected_resolution, locator_tensor._resolution) + def test_coord_transforms(self, num_volumes=3, num_channels=4, dtype=torch.float32): """ Test the correctness of the conversion between the internal - Transform3D Volumes._local_to_world_transform and the initialization + Transform3D Volumes.VolumeLocator._local_to_world_transform and the initialization from the translation and voxel_size. """ @@ -440,7 +485,10 @@ def _check_vars_on_device(self, v, desired_device): for var_name, var in vars(v).items(): if var_name != "device": if var is not None: - self.assertTrue(var.device.type == desired_device.type) + self.assertTrue( + var.device.type == desired_device.type, + (var_name, var.device, desired_device), + ) else: self.assertTrue(var.type == desired_device.type) @@ -456,60 +504,74 @@ def test_to( ) densities = torch.rand(size=[num_volumes, 1, *size], dtype=dtype) volumes = Volumes(densities=densities, features=features) + locator = VolumeLocator( + batch_size=5, grid_sizes=(3, 5, 7), device=volumes.device + ) - # Test support for str and torch.device - cpu_device = torch.device("cpu") + for name, obj in (("VolumeLocator", locator), ("Volumes", volumes)): + with self.subTest(f"Moving {name} from/to gpu and cpu"): + # Test support for str and torch.device + cpu_device = torch.device("cpu") - converted_volumes = volumes.to("cpu") - self.assertEqual(cpu_device, converted_volumes.device) - self.assertEqual(cpu_device, volumes.device) - self.assertIs(volumes, converted_volumes) + converted_obj = obj.to("cpu") + self.assertEqual(cpu_device, converted_obj.device) + self.assertEqual(cpu_device, obj.device) + self.assertIs(obj, converted_obj) - converted_volumes = volumes.to(cpu_device) - self.assertEqual(cpu_device, converted_volumes.device) - self.assertEqual(cpu_device, volumes.device) - self.assertIs(volumes, converted_volumes) + converted_obj = obj.to(cpu_device) + self.assertEqual(cpu_device, converted_obj.device) + self.assertEqual(cpu_device, obj.device) + self.assertIs(obj, converted_obj) - cuda_device = torch.device("cuda:0") + cuda_device = torch.device("cuda:0") - converted_volumes = volumes.to("cuda:0") - self.assertEqual(cuda_device, converted_volumes.device) - self.assertEqual(cpu_device, volumes.device) - self.assertIsNot(volumes, converted_volumes) + converted_obj = obj.to("cuda:0") + self.assertEqual(cuda_device, converted_obj.device) + self.assertEqual(cpu_device, obj.device) + self.assertIsNot(obj, converted_obj) - converted_volumes = volumes.to(cuda_device) - self.assertEqual(cuda_device, converted_volumes.device) - self.assertEqual(cpu_device, volumes.device) - self.assertIsNot(volumes, converted_volumes) + converted_obj = obj.to(cuda_device) + self.assertEqual(cuda_device, converted_obj.device) + self.assertEqual(cpu_device, obj.device) + self.assertIsNot(obj, converted_obj) - # Test device placement of internal tensors - features = features.to(cuda_device) - densities = features.to(cuda_device) + with self.subTest("Test device placement of internal tensors of Volumes"): + features = features.to(cuda_device) + densities = features.to(cuda_device) - for features_ in (features, None): - volumes = Volumes(densities=densities, features=features_) + for features_ in (features, None): + volumes = Volumes(densities=densities, features=features_) - cpu_volumes = volumes.cpu() - cuda_volumes = cpu_volumes.cuda() - cuda_volumes2 = cuda_volumes.cuda() - cpu_volumes2 = cuda_volumes2.cpu() + cpu_volumes = volumes.cpu() + cuda_volumes = cpu_volumes.cuda() + cuda_volumes2 = cuda_volumes.cuda() + cpu_volumes2 = cuda_volumes2.cpu() - for volumes1, volumes2 in itertools.combinations( - (volumes, cpu_volumes, cpu_volumes2, cuda_volumes, cuda_volumes2), 2 - ): - if volumes1 is cuda_volumes and volumes2 is cuda_volumes2: - # checks that we do not copy if the devices stay the same - assert_fun = self.assertIs - else: - assert_fun = self.assertSeparate - assert_fun(volumes1._densities, volumes2._densities) - if features_ is not None: - assert_fun(volumes1._features, volumes2._features) - for volumes_ in (volumes1, volumes2): - if volumes_ in (cpu_volumes, cpu_volumes2): - self._check_vars_on_device(volumes_, cpu_device) + for volumes1, volumes2 in itertools.combinations( + (volumes, cpu_volumes, cpu_volumes2, cuda_volumes, cuda_volumes2), 2 + ): + if volumes1 is cuda_volumes and volumes2 is cuda_volumes2: + # checks that we do not copy if the devices stay the same + assert_fun = self.assertIs else: - self._check_vars_on_device(volumes_, cuda_device) + assert_fun = self.assertSeparate + assert_fun(volumes1._densities, volumes2._densities) + if features_ is not None: + assert_fun(volumes1._features, volumes2._features) + for volumes_ in (volumes1, volumes2): + if volumes_ in (cpu_volumes, cpu_volumes2): + self._check_vars_on_device(volumes_, cpu_device) + else: + self._check_vars_on_device(volumes_, cuda_device) + + with self.subTest("Test device placement of internal tensors of VolumeLocator"): + for device1, device2 in itertools.combinations( + (torch.device("cpu"), torch.device("cuda:0")), 2 + ): + locator = locator.to(device1) + locator = locator.to(device2) + self.assertEqual(locator._grid_sizes.device, device2) + self.assertEqual(locator._local_to_world_transform.device, device2) def _check_padded(self, x_pad, x_list, grid_sizes): """