-
Notifications
You must be signed in to change notification settings - Fork 334
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
add build_models() / build_evaluator() / build_log_processor()
- Loading branch information
1 parent
a483dba
commit 1032279
Showing
5 changed files
with
139 additions
and
3 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,6 +1,10 @@ | ||
# Copyright (c) OpenMMLab. All rights reserved. | ||
from .builder import build_evaluator | ||
from .evaluator import Evaluator | ||
from .metric import BaseMetric, DumpResults | ||
from .utils import get_metric_value | ||
|
||
__all__ = ['BaseMetric', 'Evaluator', 'get_metric_value', 'DumpResults'] | ||
__all__ = [ | ||
'BaseMetric', 'Evaluator', 'get_metric_value', 'DumpResults', | ||
'build_evaluator' | ||
] |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,53 @@ | ||
# Copyright (c) OpenMMLab. All rights reserved. | ||
from typing import Dict, List, Union | ||
|
||
from mmengine.registry import EVALUATOR | ||
from .evaluator import Evaluator | ||
|
||
|
||
def build_evaluator(evaluator: Union[Dict, List, Evaluator]) -> Evaluator: | ||
"""Build evaluator. | ||
Examples of ``evaluator``:: | ||
# evaluator could be a built Evaluator instance | ||
evaluator = Evaluator(metrics=[ToyMetric()]) | ||
# evaluator can also be a list of dict | ||
evaluator = [ | ||
dict(type='ToyMetric1'), | ||
dict(type='ToyEvaluator2') | ||
] | ||
# evaluator can also be a list of built metric | ||
evaluator = [ToyMetric1(), ToyMetric2()] | ||
# evaluator can also be a dict with key metrics | ||
evaluator = dict(metrics=ToyMetric()) | ||
# metric is a list | ||
evaluator = dict(metrics=[ToyMetric()]) | ||
Args: | ||
evaluator (Evaluator or dict or list): An Evaluator object or a | ||
config dict or list of config dict used to build an Evaluator. | ||
Returns: | ||
Evaluator: Evaluator build from ``evaluator``. | ||
""" | ||
if isinstance(evaluator, Evaluator): | ||
return evaluator | ||
elif isinstance(evaluator, dict): | ||
# if `metrics` in dict keys, it means to build customized evalutor | ||
if 'metrics' in evaluator: | ||
evaluator.setdefault('type', 'Evaluator') | ||
return EVALUATOR.build(evaluator) | ||
# otherwise, default evalutor will be built | ||
else: | ||
return Evaluator(evaluator) # type: ignore | ||
elif isinstance(evaluator, list): | ||
# use the default `Evaluator` | ||
return Evaluator(evaluator) # type: ignore | ||
else: | ||
raise TypeError( | ||
'evaluator should be one of dict, list of dict, and Evaluator' | ||
f', but got {evaluator}') |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,40 @@ | ||
# Copyright (c) OpenMMLab. All rights reserved. | ||
from typing import Dict, Union | ||
|
||
import torch.nn as nn | ||
|
||
from mmengine.registry import MODELS | ||
|
||
|
||
def build_model(model: Union[nn.Module, Dict]) -> nn.Module: | ||
"""Build function of Model. | ||
If ``model`` is a dict, it will be used to build a nn.Module object. | ||
Else, if ``model`` is a nn.Module object it will be returned directly. | ||
An example of ``model``:: | ||
model = dict(type='ResNet') | ||
Args: | ||
model (nn.Module or dict): A ``nn.Module`` object or a dict to | ||
build nn.Module object. If ``model`` is a nn.Module object, | ||
just returns itself. | ||
Note: | ||
The returned model must implement ``train_step``, ``test_step`` | ||
if ``runner.train`` or ``runner.test`` will be called. If | ||
``runner.val`` will be called or ``val_cfg`` is configured, | ||
model must implement `val_step`. | ||
Returns: | ||
nn.Module: Model build from ``model``. | ||
""" | ||
if isinstance(model, nn.Module): | ||
return model | ||
elif isinstance(model, dict): | ||
model = MODELS.build(model) | ||
return model # type: ignore | ||
else: | ||
raise TypeError('model should be a nn.Module object or dict, ' | ||
f'but got {model}') |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters