diff --git a/skan/test/test_vendored_correlate.py b/skan/test/test_vendored_correlate.py index ced96ac5..ec66b200 100644 --- a/skan/test/test_vendored_correlate.py +++ b/skan/test/test_vendored_correlate.py @@ -1,6 +1,9 @@ from time import time +from functools import reduce import numpy as np from skan.vendored import thresholding as th +from skimage.transform import integral_image +from scipy import ndimage as ndi class Timer: @@ -25,3 +28,16 @@ def test_fast_sauvola(): with Timer() as t1: th.threshold_sauvola(image, window_size=w1) assert t1.interval < 2 * t0.interval + + +def test_reference_correlation(): + ndim = 4 + shape = np.random.randint(0, 20, size=ndim) + x = np.random.random(shape) + kern = reduce(np.outer, [[-1, 0, 0, 1]] * ndim).reshape((4,) * ndim) + px = np.pad(x, (2, 1), mode='reflect') + pxi = integral_image(px) + mean_fast = th.correlate_nonzeros(pxi, kern / 3 ** ndim) + mean_ref = ndi.correlate(x, np.ones((3,) * ndim) / 3 ** ndim, + mode='mirror') + np.testing.assert_allclose(mean_fast, mean_ref)