diff --git a/foolbox/attacks/__init__.py b/foolbox/attacks/__init__.py index e16af123..3d02e026 100644 --- a/foolbox/attacks/__init__.py +++ b/foolbox/attacks/__init__.py @@ -47,6 +47,7 @@ L2BrendelBethgeAttack, LinfinityBrendelBethgeAttack, ) +from .gen_attack import GenAttack # noqa: F401 # from .blended_noise import LinearSearchBlendedUniformNoiseAttack # noqa: F401 # from .brendel_bethge import ( # noqa: F401 diff --git a/foolbox/attacks/gen_attack.py b/foolbox/attacks/gen_attack.py new file mode 100644 index 00000000..633ba5af --- /dev/null +++ b/foolbox/attacks/gen_attack.py @@ -0,0 +1,255 @@ +from typing import Optional, Any, Tuple, Union +import numpy as np +import eagerpy as ep + +from ..devutils import atleast_kd + +from ..models import Model + +from ..criteria import TargetedMisclassification + +from ..distances import linf + +from .base import FixedEpsilonAttack +from .base import T +from .base import get_channel_axis +from .base import raise_if_kwargs +import math + +from .gen_attack_utils import rescale_images + + +class GenAttack(FixedEpsilonAttack): + """A black-box algorithm for L-infinity adversarials. [#Alz18]_ + + This attack is performs a genetic search in order to find an adversarial + perturbation in a black-box scenario in as few queries as possible. + + References: + .. [#Alz18] Moustafa Alzantot, Yash Sharma, Supriyo Chakraborty, Huan Zhang, + Cho-Jui Hsieh, Mani Srivastava, + "GenAttack: Practical Black-box Attacks with Gradient-Free + Optimization", + https://arxiv.org/abs/1805.11090 + + """ + + def __init__( + self, + *, + steps: int = 1000, + population: int = 10, + mutation_probability: float = 0.10, + mutation_range: float = 0.15, + sampling_temperature: float = 0.3, + channel_axis: Optional[int] = None, + reduced_dims: Optional[Tuple[int, int]] = None, + ): + self.steps = steps + self.population = population + self.min_mutation_probability = mutation_probability + self.min_mutation_range = mutation_range + self.sampling_temperature = sampling_temperature + self.channel_axis = channel_axis + self.reduced_dims = reduced_dims + + distance = linf + + def apply_noise( + self, + x: ep.TensorType, + noise: ep.TensorType, + epsilon: float, + channel_axis: Optional[int], + ) -> ep.TensorType: + if noise.shape != x.shape and channel_axis is not None: + # upscale noise + + noise = rescale_images(noise, x.shape, channel_axis) + + return ep.clip(noise + x, -epsilon, +epsilon) + + def choice( + self, a: int, size: Union[int, ep.TensorType], replace: bool, p: ep.TensorType + ) -> Any: + p = p.numpy() + x = np.random.choice(a, size, replace, p) + return x + + def run( + self, + model: Model, + inputs: T, + criterion: TargetedMisclassification, + *, + epsilon: float, + **kwargs: Any, + ) -> T: + raise_if_kwargs(kwargs) + x, restore_type = ep.astensor_(inputs) + del inputs, kwargs + + N = len(x) + + if isinstance(criterion, TargetedMisclassification): + classes = criterion.target_classes + else: + raise ValueError("unsupported criterion") + + if classes.shape != (N,): + raise ValueError( + f"expected target_classes to have shape ({N},), got {classes.shape}" + ) + + noise_shape: Union[Tuple[int, int, int, int], Tuple[int, ...]] + channel_axis: Optional[int] = None + if self.reduced_dims is not None: + if x.ndim != 4: + raise NotImplementedError( + "only implemented for inputs with two spatial dimensions" + " (and one channel and one batch dimension)" + ) + + if self.channel_axis is None: + maybe_axis = get_channel_axis(model, x.ndim) + if maybe_axis is None: + raise ValueError( + "cannot infer the data_format from the model, please" + " specify channel_axis when initializing the attack" + ) + else: + channel_axis = maybe_axis + else: + channel_axis = self.channel_axis % x.ndim + + if channel_axis == 1: + noise_shape = (x.shape[1], *self.reduced_dims) + elif channel_axis == 3: + noise_shape = (*self.reduced_dims, x.shape[3]) + else: + raise ValueError( + "expected 'channel_axis' to be 1 or 3, got {channel_axis}" + ) + else: + noise_shape = x.shape[1:] # pragma: no cover + + def is_adversarial(logits: ep.TensorType) -> ep.TensorType: + return ep.argmax(logits, 1) == classes + + num_plateaus = ep.zeros(x, len(x)) + mutation_probability = ( + ep.ones_like(num_plateaus) * self.min_mutation_probability + ) + mutation_range = ep.ones_like(num_plateaus) * self.min_mutation_range + + noise_pops = ep.uniform( + x, (N, self.population, *noise_shape), -epsilon, epsilon + ) + + def calculate_fitness(logits: ep.TensorType) -> ep.TensorType: + first = logits[range(N), classes] + second = ep.log(ep.exp(logits).sum(1) - first) + + return first - second + + n_its_wo_change = ep.zeros(x, (N,)) + for step in range(self.steps): + fitness_l, is_adv_l = [], [] + + for i in range(self.population): + it = self.apply_noise(x, noise_pops[:, i], epsilon, channel_axis) + logits = model(it) + f = calculate_fitness(logits) + a = is_adversarial(logits) + fitness_l.append(f) + is_adv_l.append(a) + + fitness = ep.stack(fitness_l) + is_adv = ep.stack(is_adv_l, 1) + elite_idxs = ep.argmax(fitness, 0) + + elite_noise = noise_pops[range(N), elite_idxs] + is_adv = is_adv[range(N), elite_idxs] + + # early stopping + if is_adv.all(): + return restore_type( # pragma: no cover + self.apply_noise(x, elite_noise, epsilon, channel_axis) + ) + + probs = ep.softmax(fitness / self.sampling_temperature, 0) + parents_idxs = np.stack( + [ + self.choice( + self.population, + 2 * self.population - 2, + replace=True, + p=probs[:, i], + ) + for i in range(N) + ], + 1, + ) + + mutations = [ + ep.uniform( + x, + noise_shape, + -mutation_range[i].item() * epsilon, + mutation_range[i].item() * epsilon, + ) + for i in range(N) + ] + + new_noise_pops = [elite_noise] + for i in range(0, self.population - 1): + parents_1 = noise_pops[range(N), parents_idxs[2 * i]] + parents_2 = noise_pops[range(N), parents_idxs[2 * i + 1]] + + # calculate crossover + p = probs[parents_idxs[2 * i], range(N)] / ( + probs[parents_idxs[2 * i], range(N)] + + probs[parents_idxs[2 * i + 1], range(N)] + ) + p = atleast_kd(p, x.ndim) + p = ep.tile(p, (1, *noise_shape)) + + crossover_mask = ep.uniform(p, p.shape, 0, 1) < p + children = ep.where(crossover_mask, parents_1, parents_2) + + # calculate mutation + mutation_mask = ep.uniform(children, children.shape) + mutation_mask = mutation_mask <= atleast_kd( + mutation_probability, children.ndim + ) + children = ep.where(mutation_mask, children + mutations[i], children) + + # project back to epsilon range + children = ep.clip(children, -epsilon, epsilon) + + new_noise_pops.append(children) + + noise_pops = ep.stack(new_noise_pops, 1) + + # increase num_plateaus if fitness does not improve + # for 100 consecutive steps + n_its_wo_change = ep.where( + elite_idxs == 0, n_its_wo_change + 1, ep.zeros_like(n_its_wo_change) + ) + num_plateaus = ep.where( + n_its_wo_change >= 100, num_plateaus + 1, num_plateaus + ) + n_its_wo_change = ep.where( + n_its_wo_change >= 100, ep.zeros_like(n_its_wo_change), n_its_wo_change + ) + + mutation_probability = ep.maximum( + self.min_mutation_probability, + 0.5 * ep.exp(math.log(0.9) * ep.ones_like(num_plateaus) * num_plateaus), + ) + mutation_range = ep.maximum( + self.min_mutation_range, + 0.5 * ep.exp(math.log(0.9) * ep.ones_like(num_plateaus) * num_plateaus), + ) + + return restore_type(self.apply_noise(x, elite_noise, epsilon, channel_axis)) diff --git a/foolbox/attacks/gen_attack_utils.py b/foolbox/attacks/gen_attack_utils.py new file mode 100644 index 00000000..6cf9c3d7 --- /dev/null +++ b/foolbox/attacks/gen_attack_utils.py @@ -0,0 +1,207 @@ +from typing import Union, List, Tuple +import eagerpy as ep + + +def rescale_jax(x: ep.JAXTensor, target_shape: List[int]) -> ep.JAXTensor: + # img must be in channel_last format + + # modified according to https://github.com/google/jax/issues/862 + import jax.numpy as np + + img = x.raw + + resize_rates = (target_shape[1] / x.shape[1], target_shape[2] / x.shape[2]) + + def interpolate_bilinear( # type: ignore + im: np.ndarray, rows: np.ndarray, cols: np.ndarray + ) -> np.ndarray: + # based on http://stackoverflow.com/a/12729229 + col_lo = np.floor(cols).astype(int) + col_hi = col_lo + 1 + row_lo = np.floor(rows).astype(int) + row_hi = row_lo + 1 + + def cclip(cols: np.ndarray) -> np.ndarray: # type: ignore + return np.clip(cols, 0, ncols - 1) + + def rclip(rows: np.ndarray) -> np.ndarray: # type: ignore + return np.clip(rows, 0, nrows - 1) + + nrows, ncols = im.shape[-3:-1] + + Ia = im[..., rclip(row_lo), cclip(col_lo), :] + Ib = im[..., rclip(row_hi), cclip(col_lo), :] + Ic = im[..., rclip(row_lo), cclip(col_hi), :] + Id = im[..., rclip(row_hi), cclip(col_hi), :] + + wa = np.expand_dims((col_hi - cols) * (row_hi - rows), -1) + wb = np.expand_dims((col_hi - cols) * (rows - row_lo), -1) + wc = np.expand_dims((cols - col_lo) * (row_hi - rows), -1) + wd = np.expand_dims((cols - col_lo) * (rows - row_lo), -1) + + return wa * Ia + wb * Ib + wc * Ic + wd * Id + + nrows, ncols = img.shape[-3:-1] + deltas = (0.5 / resize_rates[0], 0.5 / resize_rates[1]) + + rows = np.linspace(deltas[0], nrows - deltas[0], np.int32(resize_rates[0] * nrows)) + cols = np.linspace(deltas[1], ncols - deltas[1], np.int32(resize_rates[1] * ncols)) + rows_grid, cols_grid = np.meshgrid(rows - 0.5, cols - 0.5, indexing="ij") + + img_resize_vec = interpolate_bilinear(img, rows_grid.flatten(), cols_grid.flatten()) + img_resize = img_resize_vec.reshape( + img.shape[:-3] + (len(rows), len(cols)) + img.shape[-1:] + ) + + return ep.JAXTensor(img_resize) + + +def rescale_numpy(x: ep.NumPyTensor, target_shape: List[int]) -> ep.NumPyTensor: + import numpy as np + + img = x.raw + + resize_rates = (target_shape[1] / x.shape[1], target_shape[2] / x.shape[2]) + + def interpolate_bilinear( # type: ignore + im: np.ndarray, rows: np.ndarray, cols: np.ndarray + ) -> np.ndarray: + # based on http://stackoverflow.com/a/12729229 + col_lo = np.floor(cols).astype(int) + col_hi = col_lo + 1 + row_lo = np.floor(rows).astype(int) + row_hi = row_lo + 1 + + def cclip(cols: np.ndarray) -> np.ndarray: # type: ignore + return np.clip(cols, 0, ncols - 1) + + def rclip(rows: np.ndarray) -> np.ndarray: # type: ignore + return np.clip(rows, 0, nrows - 1) + + nrows, ncols = im.shape[-3:-1] + + Ia = im[..., rclip(row_lo), cclip(col_lo), :] + Ib = im[..., rclip(row_hi), cclip(col_lo), :] + Ic = im[..., rclip(row_lo), cclip(col_hi), :] + Id = im[..., rclip(row_hi), cclip(col_hi), :] + + wa = np.expand_dims((col_hi - cols) * (row_hi - rows), -1) + wb = np.expand_dims((col_hi - cols) * (rows - row_lo), -1) + wc = np.expand_dims((cols - col_lo) * (row_hi - rows), -1) + wd = np.expand_dims((cols - col_lo) * (rows - row_lo), -1) + + return wa * Ia + wb * Ib + wc * Ic + wd * Id + + nrows, ncols = img.shape[-3:-1] + deltas = (0.5 / resize_rates[0], 0.5 / resize_rates[1]) + + rows = np.linspace(deltas[0], nrows - deltas[0], np.int32(resize_rates[0] * nrows)) + cols = np.linspace(deltas[1], ncols - deltas[1], np.int32(resize_rates[1] * ncols)) + rows_grid, cols_grid = np.meshgrid(rows - 0.5, cols - 0.5, indexing="ij") + + img_resize_vec = interpolate_bilinear(img, rows_grid.flatten(), cols_grid.flatten()) + img_resize = img_resize_vec.reshape( + img.shape[:-3] + (len(rows), len(cols)) + img.shape[-1:] + ) + + return ep.NumPyTensor(img_resize) + + +def rescale_tensorflow( + x: ep.TensorFlowTensor, target_shape: List[int] +) -> ep.TensorFlowTensor: + import tensorflow + + img = x.raw + + img_resized = tensorflow.image.resize(img, size=target_shape[1:-1]) + + return ep.TensorFlowTensor(img_resized) + + +def rescale_pytorch(x: ep.PyTorchTensor, target_shape: List[int]) -> ep.PyTorchTensor: + import torch + + img = x.raw + + img_resized = torch.nn.functional.interpolate( + img, size=target_shape[2:], mode="bilinear", align_corners=False + ) + + return ep.PyTorchTensor(img_resized) + + +def swap_axes(x: ep.TensorType, dim0: int, dim1: int) -> ep.TensorType: + assert dim0 < x.ndim + assert dim1 < x.ndim + + axes = list(range(x.ndim)) + axes[dim0] = dim1 + axes[dim1] = dim0 + + return ep.transpose(x, tuple(axes)) + + +def rescale_images( + x: ep.TensorType, target_shape: Union[Tuple[int, ...], List[int]], channel_axis: int +) -> ep.TensorType: + target_shape = list(target_shape) + + if channel_axis < 0: + channel_axis = x.ndim - 1 + channel_axis + + if isinstance(x, ep.PyTorchTensor): + if channel_axis != 1: + x = swap_axes(x, channel_axis, 1) # type: ignore + + target_shape[channel_axis], target_shape[1] = ( + target_shape[1], + target_shape[channel_axis], + ) + + x = rescale_pytorch(x, target_shape) # type: ignore + + if channel_axis != 1: + x = swap_axes(x, channel_axis, 1) # type: ignore + + elif isinstance(x, ep.TensorFlowTensor): + if channel_axis != x.ndim - 1: + x = swap_axes(x, channel_axis, x.ndim - 1) # type: ignore + + target_shape[channel_axis], target_shape[x.ndim - 1] = ( + target_shape[x.ndim - 1], + target_shape[channel_axis], + ) + + x = rescale_tensorflow(x, target_shape) # type: ignore + + if channel_axis != x.ndim - 1: + x = swap_axes(x, channel_axis, x.ndim - 1) # type: ignore + + elif isinstance(x, ep.NumPyTensor): + if channel_axis != x.ndim - 1: + x = swap_axes(x, channel_axis, x.ndim - 1) # type: ignore + + target_shape[channel_axis], target_shape[x.ndim - 1] = ( + target_shape[x.ndim - 1], + target_shape[channel_axis], + ) + + x = rescale_numpy(x, target_shape) # type: ignore + if channel_axis != x.ndim - 1: + x = swap_axes(x, channel_axis, x.ndim - 1) # type: ignore + + elif isinstance(x, ep.JAXTensor): + if channel_axis != x.ndim - 1: + x = swap_axes(x, channel_axis, x.ndim - 1) # type: ignore + + target_shape[channel_axis], target_shape[x.ndim - 1] = ( + target_shape[x.ndim - 1], + target_shape[channel_axis], + ) + + x = rescale_jax(x, target_shape) # type: ignore + if channel_axis != x.ndim - 1: + x = swap_axes(x, channel_axis, x.ndim - 1) # type: ignore + + return x diff --git a/tests/test_attacks.py b/tests/test_attacks.py index 37c139a6..16ed244c 100644 --- a/tests/test_attacks.py +++ b/tests/test_attacks.py @@ -155,6 +155,7 @@ def test_untargeted_attacks( True, False, ), + (fa.GenAttack(steps=100, population=6, reduced_dims=(14, 14)), 0.3, False, True), ] diff --git a/tests/test_attacks_raise.py b/tests/test_attacks_raise.py index ea4e7528..8d66e4cb 100644 --- a/tests/test_attacks_raise.py +++ b/tests/test_attacks_raise.py @@ -12,6 +12,33 @@ def test_ead_init_raises() -> None: fbn.attacks.EADAttack(binary_search_steps=3, steps=20, decision_rule="invalid") # type: ignore +def test_genattack_numpy(request: Any) -> None: + class Model: + def __call__(self, inputs: Any) -> Any: + return inputs.mean(axis=(2, 3)) + + model = Model() + with pytest.raises(ValueError): + fbn.NumPyModel(model, bounds=(0, 1), data_format="foo") + + fmodel = fbn.NumPyModel(model, bounds=(0, 1)) + x, y = ep.astensors( + *fbn.samples( + fmodel, dataset="imagenet", batchsize=16, data_format="channels_first" + ) + ) + + with pytest.raises(ValueError, match="data_format"): + fbn.attacks.GenAttack(reduced_dims=(2, 2)).run( + fmodel, x, fbn.TargetedMisclassification(y), epsilon=0.3 + ) + + with pytest.raises(ValueError, match="channel_axis"): + fbn.attacks.GenAttack(channel_axis=2, reduced_dims=(2, 2)).run( + fmodel, x, fbn.TargetedMisclassification(y), epsilon=0.3 + ) + + def test_deepfool_run_raises( fmodel_and_data_ext_for_attacks: Tuple[Tuple[fbn.Model, ep.Tensor, ep.Tensor], bool] ) -> None: @@ -153,6 +180,7 @@ def test_dataset_attack_raises( (fbn.attacks.EADAttack(), True), (fbn.attacks.DDNAttack(), True), (fbn.attacks.L2CarliniWagnerAttack(), True), + (fbn.attacks.GenAttack(), False), ] diff --git a/tests/test_gen_attack_utils.py b/tests/test_gen_attack_utils.py new file mode 100644 index 00000000..64849279 --- /dev/null +++ b/tests/test_gen_attack_utils.py @@ -0,0 +1,81 @@ +import eagerpy as ep + +from foolbox.attacks.gen_attack_utils import rescale_images + + +def test_pytorch_numpy_compatibility() -> None: + import numpy as np + import torch + + x_np = np.random.uniform(0.0, 1.0, size=(16, 3, 64, 64)) + x_torch = torch.from_numpy(x_np) + + x_np_ep = ep.astensor(x_np) + x_torch_ep = ep.astensor(x_torch) + + x_up_np_ep = rescale_images(x_np_ep, (16, 3, 128, 128), 1) + x_up_torch_ep = rescale_images(x_torch_ep, (16, 3, 128, 128), 1) + + x_up_np = x_up_np_ep.numpy() + x_up_torch = x_up_torch_ep.numpy() + + assert np.allclose(x_up_np, x_up_torch) + + +def test_pytorch_numpy_compatibility_different_axis() -> None: + import numpy as np + import torch + + x_np = np.random.uniform(0.0, 1.0, size=(16, 64, 64, 3)) + x_torch = torch.from_numpy(x_np) + + x_np_ep = ep.astensor(x_np) + x_torch_ep = ep.astensor(x_torch) + + x_up_np_ep = rescale_images(x_np_ep, (16, 128, 128, 3), -1) + x_up_torch_ep = rescale_images(x_torch_ep, (16, 128, 128, 3), -1) + + x_up_np = x_up_np_ep.numpy() + x_up_torch = x_up_torch_ep.numpy() + + assert np.allclose(x_up_np, x_up_torch) + + +def test_pytorch_tensorflow_compatibility() -> None: + import numpy as np + import torch + import tensorflow as tf + + x_np = np.random.uniform(0.0, 1.0, size=(16, 3, 64, 64)) + x_torch = torch.from_numpy(x_np) + x_tf = tf.convert_to_tensor(x_np) + + x_tf_ep = ep.astensor(x_tf) + x_torch_ep = ep.astensor(x_torch) + + x_up_tf_ep = rescale_images(x_tf_ep, (16, 3, 128, 128), 1) + x_up_torch_ep = rescale_images(x_torch_ep, (16, 3, 128, 128), 1) + + x_up_tf = x_up_tf_ep.numpy() + x_up_torch = x_up_torch_ep.numpy() + + assert np.allclose(x_up_tf, x_up_torch) + + +def test_jax_numpy_compatibility() -> None: + import numpy as np + import jax.numpy as jnp + + x_np = np.random.uniform(0.0, 1.0, size=(16, 3, 64, 64)) + x_jax = jnp.array(x_np) + + x_np_ep = ep.astensor(x_np) + x_jax_ep = ep.astensor(x_jax) + + x_up_np_ep = rescale_images(x_np_ep, (16, 3, 128, 128), 1) + x_up_jax_ep = rescale_images(x_jax_ep, (16, 3, 128, 128), 1) + + x_up_np = x_up_np_ep.numpy() + x_up_jax = x_up_jax_ep.numpy() + + assert np.allclose(x_up_np, x_up_jax)