Skip to content

Commit

Permalink
Adding preprocessing to all models (#39)
Browse files Browse the repository at this point in the history
* added preprocessing_fn to pytorch model

* fixed tabs

* use mxnet 0.10.0 for now, because 0.10.1 is buggy

(apache/mxnet#6874)

* fixed pep8 violations

* added test for pytorch preprocessing support

* fixed test

* Added elementwise preprocessing to all models

* Removed leftover copies from mxnet

* fixed hanging indent

* Updated toy example
  • Loading branch information
wielandbrendel authored and jonasrauber committed Jul 21, 2017
1 parent 2c4d327 commit 659534c
Show file tree
Hide file tree
Showing 14 changed files with 439 additions and 47 deletions.
7 changes: 4 additions & 3 deletions README.rst
Expand Up @@ -38,19 +38,20 @@ Example
import foolbox
import keras
from keras.applications.resnet50 import ResNet50, preprocess_input
from keras.applications.resnet50 import ResNet50
# instantiate model
keras.backend.set_learning_phase(0)
kmodel = ResNet50(weights='imagenet')
fmodel = foolbox.models.KerasModel(kmodel, bounds=(0, 255), preprocess_fn=preprocess_input)
preprocessing = (np.array([104, 116, 123])[None, None], 1)
fmodel = foolbox.models.KerasModel(kmodel, bounds=(0, 255), preprocessing=preprocessing)
# get source image and label
image, label = foolbox.utils.imagenet_example()
# apply attack on source image
attack = foolbox.attacks.FGSM(fmodel)
adversarial = attack(image, label)
adversarial = attack(image[:,:,::-1], label)
Interfaces for a range of other deeplearning packages such as TensorFlow,
PyTorch, Theano, Lasagne and MXNet are available, e.g.
Expand Down
13 changes: 12 additions & 1 deletion foolbox/models/base.py
Expand Up @@ -26,14 +26,19 @@ class Model(ABC):
(0, 1) or (0, 255).
channel_axis : int
The index of the axis that represents color channels.
preprocessing: 2-element tuple with floats or numpy arrays
Elementwises preprocessing of input; we first subtract the first
element of preprocessing from the input and then divide the input by
the second element.
"""

def __init__(self, bounds, channel_axis):
def __init__(self, bounds, channel_axis, preprocessing=(0, 1)):
assert len(bounds) == 2
self._bounds = bounds
assert channel_axis in [0, 1, 2, 3]
self._channel_axis = channel_axis
self._preprocessing = preprocessing

def __enter__(self):
return self
Expand All @@ -47,6 +52,12 @@ def bounds(self):
def channel_axis(self):
return self._channel_axis

def _process_input(self, input):
return (input - self._preprocessing[0]) / self._preprocessing[1]

def _process_gradient(self, gradient):
return gradient / self._preprocessing[1]

@abstractmethod
def batch_predictions(self, images):
"""Calculates predictions for a batch of images.
Expand Down
26 changes: 14 additions & 12 deletions foolbox/models/keras.py
Expand Up @@ -17,11 +17,13 @@ class KerasModel(DifferentiableModel):
(0, 1) or (0, 255).
channel_axis : int
The index of the axis that represents color channels.
preprocessing: 2-element tuple with floats or numpy arrays
Elementwises preprocessing of input; we first subtract the first
element of preprocessing from the input and then divide the input by
the second element.
predicts : str
Specifies whether the `Keras` model predicts logits or probabilities.
Logits are preferred, but probabilities are the default.
preprocess_fn : function
Will be called with the images before model predictions are calculated.
"""

Expand All @@ -30,11 +32,12 @@ def __init__(
model,
bounds,
channel_axis=3,
predicts='probabilities',
preprocess_fn=None):
preprocessing=(0, 1),
predicts='probabilities'):

super(KerasModel, self).__init__(bounds=bounds,
channel_axis=channel_axis)
channel_axis=channel_axis,
preprocessing=preprocessing)

from keras import backend as K

Expand Down Expand Up @@ -67,6 +70,9 @@ def __init__(
grads = K.gradients(loss, images_input)
grad = grads[0]

self._loss_fn = K.function(
[images_input, label_input],
[loss])
self._batch_pred_fn = K.function(
[images_input], [predictions])
self._pred_grad_fn = K.function(
Expand All @@ -75,11 +81,6 @@ def __init__(

self._predictions_are_logits = predictions_are_logits

if preprocess_fn is not None:
self.preprocessing_fn = lambda x: preprocess_fn(x.copy())
else:
self.preprocessing_fn = lambda x: x

def _as_logits(self, predictions):
assert predictions.ndim in [1, 2]
if self._predictions_are_logits:
Expand All @@ -93,7 +94,7 @@ def num_classes(self):
return self._num_classes

def batch_predictions(self, images):
predictions = self._batch_pred_fn([self.preprocessing_fn(images)])
predictions = self._batch_pred_fn([self._process_input(images)])
assert len(predictions) == 1
predictions = predictions[0]
assert predictions.shape == (images.shape[0], self.num_classes())
Expand All @@ -102,11 +103,12 @@ def batch_predictions(self, images):

def predictions_and_gradient(self, image, label):
predictions, gradient = self._pred_grad_fn([
self.preprocessing_fn(image[np.newaxis]),
self._process_input(image[np.newaxis]),
np.array([label])])
predictions = np.squeeze(predictions, axis=0)
predictions = self._as_logits(predictions)
gradient = np.squeeze(gradient, axis=0)
gradient = self._process_gradient(gradient)
assert predictions.shape == (self.num_classes(),)
assert gradient.shape == image.shape
return predictions, gradient
11 changes: 9 additions & 2 deletions foolbox/models/lasagne.py
Expand Up @@ -19,6 +19,10 @@ class LasagneModel(DifferentiableModel):
(0, 1) or (0, 255).
channel_axis : int
The index of the axis that represents color channels.
preprocessing: 2-element tuple with floats or numpy arrays
Elementwises preprocessing of input; we first subtract the first
element of preprocessing from the input and then divide the input by
the second element.
"""

Expand All @@ -27,10 +31,12 @@ def __init__(
input_layer,
logits_layer,
bounds,
channel_axis=1):
channel_axis=1,
preprocessing=(0, 1)):

super(LasagneModel, self).__init__(bounds=bounds,
channel_axis=channel_axis)
channel_axis=channel_axis,
preprocessing=preprocessing)

# delay import until class is instantiated
import theano as th
Expand All @@ -56,6 +62,7 @@ def __init__(
self._predictions_and_gradient_fn = th.function(
[images, labels], [logits, gradient])
self._gradient_fn = th.function([images, labels], gradient)
self._loss_fn = th.function([images, labels], loss)

def batch_predictions(self, images):
predictions = self._batch_prediction_fn(images)
Expand Down
28 changes: 26 additions & 2 deletions foolbox/models/mxnet.py
Expand Up @@ -25,6 +25,10 @@ class MXNetModel(DifferentiableModel):
(0, 1) or (0, 255).
channel_axis : int
The index of the axis that represents color channels.
preprocessing: 2-element tuple with floats or numpy arrays
Elementwises preprocessing of input; we first subtract the first
element of preprocessing from the input and then divide the input by
the second element.
"""

Expand All @@ -36,10 +40,13 @@ def __init__(
device,
num_classes,
bounds,
channel_axis=1):
channel_axis=1,
preprocessing=(0, 1)):

super(MXNetModel, self).__init__(
bounds=bounds, channel_axis=channel_axis)
bounds=bounds,
channel_axis=channel_axis,
preprocessing=preprocessing)

import mxnet as mx

Expand All @@ -65,6 +72,7 @@ def num_classes(self):

def batch_predictions(self, images):
import mxnet as mx
images = self._process_input(images)
data_array = mx.nd.array(images, ctx=self._device)
self._args_map[self._data_sym.name] = data_array
model = self._batch_logits_sym.bind(
Expand All @@ -77,6 +85,7 @@ def batch_predictions(self, images):
def predictions_and_gradient(self, image, label):
import mxnet as mx
label = np.asarray(label)
image = self._process_input(image)
data_array = mx.nd.array(image[np.newaxis], ctx=self._device)
label_array = mx.nd.array(label[np.newaxis], ctx=self._device)
self._args_map[self._data_sym.name] = data_array
Expand All @@ -99,4 +108,19 @@ def predictions_and_gradient(self, image, label):
])
logits = logits_array.asnumpy()
gradient = grad_array.asnumpy()
gradient = self._process_gradient(gradient)
return np.squeeze(logits, axis=0), np.squeeze(gradient, axis=0)

def _loss_fn(self, image, label):
import mxnet as mx
image = self._process_input(image)
data_array = mx.nd.array(image[np.newaxis], ctx=self._device)
label_array = mx.nd.array(np.array([label]), ctx=self._device)
self._args_map[self._data_sym.name] = data_array
self._args_map[self._label_sym.name] = label_array
model = self._loss_sym.bind(
ctx=self._device, args=self._args_map, grad_req='null')
model.forward(is_train=False)
loss_array = model.outputs[0]
loss = loss_array.asnumpy()[0]
return loss
49 changes: 38 additions & 11 deletions foolbox/models/pytorch.py
Expand Up @@ -19,8 +19,10 @@ class PyTorchModel(DifferentiableModel):
The index of the axis that represents color channels.
cuda : bool
A boolean specifying whether the model uses CUDA.
preprocess_fn : function
Will be called with the images before model predictions are calculated.
preprocessing: 2-element tuple with floats or numpy arrays
Elementwises preprocessing of input; we first subtract the first
element of preprocessing from the input and then divide the input by
the second element.
"""

Expand All @@ -31,27 +33,24 @@ def __init__(
num_classes,
channel_axis=1,
cuda=True,
preprocess_fn=None):
preprocessing=(0, 1)):

super(PyTorchModel, self).__init__(bounds=bounds,
channel_axis=channel_axis)
channel_axis=channel_axis,
preprocessing=preprocessing)

self._num_classes = num_classes
self._model = model
self.cuda = cuda

if preprocess_fn is not None:
self.preprocessing_fn = lambda x: preprocess_fn(x.copy())
else:
self.preprocessing_fn = lambda x: x

def batch_predictions(self, images):
# lazy import
import torch
from torch.autograd import Variable

images = self._process_input(images)
n = len(images)
images = torch.from_numpy(self.preprocessing_fn(images))
images = torch.from_numpy(images)
if self.cuda: # pragma: no cover
images = images.cuda()
images = Variable(images, volatile=True)
Expand All @@ -73,6 +72,7 @@ def predictions_and_gradient(self, image, label):
import torch.nn as nn
from torch.autograd import Variable

image = self._process_input(image)
target = np.array([label])
target = torch.from_numpy(target)
if self.cuda: # pragma: no cover
Expand All @@ -81,7 +81,7 @@ def predictions_and_gradient(self, image, label):

assert image.ndim == 3
images = image[np.newaxis]
images = torch.from_numpy(self.preprocessing_fn(images))
images = torch.from_numpy(images)
if self.cuda: # pragma: no cover
images = images.cuda()
images = Variable(images, requires_grad=True)
Expand All @@ -104,7 +104,34 @@ def predictions_and_gradient(self, image, label):
if self.cuda: # pragma: no cover
grad = grad.cpu()
grad = grad.numpy()
grad = self._process_gradient(grad)
grad = np.squeeze(grad, axis=0)
assert grad.shape == image.shape

return predictions, grad

def _loss_fn(self, image, label):
# lazy import
import torch
import torch.nn as nn
from torch.autograd import Variable

image = self._process_input(image)
target = np.array([label])
target = torch.from_numpy(target)
if self.cuda: # pragma: no cover
target = target.cuda()
target = Variable(target)

images = torch.from_numpy(image[None])
if self.cuda: # pragma: no cover
images = images.cuda()
images = Variable(images, volatile=True)
predictions = self._model(images)
ce = nn.CrossEntropyLoss()
loss = ce(predictions, target)
loss = loss.data
if self.cuda: # pragma: no cover
loss = loss.cpu()
loss = loss.numpy()
return loss

0 comments on commit 659534c

Please sign in to comment.