-
-
Notifications
You must be signed in to change notification settings - Fork 422
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge pull request #227 from wielandbrendel/spatial
Implementation of Spatial Attack (WIP)
- Loading branch information
Showing
6 changed files
with
220 additions
and
9 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,131 @@ | ||
# -*- coding: utf-8 -*- | ||
from __future__ import division | ||
|
||
import numpy as np | ||
from itertools import product | ||
from scipy.ndimage import rotate, shift | ||
import operator | ||
|
||
from .base import Attack | ||
from .base import call_decorator | ||
|
||
|
||
class SpatialAttack(Attack): | ||
"""Adversarially chosen rotations and translations [1]. | ||
This implementation is based on the reference implementation by | ||
Madry et al.: https://github.com/MadryLab/adversarial_spatial | ||
References | ||
---------- | ||
.. [1] Logan Engstrom*, Brandon Tran*, Dimitris Tsipras*, | ||
Ludwig Schmidt, Aleksander Mądry: "A Rotation and a | ||
Translation Suffice: Fooling CNNs with Simple Transformations", | ||
http://arxiv.org/abs/1712.02779 | ||
""" | ||
|
||
@call_decorator | ||
def __call__(self, input_or_adv, label=None, unpack=True, | ||
do_rotations=True, do_translations=True, | ||
x_shift_limits=(-5, 5), y_shift_limits=(-5, 5), | ||
angular_limits=(-5, 5), granularity=10, | ||
random_sampling=False, abort_early=True): | ||
|
||
"""Adversarially chosen rotations and translations. | ||
Parameters | ||
---------- | ||
input_or_adv : `numpy.ndarray` or :class:`Adversarial` | ||
The original, unperturbed input as a `numpy.ndarray` or | ||
an :class:`Adversarial` instance. | ||
label : int | ||
The reference label of the original input. Must be passed | ||
if `a` is a `numpy.ndarray`, must not be passed if `a` is | ||
an :class:`Adversarial` instance. | ||
unpack : bool | ||
If true, returns the adversarial input, otherwise returns | ||
the Adversarial object. | ||
do_rotations : bool | ||
If False no rotations will be applied to the image. | ||
do_translations : bool | ||
If False no translations will be applied to the image. | ||
x_shift_limits : int or (int, int) | ||
Limits for horizontal translations in pixels. If one integer is | ||
provided the limits will be (-x_shift_limits, x_shift_limits). | ||
y_shift_limits : int or (int, int) | ||
Limits for vertical translations in pixels. If one integer is | ||
provided the limits will be (-y_shift_limits, y_shift_limits). | ||
angular_limits : int or (int, int) | ||
Limits for rotations in degrees. If one integer is | ||
provided the limits will be [-angular_limits, angular_limits]. | ||
granularity : int | ||
Density of sampling within limits for each dimension. | ||
random_sampling : bool | ||
If True we sample translations/rotations randomly within limits, | ||
otherwise we use a regular grid. | ||
abort_early : bool | ||
If True, the attack stops as soon as it finds an adversarial. | ||
""" | ||
|
||
a = input_or_adv | ||
del input_or_adv | ||
del label | ||
del unpack | ||
|
||
min_, max_ = a.bounds() | ||
channel_axis = a.channel_axis(batch=False) | ||
|
||
def get_samples(limits, num, do_flag): | ||
# get regularly spaced or random samples within limits | ||
lb, up = (-limits, limits) if isinstance(limits, int) else limits | ||
|
||
if not do_flag: | ||
return [0] | ||
elif random_sampling: | ||
return np.random.uniform(lb, up, num) | ||
else: | ||
return np.linspace(lb, up, num) | ||
|
||
def crop_center(img): | ||
# crop center of the image (of the size of the original image) | ||
start = tuple(map(lambda a, da: (a - da) // 2, img.shape, | ||
a.original_image.shape)) | ||
end = tuple(map(operator.add, start, a.original_image.shape)) | ||
slices = tuple(map(slice, start, end)) | ||
return img[slices] | ||
|
||
x_shifts = get_samples(x_shift_limits, granularity, do_translations) | ||
y_shifts = get_samples(y_shift_limits, granularity, do_translations) | ||
rotations = get_samples(angular_limits, granularity, do_rotations) | ||
|
||
transformations = product(x_shifts, y_shifts, rotations) | ||
|
||
for x_shift, y_shift, angle in transformations: | ||
if channel_axis == 0: | ||
xy_shift = (0, x_shift, y_shift) | ||
axes = (1, 2) | ||
elif channel_axis == 2: | ||
xy_shift = (x_shift, y_shift, 0) | ||
axes = (0, 1) | ||
else: # pragma: no cover | ||
raise ValueError('SpatialAttack only supports models ' | ||
'and inputs with NCHW or NHWC format') | ||
|
||
# rotate image (increases size) | ||
x = a.original_image | ||
x = rotate(x, angle=angle, axes=axes, reshape=True, order=1) | ||
|
||
# translate image | ||
x = shift(x, shift=xy_shift, mode='constant') | ||
|
||
# crop center | ||
x = crop_center(x) | ||
|
||
# ensure values are in range | ||
x = np.clip(x, min_, max_) | ||
|
||
# test image | ||
_, is_adv = a.predictions(x) | ||
|
||
if abort_early and is_adv: | ||
break |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,59 @@ | ||
import numpy as np | ||
|
||
from foolbox.attacks import SpatialAttack as Attack | ||
|
||
|
||
def test_attack_pytorch(bn_adversarial_pytorch): | ||
adv = bn_adversarial_pytorch | ||
attack = Attack() | ||
attack(adv) | ||
assert adv.image is not None | ||
assert adv.distance.value < np.inf | ||
|
||
|
||
def test_attack(bn_adversarial): | ||
adv = bn_adversarial | ||
attack = Attack() | ||
attack(adv) | ||
assert adv.image is not None | ||
assert adv.distance.value < np.inf | ||
|
||
|
||
def test_attack_rnd(bn_adversarial): | ||
adv = bn_adversarial | ||
attack = Attack() | ||
attack(adv, random_sampling=True) | ||
assert adv.image is not None | ||
assert adv.distance.value < np.inf | ||
|
||
|
||
def test_attack_norot(bn_adversarial): | ||
adv = bn_adversarial | ||
attack = Attack() | ||
attack(adv, do_rotations=False) | ||
assert adv.image is not None | ||
assert adv.distance.value < np.inf | ||
|
||
|
||
def test_attack_notrans(bn_adversarial): | ||
adv = bn_adversarial | ||
attack = Attack() | ||
attack(adv, do_translations=False) | ||
assert adv.image is not None | ||
assert adv.distance.value < np.inf | ||
|
||
|
||
def test_attack_notrans_norot(bn_adversarial): | ||
adv = bn_adversarial | ||
attack = Attack() | ||
attack(adv, do_translations=False, do_rotations=False) | ||
assert adv.image is None | ||
assert adv.distance.value == np.inf | ||
|
||
|
||
def test_attack_gl(gl_bn_adversarial): | ||
adv = gl_bn_adversarial | ||
attack = Attack() | ||
attack(adv) | ||
assert adv.image is not None | ||
assert adv.distance.value < np.inf |