Skip to content

Commit

Permalink
[Refactor] Use new API of matplotlib to handle blocking input in visu…
Browse files Browse the repository at this point in the history
…alization. (open-mmlab#568)

* [Refactor] Use new API of matplotlib to handle blocking input in
visualization.

* Modify unit tests
  • Loading branch information
mzr1996 committed Dec 2, 2021
1 parent b962805 commit 114ac6f
Show file tree
Hide file tree
Showing 3 changed files with 85 additions and 71 deletions.
88 changes: 52 additions & 36 deletions mmcls/core/visualization/image.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,7 @@
from threading import Timer

import matplotlib
import matplotlib.pyplot as plt
import mmcv
import numpy as np
from matplotlib.backend_bases import CloseEvent
from matplotlib.blocking_input import BlockingInput

# A small value
EPS = 1e-2
Expand Down Expand Up @@ -41,7 +37,7 @@ class BaseFigureContextManager:
"""

def __init__(self, axis=False, fig_save_cfg={}, fig_show_cfg={}) -> None:
self.is_inline = 'inline' in matplotlib.get_backend()
self.is_inline = 'inline' in plt.get_backend()

# Because save and show need different figure size
# We set two figure and axes to handle save and show
Expand All @@ -52,7 +48,6 @@ def __init__(self, axis=False, fig_save_cfg={}, fig_show_cfg={}) -> None:
self.fig_show: plt.Figure = None
self.fig_show_cfg = fig_show_cfg
self.ax_show: plt.Axes = None
self.blocking_input: BlockingInput = None

self.axis = axis

Expand Down Expand Up @@ -83,8 +78,6 @@ def _initialize_fig_show(self):
fig.subplots_adjust(left=0, right=1, bottom=0, top=1)

self.fig_show, self.ax_show = fig, ax
self.blocking_input = BlockingInput(
self.fig_show, eventslist=('key_press_event', 'close_event'))

def __exit__(self, exc_type, exc_value, traceback):
if self.is_inline:
Expand All @@ -95,14 +88,6 @@ def __exit__(self, exc_type, exc_value, traceback):
plt.close(self.fig_save)
plt.close(self.fig_show)

try:
# In matplotlib>=3.4.0, with TkAgg, plt.close will destroy
# window after idle, need to update manually.
# Refers to https://github.com/matplotlib/matplotlib/blob/v3.4.x/lib/matplotlib/backends/_backend_tk.py#L470 # noqa: E501
self.fig_show.canvas.manager.window.update()
except AttributeError:
pass

def prepare(self):
if self.is_inline:
# if use inline backend, just rebuild the fig_save.
Expand All @@ -121,29 +106,59 @@ def prepare(self):
self.ax_show.cla()
self.ax_show.axis(self.axis)

def wait_continue(self, timeout=0):
def wait_continue(self, timeout=0, continue_key=' ') -> int:
"""Show the image and wait for the user's input.
This implementation refers to
https://github.com/matplotlib/matplotlib/blob/v3.5.x/lib/matplotlib/_blocking_input.py
Args:
timeout (int): If positive, continue after ``timeout`` seconds.
Defaults to 0.
continue_key (str): The key for users to continue. Defaults to
the space key.
Returns:
int: If zero, means time out or the user pressed ``continue_key``,
and if one, means the user closed the show figure.
""" # noqa: E501
if self.is_inline:
# If use inline backend, interactive input and timeout is no use.
return

# In matplotlib==3.4.x, with TkAgg, official timeout api of
# start_event_loop cannot work properly. Use a Timer to directly stop
# event loop.
if timeout > 0:
timer = Timer(timeout, self.fig_show.canvas.stop_event_loop)
timer.start()
if self.fig_show.canvas.manager:
# Ensure that the figure is shown
self.fig_show.show()

while True:
# Disable matplotlib default hotkey to close figure.
with plt.rc_context({'keymap.quit': []}):
key_press = self.blocking_input(n=1, timeout=0)

# Timeout or figure is closed or press space or press 'q'
if len(key_press) == 0 or isinstance(
key_press[0],
CloseEvent) or key_press[0].key in ['q', ' ']:
break
if timeout > 0:
timer.cancel()
# Connect the events to the handler function call.
event = None

def handler(ev):
# Set external event variable
nonlocal event
# Qt backend may fire two events at the same time,
# use a condition to avoid missing close event.
event = ev if not isinstance(event, CloseEvent) else event
self.fig_show.canvas.stop_event_loop()

cids = [
self.fig_show.canvas.mpl_connect(name, handler)
for name in ('key_press_event', 'close_event')
]

try:
self.fig_show.canvas.start_event_loop(timeout)
finally: # Run even on exception like ctrl-c.
# Disconnect the callbacks.
for cid in cids:
self.fig_show.canvas.mpl_disconnect(cid)

if isinstance(event, CloseEvent):
return 1 # Quit for close.
elif event is None or event.key == continue_key:
return 0 # Quit for continue.


class ImshowInfosContextManager(BaseFigureContextManager):
Expand Down Expand Up @@ -259,6 +274,7 @@ def put_img_infos(self,
if out_file is not None:
mmcv.imwrite(img_save, out_file)

ret = 0
if show and not self.is_inline:
# Reserve some space for the tip.
self.ax_show.set_title(win_name)
Expand All @@ -274,13 +290,13 @@ def put_img_infos(self,
# Refresh canvas, necessary for Qt5 backend.
self.fig_show.canvas.draw()

self.wait_continue(timeout=wait_time)
ret = self.wait_continue(timeout=wait_time)
elif (not show) and self.is_inline:
# If use inline backend, we use fig_save to show the image
# So we need to close it if users don't want to show.
plt.close(self.fig_save)

return img_save
return ret, img_save


def imshow_infos(img,
Expand Down Expand Up @@ -313,7 +329,7 @@ def imshow_infos(img,
np.ndarray: The image with extra infomations.
"""
with ImshowInfosContextManager(fig_size=fig_size) as manager:
img = manager.put_img_infos(
_, img = manager.put_img_infos(
img,
infos,
text_color=text_color,
Expand Down
62 changes: 28 additions & 34 deletions tests/test_utils/test_visualization.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,8 @@
# Copyright (c) Open-MMLab. All rights reserved.
import os
import os.path as osp
import shutil
import tempfile
from unittest.mock import Mock, patch
from unittest.mock import MagicMock

import matplotlib.pyplot as plt
import mmcv
Expand Down Expand Up @@ -52,30 +51,8 @@ def test_imshow_infos():
assert image.shape == out_image.shape[:2]
os.remove(tmp_filename)

# test show=True
image = np.ones((10, 10, 3), np.uint8)
result = {'pred_label': 1, 'pred_class': 'bird', 'pred_score': 0.98}

def mock_blocking_input(self, n=1, timeout=30):
keypress = Mock()
keypress.key = ' '
out_path = osp.join(tmp_dir, '_'.join([str(n), str(timeout)]))
with open(out_path, 'w') as f:
f.write('test')
return [keypress]

with patch('matplotlib.blocking_input.BlockingInput.__call__',
mock_blocking_input):
vis.imshow_infos(image, result, show=True, wait_time=5)
assert osp.exists(osp.join(tmp_dir, '1_0'))

shutil.rmtree(tmp_dir)


@patch(
'matplotlib.blocking_input.BlockingInput.__call__',
return_value=[Mock(key=' ')])
def test_context_manager(mock_blocking_input):
def test_figure_context_manager():
# test show multiple images with the same figure.
images = [
np.random.randint(0, 255, (100, 100, 3), np.uint8) for _ in range(5)
Expand All @@ -85,22 +62,39 @@ def test_context_manager(mock_blocking_input):
with vis.ImshowInfosContextManager() as manager:
fig_show = manager.fig_show
fig_save = manager.fig_save

# Test time out
fig_show.canvas.start_event_loop = MagicMock()
fig_show.canvas.end_event_loop = MagicMock()
for image in images:
out_image = manager.put_img_infos(image, result, show=True)
ret, out_image = manager.put_img_infos(image, result, show=True)
assert ret == 0
assert image.shape == out_image.shape
assert not np.allclose(image, out_image)
assert fig_show is manager.fig_show
assert fig_save is manager.fig_save

# test rebuild figure if user destroyed it.
with vis.ImshowInfosContextManager() as manager:
fig_save = manager.fig_save
# Test continue key
fig_show.canvas.start_event_loop = (
lambda _: fig_show.canvas.key_press_event(' '))
for image in images:
fig_show = manager.fig_show
plt.close(manager.fig_show)

out_image = manager.put_img_infos(image, result, show=True)
ret, out_image = manager.put_img_infos(image, result, show=True)
assert ret == 0
assert image.shape == out_image.shape
assert not np.allclose(image, out_image)
assert not (fig_show is manager.fig_show)
assert fig_show is manager.fig_show
assert fig_save is manager.fig_save

# Test close figure manually
fig_show = manager.fig_show

def destroy(*_, **__):
fig_show.canvas.close_event()
plt.close(fig_show)

fig_show.canvas.start_event_loop = destroy
ret, out_image = manager.put_img_infos(images[0], result, show=True)
assert ret == 1
assert image.shape == out_image.shape
assert not np.allclose(image, out_image)
assert fig_save is manager.fig_save
6 changes: 5 additions & 1 deletion tools/visualizations/vis_pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -238,7 +238,7 @@ def main():

infos = dict(label=CLASSES[item['gt_label']])

manager.put_img_infos(
ret, _ = manager.put_img_infos(
image,
infos,
font_size=20,
Expand All @@ -248,6 +248,10 @@ def main():

progressBar.update()

if ret == 1:
print('\nMannualy interrupted.')
break


if __name__ == '__main__':
main()

0 comments on commit 114ac6f

Please sign in to comment.