Skip to content

Commit

Permalink
VoxelGridModule
Browse files Browse the repository at this point in the history
Summary: Simple wrapper around voxel grids to make them a module

Reviewed By: bottler

Differential Revision: D38829762

fbshipit-source-id: dfee85088fa3c65e396cc7d3bf7ebaaffaadb646
  • Loading branch information
Darijan Gudelj authored and facebook-github-bot committed Aug 25, 2022
1 parent 6653f44 commit 24f5f4a
Show file tree
Hide file tree
Showing 2 changed files with 112 additions and 2 deletions.
88 changes: 86 additions & 2 deletions pytorch3d/implicitron/models/implicit_function/voxel_grid.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,14 +8,23 @@
This file contains classes that implement Voxel grids, both in their full resolution
as in the factorized form. There are two factorized forms implemented, Tensor rank decomposition
or CANDECOMP/PARAFAC (here CP) and Vector Matrix (here VM) factorization from the
https://arxiv.org/abs/2203.09517.
TensoRF (https://arxiv.org/abs/2203.09517) paper.
In addition, the module VoxelGridModule implements a trainable instance of one of
these classes.
"""

from dataclasses import dataclass
from typing import ClassVar, Dict, Optional, Tuple, Type

import torch
from pytorch3d.implicitron.tools.config import registry, ReplaceableBase
from pytorch3d.implicitron.tools.config import (
Configurable,
registry,
ReplaceableBase,
run_auto_creation,
)
from pytorch3d.structures.volumes import VolumeLocator

from .utils import interpolate_line, interpolate_plane, interpolate_volume
Expand Down Expand Up @@ -426,3 +435,78 @@ def get_shapes(self) -> Dict[str, Tuple]:
)

return shape_dict


class VoxelGridModule(Configurable, torch.nn.Module):
"""
A wrapper torch.nn.Module for the VoxelGrid classes, which
contains parameters that are needed to train the VoxelGrid classes.
Members:
voxel_grid_class_type: The name of the class to use for voxel_grid,
which must be available in the registry. Default FullResolutionVoxelGrid.
voxel_grid: An instance of `VoxelGridBase`. This is the object which
this class wraps.
extents: 3-tuple of a form (width, height, depth), denotes the size of the grid
in world units.
translation: 3-tuple of float. The center of the volume in world units as (x, y, z).
init_std: Parameters are initialized using the gaussian distribution
with mean=init_mean and std=init_std. Default 0.1
init_mean: Parameters are initialized using the gaussian distribution
with mean=init_mean and std=init_std. Default 0.
"""

voxel_grid_class_type: str = "FullResolutionVoxelGrid"
voxel_grid: VoxelGridBase

extents: Tuple[float, float, float] = 1.0
translation: Tuple[float, float, float] = (0.0, 0.0, 0.0)

init_std: float = 0.1
init_mean: float = 0

def __post_init__(self):
super().__init__()
run_auto_creation(self)
n_grids = 1 # Voxel grid objects are batched. We need only a single grid.
shapes = self.voxel_grid.get_shapes()
params = {
name: torch.normal(
mean=torch.zeros((n_grids, *shape)) + self.init_mean,
std=self.init_std,
)
for name, shape in shapes.items()
}
self.params = torch.nn.ParameterDict(params)

def forward(self, points: torch.Tensor) -> torch.Tensor:
"""
Evaluates points in the world coordinate frame on the voxel_grid.
Args:
points (torch.Tensor): tensor of points that you want to query
of a form (n_points, 3)
Returns:
torch.Tensor of shape (n_points, n_features)
"""
locator = VolumeLocator(
batch_size=1,
# The resolution of the voxel grid does not need to be known
# to the locator object. It is easiest to fix the resolution of the locator.
# In particular we fix it to (2,2,2) so that there is exactly one voxel of the
# desired size. The locator object uses (z, y, x) convention for the grid_size,
# and this module uses (x, y, z) convention so the order has to be reversed
# (irrelevant in this case since they are all equal).
# It is (2, 2, 2) because the VolumeLocator object behaves like
# align_corners=True, which means that the points are in the corners of
# the volume. So in the grid of (2, 2, 2) there is only one voxel.
grid_sizes=(2, 2, 2),
# The locator object uses (x, y, z) convention for the
# voxel size and translation.
voxel_size=self.extents,
volume_translation=self.translation,
device=next(self.params.values()).device,
)
grid_values = self.voxel_grid.values_type(**self.params)
# voxel grids operate with extra n_grids dimension, which we fix to one
return self.voxel_grid.evaluate_world(points[None], grid_values, locator)[0]
26 changes: 26 additions & 0 deletions tests/implicitron/test_voxel_grids.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
CPFactorizedVoxelGrid,
FullResolutionVoxelGrid,
VMFactorizedVoxelGrid,
VoxelGridModule,
)

from pytorch3d.implicitron.tools.config import expand_args_fields
Expand Down Expand Up @@ -198,6 +199,7 @@ def setUp(self):
expand_args_fields(FullResolutionVoxelGrid)
expand_args_fields(CPFactorizedVoxelGrid)
expand_args_fields(VMFactorizedVoxelGrid)
expand_args_fields(VoxelGridModule)

def _interpolate_1D(
self, points: torch.Tensor, vectors: torch.Tensor
Expand Down Expand Up @@ -585,3 +587,27 @@ def test(cls, **kwargs):
n_features=10,
n_components=3,
)

def test_voxel_grid_module_location(self, n_times=10):
"""
This checks the module uses locator correctly etc..
If we know that voxel grids work for (x, y, z) in local coordinates
to test if the VoxelGridModule does not have permuted dimensions we
create local coordinates, pass them through verified voxelgrids and
compare the result with the result that we get when we convert
coordinates to world and pass them through the VoxelGridModule
"""
for _ in range(n_times):
extents = tuple(torch.randint(1, 50, size=(3,)).tolist())

grid = VoxelGridModule(extents=extents)
local_point = torch.rand(1, 3) * 2 - 1
world_point = local_point * torch.tensor(extents) / 2
grid_values = grid.voxel_grid.values_type(**grid.params)

assert torch.allclose(
grid(world_point)[0, 0],
grid.voxel_grid.evaluate_local(local_point[None], grid_values)[0, 0, 0],
rtol=0.0001,
)

0 comments on commit 24f5f4a

Please sign in to comment.