Skip to content

Commit

Permalink
Allow a None batch dimension in SPSA for compatibility with generate_…
Browse files Browse the repository at this point in the history
…np()
  • Loading branch information
Jonathan Uesato committed Aug 25, 2018
1 parent a2bfa20 commit 166f7a8
Show file tree
Hide file tree
Showing 3 changed files with 36 additions and 6 deletions.
9 changes: 7 additions & 2 deletions cleverhans/attacks.py
Original file line number Diff line number Diff line change
Expand Up @@ -1735,8 +1735,8 @@ def __init__(self, model, back='tf', sess=None, dtypestr='float32'):

self.feedable_kwargs = {
'epsilon': self.np_dtype,
'y': self.np_dtype,
'y_target': self.np_dtype,
'y': np.int32,
'y_target': np.int32,
}
self.structural_kwargs = [
'num_steps',
Expand Down Expand Up @@ -1821,3 +1821,8 @@ def loss_fn(x, label):
is_debug=is_debug,
)
return adv_x

def generate_np(self, x_val, **kwargs):
# Add shape check for batch size=1, then call parent class generate_np
assert x_val.shape[0] == 1, 'x_val should be a batch of a single image'
return super(SPSA, self).generate_np(x_val, **kwargs)
8 changes: 7 additions & 1 deletion cleverhans/attacks_tf.py
Original file line number Diff line number Diff line change
Expand Up @@ -1760,7 +1760,13 @@ def _get_delta(self, x, delta):
def _compute_gradients(self, loss_fn, x, unused_optim_state):
"""Compute gradient estimates using SPSA."""
# Assumes `x` is a list, containing a [1, H, W, C] image
assert len(x) == 1 and x[0].get_shape().as_list()[0] == 1
# If static batch dimension is None, tf.reshape to batch size 1
# so that static shape can be inferred
assert len(x) == 1
static_x_shape = x[0].get_shape().as_list()
if static_x_shape[0] is None:
x[0] = tf.reshape(x[0], [1] + static_x_shape[1:])
assert x[0].get_shape().as_list()[0] == 1
x = x[0]
x_shape = x.get_shape().as_list()

Expand Down
25 changes: 22 additions & 3 deletions tests_tf/test_attacks.py
Original file line number Diff line number Diff line change
Expand Up @@ -268,9 +268,6 @@ def setUp(self):
self.attack = SPSA(self.model, sess=self.sess)

def test_attack_strength(self):
# This uses the existing input structure for SPSA. Tom tried for ~40
# minutes to get generate_np to work correctly but could not.

n_samples = 10
x_val = np.random.rand(n_samples, 2)
x_val = np.array(x_val, dtype=np.float32)
Expand Down Expand Up @@ -299,6 +296,28 @@ def test_attack_strength(self):
new_labs = np.argmax(self.sess.run(self.model(x_adv)), axis=1)
self.assertTrue(np.mean(feed_labs == new_labs) < 0.1)

def test_attack_strength_np(self):
# Same test as test_attack_strength, but uses generate_np interface
n_samples = 10
x_val = np.random.rand(n_samples, 2)
x_val = np.array(x_val, dtype=np.float32)

feed_labs = np.random.randint(0, 2, n_samples)

all_x_adv = []
for i in range(n_samples):
x_adv_np = self.attack.generate_np(
np.expand_dims(x_val[i], axis=0),
y=np.expand_dims(feed_labs[i], axis=0),
epsilon=.5, num_steps=100, batch_size=64, spsa_iters=1,
)
all_x_adv.append(x_adv_np[0])

x_adv = np.vstack(all_x_adv)
new_labs = np.argmax(self.sess.run(self.model(x_adv)), axis=1)
self.assertTrue(np.mean(feed_labs == new_labs) < 0.1)



class TestBasicIterativeMethod(TestFastGradientMethod):
def setUp(self):
Expand Down

0 comments on commit 166f7a8

Please sign in to comment.