Skip to content

Commit

Permalink
Add integrated position encoding based on MIPNerf implementation.
Browse files Browse the repository at this point in the history
Summary: Add a new implicit module Integral Position Encoding based on [MIP-NeRF](https://arxiv.org/abs/2103.13415).

Reviewed By: shapovalov

Differential Revision: D46352730

fbshipit-source-id: c6a56134c975d80052b3a11f5e92fd7d95cbff1e
  • Loading branch information
EmGarr authored and facebook-github-bot committed Jul 6, 2023
1 parent 29b8ebd commit ccf860f
Show file tree
Hide file tree
Showing 6 changed files with 283 additions and 29 deletions.
6 changes: 6 additions & 0 deletions projects/implicitron_trainer/tests/experiment.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -361,6 +361,7 @@ model_factory_ImplicitronModelFactory_args:
n_hidden_neurons_dir: 128
input_xyz: true
xyz_ray_dir_in_camera_coords: false
use_integrated_positional_encoding: false
transformer_dim_down_factor: 2.0
n_hidden_neurons_xyz: 80
n_layers_xyz: 2
Expand All @@ -372,6 +373,7 @@ model_factory_ImplicitronModelFactory_args:
n_hidden_neurons_dir: 128
input_xyz: true
xyz_ray_dir_in_camera_coords: false
use_integrated_positional_encoding: false
transformer_dim_down_factor: 1.0
n_hidden_neurons_xyz: 256
n_layers_xyz: 8
Expand Down Expand Up @@ -741,6 +743,7 @@ model_factory_ImplicitronModelFactory_args:
n_hidden_neurons_dir: 128
input_xyz: true
xyz_ray_dir_in_camera_coords: false
use_integrated_positional_encoding: false
transformer_dim_down_factor: 2.0
n_hidden_neurons_xyz: 80
n_layers_xyz: 2
Expand All @@ -752,6 +755,7 @@ model_factory_ImplicitronModelFactory_args:
n_hidden_neurons_dir: 128
input_xyz: true
xyz_ray_dir_in_camera_coords: false
use_integrated_positional_encoding: false
transformer_dim_down_factor: 1.0
n_hidden_neurons_xyz: 256
n_layers_xyz: 8
Expand Down Expand Up @@ -979,6 +983,7 @@ model_factory_ImplicitronModelFactory_args:
n_hidden_neurons_dir: 128
input_xyz: true
xyz_ray_dir_in_camera_coords: false
use_integrated_positional_encoding: false
transformer_dim_down_factor: 2.0
n_hidden_neurons_xyz: 80
n_layers_xyz: 2
Expand All @@ -990,6 +995,7 @@ model_factory_ImplicitronModelFactory_args:
n_hidden_neurons_dir: 128
input_xyz: true
xyz_ray_dir_in_camera_coords: false
use_integrated_positional_encoding: false
transformer_dim_down_factor: 1.0
n_hidden_neurons_xyz: 256
n_layers_xyz: 8
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -9,11 +9,14 @@

import torch
from pytorch3d.common.linear_with_repeat import LinearWithRepeat
from pytorch3d.implicitron.models.renderer.base import ImplicitronRayBundle
from pytorch3d.implicitron.models.renderer.base import (
conical_frustum_to_gaussian,
ImplicitronRayBundle,
)
from pytorch3d.implicitron.tools.config import expand_args_fields, registry
from pytorch3d.renderer import ray_bundle_to_ray_points
from pytorch3d.renderer.cameras import CamerasBase
from pytorch3d.renderer.implicit import HarmonicEmbedding
from pytorch3d.renderer.implicit.utils import ray_bundle_to_ray_points

from .base import ImplicitFunctionBase

Expand All @@ -36,6 +39,7 @@ class NeuralRadianceFieldBase(ImplicitFunctionBase, torch.nn.Module):
input_xyz: bool = True
xyz_ray_dir_in_camera_coords: bool = False
color_dim: int = 3
use_integrated_positional_encoding: bool = False
"""
Args:
n_harmonic_functions_xyz: The number of harmonic functions
Expand All @@ -53,6 +57,10 @@ class NeuralRadianceFieldBase(ImplicitFunctionBase, torch.nn.Module):
n_layers_xyz: The number of layers of the MLP that outputs the
occupancy field.
append_xyz: The list of indices of the skip layers of the occupancy MLP.
use_integrated_positional_encoding: If True, use integrated positional enoding
as defined in `MIP-NeRF <https://arxiv.org/abs/2103.13415>`_.
If False, use the classical harmonic embedding
defined in `NeRF <https://arxiv.org/abs/2003.08934>`_.
"""

def __post_init__(self):
Expand Down Expand Up @@ -149,6 +157,10 @@ def forward(
containing the direction vectors of sampling rays in world coords.
lengths: A tensor of shape `(minibatch, ..., num_points_per_ray)`
containing the lengths at which the rays are sampled.
bins: An optional tensor of shape `(minibatch,..., num_points_per_ray + 1)`
containing the bins at which the rays are sampled. In this case
lengths is equal to the midpoints of bins.
fun_viewpool: an optional callback with the signature
fun_fiewpool(points) -> pooled_features
where points is a [N_TGT x N x 3] tensor of world coords,
Expand All @@ -160,11 +172,22 @@ def forward(
denoting the opacitiy of each ray point.
rays_colors: A tensor of shape `(minibatch, ..., num_points_per_ray, 3)`
denoting the color of each ray point.
Raises:
ValueError: If `use_integrated_positional_encoding` is True and
`ray_bundle.bins` is None.
"""
# We first convert the ray parametrizations to world
# coordinates with `ray_bundle_to_ray_points`.
# pyre-ignore[6]
rays_points_world = ray_bundle_to_ray_points(ray_bundle)
if self.use_integrated_positional_encoding and ray_bundle.bins is None:
raise ValueError(
"When use_integrated_positional_encoding is True, ray_bundle.bins must be set."
"Have you set to True `AbstractMaskRaySampler.use_bins_for_ray_sampling`?"
)

rays_points_world, diag_cov = (
conical_frustum_to_gaussian(ray_bundle)
if self.use_integrated_positional_encoding
else (ray_bundle_to_ray_points(ray_bundle), None) # pyre-ignore
)
# rays_points_world.shape = [minibatch x ... x pts_per_ray x 3]

embeds = create_embeddings_for_implicit_function(
Expand All @@ -177,6 +200,7 @@ def forward(
fun_viewpool=fun_viewpool,
xyz_in_camera_coords=self.xyz_ray_dir_in_camera_coords,
camera=camera,
diag_cov=diag_cov,
)

# embeds.shape = [minibatch x n_src x n_rays x n_pts x self.n_harmonic_functions*6+3]
Expand Down
7 changes: 4 additions & 3 deletions pytorch3d/implicitron/models/implicit_function/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@ def create_embeddings_for_implicit_function(
camera: Optional[CamerasBase],
fun_viewpool: Optional[Callable],
xyz_embedding_function: Optional[Callable],
diag_cov: Optional[torch.Tensor] = None,
) -> torch.Tensor:

bs, *spatial_size, pts_per_ray, _ = xyz_world.shape
Expand All @@ -59,11 +60,11 @@ def create_embeddings_for_implicit_function(
prod(spatial_size),
pts_per_ray,
0,
dtype=xyz_world.dtype,
device=xyz_world.device,
)
else:
embeds = xyz_embedding_function(ray_points_for_embed).reshape(

embeds = xyz_embedding_function(ray_points_for_embed, diag_cov=diag_cov)
embeds = embeds.reshape(
bs,
1,
prod(spatial_size),
Expand Down
83 changes: 69 additions & 14 deletions pytorch3d/renderer/implicit/harmonic_embedding.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,8 @@
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

from typing import Optional

import torch


Expand All @@ -16,8 +18,18 @@ def __init__(
append_input: bool = True,
) -> None:
"""
Given an input tensor `x` of shape [minibatch, ... , dim],
the harmonic embedding layer converts each feature
The harmonic embedding layer supports the classical
Nerf positional encoding described in
`NeRF <https://arxiv.org/abs/2003.08934>`_
and the integrated position encoding in
`MIP-NeRF <https://arxiv.org/abs/2103.13415>`_.
During, the inference you can provide the extra argument `diag_cov`.
If `diag_cov is None`, it converts
rays parametrized with a `ray_bundle` to 3D points by
extending each ray according to the corresponding length.
Then it converts each feature
(i.e. vector along the last dimension) in `x`
into a series of harmonic features `embedding`,
where for each i in range(dim) the following are present
Expand All @@ -38,6 +50,31 @@ def __init__(
where N corresponds to `n_harmonic_functions-1`, and f_i is a scalar
denoting the i-th frequency of the harmonic embedding.
If `diag_cov is not None`, it approximates
conical frustums following a ray bundle as gaussians,
defined by x, the means of the gaussians and diag_cov,
the diagonal covariances.
Then it converts each gaussian
into a series of harmonic features `embedding`,
where for each i in range(dim) the following are present
in embedding[...]::
[
sin(f_1*x[..., i]) * exp(0.5 * f_1**2 * diag_cov[..., i,]),
sin(f_2*x[..., i]) * exp(0.5 * f_2**2 * diag_cov[..., i,]),
...
sin(f_N * x[..., i]) * exp(0.5 * f_N**2 * diag_cov[..., i,]),
cos(f_1*x[..., i]) * exp(0.5 * f_1**2 * diag_cov[..., i,]),
cos(f_2*x[..., i]) * exp(0.5 * f_2**2 * diag_cov[..., i,]),,
...
cos(f_N * x[..., i]) * exp(0.5 * f_N**2 * diag_cov[..., i,]),
x[..., i], # only present if append_input is True.
]
where N equals `n_harmonic_functions-1`, and f_i is a scalar
denoting the i-th frequency of the harmonic embedding.
If `logspace==True`, the frequencies `[f_1, ..., f_N]` are
powers of 2:
`f_1, ..., f_N = 2**torch.arange(n_harmonic_functions)`
Expand All @@ -59,8 +96,7 @@ def __init__(
logspace or linear space
append_input: bool, whether to concat the original
input to the harmonic embedding. If true the
output is of the form (x, embed.sin(), embed.cos()
output is of the form (embed.sin(), embed.cos(), x)
"""
super().__init__()

Expand All @@ -78,23 +114,42 @@ def __init__(
)

self.register_buffer("_frequencies", frequencies * omega_0, persistent=False)
self.register_buffer(
"_zero_half_pi", torch.tensor([0.0, 0.5 * torch.pi]), persistent=False
)
self.append_input = append_input

def forward(self, x: torch.Tensor) -> torch.Tensor:
def forward(
self, x: torch.Tensor, diag_cov: Optional[torch.Tensor] = None, **kwargs
) -> torch.Tensor:
"""
Args:
x: tensor of shape [..., dim]
diag_cov: An optional tensor of shape `(..., dim)`
representing the diagonal covariance matrices of our Gaussians, joined with x
as means of the Gaussians.
Returns:
embedding: a harmonic embedding of `x`
of shape [..., (n_harmonic_functions * 2 + int(append_input)) * dim]
embedding: a harmonic embedding of `x` of shape
[..., (n_harmonic_functions * 2 + int(append_input)) * num_points_per_ray]
"""
embed = (x[..., None] * self._frequencies).reshape(*x.shape[:-1], -1)
embed = torch.cat(
(embed.sin(), embed.cos(), x)
if self.append_input
else (embed.sin(), embed.cos()),
dim=-1,
)
# [..., dim, n_harmonic_functions]
embed = x[..., None] * self._frequencies
# [..., 1, dim, n_harmonic_functions] + [2, 1, 1] => [..., 2, dim, n_harmonic_functions]
embed = embed[..., None, :, :] + self._zero_half_pi[..., None, None]
# Use the trig identity cos(x) = sin(x + pi/2)
# and do one vectorized call to sin([x, x+pi/2]) instead of (sin(x), cos(x)).
embed = embed.sin()
if diag_cov is not None:
x_var = diag_cov[..., None] * torch.pow(self._frequencies, 2)
exp_var = torch.exp(-0.5 * x_var)
# [..., 2, dim, n_harmonic_functions]
embed = embed * exp_var[..., None, :, :]

embed = embed.reshape(*x.shape[:-1], -1)

if self.append_input:
return torch.cat([embed, x], dim=-1)
return embed

@staticmethod
Expand Down
66 changes: 66 additions & 0 deletions tests/implicitron/test_implicit_function_neural_radiance_field.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,66 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

import unittest

import torch
from pytorch3d.implicitron.models.implicit_function.base import ImplicitronRayBundle
from pytorch3d.implicitron.models.implicit_function.neural_radiance_field import (
NeuralRadianceFieldImplicitFunction,
)


class TestNeuralRadianceFieldImplicitFunction(unittest.TestCase):
def setUp(self):
torch.manual_seed(42)

def test_forward_with_integrated_positionial_embedding(self):
shape = [2, 4, 4]
ray_bundle = ImplicitronRayBundle(
origins=torch.randn(*shape, 3),
directions=torch.randn(*shape, 3),
bins=torch.randn(*shape, 6 + 1),
lengths=torch.randn(*shape, 6),
pixel_radii_2d=torch.randn(*shape, 1),
xys=None,
)
model = NeuralRadianceFieldImplicitFunction(
n_hidden_neurons_dir=32, use_integrated_positional_encoding=True
)
raw_densities, ray_colors, _ = model(ray_bundle=ray_bundle)

self.assertEqual(raw_densities.shape, (*shape, ray_bundle.lengths.shape[-1], 1))
self.assertEqual(ray_colors.shape, (*shape, ray_bundle.lengths.shape[-1], 3))

def test_forward_with_integrated_positionial_embedding_raise_exception(self):
shape = [2, 4, 4]
ray_bundle = ImplicitronRayBundle(
origins=torch.randn(*shape, 3),
directions=torch.randn(*shape, 3),
bins=None,
lengths=torch.randn(*shape, 6),
pixel_radii_2d=torch.randn(*shape, 1),
xys=None,
)
model = NeuralRadianceFieldImplicitFunction(
n_hidden_neurons_dir=32, use_integrated_positional_encoding=True
)
with self.assertRaises(ValueError):
_ = model(ray_bundle=ray_bundle)

def test_forward(self):
shape = [2, 4, 4]
ray_bundle = ImplicitronRayBundle(
origins=torch.randn(*shape, 3),
directions=torch.randn(*shape, 3),
lengths=torch.randn(*shape, 6),
pixel_radii_2d=torch.randn(*shape, 1),
xys=None,
)
model = NeuralRadianceFieldImplicitFunction(n_hidden_neurons_dir=32)
raw_densities, ray_colors, _ = model(ray_bundle=ray_bundle)
self.assertEqual(raw_densities.shape, (*shape, 6, 1))
self.assertEqual(ray_colors.shape, (*shape, 6, 3))
Loading

0 comments on commit ccf860f

Please sign in to comment.