Skip to content

Commit

Permalink
[Enhance] Enhance error information in build function (#1088)
Browse files Browse the repository at this point in the history
  • Loading branch information
HAOCHENYE committed Jul 28, 2023
1 parent e56d6ed commit 237aee3
Show file tree
Hide file tree
Showing 3 changed files with 37 additions and 57 deletions.
84 changes: 33 additions & 51 deletions mmengine/registry/build_functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -98,7 +98,7 @@ def build_from_cfg(
obj_cls = registry.get(obj_type)
if obj_cls is None:
raise KeyError(
f'{obj_type} is not in the {registry.name} registry. '
f'{obj_type} is not in the {registry.scope}::{registry.name} registry. ' # noqa: E501
f'Please check whether the value of `{obj_type}` is '
'correct or it was registered as expected. More details '
'can be found at '
Expand All @@ -111,39 +111,30 @@ def build_from_cfg(
raise TypeError(
f'type must be a str or valid type, but got {type(obj_type)}')

try:
# If `obj_cls` inherits from `ManagerMixin`, it should be
# instantiated by `ManagerMixin.get_instance` to ensure that it
# can be accessed globally.
if inspect.isclass(obj_cls) and \
issubclass(obj_cls, ManagerMixin): # type: ignore
obj = obj_cls.get_instance(**args) # type: ignore
else:
obj = obj_cls(**args) # type: ignore

if (inspect.isclass(obj_cls) or inspect.isfunction(obj_cls)
or inspect.ismethod(obj_cls)):
print_log(
f'An `{obj_cls.__name__}` instance is built from ' # type: ignore # noqa: E501
'registry, and its implementation can be found in '
f'{obj_cls.__module__}', # type: ignore
logger='current',
level=logging.DEBUG)
else:
print_log(
'An instance is built from registry, and its constructor '
f'is {obj_cls}',
logger='current',
level=logging.DEBUG)
return obj

except Exception as e:
# Normal TypeError does not print class name.
cls_location = '/'.join(
obj_cls.__module__.split('.')) # type: ignore
raise type(e)(
f'class `{obj_cls.__name__}` in ' # type: ignore
f'{cls_location}.py: {e}')
# If `obj_cls` inherits from `ManagerMixin`, it should be
# instantiated by `ManagerMixin.get_instance` to ensure that it
# can be accessed globally.
if inspect.isclass(obj_cls) and \
issubclass(obj_cls, ManagerMixin): # type: ignore
obj = obj_cls.get_instance(**args) # type: ignore
else:
obj = obj_cls(**args) # type: ignore

if (inspect.isclass(obj_cls) or inspect.isfunction(obj_cls)
or inspect.ismethod(obj_cls)):
print_log(
f'An `{obj_cls.__name__}` instance is built from ' # type: ignore # noqa: E501
'registry, and its implementation can be found in '
f'{obj_cls.__module__}', # type: ignore
logger='current',
level=logging.DEBUG)
else:
print_log(
'An instance is built from registry, and its constructor '
f'is {obj_cls}',
logger='current',
level=logging.DEBUG)
return obj


def build_runner_from_cfg(cfg: Union[dict, ConfigDict, Config],
Expand Down Expand Up @@ -202,23 +193,14 @@ def build_runner_from_cfg(cfg: Union[dict, ConfigDict, Config],
raise TypeError(
f'type must be a str or valid type, but got {type(obj_type)}')

try:
runner = runner_cls.from_cfg(args) # type: ignore
print_log(
f'An `{runner_cls.__name__}` instance is built from ' # type: ignore # noqa: E501
'registry, its implementation can be found in'
f'{runner_cls.__module__}', # type: ignore
logger='current',
level=logging.DEBUG)
return runner

except Exception as e:
# Normal TypeError does not print class name.
cls_location = '/'.join(
runner_cls.__module__.split('.')) # type: ignore
raise type(e)(
f'class `{runner_cls.__name__}` in ' # type: ignore
f'{cls_location}.py: {e}')
runner = runner_cls.from_cfg(args) # type: ignore
print_log(
f'An `{runner_cls.__name__}` instance is built from ' # type: ignore # noqa: E501
'registry, its implementation can be found in'
f'{runner_cls.__module__}', # type: ignore
logger='current',
level=logging.DEBUG)
return runner


def build_model_from_cfg(
Expand Down
5 changes: 4 additions & 1 deletion tests/test_registry/test_build_functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,7 +64,10 @@ def __init__(self, depth, stages=4):
assert model.depth == 50 and model.stages == 4

# non-registered class
with pytest.raises(KeyError, match='VGG is not in the backbone registry'):
with pytest.raises(
KeyError,
match='VGG is not in the test_build_functions::backbone registry',
):
cfg = cfg_type(dict(type='VGG'))
model = build_from_cfg(cfg, BACKBONES)

Expand Down
5 changes: 0 additions & 5 deletions tests/test_registry/test_registry.py
Original file line number Diff line number Diff line change
Expand Up @@ -578,11 +578,6 @@ def __init__(self, depth, stages=4):
assert isinstance(model, ResNet)
assert model.depth == 50 and model.stages == 4

# non-registered class
with pytest.raises(KeyError, match='VGG is not in the backbone registry'):
cfg = cfg_type(dict(type='VGG'))
model = build_from_cfg(cfg, BACKBONES)

# `cfg` contains unexpected arguments
with pytest.raises(TypeError):
cfg = cfg_type(dict(type='ResNet', non_existing_arg=50))
Expand Down

0 comments on commit 237aee3

Please sign in to comment.