Skip to content

Commit

Permalink
Skip tests for high-dimensional inputs for some attacks (#660)
Browse files Browse the repository at this point in the history
* Skip tests for high-dimensional inputs for some attacks

* Replace pre-trained imagenet models with faster ones and reduce batch size

* Fix typo

* Make target test for HSJA more robust for small batch sizes

* List duration of unit tests

* Remove debugging commands

* Use correct datatypes

* Skip binarization test for imagenet models since they have low accuracy

* Use correct preprocessing for mobilenetv3

* Reformat

* Fix small bug in GenAttack triggered when batch size is smaller than population size

* Make tests less likely to fail due to randomness
  • Loading branch information
zimmerrol committed Feb 1, 2022
1 parent cf4e42b commit e7d3aa9
Show file tree
Hide file tree
Showing 14 changed files with 196 additions and 110 deletions.
2 changes: 1 addition & 1 deletion .github/workflows/tests.yml
Original file line number Diff line number Diff line change
Expand Up @@ -96,7 +96,7 @@ jobs:
mypy tests/
- name: Test with pytest (backend ${{ matrix.backend }})
run: |
pytest --cov-report term-missing --cov=foolbox --verbose --backend ${{ matrix.backend }}
pytest --durations=0 --cov-report term-missing --cov=foolbox --verbose --backend ${{ matrix.backend }}
- name: Codecov
continue-on-error: true
env:
Expand Down
17 changes: 14 additions & 3 deletions foolbox/attacks/dataset_attack.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,11 +84,22 @@ def run(
result = x
found = criterion(x, model(x))

dataset_size = len(self.inputs)
batch_size = len(x)

while not found.all():
indices = np.random.randint(0, dataset_size, size=(batch_size,))
# for every sample try every other sample
index_pools: List[List[int]] = []
for i in range(batch_size):
indices = list(range(batch_size))
indices.remove(i)
indices = list(indices)
np.random.shuffle(indices)
index_pools.append(indices)

for i in range(batch_size - 1):
if found.all():
break

indices = np.array([pool[i] for pool in index_pools])

xp = self.inputs[indices]
yp = self.outputs[indices]
Expand Down
29 changes: 16 additions & 13 deletions foolbox/attacks/gen_attack.py
Original file line number Diff line number Diff line change
Expand Up @@ -195,18 +195,8 @@ def calculate_fitness(logits: ep.TensorType) -> ep.TensorType:
1,
)

mutations = [
ep.uniform(
x,
noise_shape,
-mutation_range[i].item() * epsilon,
mutation_range[i].item() * epsilon,
)
for i in range(N)
]

new_noise_pops = [elite_noise]
for i in range(0, self.population - 1):
for i in range(self.population - 1):
parents_1 = noise_pops[range(N), parents_idxs[2 * i]]
parents_2 = noise_pops[range(N), parents_idxs[2 * i + 1]]

Expand All @@ -222,11 +212,24 @@ def calculate_fitness(logits: ep.TensorType) -> ep.TensorType:
children = ep.where(crossover_mask, parents_1, parents_2)

# calculate mutation
mutations = ep.stack(
[
ep.uniform(
x,
noise_shape,
-mutation_range[i].item() * epsilon,
mutation_range[i].item() * epsilon,
)
for i in range(N)
],
0,
)

mutation_mask = ep.uniform(children, children.shape)
mutation_mask = mutation_mask <= atleast_kd(
mutation_probability, children.ndim
)
children = ep.where(mutation_mask, children + mutations[i], children)
children = ep.where(mutation_mask, children + mutations, children)

# project back to epsilon range
children = ep.clip(children, -epsilon, epsilon)
Expand All @@ -253,7 +256,7 @@ def calculate_fitness(logits: ep.TensorType) -> ep.TensorType:
)
mutation_range = ep.maximum(
self.min_mutation_range,
0.5 * ep.exp(math.log(0.9) * ep.ones_like(num_plateaus) * num_plateaus),
0.4 * ep.exp(math.log(0.9) * ep.ones_like(num_plateaus) * num_plateaus),
)

return restore_type(self.apply_noise(x, elite_noise, epsilon, channel_axis))
2 changes: 1 addition & 1 deletion foolbox/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -109,7 +109,7 @@ def _samples(
labels.append(label)

images_ = np.stack(images)
labels_ = np.array(labels)
labels_ = np.array(labels).astype(np.int64)

if bounds != (0, 255):
images_ = images_ / 255 * (bounds[1] - bounds[0]) + bounds[0]
Expand Down
68 changes: 34 additions & 34 deletions tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,20 +7,20 @@
import foolbox as fbn

ModelAndData = Tuple[fbn.Model, ep.Tensor, ep.Tensor]
CallableModelAndDescription = NamedTuple(
"CallableModelAndDescription",
CallableModelAndDataAndDescription = NamedTuple(
"CallableModelAndDataAndDescription",
[
("model_fn", Callable[..., ModelAndData]),
("real", bool),
("low_dimensional_input", bool),
],
)
ModelDescriptionAndData = NamedTuple(
"ModelDescriptionAndData",
ModeAndDataAndDescription = NamedTuple(
"ModeAndDataAndDescription",
[("model_and_data", ModelAndData), ("real", bool), ("low_dimensional_input", bool)],
)

models: Dict[str, CallableModelAndDescription] = {}
models: Dict[str, CallableModelAndDataAndDescription] = {}
models_for_attacks: List[str] = []


Expand Down Expand Up @@ -55,7 +55,7 @@ def model(request: Any) -> ModelAndData:
global models
global real_models

models[model.__name__] = CallableModelAndDescription(
models[model.__name__] = CallableModelAndDataAndDescription(
model_fn=model, real=real, low_dimensional_input=low_dimensional_input
)
if attack:
Expand All @@ -82,23 +82,23 @@ def forward(self, x: torch.Tensor) -> torch.Tensor:
model, bounds=bounds, device=device, preprocessing=preprocessing
)

x, _ = fbn.samples(fmodel, dataset="imagenet", batchsize=16)
x, _ = fbn.samples(fmodel, dataset="imagenet", batchsize=8)
x = ep.astensor(x)
y = fmodel(x).argmax(axis=-1)
return fmodel, x, y


@register("pytorch")
@register("pytorch", low_dimensional_input=True)
def pytorch_simple_model_default(request: Any) -> ModelAndData:
return pytorch_simple_model()


@register("pytorch")
@register("pytorch", low_dimensional_input=True)
def pytorch_simple_model_default_flip(request: Any) -> ModelAndData:
return pytorch_simple_model(preprocessing=dict(flip_axis=-3))


@register("pytorch", attack=False)
@register("pytorch", attack=False, low_dimensional_input=True)
def pytorch_simple_model_default_cpu_native_tensor(request: Any) -> ModelAndData:
import torch

Expand All @@ -107,19 +107,19 @@ def pytorch_simple_model_default_cpu_native_tensor(request: Any) -> ModelAndData
return pytorch_simple_model("cpu", preprocessing=dict(mean=mean, std=std, axis=-3))


@register("pytorch", attack=False)
@register("pytorch", attack=False, low_dimensional_input=True)
def pytorch_simple_model_default_cpu_eagerpy_tensor(request: Any) -> ModelAndData:
mean = 0.05 * ep.torch.arange(3).float32()
std = ep.torch.ones(3) * 2
return pytorch_simple_model("cpu", preprocessing=dict(mean=mean, std=std, axis=-3))


@register("pytorch", attack=False)
@register("pytorch", attack=False, low_dimensional_input=True)
def pytorch_simple_model_string(request: Any) -> ModelAndData:
return pytorch_simple_model("cpu")


@register("pytorch", attack=False)
@register("pytorch", attack=False, low_dimensional_input=True)
def pytorch_simple_model_object(request: Any) -> ModelAndData:
import torch

Expand All @@ -131,24 +131,24 @@ def pytorch_mnist(request: Any) -> ModelAndData:
fmodel = fbn.zoo.ModelLoader.get().load(
"examples/zoo/mnist/", module_name="foolbox_model"
)
x, y = fbn.samples(fmodel, dataset="mnist", batchsize=16)
x, y = fbn.samples(fmodel, dataset="mnist", batchsize=8)
x = ep.astensor(x)
y = ep.astensor(y)
return fmodel, x, y


@register("pytorch", real=True)
def pytorch_resnet18(request: Any) -> ModelAndData:
def pytorch_shufflenetv2(request: Any) -> ModelAndData:
if request.config.option.skipslow:
pytest.skip()

import torchvision.models as models

model = models.resnet18(pretrained=True).eval()
model = models.shufflenet_v2_x0_5(pretrained=True).eval()
preprocessing = dict(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225], axis=-3)
fmodel = fbn.PyTorchModel(model, bounds=(0, 1), preprocessing=preprocessing)

x, y = fbn.samples(fmodel, dataset="imagenet", batchsize=16)
x, y = fbn.samples(fmodel, dataset="imagenet", batchsize=8)
x = ep.astensor(x)
y = ep.astensor(y)
return fmodel, x, y
Expand All @@ -167,7 +167,7 @@ def tensorflow_simple_sequential(
model, bounds=bounds, device=device, preprocessing=preprocessing
)

x, _ = fbn.samples(fmodel, dataset="cifar10", batchsize=16)
x, _ = fbn.samples(fmodel, dataset="cifar10", batchsize=8)
x = ep.astensor(x)
y = fmodel(x).argmax(axis=-1)
return fmodel, x, y
Expand Down Expand Up @@ -211,7 +211,7 @@ def call(self, x: tf.Tensor) -> tf.Tensor: # type: ignore
bounds = (0, 1)
fmodel = fbn.TensorFlowModel(model, bounds=bounds)

x, _ = fbn.samples(fmodel, dataset="cifar10", batchsize=16)
x, _ = fbn.samples(fmodel, dataset="cifar10", batchsize=8)
x = ep.astensor(x)
y = fmodel(x).argmax(axis=-1)
return fmodel, x, y
Expand All @@ -231,25 +231,25 @@ def tensorflow_simple_functional(request: Any) -> ModelAndData:
bounds = (0, 1)
fmodel = fbn.TensorFlowModel(model, bounds=bounds)

x, _ = fbn.samples(fmodel, dataset="imagenet", batchsize=16)
x, _ = fbn.samples(fmodel, dataset="imagenet", batchsize=8)
x = ep.astensor(x)
y = fmodel(x).argmax(axis=-1)
return fmodel, x, y


@register("tensorflow", real=True)
def tensorflow_mobilenetv2(request: Any) -> ModelAndData:
def tensorflow_mobilenetv3(request: Any) -> ModelAndData:
if request.config.option.skipslow:
pytest.skip()

import tensorflow as tf

model = tf.keras.applications.MobileNetV2(weights="imagenet")
fmodel = fbn.TensorFlowModel(
model, bounds=(0, 255), preprocessing=dict(mean=127.5, std=127.5)
model = tf.keras.applications.MobileNetV3Small(
weights="imagenet", minimalistic=True
)
fmodel = fbn.TensorFlowModel(model, bounds=(0, 255), preprocessing=None,)

x, y = fbn.samples(fmodel, dataset="imagenet", batchsize=16)
x, y = fbn.samples(fmodel, dataset="imagenet", batchsize=8)
x = ep.astensor(x)
y = ep.astensor(y)
return fmodel, x, y
Expand All @@ -269,7 +269,7 @@ def tensorflow_resnet50(request: Any) -> ModelAndData:
preprocessing = dict(flip_axis=-1, mean=[104.0, 116.0, 123.0]) # RGB to BGR
fmodel = fbn.TensorFlowModel(model, bounds=(0, 255), preprocessing=preprocessing)

x, y = fbn.samples(fmodel, dataset="imagenet", batchsize=16)
x, y = fbn.samples(fmodel, dataset="imagenet", batchsize=8)
x = ep.astensor(x)
y = ep.astensor(y)
return fmodel, x, y
Expand All @@ -286,7 +286,7 @@ def model(x: Any) -> Any:
fmodel = fbn.JAXModel(model, bounds=bounds)

x, _ = fbn.samples(
fmodel, dataset="cifar10", batchsize=16, data_format="channels_last"
fmodel, dataset="cifar10", batchsize=8, data_format="channels_last"
)
x = ep.astensor(x)
y = fmodel(x).argmax(axis=-1)
Expand All @@ -305,33 +305,33 @@ def __call__(self, inputs: Any) -> Any:

fmodel = fbn.NumPyModel(model, bounds=(0, 1))
with pytest.raises(ValueError, match="data_format"):
x, _ = fbn.samples(fmodel, dataset="imagenet", batchsize=16)
x, _ = fbn.samples(fmodel, dataset="imagenet", batchsize=8)

fmodel = fbn.NumPyModel(model, bounds=(0, 1), data_format="channels_first")
with pytest.warns(UserWarning, match="returning NumPy arrays"):
x, _ = fbn.samples(fmodel, dataset="imagenet", batchsize=16)
x, _ = fbn.samples(fmodel, dataset="imagenet", batchsize=8)

x = ep.astensor(x)
y = fmodel(x).argmax(axis=-1)
return fmodel, x, y


@pytest.fixture(scope="session", params=list(models.keys()))
def fmodel_and_data_ext(request: Any) -> ModelDescriptionAndData:
def fmodel_and_data_ext(request: Any) -> ModeAndDataAndDescription:
global models
model_description = models[request.param]
model_and_data = model_description.model_fn(request)
return ModelDescriptionAndData(model_and_data, *model_description[1:])
return ModeAndDataAndDescription(model_and_data, *model_description[1:])


@pytest.fixture(scope="session", params=models_for_attacks)
def fmodel_and_data_ext_for_attacks(request: Any) -> ModelDescriptionAndData:
def fmodel_and_data_ext_for_attacks(request: Any) -> ModeAndDataAndDescription:
global models
model_description = models[request.param]
model_and_data = model_description.model_fn(request)
return ModelDescriptionAndData(model_and_data, *model_description[1:])
return ModeAndDataAndDescription(model_and_data, *model_description[1:])


@pytest.fixture(scope="session")
def fmodel_and_data(fmodel_and_data_ext: ModelDescriptionAndData) -> ModelAndData:
def fmodel_and_data(fmodel_and_data_ext: ModeAndDataAndDescription) -> ModelAndData:
return fmodel_and_data_ext.model_and_data

0 comments on commit e7d3aa9

Please sign in to comment.