Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

add model registry #760

Merged
merged 23 commits into from Apr 10, 2021
Merged
Show file tree
Hide file tree
Changes from 7 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
63 changes: 51 additions & 12 deletions docs/registry.md
Expand Up @@ -12,39 +12,40 @@ One typical example is the config systems in most OpenMMLab projects, which use

To manage your modules in the codebase by `Registry`, there are three steps as below.

1. Create an registry
2. Create a build method
3. Use this registry to manage the modules
1. Create a build method (or use default one).
xvjiarui marked this conversation as resolved.
Show resolved Hide resolved
2. Create a registry.
3. Use this registry to manage the modules.

### A Simple Example

Here we show a simple example of using registry to manage modules in a package.
You can find more practical examples in OpenMMLab projects.

Assuming we want to implement a series of Dataset Converter for converting different formats of data to the expected data format.
We create directory as a package named `converters`.
We create a directory as a package named `converters`.
In the package, we first create a file to implement builders, named `converters/builder.py`, as below
xvjiarui marked this conversation as resolved.
Show resolved Hide resolved

```python
from mmcv.utils import Registry

# create a registry for converters
CONVERTERS = Registry('converter')


# create a build function
def build_converter(cfg, *args, **kwargs):
def build_converter(cfg, registry, *args, **kwargs):
xvjiarui marked this conversation as resolved.
Show resolved Hide resolved
cfg_ = cfg.copy()
converter_type = cfg_.pop('type')
if converter_type not in CONVERTERS:
if converter_type not in registry:
raise KeyError(f'Unrecognized task type {converter_type}')
else:
converter_cls = CONVERTERS.get(converter_type)
converter_cls = registry.get(converter_type)

converter = converter_cls(*args, **kwargs, **cfg_)
return converter

# create a registry for converters
CONVERTERS = Registry('converter', build_func=build_converter)
```

*Note: similar functions like `build_from_cfg` and `build_model_from_cfg` is already implemented, you may directly use them instead of implementing by yourself.*

Then we can implement different converters in the package. For example, implement `Converter1` in `converters/converter1.py`

```python
Expand All @@ -66,5 +67,43 @@ If the module is successfully registered, you can use this converter through con

```python
converter_cfg = dict(type='Converter1', a=a_value, b=b_value)
converter = build_converter(converter_cfg)
converter = CONVERTERS.build(converter_cfg)
```

## Hierarchy Registry
xvjiarui marked this conversation as resolved.
Show resolved Hide resolved

Hierarchy structure is used for similar registries from different packages.
For example, both [MMDetection](https://github.com/open-mmlab/mmdetection) and [MMClassification](https://github.com/open-mmlab/mmclassification) have `MODEL` registry define as followed:
xvjiarui marked this conversation as resolved.
Show resolved Hide resolved
In MMDetection:

```python
from mmcv.utils import Registry
from mmcv.cnn import MODELS as MMCV_MODELS
MODELS = Registry('model', parent=MMCV_MODELS)

@MODELS.register_module()
class NetA(nn.Module):
def forward(self, x):
return x
```

In MMClassification:

```python
from mmcv.utils import Registry
from mmcv.cnn import MODELS as MMCV_MODELS
MODELS = Registry('model', parent=MMCV_MODELS)

@MODELS.register_module()
class NetB(nn.Module):
def forward(self, x):
return x + 1
```

We could build either `NetA` or `NetB` by:

```python
from mmcv.cnn import MODELS as MMCV_MODELS
net_a = MMCV_MODELS.build(cfg=dict(type='mmdet.NetA'))
ZwwWayne marked this conversation as resolved.
Show resolved Hide resolved
net_b = MMCV_MODELS.build(cfg=dict(type='mmcls.NetB'))
```
3 changes: 2 additions & 1 deletion mmcv/cnn/__init__.py
Expand Up @@ -11,6 +11,7 @@
build_activation_layer, build_conv_layer,
build_norm_layer, build_padding_layer, build_plugin_layer,
build_upsample_layer, conv_ws_2d, is_norm)
from .builder import MODELS, build_model_from_cfg
# yapf: enable
from .resnet import ResNet, make_res_layer
from .utils import (bias_init_with_prob, caffe2_xavier_init, constant_init,
Expand All @@ -30,5 +31,5 @@
'PLUGIN_LAYERS', 'Scale', 'get_model_complexity_info', 'conv_ws_2d',
'ConvAWS2d', 'ConvWS2d', 'fuse_conv_bn', 'DepthwiseSeparableConvModule',
'Linear', 'Conv2d', 'ConvTranspose2d', 'MaxPool2d', 'ConvTranspose3d',
'MaxPool3d', 'Conv3d'
'MaxPool3d', 'Conv3d', 'MODELS', 'build_model_from_cfg'
]
28 changes: 28 additions & 0 deletions mmcv/cnn/builder.py
@@ -0,0 +1,28 @@
import torch.nn as nn

from ..utils import Registry, build_from_cfg


def build_model_from_cfg(cfg, registry, default_args=None):
"""Build a module.
xvjiarui marked this conversation as resolved.
Show resolved Hide resolved

Args:
cfg (dict, list[dict]): The config of modules, is is either a dict
or a list of configs.
xvjiarui marked this conversation as resolved.
Show resolved Hide resolved
registry (:obj:`Registry`): A registry the module belongs to.
default_args (dict, optional): Default arguments to build the module.
Defaults to None.

Returns:
nn.Module: A built nn module.
"""
if isinstance(cfg, list):
modules = [
build_from_cfg(cfg_, registry, default_args) for cfg_ in cfg
]
return nn.Sequential(*modules)
else:
return build_from_cfg(cfg, registry, default_args)


MODELS = Registry('model', build_func=build_model_from_cfg)
160 changes: 112 additions & 48 deletions mmcv/utils/registry.py
Expand Up @@ -5,16 +5,84 @@
from .misc import is_str


def build_from_cfg(cfg, registry, default_args=None):
"""Build a module from config dict.

Args:
cfg (dict): Config dict. It should at least contain the key "type".
registry (:obj:`Registry`): The registry to search the type from.
default_args (dict, optional): Default initialization arguments.

Returns:
object: The constructed object.
"""
if not isinstance(cfg, dict):
raise TypeError(f'cfg must be a dict, but got {type(cfg)}')
if 'type' not in cfg:
if default_args is None or 'type' not in default_args:
raise KeyError(
'`cfg` or `default_args` must contain the key "type", '
f'but got {cfg}\n{default_args}')
if not isinstance(registry, Registry):
raise TypeError('registry must be an mmcv.Registry object, '
f'but got {type(registry)}')
if not (isinstance(default_args, dict) or default_args is None):
raise TypeError('default_args must be a dict or None, '
f'but got {type(default_args)}')

args = cfg.copy()

if default_args is not None:
for name, value in default_args.items():
args.setdefault(name, value)

obj_type = args.pop('type')
if is_str(obj_type):
obj_cls = registry.get(obj_type)
if obj_cls is None:
raise KeyError(
f'{obj_type} is not in the {registry.name} registry')
elif inspect.isclass(obj_type):
obj_cls = obj_type
else:
raise TypeError(
f'type must be a str or valid type, but got {type(obj_type)}')

return obj_cls(**args)


class Registry:
"""A registry to map strings to classes.
xvjiarui marked this conversation as resolved.
Show resolved Hide resolved

Args:
name (str): Registry name.
build_func(func, optional): Build function to construct instance from
Registry, ``func:build_from_cfg`` is used if neither ``parent`` or
xvjiarui marked this conversation as resolved.
Show resolved Hide resolved
``build_func`` is specified. If ``parent`` is specified and
``build_func`` is not given, ``build_func`` will be inherited
from ``parent``. Default: None.
parent (Registry, optional): Parent registry. The class registered in
children registry could be built from parent. Default: None.
scope (str, optional): The scope of registry. If not specified, scope
xvjiarui marked this conversation as resolved.
Show resolved Hide resolved
will be the name of the package where class is defined.
Default: None.
"""

def __init__(self, name):
def __init__(self, name, build_func=None, parent=None, scope=None):
self._name = name
self._module_dict = dict()
self._children = dict()
self._scope = self.infer_scope() if scope is None else scope
if parent is not None:
assert isinstance(parent, Registry)
parent._add_children(self)
xvjiarui marked this conversation as resolved.
Show resolved Hide resolved
if build_func is None:
xvjiarui marked this conversation as resolved.
Show resolved Hide resolved
if parent is not None:
self.build_func = parent.build_func
else:
self.build_func = build_from_cfg
else:
self.build_func = build_func

def __len__(self):
return len(self._module_dict)
Expand All @@ -28,14 +96,36 @@ def __repr__(self):
f'items={self._module_dict})'
return format_str

@staticmethod
def infer_scope():
xvjiarui marked this conversation as resolved.
Show resolved Hide resolved
filename = inspect.getmodule(inspect.stack()[2][0]).__name__
xvjiarui marked this conversation as resolved.
Show resolved Hide resolved
split_filename = filename.split('.')
return split_filename[0]

@staticmethod
def split_scope_key(key):
split_index = key.find('.')
hellock marked this conversation as resolved.
Show resolved Hide resolved
if split_index != -1:
return key[:split_index], key[split_index + 1:]
else:
return None, key

@property
def name(self):
return self._name

@property
def scope(self):
return self._scope

@property
def module_dict(self):
return self._module_dict

@property
def children(self):
return self._children

def get(self, key):
"""Get the registry record.

Expand All @@ -45,7 +135,27 @@ def get(self, key):
Returns:
class: The corresponding class.
"""
return self._module_dict.get(key, None)
scope, real_key = self.split_scope_key(key)
if scope is not None:
return self._children[scope].get(real_key)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Does this satisfy the desired use case?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yep. The object could be built with inheritance in a recursive way.

else:
# get from self
if real_key in self._module_dict:
return self._module_dict[real_key]
else:
# get from children
for registry in self._children.values():
result = registry.get(real_key)
if result is not None:
return result

def build(self, *args, **kwargs):
return self.build_func(*args, **kwargs, registry=self)

def _add_children(self, registry):
xvjiarui marked this conversation as resolved.
Show resolved Hide resolved
assert isinstance(registry, Registry)
assert registry.scope is not None
self.children[registry.scope] = registry

def _register_module(self, module_class, module_name=None, force=False):
if not inspect.isclass(module_class):
Expand Down Expand Up @@ -123,49 +233,3 @@ def _register(cls):
return cls

return _register


def build_from_cfg(cfg, registry, default_args=None):
"""Build a module from config dict.

Args:
cfg (dict): Config dict. It should at least contain the key "type".
registry (:obj:`Registry`): The registry to search the type from.
default_args (dict, optional): Default initialization arguments.

Returns:
object: The constructed object.
"""
if not isinstance(cfg, dict):
raise TypeError(f'cfg must be a dict, but got {type(cfg)}')
if 'type' not in cfg:
if default_args is None or 'type' not in default_args:
raise KeyError(
'`cfg` or `default_args` must contain the key "type", '
f'but got {cfg}\n{default_args}')
if not isinstance(registry, Registry):
raise TypeError('registry must be an mmcv.Registry object, '
f'but got {type(registry)}')
if not (isinstance(default_args, dict) or default_args is None):
raise TypeError('default_args must be a dict or None, '
f'but got {type(default_args)}')

args = cfg.copy()

if default_args is not None:
for name, value in default_args.items():
args.setdefault(name, value)

obj_type = args.pop('type')
if is_str(obj_type):
obj_cls = registry.get(obj_type)
if obj_cls is None:
raise KeyError(
f'{obj_type} is not in the {registry.name} registry')
elif inspect.isclass(obj_type):
obj_cls = obj_type
else:
raise TypeError(
f'type must be a str or valid type, but got {type(obj_type)}')

return obj_cls(**args)