Skip to content

Commit

Permalink
[Feature]Implement dump results (#8814)
Browse files Browse the repository at this point in the history
* Implement dump results

* support xxxevaluator

* update

* reuse func
  • Loading branch information
wanghonglie authored and ZwwWayne committed Sep 26, 2022
1 parent d18ec25 commit 7a649f9
Show file tree
Hide file tree
Showing 5 changed files with 51 additions and 24 deletions.
3 changes: 2 additions & 1 deletion mmdet/engine/hooks/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,11 +5,12 @@
from .num_class_check_hook import NumClassCheckHook
from .set_epoch_info_hook import SetEpochInfoHook
from .sync_norm_hook import SyncNormHook
from .utils import trigger_visualization_hook
from .visualization_hook import DetVisualizationHook
from .yolox_mode_switch_hook import YOLOXModeSwitchHook

__all__ = [
'YOLOXModeSwitchHook', 'SyncNormHook', 'CheckInvalidLossHook',
'SetEpochInfoHook', 'MemoryProfilerHook', 'DetVisualizationHook',
'NumClassCheckHook', 'MeanTeacherHook'
'NumClassCheckHook', 'MeanTeacherHook', 'trigger_visualization_hook'
]
19 changes: 19 additions & 0 deletions mmdet/engine/hooks/utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
# Copyright (c) OpenMMLab. All rights reserved.
def trigger_visualization_hook(cfg, args):
default_hooks = cfg.default_hooks
if 'visualization' in default_hooks:
visualization_hook = default_hooks['visualization']
# Turn on visualization
visualization_hook['draw'] = True
if args.show:
visualization_hook['show'] = True
visualization_hook['wait_time'] = args.wait_time
if args.show_dir:
visualization_hook['test_out_dir'] = args.show_dir
else:
raise RuntimeError(
'VisualizationHook must be included in default_hooks.'
'refer to usage '
'"visualization=dict(type=\'VisualizationHook\')"')

return cfg
4 changes: 2 additions & 2 deletions mmdet/utils/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
sync_random_seed)
from .logger import get_caller_name, log_img_scale
from .memory import AvoidCUDAOOM, AvoidOOM
from .misc import find_latest_checkpoint, update_data_root
from .misc import add_dump_metric, find_latest_checkpoint, update_data_root
from .replace_cfg_vals import replace_cfg_vals
from .setup_env import register_all_modules, setup_multi_processes
from .split_batch import split_batch
Expand All @@ -20,5 +20,5 @@
'AvoidCUDAOOM', 'all_reduce_dict', 'allreduce_grads', 'reduce_mean',
'sync_random_seed', 'ConfigType', 'InstanceList', 'MultiConfig',
'OptConfigType', 'OptInstanceList', 'OptMultiConfig', 'OptPixelList',
'PixelList', 'RangeType'
'PixelList', 'RangeType', 'add_dump_metric'
]
16 changes: 16 additions & 0 deletions mmdet/utils/misc.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,3 +74,19 @@ def update(cfg, src_str, dst_str):

update(cfg.data, cfg.data_root, dst_root)
cfg.data_root = dst_root


def add_dump_metric(args, cfg):
dump_metric = dict(type='DumpResults', out_file_path=args.out)
if isinstance(cfg.test_evaluator, (list, tuple)):
cfg.test_evaluator = list(cfg.test_evaluator).append(dump_metric)
elif isinstance(cfg.test_evaluator, dict):
if isinstance(cfg.test_evaluator.metric, str):
cfg.test_evaluator = [cfg.test_evaluator, dump_metric]
elif isinstance(cfg.test_evaluator.metric, (list, tuple)):
cfg.test_evaluator.metric = list(
cfg.test_evaluator.metric).append(dump_metric)
else:
cfg.test_evaluator.metric = [
cfg.test_evaluator.metric, dump_metric
]
33 changes: 12 additions & 21 deletions tools/test.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,9 @@
from mmengine.config import Config, DictAction
from mmengine.runner import Runner

from mmdet.engine.hooks.utils import trigger_visualization_hook
from mmdet.registry import RUNNERS
from mmdet.utils import register_all_modules
from mmdet.utils import add_dump_metric, register_all_modules


# TODO: support fuse_conv_bn and format_only
Expand All @@ -19,6 +20,10 @@ def parse_args():
parser.add_argument(
'--work-dir',
help='the directory to save the file containing evaluation metrics')
parser.add_argument(
'--out',
type=str,
help='dump predictions to a pickle file for offline evaluation')
parser.add_argument(
'--show', action='store_true', help='show prediction results')
parser.add_argument(
Expand Down Expand Up @@ -50,26 +55,6 @@ def parse_args():
return args


def trigger_visualization_hook(cfg, args):
default_hooks = cfg.default_hooks
if 'visualization' in default_hooks:
visualization_hook = default_hooks['visualization']
# Turn on visualization
visualization_hook['draw'] = True
if args.show:
visualization_hook['show'] = True
visualization_hook['wait_time'] = args.wait_time
if args.show_dir:
visualization_hook['test_out_dir'] = args.show_dir
else:
raise RuntimeError(
'VisualizationHook must be included in default_hooks.'
'refer to usage '
'"visualization=dict(type=\'VisualizationHook\')"')

return cfg


def main():
args = parse_args()

Expand Down Expand Up @@ -97,6 +82,12 @@ def main():
if args.show or args.show_dir:
cfg = trigger_visualization_hook(cfg, args)

# Dump predictions
if args.out is not None:
assert args.out.endswith(('.pkl', '.pickle')), \
'The dump file must be a pkl file.'
add_dump_metric(args, cfg)

# build the runner from config
if 'runner_type' not in cfg:
# build the default runner
Expand Down

0 comments on commit 7a649f9

Please sign in to comment.