Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

add build_models() / build_evaluator() / build_log_processor() #1310

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 5 additions & 1 deletion mmengine/evaluator/__init__.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,10 @@
# Copyright (c) OpenMMLab. All rights reserved.
from .builder import build_evaluator
from .evaluator import Evaluator
from .metric import BaseMetric, DumpResults
from .utils import get_metric_value

__all__ = ['BaseMetric', 'Evaluator', 'get_metric_value', 'DumpResults']
__all__ = [
'BaseMetric', 'Evaluator', 'get_metric_value', 'DumpResults',
'build_evaluator'
]
53 changes: 53 additions & 0 deletions mmengine/evaluator/builder.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,53 @@
# Copyright (c) OpenMMLab. All rights reserved.
from typing import Dict, List, Union

from mmengine.registry import EVALUATOR
from .evaluator import Evaluator


def build_evaluator(evaluator: Union[Dict, List, Evaluator]) -> Evaluator:
"""Build evaluator.
Examples of ``evaluator``::
# evaluator could be a built Evaluator instance
evaluator = Evaluator(metrics=[ToyMetric()])
# evaluator can also be a list of dict
evaluator = [
dict(type='ToyMetric1'),
dict(type='ToyEvaluator2')
]
# evaluator can also be a list of built metric
evaluator = [ToyMetric1(), ToyMetric2()]
# evaluator can also be a dict with key metrics
evaluator = dict(metrics=ToyMetric())
# metric is a list
evaluator = dict(metrics=[ToyMetric()])
Args:
evaluator (Evaluator or dict or list): An Evaluator object or a
config dict or list of config dict used to build an Evaluator.
Returns:
Evaluator: Evaluator build from ``evaluator``.
"""
if isinstance(evaluator, Evaluator):
return evaluator
elif isinstance(evaluator, dict):
# if `metrics` in dict keys, it means to build customized evalutor
if 'metrics' in evaluator:
evaluator.setdefault('type', 'Evaluator')
return EVALUATOR.build(evaluator)
# otherwise, default evalutor will be built
else:
return Evaluator(evaluator) # type: ignore
elif isinstance(evaluator, list):
# use the default `Evaluator`
return Evaluator(evaluator) # type: ignore
else:
raise TypeError(
'evaluator should be one of dict, list of dict, and Evaluator'
f', but got {evaluator}')
3 changes: 2 additions & 1 deletion mmengine/model/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
MomentumAnnealingEMA, StochasticWeightAverage)
from .base_model import BaseDataPreprocessor, BaseModel, ImgDataPreprocessor
from .base_module import BaseModule, ModuleDict, ModuleList, Sequential
from .builder import build_model
from .test_time_aug import BaseTTAModel
from .utils import (convert_sync_batchnorm, detect_anomalous_params,
merge_dict, revert_sync_batchnorm, stack_batch)
Expand All @@ -30,7 +31,7 @@
'bias_init_with_prob', 'BaseInit', 'ConstantInit', 'XavierInit',
'NormalInit', 'TruncNormalInit', 'UniformInit', 'KaimingInit',
'Caffe2XavierInit', 'PretrainedInit', 'initialize',
'convert_sync_batchnorm', 'BaseTTAModel'
'convert_sync_batchnorm', 'BaseTTAModel', 'build_model'
]

if digit_version(TORCH_VERSION) >= digit_version('2.0.0'):
Expand Down
40 changes: 40 additions & 0 deletions mmengine/model/builder.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,40 @@
# Copyright (c) OpenMMLab. All rights reserved.
from typing import Dict, Union

import torch.nn as nn

from mmengine.registry import MODELS


def build_model(model: Union[nn.Module, Dict]) -> nn.Module:
"""Build function of Model.
If ``model`` is a dict, it will be used to build a nn.Module object.
Else, if ``model`` is a nn.Module object it will be returned directly.
An example of ``model``::
model = dict(type='ResNet')
Args:
model (nn.Module or dict): A ``nn.Module`` object or a dict to
build nn.Module object. If ``model`` is a nn.Module object,
just returns itself.
Note:
The returned model must implement ``train_step``, ``test_step``
if ``runner.train`` or ``runner.test`` will be called. If
``runner.val`` will be called or ``val_cfg`` is configured,
model must implement `val_step`.
Returns:
nn.Module: Model build from ``model``.
"""
if isinstance(model, nn.Module):
return model
elif isinstance(model, dict):
model = MODELS.build(model)
return model # type: ignore
else:
raise TypeError('model should be a nn.Module object or dict, '
f'but got {model}')
40 changes: 39 additions & 1 deletion mmengine/runner/log_processor.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
import re
from collections import OrderedDict
from itertools import chain
from typing import List, Optional, Tuple
from typing import Dict, List, Optional, Tuple, Union

import numpy as np
import torch
Expand Down Expand Up @@ -575,3 +575,41 @@ def _get_dataloader_size(self, runner, mode) -> int:
int: The dataloader size of current loop.
"""
return len(self._get_cur_loop(runner=runner, mode=mode).dataloader)


def build_log_processor(
log_processor: Union[LogProcessor, Dict]) -> LogProcessor:
"""Build test log_processor.
Examples of ``log_processor``:
# `LogProcessor` will be used
log_processor = dict()
# custom log_processor
log_processor = dict(type='CustomLogProcessor')
Args:
log_processor (LogProcessor or dict): A log processor or a dict
to build log processor. If ``log_processor`` is a log processor
object, just returns itself.
Returns:
:obj:`LogProcessor`: Log processor object build from
``log_processor_cfg``.
"""
if isinstance(log_processor, LogProcessor):
return log_processor
elif not isinstance(log_processor, dict):
raise TypeError(
'log processor should be a LogProcessor object or dict, but'
f'got {log_processor}')

log_processor_cfg = copy.deepcopy(log_processor) # type: ignore

if 'type' in log_processor_cfg:
log_processor = LOG_PROCESSORS.build(log_processor_cfg)
else:
log_processor = LogProcessor(**log_processor_cfg) # type: ignore

return log_processor # type: ignore
84 changes: 16 additions & 68 deletions mmengine/runner/runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,29 +23,29 @@
from mmengine.device import get_device
from mmengine.dist import (broadcast, get_dist_info, get_rank, init_dist,
is_distributed, master_only)
from mmengine.evaluator import Evaluator
from mmengine.evaluator import Evaluator, build_evaluator
from mmengine.fileio import FileClient, join_path
from mmengine.hooks import Hook
from mmengine.logging import MessageHub, MMLogger, print_log
from mmengine.model import (MMDistributedDataParallel, convert_sync_batchnorm,
is_model_wrapper, revert_sync_batchnorm)
from mmengine.model import (MMDistributedDataParallel, build_model,
convert_sync_batchnorm, is_model_wrapper,
revert_sync_batchnorm)
from mmengine.model.efficient_conv_bn_eval import \
turn_on_efficient_conv_bn_eval
from mmengine.optim import (OptimWrapper, OptimWrapperDict, _ParamScheduler,
build_optim_wrapper)
from mmengine.registry import (DATA_SAMPLERS, DATASETS, EVALUATOR, FUNCTIONS,
HOOKS, LOG_PROCESSORS, LOOPS, MODEL_WRAPPERS,
MODELS, OPTIM_WRAPPERS, PARAM_SCHEDULERS,
RUNNERS, VISUALIZERS, DefaultScope)
from mmengine.registry import (DATA_SAMPLERS, DATASETS, FUNCTIONS, HOOKS,
LOOPS, MODEL_WRAPPERS, OPTIM_WRAPPERS,
PARAM_SCHEDULERS, RUNNERS, DefaultScope)
from mmengine.utils import apply_to, digit_version, get_git_hash, is_seq_of
from mmengine.utils.dl_utils import (TORCH_VERSION, collect_env,
set_multi_processing)
from mmengine.visualization import Visualizer
from mmengine.visualization import Visualizer, build_visualizer
from .base_loop import BaseLoop
from .checkpoint import (_load_checkpoint, _load_checkpoint_to_model,
find_latest_checkpoint, save_checkpoint,
weights_to_cpu)
from .log_processor import LogProcessor
from .log_processor import LogProcessor, build_log_processor
from .loops import EpochBasedTrainLoop, IterBasedTrainLoop, TestLoop, ValLoop
from .priority import Priority, get_priority
from .utils import set_random_seed
Expand Down Expand Up @@ -769,25 +769,10 @@ def build_visualizer(
Returns:
Visualizer: A Visualizer object build from ``visualizer``.
"""
if visualizer is None:
visualizer = dict(
name=self._experiment_name,
vis_backends=[dict(type='LocalVisBackend')],
save_dir=self._log_dir)
return Visualizer.get_instance(**visualizer)

if isinstance(visualizer, Visualizer):
return visualizer

if isinstance(visualizer, dict):
# ensure visualizer containing name key
visualizer.setdefault('name', self._experiment_name)
visualizer.setdefault('save_dir', self._log_dir)
return VISUALIZERS.build(visualizer)
else:
raise TypeError(
'visualizer should be Visualizer object, a dict or None, '
f'but got {visualizer}')
return build_visualizer(
visualizer=visualizer,
experiment_name=self._experiment_name,
log_dir=self._log_dir)

def build_model(self, model: Union[nn.Module, Dict]) -> nn.Module:
"""Build model.
Expand All @@ -813,14 +798,7 @@ def build_model(self, model: Union[nn.Module, Dict]) -> nn.Module:
Returns:
nn.Module: Model build from ``model``.
"""
if isinstance(model, nn.Module):
return model
elif isinstance(model, dict):
model = MODELS.build(model)
return model # type: ignore
else:
raise TypeError('model should be a nn.Module object or dict, '
f'but got {model}')
return build_model(model=model)

def wrap_model(
self, model_wrapper_cfg: Optional[Dict],
Expand Down Expand Up @@ -1289,23 +1267,7 @@ def build_evaluator(self, evaluator: Union[Dict, List,
Returns:
Evaluator: Evaluator build from ``evaluator``.
"""
if isinstance(evaluator, Evaluator):
return evaluator
elif isinstance(evaluator, dict):
# if `metrics` in dict keys, it means to build customized evalutor
if 'metrics' in evaluator:
evaluator.setdefault('type', 'Evaluator')
return EVALUATOR.build(evaluator)
# otherwise, default evalutor will be built
else:
return Evaluator(evaluator) # type: ignore
elif isinstance(evaluator, list):
# use the default `Evaluator`
return Evaluator(evaluator) # type: ignore
else:
raise TypeError(
'evaluator should be one of dict, list of dict, and Evaluator'
f', but got {evaluator}')
return build_evaluator(evaluator=evaluator)

@staticmethod
def build_dataloader(dataloader: Union[DataLoader, Dict],
Expand Down Expand Up @@ -1612,21 +1574,7 @@ def build_log_processor(
:obj:`LogProcessor`: Log processor object build from
``log_processor_cfg``.
"""
if isinstance(log_processor, LogProcessor):
return log_processor
elif not isinstance(log_processor, dict):
raise TypeError(
'log processor should be a LogProcessor object or dict, but'
f'got {log_processor}')

log_processor_cfg = copy.deepcopy(log_processor) # type: ignore

if 'type' in log_processor_cfg:
log_processor = LOG_PROCESSORS.build(log_processor_cfg)
else:
log_processor = LogProcessor(**log_processor_cfg) # type: ignore

return log_processor # type: ignore
return build_log_processor(log_processor=log_processor)

def get_hooks_info(self) -> str:
# Get hooks info in each stage
Expand Down
4 changes: 3 additions & 1 deletion mmengine/visualization/__init__.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,12 @@
# Copyright (c) OpenMMLab. All rights reserved.
from .builder import build_visualizer
from .vis_backend import (BaseVisBackend, ClearMLVisBackend, LocalVisBackend,
MLflowVisBackend, TensorboardVisBackend,
WandbVisBackend)
from .visualizer import Visualizer

__all__ = [
'Visualizer', 'BaseVisBackend', 'LocalVisBackend', 'WandbVisBackend',
'TensorboardVisBackend', 'MLflowVisBackend', 'ClearMLVisBackend'
'TensorboardVisBackend', 'MLflowVisBackend', 'ClearMLVisBackend',
'build_visualizer'
]
44 changes: 44 additions & 0 deletions mmengine/visualization/builder.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,44 @@
# Copyright (c) OpenMMLab. All rights reserved.
from typing import Dict, Optional, Union

from mmengine.registry import VISUALIZERS
from .visualizer import Visualizer
from datetime import datetime

def build_visualizer(
visualizer: Optional[Union[Visualizer,
Dict]] = None,
experiment_name: str = datetime.now().strftime(r'%Y%m%d_%H%M%S'),
log_dir: str = 'work_dirs'
) -> Visualizer:
"""Build a global asscessable Visualizer.
Args:
visualizer (Visualizer or dict, optional): A Visualizer object
or a dict to build Visualizer object. If ``visualizer`` is a
Visualizer object, just returns itself. If not specified,
default config will be used to build Visualizer object.
Defaults to None.
Returns:
Visualizer: A Visualizer object build from ``visualizer``.
"""
if visualizer is None:
visualizer = dict(
name=experiment_name,
vis_backends=[dict(type='LocalVisBackend')],
save_dir=log_dir)
return Visualizer.get_instance(**visualizer)

if isinstance(visualizer, Visualizer):
return visualizer

if isinstance(visualizer, dict):
# ensure visualizer containing name key
visualizer.setdefault('name', experiment_name)
visualizer.setdefault('save_dir', log_dir)
return VISUALIZERS.build(visualizer)
else:
raise TypeError(
'visualizer should be Visualizer object, a dict or None, '
f'but got {visualizer}')