Skip to content

Commit

Permalink
Allow a None batch dimension in SPSA for compatibility with generate_…
Browse files Browse the repository at this point in the history
…np()
  • Loading branch information
Jonathan Uesato committed Aug 25, 2018
1 parent a2bfa20 commit 784db5a
Show file tree
Hide file tree
Showing 3 changed files with 37 additions and 5 deletions.
9 changes: 7 additions & 2 deletions cleverhans/attacks.py
Original file line number Diff line number Diff line change
Expand Up @@ -1735,8 +1735,8 @@ def __init__(self, model, back='tf', sess=None, dtypestr='float32'):

self.feedable_kwargs = {
'epsilon': self.np_dtype,
'y': self.np_dtype,
'y_target': self.np_dtype,
'y': np.int32,
'y_target': np.int32,
}
self.structural_kwargs = [
'num_steps',
Expand Down Expand Up @@ -1821,3 +1821,8 @@ def loss_fn(x, label):
is_debug=is_debug,
)
return adv_x

def generate_np(self, x_val, **kwargs):
# Add shape check for batch size=1, then call parent class generate_np
assert x_val.shape[0] == 1, 'x_val should be a batch of a single image'
return super(SPSA, self).generate_np(x_val, **kwargs)
8 changes: 8 additions & 0 deletions cleverhans/attacks_tf.py
Original file line number Diff line number Diff line change
Expand Up @@ -1843,6 +1843,14 @@ def pgd_attack(loss_fn,
methods. The method uses a tf.while_loop to optimize a loss function in
a single sess.run() call.
"""
# If batch dimension is None, reshape to explicit batch size 1
static_x_shape = input_image.get_shape().as_list()
static_y_shape = label.get_shape().as_list()
if static_x_shape[0] is None or static_y_shape[0] is None:
assert static_x_shape[0] is None and static_y_shape[0] is None
input_image = tf.reshape(input_image, [1] + static_x_shape[1:])
label = tf.reshape(label, [1] + static_y_shape[1:])

if is_debug:
with tf.device("/cpu:0"):
input_image = tf.Print(
Expand Down
25 changes: 22 additions & 3 deletions tests_tf/test_attacks.py
Original file line number Diff line number Diff line change
Expand Up @@ -268,9 +268,6 @@ def setUp(self):
self.attack = SPSA(self.model, sess=self.sess)

def test_attack_strength(self):
# This uses the existing input structure for SPSA. Tom tried for ~40
# minutes to get generate_np to work correctly but could not.

n_samples = 10
x_val = np.random.rand(n_samples, 2)
x_val = np.array(x_val, dtype=np.float32)
Expand Down Expand Up @@ -299,6 +296,28 @@ def test_attack_strength(self):
new_labs = np.argmax(self.sess.run(self.model(x_adv)), axis=1)
self.assertTrue(np.mean(feed_labs == new_labs) < 0.1)

def test_attack_strength_np(self):
# Same test as test_attack_strength, but uses generate_np interface
n_samples = 10
x_val = np.random.rand(n_samples, 2)
x_val = np.array(x_val, dtype=np.float32)

feed_labs = np.random.randint(0, 2, n_samples)

all_x_adv = []
for i in range(n_samples):
x_adv_np = self.attack.generate_np(
np.expand_dims(x_val[i], axis=0),
y=np.expand_dims(feed_labs[i], axis=0),
epsilon=.5, num_steps=100, batch_size=64, spsa_iters=1,
)
all_x_adv.append(x_adv_np[0])

x_adv = np.vstack(all_x_adv)
new_labs = np.argmax(self.sess.run(self.model(x_adv)), axis=1)
self.assertTrue(np.mean(feed_labs == new_labs) < 0.1)



class TestBasicIterativeMethod(TestFastGradientMethod):
def setUp(self):
Expand Down

0 comments on commit 784db5a

Please sign in to comment.