Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add support for classes with pytest 7 #164

Merged
merged 5 commits into from
Jun 14, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
4 changes: 4 additions & 0 deletions .github/workflows/test_and_publish.yml
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,10 @@ jobs:
- linux: py38-test-mpl33
- linux: py39-test-mpl34
- linux: py310-test-mpl35
# Test different versions of pytest
- linux: py310-test-mpl35-pytestdev
- linux: py310-test-mpl35-pytest62
- linux: py38-test-mpl35-pytest54
coverage: 'codecov'

publish:
Expand Down
239 changes: 108 additions & 131 deletions pytest_mpl/plugin.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,13 +33,11 @@
import json
import shutil
import hashlib
import inspect
import logging
import tempfile
import warnings
import contextlib
from pathlib import Path
from functools import wraps
from urllib.request import urlopen

import pytest
Expand Down Expand Up @@ -83,6 +81,14 @@ def pathify(path):
return Path(path + ext)


def _pytest_pyfunc_call(obj, pyfuncitem):
testfunction = pyfuncitem.obj
funcargs = pyfuncitem.funcargs
testargs = {arg: funcargs[arg] for arg in pyfuncitem._fixtureinfo.argnames}
obj.result = testfunction(**testargs)
return True


def pytest_report_header(config, startdir):
import matplotlib
import matplotlib.ft2font
Expand Down Expand Up @@ -211,13 +217,11 @@ def close_mpl_figure(fig):
plt.close(fig)


def get_marker(item, marker_name):
if hasattr(item, 'get_closest_marker'):
return item.get_closest_marker(marker_name)
else:
# "item.keywords.get" was deprecated in pytest 3.6
# See https://docs.pytest.org/en/latest/mark.html#updating-code
return item.keywords.get(marker_name)
Comment on lines -218 to -220
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't think we need to support this branch anymore.

def get_compare(item):
"""
Return the mpl_image_compare marker for the given item.
"""
return item.get_closest_marker("mpl_image_compare")


def path_is_not_none(apath):
Expand Down Expand Up @@ -278,20 +282,14 @@ def __init__(self,
logging.basicConfig(level=level)
self.logger = logging.getLogger('pytest-mpl')

def get_compare(self, item):
"""
Return the mpl_image_compare marker for the given item.
"""
return get_marker(item, 'mpl_image_compare')

def generate_filename(self, item):
"""
Given a pytest item, generate the figure filename.
"""
if self.config.getini('mpl-use-full-test-name'):
filename = self.generate_test_name(item) + '.png'
else:
compare = self.get_compare(item)
compare = get_compare(item)
# Find test name to use as plot name
filename = compare.kwargs.get('filename', None)
if filename is None:
Expand All @@ -304,7 +302,11 @@ def generate_test_name(self, item):
"""
Generate a unique name for the hash for this test.
"""
return f"{item.module.__name__}.{item.name}"
if item.cls is not None:
name = f"{item.module.__name__}.{item.cls.__name__}.{item.name}"
else:
name = f"{item.module.__name__}.{item.name}"
return name

def make_test_results_dir(self, item):
"""
Expand All @@ -319,7 +321,7 @@ def baseline_directory_specified(self, item):
"""
Returns `True` if a non-default baseline directory is specified.
"""
compare = self.get_compare(item)
compare = get_compare(item)
item_baseline_dir = compare.kwargs.get('baseline_dir', None)
return item_baseline_dir or self.baseline_dir or self.baseline_relative_dir

Expand All @@ -330,7 +332,7 @@ def get_baseline_directory(self, item):
Using the global and per-test configuration return the absolute
baseline dir, if the baseline file is local else return base URL.
"""
compare = self.get_compare(item)
compare = get_compare(item)
baseline_dir = compare.kwargs.get('baseline_dir', None)
if baseline_dir is None:
if self.baseline_dir is None:
Expand Down Expand Up @@ -394,7 +396,7 @@ def generate_baseline_image(self, item, fig):
"""
Generate reference figures.
"""
compare = self.get_compare(item)
compare = get_compare(item)
savefig_kwargs = compare.kwargs.get('savefig_kwargs', {})

if not os.path.exists(self.generate_dir):
Expand All @@ -413,7 +415,7 @@ def generate_image_hash(self, item, fig):
For a `matplotlib.figure.Figure`, returns the SHA256 hash as a hexadecimal
string.
"""
compare = self.get_compare(item)
compare = get_compare(item)
savefig_kwargs = compare.kwargs.get('savefig_kwargs', {})

imgdata = io.BytesIO()
Expand All @@ -436,7 +438,7 @@ def compare_image_to_baseline(self, item, fig, result_dir, summary=None):
if summary is None:
summary = {}

compare = self.get_compare(item)
compare = get_compare(item)
tolerance = compare.kwargs.get('tolerance', 2)
savefig_kwargs = compare.kwargs.get('savefig_kwargs', {})

Expand Down Expand Up @@ -510,7 +512,7 @@ def compare_image_to_hash_library(self, item, fig, result_dir, summary=None):
if summary is None:
summary = {}

compare = self.get_compare(item)
compare = get_compare(item)
savefig_kwargs = compare.kwargs.get('savefig_kwargs', {})

if not self.results_hash_library_name:
Expand Down Expand Up @@ -582,11 +584,13 @@ def compare_image_to_hash_library(self, item, fig, result_dir, summary=None):
return
return summary['status_msg']

def pytest_runtest_setup(self, item): # noqa
@pytest.hookimpl(hookwrapper=True)
def pytest_runtest_call(self, item): # noqa

compare = self.get_compare(item)
compare = get_compare(item)

if compare is None:
yield
return

import matplotlib.pyplot as plt
Expand All @@ -600,95 +604,82 @@ def pytest_runtest_setup(self, item): # noqa
remove_text = compare.kwargs.get('remove_text', False)
backend = compare.kwargs.get('backend', 'agg')

original = item.function

@wraps(item.function)
def item_function_wrapper(*args, **kwargs):

with plt.style.context(style, after_reset=True), switch_backend(backend):

# Run test and get figure object
if inspect.ismethod(original): # method
# In some cases, for example if setup_method is used,
# original appears to belong to an instance of the test
# class that is not the same as args[0], and args[0] is the
# one that has the correct attributes set up from setup_method
# so we ignore original.__self__ and use args[0] instead.
fig = original.__func__(*args, **kwargs)
else: # function
fig = original(*args, **kwargs)

if remove_text:
remove_ticks_and_titles(fig)

test_name = self.generate_test_name(item)
result_dir = self.make_test_results_dir(item)

summary = {
'status': None,
'image_status': None,
'hash_status': None,
'status_msg': None,
'baseline_image': None,
'diff_image': None,
'rms': None,
'tolerance': None,
'result_image': None,
'baseline_hash': None,
'result_hash': None,
}

# What we do now depends on whether we are generating the
# reference images or simply running the test.
if self.generate_dir is not None:
summary['status'] = 'skipped'
summary['image_status'] = 'generated'
summary['status_msg'] = 'Skipped test, since generating image.'
generate_image = self.generate_baseline_image(item, fig)
if self.results_always: # Make baseline image available in HTML
result_image = (result_dir / "baseline.png").absolute()
shutil.copy(generate_image, result_image)
summary['baseline_image'] = \
result_image.relative_to(self.results_dir).as_posix()

if self.generate_hash_library is not None:
summary['hash_status'] = 'generated'
image_hash = self.generate_image_hash(item, fig)
self._generated_hash_library[test_name] = image_hash
summary['baseline_hash'] = image_hash

# Only test figures if not generating images
if self.generate_dir is None:
# Compare to hash library
if self.hash_library or compare.kwargs.get('hash_library', None):
msg = self.compare_image_to_hash_library(item, fig, result_dir, summary=summary)

# Compare against a baseline if specified
else:
msg = self.compare_image_to_baseline(item, fig, result_dir, summary=summary)

close_mpl_figure(fig)

if msg is None:
if not self.results_always:
shutil.rmtree(result_dir)
for image_type in ['baseline_image', 'diff_image', 'result_image']:
summary[image_type] = None # image no longer exists
else:
self._test_results[test_name] = summary
pytest.fail(msg, pytrace=False)
with plt.style.context(style, after_reset=True), switch_backend(backend):

# Run test and get figure object
yield
fig = self.result

if remove_text:
remove_ticks_and_titles(fig)

test_name = self.generate_test_name(item)
result_dir = self.make_test_results_dir(item)

summary = {
'status': None,
'image_status': None,
'hash_status': None,
'status_msg': None,
'baseline_image': None,
'diff_image': None,
'rms': None,
'tolerance': None,
'result_image': None,
'baseline_hash': None,
'result_hash': None,
}

# What we do now depends on whether we are generating the
# reference images or simply running the test.
if self.generate_dir is not None:
summary['status'] = 'skipped'
summary['image_status'] = 'generated'
summary['status_msg'] = 'Skipped test, since generating image.'
generate_image = self.generate_baseline_image(item, fig)
if self.results_always: # Make baseline image available in HTML
result_image = (result_dir / "baseline.png").absolute()
shutil.copy(generate_image, result_image)
summary['baseline_image'] = \
result_image.relative_to(self.results_dir).as_posix()

if self.generate_hash_library is not None:
summary['hash_status'] = 'generated'
image_hash = self.generate_image_hash(item, fig)
self._generated_hash_library[test_name] = image_hash
summary['baseline_hash'] = image_hash

# Only test figures if not generating images
if self.generate_dir is None:
# Compare to hash library
if self.hash_library or compare.kwargs.get('hash_library', None):
msg = self.compare_image_to_hash_library(item, fig, result_dir, summary=summary)

# Compare against a baseline if specified
else:
msg = self.compare_image_to_baseline(item, fig, result_dir, summary=summary)

close_mpl_figure(fig)

self._test_results[test_name] = summary
if msg is None:
if not self.results_always:
shutil.rmtree(result_dir)
for image_type in ['baseline_image', 'diff_image', 'result_image']:
summary[image_type] = None # image no longer exists
else:
self._test_results[test_name] = summary
pytest.fail(msg, pytrace=False)

if summary['status'] == 'skipped':
pytest.skip(summary['status_msg'])
close_mpl_figure(fig)

if item.cls is not None:
setattr(item.cls, item.function.__name__, item_function_wrapper)
else:
item.obj = item_function_wrapper
self._test_results[test_name] = summary

if summary['status'] == 'skipped':
pytest.skip(summary['status_msg'])

@pytest.hookimpl(tryfirst=True)
def pytest_pyfunc_call(self, pyfuncitem):
return _pytest_pyfunc_call(self, pyfuncitem)

def generate_summary_json(self):
json_file = self.results_dir / 'results.json'
Expand Down Expand Up @@ -742,26 +733,12 @@ class FigureCloser:
def __init__(self, config):
self.config = config

def pytest_runtest_setup(self, item):

compare = get_marker(item, 'mpl_image_compare')

if compare is None:
return

original = item.function

@wraps(item.function)
def item_function_wrapper(*args, **kwargs):

if inspect.ismethod(original): # method
fig = original.__func__(*args, **kwargs)
else: # function
fig = original(*args, **kwargs)

close_mpl_figure(fig)
@pytest.hookimpl(hookwrapper=True)
def pytest_runtest_call(self, item):
yield
if get_compare(item) is not None:
close_mpl_figure(self.result)

if item.cls is not None:
setattr(item.cls, item.function.__name__, item_function_wrapper)
else:
item.obj = item_function_wrapper
@pytest.hookimpl(tryfirst=True)
def pytest_pyfunc_call(self, pyfuncitem):
return _pytest_pyfunc_call(self, pyfuncitem)