Skip to content

Commit

Permalink
Fix a bug where SSIM fails on non-square images
Browse files Browse the repository at this point in the history
This PR also adds unit tests that probe non-square images.

PiperOrigin-RevId: 401296829
  • Loading branch information
jonbarron authored and Copybara-Service committed Oct 6, 2021
1 parent c2593d9 commit 4887a55
Show file tree
Hide file tree
Showing 2 changed files with 8 additions and 4 deletions.
2 changes: 1 addition & 1 deletion dm_pix/_src/metrics.py
Original file line number Diff line number Diff line change
Expand Up @@ -187,7 +187,7 @@ def filt_fn_y(z):
def filt_fn_x(z):
z_flat = jnp.moveaxis(z, -2, -1).reshape((-1, z.shape[-2]))
z_filt_shape = ((z.shape[-4],) if z.ndim == 4 else
()) + (z.shape[-2], z.shape[-1], -1)
()) + (z.shape[-3], z.shape[-1], -1)
return jnp.moveaxis(filt_fn_vmap(z_flat).reshape(z_filt_shape), -1, -2)

# Apply the blur in both x and y.
Expand Down
10 changes: 7 additions & 3 deletions dm_pix/_src/metrics_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,7 +64,7 @@ def test_ssim_golden(self):
"""Test that the SSIM implementation matches the Tensorflow version."""

key = jax.random.PRNGKey(0)
for shape in (2, 12, 12, 3), (12, 12, 3):
for shape in ((2, 12, 12, 3), (12, 12, 3), (2, 12, 15, 3), (17, 12, 3)):
for _ in range(4):
(max_val_key, img0_key, img1_key, filter_size_key, filter_sigma_key,
k1_key, k2_key, key) = jax.random.split(key, 8)
Expand Down Expand Up @@ -99,9 +99,13 @@ def test_ssim_golden(self):
))
ssim = ssim_fn(img0, img1)
if not return_map:
self.assertAllClose(ssim, ssim_gt)
self.assertAllClose(ssim, ssim_gt, atol=1e-5, rtol=1e-5)
else:
self.assertAllClose(np.mean(ssim, list(range(-3, 0))), ssim_gt)
self.assertAllClose(
np.mean(ssim, list(range(-3, 0))),
ssim_gt,
atol=1e-5,
rtol=1e-5)
self.assertLessEqual(np.max(ssim), 1.)
self.assertGreaterEqual(np.min(ssim), -1.)

Expand Down

0 comments on commit 4887a55

Please sign in to comment.