-
Notifications
You must be signed in to change notification settings - Fork 328
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
add build_models() / build_evaluator() / build_log_processor()
- Loading branch information
1 parent
a483dba
commit df41a74
Showing
8 changed files
with
202 additions
and
72 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,6 +1,10 @@ | ||
# Copyright (c) OpenMMLab. All rights reserved. | ||
from .builder import build_evaluator | ||
from .evaluator import Evaluator | ||
from .metric import BaseMetric, DumpResults | ||
from .utils import get_metric_value | ||
|
||
__all__ = ['BaseMetric', 'Evaluator', 'get_metric_value', 'DumpResults'] | ||
__all__ = [ | ||
'BaseMetric', 'Evaluator', 'get_metric_value', 'DumpResults', | ||
'build_evaluator' | ||
] |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,53 @@ | ||
# Copyright (c) OpenMMLab. All rights reserved. | ||
from typing import Dict, List, Union | ||
|
||
from mmengine.registry import EVALUATOR | ||
from .evaluator import Evaluator | ||
|
||
|
||
def build_evaluator(evaluator: Union[Dict, List, Evaluator]) -> Evaluator: | ||
"""Build evaluator. | ||
Examples of ``evaluator``:: | ||
# evaluator could be a built Evaluator instance | ||
evaluator = Evaluator(metrics=[ToyMetric()]) | ||
# evaluator can also be a list of dict | ||
evaluator = [ | ||
dict(type='ToyMetric1'), | ||
dict(type='ToyEvaluator2') | ||
] | ||
# evaluator can also be a list of built metric | ||
evaluator = [ToyMetric1(), ToyMetric2()] | ||
# evaluator can also be a dict with key metrics | ||
evaluator = dict(metrics=ToyMetric()) | ||
# metric is a list | ||
evaluator = dict(metrics=[ToyMetric()]) | ||
Args: | ||
evaluator (Evaluator or dict or list): An Evaluator object or a | ||
config dict or list of config dict used to build an Evaluator. | ||
Returns: | ||
Evaluator: Evaluator build from ``evaluator``. | ||
""" | ||
if isinstance(evaluator, Evaluator): | ||
return evaluator | ||
elif isinstance(evaluator, dict): | ||
# if `metrics` in dict keys, it means to build customized evalutor | ||
if 'metrics' in evaluator: | ||
evaluator.setdefault('type', 'Evaluator') | ||
return EVALUATOR.build(evaluator) | ||
# otherwise, default evalutor will be built | ||
else: | ||
return Evaluator(evaluator) # type: ignore | ||
elif isinstance(evaluator, list): | ||
# use the default `Evaluator` | ||
return Evaluator(evaluator) # type: ignore | ||
else: | ||
raise TypeError( | ||
'evaluator should be one of dict, list of dict, and Evaluator' | ||
f', but got {evaluator}') |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,40 @@ | ||
# Copyright (c) OpenMMLab. All rights reserved. | ||
from typing import Dict, Union | ||
|
||
import torch.nn as nn | ||
|
||
from mmengine.registry import MODELS | ||
|
||
|
||
def build_model(model: Union[nn.Module, Dict]) -> nn.Module: | ||
"""Build function of Model. | ||
If ``model`` is a dict, it will be used to build a nn.Module object. | ||
Else, if ``model`` is a nn.Module object it will be returned directly. | ||
An example of ``model``:: | ||
model = dict(type='ResNet') | ||
Args: | ||
model (nn.Module or dict): A ``nn.Module`` object or a dict to | ||
build nn.Module object. If ``model`` is a nn.Module object, | ||
just returns itself. | ||
Note: | ||
The returned model must implement ``train_step``, ``test_step`` | ||
if ``runner.train`` or ``runner.test`` will be called. If | ||
``runner.val`` will be called or ``val_cfg`` is configured, | ||
model must implement `val_step`. | ||
Returns: | ||
nn.Module: Model build from ``model``. | ||
""" | ||
if isinstance(model, nn.Module): | ||
return model | ||
elif isinstance(model, dict): | ||
model = MODELS.build(model) | ||
return model # type: ignore | ||
else: | ||
raise TypeError('model should be a nn.Module object or dict, ' | ||
f'but got {model}') |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,10 +1,12 @@ | ||
# Copyright (c) OpenMMLab. All rights reserved. | ||
from .builder import build_visualizer | ||
from .vis_backend import (BaseVisBackend, ClearMLVisBackend, LocalVisBackend, | ||
MLflowVisBackend, TensorboardVisBackend, | ||
WandbVisBackend) | ||
from .visualizer import Visualizer | ||
|
||
__all__ = [ | ||
'Visualizer', 'BaseVisBackend', 'LocalVisBackend', 'WandbVisBackend', | ||
'TensorboardVisBackend', 'MLflowVisBackend', 'ClearMLVisBackend' | ||
'TensorboardVisBackend', 'MLflowVisBackend', 'ClearMLVisBackend', | ||
'build_visualizer' | ||
] |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,44 @@ | ||
# Copyright (c) OpenMMLab. All rights reserved. | ||
from typing import Dict, Optional, Union | ||
|
||
from mmengine.registry import VISUALIZERS | ||
from .visualizer import Visualizer | ||
from datetime import datetime | ||
|
||
def build_visualizer( | ||
visualizer: Optional[Union[Visualizer, | ||
Dict]] = None, | ||
experiment_name: str = datetime.now().strftime(r'%Y%m%d_%H%M%S'), | ||
log_dir: str = 'work_dirs' | ||
) -> Visualizer: | ||
"""Build a global asscessable Visualizer. | ||
Args: | ||
visualizer (Visualizer or dict, optional): A Visualizer object | ||
or a dict to build Visualizer object. If ``visualizer`` is a | ||
Visualizer object, just returns itself. If not specified, | ||
default config will be used to build Visualizer object. | ||
Defaults to None. | ||
Returns: | ||
Visualizer: A Visualizer object build from ``visualizer``. | ||
""" | ||
if visualizer is None: | ||
visualizer = dict( | ||
name=experiment_name, | ||
vis_backends=[dict(type='LocalVisBackend')], | ||
save_dir=log_dir) | ||
return Visualizer.get_instance(**visualizer) | ||
|
||
if isinstance(visualizer, Visualizer): | ||
return visualizer | ||
|
||
if isinstance(visualizer, dict): | ||
# ensure visualizer containing name key | ||
visualizer.setdefault('name', experiment_name) | ||
visualizer.setdefault('save_dir', log_dir) | ||
return VISUALIZERS.build(visualizer) | ||
else: | ||
raise TypeError( | ||
'visualizer should be Visualizer object, a dict or None, ' | ||
f'but got {visualizer}') |