Skip to content

Commit

Permalink
Check that image pixels are within [0.0, 1.0] bounds at the beginning…
Browse files Browse the repository at this point in the history
… of PGD attack (#498)

* Check that image pixels are within [0.0, 1.0] bounds at the beginning of PGD attack

* Fix PEP8 indentation

* Revert "Fix PEP8 indentation"

This reverts commit a85fe56.

* Correctly fix pep8

* Only validate inputs when using default project_perturbation

* Add quick test that SPSA attack works correctly
  • Loading branch information
nottombrown committed Aug 22, 2018
1 parent a6e4e7b commit e309fc4
Show file tree
Hide file tree
Showing 3 changed files with 73 additions and 6 deletions.
15 changes: 15 additions & 0 deletions cleverhans/attacks.py
Expand Up @@ -1732,6 +1732,21 @@ class SPSA(Attack):

def __init__(self, model, back='tf', sess=None, dtypestr='float32'):
super(SPSA, self).__init__(model, back, sess, dtypestr)

self.feedable_kwargs = {
'epsilon': self.np_dtype,
'y': self.np_dtype,
'y_target': self.np_dtype,
}
self.structural_kwargs = [
'num_steps',
'batch_size',
'spsa_iters',
'early_stop_loss_threshold',
'is_debug',
'is_targeted',
]

assert isinstance(self.model, Model)

def generate(self,
Expand Down
21 changes: 16 additions & 5 deletions cleverhans/attacks_tf.py
Expand Up @@ -1793,9 +1793,16 @@ def cond(i, _):

def _project_perturbation(perturbation, epsilon, input_image):
"""Project `perturbation` onto L-infinity ball of radius `epsilon`."""
clipped_perturbation = tf.clip_by_value(perturbation, -epsilon, epsilon)
new_image = tf.clip_by_value(input_image + clipped_perturbation, 0., 1.)
return new_image - input_image
# Ensure inputs are in the correct range
with tf.control_dependencies([
tf.assert_less_equal(input_image, 1.0),
tf.assert_greater_equal(input_image, 0.0)
]):
clipped_perturbation = tf.clip_by_value(
perturbation, -epsilon, epsilon)
new_image = tf.clip_by_value(
input_image + clipped_perturbation, 0., 1.)
return new_image - input_image


def pgd_attack(loss_fn,
Expand Down Expand Up @@ -1883,11 +1890,15 @@ def cond(i, *_):
loop_vars=[tf.constant(0.), init_perturbation, flat_init_optim_state],
parallel_iterations=1,
back_prop=False)

if project_perturbation == _project_perturbation:
check_diff = tf.assert_less_equal(final_perturbation, epsilon * 1.1)
perturbation_max = epsilon * 1.1
check_diff = tf.assert_less_equal(
final_perturbation, perturbation_max,
message="final_perturbation must change no pixel by more than "
"%s" % perturbation_max)
else:
check_diff = tf.no_op()

with tf.control_dependencies([check_diff]):
adversarial_image = input_image + final_perturbation
return tf.stop_gradient(adversarial_image)
Expand Down
43 changes: 42 additions & 1 deletion tests_tf/test_attacks.py
Expand Up @@ -11,7 +11,7 @@
import numpy as np

from cleverhans.devtools.checks import CleverHansTest
from cleverhans.attacks import Attack
from cleverhans.attacks import Attack, SPSA
from cleverhans.attacks import FastGradientMethod
from cleverhans.attacks import BasicIterativeMethod
from cleverhans.attacks import MomentumIterativeMethod
Expand Down Expand Up @@ -259,6 +259,47 @@ def fn(*x, **y):
tf.gradients = old_grads


class TestSPSA(CleverHansTest):
def setUp(self):
super(TestSPSA, self).setUp()

self.sess = tf.Session()
self.model = SimpleModel()
self.attack = SPSA(self.model, sess=self.sess)

def test_attack_strength(self):
# This uses the existing input structure for SPSA. Tom tried for ~40
# minutes to get generate_np to work correctly but could not.

n_samples = 10
x_val = np.random.rand(n_samples, 2)
x_val = np.array(x_val, dtype=np.float32)

# The SPSA attack currently uses non-one-hot labels
# TODO: change this to use standard cleverhans label conventions
feed_labs = np.random.randint(0, 2, n_samples)

x_input = tf.placeholder(tf.float32, shape=(1,2))
y_label = tf.placeholder(tf.int32, shape=(1,))

x_adv_op = self.attack.generate(
x_input, y=y_label,
epsilon=.5, num_steps=100, batch_size=64, spsa_iters=1,
)

all_x_adv = []
for i in range(n_samples):
x_adv_np = self.sess.run(x_adv_op, feed_dict={
x_input: np.expand_dims(x_val[i], axis=0),
y_label: np.expand_dims(feed_labs[i], axis=0),
})
all_x_adv.append(x_adv_np[0])

x_adv = np.vstack(all_x_adv)
new_labs = np.argmax(self.sess.run(self.model(x_adv)), axis=1)
self.assertTrue(np.mean(feed_labs == new_labs) < 0.1)


class TestBasicIterativeMethod(TestFastGradientMethod):
def setUp(self):
super(TestBasicIterativeMethod, self).setUp()
Expand Down

0 comments on commit e309fc4

Please sign in to comment.