-
-
Notifications
You must be signed in to change notification settings - Fork 422
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge pull request #523 from bethgelab/genattacktests
full rewrite of gen_attack_utils tests, fixes #522
- Loading branch information
Showing
2 changed files
with
30 additions
and
61 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,81 +1,42 @@ | ||
import eagerpy as ep | ||
import numpy as np | ||
import pytest | ||
from typing import Any | ||
|
||
from foolbox.attacks.gen_attack_utils import rescale_images | ||
|
||
|
||
def test_pytorch_numpy_compatibility() -> None: | ||
import numpy as np | ||
import torch | ||
def test_rescale_axis(request: Any, dummy: ep.Tensor) -> None: | ||
backend = request.config.option.backend | ||
if backend == "numpy": | ||
pytest.skip() | ||
|
||
x_np = np.random.uniform(0.0, 1.0, size=(16, 3, 64, 64)) | ||
x_torch = torch.from_numpy(x_np) | ||
|
||
x_np_ep = ep.astensor(x_np) | ||
x_torch_ep = ep.astensor(x_torch) | ||
|
||
x_up_np_ep = rescale_images(x_np_ep, (16, 3, 128, 128), 1) | ||
x_up_torch_ep = rescale_images(x_torch_ep, (16, 3, 128, 128), 1) | ||
|
||
x_up_np = x_up_np_ep.numpy() | ||
x_up_torch = x_up_torch_ep.numpy() | ||
|
||
assert np.allclose(x_up_np, x_up_torch) | ||
x = ep.from_numpy(dummy, x_np) | ||
x_ep = ep.astensor(x) | ||
x_up_ep = rescale_images(x_ep, (16, 3, 128, 128), 1) | ||
x_up = x_up_ep.numpy() | ||
|
||
assert np.allclose(x_up_np, x_up) | ||
|
||
def test_pytorch_numpy_compatibility_different_axis() -> None: | ||
import numpy as np | ||
import torch | ||
|
||
x_np = np.random.uniform(0.0, 1.0, size=(16, 64, 64, 3)) | ||
x_torch = torch.from_numpy(x_np) | ||
def test_rescale_axis_nhwc(request: Any, dummy: ep.Tensor) -> None: | ||
backend = request.config.option.backend | ||
if backend == "numpy": | ||
pytest.skip() | ||
|
||
x_np = np.random.uniform(0.0, 1.0, size=(16, 64, 64, 3)) | ||
x_np_ep = ep.astensor(x_np) | ||
x_torch_ep = ep.astensor(x_torch) | ||
|
||
x_up_np_ep = rescale_images(x_np_ep, (16, 128, 128, 3), -1) | ||
x_up_torch_ep = rescale_images(x_torch_ep, (16, 128, 128, 3), -1) | ||
|
||
x_up_np = x_up_np_ep.numpy() | ||
x_up_torch = x_up_torch_ep.numpy() | ||
|
||
assert np.allclose(x_up_np, x_up_torch) | ||
|
||
|
||
def test_pytorch_tensorflow_compatibility() -> None: | ||
import numpy as np | ||
import torch | ||
import tensorflow as tf | ||
|
||
x_np = np.random.uniform(0.0, 1.0, size=(16, 3, 64, 64)) | ||
x_torch = torch.from_numpy(x_np) | ||
x_tf = tf.convert_to_tensor(x_np) | ||
|
||
x_tf_ep = ep.astensor(x_tf) | ||
x_torch_ep = ep.astensor(x_torch) | ||
|
||
x_up_tf_ep = rescale_images(x_tf_ep, (16, 3, 128, 128), 1) | ||
x_up_torch_ep = rescale_images(x_torch_ep, (16, 3, 128, 128), 1) | ||
|
||
x_up_tf = x_up_tf_ep.numpy() | ||
x_up_torch = x_up_torch_ep.numpy() | ||
|
||
assert np.allclose(x_up_tf, x_up_torch) | ||
|
||
|
||
def test_jax_numpy_compatibility() -> None: | ||
import numpy as np | ||
import jax.numpy as jnp | ||
|
||
x_np = np.random.uniform(0.0, 1.0, size=(16, 3, 64, 64)) | ||
x_jax = jnp.array(x_np) | ||
|
||
x_np_ep = ep.astensor(x_np) | ||
x_jax_ep = ep.astensor(x_jax) | ||
|
||
x_up_np_ep = rescale_images(x_np_ep, (16, 3, 128, 128), 1) | ||
x_up_jax_ep = rescale_images(x_jax_ep, (16, 3, 128, 128), 1) | ||
|
||
x_up_np = x_up_np_ep.numpy() | ||
x_up_jax = x_up_jax_ep.numpy() | ||
x = ep.from_numpy(dummy, x_np) | ||
x_ep = ep.astensor(x) | ||
x_up_ep = rescale_images(x_ep, (16, 128, 128, 3), -1) | ||
x_up = x_up_ep.numpy() | ||
|
||
assert np.allclose(x_up_np, x_up_jax) | ||
assert np.allclose(x_up_np, x_up) |