From df41a74f5fa29cf1914d061898cac8aaa6a7b12f Mon Sep 17 00:00:00 2001 From: V100 Date: Wed, 16 Aug 2023 18:20:53 +0800 Subject: [PATCH] add build_models() / build_evaluator() / build_log_processor() --- mmengine/evaluator/__init__.py | 6 ++- mmengine/evaluator/builder.py | 53 +++++++++++++++++++ mmengine/model/__init__.py | 3 +- mmengine/model/builder.py | 40 ++++++++++++++ mmengine/runner/log_processor.py | 40 +++++++++++++- mmengine/runner/runner.py | 84 ++++++------------------------ mmengine/visualization/__init__.py | 4 +- mmengine/visualization/builder.py | 44 ++++++++++++++++ 8 files changed, 202 insertions(+), 72 deletions(-) create mode 100644 mmengine/evaluator/builder.py create mode 100644 mmengine/model/builder.py create mode 100644 mmengine/visualization/builder.py diff --git a/mmengine/evaluator/__init__.py b/mmengine/evaluator/__init__.py index e6bc78425e..00a93be3f1 100644 --- a/mmengine/evaluator/__init__.py +++ b/mmengine/evaluator/__init__.py @@ -1,6 +1,10 @@ # Copyright (c) OpenMMLab. All rights reserved. +from .builder import build_evaluator from .evaluator import Evaluator from .metric import BaseMetric, DumpResults from .utils import get_metric_value -__all__ = ['BaseMetric', 'Evaluator', 'get_metric_value', 'DumpResults'] +__all__ = [ + 'BaseMetric', 'Evaluator', 'get_metric_value', 'DumpResults', + 'build_evaluator' +] diff --git a/mmengine/evaluator/builder.py b/mmengine/evaluator/builder.py new file mode 100644 index 0000000000..22e50e2a3a --- /dev/null +++ b/mmengine/evaluator/builder.py @@ -0,0 +1,53 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from typing import Dict, List, Union + +from mmengine.registry import EVALUATOR +from .evaluator import Evaluator + + +def build_evaluator(evaluator: Union[Dict, List, Evaluator]) -> Evaluator: + """Build evaluator. + + Examples of ``evaluator``:: + + # evaluator could be a built Evaluator instance + evaluator = Evaluator(metrics=[ToyMetric()]) + + # evaluator can also be a list of dict + evaluator = [ + dict(type='ToyMetric1'), + dict(type='ToyEvaluator2') + ] + + # evaluator can also be a list of built metric + evaluator = [ToyMetric1(), ToyMetric2()] + + # evaluator can also be a dict with key metrics + evaluator = dict(metrics=ToyMetric()) + # metric is a list + evaluator = dict(metrics=[ToyMetric()]) + + Args: + evaluator (Evaluator or dict or list): An Evaluator object or a + config dict or list of config dict used to build an Evaluator. + + Returns: + Evaluator: Evaluator build from ``evaluator``. + """ + if isinstance(evaluator, Evaluator): + return evaluator + elif isinstance(evaluator, dict): + # if `metrics` in dict keys, it means to build customized evalutor + if 'metrics' in evaluator: + evaluator.setdefault('type', 'Evaluator') + return EVALUATOR.build(evaluator) + # otherwise, default evalutor will be built + else: + return Evaluator(evaluator) # type: ignore + elif isinstance(evaluator, list): + # use the default `Evaluator` + return Evaluator(evaluator) # type: ignore + else: + raise TypeError( + 'evaluator should be one of dict, list of dict, and Evaluator' + f', but got {evaluator}') diff --git a/mmengine/model/__init__.py b/mmengine/model/__init__.py index 033512a985..b8c9aaa80d 100644 --- a/mmengine/model/__init__.py +++ b/mmengine/model/__init__.py @@ -5,6 +5,7 @@ MomentumAnnealingEMA, StochasticWeightAverage) from .base_model import BaseDataPreprocessor, BaseModel, ImgDataPreprocessor from .base_module import BaseModule, ModuleDict, ModuleList, Sequential +from .builder import build_model from .test_time_aug import BaseTTAModel from .utils import (convert_sync_batchnorm, detect_anomalous_params, merge_dict, revert_sync_batchnorm, stack_batch) @@ -30,7 +31,7 @@ 'bias_init_with_prob', 'BaseInit', 'ConstantInit', 'XavierInit', 'NormalInit', 'TruncNormalInit', 'UniformInit', 'KaimingInit', 'Caffe2XavierInit', 'PretrainedInit', 'initialize', - 'convert_sync_batchnorm', 'BaseTTAModel' + 'convert_sync_batchnorm', 'BaseTTAModel', 'build_model' ] if digit_version(TORCH_VERSION) >= digit_version('2.0.0'): diff --git a/mmengine/model/builder.py b/mmengine/model/builder.py new file mode 100644 index 0000000000..eeb638fedd --- /dev/null +++ b/mmengine/model/builder.py @@ -0,0 +1,40 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from typing import Dict, Union + +import torch.nn as nn + +from mmengine.registry import MODELS + + +def build_model(model: Union[nn.Module, Dict]) -> nn.Module: + """Build function of Model. + + If ``model`` is a dict, it will be used to build a nn.Module object. + Else, if ``model`` is a nn.Module object it will be returned directly. + + An example of ``model``:: + + model = dict(type='ResNet') + + Args: + model (nn.Module or dict): A ``nn.Module`` object or a dict to + build nn.Module object. If ``model`` is a nn.Module object, + just returns itself. + + Note: + The returned model must implement ``train_step``, ``test_step`` + if ``runner.train`` or ``runner.test`` will be called. If + ``runner.val`` will be called or ``val_cfg`` is configured, + model must implement `val_step`. + + Returns: + nn.Module: Model build from ``model``. + """ + if isinstance(model, nn.Module): + return model + elif isinstance(model, dict): + model = MODELS.build(model) + return model # type: ignore + else: + raise TypeError('model should be a nn.Module object or dict, ' + f'but got {model}') diff --git a/mmengine/runner/log_processor.py b/mmengine/runner/log_processor.py index d3f9d95714..d001ac646f 100644 --- a/mmengine/runner/log_processor.py +++ b/mmengine/runner/log_processor.py @@ -4,7 +4,7 @@ import re from collections import OrderedDict from itertools import chain -from typing import List, Optional, Tuple +from typing import Dict, List, Optional, Tuple, Union import numpy as np import torch @@ -575,3 +575,41 @@ def _get_dataloader_size(self, runner, mode) -> int: int: The dataloader size of current loop. """ return len(self._get_cur_loop(runner=runner, mode=mode).dataloader) + + +def build_log_processor( + log_processor: Union[LogProcessor, Dict]) -> LogProcessor: + """Build test log_processor. + + Examples of ``log_processor``: + + # `LogProcessor` will be used + log_processor = dict() + + # custom log_processor + log_processor = dict(type='CustomLogProcessor') + + Args: + log_processor (LogProcessor or dict): A log processor or a dict + to build log processor. If ``log_processor`` is a log processor + object, just returns itself. + + Returns: + :obj:`LogProcessor`: Log processor object build from + ``log_processor_cfg``. + """ + if isinstance(log_processor, LogProcessor): + return log_processor + elif not isinstance(log_processor, dict): + raise TypeError( + 'log processor should be a LogProcessor object or dict, but' + f'got {log_processor}') + + log_processor_cfg = copy.deepcopy(log_processor) # type: ignore + + if 'type' in log_processor_cfg: + log_processor = LOG_PROCESSORS.build(log_processor_cfg) + else: + log_processor = LogProcessor(**log_processor_cfg) # type: ignore + + return log_processor # type: ignore diff --git a/mmengine/runner/runner.py b/mmengine/runner/runner.py index 12830cf4ad..64cff79166 100644 --- a/mmengine/runner/runner.py +++ b/mmengine/runner/runner.py @@ -23,29 +23,29 @@ from mmengine.device import get_device from mmengine.dist import (broadcast, get_dist_info, get_rank, init_dist, is_distributed, master_only) -from mmengine.evaluator import Evaluator +from mmengine.evaluator import Evaluator, build_evaluator from mmengine.fileio import FileClient, join_path from mmengine.hooks import Hook from mmengine.logging import MessageHub, MMLogger, print_log -from mmengine.model import (MMDistributedDataParallel, convert_sync_batchnorm, - is_model_wrapper, revert_sync_batchnorm) +from mmengine.model import (MMDistributedDataParallel, build_model, + convert_sync_batchnorm, is_model_wrapper, + revert_sync_batchnorm) from mmengine.model.efficient_conv_bn_eval import \ turn_on_efficient_conv_bn_eval from mmengine.optim import (OptimWrapper, OptimWrapperDict, _ParamScheduler, build_optim_wrapper) -from mmengine.registry import (DATA_SAMPLERS, DATASETS, EVALUATOR, FUNCTIONS, - HOOKS, LOG_PROCESSORS, LOOPS, MODEL_WRAPPERS, - MODELS, OPTIM_WRAPPERS, PARAM_SCHEDULERS, - RUNNERS, VISUALIZERS, DefaultScope) +from mmengine.registry import (DATA_SAMPLERS, DATASETS, FUNCTIONS, HOOKS, + LOOPS, MODEL_WRAPPERS, OPTIM_WRAPPERS, + PARAM_SCHEDULERS, RUNNERS, DefaultScope) from mmengine.utils import apply_to, digit_version, get_git_hash, is_seq_of from mmengine.utils.dl_utils import (TORCH_VERSION, collect_env, set_multi_processing) -from mmengine.visualization import Visualizer +from mmengine.visualization import Visualizer, build_visualizer from .base_loop import BaseLoop from .checkpoint import (_load_checkpoint, _load_checkpoint_to_model, find_latest_checkpoint, save_checkpoint, weights_to_cpu) -from .log_processor import LogProcessor +from .log_processor import LogProcessor, build_log_processor from .loops import EpochBasedTrainLoop, IterBasedTrainLoop, TestLoop, ValLoop from .priority import Priority, get_priority from .utils import set_random_seed @@ -769,25 +769,10 @@ def build_visualizer( Returns: Visualizer: A Visualizer object build from ``visualizer``. """ - if visualizer is None: - visualizer = dict( - name=self._experiment_name, - vis_backends=[dict(type='LocalVisBackend')], - save_dir=self._log_dir) - return Visualizer.get_instance(**visualizer) - - if isinstance(visualizer, Visualizer): - return visualizer - - if isinstance(visualizer, dict): - # ensure visualizer containing name key - visualizer.setdefault('name', self._experiment_name) - visualizer.setdefault('save_dir', self._log_dir) - return VISUALIZERS.build(visualizer) - else: - raise TypeError( - 'visualizer should be Visualizer object, a dict or None, ' - f'but got {visualizer}') + return build_visualizer( + visualizer=visualizer, + experiment_name=self._experiment_name, + log_dir=self._log_dir) def build_model(self, model: Union[nn.Module, Dict]) -> nn.Module: """Build model. @@ -813,14 +798,7 @@ def build_model(self, model: Union[nn.Module, Dict]) -> nn.Module: Returns: nn.Module: Model build from ``model``. """ - if isinstance(model, nn.Module): - return model - elif isinstance(model, dict): - model = MODELS.build(model) - return model # type: ignore - else: - raise TypeError('model should be a nn.Module object or dict, ' - f'but got {model}') + return build_model(model=model) def wrap_model( self, model_wrapper_cfg: Optional[Dict], @@ -1289,23 +1267,7 @@ def build_evaluator(self, evaluator: Union[Dict, List, Returns: Evaluator: Evaluator build from ``evaluator``. """ - if isinstance(evaluator, Evaluator): - return evaluator - elif isinstance(evaluator, dict): - # if `metrics` in dict keys, it means to build customized evalutor - if 'metrics' in evaluator: - evaluator.setdefault('type', 'Evaluator') - return EVALUATOR.build(evaluator) - # otherwise, default evalutor will be built - else: - return Evaluator(evaluator) # type: ignore - elif isinstance(evaluator, list): - # use the default `Evaluator` - return Evaluator(evaluator) # type: ignore - else: - raise TypeError( - 'evaluator should be one of dict, list of dict, and Evaluator' - f', but got {evaluator}') + return build_evaluator(evaluator=evaluator) @staticmethod def build_dataloader(dataloader: Union[DataLoader, Dict], @@ -1612,21 +1574,7 @@ def build_log_processor( :obj:`LogProcessor`: Log processor object build from ``log_processor_cfg``. """ - if isinstance(log_processor, LogProcessor): - return log_processor - elif not isinstance(log_processor, dict): - raise TypeError( - 'log processor should be a LogProcessor object or dict, but' - f'got {log_processor}') - - log_processor_cfg = copy.deepcopy(log_processor) # type: ignore - - if 'type' in log_processor_cfg: - log_processor = LOG_PROCESSORS.build(log_processor_cfg) - else: - log_processor = LogProcessor(**log_processor_cfg) # type: ignore - - return log_processor # type: ignore + return build_log_processor(log_processor=log_processor) def get_hooks_info(self) -> str: # Get hooks info in each stage diff --git a/mmengine/visualization/__init__.py b/mmengine/visualization/__init__.py index a0a518e675..db79022f87 100644 --- a/mmengine/visualization/__init__.py +++ b/mmengine/visualization/__init__.py @@ -1,4 +1,5 @@ # Copyright (c) OpenMMLab. All rights reserved. +from .builder import build_visualizer from .vis_backend import (BaseVisBackend, ClearMLVisBackend, LocalVisBackend, MLflowVisBackend, TensorboardVisBackend, WandbVisBackend) @@ -6,5 +7,6 @@ __all__ = [ 'Visualizer', 'BaseVisBackend', 'LocalVisBackend', 'WandbVisBackend', - 'TensorboardVisBackend', 'MLflowVisBackend', 'ClearMLVisBackend' + 'TensorboardVisBackend', 'MLflowVisBackend', 'ClearMLVisBackend', + 'build_visualizer' ] diff --git a/mmengine/visualization/builder.py b/mmengine/visualization/builder.py new file mode 100644 index 0000000000..00dc73dc76 --- /dev/null +++ b/mmengine/visualization/builder.py @@ -0,0 +1,44 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from typing import Dict, Optional, Union + +from mmengine.registry import VISUALIZERS +from .visualizer import Visualizer +from datetime import datetime + +def build_visualizer( + visualizer: Optional[Union[Visualizer, + Dict]] = None, + experiment_name: str = datetime.now().strftime(r'%Y%m%d_%H%M%S'), + log_dir: str = 'work_dirs' + ) -> Visualizer: + """Build a global asscessable Visualizer. + + Args: + visualizer (Visualizer or dict, optional): A Visualizer object + or a dict to build Visualizer object. If ``visualizer`` is a + Visualizer object, just returns itself. If not specified, + default config will be used to build Visualizer object. + Defaults to None. + + Returns: + Visualizer: A Visualizer object build from ``visualizer``. + """ + if visualizer is None: + visualizer = dict( + name=experiment_name, + vis_backends=[dict(type='LocalVisBackend')], + save_dir=log_dir) + return Visualizer.get_instance(**visualizer) + + if isinstance(visualizer, Visualizer): + return visualizer + + if isinstance(visualizer, dict): + # ensure visualizer containing name key + visualizer.setdefault('name', experiment_name) + visualizer.setdefault('save_dir', log_dir) + return VISUALIZERS.build(visualizer) + else: + raise TypeError( + 'visualizer should be Visualizer object, a dict or None, ' + f'but got {visualizer}')