Skip to content

Commit

Permalink
Merge pull request #577 from bethgelab/clipping_aware
Browse files Browse the repository at this point in the history
added clipping-aware Gaussian and uniform noise attacks
  • Loading branch information
jonasrauber committed Aug 14, 2020
2 parents 7346233 + 3f5c9ef commit 7e68eb6
Show file tree
Hide file tree
Showing 9 changed files with 210 additions and 19 deletions.
8 changes: 8 additions & 0 deletions docs/modules/attacks.rst
Original file line number Diff line number Diff line change
Expand Up @@ -20,9 +20,13 @@

L2AdditiveGaussianNoiseAttack
L2AdditiveUniformNoiseAttack
L2ClippingAwareAdditiveGaussianNoiseAttack
L2ClippingAwareAdditiveUniformNoiseAttack
LinfAdditiveUniformNoiseAttack
L2RepeatedAdditiveGaussianNoiseAttack
L2RepeatedAdditiveUniformNoiseAttack
L2ClippingAwareRepeatedAdditiveGaussianNoiseAttack
L2ClippingAwareRepeatedAdditiveUniformNoiseAttack
LinfRepeatedAdditiveUniformNoiseAttack
InversionAttack
BinarySearchContrastReductionAttack
Expand Down Expand Up @@ -62,9 +66,13 @@

.. autoclass:: L2AdditiveGaussianNoiseAttack
.. autoclass:: L2AdditiveUniformNoiseAttack
.. autoclass:: L2ClippingAwareAdditiveGaussianNoiseAttack
.. autoclass:: L2ClippingAwareAdditiveUniformNoiseAttack
.. autoclass:: LinfAdditiveUniformNoiseAttack
.. autoclass:: L2RepeatedAdditiveGaussianNoiseAttack
.. autoclass:: L2RepeatedAdditiveUniformNoiseAttack
.. autoclass:: L2ClippingAwareRepeatedAdditiveGaussianNoiseAttack
.. autoclass:: L2ClippingAwareRepeatedAdditiveUniformNoiseAttack
.. autoclass:: LinfRepeatedAdditiveUniformNoiseAttack
.. autoclass:: InversionAttack
.. autoclass:: BinarySearchContrastReductionAttack
Expand Down
4 changes: 4 additions & 0 deletions foolbox/attacks/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,9 +22,13 @@
from .additive_noise import ( # noqa: F401
L2AdditiveGaussianNoiseAttack,
L2AdditiveUniformNoiseAttack,
L2ClippingAwareAdditiveGaussianNoiseAttack,
L2ClippingAwareAdditiveUniformNoiseAttack,
LinfAdditiveUniformNoiseAttack,
L2RepeatedAdditiveGaussianNoiseAttack,
L2RepeatedAdditiveUniformNoiseAttack,
L2ClippingAwareRepeatedAdditiveGaussianNoiseAttack,
L2ClippingAwareRepeatedAdditiveUniformNoiseAttack,
LinfRepeatedAdditiveUniformNoiseAttack,
)
from .sparse_l1_descent_attack import SparseL1DescentAttack # noqa: F401
Expand Down
131 changes: 113 additions & 18 deletions foolbox/attacks/additive_noise.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from typing import Union, Any
from typing import Union, Any, cast
from abc import ABC
from abc import abstractmethod
import eagerpy as ep
Expand All @@ -16,6 +16,8 @@
from .base import get_is_adversarial
from .base import raise_if_kwargs

from ..external.clipping_aware_rescaling import l2_clipping_aware_rescaling


class BaseAdditiveNoiseAttack(FixedEpsilonAttack, ABC):
def run(
Expand All @@ -33,9 +35,8 @@ def run(

min_, max_ = model.bounds
p = self.sample_noise(x)
norms = self.get_norms(p)
p = p / atleast_kd(norms, p.ndim)
x = x + epsilon * p
epsilons = self.get_epsilons(x, p, epsilon, min_=min_, max_=max_)
x = x + epsilons * p
x = x.clip(min_, max_)

return restore_type(x)
Expand All @@ -45,22 +46,41 @@ def sample_noise(self, x: ep.Tensor) -> ep.Tensor:
raise NotImplementedError

@abstractmethod
def get_norms(self, p: ep.Tensor) -> ep.Tensor:
def get_epsilons(
self, x: ep.Tensor, p: ep.Tensor, epsilon: float, min_: float, max_: float
) -> ep.Tensor:
raise NotImplementedError


class L2Mixin:
distance = l2

def get_norms(self, p: ep.Tensor) -> ep.Tensor:
return flatten(p).norms.l2(axis=-1)
def get_epsilons(
self, x: ep.Tensor, p: ep.Tensor, epsilon: float, min_: float, max_: float
) -> ep.Tensor:
norms = flatten(p).norms.l2(axis=-1)
return epsilon / atleast_kd(norms, p.ndim)


class L2ClippingAwareMixin:
distance = l2

def get_epsilons(
self, x: ep.Tensor, p: ep.Tensor, epsilon: float, min_: float, max_: float
) -> ep.Tensor:
return cast(
ep.Tensor, l2_clipping_aware_rescaling(x, p, epsilon, a=min_, b=max_)
)


class LinfMixin:
distance = linf

def get_norms(self, p: ep.Tensor) -> ep.Tensor:
return flatten(p).max(axis=-1)
def get_epsilons(
self, x: ep.Tensor, p: ep.Tensor, epsilon: float, min_: float, max_: float
) -> ep.Tensor:
norms = flatten(p).max(axis=-1)
return epsilon / atleast_kd(norms, p.ndim)


class GaussianMixin:
Expand All @@ -74,13 +94,47 @@ def sample_noise(self, x: ep.Tensor) -> ep.Tensor:


class L2AdditiveGaussianNoiseAttack(L2Mixin, GaussianMixin, BaseAdditiveNoiseAttack):
"""Samples Gaussian noise with a fixed L2 size"""
"""Samples Gaussian noise with a fixed L2 size."""

pass


class L2AdditiveUniformNoiseAttack(L2Mixin, UniformMixin, BaseAdditiveNoiseAttack):
"""Samples uniform noise with a fixed L2 size"""
"""Samples uniform noise with a fixed L2 size."""

pass


class L2ClippingAwareAdditiveGaussianNoiseAttack(
L2ClippingAwareMixin, GaussianMixin, BaseAdditiveNoiseAttack
):
"""Samples Gaussian noise with a fixed L2 size after clipping.
The implementation is based on [#Rauber20]_.
References:
.. [#Rauber20] Jonas Rauber, Matthias Bethge
"Fast Differentiable Clipping-Aware Normalization and Rescaling"
https://arxiv.org/abs/2007.07677
"""

pass


class L2ClippingAwareAdditiveUniformNoiseAttack(
L2ClippingAwareMixin, UniformMixin, BaseAdditiveNoiseAttack
):
"""Samples uniform noise with a fixed L2 size after clipping.
The implementation is based on [#Rauber20]_.
References:
.. [#Rauber20] Jonas Rauber, Matthias Bethge
"Fast Differentiable Clipping-Aware Normalization and Rescaling"
https://arxiv.org/abs/2007.07677
"""

pass

Expand Down Expand Up @@ -125,9 +179,8 @@ def run(
break

p = self.sample_noise(x0)
norms = self.get_norms(p)
p = p / atleast_kd(norms, p.ndim)
x = x0 + epsilon * p
epsilons = self.get_epsilons(x0, p, epsilon, min_=min_, max_=max_)
x = x0 + epsilons * p
x = x.clip(min_, max_)
is_adv = is_adversarial(x)
is_new_adv = ep.logical_and(is_adv, ep.logical_not(found))
Expand All @@ -141,14 +194,16 @@ def sample_noise(self, x: ep.Tensor) -> ep.Tensor:
raise NotImplementedError

@abstractmethod
def get_norms(self, p: ep.Tensor) -> ep.Tensor:
def get_epsilons(
self, x: ep.Tensor, p: ep.Tensor, epsilon: float, min_: float, max_: float
) -> ep.Tensor:
raise NotImplementedError


class L2RepeatedAdditiveGaussianNoiseAttack(
L2Mixin, GaussianMixin, BaseRepeatedAdditiveNoiseAttack
):
"""Repeatedly samples Gaussian noise with a fixed L2 size
"""Repeatedly samples Gaussian noise with a fixed L2 size.
Args:
repeats : How often to sample random noise.
Expand All @@ -161,7 +216,47 @@ class L2RepeatedAdditiveGaussianNoiseAttack(
class L2RepeatedAdditiveUniformNoiseAttack(
L2Mixin, UniformMixin, BaseRepeatedAdditiveNoiseAttack
):
"""Repeatedly samples uniform noise with a fixed L2 size
"""Repeatedly samples uniform noise with a fixed L2 size.
Args:
repeats : How often to sample random noise.
check_trivial : Check whether original sample is already adversarial.
"""

pass


class L2ClippingAwareRepeatedAdditiveGaussianNoiseAttack(
L2ClippingAwareMixin, GaussianMixin, BaseRepeatedAdditiveNoiseAttack
):
"""Repeatedly samples Gaussian noise with a fixed L2 size after clipping.
The implementation is based on [#Rauber20]_.
References:
.. [#Rauber20] Jonas Rauber, Matthias Bethge
"Fast Differentiable Clipping-Aware Normalization and Rescaling"
https://arxiv.org/abs/2007.07677
Args:
repeats : How often to sample random noise.
check_trivial : Check whether original sample is already adversarial.
"""

pass


class L2ClippingAwareRepeatedAdditiveUniformNoiseAttack(
L2ClippingAwareMixin, UniformMixin, BaseRepeatedAdditiveNoiseAttack
):
"""Repeatedly samples uniform noise with a fixed L2 size after clipping.
The implementation is based on [#Rauber20]_.
References:
.. [#Rauber20] Jonas Rauber, Matthias Bethge
"Fast Differentiable Clipping-Aware Normalization and Rescaling"
https://arxiv.org/abs/2007.07677
Args:
repeats : How often to sample random noise.
Expand All @@ -174,7 +269,7 @@ class L2RepeatedAdditiveUniformNoiseAttack(
class LinfRepeatedAdditiveUniformNoiseAttack(
LinfMixin, UniformMixin, BaseRepeatedAdditiveNoiseAttack
):
"""Repeatedly samples uniform noise with a fixed L-infinity size
"""Repeatedly samples uniform noise with a fixed L-infinity size.
Args:
repeats : How often to sample random noise.
Expand Down
1 change: 1 addition & 0 deletions foolbox/external/LICENSE
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
The code in this subfolder might be under a different license than the rest of the project.
9 changes: 9 additions & 0 deletions foolbox/external/README.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
License
-------

The code in this subfolder might be under a different license than the rest of the project.

Sources
-------

* `clipping_aware_rescaling.py <https://github.com/jonasrauber/clipping-aware-rescaling>`_
1 change: 1 addition & 0 deletions foolbox/external/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
from . import clipping_aware_rescaling # noqa: F401
64 changes: 64 additions & 0 deletions foolbox/external/clipping_aware_rescaling.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,64 @@
# Copyright (c) 2020, Jonas Rauber
#
# Licensed under the BSD 3-Clause License
#
# Last changed:
# * 2020-07-15
# * 2020-01-08
# * 2019-04-18

import eagerpy as ep


def l2_clipping_aware_rescaling(x, delta, eps: float, a: float = 0.0, b: float = 1.0): # type: ignore
"""Calculates eta such that norm(clip(x + eta * delta, a, b) - x) == eps.
Assumes x and delta have a batch dimension and eps, a, b, and p are
scalars. If the equation cannot be solved because eps is too large, the
left hand side is maximized.
Args:
x: A batch of inputs (PyTorch Tensor, TensorFlow Eager Tensor, NumPy
Array, JAX Array, or EagerPy Tensor).
delta: A batch of perturbation directions (same shape and type as x).
eps: The target norm (non-negative float).
a: The lower bound of the data domain (float).
b: The upper bound of the data domain (float).
Returns:
eta: A batch of scales with the same number of dimensions as x but all
axis == 1 except for the batch dimension.
"""
(x, delta), restore_fn = ep.astensors_(x, delta)
N = x.shape[0]
assert delta.shape[0] == N
rows = ep.arange(x, N)

delta2 = delta.square().reshape((N, -1))
space = ep.where(delta >= 0, b - x, x - a).reshape((N, -1))
f2 = space.square() / ep.maximum(delta2, 1e-20)
ks = ep.argsort(f2, axis=-1)
f2_sorted = f2[rows[:, ep.newaxis], ks]
m = ep.cumsum(delta2[rows[:, ep.newaxis], ks.flip(axis=1)], axis=-1).flip(axis=1)
dx = f2_sorted[:, 1:] - f2_sorted[:, :-1]
dx = ep.concatenate((f2_sorted[:, :1], dx), axis=-1)
dy = m * dx
y = ep.cumsum(dy, axis=-1)
c = y >= eps ** 2

# work-around to get first nonzero element in each row
f = ep.arange(x, c.shape[-1], 0, -1)
j = ep.argmax(c.astype(f.dtype) * f, axis=-1)

eta2 = f2_sorted[rows, j] - (y[rows, j] - eps ** 2) / m[rows, j]
# it can happen that for certain rows even the largest j is not large enough
# (i.e. c[:, -1] is False), then we will just use it (without any correction) as it's
# the best we can do (this should also be the only cases where m[j] can be
# 0 and they are thus not a problem)
eta2 = ep.where(c[:, -1], eta2, f2_sorted[:, -1])
eta = ep.sqrt(eta2)
eta = eta.reshape((-1,) + (1,) * (x.ndim - 1))

# xp = ep.clip(x + eta * delta, a, b)
# l2 = (xp - x).reshape((N, -1)).square().sum(axis=-1).sqrt()
return restore_fn(eta)
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@
"numpy",
"scipy",
"setuptools",
"eagerpy==0.27.0",
"eagerpy==0.29.0",
"GitPython>=3.0.7",
"typing-extensions>=3.7.4.1",
"requests>=2.24.0",
Expand Down
9 changes: 9 additions & 0 deletions tests/test_attacks.py
Original file line number Diff line number Diff line change
Expand Up @@ -100,15 +100,24 @@ def get_attack_id(x: Tuple[fbn.Attack, bool, bool]) -> str:
(fa.SaltAndPepperNoiseAttack(steps=50, channel_axis=1), None, True, False),
(fa.LinearSearchBlendedUniformNoiseAttack(steps=50), None, False, False),
(fa.L2AdditiveGaussianNoiseAttack(), 2500.0, False, False),
(fa.L2ClippingAwareAdditiveGaussianNoiseAttack(), 500.0, False, False),
(fa.LinfAdditiveUniformNoiseAttack(), 10.0, False, False),
(
fa.L2RepeatedAdditiveGaussianNoiseAttack(check_trivial=False),
1000.0,
False,
False,
),
(
fa.L2ClippingAwareRepeatedAdditiveGaussianNoiseAttack(check_trivial=False),
200.0,
False,
False,
),
(fa.L2RepeatedAdditiveGaussianNoiseAttack(), 1000.0, False, False),
(fa.L2ClippingAwareRepeatedAdditiveGaussianNoiseAttack(), 200.0, False, False),
(fa.L2RepeatedAdditiveUniformNoiseAttack(), 1000.0, False, False),
(fa.L2ClippingAwareRepeatedAdditiveUniformNoiseAttack(), 200.0, False, False),
(fa.LinfRepeatedAdditiveUniformNoiseAttack(), 3.0, False, False),
]

Expand Down

0 comments on commit 7e68eb6

Please sign in to comment.