From 2a1de3b610b8f2e95b3aeda6de36805a0baa0e9d Mon Sep 17 00:00:00 2001 From: Jeremy Reizenstein Date: Mon, 14 Feb 2022 04:51:02 -0800 Subject: [PATCH] move LinearWithRepeat to pytorch3d Summary: Move this simple layer from the NeRF project into pytorch3d. Reviewed By: shapovalov Differential Revision: D34126972 fbshipit-source-id: a9c6d6c3c1b662c1b844ea5d1b982007d4df83e6 --- projects/nerf/nerf/implicit_function.py | 3 +- .../common}/linear_with_repeat.py | 38 ++++++++++++++++++- pytorch3d/renderer/__init__.py | 2 +- pytorch3d/renderer/utils.py | 2 +- tests/test_common_linear_with_repeat.py | 32 ++++++++++++++++ tests/test_rendering_utils.py | 6 +-- 6 files changed, 75 insertions(+), 8 deletions(-) rename {projects/nerf/nerf => pytorch3d/common}/linear_with_repeat.py (55%) create mode 100644 tests/test_common_linear_with_repeat.py diff --git a/projects/nerf/nerf/implicit_function.py b/projects/nerf/nerf/implicit_function.py index 7a1ad60f1..472a4a35b 100644 --- a/projects/nerf/nerf/implicit_function.py +++ b/projects/nerf/nerf/implicit_function.py @@ -7,10 +7,9 @@ from typing import Tuple import torch +from pytorch3d.common.linear_with_repeat import LinearWithRepeat from pytorch3d.renderer import HarmonicEmbedding, RayBundle, ray_bundle_to_ray_points -from .linear_with_repeat import LinearWithRepeat - def _xavier_init(linear): """ diff --git a/projects/nerf/nerf/linear_with_repeat.py b/pytorch3d/common/linear_with_repeat.py similarity index 55% rename from projects/nerf/nerf/linear_with_repeat.py rename to pytorch3d/common/linear_with_repeat.py index 3a53db564..c9f62355e 100644 --- a/projects/nerf/nerf/linear_with_repeat.py +++ b/pytorch3d/common/linear_with_repeat.py @@ -4,13 +4,15 @@ # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. +import math from typing import Tuple import torch import torch.nn.functional as F +from torch.nn import Parameter, init -class LinearWithRepeat(torch.nn.Linear): +class LinearWithRepeat(torch.nn.Module): """ if x has shape (..., k, n1) and y has shape (..., n2) @@ -50,6 +52,40 @@ class LinearWithRepeat(torch.nn.Linear): and sent that through the Linear. """ + def __init__( + self, + in_features: int, + out_features: int, + bias: bool = True, + device=None, + dtype=None, + ) -> None: + """ + Copied from torch.nn.Linear. + """ + factory_kwargs = {"device": device, "dtype": dtype} + super().__init__() + self.in_features = in_features + self.out_features = out_features + self.weight = Parameter( + torch.empty((out_features, in_features), **factory_kwargs) + ) + if bias: + self.bias = Parameter(torch.empty(out_features, **factory_kwargs)) + else: + self.register_parameter("bias", None) + self.reset_parameters() + + def reset_parameters(self) -> None: + """ + Copied from torch.nn.Linear. + """ + init.kaiming_uniform_(self.weight, a=math.sqrt(5)) + if self.bias is not None: + fan_in, _ = init._calculate_fan_in_and_fan_out(self.weight) + bound = 1 / math.sqrt(fan_in) if fan_in > 0 else 0 + init.uniform_(self.bias, -bound, bound) + def forward(self, input: Tuple[torch.Tensor, torch.Tensor]) -> torch.Tensor: n1 = input[0].shape[-1] output1 = F.linear(input[0], self.weight[:, :n1], self.bias) diff --git a/pytorch3d/renderer/__init__.py b/pytorch3d/renderer/__init__.py index c0a2d13c3..7fca11941 100644 --- a/pytorch3d/renderer/__init__.py +++ b/pytorch3d/renderer/__init__.py @@ -73,8 +73,8 @@ from .utils import ( TensorProperties, convert_to_tensors_and_broadcast, - ndc_to_grid_sample_coords, ndc_grid_sample, + ndc_to_grid_sample_coords, ) diff --git a/pytorch3d/renderer/utils.py b/pytorch3d/renderer/utils.py index 3984cf141..e0af59e87 100644 --- a/pytorch3d/renderer/utils.py +++ b/pytorch3d/renderer/utils.py @@ -8,7 +8,7 @@ import copy import inspect import warnings -from typing import Any, Optional, Union, Tuple +from typing import Any, Optional, Tuple, Union import numpy as np import torch diff --git a/tests/test_common_linear_with_repeat.py b/tests/test_common_linear_with_repeat.py new file mode 100644 index 000000000..dc8ec07e2 --- /dev/null +++ b/tests/test_common_linear_with_repeat.py @@ -0,0 +1,32 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +import unittest + +import torch +from common_testing import TestCaseMixin +from pytorch3d.common.linear_with_repeat import LinearWithRepeat + + +class TestLinearWithRepeat(TestCaseMixin, unittest.TestCase): + def setUp(self) -> None: + super().setUp() + torch.manual_seed(42) + + def test_simple(self): + x = torch.rand(4, 6, 7, 3) + y = torch.rand(4, 6, 4) + + linear = torch.nn.Linear(7, 8) + torch.nn.init.xavier_uniform_(linear.weight.data) + linear.bias.data.uniform_() + equivalent = torch.cat([x, y.unsqueeze(-2).expand(4, 6, 7, 4)], dim=-1) + expected = linear.forward(equivalent) + + linear_with_repeat = LinearWithRepeat(7, 8) + linear_with_repeat.load_state_dict(linear.state_dict()) + actual = linear_with_repeat.forward((x, y)) + self.assertClose(actual, expected, rtol=1e-4) diff --git a/tests/test_rendering_utils.py b/tests/test_rendering_utils.py index 6037a6598..7589b2447 100644 --- a/tests/test_rendering_utils.py +++ b/tests/test_rendering_utils.py @@ -12,16 +12,16 @@ from common_testing import TestCaseMixin from pytorch3d.ops import eyes from pytorch3d.renderer import ( - PerspectiveCameras, AlphaCompositor, - PointsRenderer, + PerspectiveCameras, PointsRasterizationSettings, PointsRasterizer, + PointsRenderer, ) from pytorch3d.renderer.utils import ( TensorProperties, - ndc_to_grid_sample_coords, ndc_grid_sample, + ndc_to_grid_sample_coords, ) from pytorch3d.structures import Pointclouds