Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

adopt torch.testing.assert_close #1031

Merged
merged 5 commits into from Jun 27, 2021
Merged
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
45 changes: 44 additions & 1 deletion kornia/testing/__init__.py
@@ -1,14 +1,16 @@
"""
The testing package contains testing-specific utilities.
"""
import contextlib
import importlib
from abc import ABC, abstractmethod
from copy import deepcopy
from itertools import product
from typing import Any, Optional

import torch

__all__ = ['tensor_to_gradcheck_var', 'create_eye_batch', 'xla_is_available']
__all__ = ['tensor_to_gradcheck_var', 'create_eye_batch', 'xla_is_available', 'assert_close']


def xla_is_available() -> bool:
Expand Down Expand Up @@ -128,3 +130,44 @@ def _get_precision_by_name(
return tol_val

return tol_val_default


try:
pmeier marked this conversation as resolved.
Show resolved Hide resolved
# torch.testing.assert_close is only available for torch>=1.9
from torch.testing import assert_close as _assert_close
from torch.testing._core import _get_default_tolerance

def assert_close(
actual: torch.Tensor,
expected: torch.Tensor,
*,
rtol: Optional[float] = None,
atol: Optional[float] = None,
**kwargs: Any,
) -> None:
if rtol is None and atol is None:
with contextlib.suppress(Exception):
rtol, atol = _get_default_tolerance(actual, expected)

return _assert_close(actual, expected, rtol=rtol, atol=atol, check_stride=False, equal_nan=True, **kwargs)


except ImportError:
# Partial backport of torch.testing.assert_close for torch<1.9
from torch.testing import assert_allclose as _assert_allclose

class UsageError(Exception):
pass

def assert_close(
actual: torch.Tensor,
expected: torch.Tensor,
*,
rtol: Optional[float] = None,
atol: Optional[float] = None,
**kwargs: Any,
) -> None:
try:
return _assert_allclose(actual, expected, rtol=rtol, atol=atol, **kwargs)
except ValueError as error:
raise UsageError(str(error)) from error
303 changes: 151 additions & 152 deletions test/augmentation/test_augmentation.py

Large diffs are not rendered by default.

108 changes: 54 additions & 54 deletions test/augmentation/test_augmentation_3d.py
Expand Up @@ -4,7 +4,6 @@
import torch
import torch.nn as nn
from torch.autograd import gradcheck
from torch.testing import assert_allclose

import kornia
import kornia.testing as utils # test utils
Expand All @@ -16,6 +15,7 @@
RandomRotation3D,
RandomVerticalFlip3D,
)
from kornia.testing import assert_close


class TestRandomHorizontalFlip3D:
Expand Down Expand Up @@ -192,12 +192,12 @@ def test_random_vflip(self, device, dtype):
dtype=dtype,
) # 1 x 4 x 4

assert_allclose(f(input)[0], expected)
assert_allclose(f(input)[1], expected_transform)
assert_allclose(f1(input)[0], input)
assert_allclose(f1(input)[1], identity)
assert_allclose(f2(input), expected)
assert_allclose(f3(input), input)
assert_close(f(input)[0], expected)
assert_close(f(input)[1], expected_transform)
assert_close(f1(input)[0], input)
assert_close(f1(input)[1], identity)
assert_close(f2(input), expected)
assert_close(f3(input), input)

def test_batch_random_vflip(self, device):

Expand Down Expand Up @@ -225,10 +225,10 @@ def test_batch_random_vflip(self, device):
expected_transform = expected_transform.repeat(5, 1, 1) # 5 x 4 x 4
identity = identity.repeat(5, 1, 1) # 5 x 4 x 4

assert_allclose(f(input)[0], expected)
assert_allclose(f(input)[1], expected_transform)
assert_allclose(f1(input)[0], input)
assert_allclose(f1(input)[1], identity)
assert_close(f(input)[0], expected)
assert_close(f(input)[1], expected_transform)
assert_close(f1(input)[0], input)
assert_close(f1(input)[1], identity)

def test_same_on_batch(self, device):
f = RandomVerticalFlip3D(p=0.5, same_on_batch=True)
Expand All @@ -253,10 +253,10 @@ def test_sequential(self, device):

expected_transform_1 = expected_transform @ expected_transform

assert_allclose(f(input)[0], input)
assert_allclose(f(input)[1], expected_transform_1)
assert_allclose(f1(input)[0], input)
assert_allclose(f1(input)[1], expected_transform)
assert_close(f(input)[0], input)
assert_close(f(input)[1], expected_transform_1)
assert_close(f1(input)[0], input)
assert_close(f1(input)[1], expected_transform)

def test_gradcheck(self, device):
input = torch.rand((1, 3, 3)).to(device) # 4 x 4
Expand Down Expand Up @@ -319,12 +319,12 @@ def test_random_dflip(self, device, dtype):
dtype=dtype,
) # 4 x 4

assert_allclose(f(input)[0], expected)
assert_allclose(f(input)[1], expected_transform)
assert_allclose(f1(input)[0], input)
assert_allclose(f1(input)[1], identity)
assert_allclose(f2(input), expected)
assert_allclose(f3(input), input)
assert_close(f(input)[0], expected)
assert_close(f(input)[1], expected_transform)
assert_close(f1(input)[0], input)
assert_close(f1(input)[1], identity)
assert_close(f2(input), expected)
assert_close(f3(input), input)

def test_batch_random_dflip(self, device):

Expand Down Expand Up @@ -363,10 +363,10 @@ def test_batch_random_dflip(self, device):
expected_transform = expected_transform.repeat(5, 1, 1) # 5 x 4 x 4
identity = identity.repeat(5, 1, 1) # 5 x 4 x 4

assert_allclose(f(input)[0], expected)
assert_allclose(f(input)[1], expected_transform)
assert_allclose(f1(input)[0], input)
assert_allclose(f1(input)[1], identity)
assert_close(f(input)[0], expected)
assert_close(f(input)[1], expected_transform)
assert_close(f1(input)[0], input)
assert_close(f1(input)[1], identity)

def test_same_on_batch(self, device):
f = RandomDepthicalFlip3D(p=0.5, same_on_batch=True)
Expand Down Expand Up @@ -400,10 +400,10 @@ def test_sequential(self, device):

expected_transform_1 = expected_transform @ expected_transform

assert_allclose(f(input)[0], input)
assert_allclose(f(input)[1], expected_transform_1)
assert_allclose(f1(input)[0], input)
assert_allclose(f1(input)[1], expected_transform)
assert_close(f(input)[0], input)
assert_close(f(input)[1], expected_transform_1)
assert_close(f1(input)[0], input)
assert_close(f1(input)[1], expected_transform)

def test_gradcheck(self, device):
input = torch.rand((1, 3, 3)).to(device) # 4 x 4
Expand Down Expand Up @@ -488,11 +488,11 @@ def test_random_rotation(self, device, dtype):
)

out, mat = f(input)
assert_allclose(out, expected, rtol=1e-6, atol=1e-4)
assert_allclose(mat, expected_transform, rtol=1e-6, atol=1e-4)
assert_close(out, expected, rtol=1e-6, atol=1e-4)
assert_close(mat, expected_transform, rtol=1e-6, atol=1e-4)

torch.manual_seed(0) # for random reproductibility
assert_allclose(f1(input), expected, rtol=1e-6, atol=1e-4)
assert_close(f1(input), expected, rtol=1e-6, atol=1e-4)

def test_batch_random_rotation(self, device, dtype):

Expand Down Expand Up @@ -585,8 +585,8 @@ def test_batch_random_rotation(self, device, dtype):
input = input.repeat(2, 1, 1, 1, 1) # 5 x 4 x 4 x 3

out, mat = f(input)
assert_allclose(out, expected, rtol=1e-6, atol=1e-4)
assert_allclose(mat, expected_transform, rtol=1e-6, atol=1e-4)
assert_close(out, expected, rtol=1e-6, atol=1e-4)
assert_close(mat, expected_transform, rtol=1e-6, atol=1e-4)

def test_same_on_batch(self, device, dtype):
f = RandomRotation3D(degrees=40, same_on_batch=True)
Expand Down Expand Up @@ -671,9 +671,9 @@ def test_sequential(self, device, dtype):

out, mat = f(input)
_, mat_2 = f1(input)
assert_allclose(out, expected, rtol=1e-6, atol=1e-4)
assert_allclose(mat, expected_transform, rtol=1e-6, atol=1e-4)
assert_allclose(mat_2, expected_transform_2, rtol=1e-6, atol=1e-4)
assert_close(out, expected, rtol=1e-6, atol=1e-4)
assert_close(mat, expected_transform, rtol=1e-6, atol=1e-4)
assert_close(mat_2, expected_transform_2, rtol=1e-6, atol=1e-4)

def test_gradcheck(self, device):

Expand Down Expand Up @@ -759,7 +759,7 @@ def test_no_padding(self, batch_size, device, dtype):
dtype=dtype,
)

assert_allclose(out, expected, atol=1e-4, rtol=1e-4)
assert_close(out, expected, atol=1e-4, rtol=1e-4)

def test_same_on_batch(self, device, dtype):
f = RandomCrop3D(size=(2, 3, 4), padding=None, align_corners=True, p=1.0, same_on_batch=True)
Expand Down Expand Up @@ -795,7 +795,7 @@ def test_padding_batch(self, padding, device, dtype):
f = RandomCrop3D(size=(2, 3, 4), fill=10.0, padding=padding, align_corners=True, p=1.0)
out = f(inp)

assert_allclose(out, expected, atol=1e-4, rtol=1e-4)
assert_close(out, expected, atol=1e-4, rtol=1e-4)

def test_pad_if_needed(self, device, dtype):
torch.manual_seed(42)
Expand All @@ -815,7 +815,7 @@ def test_pad_if_needed(self, device, dtype):
rc = RandomCrop3D(size=(2, 3, 4), pad_if_needed=True, fill=9, align_corners=True, p=1.0)
out = rc(inp)

assert_allclose(out, expected, atol=1e-4, rtol=1e-4)
assert_close(out, expected, atol=1e-4, rtol=1e-4)

def test_gradcheck(self, device, dtype):
torch.manual_seed(0) # for random reproductibility
Expand All @@ -832,7 +832,7 @@ def test_jit(self, device, dtype):

actual = op_script(img)
expected = kornia.center_crop3d(img)
assert_allclose(actual, expected)
assert_close(actual, expected)

@pytest.mark.skip("Need to fix Union type")
def test_jit_trace(self, device, dtype):
Expand All @@ -850,7 +850,7 @@ def test_jit_trace(self, device, dtype):
# 3. Evaluate
actual = op_trace(img)
expected = op(img)
assert_allclose(actual, expected)
assert_close(actual, expected)


class TestCenterCrop3D:
Expand Down Expand Up @@ -905,12 +905,12 @@ def test_random_equalize(self, device, dtype):

identity = kornia.eye_like(4, expected)

assert_allclose(f(inputs3d)[0], expected, rtol=1e-4, atol=1e-4)
assert_allclose(f(inputs3d)[1], identity, rtol=1e-4, atol=1e-4)
assert_allclose(f1(inputs3d)[0], inputs3d, rtol=1e-4, atol=1e-4)
assert_allclose(f1(inputs3d)[1], identity, rtol=1e-4, atol=1e-4)
assert_allclose(f2(inputs3d), expected, rtol=1e-4, atol=1e-4)
assert_allclose(f3(inputs3d), inputs3d, rtol=1e-4, atol=1e-4)
assert_close(f(inputs3d)[0], expected, rtol=1e-4, atol=1e-4)
assert_close(f(inputs3d)[1], identity, rtol=1e-4, atol=1e-4)
assert_close(f1(inputs3d)[0], inputs3d, rtol=1e-4, atol=1e-4)
assert_close(f1(inputs3d)[1], identity, rtol=1e-4, atol=1e-4)
assert_close(f2(inputs3d), expected, rtol=1e-4, atol=1e-4)
assert_close(f3(inputs3d), inputs3d, rtol=1e-4, atol=1e-4)

def test_batch_random_equalize(self, device, dtype):
f = RandomEqualize3D(p=1.0, return_transform=True)
Expand All @@ -927,12 +927,12 @@ def test_batch_random_equalize(self, device, dtype):

identity = kornia.eye_like(4, expected) # 2 x 4 x 4

assert_allclose(f(inputs3d)[0], expected, rtol=1e-4, atol=1e-4)
assert_allclose(f(inputs3d)[1], identity, rtol=1e-4, atol=1e-4)
assert_allclose(f1(inputs3d)[0], inputs3d, rtol=1e-4, atol=1e-4)
assert_allclose(f1(inputs3d)[1], identity, rtol=1e-4, atol=1e-4)
assert_allclose(f2(inputs3d), expected, rtol=1e-4, atol=1e-4)
assert_allclose(f3(inputs3d), inputs3d, rtol=1e-4, atol=1e-4)
assert_close(f(inputs3d)[0], expected, rtol=1e-4, atol=1e-4)
assert_close(f(inputs3d)[1], identity, rtol=1e-4, atol=1e-4)
assert_close(f1(inputs3d)[0], inputs3d, rtol=1e-4, atol=1e-4)
assert_close(f1(inputs3d)[1], identity, rtol=1e-4, atol=1e-4)
assert_close(f2(inputs3d), expected, rtol=1e-4, atol=1e-4)
assert_close(f3(inputs3d), inputs3d, rtol=1e-4, atol=1e-4)

def test_same_on_batch(self, device, dtype):
f = RandomEqualize3D(p=0.5, same_on_batch=True)
Expand Down