From 1382cee02af13ee81b02047c04e4445811aa8538 Mon Sep 17 00:00:00 2001 From: Michael Droettboom Date: Mon, 25 Jan 2016 15:51:08 -0500 Subject: [PATCH] Merge pull request #5809 from anntzer/cleanup-generative-tests Support generative tests in @cleanup. --- lib/matplotlib/testing/decorators.py | 36 +++++++++++++++++++--------- lib/matplotlib/tests/test_axes.py | 2 +- 2 files changed, 26 insertions(+), 12 deletions(-) diff --git a/lib/matplotlib/testing/decorators.py b/lib/matplotlib/testing/decorators.py index ad463f3e5285..5cf1bf030448 100644 --- a/lib/matplotlib/testing/decorators.py +++ b/lib/matplotlib/testing/decorators.py @@ -5,6 +5,7 @@ import functools import gc +import inspect import os import sys import shutil @@ -129,18 +130,31 @@ def cleanup(style=None): # writing a decorator with optional arguments. def make_cleanup(func): - @functools.wraps(func) - def wrapped_function(*args, **kwargs): - original_units_registry = matplotlib.units.registry.copy() - original_settings = mpl.rcParams.copy() - matplotlib.style.use(style) - try: - func(*args, **kwargs) - finally: - _do_cleanup(original_units_registry, - original_settings) + if inspect.isgenerator(func): + @functools.wraps(func) + def wrapped_callable(*args, **kwargs): + original_units_registry = matplotlib.units.registry.copy() + original_settings = mpl.rcParams.copy() + matplotlib.style.use(style) + try: + for yielded in func(*args, **kwargs): + yield yielded + finally: + _do_cleanup(original_units_registry, + original_settings) + else: + @functools.wraps(func) + def wrapped_callable(*args, **kwargs): + original_units_registry = matplotlib.units.registry.copy() + original_settings = mpl.rcParams.copy() + matplotlib.style.use(style) + try: + func(*args, **kwargs) + finally: + _do_cleanup(original_units_registry, + original_settings) - return wrapped_function + return wrapped_callable if isinstance(style, six.string_types): return make_cleanup diff --git a/lib/matplotlib/tests/test_axes.py b/lib/matplotlib/tests/test_axes.py index 0c489a1b707b..3348a404d5d1 100644 --- a/lib/matplotlib/tests/test_axes.py +++ b/lib/matplotlib/tests/test_axes.py @@ -4282,7 +4282,7 @@ def _helper_y(ax): orig_xlim = ax_lst[0][1].get_xlim() ax.remove() ax.set_xlim(0, 5) - assert assert_array_equal(ax_lst[0][1].get_xlim(), orig_xlim) + assert_array_equal(ax_lst[0][1].get_xlim(), orig_xlim) @cleanup