-
-
Notifications
You must be signed in to change notification settings - Fork 422
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
Showing
5 changed files
with
351 additions
and
2 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,245 @@ | ||
from typing import Union, Any, Optional, Tuple, Callable, List | ||
import eagerpy as ep | ||
import numpy as np | ||
import logging | ||
|
||
from ..criteria import Criterion | ||
|
||
from .base import FlexibleDistanceMinimizationAttack | ||
from .saltandpepper import SaltAndPepperNoiseAttack | ||
|
||
from ..devutils import flatten | ||
from .base import Model | ||
from .base import MinimizationAttack | ||
from .base import get_is_adversarial | ||
from .base import get_criterion | ||
from .base import T | ||
from .base import raise_if_kwargs | ||
|
||
|
||
class PointwiseAttack(FlexibleDistanceMinimizationAttack): | ||
""" Starts with an adversarial and performs a binary search between | ||
the adversarial and the original for each dimension of the input | ||
individually. [#Sch18]_ | ||
References: | ||
.. [#Sch18] Lukas Schott, Jonas Rauber, Matthias Bethge, Wieland Brendel, | ||
"Towards the first adversarially robust neural network model on MNIST", | ||
https://arxiv.org/abs/1805.09190 | ||
""" | ||
|
||
def __init__( | ||
self, | ||
init_attack: Optional[MinimizationAttack] = None, | ||
l2_binary_search: bool = True, | ||
): | ||
self.init_attack = init_attack | ||
self.l2_binary_search = l2_binary_search | ||
|
||
def run( | ||
self, | ||
model: Model, | ||
inputs: T, | ||
criterion: Union[Criterion, Any] = None, | ||
*, | ||
starting_points: Optional[ep.Tensor] = None, | ||
early_stop: Optional[float] = None, | ||
**kwargs: Any, | ||
) -> T: | ||
raise_if_kwargs(kwargs) | ||
del kwargs | ||
|
||
x, restore_type = ep.astensor_(inputs) | ||
del inputs | ||
|
||
criterion_ = get_criterion(criterion) | ||
del criterion | ||
is_adversarial = get_is_adversarial(criterion_, model) | ||
|
||
if starting_points is None: | ||
init_attack: MinimizationAttack | ||
if self.init_attack is None: | ||
init_attack = SaltAndPepperNoiseAttack() | ||
logging.info( | ||
f"Neither starting_points nor init_attack given. Falling" | ||
f" back to {init_attack!r} for initialization." | ||
) | ||
else: | ||
init_attack = self.init_attack | ||
# TODO: use call and support all types of attacks (once early_stop is | ||
# possible in __call__) | ||
starting_points = init_attack.run(model, x, criterion_) | ||
|
||
x_adv = ep.astensor(starting_points) | ||
assert is_adversarial(x_adv).all() | ||
|
||
original_shape = x.shape | ||
N = len(x) | ||
|
||
x_flat = flatten(x) | ||
x_adv_flat = flatten(x_adv) | ||
|
||
# was there a pixel left in the samples to manipulate, | ||
# i.e. reset to the clean version? | ||
found_index_to_manipulate = ep.from_numpy(x, np.ones(N, dtype=bool)) | ||
|
||
while ep.any(found_index_to_manipulate): | ||
diff_mask = (ep.abs(x_flat - x_adv_flat) > 1e-8).numpy() | ||
diff_idxs = [z.nonzero()[0] for z in diff_mask] | ||
untouched_indices = [z.tolist() for z in diff_idxs] | ||
untouched_indices = [ | ||
np.random.permutation(it).tolist() for it in untouched_indices | ||
] | ||
|
||
found_index_to_manipulate = ep.from_numpy(x, np.zeros(N, dtype=bool)) | ||
|
||
# since the number of pixels still left to manipulate might differ | ||
# across different samples we track each of them separately and | ||
# and manipulate the images until there is no pixel left for | ||
# any of the samples. to not update already finished samples, we mask | ||
# the updates such that only samples that still have pixels left to manipulate | ||
# will be updated | ||
i = 0 | ||
while i < max([len(it) for it in untouched_indices]): | ||
# mask all samples that still have pixels to manipulate left | ||
relevant_mask = [len(it) > i for it in untouched_indices] | ||
relevant_mask = np.array(relevant_mask, dtype=bool) | ||
relevant_mask_index = np.flatnonzero(relevant_mask) | ||
|
||
# for each image get the index of the next pixel we try out | ||
relevant_indices = [it[i] for it in untouched_indices if len(it) > i] | ||
|
||
old_values = x_adv_flat[relevant_mask_index, relevant_indices] | ||
new_values = x_flat[relevant_mask_index, relevant_indices] | ||
x_adv_flat = ep.index_update( | ||
x_adv_flat, (relevant_mask_index, relevant_indices), new_values | ||
) | ||
|
||
# check if still adversarial | ||
is_adv = is_adversarial(x_adv_flat.reshape(original_shape)) | ||
found_index_to_manipulate = ep.index_update( | ||
found_index_to_manipulate, | ||
relevant_mask_index, | ||
ep.logical_or(found_index_to_manipulate, is_adv)[relevant_mask], | ||
) | ||
|
||
# if not, undo change | ||
new_or_old_values = ep.where( | ||
is_adv[relevant_mask], new_values, old_values | ||
) | ||
x_adv_flat = ep.index_update( | ||
x_adv_flat, | ||
(relevant_mask_index, relevant_indices), | ||
new_or_old_values, | ||
) | ||
|
||
i += 1 | ||
|
||
if not ep.any(found_index_to_manipulate): | ||
break | ||
|
||
if self.l2_binary_search: | ||
while True: | ||
diff_mask = (ep.abs(x_flat - x_adv_flat) > 1e-12).numpy() | ||
diff_idxs = [z.nonzero()[0] for z in diff_mask] | ||
untouched_indices = [z.tolist() for z in diff_idxs] | ||
# draw random shuffling of all indices for all samples | ||
untouched_indices = [ | ||
np.random.permutation(it).tolist() for it in untouched_indices | ||
] | ||
|
||
# whether that run through all values made any improvement | ||
improved = ep.from_numpy(x, np.zeros(N, dtype=bool)).astype(bool) | ||
|
||
logging.info("Starting new loop through all values") | ||
|
||
# use the same logic as above | ||
i = 0 | ||
while i < max([len(it) for it in untouched_indices]): | ||
# mask all samples that still have pixels to manipulate left | ||
relevant_mask = [len(it) > i for it in untouched_indices] | ||
relevant_mask = np.array(relevant_mask, dtype=bool) | ||
relevant_mask_index = np.flatnonzero(relevant_mask) | ||
|
||
# for each image get the index of the next pixel we try out | ||
relevant_indices = [ | ||
it[i] for it in untouched_indices if len(it) > i | ||
] | ||
|
||
old_values = x_adv_flat[relevant_mask_index, relevant_indices] | ||
new_values = x_flat[relevant_mask_index, relevant_indices] | ||
|
||
x_adv_flat = ep.index_update( | ||
x_adv_flat, (relevant_mask_index, relevant_indices), new_values | ||
) | ||
|
||
# check if still adversarial | ||
is_adv = is_adversarial(x_adv_flat.reshape(original_shape)) | ||
|
||
improved = ep.index_update( | ||
improved, | ||
relevant_mask_index, | ||
ep.logical_or(improved, is_adv)[relevant_mask], | ||
) | ||
|
||
if not ep.all(is_adv): | ||
# run binary search for examples that became non-adversarial | ||
updated_new_values = self._binary_search( | ||
x_adv_flat, | ||
relevant_mask, | ||
relevant_mask_index, | ||
relevant_indices, | ||
old_values, | ||
new_values, | ||
(-1, *original_shape[1:]), | ||
is_adversarial, | ||
) | ||
x_adv_flat = ep.index_update( | ||
x_adv_flat, | ||
(relevant_mask_index, relevant_indices), | ||
ep.where( | ||
is_adv[relevant_mask], new_values, updated_new_values | ||
), | ||
) | ||
|
||
improved = ep.index_update( | ||
improved, | ||
relevant_mask_index, | ||
ep.logical_or( | ||
old_values != updated_new_values, | ||
improved[relevant_mask], | ||
), | ||
) | ||
|
||
i += 1 | ||
|
||
if not ep.any(improved): | ||
# no improvement for any of the indices | ||
break | ||
|
||
x_adv = x_adv_flat.reshape(original_shape) | ||
|
||
return restore_type(x_adv) | ||
|
||
def _binary_search( | ||
self, | ||
x_adv_flat: ep.Tensor, | ||
mask: Union[ep.Tensor, List[bool]], | ||
mask_indices: ep.Tensor, | ||
indices: Union[ep.Tensor, List[int]], | ||
adv_values: ep.Tensor, | ||
non_adv_values: ep.Tensor, | ||
original_shape: Tuple, | ||
is_adversarial: Callable, | ||
) -> ep.Tensor: | ||
for i in range(10): | ||
next_values = (adv_values + non_adv_values) / 2 | ||
x_adv_flat = ep.index_update( | ||
x_adv_flat, (mask_indices, indices), next_values | ||
) | ||
is_adv = is_adversarial(x_adv_flat.reshape(original_shape))[mask] | ||
|
||
adv_values = ep.where(is_adv, next_values, adv_values) | ||
non_adv_values = ep.where(is_adv, non_adv_values, next_values) | ||
|
||
return adv_values |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,101 @@ | ||
from typing import List, Any | ||
import eagerpy as ep | ||
|
||
import foolbox as fbn | ||
import foolbox.attacks as fa | ||
from foolbox.devutils import flatten | ||
import pytest | ||
|
||
from conftest import ModeAndDataAndDescription | ||
|
||
|
||
def get_attack_id(x: fa.Attack) -> str: | ||
return repr(x) | ||
|
||
|
||
attacks: List[fa.Attack] = [ | ||
fa.PointwiseAttack(), | ||
fa.PointwiseAttack(l2_binary_search=False), | ||
] | ||
|
||
|
||
@pytest.mark.parametrize("attack", attacks, ids=get_attack_id) | ||
def test_pointwise_untargeted_attack( | ||
request: Any, | ||
fmodel_and_data_ext_for_attacks: ModeAndDataAndDescription, | ||
attack: fa.PointwiseAttack, | ||
) -> None: | ||
(fmodel, x, y), real, low_dimensional_input = fmodel_and_data_ext_for_attacks | ||
|
||
if not low_dimensional_input or not real: | ||
pytest.skip() | ||
|
||
x = (x - fmodel.bounds.lower) / (fmodel.bounds.upper - fmodel.bounds.lower) | ||
fmodel = fmodel.transform_bounds((0, 1)) | ||
|
||
init_attack = fa.SaltAndPepperNoiseAttack(steps=50) | ||
init_advs = init_attack.run(fmodel, x, y) | ||
|
||
advs = attack.run(fmodel, x, y, starting_points=init_advs) | ||
|
||
init_norms_l0 = ep.norms.lp(flatten(init_advs - x), p=0, axis=-1) | ||
norms_l0 = ep.norms.lp(flatten(advs - x), p=0, axis=-1) | ||
|
||
init_norms_l2 = ep.norms.lp(flatten(init_advs - x), p=2, axis=-1) | ||
norms_l2 = ep.norms.lp(flatten(advs - x), p=2, axis=-1) | ||
|
||
is_smaller_l0 = norms_l0 < init_norms_l0 | ||
is_smaller_l2 = norms_l2 < init_norms_l2 | ||
|
||
assert fbn.accuracy(fmodel, advs, y) < fbn.accuracy(fmodel, x, y) | ||
assert fbn.accuracy(fmodel, advs, y) <= fbn.accuracy(fmodel, init_advs, y) | ||
assert is_smaller_l2.any() | ||
assert is_smaller_l0.any() | ||
|
||
|
||
@pytest.mark.parametrize("attack", attacks, ids=get_attack_id) | ||
def test_pointwise_targeted_attack( | ||
request: Any, | ||
fmodel_and_data_ext_for_attacks: ModeAndDataAndDescription, | ||
attack: fa.PointwiseAttack, | ||
) -> None: | ||
(fmodel, x, y), real, low_dimensional_input = fmodel_and_data_ext_for_attacks | ||
|
||
if not low_dimensional_input or not real: | ||
pytest.skip() | ||
|
||
x = (x - fmodel.bounds.lower) / (fmodel.bounds.upper - fmodel.bounds.lower) | ||
fmodel = fmodel.transform_bounds((0, 1)) | ||
|
||
init_attack = fa.SaltAndPepperNoiseAttack(steps=50) | ||
init_advs = init_attack.run(fmodel, x, y) | ||
|
||
logits = fmodel(init_advs) | ||
num_classes = logits.shape[-1] | ||
target_classes = logits.argmax(-1) | ||
target_classes = ep.where( | ||
target_classes == y, (target_classes + 1) % num_classes, target_classes | ||
) | ||
criterion = fbn.TargetedMisclassification(target_classes) | ||
|
||
advs = attack.run(fmodel, x, criterion, starting_points=init_advs) | ||
|
||
init_norms_l0 = ep.norms.lp(flatten(init_advs - x), p=0, axis=-1) | ||
norms_l0 = ep.norms.lp(flatten(advs - x), p=0, axis=-1) | ||
|
||
init_norms_l2 = ep.norms.lp(flatten(init_advs - x), p=2, axis=-1) | ||
norms_l2 = ep.norms.lp(flatten(advs - x), p=2, axis=-1) | ||
|
||
is_smaller_l0 = norms_l0 < init_norms_l0 | ||
is_smaller_l2 = norms_l2 < init_norms_l2 | ||
|
||
assert fbn.accuracy(fmodel, advs, y) < fbn.accuracy(fmodel, x, y) | ||
assert fbn.accuracy(fmodel, advs, y) <= fbn.accuracy(fmodel, init_advs, y) | ||
assert fbn.accuracy(fmodel, advs, target_classes) > fbn.accuracy( | ||
fmodel, x, target_classes | ||
) | ||
assert fbn.accuracy(fmodel, advs, target_classes) >= fbn.accuracy( | ||
fmodel, init_advs, target_classes | ||
) | ||
assert is_smaller_l2.any() | ||
assert is_smaller_l0.any() |