Skip to content

Commit

Permalink
Support for additional tensorflow operations.
Browse files Browse the repository at this point in the history
  • Loading branch information
Sven Gowal committed Jan 23, 2019
1 parent 15340d3 commit 5fa09e7
Show file tree
Hide file tree
Showing 13 changed files with 552 additions and 120 deletions.
5 changes: 4 additions & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,10 @@ This is not an official Google product
## Installation

IBP can be installed with the following command:
`python setup.py install`

```bash
pip install git+https://github.com/deepmind/interval-bound-propagation`
```

IBP will work with both the CPU and GPU version of tensorflow and dm-sonnet, but
to allow for that it does not list Tensorflow as a requirement, so you need to
Expand Down
8 changes: 5 additions & 3 deletions examples/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -151,7 +151,7 @@ def main(unused_args):
predictor = ibp.VerifiableModelWrapper(predictor)

# Training.
train_losses, train_loss = ibp.create_classification_losses(
train_losses, train_loss, _ = ibp.create_classification_losses(
step,
data.image,
data.label,
Expand Down Expand Up @@ -193,7 +193,7 @@ def body(i, metrics):
tf.maximum(test_data.image - FLAGS.epsilon, input_bounds[0]),
tf.minimum(test_data.image + FLAGS.epsilon, input_bounds[1]))
predictor.propagate_bounds(input_interval_bounds)
test_specification = ibp.build_classification_specification(
test_specification = ibp.ClassificationSpecification(
test_data.label, num_classes)
test_attack = attack_builder(predictor, test_specification, FLAGS.epsilon,
input_bounds=input_bounds,
Expand Down Expand Up @@ -230,7 +230,9 @@ def body(i, metrics):
test_writer = tf.summary.FileWriter(os.path.join(FLAGS.output_dir, 'test'))

# Run everything.
with tf.train.MonitoredSession() as sess:
tf_config = tf.ConfigProto()
tf_config.gpu_options.allow_growth = True
with tf.train.SingularMonitoredSession(config=tf_config) as sess:
for _ in xrange(FLAGS.steps):
iteration, loss_value, _ = sess.run(
[step, train_losses.scalar_losses.nominal_cross_entropy, train_op])
Expand Down
4 changes: 3 additions & 1 deletion interval_bound_propagation/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@
from interval_bound_propagation.src.attacks import UnrolledAdam
from interval_bound_propagation.src.attacks import UnrolledGradientDescent
from interval_bound_propagation.src.attacks import UntargetedPGDAttack
from interval_bound_propagation.src.bounds import AbstractBounds
from interval_bound_propagation.src.bounds import IntervalBounds
from interval_bound_propagation.src.layers import BatchNorm
from interval_bound_propagation.src.layers import ImageNorm
Expand All @@ -35,14 +36,15 @@
from interval_bound_propagation.src.loss import ScalarMetrics
from interval_bound_propagation.src.model import DNN
from interval_bound_propagation.src.model import VerifiableModelWrapper
from interval_bound_propagation.src.specification import ClassificationSpecification
from interval_bound_propagation.src.specification import LinearSpecification
from interval_bound_propagation.src.utils import add_image_normalization
from interval_bound_propagation.src.utils import build_classification_specification
from interval_bound_propagation.src.utils import build_dataset
from interval_bound_propagation.src.utils import create_classification_losses
from interval_bound_propagation.src.utils import linear_schedule
from interval_bound_propagation.src.verifiable_wrapper import BatchFlattenWrapper
from interval_bound_propagation.src.verifiable_wrapper import BatchNormWrapper
from interval_bound_propagation.src.verifiable_wrapper import ImageNormWrapper
from interval_bound_propagation.src.verifiable_wrapper import LinearConv2dWrapper
from interval_bound_propagation.src.verifiable_wrapper import LinearFCWrapper
from interval_bound_propagation.src.verifiable_wrapper import MonotonicWrapper
Expand Down
28 changes: 14 additions & 14 deletions interval_bound_propagation/src/attacks.py
Original file line number Diff line number Diff line change
Expand Up @@ -255,7 +255,7 @@ class UntargetedPGDAttack(PGDAttack):
"""Defines an untargeted PGD attack."""

def _build(self, labels):
batch_size = tf.shape(self._specification.c)[0]
batch_size = tf.shape(self._predictor.inputs)[0]
input_shape = list(self._predictor.inputs.shape.as_list()[1:])
duplicated_inputs = tf.expand_dims(self._predictor.inputs, axis=0)
# Shape is [num_restarts, batch_size, ...]
Expand All @@ -272,10 +272,9 @@ def objective_fn(x):
model_logits = eval_fn(x) # [restarts * batch_size, output].
model_logits = tf.reshape(
model_logits, [self._num_restarts, batch_size, -1])
# c has shape [batch_size, num_specs, output].
obj = tf.einsum('rbo,bso->rsb', model_logits, self._specification.c)
obj = self._specification.evaluate(model_logits)
# Output has dimension [num_restarts, batch_size].
return tf.reduce_max(obj, axis=1)
return tf.reduce_max(obj, axis=-1)

def reduced_loss_fn(x):
# Pick worse attack, output has shape [num_restarts, batch_size].
Expand All @@ -297,8 +296,9 @@ def reduced_loss_fn(x):
ij = tf.stack([i, j], axis=1)
self._attack = tf.gather_nd(adversarial_input, ij)
self._logits = eval_fn(self._attack)
correct_examples = tf.equal(labels, tf.argmax(self._logits, 1))
self._accuracy = tf.reduce_mean(tf.cast(correct_examples, tf.float32))
# Count the number of sample that violate any specification.
bounds = tf.reduce_max(self._specification.evaluate(self._logits), axis=1)
self._accuracy = tf.reduce_mean(tf.cast(bounds <= 0, tf.float32))
return self._attack

@property
Expand All @@ -321,8 +321,8 @@ class TargetedPGDAttack(PGDAttack):
"""Runs targeted attacks for each specification."""

def _build(self, labels):
batch_size = tf.shape(self._specification.c)[0]
num_specs = tf.shape(self._specification.c)[1]
batch_size = tf.shape(self._predictor.inputs)[0]
num_specs = self._specification.num_specifications
input_shape = list(self._predictor.inputs.shape.as_list()[1:])
duplicated_inputs = tf.expand_dims(self._predictor.inputs, axis=0)
# Shape is [num_restarts * num_specifications, batch_size, ...]
Expand All @@ -339,10 +339,8 @@ def objective_fn(x):
model_logits = eval_fn(x) # [restarts * num_specs * batch_size, output].
model_logits = tf.reshape(
model_logits, [self._num_restarts, num_specs, batch_size, -1])
# c has shape [batch_size, num_specs, output].
obj = tf.einsum('rsbo,bso->rsb', model_logits, self._specification.c)
# Output has dimension [num_restarts, num_objectives, batch_size]
return obj
# Output has shape [num_restarts, batch_size, num_specs].
return self._specification.evaluate(model_logits)

def reduced_loss_fn(x):
# Negate as we minimize.
Expand All @@ -356,6 +354,7 @@ def reduced_loss_fn(x):
image_bounds=self._input_bounds, random_init=True, optimizer=optimizer)
# Get best attack.
adversarial_objective = objective_fn(adversarial_input)
adversarial_objective = tf.transpose(adversarial_objective, [0, 2, 1])
adversarial_objective = tf.reshape(adversarial_objective, [-1, batch_size])
adversarial_input = tf.reshape(adversarial_input,
[-1, batch_size] + input_shape)
Expand All @@ -364,8 +363,9 @@ def reduced_loss_fn(x):
ij = tf.stack([i, j], axis=1)
self._attack = tf.gather_nd(adversarial_input, ij)
self._logits = eval_fn(self._attack)
correct_examples = tf.equal(labels, tf.argmax(self._logits, 1))
self._accuracy = tf.reduce_mean(tf.cast(correct_examples, tf.float32))
# Count the number of sample that violate any specification.
bounds = tf.reduce_max(self._specification.evaluate(self._logits), axis=1)
self._accuracy = tf.reduce_mean(tf.cast(bounds <= 0, tf.float32))
return self._attack

@property
Expand Down
45 changes: 8 additions & 37 deletions interval_bound_propagation/src/bounds.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,11 +31,12 @@ class AbstractBounds(object):

__metaclass__ = abc.ABCMeta

def propagate_through(self, wrapper):
def propagate_through(self, wrapper, *args):
"""Propagates bounds through a verifiable wrapper.
Args:
wrapper: `verifiable_wrapper.VerifiableWrapper`
*args: Additional arguments passed down to downstream callbacks.
Returns:
New bounds.
Expand All @@ -52,7 +53,7 @@ def propagate_through(self, wrapper):
strides = module.stride[1:-1]
return self._conv2d(w, b, padding, strides)
elif isinstance(wrapper, verifiable_wrapper.MonotonicWrapper):
return self._monotonic_fn(module)
return self._monotonic_fn(module, *args)
elif isinstance(wrapper, verifiable_wrapper.BatchNormWrapper):
return self._batch_norm(module.mean, module.variance, module.scale,
module.bias, module.epsilon)
Expand All @@ -62,10 +63,6 @@ def propagate_through(self, wrapper):
raise NotImplementedError('{} not supported.'.format(
wrapper.__class__.__name__))

@abc.abstractmethod
def combine_with(self, bounds):
"""Produces new bounds that keep track of multiple input bounds."""

def _raise_not_implemented(self, name):
raise NotImplementedError(
'{} modules are not supported by "{}".'.format(
Expand All @@ -77,7 +74,7 @@ def _linear(self, w, b): # pylint: disable=unused-argument
def _conv2d(self, w, b, padding, strides): # pylint: disable=unused-argument
self._raise_not_implemented('snt.Conv2D')

def _monotonic_fn(self, fn):
def _monotonic_fn(self, fn, *args): # pylint: disable=unused-argument
self._raise_not_implemented(fn.__name__)

def _batch_norm(self, mean, variance, scale, bias, epsilon): # pylint: disable=unused-argument
Expand All @@ -102,26 +99,7 @@ def lower(self):
def upper(self):
return self._upper

def combine_with(self, bounds):
if not isinstance(bounds, IntervalBounds):
raise NotImplementedError('Cannot combine IntervalBounds with '
'{}'.format(bounds))
bounds._ensure_singleton() # pylint: disable=protected-access
if not isinstance(self._lower, tuple):
self._ensure_singleton()
lower = (self._lower, bounds.lower)
upper = (self._upper, bounds.upper)
else:
lower = self._lower + (bounds.lower,)
upper = self._upper + (bounds.upper,)
return IntervalBounds(lower, upper)

def _ensure_singleton(self):
if isinstance(self._lower, tuple) or isinstance(self._upper, tuple):
raise ValueError('Cannot proceed with multiple inputs.')

def _linear(self, w, b):
self._ensure_singleton()
c = (self.lower + self.upper) / 2.
r = (self.upper - self.lower) / 2.
c = tf.matmul(c, w)
Expand All @@ -131,7 +109,6 @@ def _linear(self, w, b):
return IntervalBounds(c - r, c + r)

def _conv2d(self, w, b, padding, strides):
self._ensure_singleton()
c = (self.lower + self.upper) / 2.
r = (self.upper - self.lower) / 2.
c = tf.nn.convolution(c, w, padding=padding, strides=strides)
Expand All @@ -140,17 +117,12 @@ def _conv2d(self, w, b, padding, strides):
r = tf.nn.convolution(r, tf.abs(w), padding=padding, strides=strides)
return IntervalBounds(c - r, c + r)

def _monotonic_fn(self, fn):
if isinstance(self._lower, tuple):
assert isinstance(self._upper, tuple)
return IntervalBounds(fn(*self.lower),
fn(*self.upper))
self._ensure_singleton()
return IntervalBounds(fn(self.lower),
fn(self.upper))
def _monotonic_fn(self, fn, *args):
args_lower = [self.lower] + [a.lower for a in args]
args_upper = [self.upper] + [a.upper for a in args]
return IntervalBounds(fn(*args_lower), fn(*args_upper))

def _batch_norm(self, mean, variance, scale, bias, epsilon):
self._ensure_singleton()
# Element-wise multiplier.
multiplier = tf.rsqrt(variance + epsilon)
if scale is not None:
Expand All @@ -170,6 +142,5 @@ def _batch_norm(self, mean, variance, scale, bias, epsilon):
return IntervalBounds(c - r, c + r)

def _batch_flatten(self):
self._ensure_singleton()
return IntervalBounds(snt.BatchFlatten()(self.lower),
snt.BatchFlatten()(self.upper))
63 changes: 57 additions & 6 deletions interval_bound_propagation/src/loss.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,10 @@
import tensorflow as tf


# Used to pick the least violated specification.
_BIG_NUMBER = 1e25


ScalarMetrics = collections.namedtuple('ScalarMetrics', [
'nominal_accuracy',
'verified_accuracy',
Expand All @@ -47,10 +51,32 @@ def __init__(self, predictor, specification=None, pgd_attack=None,
self._predictor = predictor
self._specification = specification
self._attack = pgd_attack
if interval_bounds_loss_type not in ('xent', 'hinge'):
raise ValueError('interval_bounds_loss_type must be either "xent" or '
'"hinge".')
self._interval_bounds_loss_type = interval_bounds_loss_type
# Loss type can be any combination of:
# xent: cross-entropy loss
# hinge: hinge loss
# softplus: softplus loss
# with
# all: using all specifications.
# most: using only the specification that is the most violated.
# least: using only the specification that is the least violated.
# random_n: using a random subset of the specifications.
# E.g.: "xent_max" or "hinge_random_3".
tokens = interval_bounds_loss_type.split('_', 1)
if len(tokens) == 1:
loss_type, loss_mode = tokens[0], 'all'
else:
loss_type, loss_mode = tokens
if loss_mode.startswith('random'):
loss_mode, num_samples = loss_mode.split('_', 1)
self._interval_bounds_loss_n = int(num_samples)
if loss_type not in ('xent', 'hinge', 'softplus'):
raise ValueError('interval_bounds_loss_type must be either "xent", '
'"hinge" or "softplus".')
if loss_mode not in ('all', 'most', 'random', 'least'):
raise ValueError('interval_bounds_loss_type must be followed by either '
'"all", "most", "random_N" or "least".')
self._interval_bounds_loss_type = loss_type
self._interval_bounds_loss_mode = loss_mode
self._interval_bounds_hinge_margin = interval_bounds_hinge_margin

def _build(self, labels):
Expand All @@ -69,6 +95,28 @@ def _build(self, labels):
v = tf.reduce_max(bounds, axis=1)
self._interval_bounds_accuracy = tf.reduce_mean(
tf.cast(v <= 0., tf.float32))
# Select specifications.
if self._interval_bounds_loss_mode == 'all':
pass # Keep bounds the way it is.
elif self._interval_bounds_loss_mode == 'most':
bounds = tf.reduce_max(bounds, axis=1, keepdims=True)
elif self._interval_bounds_loss_mode == 'random':
idx = tf.random.uniform(
[tf.shape(bounds)[0], self._interval_bounds_loss_n],
0, tf.shape(bounds)[1], dtype=tf.int32)
bounds = tf.batch_gather(bounds, idx)
else:
assert self._interval_bounds_loss_mode == 'least'
# This picks the least violated contraint.
mask = tf.cast(bounds < 0., tf.float32)
smallest_violation = tf.reduce_min(
bounds + mask * _BIG_NUMBER, axis=1, keepdims=True)
has_violations = tf.less(
tf.reduce_sum(mask, axis=1, keepdims=True) + .5,
tf.cast(tf.shape(bounds)[1], tf.float32))
largest_bounds = tf.reduce_max(bounds, axis=1, keepdims=True)
bounds = tf.where(has_violations, smallest_violation, largest_bounds)

if self._interval_bounds_loss_type == 'xent':
v = tf.concat(
[bounds, tf.zeros([tf.shape(bounds)[0], 1], dtype=bounds.dtype)],
Expand All @@ -80,10 +128,13 @@ def _build(self, labels):
self._verified_loss = tf.reduce_mean(
tf.nn.softmax_cross_entropy_with_logits_v2(
labels=tf.stop_gradient(l), logits=v))
elif self._interval_bounds_loss_type == 'softplus':
self._verified_loss = tf.reduce_mean(
tf.nn.softplus(bounds + self._interval_bounds_hinge_margin))
else:
assert self._interval_bounds_loss_type == 'hinge'
self._verified_loss = tf.maximum(v, -self._interval_bounds_hinge_margin)

self._verified_loss = tf.reduce_mean(
tf.maximum(bounds, -self._interval_bounds_hinge_margin))
else:
self._verified_loss = tf.constant(0.)
self._interval_bounds_accuracy = tf.constant(0.)
Expand Down
Loading

0 comments on commit 5fa09e7

Please sign in to comment.