Skip to content

Commit

Permalink
Merge pull request #510 from bethgelab/sparse_l1
Browse files Browse the repository at this point in the history
implemented the SparseL1DescentAttack (previous impl. was wrong)
  • Loading branch information
jonasrauber committed Mar 22, 2020
2 parents 2312628 + fed9add commit fa82566
Show file tree
Hide file tree
Showing 4 changed files with 106 additions and 5 deletions.
1 change: 1 addition & 0 deletions foolbox/attacks/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@
L2RepeatedAdditiveUniformNoiseAttack,
LinfRepeatedAdditiveUniformNoiseAttack,
)
from .sparse_l1_descent_attack import SparseL1DescentAttack # noqa: F401

# MinimizatonAttack subclasses
from .inversion import InversionAttack # noqa: F401
Expand Down
20 changes: 15 additions & 5 deletions foolbox/attacks/gradient_descent_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,8 @@
from ..devutils import flatten
from ..devutils import atleast_kd

from ..types import Bounds

from ..models.base import Model

from ..criteria import Misclassification
Expand Down Expand Up @@ -84,7 +86,7 @@ def run(

for _ in range(self.steps):
_, gradients = self.value_and_grad(loss_fn, x)
gradients = self.normalize(gradients)
gradients = self.normalize(gradients, x=x, bounds=model.bounds)
x = x + stepsize * gradients
x = self.project(x, x0, epsilon)
x = ep.clip(x, *model.bounds)
Expand All @@ -96,7 +98,9 @@ def get_random_start(self, x0: ep.Tensor, epsilon: float) -> ep.Tensor:
...

@abstractmethod
def normalize(self, gradients: ep.Tensor) -> ep.Tensor:
def normalize(
self, gradients: ep.Tensor, *, x: ep.Tensor, bounds: Bounds
) -> ep.Tensor:
...

@abstractmethod
Expand Down Expand Up @@ -163,7 +167,9 @@ def get_random_start(self, x0: ep.Tensor, epsilon: float) -> ep.Tensor:
r = uniform_l1_n_balls(x0, batch_size, n).reshape(x0.shape)
return x0 + epsilon * r

def normalize(self, gradients: ep.Tensor) -> ep.Tensor:
def normalize(
self, gradients: ep.Tensor, *, x: ep.Tensor, bounds: Bounds
) -> ep.Tensor:
return normalize_lp_norms(gradients, p=1)

def project(self, x: ep.Tensor, x0: ep.Tensor, epsilon: float) -> ep.Tensor:
Expand All @@ -178,7 +184,9 @@ def get_random_start(self, x0: ep.Tensor, epsilon: float) -> ep.Tensor:
r = uniform_l2_n_balls(x0, batch_size, n).reshape(x0.shape)
return x0 + epsilon * r

def normalize(self, gradients: ep.Tensor) -> ep.Tensor:
def normalize(
self, gradients: ep.Tensor, *, x: ep.Tensor, bounds: Bounds
) -> ep.Tensor:
return normalize_lp_norms(gradients, p=2)

def project(self, x: ep.Tensor, x0: ep.Tensor, epsilon: float) -> ep.Tensor:
Expand All @@ -191,7 +199,9 @@ class LinfBaseGradientDescent(BaseGradientDescent):
def get_random_start(self, x0: ep.Tensor, epsilon: float) -> ep.Tensor:
return x0 + ep.uniform(x0, x0.shape, -epsilon, epsilon)

def normalize(self, gradients: ep.Tensor) -> ep.Tensor:
def normalize(
self, gradients: ep.Tensor, *, x: ep.Tensor, bounds: Bounds
) -> ep.Tensor:
return gradients.sign()

def project(self, x: ep.Tensor, x0: ep.Tensor, epsilon: float) -> ep.Tensor:
Expand Down
89 changes: 89 additions & 0 deletions foolbox/attacks/sparse_l1_descent_attack.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,89 @@
from typing import Optional
import eagerpy as ep
import numpy as np

from ..devutils import flatten
from ..devutils import atleast_kd

from ..types import Bounds

from .gradient_descent_base import L1BaseGradientDescent
from .gradient_descent_base import normalize_lp_norms


class SparseL1DescentAttack(L1BaseGradientDescent):
"""Sparse L1 Descent Attack [#Tra19]_.
Args:
rel_stepsize: Stepsize relative to epsilon.
abs_stepsize: If given, it takes precedence over rel_stepsize.
steps : Number of update steps.
random_start : Controls whether to randomly start within allowed epsilon ball.
References:
.. [#Tra19] Florian Tramèr, Dan Boneh, "Adversarial Training and
Robustness for Multiple Perturbations"
https://arxiv.org/abs/1904.13000
"""

def normalize(
self, gradients: ep.Tensor, *, x: ep.Tensor, bounds: Bounds
) -> ep.Tensor:
bad_pos = ep.logical_or(
ep.logical_and(x == bounds.lower, gradients < 0),
ep.logical_and(x == bounds.upper, gradients > 0),
)
gradients = ep.where(bad_pos, ep.zeros_like(gradients), gradients)

abs_gradients = gradients.abs()
quantiles = np.quantile(
flatten(abs_gradients).numpy(), q=self.quantile, axis=-1
)
keep = abs_gradients >= atleast_kd(
ep.from_numpy(gradients, quantiles), gradients.ndim
)
e = ep.where(keep, gradients.sign(), ep.zeros_like(gradients))
return normalize_lp_norms(e, p=1)

def project(self, x: ep.Tensor, x0: ep.Tensor, epsilon: float) -> ep.Tensor:
# based on https://github.com/ftramer/MultiRobustness/blob/ad41b63235d13b1b2a177c5f270ab9afa74eee69/pgd_attack.py#L110
delta = flatten(x - x0)
norms = delta.norms.l1(axis=-1)
if (norms <= epsilon).all():
return x

n, d = delta.shape
abs_delta = abs(delta)
mu = -ep.sort(-abs_delta, axis=-1)
cumsums = mu.cumsum(axis=-1)
js = 1.0 / ep.arange(x, 1, d + 1).astype(x.dtype)
temp = mu - js * (cumsums - epsilon)
guarantee_first = ep.arange(x, d).astype(x.dtype) / d
# guarantee_first are small values (< 1) that we add to the boolean
# tensor (only 0 and 1) to break the ties and always return the first
# argmin, i.e. the first value where the boolean tensor is 0
# (otherwise, this is not guaranteed on GPUs, see e.g. PyTorch)
rho = ep.argmin((temp > 0).astype(x.dtype) + guarantee_first, axis=-1)
theta = 1.0 / (1 + rho.astype(x.dtype)) * (cumsums[range(n), rho] - epsilon)
delta = delta.sign() * ep.maximum(abs_delta - theta[..., ep.newaxis], 0)
delta = delta.reshape(x.shape)
return x0 + delta

def __init__(
self,
*,
quantile: float = 0.99,
rel_stepsize: float = 0.2,
abs_stepsize: Optional[float] = None,
steps: int = 10,
random_start: bool = False,
):
super().__init__(
rel_stepsize=rel_stepsize,
abs_stepsize=abs_stepsize,
steps=steps,
random_start=random_start,
)
if not 0 <= quantile <= 1:
raise ValueError(f"quantile needs to be between 0 and 1, got {quantile}")
self.quantile = quantile
1 change: 1 addition & 0 deletions tests/test_attacks.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,6 +75,7 @@ def get_attack_id(x: Tuple[fbn.Attack, bool, bool]) -> str:
(fa.LinfBasicIterativeAttack(abs_stepsize=0.2), Linf(1.0), True, False),
(fa.L2BasicIterativeAttack(), L2(50.0), True, False),
(fa.L1BasicIterativeAttack(), 5000.0, True, False),
(fa.SparseL1DescentAttack(), 5000.0, True, False),
(fa.FGSM(), Linf(100.0), True, False),
(FGSM_GE(), Linf(100.0), False, False),
(fa.FGM(), L2(100.0), True, False),
Expand Down

0 comments on commit fa82566

Please sign in to comment.