From 417bf0c67680406acdd2b751538c2472ef313157 Mon Sep 17 00:00:00 2001 From: wanghonglie Date: Fri, 16 Sep 2022 18:06:30 +0800 Subject: [PATCH] update --- tools/analysis_tools/test_robustness.py | 31 ++++++++++++------------- 1 file changed, 15 insertions(+), 16 deletions(-) diff --git a/tools/analysis_tools/test_robustness.py b/tools/analysis_tools/test_robustness.py index 20f77e8d266..83b7f3de9ed 100644 --- a/tools/analysis_tools/test_robustness.py +++ b/tools/analysis_tools/test_robustness.py @@ -18,7 +18,10 @@ def parse_args(): parser = argparse.ArgumentParser(description='MMDet test detector') parser.add_argument('config', help='test config file path') parser.add_argument('checkpoint', help='checkpoint file') - parser.add_argument('--out', help='output result file') + parser.add_argument( + '--dump', + type=str, + help='dump predictions to a pickle file for offline evaluation') parser.add_argument( '--corruptions', type=str, @@ -115,12 +118,9 @@ def main(): # do not init the default scope here because it will be init in the runner register_all_modules(init_default_scope=False) - assert args.out or args.show or args.show_dir, \ + assert args.dump or args.show or args.show_dir, \ ('Please specify at least one operation (save or show the results) ' - 'with the argument "--out", "--show" or "show-dir"') - - if args.out is not None and not args.out.endswith(('.pkl', '.pickle')): - raise ValueError('The output file must be a pkl file.') + 'with the argument "--dump", "--show" or "show-dir"') # load config cfg = Config.fromfile(args.config) @@ -137,6 +137,15 @@ def main(): cfg.work_dir = osp.join('./work_dirs', osp.splitext(osp.basename(args.config))[0]) + if args.dump is not None: + assert args.dump.endswith(('.pkl', '.pickle')), \ + 'The dump file must be a pkl file.' + dump_metric = dict(type='DumpResults', out_file_path=args.dump) + if isinstance(cfg.test_evaluator, (list, tuple)): + cfg.test_evaluator = list(cfg.test_evaluator).append(dump_metric) + else: + cfg.test_evaluator = [cfg.test_evaluator, dump_metric] + cfg.model.backbone.init_cfg.type = None cfg.test_dataloader.dataset.test_mode = True @@ -144,16 +153,6 @@ def main(): if args.show or args.show_dir: cfg = trigger_visualization_hook(cfg, args) - if args.out: - test_evaluator = dict( - type='DumpResults', - out_file_path='robust.pkl', - ) - if isinstance(cfg.test_evaluator, dict): - cfg.test_evaluator = [cfg.test_evaluator, test_evaluator] - elif isinstance(cfg.test_evaluator, list): - cfg.test_evaluator = cfg.test_evaluator.append(test_evaluator) - # build the runner from config if 'runner_type' not in cfg: # build the default runner