From fd87b664c32f58f218f7816876d254079bdd1069 Mon Sep 17 00:00:00 2001 From: wanghonglie Date: Mon, 19 Sep 2022 18:03:49 +0800 Subject: [PATCH] update --- tools/analysis_tools/test_robustness.py | 38 ++++++------------------- 1 file changed, 8 insertions(+), 30 deletions(-) diff --git a/tools/analysis_tools/test_robustness.py b/tools/analysis_tools/test_robustness.py index 83b7f3de9ed..2a6499439ee 100644 --- a/tools/analysis_tools/test_robustness.py +++ b/tools/analysis_tools/test_robustness.py @@ -9,8 +9,9 @@ from mmengine.fileio import dump from mmengine.runner import Runner +from mmdet.engine.hooks.utils import trigger_visualization_hook from mmdet.registry import RUNNERS -from mmdet.utils import register_all_modules +from mmdet.utils import add_dump_metric, register_all_modules from tools.analysis_tools.robustness_eval import get_results @@ -19,7 +20,7 @@ def parse_args(): parser.add_argument('config', help='test config file path') parser.add_argument('checkpoint', help='checkpoint file') parser.add_argument( - '--dump', + '--out', type=str, help='dump predictions to a pickle file for offline evaluation') parser.add_argument( @@ -91,26 +92,6 @@ def parse_args(): return args -def trigger_visualization_hook(cfg, args): - default_hooks = cfg.default_hooks - if 'visualization' in default_hooks: - visualization_hook = default_hooks['visualization'] - # Turn on visualization - visualization_hook['draw'] = True - if args.show: - visualization_hook['show'] = True - visualization_hook['wait_time'] = args.wait_time - if args.show_dir: - visualization_hook['test_out_dir'] = args.show_dir - else: - raise RuntimeError( - 'VisualizationHook must be included in default_hooks.' - 'refer to usage ' - '"visualization=dict(type=\'VisualizationHook\')"') - - return cfg - - def main(): args = parse_args() @@ -118,7 +99,7 @@ def main(): # do not init the default scope here because it will be init in the runner register_all_modules(init_default_scope=False) - assert args.dump or args.show or args.show_dir, \ + assert args.out or args.show or args.show_dir, \ ('Please specify at least one operation (save or show the results) ' 'with the argument "--dump", "--show" or "show-dir"') @@ -137,14 +118,11 @@ def main(): cfg.work_dir = osp.join('./work_dirs', osp.splitext(osp.basename(args.config))[0]) - if args.dump is not None: - assert args.dump.endswith(('.pkl', '.pickle')), \ + # Dump predictions + if args.out is not None: + assert args.out.endswith(('.pkl', '.pickle')), \ 'The dump file must be a pkl file.' - dump_metric = dict(type='DumpResults', out_file_path=args.dump) - if isinstance(cfg.test_evaluator, (list, tuple)): - cfg.test_evaluator = list(cfg.test_evaluator).append(dump_metric) - else: - cfg.test_evaluator = [cfg.test_evaluator, dump_metric] + add_dump_metric(args, cfg) cfg.model.backbone.init_cfg.type = None cfg.test_dataloader.dataset.test_mode = True