Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add support for comparing multiple baseline images #161

Draft
wants to merge 1 commit into
base: main
Choose a base branch
from
Draft
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
149 changes: 98 additions & 51 deletions pytest_mpl/plugin.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@

import io
import os
import glob
import json
import shutil
import hashlib
Expand Down Expand Up @@ -370,40 +371,64 @@ def _download_file(self, baseline, filename):
tmpfile.write(content)
return Path(filename)

def obtain_baseline_image(self, item, target_dir):
def obtain_baseline_images(self, item, target_dir):
"""
Copy the baseline image to our working directory.
Copy the baseline image(s) to our working directory.

If the image is remote it is downloaded, if it is local it is copied to
ensure it is kept in the event of a test failure.
"""
compare = self.get_compare(item)
multi = compare.kwargs.get('multi', False)
filename = self.generate_filename(item)
baseline_dir = self.get_baseline_directory(item)
baseline_remote = (isinstance(baseline_dir, str) and # noqa
baseline_dir.startswith(('http://', 'https://')))
if baseline_remote:
if multi:
pytest.fail('Multi-baseline testing only works with local baselines.',
pytrace=False)
# baseline_dir can be a list of URLs when remote, so we have to
# pass base and filename to download
baseline_image = self._download_file(baseline_dir, filename)
baseline_images = [self._download_file(baseline_dir, filename)]
elif not multi:
baseline_images = [(baseline_dir / filename).absolute()]
else:
baseline_image = (baseline_dir / filename).absolute()
dirname, ext = os.path.splitext(filename)
baseline_images = glob.glob(
os.path.join(baseline_dir.absolute(), dirname, '**', '*' + ext),
recursive=True)

return baseline_images

def obtain_baseline_image(self, item, target_dir):
"""
Backwards-Compatible wrapper for obtain_baseline_images.

return baseline_image
Always returns the first found baseline image.
"""
return self.obtain_baseline_images(item, target_dir)[0]

def generate_baseline_image(self, item, fig):
"""
Generate reference figures.
"""
compare = self.get_compare(item)
savefig_kwargs = compare.kwargs.get('savefig_kwargs', {})
multi = compare.kwargs.get('multi', False)

if not os.path.exists(self.generate_dir):
os.makedirs(self.generate_dir)

baseline_filename = self.generate_filename(item)
baseline_path = (self.generate_dir / baseline_filename).absolute()
fig.savefig(str(baseline_path), **savefig_kwargs)
if multi:
raw_name, ext = os.path.splitext(str(baseline_path))
if not os.path.exists(raw_name):
os.makedirs(raw_name)
baseline_path = os.path.join(raw_name, "generated" + ext)

fig.savefig(str(baseline_path), **savefig_kwargs)
close_mpl_figure(fig)

return baseline_path
Expand Down Expand Up @@ -440,13 +465,14 @@ def compare_image_to_baseline(self, item, fig, result_dir, summary=None):
tolerance = compare.kwargs.get('tolerance', 2)
savefig_kwargs = compare.kwargs.get('savefig_kwargs', {})

baseline_image_ref = self.obtain_baseline_image(item, result_dir)
baseline_image_refs = self.obtain_baseline_images(item, result_dir)
baseline_image_refs = [p for p in baseline_image_refs if os.path.exists(p)]

test_image = (result_dir / "result.png").absolute()
fig.savefig(str(test_image), **savefig_kwargs)
summary['result_image'] = test_image.relative_to(self.results_dir).as_posix()

if not os.path.exists(baseline_image_ref):
if len(baseline_image_refs) == 0:
summary['status'] = 'failed'
summary['image_status'] = 'missing'
error_message = ("Image file not found for comparison test in: \n\t"
Expand All @@ -457,49 +483,70 @@ def compare_image_to_baseline(self, item, fig, result_dir, summary=None):
summary['status_msg'] = error_message
return error_message

# setuptools may put the baseline images in non-accessible places,
# copy to our tmpdir to be sure to keep them in case of failure
baseline_image = (result_dir / "baseline.png").absolute()
shutil.copyfile(baseline_image_ref, baseline_image)
summary['baseline_image'] = baseline_image.relative_to(self.results_dir).as_posix()

# Compare image size ourselves since the Matplotlib
# exception is a bit cryptic in this case and doesn't show
# the filenames
expected_shape = imread(str(baseline_image)).shape[:2]
actual_shape = imread(str(test_image)).shape[:2]
if expected_shape != actual_shape:
summary['status'] = 'failed'
summary['image_status'] = 'diff'
error_message = SHAPE_MISMATCH_ERROR.format(expected_path=baseline_image,
expected_shape=expected_shape,
actual_path=test_image,
actual_shape=actual_shape)
summary['status_msg'] = error_message
return error_message

results = compare_images(str(baseline_image), str(test_image), tol=tolerance, in_decorator=True)
summary['tolerance'] = tolerance
if results is None:
summary['status'] = 'passed'
summary['image_status'] = 'match'
summary['status_msg'] = 'Image comparison passed.'
return None
else:
summary['status'] = 'failed'
summary['image_status'] = 'diff'
summary['rms'] = results['rms']
diff_image = (result_dir / 'result-failed-diff.png').absolute()
summary['diff_image'] = diff_image.relative_to(self.results_dir).as_posix()
template = ['Error: Image files did not match.',
'RMS Value: {rms}',
'Expected: \n {expected}',
'Actual: \n {actual}',
'Difference:\n {diff}',
'Tolerance: \n {tol}', ]
error_message = '\n '.join([line.format(**results) for line in template])
summary['status_msg'] = error_message
return error_message
cur_summ = {}
best_rms = float('inf')
all_msgs = ''
i = -1

for baseline_image_ref in baseline_image_refs:
# setuptools may put the baseline images in non-accessible places,
# copy to our tmpdir to be sure to keep them in case of failure
i += 1
baseline_file = f"baseline-{i}.png" if i else "baseline.png"
baseline_image = (result_dir / baseline_file).absolute()
rel_baseline_image = baseline_image.relative_to(self.results_dir).as_posix()
shutil.copyfile(baseline_image_ref, baseline_image)

# Compare image size ourselves since the Matplotlib
# exception is a bit cryptic in this case and doesn't show
# the filenames
expected_shape = imread(str(baseline_image)).shape[:2]
actual_shape = imread(str(test_image)).shape[:2]
if expected_shape != actual_shape:
best_rms = float('-inf')
cur_summ = {}
cur_summ['baseline_image'] = rel_baseline_image
cur_summ['status'] = 'failed'
cur_summ['image_status'] = 'diff'
error_message = SHAPE_MISMATCH_ERROR.format(expected_path=baseline_image,
expected_shape=expected_shape,
actual_path=test_image,
actual_shape=actual_shape)
cur_summ['status_msg'] = error_message
all_msgs += error_message + '\n\n'
continue

results = compare_images(str(baseline_image), str(test_image), tol=tolerance, in_decorator=True)
if results is None:
summary['baseline_image'] = rel_baseline_image
summary['tolerance'] = tolerance
summary['status'] = 'passed'
summary['image_status'] = 'match'
summary['status_msg'] = 'Image comparison passed.'
return None
else:
template = ['Error: Image files did not match.',
'RMS Value: {rms}',
'Expected: \n {expected}',
'Actual: \n {actual}',
'Difference:\n {diff}',
'Tolerance: \n {tol}', ]
error_message = '\n '.join([line.format(**results) for line in template])
all_msgs += error_message + '\n\n'
if results['rms'] < best_rms:
best_rms = results['rms']
cur_summ = {}
cur_summ['baseline_image'] = rel_baseline_image
cur_summ['tolerance'] = tolerance
cur_summ['status'] = 'failed'
cur_summ['image_status'] = 'diff'
cur_summ['rms'] = results['rms']
diff_image = (result_dir / 'result-failed-diff.png').absolute()
cur_summ['diff_image'] = diff_image.relative_to(self.results_dir).as_posix()
cur_summ['status_msg'] = error_message

summary.update(cur_summ)
return all_msgs.strip()

def load_hash_library(self, library_path):
with open(str(library_path)) as fp:
Expand Down