Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Fix] Skip adding vis_backends when save_dir is not set #1289

Merged
merged 8 commits into from
Aug 3, 2023
Merged
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
78 changes: 48 additions & 30 deletions mmengine/visualization/visualizer.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
# Copyright (c) OpenMMLab. All rights reserved.
import inspect
import os.path as osp
import warnings
from typing import TYPE_CHECKING, Dict, List, Optional, Sequence, Tuple, Union
Expand All @@ -15,14 +16,17 @@
from mmengine.dist import master_only
from mmengine.registry import VISBACKENDS, VISUALIZERS
from mmengine.structures import BaseDataElement
from mmengine.utils import ManagerMixin
from mmengine.utils import ManagerMixin, is_seq_of
from mmengine.visualization.utils import (check_type, check_type_and_length,
color_str2rgb, color_val_matplotlib,
convert_overlay_heatmap,
img_from_canvas, tensor2ndarray,
value2list, wait_continue)
from mmengine.visualization.vis_backend import BaseVisBackend

VisBackendsType = Union[List[Union[List, BaseDataElement]], BaseDataElement,
dict, None]


@VISUALIZERS.register_module()
class Visualizer(ManagerMixin):
Expand Down Expand Up @@ -153,42 +157,56 @@ def __init__(
self,
name='visualizer',
image: Optional[np.ndarray] = None,
vis_backends: Optional[List[Dict]] = None,
vis_backends: VisBackendsType = None,
save_dir: Optional[str] = None,
fig_save_cfg=dict(frameon=False),
fig_show_cfg=dict(frameon=False)
) -> None:
super().__init__(name)
self._dataset_meta: Optional[dict] = None
self._vis_backends: Union[Dict, Dict[str, 'BaseVisBackend']] = dict()

if vis_backends is not None:
assert len(vis_backends) > 0, 'empty list'
names = [vis_backend.get('name') for vis_backend in vis_backends]
if None in names:
if len(set(names)) > 1:
raise RuntimeError(
'If one of them has a name attribute, '
'all backends must use the name attribute')
else:
type_names = [
vis_backend['type'] for vis_backend in vis_backends
]
if len(set(type_names)) != len(type_names):
raise RuntimeError(
'The same vis backend cannot exist in '
'`vis_backend` config. '
'Please specify the name field.')

if None not in names and len(set(names)) != len(names):
raise RuntimeError('The name fields cannot be the same')

if save_dir is not None:
save_dir = osp.join(save_dir, 'vis_data')
for vis_backend in vis_backends:
name = vis_backend.pop('name', vis_backend['type'])
self._vis_backends: Dict[str, BaseVisBackend] = {}

if vis_backends is None:
vis_backends = []

if isinstance(vis_backends, (dict, BaseVisBackend)):
vis_backends = [vis_backends] # type: ignore

if not is_seq_of(vis_backends, (dict, BaseVisBackend)):
raise TypeError('vis_backends must be a list of dicts or a list '
'of BaseBackend instances')
if save_dir is not None:
save_dir = osp.join(save_dir, 'vis_data')

for vis_backend in vis_backends: # type: ignore
name = None
if isinstance(vis_backend, dict):
name = vis_backend.pop('name', None)
vis_backend.setdefault('save_dir', save_dir)
self._vis_backends[name] = VISBACKENDS.build(vis_backend)
# Using `get` could be a better chioce here to get the
# signature of `VisBackend__init__`. However, we use
# `Registry.build` to build the visbackend first to avoid the
# risk of scope error when using `Registry.get`
vis_backend = VISBACKENDS.build(vis_backend)

# If vis_backend requires save_dir but it's not provided
# (the value is None), then don't add this vis_backend to
# the visualizer.
zhouzaida marked this conversation as resolved.
Show resolved Hide resolved
save_dir_arg = inspect.signature(
vis_backend.__class__.__init__).parameters.get('save_dir')
if (save_dir_arg is not None
and (save_dir_arg.default is save_dir_arg.empty
and save_dir is None)):
warnings.warn(f'Failed to add {vis_backend.__class__}, '
'please provide the `save_dir` argument.')
continue

type_name = vis_backend.__class__.__name__
name = name or type_name

if name in self._vis_backends:
raise RuntimeError(f'vis_backend name {name} already exists')
self._vis_backends[name] = vis_backend # type: ignore

self.fig_save = None
self.fig_save_cfg = fig_save_cfg
Expand Down
42 changes: 26 additions & 16 deletions tests/test_visualizer/test_visualizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
@VISBACKENDS.register_module()
class MockVisBackend:

def __init__(self, save_dir: str):
def __init__(self, save_dir: str = 'none'):
self._save_dir = save_dir
self._close = False

Expand Down Expand Up @@ -78,21 +78,6 @@ def test_init(self):
assert isinstance(visualizer.get_backend('mock1'), MockVisBackend)
assert len(visualizer._vis_backends) == 2

# test empty list
with pytest.raises(AssertionError):
Visualizer(vis_backends=[], save_dir='temp_dir')

# test name
# If one of them has a name attribute, all backends must
# use the name attribute
with pytest.raises(RuntimeError):
Visualizer(
vis_backends=[
dict(type='MockVisBackend'),
dict(type='MockVisBackend', name='mock2')
],
save_dir='temp_dir')

# The name fields cannot be the same
with pytest.raises(RuntimeError):
Visualizer(
Expand Down Expand Up @@ -120,6 +105,31 @@ def test_init(self):
visualizer_any = Visualizer.get_instance(instance_name)
assert visualizer_any == visualizer

# local backend will not be built without `save_dir` argument
@VISBACKENDS.register_module()
class CustomLocalVisBackend:

def __init__(self, save_dir: str) -> None:
...

with pytest.warns(UserWarning):
visualizer = Visualizer.get_instance(
'test_save_dir',
vis_backends=[dict(type='CustomLocalVisBackend')])
assert not visualizer._vis_backends

VISBACKENDS.module_dict.pop('CustomLocalVisBackend')

visualizer = Visualizer.get_instance(
'test_save_dir',
vis_backends=dict(type='CustomLocalVisBackend', save_dir='tmp'))

visualizer = Visualizer.get_instance(
'test_save_dir', vis_backends=[CustomLocalVisBackend('tmp')])

visualizer = Visualizer.get_instance(
'test_save_dir', vis_backends=CustomLocalVisBackend('tmp'))

def test_set_image(self):
visualizer = Visualizer()
visualizer.set_image(self.image)
Expand Down