Skip to content

Commit

Permalink
Rename flag
Browse files Browse the repository at this point in the history
  • Loading branch information
zimmerrol committed Jul 6, 2021
1 parent a73542d commit de48aca
Show file tree
Hide file tree
Showing 2 changed files with 15 additions and 15 deletions.
20 changes: 10 additions & 10 deletions tests/conftest.py
Expand Up @@ -9,11 +9,11 @@
ModelAndData = Tuple[fbn.Model, ep.Tensor, ep.Tensor]
CallableModelAndDescription = NamedTuple(
"CallableModelAndDescription",
[("model_fn", Callable[..., ModelAndData]), ("real", bool), ("small", bool)],
[("model_fn", Callable[..., ModelAndData]), ("real", bool), ("small_input", bool)],
)
ModelDescriptionAndData = NamedTuple(
"ModelDescriptionAndData",
[("model_and_data", ModelAndData), ("real", bool), ("small", bool)],
[("model_and_data", ModelAndData), ("real", bool), ("small_input", bool)],
)

models: Dict[str, CallableModelAndDescription] = {}
Expand All @@ -35,7 +35,7 @@ def dummy(request: Any) -> ep.Tensor:


def register(
backend: str, *, real: bool = False, small: bool = False, attack: bool = True
backend: str, *, real: bool = False, small_input: bool = False, attack: bool = True
) -> Callable[[Callable], Callable]:
def decorator(f: Callable[[Any], ModelAndData]) -> Callable[[Any], ModelAndData]:
@functools.wraps(f)
Expand All @@ -48,7 +48,7 @@ def model(request: Any) -> ModelAndData:
global real_models

models[model.__name__] = CallableModelAndDescription(
model_fn=model, real=real, small=small
model_fn=model, real=real, small_input=small_input
)
if attack:
models_for_attacks.append(model.__name__)
Expand Down Expand Up @@ -118,7 +118,7 @@ def pytorch_simple_model_object(request: Any) -> ModelAndData:
return pytorch_simple_model(torch.device("cpu"))


@register("pytorch", real=True, small=True)
@register("pytorch", real=True, small_input=True)
def pytorch_mnist(request: Any) -> ModelAndData:
fmodel = fbn.zoo.ModelLoader.get().load(
"examples/zoo/mnist/", module_name="foolbox_model"
Expand Down Expand Up @@ -165,12 +165,12 @@ def tensorflow_simple_sequential(
return fmodel, x, y


@register("tensorflow", small=True)
@register("tensorflow", small_input=True)
def tensorflow_simple_sequential_cpu(request: Any) -> ModelAndData:
return tensorflow_simple_sequential("cpu", None)


@register("tensorflow", small=True)
@register("tensorflow", small_input=True)
def tensorflow_simple_sequential_native_tensors(request: Any) -> ModelAndData:
import tensorflow as tf

Expand All @@ -179,14 +179,14 @@ def tensorflow_simple_sequential_native_tensors(request: Any) -> ModelAndData:
return tensorflow_simple_sequential("cpu", dict(mean=mean, std=std))


@register("tensorflow", small=True)
@register("tensorflow", small_input=True)
def tensorflow_simple_sequential_eagerpy_tensors(request: Any) -> ModelAndData:
mean = ep.tensorflow.zeros(1)
std = ep.tensorflow.ones(1) * 255.0
return tensorflow_simple_sequential("cpu", dict(mean=mean, std=std))


@register("tensorflow", small=True)
@register("tensorflow", small_input=True)
def tensorflow_simple_subclassing(request: Any) -> ModelAndData:
import tensorflow as tf

Expand Down Expand Up @@ -267,7 +267,7 @@ def tensorflow_resnet50(request: Any) -> ModelAndData:
return fmodel, x, y


@register("jax", small=True)
@register("jax", small_input=True)
def jax_simple_model(request: Any) -> ModelAndData:
import jax

Expand Down
10 changes: 5 additions & 5 deletions tests/test_attacks.py
Expand Up @@ -22,7 +22,7 @@ class AttackTestTarget(NamedTuple):
epsilon: Optional[float] = None
uses_grad: Optional[bool] = False
requires_real_model: Optional[bool] = False
requires_small_model: Optional[bool] = False
requires_small_input: Optional[bool] = False


def get_attack_id(x: AttackTestTarget) -> str:
Expand Down Expand Up @@ -127,10 +127,10 @@ def test_untargeted_attacks(
attack_test_target: AttackTestTarget,
) -> None:

(fmodel, x, y), real, small = fmodel_and_data_ext_for_attacks
(fmodel, x, y), real, small_input = fmodel_and_data_ext_for_attacks
if attack_test_target.requires_real_model and not real:
pytest.skip()
if attack_test_target.requires_small_model and not small:
if attack_test_target.requires_small_input and not small_input:
pytest.skip()
if isinstance(x, ep.NumPyTensor) and attack_test_target.uses_grad:
pytest.skip()
Expand Down Expand Up @@ -186,10 +186,10 @@ def test_targeted_attacks(
attack_test_target: AttackTestTarget,
) -> None:

(fmodel, x, y), real, small = fmodel_and_data_ext_for_attacks
(fmodel, x, y), real, small_input = fmodel_and_data_ext_for_attacks
if attack_test_target.requires_real_model and not real:
pytest.skip()
if attack_test_target.requires_small_model and not small:
if attack_test_target.requires_small_input and not small_input:
pytest.skip()

if isinstance(x, ep.NumPyTensor) and attack_test_target.uses_grad:
Expand Down

0 comments on commit de48aca

Please sign in to comment.