Skip to content

Commit

Permalink
Make Module.__init__ automatic
Browse files Browse the repository at this point in the history
Summary: If a configurable class inherits torch.nn.Module and is instantiated, automatically call `torch.nn.Module.__init__` on it before doing anything else.

Reviewed By: shapovalov

Differential Revision: D42760349

fbshipit-source-id: 409894911a4252b7987e1fd218ee9ecefbec8e62
  • Loading branch information
bottler authored and facebook-github-bot committed Jan 27, 2023
1 parent 97f8f9b commit 9540c29
Show file tree
Hide file tree
Showing 29 changed files with 36 additions and 87 deletions.
2 changes: 0 additions & 2 deletions projects/implicitron_trainer/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -212,9 +212,7 @@ from pytorch3d.implicitron.tools.config import registry
class XRayRenderer(BaseRenderer, torch.nn.Module):
n_pts_per_ray: int = 64
# if there are other base classes, make sure to call `super().__init__()` explicitly
def __post_init__(self):
super().__init__()
# custom initialization
def forward(
Expand Down
2 changes: 1 addition & 1 deletion pytorch3d/implicitron/eval_demo.py
Original file line number Diff line number Diff line change
Expand Up @@ -130,7 +130,7 @@ def evaluate_dbir_for_category(
raise ValueError("Image size should be set in the dataset")

# init the simple DBIR model
model = ModelDBIR( # pyre-ignore[28]: c’tor implicitly overridden
model = ModelDBIR(
render_image_width=image_size,
render_image_height=image_size,
bg_color=bg_color,
Expand Down
3 changes: 0 additions & 3 deletions pytorch3d/implicitron/models/base_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,9 +49,6 @@ class ImplicitronModelBase(ReplaceableBase, torch.nn.Module):
# the training loop.
log_vars: List[str] = field(default_factory=lambda: ["objective"])

def __init__(self) -> None:
super().__init__()

def forward(
self,
*, # force keyword-only arguments
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -15,9 +15,6 @@ class FeatureExtractorBase(ReplaceableBase, torch.nn.Module):
Base class for an extractor of a set of features from images.
"""

def __init__(self):
super().__init__()

def get_feat_dims(self) -> int:
"""
Returns:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -78,7 +78,6 @@ class ResNetFeatureExtractor(FeatureExtractorBase):
feature_rescale: float = 1.0

def __post_init__(self):
super().__init__()
if self.normalize_image:
# register buffers needed to normalize the image
for k, v in (("_resnet_mean", _RESNET_MEAN), ("_resnet_std", _RESNET_STD)):
Expand Down
2 changes: 0 additions & 2 deletions pytorch3d/implicitron/models/generic_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -304,8 +304,6 @@ class GenericModel(ImplicitronModelBase): # pyre-ignore: 13
)

def __post_init__(self):
super().__init__()

if self.view_pooler_enabled:
if self.image_feature_extractor_class_type is None:
raise ValueError(
Expand Down
2 changes: 0 additions & 2 deletions pytorch3d/implicitron/models/global_encoder/autodecoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,8 +29,6 @@ class Autodecoder(Configurable, torch.nn.Module):
ignore_input: bool = False

def __post_init__(self):
super().__init__()

if self.n_instances <= 0:
raise ValueError(f"Invalid n_instances {self.n_instances}")

Expand Down
5 changes: 0 additions & 5 deletions pytorch3d/implicitron/models/global_encoder/global_encoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,9 +26,6 @@ class GlobalEncoderBase(ReplaceableBase):
(`SequenceAutodecoder`).
"""

def __init__(self) -> None:
super().__init__()

def get_encoding_dim(self):
"""
Returns the dimensionality of the returned encoding.
Expand Down Expand Up @@ -69,7 +66,6 @@ class SequenceAutodecoder(GlobalEncoderBase, torch.nn.Module): # pyre-ignore: 1
autodecoder: Autodecoder

def __post_init__(self):
super().__init__()
run_auto_creation(self)

def get_encoding_dim(self):
Expand Down Expand Up @@ -103,7 +99,6 @@ class HarmonicTimeEncoder(GlobalEncoderBase, torch.nn.Module):
time_divisor: float = 1.0

def __post_init__(self):
super().__init__()
self._harmonic_embedding = HarmonicEmbedding(
n_harmonic_functions=self.n_harmonic_functions,
append_input=self.append_input,
Expand Down
3 changes: 0 additions & 3 deletions pytorch3d/implicitron/models/implicit_function/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,9 +14,6 @@


class ImplicitFunctionBase(ABC, ReplaceableBase):
def __init__(self):
super().__init__()

@abstractmethod
def forward(
self,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -45,9 +45,6 @@ class DecoderFunctionBase(ReplaceableBase, torch.nn.Module):
space and transforms it into the required quantity (for example density and color).
"""

def __post_init__(self):
super().__init__()

def forward(
self, features: torch.Tensor, z: Optional[torch.Tensor] = None
) -> torch.Tensor:
Expand Down Expand Up @@ -83,7 +80,6 @@ class ElementwiseDecoder(DecoderFunctionBase):
operation: DecoderActivation = DecoderActivation.IDENTITY

def __post_init__(self):
super().__post_init__()
if self.operation not in [
DecoderActivation.RELU,
DecoderActivation.SOFTPLUS,
Expand Down Expand Up @@ -163,8 +159,6 @@ class MLPWithInputSkips(Configurable, torch.nn.Module):
use_xavier_init: bool = True

def __post_init__(self):
super().__init__()

try:
last_activation = {
DecoderActivation.RELU: torch.nn.ReLU(True),
Expand Down Expand Up @@ -284,7 +278,6 @@ class MLPDecoder(DecoderFunctionBase):
network: MLPWithInputSkips

def __post_init__(self):
super().__post_init__()
run_auto_creation(self)

def forward(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -66,8 +66,6 @@ class IdrFeatureField(ImplicitFunctionBase, torch.nn.Module):
encoding_dim: int = 0

def __post_init__(self):
super().__init__()

dims = [self.d_in] + list(self.dims) + [self.d_out + self.feature_vector_size]

self.embed_fn = None
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,6 @@ class NeuralRadianceFieldBase(ImplicitFunctionBase, torch.nn.Module):
"""

def __post_init__(self):
super().__init__()
# The harmonic embedding layer converts input 3D coordinates
# to a representation that is more suitable for
# processing with a deep neural network.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,6 @@ class SRNRaymarchFunction(Configurable, torch.nn.Module):
raymarch_function: Any = None

def __post_init__(self):
super().__init__()
self._harmonic_embedding = HarmonicEmbedding(
self.n_harmonic_functions, append_input=True
)
Expand Down Expand Up @@ -135,7 +134,6 @@ class SRNPixelGenerator(Configurable, torch.nn.Module):
ray_dir_in_camera_coords: bool = False

def __post_init__(self):
super().__init__()
self._harmonic_embedding = HarmonicEmbedding(
self.n_harmonic_functions, append_input=True
)
Expand Down Expand Up @@ -249,7 +247,6 @@ class SRNRaymarchHyperNet(Configurable, torch.nn.Module):
xyz_in_camera_coords: bool = False

def __post_init__(self):
super().__init__()
raymarch_input_embedding_dim = (
HarmonicEmbedding.get_output_dim_static(
self.in_features,
Expand Down Expand Up @@ -335,7 +332,6 @@ class SRNImplicitFunction(ImplicitFunctionBase, torch.nn.Module):
pixel_generator: SRNPixelGenerator

def __post_init__(self):
super().__init__()
run_auto_creation(self)

def create_raymarch_function(self) -> None:
Expand Down Expand Up @@ -393,7 +389,6 @@ class SRNHyperNetImplicitFunction(ImplicitFunctionBase, torch.nn.Module):
pixel_generator: SRNPixelGenerator

def __post_init__(self):
super().__init__()
run_auto_creation(self)

def create_hypernet(self) -> None:
Expand Down
2 changes: 0 additions & 2 deletions pytorch3d/implicitron/models/implicit_function/voxel_grid.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,7 +81,6 @@ class VoxelGridBase(ReplaceableBase, torch.nn.Module):
)

def __post_init__(self):
super().__init__()
if 0 not in self.resolution_changes:
raise ValueError("There has to be key `0` in `resolution_changes`.")

Expand Down Expand Up @@ -857,7 +856,6 @@ class VoxelGridModule(Configurable, torch.nn.Module):
param_groups: Dict[str, str] = field(default_factory=lambda: {})

def __post_init__(self):
super().__init__()
run_auto_creation(self)
n_grids = 1 # Voxel grid objects are batched. We need only a single grid.
shapes = self.voxel_grid.get_shapes(epoch=0)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -186,7 +186,6 @@ class VoxelGridImplicitFunction(ImplicitFunctionBase, torch.nn.Module):
volume_cropping_epochs: Tuple[int, ...] = ()

def __post_init__(self) -> None:
super().__init__()
run_auto_creation(self)
# pyre-ignore[16]
self.voxel_grid_scaffold = self._create_voxel_grid_scaffold()
Expand Down
6 changes: 0 additions & 6 deletions pytorch3d/implicitron/models/metrics.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,9 +25,6 @@ class RegularizationMetricsBase(ReplaceableBase, torch.nn.Module):
depend on the model's parameters.
"""

def __post_init__(self) -> None:
super().__init__()

def forward(
self, model: Any, keys_prefix: str = "loss_", **kwargs
) -> Dict[str, Any]:
Expand Down Expand Up @@ -56,9 +53,6 @@ class ViewMetricsBase(ReplaceableBase, torch.nn.Module):
`forward()` method produces losses and other metrics.
"""

def __post_init__(self) -> None:
super().__init__()

def forward(
self,
raymarched: RendererOutput,
Expand Down
3 changes: 0 additions & 3 deletions pytorch3d/implicitron/models/model_dbir.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,9 +41,6 @@ class ModelDBIR(ImplicitronModelBase):
bg_color: Tuple[float, float, float] = (0.0, 0.0, 0.0)
max_points: int = -1

def __post_init__(self):
super().__init__()

def forward(
self,
*, # force keyword-only arguments
Expand Down
3 changes: 0 additions & 3 deletions pytorch3d/implicitron/models/renderer/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -141,9 +141,6 @@ class BaseRenderer(ABC, ReplaceableBase):
Base class for all Renderer implementations.
"""

def __init__(self) -> None:
super().__init__()

def requires_object_mask(self) -> bool:
"""
Whether `forward` needs the object_mask.
Expand Down
1 change: 0 additions & 1 deletion pytorch3d/implicitron/models/renderer/lstm_renderer.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,6 @@ class LSTMRenderer(BaseRenderer, torch.nn.Module):
verbose: bool = False

def __post_init__(self):
super().__init__()
self._lstm = torch.nn.LSTMCell(
input_size=self.n_feature_channels,
hidden_size=self.hidden_size,
Expand Down
1 change: 0 additions & 1 deletion pytorch3d/implicitron/models/renderer/multipass_ea.py
Original file line number Diff line number Diff line change
Expand Up @@ -90,7 +90,6 @@ class MultiPassEmissionAbsorptionRenderer( # pyre-ignore: 13
return_weights: bool = False

def __post_init__(self):
super().__init__()
self._refiners = {
EvaluationMode.TRAINING: RayPointRefiner(
n_pts_per_ray=self.n_pts_per_ray_fine_training,
Expand Down
3 changes: 0 additions & 3 deletions pytorch3d/implicitron/models/renderer/ray_point_refiner.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,9 +38,6 @@ class RayPointRefiner(Configurable, torch.nn.Module):
random_sampling: bool
add_input_samples: bool = True

def __post_init__(self) -> None:
super().__init__()

def forward(
self,
input_ray_bundle: ImplicitronRayBundle,
Expand Down
5 changes: 0 additions & 5 deletions pytorch3d/implicitron/models/renderer/ray_sampler.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,9 +20,6 @@ class RaySamplerBase(ReplaceableBase):
Base class for ray samplers.
"""

def __init__(self):
super().__init__()

def forward(
self,
cameras: CamerasBase,
Expand Down Expand Up @@ -102,8 +99,6 @@ class AbstractMaskRaySampler(RaySamplerBase, torch.nn.Module):
stratified_point_sampling_evaluation: bool = False

def __post_init__(self):
super().__init__()

if (self.n_rays_per_image_sampled_from_mask is not None) and (
self.n_rays_total_training is not None
):
Expand Down
3 changes: 0 additions & 3 deletions pytorch3d/implicitron/models/renderer/ray_tracing.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,9 +43,6 @@ class RayTracing(Configurable, nn.Module):
n_steps: int = 100
n_secant_steps: int = 8

def __post_init__(self):
super().__init__()

def forward(
self,
sdf: Callable[[torch.Tensor], torch.Tensor],
Expand Down
5 changes: 0 additions & 5 deletions pytorch3d/implicitron/models/renderer/raymarcher.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,9 +22,6 @@ class RaymarcherBase(ReplaceableBase):
and marching along them in order to generate a feature render.
"""

def __init__(self):
super().__init__()

def forward(
self,
rays_densities: torch.Tensor,
Expand Down Expand Up @@ -98,8 +95,6 @@ def __post_init__(self):
surface_thickness: Denotes the overlap between the absorption
function and the density function.
"""
super().__init__()

bg_color = torch.tensor(self.bg_color)
if bg_color.ndim != 1:
raise ValueError(f"bg_color (shape {bg_color.shape}) should be a 1D tensor")
Expand Down
1 change: 0 additions & 1 deletion pytorch3d/implicitron/models/renderer/sdf_renderer.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,6 @@ class SignedDistanceFunctionRenderer(BaseRenderer, torch.nn.Module): # pyre-ign
def __post_init__(
self,
):
super().__init__()
render_features_dimensions = self.render_features_dimensions
if len(self.bg_color) not in [1, render_features_dimensions]:
raise ValueError(
Expand Down
12 changes: 0 additions & 12 deletions pytorch3d/implicitron/models/view_pooler/feature_aggregator.py
Original file line number Diff line number Diff line change
Expand Up @@ -118,9 +118,6 @@ class IdentityFeatureAggregator(torch.nn.Module, FeatureAggregatorBase):
the outputs.
"""

def __post_init__(self):
super().__init__()

def get_aggregated_feature_dim(
self, feats_or_feats_dim: Union[Dict[str, torch.Tensor], int]
):
Expand Down Expand Up @@ -181,9 +178,6 @@ class ReductionFeatureAggregator(torch.nn.Module, FeatureAggregatorBase):
ReductionFunction.STD,
)

def __post_init__(self):
super().__init__()

def get_aggregated_feature_dim(
self, feats_or_feats_dim: Union[Dict[str, torch.Tensor], int]
):
Expand Down Expand Up @@ -275,9 +269,6 @@ class AngleWeightedReductionFeatureAggregator(torch.nn.Module, FeatureAggregator
weight_by_ray_angle_gamma: float = 1.0
min_ray_angle_weight: float = 0.1

def __post_init__(self):
super().__init__()

def get_aggregated_feature_dim(
self, feats_or_feats_dim: Union[Dict[str, torch.Tensor], int]
):
Expand Down Expand Up @@ -377,9 +368,6 @@ class AngleWeightedIdentityFeatureAggregator(torch.nn.Module, FeatureAggregatorB
weight_by_ray_angle_gamma: float = 1.0
min_ray_angle_weight: float = 0.1

def __post_init__(self):
super().__init__()

def get_aggregated_feature_dim(
self, feats_or_feats_dim: Union[Dict[str, torch.Tensor], int]
):
Expand Down
1 change: 0 additions & 1 deletion pytorch3d/implicitron/models/view_pooler/view_pooler.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,6 @@ class ViewPooler(Configurable, torch.nn.Module):
feature_aggregator: FeatureAggregatorBase

def __post_init__(self):
super().__init__()
run_auto_creation(self)

def get_aggregated_feature_dim(self, feats: Union[Dict[str, torch.Tensor], int]):
Expand Down
3 changes: 0 additions & 3 deletions pytorch3d/implicitron/models/view_pooler/view_sampler.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,9 +29,6 @@ class ViewSampler(Configurable, torch.nn.Module):
masked_sampling: bool = False
sampling_mode: str = "bilinear"

def __post_init__(self):
super().__init__()

def forward(
self,
*, # force kw args
Expand Down
Loading

0 comments on commit 9540c29

Please sign in to comment.