Skip to content

Commit

Permalink
Merge branch 'master' of github.com:bethgelab/foolbox
Browse files Browse the repository at this point in the history
  • Loading branch information
Jonas Rauber committed Jul 13, 2017
2 parents 7bbbecc + 0236724 commit e767c1b
Show file tree
Hide file tree
Showing 3 changed files with 24 additions and 4 deletions.
10 changes: 6 additions & 4 deletions foolbox/attacks/lbfgs.py
Original file line number Diff line number Diff line change
Expand Up @@ -112,7 +112,7 @@ def _optimize(self, a, target_class, epsilon, maxiter, verbose):

# store the shape for later and operate on the flattened image
shape = image.shape
# dtype = image.dtype
dtype = image.dtype
image = image.flatten().astype(np.float64)

n = len(image)
Expand All @@ -134,9 +134,10 @@ def crossentropy(x):
return ce

def loss(x, c):
x = x.astype(dtype)
v1 = distance(x)
v2 = crossentropy(x)
return v1 + c * v2
return np.float64(v1 + c * v2)

else:

Expand All @@ -152,13 +153,14 @@ def crossentropy(x):
return ce, gradient

def loss(x, c):
x = x.astype(dtype)
v1, g1 = distance(x)
v2, g2 = crossentropy(x)
v = v1 + c * v2
g = g1 + c * g2

a = 1e10
return a * v, a * g
return np.float64(a * v), np.float64(a * g)

def lbfgsb(c):
approx_grad_eps = (max_ - min_) / 100
Expand All @@ -174,7 +176,7 @@ def lbfgsb(c):

logging.info(d)

_, is_adversarial = a.predictions(x.reshape(shape))
_, is_adversarial = a.predictions(x.reshape(shape).astype(dtype))
return is_adversarial

# finding initial c
Expand Down
10 changes: 10 additions & 0 deletions foolbox/tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -97,6 +97,7 @@ def bn_model_pytorch():
class Net(nn.Module):

def forward(self, x):
assert isinstance(x.data, torch.FloatTensor)
x = torch.mean(x, 3)
x = torch.squeeze(x, dim=3)
x = torch.mean(x, 2)
Expand Down Expand Up @@ -238,3 +239,12 @@ def bn_adversarial_pytorch():
image = bn_image_pytorch()
label = bn_label()
return Adversarial(model, criterion, image, label)


@pytest.fixture
def bn_targeted_adversarial_pytorch():
model = bn_model_pytorch()
criterion = bn_targeted_criterion()
image = bn_image_pytorch()
label = bn_label()
return Adversarial(model, criterion, image, label)
8 changes: 8 additions & 0 deletions foolbox/tests/test_attacks_lbfgs.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,3 +33,11 @@ def test_attack_pytorch(bn_adversarial_pytorch):
attack(adv, verbose=True, num_random_targets=2)
assert adv.image is not None
assert adv.distance.value < np.inf


def test_targeted_attack_pytorch(bn_targeted_adversarial_pytorch):
adv = bn_targeted_adversarial_pytorch
attack = Attack()
attack(adv, verbose=True, num_random_targets=2)
assert adv.image is not None
assert adv.distance.value < np.inf

0 comments on commit e767c1b

Please sign in to comment.