Skip to content

Commit

Permalink
update
Browse files Browse the repository at this point in the history
  • Loading branch information
wanghonglie committed Sep 16, 2022
1 parent 836ee92 commit 417bf0c
Showing 1 changed file with 15 additions and 16 deletions.
31 changes: 15 additions & 16 deletions tools/analysis_tools/test_robustness.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,10 @@ def parse_args():
parser = argparse.ArgumentParser(description='MMDet test detector')
parser.add_argument('config', help='test config file path')
parser.add_argument('checkpoint', help='checkpoint file')
parser.add_argument('--out', help='output result file')
parser.add_argument(
'--dump',
type=str,
help='dump predictions to a pickle file for offline evaluation')
parser.add_argument(
'--corruptions',
type=str,
Expand Down Expand Up @@ -115,12 +118,9 @@ def main():
# do not init the default scope here because it will be init in the runner
register_all_modules(init_default_scope=False)

assert args.out or args.show or args.show_dir, \
assert args.dump or args.show or args.show_dir, \
('Please specify at least one operation (save or show the results) '
'with the argument "--out", "--show" or "show-dir"')

if args.out is not None and not args.out.endswith(('.pkl', '.pickle')):
raise ValueError('The output file must be a pkl file.')
'with the argument "--dump", "--show" or "show-dir"')

# load config
cfg = Config.fromfile(args.config)
Expand All @@ -137,23 +137,22 @@ def main():
cfg.work_dir = osp.join('./work_dirs',
osp.splitext(osp.basename(args.config))[0])

if args.dump is not None:
assert args.dump.endswith(('.pkl', '.pickle')), \
'The dump file must be a pkl file.'
dump_metric = dict(type='DumpResults', out_file_path=args.dump)
if isinstance(cfg.test_evaluator, (list, tuple)):
cfg.test_evaluator = list(cfg.test_evaluator).append(dump_metric)
else:
cfg.test_evaluator = [cfg.test_evaluator, dump_metric]

cfg.model.backbone.init_cfg.type = None
cfg.test_dataloader.dataset.test_mode = True

cfg.load_from = args.checkpoint
if args.show or args.show_dir:
cfg = trigger_visualization_hook(cfg, args)

if args.out:
test_evaluator = dict(
type='DumpResults',
out_file_path='robust.pkl',
)
if isinstance(cfg.test_evaluator, dict):
cfg.test_evaluator = [cfg.test_evaluator, test_evaluator]
elif isinstance(cfg.test_evaluator, list):
cfg.test_evaluator = cfg.test_evaluator.append(test_evaluator)

# build the runner from config
if 'runner_type' not in cfg:
# build the default runner
Expand Down

0 comments on commit 417bf0c

Please sign in to comment.