diff --git a/mmseg/models/builder.py b/mmseg/models/builder.py index c487dcdd32..9b68ff888c 100644 --- a/mmseg/models/builder.py +++ b/mmseg/models/builder.py @@ -1,56 +1,35 @@ import warnings -from mmcv.utils import Registry, build_from_cfg -from torch import nn +from mmcv.cnn import MODELS as MMCV_MODELS +from mmcv.utils import Registry -BACKBONES = Registry('backbone') -NECKS = Registry('neck') -HEADS = Registry('head') -LOSSES = Registry('loss') -SEGMENTORS = Registry('segmentor') +MODELS = Registry('models', parent=MMCV_MODELS) - -def build(cfg, registry, default_args=None): - """Build a module. - - Args: - cfg (dict, list[dict]): The config of modules, is is either a dict - or a list of configs. - registry (:obj:`Registry`): A registry the module belongs to. - default_args (dict, optional): Default arguments to build the module. - Defaults to None. - - Returns: - nn.Module: A built nn module. - """ - - if isinstance(cfg, list): - modules = [ - build_from_cfg(cfg_, registry, default_args) for cfg_ in cfg - ] - return nn.Sequential(*modules) - else: - return build_from_cfg(cfg, registry, default_args) +BACKBONES = MODELS +NECKS = MODELS +HEADS = MODELS +LOSSES = MODELS +SEGMENTORS = MODELS def build_backbone(cfg): """Build backbone.""" - return build(cfg, BACKBONES) + return BACKBONES.build(cfg) def build_neck(cfg): """Build neck.""" - return build(cfg, NECKS) + return NECKS.build(cfg) def build_head(cfg): """Build head.""" - return build(cfg, HEADS) + return HEADS.build(cfg) def build_loss(cfg): """Build loss.""" - return build(cfg, LOSSES) + return LOSSES.build(cfg) def build_segmentor(cfg, train_cfg=None, test_cfg=None): @@ -63,4 +42,5 @@ def build_segmentor(cfg, train_cfg=None, test_cfg=None): 'train_cfg specified in both outer field and model field ' assert cfg.get('test_cfg') is None or test_cfg is None, \ 'test_cfg specified in both outer field and model field ' - return build(cfg, SEGMENTORS, dict(train_cfg=train_cfg, test_cfg=test_cfg)) + return SEGMENTORS.build( + cfg, default_args=dict(train_cfg=train_cfg, test_cfg=test_cfg))