Skip to content

Commit

Permalink
add build_models() / build_evaluator() / build_log_processor()
Browse files Browse the repository at this point in the history
  • Loading branch information
V100 authored and GuoPingPan committed Aug 21, 2023
1 parent a483dba commit 1032279
Show file tree
Hide file tree
Showing 5 changed files with 139 additions and 3 deletions.
6 changes: 5 additions & 1 deletion mmengine/evaluator/__init__.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,10 @@
# Copyright (c) OpenMMLab. All rights reserved.
from .builder import build_evaluator
from .evaluator import Evaluator
from .metric import BaseMetric, DumpResults
from .utils import get_metric_value

__all__ = ['BaseMetric', 'Evaluator', 'get_metric_value', 'DumpResults']
__all__ = [
'BaseMetric', 'Evaluator', 'get_metric_value', 'DumpResults',
'build_evaluator'
]
53 changes: 53 additions & 0 deletions mmengine/evaluator/builder.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,53 @@
# Copyright (c) OpenMMLab. All rights reserved.
from typing import Dict, List, Union

from mmengine.registry import EVALUATOR
from .evaluator import Evaluator


def build_evaluator(evaluator: Union[Dict, List, Evaluator]) -> Evaluator:
"""Build evaluator.
Examples of ``evaluator``::
# evaluator could be a built Evaluator instance
evaluator = Evaluator(metrics=[ToyMetric()])
# evaluator can also be a list of dict
evaluator = [
dict(type='ToyMetric1'),
dict(type='ToyEvaluator2')
]
# evaluator can also be a list of built metric
evaluator = [ToyMetric1(), ToyMetric2()]
# evaluator can also be a dict with key metrics
evaluator = dict(metrics=ToyMetric())
# metric is a list
evaluator = dict(metrics=[ToyMetric()])
Args:
evaluator (Evaluator or dict or list): An Evaluator object or a
config dict or list of config dict used to build an Evaluator.
Returns:
Evaluator: Evaluator build from ``evaluator``.
"""
if isinstance(evaluator, Evaluator):
return evaluator
elif isinstance(evaluator, dict):
# if `metrics` in dict keys, it means to build customized evalutor
if 'metrics' in evaluator:
evaluator.setdefault('type', 'Evaluator')
return EVALUATOR.build(evaluator)
# otherwise, default evalutor will be built
else:
return Evaluator(evaluator) # type: ignore
elif isinstance(evaluator, list):
# use the default `Evaluator`
return Evaluator(evaluator) # type: ignore
else:
raise TypeError(
'evaluator should be one of dict, list of dict, and Evaluator'
f', but got {evaluator}')
3 changes: 2 additions & 1 deletion mmengine/model/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
MomentumAnnealingEMA, StochasticWeightAverage)
from .base_model import BaseDataPreprocessor, BaseModel, ImgDataPreprocessor
from .base_module import BaseModule, ModuleDict, ModuleList, Sequential
from .builder import build_model
from .test_time_aug import BaseTTAModel
from .utils import (convert_sync_batchnorm, detect_anomalous_params,
merge_dict, revert_sync_batchnorm, stack_batch)
Expand All @@ -30,7 +31,7 @@
'bias_init_with_prob', 'BaseInit', 'ConstantInit', 'XavierInit',
'NormalInit', 'TruncNormalInit', 'UniformInit', 'KaimingInit',
'Caffe2XavierInit', 'PretrainedInit', 'initialize',
'convert_sync_batchnorm', 'BaseTTAModel'
'convert_sync_batchnorm', 'BaseTTAModel', 'build_model'
]

if digit_version(TORCH_VERSION) >= digit_version('2.0.0'):
Expand Down
40 changes: 40 additions & 0 deletions mmengine/model/builder.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,40 @@
# Copyright (c) OpenMMLab. All rights reserved.
from typing import Dict, Union

import torch.nn as nn

from mmengine.registry import MODELS


def build_model(model: Union[nn.Module, Dict]) -> nn.Module:
"""Build function of Model.
If ``model`` is a dict, it will be used to build a nn.Module object.
Else, if ``model`` is a nn.Module object it will be returned directly.
An example of ``model``::
model = dict(type='ResNet')
Args:
model (nn.Module or dict): A ``nn.Module`` object or a dict to
build nn.Module object. If ``model`` is a nn.Module object,
just returns itself.
Note:
The returned model must implement ``train_step``, ``test_step``
if ``runner.train`` or ``runner.test`` will be called. If
``runner.val`` will be called or ``val_cfg`` is configured,
model must implement `val_step`.
Returns:
nn.Module: Model build from ``model``.
"""
if isinstance(model, nn.Module):
return model
elif isinstance(model, dict):
model = MODELS.build(model)
return model # type: ignore
else:
raise TypeError('model should be a nn.Module object or dict, '
f'but got {model}')
40 changes: 39 additions & 1 deletion mmengine/runner/log_processor.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
import re
from collections import OrderedDict
from itertools import chain
from typing import List, Optional, Tuple
from typing import Dict, List, Optional, Tuple, Union

import numpy as np
import torch
Expand Down Expand Up @@ -575,3 +575,41 @@ def _get_dataloader_size(self, runner, mode) -> int:
int: The dataloader size of current loop.
"""
return len(self._get_cur_loop(runner=runner, mode=mode).dataloader)


def build_log_processor(
log_processor: Union[LogProcessor, Dict]) -> LogProcessor:
"""Build test log_processor.
Examples of ``log_processor``:
# `LogProcessor` will be used
log_processor = dict()
# custom log_processor
log_processor = dict(type='CustomLogProcessor')
Args:
log_processor (LogProcessor or dict): A log processor or a dict
to build log processor. If ``log_processor`` is a log processor
object, just returns itself.
Returns:
:obj:`LogProcessor`: Log processor object build from
``log_processor_cfg``.
"""
if isinstance(log_processor, LogProcessor):
return log_processor
elif not isinstance(log_processor, dict):
raise TypeError(
'log processor should be a LogProcessor object or dict, but'
f'got {log_processor}')

log_processor_cfg = copy.deepcopy(log_processor) # type: ignore

if 'type' in log_processor_cfg:
log_processor = LOG_PROCESSORS.build(log_processor_cfg)
else:
log_processor = LogProcessor(**log_processor_cfg) # type: ignore

return log_processor # type: ignore

0 comments on commit 1032279

Please sign in to comment.