diff --git a/mmdet/utils/__init__.py b/mmdet/utils/__init__.py index 350452a95d6..03d7ea0d467 100644 --- a/mmdet/utils/__init__.py +++ b/mmdet/utils/__init__.py @@ -3,6 +3,7 @@ from .compat_config import compat_cfg from .logger import get_caller_name, get_root_logger, log_img_scale from .misc import find_latest_checkpoint, update_data_root +from .replace_cfg_vals import replace_cfg_vals from .setup_env import setup_multi_processes from .split_batch import split_batch from .util_distribution import build_ddp, build_dp, get_device @@ -11,5 +12,5 @@ 'get_root_logger', 'collect_env', 'find_latest_checkpoint', 'update_data_root', 'setup_multi_processes', 'get_caller_name', 'log_img_scale', 'compat_cfg', 'split_batch', 'build_ddp', 'build_dp', - 'get_device' + 'get_device', 'replace_cfg_vals' ] diff --git a/mmdet/utils/replace_cfg_vals.py b/mmdet/utils/replace_cfg_vals.py new file mode 100644 index 00000000000..6ca301dc937 --- /dev/null +++ b/mmdet/utils/replace_cfg_vals.py @@ -0,0 +1,70 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import re + +from mmcv.utils import Config + + +def replace_cfg_vals(ori_cfg): + """Replace the string "${key}" with the corresponding value. + + Replace the "${key}" with the value of ori_cfg.key in the config. And + support replacing the chained ${key}. Such as, replace "${key0.key1}" + with the value of cfg.key0.key1. Code is modified from `vars.py + < https://github.com/microsoft/SoftTeacher/blob/main/ssod/utils/vars.py>`_ # noqa: E501 + + Args: + ori_cfg (mmcv.utils.config.Config): + The origin config with "${key}" generated from a file. + + Returns: + updated_cfg [mmcv.utils.config.Config]: + The config with "${key}" replaced by the corresponding value. + """ + + def get_value(cfg, key): + for k in key.split('.'): + cfg = cfg[k] + return cfg + + def replace_value(cfg): + if isinstance(cfg, dict): + return {key: replace_value(value) for key, value in cfg.items()} + elif isinstance(cfg, list): + return [replace_value(item) for item in cfg] + elif isinstance(cfg, tuple): + return tuple([replace_value(item) for item in cfg]) + elif isinstance(cfg, str): + # the format of string cfg may be: + # 1) "${key}", which will be replaced with cfg.key directly + # 2) "xxx${key}xxx" or "xxx${key1}xxx${key2}xxx", + # which will be replaced with the string of the cfg.key + keys = pattern_key.findall(cfg) + values = [get_value(ori_cfg, key[2:-1]) for key in keys] + if len(keys) == 1 and keys[0] == cfg: + # the format of string cfg is "${key}" + cfg = values[0] + else: + for key, value in zip(keys, values): + # the format of string cfg is + # "xxx${key}xxx" or "xxx${key1}xxx${key2}xxx" + assert not isinstance(value, (dict, list, tuple)), \ + f'for the format of string cfg is ' \ + f"'xxxxx${key}xxxxx' or 'xxx${key}xxx${key}xxx', " \ + f"the type of the value of '${key}' " \ + f'can not be dict, list, or tuple' \ + f'but you input {type(value)} in {cfg}' + cfg = cfg.replace(key, str(value)) + return cfg + else: + return cfg + + # the pattern of string "${key}" + pattern_key = re.compile(r'\$\{[a-zA-Z\d_.]*\}') + # the type of ori_cfg._cfg_dict is mmcv.utils.config.ConfigDict + updated_cfg = Config( + replace_value(ori_cfg._cfg_dict), filename=ori_cfg.filename) + # replace the model with model_wrapper + if updated_cfg.get('model_wrapper', None) is not None: + updated_cfg.model = updated_cfg.model_wrapper + updated_cfg.pop('model_wrapper') + return updated_cfg diff --git a/tests/test_utils/test_replace_cfg_vals.py b/tests/test_utils/test_replace_cfg_vals.py new file mode 100644 index 00000000000..85d9d0e2fa0 --- /dev/null +++ b/tests/test_utils/test_replace_cfg_vals.py @@ -0,0 +1,83 @@ +import os.path as osp +import tempfile +from copy import deepcopy + +import pytest +from mmcv.utils import Config + +from mmdet.utils import replace_cfg_vals + + +def test_replace_cfg_vals(): + temp_file = tempfile.NamedTemporaryFile() + cfg_path = f'{temp_file.name}.py' + with open(cfg_path, 'w') as f: + f.write('configs') + + ori_cfg_dict = dict() + ori_cfg_dict['cfg_name'] = osp.basename(temp_file.name) + ori_cfg_dict['work_dir'] = 'work_dirs/${cfg_name}/${percent}/${fold}' + ori_cfg_dict['percent'] = 5 + ori_cfg_dict['fold'] = 1 + ori_cfg_dict['model_wrapper'] = dict( + type='SoftTeacher', detector='${model}') + ori_cfg_dict['model'] = dict( + type='FasterRCNN', + backbone=dict(type='ResNet'), + neck=dict(type='FPN'), + rpn_head=dict(type='RPNHead'), + roi_head=dict(type='StandardRoIHead'), + train_cfg=dict( + rpn=dict( + assigner=dict(type='MaxIoUAssigner'), + sampler=dict(type='RandomSampler'), + ), + rpn_proposal=dict(nms=dict(type='nms', iou_threshold=0.7)), + rcnn=dict( + assigner=dict(type='MaxIoUAssigner'), + sampler=dict(type='RandomSampler'), + ), + ), + test_cfg=dict( + rpn=dict(nms=dict(type='nms', iou_threshold=0.7)), + rcnn=dict(nms=dict(type='nms', iou_threshold=0.5)), + ), + ) + ori_cfg_dict['iou_threshold'] = dict( + rpn_proposal_nms='${model.train_cfg.rpn_proposal.nms.iou_threshold}', + test_rpn_nms='${model.test_cfg.rpn.nms.iou_threshold}', + test_rcnn_nms='${model.test_cfg.rcnn.nms.iou_threshold}', + ) + + ori_cfg_dict['str'] = 'Hello, world!' + ori_cfg_dict['dict'] = {'Hello': 'world!'} + ori_cfg_dict['list'] = [ + 'Hello, world!', + ] + ori_cfg_dict['tuple'] = ('Hello, world!', ) + ori_cfg_dict['test_str'] = 'xxx${str}xxx' + + ori_cfg = Config(ori_cfg_dict, filename=cfg_path) + updated_cfg = replace_cfg_vals(deepcopy(ori_cfg)) + + assert updated_cfg.work_dir \ + == f'work_dirs/{osp.basename(temp_file.name)}/5/1' + assert updated_cfg.model.detector == ori_cfg.model + assert updated_cfg.iou_threshold.rpn_proposal_nms \ + == ori_cfg.model.train_cfg.rpn_proposal.nms.iou_threshold + assert updated_cfg.test_str == 'xxxHello, world!xxx' + ori_cfg_dict['test_dict'] = 'xxx${dict}xxx' + ori_cfg_dict['test_list'] = 'xxx${list}xxx' + ori_cfg_dict['test_tuple'] = 'xxx${tuple}xxx' + with pytest.raises(AssertionError): + cfg = deepcopy(ori_cfg) + cfg['test_dict'] = 'xxx${dict}xxx' + updated_cfg = replace_cfg_vals(cfg) + with pytest.raises(AssertionError): + cfg = deepcopy(ori_cfg) + cfg['test_list'] = 'xxx${list}xxx' + updated_cfg = replace_cfg_vals(cfg) + with pytest.raises(AssertionError): + cfg = deepcopy(ori_cfg) + cfg['test_tuple'] = 'xxx${tuple}xxx' + updated_cfg = replace_cfg_vals(cfg) diff --git a/tools/analysis_tools/analyze_results.py b/tools/analysis_tools/analyze_results.py index 15db07e41c7..916a6b2200f 100644 --- a/tools/analysis_tools/analyze_results.py +++ b/tools/analysis_tools/analyze_results.py @@ -9,7 +9,7 @@ from mmdet.core.evaluation import eval_map from mmdet.core.visualization import imshow_gt_det_bboxes from mmdet.datasets import build_dataset, get_loading_pipeline -from mmdet.utils import update_data_root +from mmdet.utils import replace_cfg_vals, update_data_root def bbox_map_eval(det_result, annotation): @@ -188,6 +188,9 @@ def main(): cfg = Config.fromfile(args.config) + # replace the ${key} with the value of cfg.key + cfg = replace_cfg_vals(cfg) + # update data root according to MMDET_DATASETS update_data_root(cfg) diff --git a/tools/analysis_tools/benchmark.py b/tools/analysis_tools/benchmark.py index 2be2d14d7b4..c956968beed 100644 --- a/tools/analysis_tools/benchmark.py +++ b/tools/analysis_tools/benchmark.py @@ -13,7 +13,7 @@ from mmdet.datasets import (build_dataloader, build_dataset, replace_ImageToTensor) from mmdet.models import build_detector -from mmdet.utils import update_data_root +from mmdet.utils import replace_cfg_vals, update_data_root def parse_args(): @@ -172,6 +172,9 @@ def main(): cfg = Config.fromfile(args.config) + # replace the ${key} with the value of cfg.key + cfg = replace_cfg_vals(cfg) + # update data root according to MMDET_DATASETS update_data_root(cfg) diff --git a/tools/analysis_tools/confusion_matrix.py b/tools/analysis_tools/confusion_matrix.py index 1c2ceb47acf..5b52ea4c0ff 100644 --- a/tools/analysis_tools/confusion_matrix.py +++ b/tools/analysis_tools/confusion_matrix.py @@ -10,7 +10,7 @@ from mmdet.core.evaluation.bbox_overlaps import bbox_overlaps from mmdet.datasets import build_dataset -from mmdet.utils import update_data_root +from mmdet.utils import replace_cfg_vals, update_data_root def parse_args(): @@ -232,6 +232,9 @@ def main(): cfg = Config.fromfile(args.config) + # replace the ${key} with the value of cfg.key + cfg = replace_cfg_vals(cfg) + # update data root according to MMDET_DATASETS update_data_root(cfg) diff --git a/tools/analysis_tools/eval_metric.py b/tools/analysis_tools/eval_metric.py index a074c9e1850..7caafe99df0 100644 --- a/tools/analysis_tools/eval_metric.py +++ b/tools/analysis_tools/eval_metric.py @@ -5,7 +5,7 @@ from mmcv import Config, DictAction from mmdet.datasets import build_dataset -from mmdet.utils import update_data_root +from mmdet.utils import replace_cfg_vals, update_data_root def parse_args(): @@ -50,6 +50,9 @@ def main(): cfg = Config.fromfile(args.config) + # replace the ${key} with the value of cfg.key + cfg = replace_cfg_vals(cfg) + # update data root according to MMDET_DATASETS update_data_root(cfg) diff --git a/tools/analysis_tools/optimize_anchors.py b/tools/analysis_tools/optimize_anchors.py index acf72acb26c..421998f945d 100644 --- a/tools/analysis_tools/optimize_anchors.py +++ b/tools/analysis_tools/optimize_anchors.py @@ -29,7 +29,7 @@ from mmdet.core import bbox_cxcywh_to_xyxy, bbox_overlaps, bbox_xyxy_to_cxcywh from mmdet.datasets import build_dataset -from mmdet.utils import get_root_logger, update_data_root +from mmdet.utils import get_root_logger, replace_cfg_vals, update_data_root def parse_args(): @@ -325,6 +325,9 @@ def main(): cfg = args.config cfg = Config.fromfile(cfg) + # replace the ${key} with the value of cfg.key + cfg = replace_cfg_vals(cfg) + # update data root according to MMDET_DATASETS update_data_root(cfg) diff --git a/tools/misc/browse_dataset.py b/tools/misc/browse_dataset.py index 14db64ee050..d9fb2851220 100644 --- a/tools/misc/browse_dataset.py +++ b/tools/misc/browse_dataset.py @@ -11,7 +11,7 @@ from mmdet.core.utils import mask2ndarray from mmdet.core.visualization import imshow_det_bboxes from mmdet.datasets.builder import build_dataset -from mmdet.utils import update_data_root +from mmdet.utils import replace_cfg_vals, update_data_root def parse_args(): @@ -57,6 +57,9 @@ def skip_pipeline_steps(config): cfg = Config.fromfile(config_path) + # replace the ${key} with the value of cfg.key + cfg = replace_cfg_vals(cfg) + # update data root according to MMDET_DATASETS update_data_root(cfg) diff --git a/tools/misc/print_config.py b/tools/misc/print_config.py index 7bb20fa60de..f10f5384a6a 100644 --- a/tools/misc/print_config.py +++ b/tools/misc/print_config.py @@ -4,7 +4,7 @@ from mmcv import Config, DictAction -from mmdet.utils import update_data_root +from mmdet.utils import replace_cfg_vals, update_data_root def parse_args(): @@ -45,6 +45,9 @@ def main(): cfg = Config.fromfile(args.config) + # replace the ${key} with the value of cfg.key + cfg = replace_cfg_vals(cfg) + # update data root according to MMDET_DATASETS update_data_root(cfg) diff --git a/tools/test.py b/tools/test.py index a6cd9ecbb89..884e0a4f6c6 100644 --- a/tools/test.py +++ b/tools/test.py @@ -17,7 +17,8 @@ replace_ImageToTensor) from mmdet.models import build_detector from mmdet.utils import (build_ddp, build_dp, compat_cfg, get_device, - setup_multi_processes, update_data_root) + replace_cfg_vals, setup_multi_processes, + update_data_root) def parse_args(): @@ -134,6 +135,9 @@ def main(): cfg = Config.fromfile(args.config) + # replace the ${key} with the value of cfg.key + cfg = replace_cfg_vals(cfg) + # update data root according to MMDET_DATASETS update_data_root(cfg) diff --git a/tools/train.py b/tools/train.py index fd29cae1311..cff19f037e1 100644 --- a/tools/train.py +++ b/tools/train.py @@ -18,7 +18,8 @@ from mmdet.datasets import build_dataset from mmdet.models import build_detector from mmdet.utils import (collect_env, get_device, get_root_logger, - setup_multi_processes, update_data_root) + replace_cfg_vals, setup_multi_processes, + update_data_root) def parse_args(): @@ -109,6 +110,9 @@ def main(): cfg = Config.fromfile(args.config) + # replace the ${key} with the value of cfg.key + cfg = replace_cfg_vals(cfg) + # update data root according to MMDET_DATASETS update_data_root(cfg) @@ -142,6 +146,7 @@ def main(): # use config filename as default work_dir if cfg.work_dir is None cfg.work_dir = osp.join('./work_dirs', osp.splitext(osp.basename(args.config))[0]) + if args.resume_from is not None: cfg.resume_from = args.resume_from cfg.auto_resume = args.auto_resume