Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Fix] Skip adding vis_backends when save_dir is not set #1289

Merged
merged 8 commits into from
Aug 3, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
74 changes: 44 additions & 30 deletions mmengine/visualization/visualizer.py
@@ -1,4 +1,5 @@
# Copyright (c) OpenMMLab. All rights reserved.
import inspect
import os.path as osp
import warnings
from typing import TYPE_CHECKING, Dict, List, Optional, Sequence, Tuple, Union
Expand All @@ -15,14 +16,17 @@
from mmengine.dist import master_only
from mmengine.registry import VISBACKENDS, VISUALIZERS
from mmengine.structures import BaseDataElement
from mmengine.utils import ManagerMixin
from mmengine.utils import ManagerMixin, is_seq_of
from mmengine.visualization.utils import (check_type, check_type_and_length,
color_str2rgb, color_val_matplotlib,
convert_overlay_heatmap,
img_from_canvas, tensor2ndarray,
value2list, wait_continue)
from mmengine.visualization.vis_backend import BaseVisBackend

VisBackendsType = Union[List[Union[List, BaseDataElement]], BaseDataElement,
dict, None]


@VISUALIZERS.register_module()
class Visualizer(ManagerMixin):
Expand Down Expand Up @@ -153,42 +157,52 @@ def __init__(
self,
name='visualizer',
image: Optional[np.ndarray] = None,
vis_backends: Optional[List[Dict]] = None,
vis_backends: VisBackendsType = None,
save_dir: Optional[str] = None,
fig_save_cfg=dict(frameon=False),
fig_show_cfg=dict(frameon=False)
) -> None:
super().__init__(name)
self._dataset_meta: Optional[dict] = None
self._vis_backends: Union[Dict, Dict[str, 'BaseVisBackend']] = dict()

if vis_backends is not None:
assert len(vis_backends) > 0, 'empty list'
names = [vis_backend.get('name') for vis_backend in vis_backends]
if None in names:
if len(set(names)) > 1:
raise RuntimeError(
'If one of them has a name attribute, '
'all backends must use the name attribute')
else:
type_names = [
vis_backend['type'] for vis_backend in vis_backends
]
if len(set(type_names)) != len(type_names):
raise RuntimeError(
'The same vis backend cannot exist in '
'`vis_backend` config. '
'Please specify the name field.')

if None not in names and len(set(names)) != len(names):
raise RuntimeError('The name fields cannot be the same')

if save_dir is not None:
save_dir = osp.join(save_dir, 'vis_data')
for vis_backend in vis_backends:
name = vis_backend.pop('name', vis_backend['type'])
self._vis_backends: Dict[str, BaseVisBackend] = {}

if vis_backends is None:
vis_backends = []

if isinstance(vis_backends, (dict, BaseVisBackend)):
vis_backends = [vis_backends] # type: ignore

if not is_seq_of(vis_backends, (dict, BaseVisBackend)):
raise TypeError('vis_backends must be a list of dicts or a list '
'of BaseBackend instances')
if save_dir is not None:
save_dir = osp.join(save_dir, 'vis_data')

for vis_backend in vis_backends: # type: ignore
name = None
if isinstance(vis_backend, dict):
name = vis_backend.pop('name', None)
vis_backend.setdefault('save_dir', save_dir)
self._vis_backends[name] = VISBACKENDS.build(vis_backend)
vis_backend = VISBACKENDS.build(vis_backend)

# If vis_backend requires `save_dir` (with no default value)
# but is initialized with None, then don't add this
# vis_backend to the visualizer.
save_dir_arg = inspect.signature(
vis_backend.__class__.__init__).parameters.get('save_dir')
if (save_dir_arg is not None
and save_dir_arg.default is save_dir_arg.empty
and getattr(vis_backend, '_save_dir') is None):
warnings.warn(f'Failed to add {vis_backend.__class__}, '
'please provide the `save_dir` argument.')
continue

type_name = vis_backend.__class__.__name__
name = name or type_name

if name in self._vis_backends:
raise RuntimeError(f'vis_backend name {name} already exists')
self._vis_backends[name] = vis_backend # type: ignore

self.fig_save = None
self.fig_save_cfg = fig_save_cfg
Expand Down
42 changes: 26 additions & 16 deletions tests/test_visualizer/test_visualizer.py
Expand Up @@ -17,7 +17,7 @@
@VISBACKENDS.register_module()
class MockVisBackend:

def __init__(self, save_dir: str):
def __init__(self, save_dir: str = 'none'):
self._save_dir = save_dir
self._close = False

Expand Down Expand Up @@ -78,21 +78,6 @@ def test_init(self):
assert isinstance(visualizer.get_backend('mock1'), MockVisBackend)
assert len(visualizer._vis_backends) == 2

# test empty list
with pytest.raises(AssertionError):
Visualizer(vis_backends=[], save_dir='temp_dir')

# test name
# If one of them has a name attribute, all backends must
# use the name attribute
with pytest.raises(RuntimeError):
Visualizer(
vis_backends=[
dict(type='MockVisBackend'),
dict(type='MockVisBackend', name='mock2')
],
save_dir='temp_dir')

# The name fields cannot be the same
with pytest.raises(RuntimeError):
Visualizer(
Expand Down Expand Up @@ -120,6 +105,31 @@ def test_init(self):
visualizer_any = Visualizer.get_instance(instance_name)
assert visualizer_any == visualizer

# local backend will not be built without `save_dir` argument
@VISBACKENDS.register_module()
class CustomLocalVisBackend:

def __init__(self, save_dir: str) -> None:
self._save_dir = save_dir

with pytest.warns(UserWarning):
visualizer = Visualizer.get_instance(
'test_save_dir',
vis_backends=[dict(type='CustomLocalVisBackend')])
assert not visualizer._vis_backends

VISBACKENDS.module_dict.pop('CustomLocalVisBackend')

visualizer = Visualizer.get_instance(
'test_save_dir',
vis_backends=dict(type='CustomLocalVisBackend', save_dir='tmp'))

visualizer = Visualizer.get_instance(
'test_save_dir', vis_backends=[CustomLocalVisBackend('tmp')])

visualizer = Visualizer.get_instance(
'test_save_dir', vis_backends=CustomLocalVisBackend('tmp'))

def test_set_image(self):
visualizer = Visualizer()
visualizer.set_image(self.image)
Expand Down