Skip to content

Commit

Permalink
Hard population of registry system with pre_expand
Browse files Browse the repository at this point in the history
Summary: Provide an extension point pre_expand to let a configurable class A make sure another class B is registered before A is expanded. This reduces top level imports.

Reviewed By: bottler

Differential Revision: D44504122

fbshipit-source-id: c418bebbe6d33862d239be592d9751378eee3a62
  • Loading branch information
Dejan Kovachev authored and facebook-github-bot committed Mar 31, 2023
1 parent 813e941 commit c759fc5
Show file tree
Hide file tree
Showing 5 changed files with 117 additions and 27 deletions.
25 changes: 20 additions & 5 deletions pytorch3d/implicitron/dataset/data_source.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,13 +13,8 @@
)
from pytorch3d.renderer.cameras import CamerasBase

from .blender_dataset_map_provider import BlenderDatasetMapProvider # noqa
from .data_loader_map_provider import DataLoaderMap, DataLoaderMapProviderBase
from .dataset_map_provider import DatasetMap, DatasetMapProviderBase
from .json_index_dataset_map_provider import JsonIndexDatasetMapProvider # noqa
from .json_index_dataset_map_provider_v2 import JsonIndexDatasetMapProviderV2 # noqa
from .llff_dataset_map_provider import LlffDatasetMapProvider # noqa
from .rendered_mesh_dataset_map_provider import RenderedMeshDatasetMapProvider # noqa


class DataSourceBase(ReplaceableBase):
Expand Down Expand Up @@ -60,6 +55,26 @@ class ImplicitronDataSource(DataSourceBase): # pyre-ignore[13]
data_loader_map_provider: DataLoaderMapProviderBase
data_loader_map_provider_class_type: str = "SequenceDataLoaderMapProvider"

@classmethod
def pre_expand(cls) -> None:
# use try/finally to bypass cinder's lazy imports
try:
from .blender_dataset_map_provider import ( # noqa: F401
BlenderDatasetMapProvider,
)
from .json_index_dataset_map_provider import ( # noqa: F401
JsonIndexDatasetMapProvider,
)
from .json_index_dataset_map_provider_v2 import ( # noqa: F401
JsonIndexDatasetMapProviderV2,
)
from .llff_dataset_map_provider import LlffDatasetMapProvider # noqa: F401
from .rendered_mesh_dataset_map_provider import ( # noqa: F401
RenderedMeshDatasetMapProvider,
)
finally:
pass

def __post_init__(self):
run_auto_creation(self)
self._all_train_cameras_cache: Optional[Tuple[Optional[CamerasBase]]] = None
Expand Down
53 changes: 31 additions & 22 deletions pytorch3d/implicitron/models/generic_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,23 +20,8 @@
ImplicitronRender,
)
from pytorch3d.implicitron.models.feature_extractor import FeatureExtractorBase
from pytorch3d.implicitron.models.feature_extractor.resnet_feature_extractor import ( # noqa
ResNetFeatureExtractor,
)
from pytorch3d.implicitron.models.global_encoder.global_encoder import GlobalEncoderBase
from pytorch3d.implicitron.models.implicit_function.base import ImplicitFunctionBase
from pytorch3d.implicitron.models.implicit_function.idr_feature_field import ( # noqa
IdrFeatureField,
)
from pytorch3d.implicitron.models.implicit_function.neural_radiance_field import ( # noqa
NeRFormerImplicitFunction,
)
from pytorch3d.implicitron.models.implicit_function.scene_representation_networks import ( # noqa
SRNHyperNetImplicitFunction,
)
from pytorch3d.implicitron.models.implicit_function.voxel_grid_implicit_function import ( # noqa
VoxelGridImplicitFunction,
)
from pytorch3d.implicitron.models.metrics import (
RegularizationMetricsBase,
ViewMetricsBase,
Expand All @@ -50,14 +35,7 @@
RendererOutput,
RenderSamplingMode,
)
from pytorch3d.implicitron.models.renderer.lstm_renderer import LSTMRenderer # noqa
from pytorch3d.implicitron.models.renderer.multipass_ea import ( # noqa
MultiPassEmissionAbsorptionRenderer,
)
from pytorch3d.implicitron.models.renderer.ray_sampler import RaySamplerBase
from pytorch3d.implicitron.models.renderer.sdf_renderer import ( # noqa
SignedDistanceFunctionRenderer,
)

from pytorch3d.implicitron.models.utils import (
apply_chunked,
Expand Down Expand Up @@ -315,6 +293,37 @@ class GenericModel(ImplicitronModelBase): # pyre-ignore: 13
]
)

@classmethod
def pre_expand(cls) -> None:
# use try/finally to bypass cinder's lazy imports
try:
from pytorch3d.implicitron.models.feature_extractor.resnet_feature_extractor import ( # noqa: F401, B950
ResNetFeatureExtractor,
)
from pytorch3d.implicitron.models.implicit_function.idr_feature_field import ( # noqa: F401, B950
IdrFeatureField,
)
from pytorch3d.implicitron.models.implicit_function.neural_radiance_field import ( # noqa: F401, B950
NeRFormerImplicitFunction,
)
from pytorch3d.implicitron.models.implicit_function.scene_representation_networks import ( # noqa: F401, B950
SRNHyperNetImplicitFunction,
)
from pytorch3d.implicitron.models.implicit_function.voxel_grid_implicit_function import ( # noqa: F401, B950
VoxelGridImplicitFunction,
)
from pytorch3d.implicitron.models.renderer.lstm_renderer import ( # noqa: F401
LSTMRenderer,
)
from pytorch3d.implicitron.models.renderer.multipass_ea import ( # noqa
MultiPassEmissionAbsorptionRenderer,
)
from pytorch3d.implicitron.models.renderer.sdf_renderer import ( # noqa: F401
SignedDistanceFunctionRenderer,
)
finally:
pass

def __post_init__(self):
if self.view_pooler_enabled:
if self.image_feature_extractor_class_type is None:
Expand Down
25 changes: 25 additions & 0 deletions pytorch3d/implicitron/models/overfit_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -258,6 +258,31 @@ class OverfitModel(ImplicitronModelBase): # pyre-ignore: 13
]
)

@classmethod
def pre_expand(cls) -> None:
# use try/finally to bypass cinder's lazy imports
try:
from pytorch3d.implicitron.models.implicit_function.idr_feature_field import ( # noqa: F401, B950
IdrFeatureField,
)
from pytorch3d.implicitron.models.implicit_function.neural_radiance_field import ( # noqa: F401, B950
NeuralRadianceFieldImplicitFunction,
)
from pytorch3d.implicitron.models.implicit_function.scene_representation_networks import ( # noqa: F401, B950
SRNImplicitFunction,
)
from pytorch3d.implicitron.models.renderer.lstm_renderer import ( # noqa: F401
LSTMRenderer,
)
from pytorch3d.implicitron.models.renderer.multipass_ea import ( # noqa: F401
MultiPassEmissionAbsorptionRenderer,
)
from pytorch3d.implicitron.models.renderer.sdf_renderer import ( # noqa: F401
SignedDistanceFunctionRenderer,
)
finally:
pass

def __post_init__(self):
# The attribute will be filled by run_auto_creation
run_auto_creation(self)
Expand Down
7 changes: 7 additions & 0 deletions pytorch3d/implicitron/tools/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -185,6 +185,7 @@ def __post_init__(self):
IMPL_SUFFIX: str = "_impl"
TWEAK_SUFFIX: str = "_tweak_args"
_DATACLASS_INIT: str = "__dataclass_own_init__"
PRE_EXPAND_NAME: str = "pre_expand"


class ReplaceableBase:
Expand Down Expand Up @@ -838,6 +839,9 @@ def x_tweak_args(cls, member_type: Type, args: DictConfig) -> None
In addition, if the class inherits torch.nn.Module, the generated __init__ will
call torch.nn.Module's __init__ before doing anything else.
Before any transformation of the class, if the class has a classmethod called
`pre_expand`, it will be called with no arguments.
Note that although the *_args members are intended to have type DictConfig, they
are actually internally annotated as dicts. OmegaConf is happy to see a DictConfig
in place of a dict, but not vice-versa. Allowing dict lets a class user specify
Expand All @@ -858,6 +862,9 @@ def x_tweak_args(cls, member_type: Type, args: DictConfig) -> None
if _is_actually_dataclass(some_class):
return some_class

if hasattr(some_class, PRE_EXPAND_NAME):
getattr(some_class, PRE_EXPAND_NAME)()

# The functions this class's run_auto_creation will run.
creation_functions: List[str] = []
# The classes which this type knows about from the registry
Expand Down
34 changes: 34 additions & 0 deletions tests/implicitron/test_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
from dataclasses import dataclass, field, is_dataclass
from enum import Enum
from typing import Any, Dict, List, Optional, Tuple
from unittest.mock import Mock

from omegaconf import DictConfig, ListConfig, OmegaConf, ValidationError
from pytorch3d.implicitron.tools.config import (
Expand Down Expand Up @@ -805,6 +806,39 @@ def __post_init__(self):

self.assertEqual(control_args, ["Orange", "Orange", True, True])

def test_pre_expand(self):
# Check that the precreate method of a class is called once before
# when expand_args_fields is called on the class.

class A(Configurable):
n: int = 9

@classmethod
def pre_expand(cls):
pass

A.pre_expand = Mock()
expand_args_fields(A)
A.pre_expand.assert_called()

def test_pre_expand_replaceable(self):
# Check that the precreate method of a class is called once before
# when expand_args_fields is called on the class.

class A(ReplaceableBase):
pass

@classmethod
def pre_expand(cls):
pass

class A1(A):
n: 9

A.pre_expand = Mock()
expand_args_fields(A1)
A.pre_expand.assert_called()


@dataclass(eq=False)
class MockDataclass:
Expand Down

0 comments on commit c759fc5

Please sign in to comment.