Skip to content

Commit

Permalink
TST: Make ignore_matplotlibrc a pytest fixture
Browse files Browse the repository at this point in the history
This allows it to be composed with pytest marks (e.g.
`pytest.mark.skipif`). See astropy#9991.
  • Loading branch information
lpsinger committed Mar 17, 2020
1 parent 24dc50a commit 72c28ae
Show file tree
Hide file tree
Showing 5 changed files with 29 additions and 49 deletions.
11 changes: 11 additions & 0 deletions astropy/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,8 @@
except ImportError:
PYTEST_HEADER_MODULES = {}

import pytest

from astropy.tests.helper import enable_deprecations_as_exceptions

try:
Expand All @@ -36,6 +38,15 @@
matplotlibrc_cache = {}


@pytest.fixture
def ignore_matplotlibrc():
# This is a fixture for tests that use matplotlib but not pytest-mpl
# (which already handles rcParams)
from matplotlib import pyplot as plt
with plt.style.context({}, after_reset=True):
yield


def pytest_configure(config):
builtins._pytest_running = True
# do not assign to matplotlibrc_cache in function scope
Expand Down
10 changes: 0 additions & 10 deletions astropy/tests/image_tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,13 +13,3 @@
ROOT = "http://{server}/testing/astropy/2019-08-02T11:38:58.288466/{mpl_version}/"
IMAGE_REFERENCE_DIR = (ROOT.format(server='data.astropy.org', mpl_version=MPL_VERSION[:3] + '.x') + ',' +
ROOT.format(server='www.astropy.org/astropy-data', mpl_version=MPL_VERSION[:3] + '.x'))


def ignore_matplotlibrc(func):
# This is a decorator for tests that use matplotlib but not pytest-mpl
# (which already handles rcParams)
@wraps(func)
def wrapper(*args, **kwargs):
with plt.style.context({}, after_reset=True):
return func(*args, **kwargs)
return wrapper
16 changes: 5 additions & 11 deletions astropy/visualization/wcsaxes/tests/test_coordinate_helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,6 @@

from astropy.visualization.wcsaxes.core import WCSAxes
from astropy import units as u
from astropy.tests.image_tests import ignore_matplotlibrc

ROOT = os.path.join(os.path.dirname(__file__))
MSX_HEADER = fits.Header.fromtextfile(os.path.join(ROOT, 'data', 'msx_header'))
Expand All @@ -20,8 +19,7 @@ def teardown_function(function):
plt.close('all')


@ignore_matplotlibrc
def test_getaxislabel():
def test_getaxislabel(ignore_matplotlibrc):

fig = plt.figure()
ax = WCSAxes(fig, [0.1, 0.1, 0.8, 0.8], aspect='equal')
Expand Down Expand Up @@ -53,22 +51,19 @@ def assert_label_draw(ax, x_label, y_label):
assert pos2.call_count == y_label


@ignore_matplotlibrc
def test_label_visibility_rules_default(ax):
def test_label_visibility_rules_default(ignore_matplotlibrc, ax):
assert_label_draw(ax, True, True)


@ignore_matplotlibrc
def test_label_visibility_rules_label(ax):
def test_label_visibility_rules_label(ignore_matplotlibrc, ax):

ax.coords[0].set_ticklabel_visible(False)
ax.coords[1].set_ticks(values=[-9999]*u.one)

assert_label_draw(ax, False, False)


@ignore_matplotlibrc
def test_label_visibility_rules_ticks(ax):
def test_label_visibility_rules_ticks(ignore_matplotlibrc, ax):

ax.coords[0].set_axislabel_visibility_rule('ticks')
ax.coords[1].set_axislabel_visibility_rule('ticks')
Expand All @@ -79,8 +74,7 @@ def test_label_visibility_rules_ticks(ax):
assert_label_draw(ax, True, False)


@ignore_matplotlibrc
def test_label_visibility_rules_always(ax):
def test_label_visibility_rules_always(ignore_matplotlibrc, ax):

ax.coords[0].set_axislabel_visibility_rule('always')
ax.coords[1].set_axislabel_visibility_rule('always')
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,6 @@
import astropy.units as u
from astropy.coordinates import FK5, SkyCoord
from astropy.io import fits
from astropy.tests.image_tests import ignore_matplotlibrc
from astropy.time import Time
from astropy.utils.data import get_pkg_data_filename
from astropy.visualization.wcsaxes.core import WCSAxes
Expand All @@ -23,8 +22,7 @@ class TestDisplayWorldCoordinate(BaseImageTests):
def teardown_method(self, method):
plt.close('all')

@ignore_matplotlibrc
def test_overlay_coords(self, tmpdir):
def test_overlay_coords(self, ignore_matplotlibrc, tmpdir):
wcs = WCS(self.msx_header)

fig = plt.figure(figsize=(4, 4))
Expand Down Expand Up @@ -104,8 +102,7 @@ def test_overlay_coords(self, tmpdir):

assert string_world5 == '267.652\xb0 -28\xb046\'23" (world, overlay 3)'

@ignore_matplotlibrc
def test_cube_coords(self, tmpdir):
def test_cube_coords(self, ignore_matplotlibrc, tmpdir):
wcs = WCS(self.cube_header)

fig = plt.figure(figsize=(4, 4))
Expand All @@ -128,8 +125,7 @@ def test_cube_coords(self, tmpdir):
string_pixel = ax._display_world_coords(0.523412, 0.523412)
assert string_pixel == "0.523412 0.523412 (pixel)"

@ignore_matplotlibrc
def test_cube_coords_uncorr_slicing(self, tmpdir):
def test_cube_coords_uncorr_slicing(self, ignore_matplotlibrc, tmpdir):

# Regression test for a bug that occurred with coordinate formatting if
# some dimensions were uncorrelated and sliced out.
Expand Down
31 changes: 10 additions & 21 deletions astropy/visualization/wcsaxes/tests/test_misc.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,6 @@
from astropy.io import fits
from astropy.coordinates import SkyCoord
from astropy.tests.helper import catch_warnings
from astropy.tests.image_tests import ignore_matplotlibrc

from astropy.visualization.wcsaxes.core import WCSAxes
from astropy.visualization.wcsaxes.frame import RectangularFrame, RectangularFrame1D
Expand All @@ -31,17 +30,15 @@ def teardown_function(function):
plt.close('all')


@ignore_matplotlibrc
def test_grid_regression():
def test_grid_regression(ignore_matplotlibrc):
# Regression test for a bug that meant that if the rc parameter
# axes.grid was set to True, WCSAxes would crash upon initalization.
plt.rc('axes', grid=True)
fig = plt.figure(figsize=(3, 3))
WCSAxes(fig, [0.1, 0.1, 0.8, 0.8])


@ignore_matplotlibrc
def test_format_coord_regression(tmpdir):
def test_format_coord_regression(ignore_matplotlibrc, tmpdir):
# Regression test for a bug that meant that if format_coord was called by
# Matplotlib before the axes were drawn, an error occurred.
fig = plt.figure(figsize=(3, 3))
Expand Down Expand Up @@ -74,9 +71,8 @@ def test_format_coord_regression(tmpdir):
""", sep='\n')


@ignore_matplotlibrc
@pytest.mark.parametrize('grid_type', ['lines', 'contours'])
def test_no_numpy_warnings(tmpdir, grid_type):
def test_no_numpy_warnings(ignore_matplotlibrc, tmpdir, grid_type):

# Make sure that no warnings are raised if some pixels are outside WCS
# (since this is normal)
Expand All @@ -95,8 +91,7 @@ def test_no_numpy_warnings(tmpdir, grid_type):
assert len(ws) == 0


@ignore_matplotlibrc
def test_invalid_frame_overlay():
def test_invalid_frame_overlay(ignore_matplotlibrc):

# Make sure a nice error is returned if a frame doesn't exist
ax = plt.subplot(1, 1, 1, projection=WCS(TARGET_HEADER))
Expand All @@ -109,8 +104,7 @@ def test_invalid_frame_overlay():
assert exc.value.args[0] == 'Unknown frame: banana'


@ignore_matplotlibrc
def test_plot_coord_transform():
def test_plot_coord_transform(ignore_matplotlibrc):

twoMASS_k_header = os.path.join(DATA, '2MASS_k_header')
twoMASS_k_header = fits.Header.fromtextfile(twoMASS_k_header)
Expand All @@ -126,8 +120,7 @@ def test_plot_coord_transform():
ax.plot_coord(c, 'o', transform=ax.get_transform('galactic'))


@ignore_matplotlibrc
def test_set_label_properties():
def test_set_label_properties(ignore_matplotlibrc):

# Regression test to make sure that arguments passed to
# set_xlabel/set_ylabel are passed to the underlying coordinate helpers
Expand Down Expand Up @@ -172,8 +165,7 @@ def test_set_label_properties():
""", sep='\n')


@ignore_matplotlibrc
def test_slicing_warnings(tmpdir):
def test_slicing_warnings(ignore_matplotlibrc, tmpdir):

# Regression test to make sure that no warnings are emitted by the tick
# locator for the sliced axis when slicing a cube.
Expand Down Expand Up @@ -325,8 +317,7 @@ def test_contour_empty():
ax.contour(np.zeros((4, 4)), transform=ax.get_transform('world'))


@ignore_matplotlibrc
def test_iterate_coords(tmpdir):
def test_iterate_coords(ignore_matplotlibrc, tmpdir):

# Regression test for a bug that caused ax.coords to return too few axes

Expand All @@ -342,8 +333,7 @@ def test_iterate_coords(tmpdir):
x, y, z = ax.coords


@ignore_matplotlibrc
def test_invalid_slices_errors():
def test_invalid_slices_errors(ignore_matplotlibrc):

# Make sure that users get a clear message when specifying a WCS with
# >2 dimensions without giving the 'slices' argument, or if the 'slices'
Expand Down Expand Up @@ -413,8 +403,7 @@ def test_invalid_slices_errors():
""".strip()


@ignore_matplotlibrc
def test_repr():
def test_repr(ignore_matplotlibrc):

# Unit test to make sure __repr__ looks as expected

Expand Down

0 comments on commit 72c28ae

Please sign in to comment.