Skip to content

Commit

Permalink
Merge pull request Trusted-AI#39 from Benjamin-Edwards/PGD
Browse files Browse the repository at this point in the history
Projected Gradient Descent
  • Loading branch information
MARIA NICOLAE authored and GitHub Enterprise committed Jul 17, 2018
2 parents 2744a03 + 94036d4 commit 98b29b5
Show file tree
Hide file tree
Showing 5 changed files with 198 additions and 106 deletions.
18 changes: 13 additions & 5 deletions art/attacks/fast_gradient.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
import numpy as np

from art.attacks.attack import Attack
from art.utils import get_labels_np_array
from art.utils import get_labels_np_array, random_sphere


class FastGradientMethod(Attack):
Expand All @@ -14,7 +14,7 @@ class FastGradientMethod(Attack):
"""
attack_params = Attack.attack_params + ['norm', 'eps', 'targeted']

def __init__(self, classifier, norm=np.inf, eps=.3, targeted=False):
def __init__(self, classifier, norm=np.inf, eps=.3, targeted=False, random_init=False):
"""
Create a :class:`FastGradientMethod` instance.
Expand All @@ -26,12 +26,15 @@ def __init__(self, classifier, norm=np.inf, eps=.3, targeted=False):
:type eps: `float`
:param targeted: Should the attack target one specific class
:type targeted: `bool`
:param random_init: Whether to start at the original input or a random point within the epsilon ball
:type random_init: `bool`
"""
super(FastGradientMethod, self).__init__(classifier)

self.norm = norm
self.eps = eps
self.targeted = targeted
self.random_init = random_init

def _minimal_perturbation(self, x, y, eps_step=0.1, eps_max=1., **kwargs):
"""Iteratively compute the minimal perturbation necessary to make the class prediction change. Stop when the
Expand Down Expand Up @@ -117,7 +120,7 @@ def generate(self, x, **kwargs):
if 'minimal' in params_cpy and params_cpy[str('minimal')]:
return self._minimal_perturbation(x, y, **params_cpy)

return self._compute(x, y, self.eps)
return self._compute(x, y, self.eps, self.random_init)

def set_params(self, **kwargs):
"""
Expand Down Expand Up @@ -165,8 +168,13 @@ def _apply_perturbation(self, batch, perturbation, eps):
clip_min, clip_max = self.classifier.clip_values
return np.clip(batch + eps * perturbation, clip_min, clip_max)

def _compute(self, x, y, eps):
adv_x = x.copy()
def _compute(self, x, y, eps, random_init):
if random_init:
n = x.shape[0]
m = np.prod(x.shape[1:])
adv_x = x.copy() + random_sphere(n, m, eps, self.norm).reshape(x.shape)
else:
adv_x = x.copy()

# Compute perturbation with implicit batching
batch_size = 128
Expand Down
30 changes: 22 additions & 8 deletions art/attacks/iterative_method.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
import numpy as np

from art.attacks import FastGradientMethod
from art.utils import to_categorical
from art.utils import to_categorical, get_labels_np_array


class BasicIterativeMethod(FastGradientMethod):
Expand All @@ -14,7 +14,7 @@ class BasicIterativeMethod(FastGradientMethod):
"""
attack_params = FastGradientMethod.attack_params + ['eps_step', 'max_iter']

def __init__(self, classifier, norm=np.inf, eps=.3, eps_step=0.1, max_iter=20):
def __init__(self, classifier, norm=np.inf, eps=.3, eps_step=0.1, max_iter=20, targeted=False, random_init=False):
"""
Create a :class:`BasicIterativeMethod` instance.
Expand All @@ -26,8 +26,12 @@ def __init__(self, classifier, norm=np.inf, eps=.3, eps_step=0.1, max_iter=20):
:type eps: `float`
:param eps_step: Attack step size (input variation) at each iteration.
:type eps_step: `float`
:param targeted: Should the attack target one specific class
:type targeted: `bool`
:param random_init: Whether to start at the original input or a random point within the epsilon ball
:type random_init: `bool`
"""
super(BasicIterativeMethod, self).__init__(classifier, norm=norm, eps=eps, targeted=True)
super(BasicIterativeMethod, self).__init__(classifier, norm=norm, eps=eps, targeted=targeted,random_init=random_init)

if eps_step > eps:
raise ValueError('The iteration step `eps_step` has to be smaller than the total attack `eps`.')
Expand Down Expand Up @@ -57,19 +61,29 @@ def generate(self, x, **kwargs):

# Choose least likely class as target prediction for the attack
adv_x = x.copy()
targets = to_categorical(np.argmin(self.classifier.predict(x), axis=1), nb_classes=self.classifier.nb_classes)
if 'y' not in kwargs or kwargs[str('y')] is None:
# Throw error if attack is targeted, but no targets are provided
if self.targeted:
raise ValueError('Target labels `y` need to be provided for a targeted attack.')

# Use model predictions as correct outputs
y = self.classifier.predict(x)
else:
y = kwargs[str('y')]
y = y / np.sum(y, axis=1, keepdims=True)

targets = to_categorical(y, nb_classes=self.classifier.nb_classes)
active_indices = range(len(adv_x))

for _ in range(self.max_iter):
# Adversarial crafting
adv_x[active_indices] = self._compute(adv_x[active_indices], targets[active_indices], self.eps_step)
adv_x[active_indices] = self._compute(adv_x[active_indices], targets[active_indices], self.eps_step, self.random_init)
noise = projection(adv_x[active_indices] - x[active_indices], self.eps, self.norm)
adv_x[active_indices] = x[active_indices] + noise
adv_preds = self.classifier.predict(adv_x)
adv_preds = self.classifier.predict(adv_x[active_indices])

# Update active indices
active_indices = np.where(targets[active_indices] != np.argmax(adv_preds, axis=1))[0]

active_indices = np.where(np.argmax(targets[active_indices],axis=1) != np.argmax(adv_preds, axis=1))[0]
# Stop if no more indices left to explore
if len(active_indices) == 0:
break
Expand Down
96 changes: 3 additions & 93 deletions art/metrics.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
from functools import reduce

from art.attacks.fast_gradient import FastGradientMethod
from art.utils import random_sphere

# TODO add all other implemented attacks
supported_methods = {
Expand Down Expand Up @@ -258,7 +259,7 @@ def clever_t(classifier, x, target_class, n_b, n_s, r, norm, c_init=1, pool_fact
shape.extend(x.shape)

# Generate a pool of samples
rand_pool = np.reshape(_random_sphere(m=pool_factor * n_s, n=dim, r=r, norm=norm), shape)
rand_pool = np.reshape(random_sphere(m=pool_factor * n_s, n=dim, r=r, norm=norm), shape)
rand_pool += np.repeat(np.array([x]), pool_factor * n_s, 0)
np.clip(rand_pool, classifier.clip_values[0], classifier.clip_values[1], out=rand_pool)

Expand Down Expand Up @@ -295,95 +296,4 @@ def clever_t(classifier, x, target_class, n_b, n_s, r, norm, c_init=1, pool_fact
# Compute scores
s = np.min([-value[0] / loc, r])

return s


def _random_sphere(m, n, r, norm):
"""
Generate randomly `m x n`-dimension points with radius `r` and centered around 0.
:param m: Number of random data points
:type m: `int`
:param n: Dimension
:type n: `int`
:param r: Radius
:type r: `float`
:param norm: Current support: 1, 2, np.inf
:type norm: `int`
:return: The generated random sphere
:rtype: `np.ndarray`
"""
if norm == 1:
res = _l1_random(m, n, r)
elif norm == 2:
res = _l2_random(m, n, r)
elif norm == np.inf:
res = _linf_random(m, n, r)
else:
raise NotImplementedError("Norm {} not supported".format(norm))

return res


def _l2_random(m, n, r):
"""
Generate randomly `m x n`-dimension points with radius `r` in norm 2 and centered around 0.
:param m: Number of random data points
:type m: `int`
:param n: Dimension
:type n: `int`
:param r: Radius
:type r: `float`
:return: The generated random sphere
:rtype: `np.ndarray`
"""
a = np.random.randn(m, n)
s2 = np.sum(a**2, axis=1)
base = gammainc(n/2.0, s2/2.0)**(1/n) * r / np.sqrt(s2)
a = a * (np.tile(base, (n, 1))).T

return a


def _l1_random(m, n, r):
"""
Generate randomly `m x n`-dimension points with radius `r` in norm 1 and centered around 0.
:param m: Number of random data points
:type m: `int`
:param n: Dimension
:type n: `int`
:param r: Radius
:type r: `float`
:return: The generated random sphere
:rtype: `np.ndarray`
"""
A = np.zeros(shape=(m, n+1))
A[:, -1] = np.sqrt(np.random.uniform(0, r**2, m))

for i in range(m):
A[i, 1:-1] = np.sort(np.random.uniform(0, A[i, -1], n-1))

X = (A[:, 1:] - A[:, :-1]) * np.random.choice([-1, 1], (m, n))

return X


def _linf_random(m, n, r):
"""
Generate randomly `m x n`-dimension points with radius `r` in inf norm and centered around 0.
:param m: Number of random data points
:type m: `int`
:param n: Dimension
:type n: `int`
:param r: Radius
:type r: `float`
:return: The generated random sphere
:rtype: `np.ndarray`
"""
return np.random.uniform(float(-r), float(r), (m, n))



return s
36 changes: 36 additions & 0 deletions art/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,8 @@

import numpy as np

from scipy.special import gammainc


def projection(v, eps, p):
"""
Expand Down Expand Up @@ -39,6 +41,40 @@ def projection(v, eps, p):
v = v_.reshape(v.shape)
return v

def random_sphere(m, n, r, norm):
"""
Generate randomly `m x n`-dimension points with radius `r` and centered around 0.
:param m: Number of random data points
:type m: `int`
:param n: Dimension
:type n: `int`
:param r: Radius
:type r: `float`
:param norm: Current support: 1, 2, np.inf
:type norm: `int`
:return: The generated random sphere
:rtype: `np.ndarray`
"""
if norm == 1:
A = np.zeros(shape=(m, n+1))
A[:, -1] = np.sqrt(np.random.uniform(0, r**2, m))

for i in range(m):
A[i, 1:-1] = np.sort(np.random.uniform(0, A[i, -1], n-1))

res = (A[:, 1:] - A[:, :-1]) * np.random.choice([-1, 1], (m, n))
elif norm == 2:
a = np.random.randn(m, n)
s2 = np.sum(a**2, axis=1)
base = gammainc(n/2.0, s2/2.0)**(1/n) * r / np.sqrt(s2)
res = a * (np.tile(base, (n, 1))).T
elif norm == np.inf:
res= np.random.uniform(float(-r), float(r), (m, n))
else:
raise NotImplementedError("Norm {} not supported".format(norm))

return res

def to_categorical(labels, nb_classes=None):
"""
Expand Down
Loading

0 comments on commit 98b29b5

Please sign in to comment.