Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

add model registry #760

Merged
merged 23 commits into from Apr 10, 2021
Merged
Show file tree
Hide file tree
Changes from 15 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
83 changes: 71 additions & 12 deletions docs/registry.md
Expand Up @@ -12,39 +12,40 @@ One typical example is the config systems in most OpenMMLab projects, which use

To manage your modules in the codebase by `Registry`, there are three steps as below.

1. Create an registry
2. Create a build method
3. Use this registry to manage the modules
1. Create a build method (or use default one).
xvjiarui marked this conversation as resolved.
Show resolved Hide resolved
2. Create a registry.
3. Use this registry to manage the modules.

### A Simple Example

Here we show a simple example of using registry to manage modules in a package.
You can find more practical examples in OpenMMLab projects.

Assuming we want to implement a series of Dataset Converter for converting different formats of data to the expected data format.
We create directory as a package named `converters`.
We create a directory as a package named `converters`.
In the package, we first create a file to implement builders, named `converters/builder.py`, as below
xvjiarui marked this conversation as resolved.
Show resolved Hide resolved

```python
from mmcv.utils import Registry

# create a registry for converters
CONVERTERS = Registry('converter')


# create a build function
def build_converter(cfg, *args, **kwargs):
def build_converter(cfg, registry, *args, **kwargs):
xvjiarui marked this conversation as resolved.
Show resolved Hide resolved
cfg_ = cfg.copy()
converter_type = cfg_.pop('type')
if converter_type not in CONVERTERS:
if converter_type not in registry:
raise KeyError(f'Unrecognized task type {converter_type}')
else:
converter_cls = CONVERTERS.get(converter_type)
converter_cls = registry.get(converter_type)

converter = converter_cls(*args, **kwargs, **cfg_)
return converter

# create a registry for converters
CONVERTERS = Registry('converter', build_func=build_converter)
```

*Note: similar functions like `build_from_cfg` and `build_model_from_cfg` is already implemented, you may directly use them instead of implementing by yourself.*

Then we can implement different converters in the package. For example, implement `Converter1` in `converters/converter1.py`

```python
Expand All @@ -71,5 +72,63 @@ If the module is successfully registered, you can use this converter through con

```python
converter_cfg = dict(type='Converter1', a=a_value, b=b_value)
converter = build_converter(converter_cfg)
converter = CONVERTERS.build(converter_cfg)
```

## Hierarchy Registry
xvjiarui marked this conversation as resolved.
Show resolved Hide resolved

Hierarchy structure is used for similar registries from different packages.
For example, both [MMDetection](https://github.com/open-mmlab/mmdetection) and [MMClassification](https://github.com/open-mmlab/mmclassification) have `MODEL` registry define as followed:
xvjiarui marked this conversation as resolved.
Show resolved Hide resolved
In MMDetection:

```python
from mmcv.utils import Registry
from mmcv.cnn import MODELS as MMCV_MODELS
MODELS = Registry('model', parent=MMCV_MODELS)

@MODELS.register_module()
class NetA(nn.Module):
def forward(self, x):
return x
```

In MMClassification:

```python
from mmcv.utils import Registry
from mmcv.cnn import MODELS as MMCV_MODELS
MODELS = Registry('model', parent=MMCV_MODELS)

@MODELS.register_module()
class NetB(nn.Module):
def forward(self, x):
return x + 1
```

We could build two net in either MMDetection or MMClassification by:

```python
import mmcls # import mmcls to register models
xvjiarui marked this conversation as resolved.
Show resolved Hide resolved
from mmdet.models import MODELS
net_a = MODELS.build(cfg=dict(type='NetA'))
net_b = MODELS.build(cfg=dict(type='mmcls.NetB'))
```

or

```python
import mmdet # import mmcls to register models
from mmcls.models import MODELS
net_a = MODELS.build(cfg=dict(type='mmdet.NetA'))
net_b = MODELS.build(cfg=dict(type='NetB'))
```

Build them by shared `MODELS` registry in MMCV is also feasible:

```python
from mmcv.cnn import MODELS as MMCV_MODELS
net_a = MMCV_MODELS.build(cfg=dict(type='mmdet.NetA'))
ZwwWayne marked this conversation as resolved.
Show resolved Hide resolved
net_b = MMCV_MODELS.build(cfg=dict(type='mmcls.NetB'))
```

Note that: we may avoid unused import by adding `custom_imports = dict(imports=['mmdet', 'mmcls'])` in config file.
4 changes: 3 additions & 1 deletion mmcv/cnn/__init__.py
Expand Up @@ -11,6 +11,7 @@
build_activation_layer, build_conv_layer,
build_norm_layer, build_padding_layer, build_plugin_layer,
build_upsample_layer, conv_ws_2d, is_norm)
from .builder import MODELS, build_model_from_cfg
# yapf: enable
from .resnet import ResNet, make_res_layer
from .utils import (INITIALIZERS, ConstantInit, KaimingInit, NormalInit,
Expand All @@ -33,5 +34,6 @@
'ConvAWS2d', 'ConvWS2d', 'fuse_conv_bn', 'DepthwiseSeparableConvModule',
'Linear', 'Conv2d', 'ConvTranspose2d', 'MaxPool2d', 'ConvTranspose3d',
'MaxPool3d', 'Conv3d', 'initialize', 'INITIALIZERS', 'ConstantInit',
'XavierInit', 'NormalInit', 'UniformInit', 'KaimingInit', 'PretrainedInit'
'XavierInit', 'NormalInit', 'UniformInit', 'KaimingInit', 'PretrainedInit',
'MODELS', 'build_model_from_cfg'
]
29 changes: 29 additions & 0 deletions mmcv/cnn/builder.py
@@ -0,0 +1,29 @@
import torch.nn as nn

from ..utils import Registry, build_from_cfg


def build_model_from_cfg(cfg, registry, default_args=None):
"""Build a module from config dict(s).
xvjiarui marked this conversation as resolved.
Show resolved Hide resolved

Args:
cfg (dict, list[dict]): The config of modules, is is either a config
dict or a list of config dicts. If cfg is a list, a
the built modules will be wrapped with ``nn.Sequential``.
registry (:obj:`Registry`): A registry the module belongs to.
default_args (dict, optional): Default arguments to build the module.
Defaults to None.

Returns:
nn.Module: A built nn module.
"""
if isinstance(cfg, list):
modules = [
build_from_cfg(cfg_, registry, default_args) for cfg_ in cfg
]
return nn.Sequential(*modules)
else:
return build_from_cfg(cfg, registry, default_args)


MODELS = Registry('model', build_func=build_model_from_cfg)