Skip to content

Commit

Permalink
Fix parameters not wrapped with nn.Parameter, antialiasing compatibility
Browse files Browse the repository at this point in the history
Summary: Some things fail if a parameter is not wraped; in particular, it prevented other tensors moving to GPU.

Reviewed By: bottler

Differential Revision: D40819932

fbshipit-source-id: a23b38ceacd7f0dc131cb0355fef1178e3e2f7fd
  • Loading branch information
shapovalov authored and facebook-github-bot committed Oct 31, 2022
1 parent 88620b6 commit f711c4b
Show file tree
Hide file tree
Showing 3 changed files with 47 additions and 28 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -75,7 +75,7 @@ class ElementwiseDecoder(DecoderFunctionBase):
shift: a scalar which is added to the scaled input before performing
the operation. Defaults to 0.
operation: which operation to perform on the transformed input. Options are:
`relu`, `softplus`, `sigmoid` and `identity`. Defaults to `identity`.
`RELU`, `SOFTPLUS`, `SIGMOID` or `IDENTITY`. Defaults to `IDENTITY`.
"""

scale: float = 1
Expand All @@ -91,7 +91,7 @@ def __post_init__(self):
DecoderActivation.IDENTITY,
]:
raise ValueError(
"`operation` can only be `relu`, `softplus`, `sigmoid` or identity."
"`operation` can only be `RELU`, `SOFTPLUS`, `SIGMOID` or `IDENTITY`."
)

def forward(
Expand Down Expand Up @@ -165,22 +165,18 @@ class MLPWithInputSkips(Configurable, torch.nn.Module):
def __post_init__(self):
super().__init__()

if self.last_activation not in [
DecoderActivation.RELU,
DecoderActivation.SOFTPLUS,
DecoderActivation.SIGMOID,
DecoderActivation.IDENTITY,
]:
try:
last_activation = {
DecoderActivation.RELU: torch.nn.ReLU(True),
DecoderActivation.SOFTPLUS: torch.nn.Softplus(),
DecoderActivation.SIGMOID: torch.nn.Sigmoid(),
DecoderActivation.IDENTITY: torch.nn.Identity(),
}[self.last_activation]
except KeyError as e:
raise ValueError(
"`last_activation` can only be `relu`,"
" `softplus`, `sigmoid` or identity."
)
last_activation = {
DecoderActivation.RELU: torch.nn.ReLU(True),
DecoderActivation.SOFTPLUS: torch.nn.Softplus(),
DecoderActivation.SIGMOID: torch.nn.Sigmoid(),
DecoderActivation.IDENTITY: torch.nn.Identity(),
}[self.last_activation]
"`last_activation` can only be `RELU`,"
" `SOFTPLUS`, `SIGMOID` or `IDENTITY`."
) from e

layers = []
skip_affine_layers = []
Expand Down
33 changes: 28 additions & 5 deletions pytorch3d/implicitron/models/implicit_function/voxel_grid.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,9 +15,12 @@
"""

import warnings
from collections.abc import Mapping
from dataclasses import dataclass, field
from typing import Callable, ClassVar, Dict, Iterator, List, Optional, Tuple, Type

from distutils.version import LooseVersion
from typing import Any, Callable, ClassVar, Dict, Iterator, List, Optional, Tuple, Type

import torch
from omegaconf import DictConfig
Expand Down Expand Up @@ -67,7 +70,9 @@ class VoxelGridBase(ReplaceableBase, torch.nn.Module):
padding: str = "zeros"
mode: str = "bilinear"
n_features: int = 1
resolution_changes: Dict[int, List[int]] = field(
# return the line below once we drop OmegaConf 2.1 support
# resolution_changes: Dict[int, List[int]] = field(
resolution_changes: Dict[int, Any] = field(
default_factory=lambda: {0: [128, 128, 128]}
)

Expand Down Expand Up @@ -212,6 +217,13 @@ def change_resolution(
+ "| 'bicubic' | 'linear' | 'area' | 'nearest-exact'"
)

interpolate_has_antialias = LooseVersion(torch.__version__) >= "1.11"

if antialias and not interpolate_has_antialias:
warnings.warn("Antialiased interpolation requires PyTorch 1.11+; ignoring")

interp_kwargs = {"antialias": antialias} if interpolate_has_antialias else {}

def change_individual_resolution(tensor, wanted_resolution):
if mode == "linear":
n_dim = len(wanted_resolution)
Expand All @@ -223,8 +235,8 @@ def change_individual_resolution(tensor, wanted_resolution):
size=wanted_resolution,
mode=new_mode,
align_corners=align_corners,
antialias=antialias,
recompute_scale_factor=False,
**interp_kwargs,
)

if epoch is not None:
Expand Down Expand Up @@ -880,7 +892,14 @@ def set_voxel_grid_parameters(self, params: VoxelGridValuesBase) -> None:
"""
if self.hold_voxel_grid_as_parameters:
# pyre-ignore [16]
self.params = torch.nn.ParameterDict(vars(params))
# Nones are converted to empty tensors by Parameter()
self.params = torch.nn.ParameterDict(
{
k: torch.nn.Parameter(val)
for k, val in vars(params).items()
if val is not None
}
)
else:
# Torch Module to hold parameters since they can only be registered
# at object level.
Expand Down Expand Up @@ -1011,7 +1030,11 @@ def crop_self(self, min_point: torch.Tensor, max_point: torch.Tensor) -> None:
)
# pyre-ignore [16]
self.params = torch.nn.ParameterDict(
{k: v for k, v in vars(grid_values).items()}
{
k: torch.nn.Parameter(val)
for k, val in vars(grid_values).items()
if val is not None
}
)
# New center of voxel grid is the middle point between max and min points.
self.translation = tuple((max_point + min_point) / 2)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -527,10 +527,10 @@ def _get_scaffold(self, epoch: int) -> bool:
return False

@classmethod
def decoder_density_tweak_args(cls, type, args: DictConfig) -> None:
def decoder_density_tweak_args(cls, type_, args: DictConfig) -> None:
args.pop("input_dim", None)

def create_decoder_density_impl(self, type, args: DictConfig) -> None:
def create_decoder_density_impl(self, type_, args: DictConfig) -> None:
"""
Decoding functions come after harmonic embedding and voxel grid. In order to not
calculate the input dimension of the decoder in the config file this function
Expand All @@ -548,18 +548,18 @@ def create_decoder_density_impl(self, type, args: DictConfig) -> None:
embedder_args["append_input"],
)

cls = registry.get(DecoderFunctionBase, type)
cls = registry.get(DecoderFunctionBase, type_)
need_input_dim = any(field.name == "input_dim" for field in fields(cls))
if need_input_dim:
self.decoder_density = cls(input_dim=input_dim, **args)
else:
self.decoder_density = cls(**args)

@classmethod
def decoder_color_tweak_args(cls, type, args: DictConfig) -> None:
def decoder_color_tweak_args(cls, type_, args: DictConfig) -> None:
args.pop("input_dim", None)

def create_decoder_color_impl(self, type, args: DictConfig) -> None:
def create_decoder_color_impl(self, type_, args: DictConfig) -> None:
"""
Decoding functions come after harmonic embedding and voxel grid. In order to not
calculate the input dimension of the decoder in the config file this function
Expand Down Expand Up @@ -587,7 +587,7 @@ def create_decoder_color_impl(self, type, args: DictConfig) -> None:

input_dim = input_dim0 + input_dim1

cls = registry.get(DecoderFunctionBase, type)
cls = registry.get(DecoderFunctionBase, type_)
need_input_dim = any(field.name == "input_dim" for field in fields(cls))
if need_input_dim:
self.decoder_color = cls(input_dim=input_dim, **args)
Expand Down

0 comments on commit f711c4b

Please sign in to comment.