diff --git a/docs/modules/attacks.rst b/docs/modules/attacks.rst index fe52c2b9..295819a4 100644 --- a/docs/modules/attacks.rst +++ b/docs/modules/attacks.rst @@ -62,13 +62,14 @@ :nosignatures: BoundaryAttack + SpatialAttack + PointwiseAttack GaussianBlurAttack ContrastReductionAttack AdditiveUniformNoiseAttack AdditiveGaussianNoiseAttack SaltAndPepperNoiseAttack BlendedUniformNoiseAttack - PointwiseAttack .. rubric:: :doc:`attacks/other` diff --git a/docs/modules/attacks/decision.rst b/docs/modules/attacks/decision.rst index 2f8f66f4..a0b10c2c 100644 --- a/docs/modules/attacks/decision.rst +++ b/docs/modules/attacks/decision.rst @@ -7,6 +7,14 @@ Decision-based attacks :members: :special-members: +.. autoclass:: SpatialAttack + :members: + :special-members: + +.. autoclass:: PointwiseAttack + :members: + :special-members: + .. autoclass:: GaussianBlurAttack :members: :special-members: @@ -32,7 +40,3 @@ Decision-based attacks .. autoclass:: BlendedUniformNoiseAttack :members: :special-members: - -.. autoclass:: PointwiseAttack - :members: - :special-members: diff --git a/foolbox/attacks/__init__.py b/foolbox/attacks/__init__.py index 9a5533b8..6eaabcd5 100644 --- a/foolbox/attacks/__init__.py +++ b/foolbox/attacks/__init__.py @@ -19,6 +19,7 @@ from .binarization import BinarizationRefinementAttack from .newtonfool import NewtonFoolAttack from .adef_attack import ADefAttack +from .spatial import SpatialAttack from .carlini_wagner import CarliniWagnerL2Attack from .iterative_projected_gradient import LinfinityBasicIterativeAttack, BasicIterativeMethod, BIM diff --git a/foolbox/attacks/spatial.py b/foolbox/attacks/spatial.py new file mode 100644 index 00000000..cb83cf16 --- /dev/null +++ b/foolbox/attacks/spatial.py @@ -0,0 +1,130 @@ +# -*- coding: utf-8 -*- +from __future__ import division + +import numpy as np +from itertools import product +from scipy.ndimage import rotate, shift +import operator + +from .base import Attack +from .base import call_decorator + + +class SpatialAttack(Attack): + """Adversarially chosen rotations and translations [1]. This implementation + is based on the reference implementation by Madry et al. + https://github.com/MadryLab/adversarial_spatial + + References + ---------- + .. [1] Logan Engstrom*, Brandon Tran*, Dimitris Tsipras*, + Ludwig Schmidt, Aleksander MÄ…dry: "A Rotation and a + Translation Suffice: Fooling CNNs with Simple Transformations", + http://arxiv.org/abs/1712.02779 + """ + + @call_decorator + def __call__(self, input_or_adv, label=None, unpack=True, + do_rotations=True, do_translations=True, + x_shift_limits=(-5, 5), y_shift_limits=(-5, 5), + angular_limits=(-5, 5), granularity=10, + random_sampling=False, abort_early=True): + + """Adversarially chosen rotations and translations. + + Parameters + ---------- + input_or_adv : `numpy.ndarray` or :class:`Adversarial` + The original, unperturbed input as a `numpy.ndarray` or + an :class:`Adversarial` instance. + label : int + The reference label of the original input. Must be passed + if `a` is a `numpy.ndarray`, must not be passed if `a` is + an :class:`Adversarial` instance. + unpack : bool + If true, returns the adversarial input, otherwise returns + the Adversarial object. + do_rotations : bool + If False no rotations will be applied to the image. + do_translations : bool + If False no translations will be applied to the image. + x_shift_limits : int or (int, int) + Limits for horizontal translations in pixels. If one integer is + provided the limits will be (-x_shift_limits, x_shift_limits). + y_shift_limits : int or (int, int) + Limits for vertical translations in pixels. If one integer is + provided the limits will be (-y_shift_limits, y_shift_limits). + angular_limits : int or (int, int) + Limits for rotations in degrees. If one integer is + provided the limits will be [-angular_limits, angular_limits]. + granularity : int + Density of sampling within limits for each dimension. + random_sampling : bool + If True we sample translations/rotations randomly within limits, + otherwise we use a regular grid. + abort_early : bool + If True, the attack stops as soon as it finds an adversarial. + """ + + a = input_or_adv + del input_or_adv + del label + del unpack + + min_, max_ = a.bounds() + channel_axis = a.channel_axis(batch=False) + + def get_samples(limits, num, do_flag): + # get regularly spaced or random samples within limits + lb, up = (-limits, limits) if isinstance(limits, int) else limits + + if not do_flag: + return [0] + elif random_sampling: + return np.random.uniform(lb, up, num) + else: + return np.linspace(lb, up, num) + + def crop_center(img): + # crop center of the image (of the size of the original image) + start = tuple(map(lambda a, da: (a - da) // 2, img.shape, + a.original_image.shape)) + end = tuple(map(operator.add, start, a.original_image.shape)) + slices = tuple(map(slice, start, end)) + return img[slices] + + x_shifts = get_samples(x_shift_limits, granularity, do_translations) + y_shifts = get_samples(y_shift_limits, granularity, do_translations) + rotations = get_samples(angular_limits, granularity, do_rotations) + + transformations = product(x_shifts, y_shifts, rotations) + + for x_shift, y_shift, angle in transformations: + if channel_axis == 0: + xy_shift = (0, x_shift, y_shift) + axes = (1, 2) + elif channel_axis == 2: + xy_shift = (x_shift, y_shift, 0) + axes = (0, 1) + else: + raise ValueError('SpatialAttack only supports models ' + 'and inputs with NCHW or NHWC format') + + # rotate image (increases size) + x = a.original_image + x = rotate(x, angle=angle, axes=axes, reshape=True, order=1) + + # translate image + x = shift(x, shift=xy_shift, mode='constant') + + # crop center + x = crop_center(x) + + # ensure values are in range + x = np.clip(x, min_, max_) + + # test image + _, is_adv = a.predictions(x) + + if abort_early and is_adv: + break diff --git a/foolbox/tests/test_attacks_spatial.py b/foolbox/tests/test_attacks_spatial.py new file mode 100644 index 00000000..1de8dc28 --- /dev/null +++ b/foolbox/tests/test_attacks_spatial.py @@ -0,0 +1,59 @@ +import numpy as np + +from foolbox.attacks import SpatialAttack as Attack + + +def test_attack_pytorch(bn_adversarial_pytorch): + adv = bn_adversarial_pytorch + attack = Attack() + attack(adv) + assert adv.image is not None + assert adv.distance.value < np.inf + + +def test_attack(bn_adversarial): + adv = bn_adversarial + attack = Attack() + attack(adv) + assert adv.image is not None + assert adv.distance.value < np.inf + + +def test_attack_rnd(bn_adversarial): + adv = bn_adversarial + attack = Attack() + attack(adv, random_sampling=True) + assert adv.image is not None + assert adv.distance.value < np.inf + + +def test_attack_norot(bn_adversarial): + adv = bn_adversarial + attack = Attack() + attack(adv, do_rotations=False) + assert adv.image is not None + assert adv.distance.value < np.inf + + +def test_attack_notrans(bn_adversarial): + adv = bn_adversarial + attack = Attack() + attack(adv, do_translations=False) + assert adv.image is not None + assert adv.distance.value < np.inf + + +def test_attack_notrans_norot(bn_adversarial): + adv = bn_adversarial + attack = Attack() + attack(adv, do_translations=False, do_rotations=False) + assert adv.image is None + assert adv.distance.value == np.inf + + +def test_attack_gl(gl_bn_adversarial): + adv = gl_bn_adversarial + attack = Attack() + attack(adv) + assert adv.image is not None + assert adv.distance.value < np.inf