Skip to content

Commit

Permalink
(breaking) image_size-agnostic GridRaySampler
Browse files Browse the repository at this point in the history
Summary:
As suggested in #802. By not persisting the _xy_grid buffer, we can allow (in some cases) a model with one image_size to be loaded from a saved model which was trained at a different resolution.

Also avoid persisting _frequencies in HarmonicEmbedding for similar reasons.

BC-break: This will cause load_state_dict, in strict mode, to complain if you try to load an old model with the new code.

Reviewed By: patricklabatut

Differential Revision: D30349234

fbshipit-source-id: d6061d1e51c9f79a78d61a9f732c9a5dfadbbb47
  • Loading branch information
bottler authored and facebook-github-bot committed Aug 31, 2021
1 parent 1251446 commit 1b8d86a
Show file tree
Hide file tree
Showing 3 changed files with 24 additions and 4 deletions.
6 changes: 3 additions & 3 deletions projects/nerf/nerf/harmonic_embedding.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ def __init__(
omega0: float = 1.0,
logspace: bool = True,
include_input: bool = True,
):
) -> None:
"""
Given an input tensor `x` of shape [minibatch, ... , dim],
the harmonic embedding layer converts each feature
Expand Down Expand Up @@ -69,10 +69,10 @@ def __init__(
dtype=torch.float32,
)

self.register_buffer("_frequencies", omega0 * frequencies)
self.register_buffer("_frequencies", omega0 * frequencies, persistent=False)
self.include_input = include_input

def forward(self, x: torch.Tensor):
def forward(self, x: torch.Tensor) -> torch.Tensor:
"""
Args:
x: tensor of shape [..., dim]
Expand Down
2 changes: 1 addition & 1 deletion pytorch3d/renderer/implicit/raysampling.py
Original file line number Diff line number Diff line change
Expand Up @@ -96,7 +96,7 @@ def __init__(
),
dim=-1,
)
self.register_buffer("_xy_grid", _xy_grid)
self.register_buffer("_xy_grid", _xy_grid, persistent=False)

def forward(self, cameras: CamerasBase, **kwargs) -> RayBundle:
"""
Expand Down
20 changes: 20 additions & 0 deletions tests/test_raysampling.py
Original file line number Diff line number Diff line change
Expand Up @@ -425,3 +425,23 @@ def _check_raysampler_ray_directions(self, cameras, raysampler, ray_bundle):
ray_bundle_camera_fix_seed.directions.view(batch_size, -1, 3),
atol=1e-5,
)

def test_load_state(self):
# check that we can load the state of one ray sampler into
# another with different image size.
module1 = NDCGridRaysampler(
image_width=20,
image_height=30,
n_pts_per_ray=40,
min_depth=1.2,
max_depth=2.3,
)
module2 = NDCGridRaysampler(
image_width=22,
image_height=32,
n_pts_per_ray=42,
min_depth=1.2,
max_depth=2.3,
)
state = module1.state_dict()
module2.load_state_dict(state)

0 comments on commit 1b8d86a

Please sign in to comment.