Skip to content

Commit

Permalink
update
Browse files Browse the repository at this point in the history
  • Loading branch information
wanghonglie committed Sep 19, 2022
1 parent 6fb9251 commit fd87b66
Showing 1 changed file with 8 additions and 30 deletions.
38 changes: 8 additions & 30 deletions tools/analysis_tools/test_robustness.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,8 +9,9 @@
from mmengine.fileio import dump
from mmengine.runner import Runner

from mmdet.engine.hooks.utils import trigger_visualization_hook
from mmdet.registry import RUNNERS
from mmdet.utils import register_all_modules
from mmdet.utils import add_dump_metric, register_all_modules
from tools.analysis_tools.robustness_eval import get_results


Expand All @@ -19,7 +20,7 @@ def parse_args():
parser.add_argument('config', help='test config file path')
parser.add_argument('checkpoint', help='checkpoint file')
parser.add_argument(
'--dump',
'--out',
type=str,
help='dump predictions to a pickle file for offline evaluation')
parser.add_argument(
Expand Down Expand Up @@ -91,34 +92,14 @@ def parse_args():
return args


def trigger_visualization_hook(cfg, args):
default_hooks = cfg.default_hooks
if 'visualization' in default_hooks:
visualization_hook = default_hooks['visualization']
# Turn on visualization
visualization_hook['draw'] = True
if args.show:
visualization_hook['show'] = True
visualization_hook['wait_time'] = args.wait_time
if args.show_dir:
visualization_hook['test_out_dir'] = args.show_dir
else:
raise RuntimeError(
'VisualizationHook must be included in default_hooks.'
'refer to usage '
'"visualization=dict(type=\'VisualizationHook\')"')

return cfg


def main():
args = parse_args()

# register all modules in mmdet into the registries
# do not init the default scope here because it will be init in the runner
register_all_modules(init_default_scope=False)

assert args.dump or args.show or args.show_dir, \
assert args.out or args.show or args.show_dir, \
('Please specify at least one operation (save or show the results) '
'with the argument "--dump", "--show" or "show-dir"')

Expand All @@ -137,14 +118,11 @@ def main():
cfg.work_dir = osp.join('./work_dirs',
osp.splitext(osp.basename(args.config))[0])

if args.dump is not None:
assert args.dump.endswith(('.pkl', '.pickle')), \
# Dump predictions
if args.out is not None:
assert args.out.endswith(('.pkl', '.pickle')), \
'The dump file must be a pkl file.'
dump_metric = dict(type='DumpResults', out_file_path=args.dump)
if isinstance(cfg.test_evaluator, (list, tuple)):
cfg.test_evaluator = list(cfg.test_evaluator).append(dump_metric)
else:
cfg.test_evaluator = [cfg.test_evaluator, dump_metric]
add_dump_metric(args, cfg)

cfg.model.backbone.init_cfg.type = None
cfg.test_dataloader.dataset.test_mode = True
Expand Down

0 comments on commit fd87b66

Please sign in to comment.