diff --git a/README.rst b/README.rst index 5e1e0e71..441468f8 100644 --- a/README.rst +++ b/README.rst @@ -38,19 +38,20 @@ Example import foolbox import keras - from keras.applications.resnet50 import ResNet50, preprocess_input + from keras.applications.resnet50 import ResNet50 # instantiate model keras.backend.set_learning_phase(0) kmodel = ResNet50(weights='imagenet') - fmodel = foolbox.models.KerasModel(kmodel, bounds=(0, 255), preprocess_fn=preprocess_input) + preprocessing = (np.array([104, 116, 123])[None, None], 1) + fmodel = foolbox.models.KerasModel(kmodel, bounds=(0, 255), preprocessing=preprocessing) # get source image and label image, label = foolbox.utils.imagenet_example() # apply attack on source image attack = foolbox.attacks.FGSM(fmodel) - adversarial = attack(image, label) + adversarial = attack(image[:,:,::-1], label) Interfaces for a range of other deeplearning packages such as TensorFlow, PyTorch, Theano, Lasagne and MXNet are available, e.g. diff --git a/foolbox/models/base.py b/foolbox/models/base.py index ecc4b5f3..9ec70a6f 100644 --- a/foolbox/models/base.py +++ b/foolbox/models/base.py @@ -26,14 +26,19 @@ class Model(ABC): (0, 1) or (0, 255). channel_axis : int The index of the axis that represents color channels. + preprocessing: 2-element tuple with floats or numpy arrays + Elementwises preprocessing of input; we first subtract the first + element of preprocessing from the input and then divide the input by + the second element. """ - def __init__(self, bounds, channel_axis): + def __init__(self, bounds, channel_axis, preprocessing=(0, 1)): assert len(bounds) == 2 self._bounds = bounds assert channel_axis in [0, 1, 2, 3] self._channel_axis = channel_axis + self._preprocessing = preprocessing def __enter__(self): return self @@ -47,6 +52,12 @@ def bounds(self): def channel_axis(self): return self._channel_axis + def _process_input(self, input): + return (input - self._preprocessing[0]) / self._preprocessing[1] + + def _process_gradient(self, gradient): + return gradient / self._preprocessing[1] + @abstractmethod def batch_predictions(self, images): """Calculates predictions for a batch of images. diff --git a/foolbox/models/keras.py b/foolbox/models/keras.py index 7fc3c7ae..df977ac4 100644 --- a/foolbox/models/keras.py +++ b/foolbox/models/keras.py @@ -17,11 +17,13 @@ class KerasModel(DifferentiableModel): (0, 1) or (0, 255). channel_axis : int The index of the axis that represents color channels. + preprocessing: 2-element tuple with floats or numpy arrays + Elementwises preprocessing of input; we first subtract the first + element of preprocessing from the input and then divide the input by + the second element. predicts : str Specifies whether the `Keras` model predicts logits or probabilities. Logits are preferred, but probabilities are the default. - preprocess_fn : function - Will be called with the images before model predictions are calculated. """ @@ -30,11 +32,12 @@ def __init__( model, bounds, channel_axis=3, - predicts='probabilities', - preprocess_fn=None): + preprocessing=(0, 1), + predicts='probabilities'): super(KerasModel, self).__init__(bounds=bounds, - channel_axis=channel_axis) + channel_axis=channel_axis, + preprocessing=preprocessing) from keras import backend as K @@ -67,6 +70,9 @@ def __init__( grads = K.gradients(loss, images_input) grad = grads[0] + self._loss_fn = K.function( + [images_input, label_input], + [loss]) self._batch_pred_fn = K.function( [images_input], [predictions]) self._pred_grad_fn = K.function( @@ -75,11 +81,6 @@ def __init__( self._predictions_are_logits = predictions_are_logits - if preprocess_fn is not None: - self.preprocessing_fn = lambda x: preprocess_fn(x.copy()) - else: - self.preprocessing_fn = lambda x: x - def _as_logits(self, predictions): assert predictions.ndim in [1, 2] if self._predictions_are_logits: @@ -93,7 +94,7 @@ def num_classes(self): return self._num_classes def batch_predictions(self, images): - predictions = self._batch_pred_fn([self.preprocessing_fn(images)]) + predictions = self._batch_pred_fn([self._process_input(images)]) assert len(predictions) == 1 predictions = predictions[0] assert predictions.shape == (images.shape[0], self.num_classes()) @@ -102,11 +103,12 @@ def batch_predictions(self, images): def predictions_and_gradient(self, image, label): predictions, gradient = self._pred_grad_fn([ - self.preprocessing_fn(image[np.newaxis]), + self._process_input(image[np.newaxis]), np.array([label])]) predictions = np.squeeze(predictions, axis=0) predictions = self._as_logits(predictions) gradient = np.squeeze(gradient, axis=0) + gradient = self._process_gradient(gradient) assert predictions.shape == (self.num_classes(),) assert gradient.shape == image.shape return predictions, gradient diff --git a/foolbox/models/lasagne.py b/foolbox/models/lasagne.py index 27afff49..e08f9670 100644 --- a/foolbox/models/lasagne.py +++ b/foolbox/models/lasagne.py @@ -19,6 +19,10 @@ class LasagneModel(DifferentiableModel): (0, 1) or (0, 255). channel_axis : int The index of the axis that represents color channels. + preprocessing: 2-element tuple with floats or numpy arrays + Elementwises preprocessing of input; we first subtract the first + element of preprocessing from the input and then divide the input by + the second element. """ @@ -27,10 +31,12 @@ def __init__( input_layer, logits_layer, bounds, - channel_axis=1): + channel_axis=1, + preprocessing=(0, 1)): super(LasagneModel, self).__init__(bounds=bounds, - channel_axis=channel_axis) + channel_axis=channel_axis, + preprocessing=preprocessing) # delay import until class is instantiated import theano as th @@ -56,6 +62,7 @@ def __init__( self._predictions_and_gradient_fn = th.function( [images, labels], [logits, gradient]) self._gradient_fn = th.function([images, labels], gradient) + self._loss_fn = th.function([images, labels], loss) def batch_predictions(self, images): predictions = self._batch_prediction_fn(images) diff --git a/foolbox/models/mxnet.py b/foolbox/models/mxnet.py index 449f737e..46f6fb1b 100644 --- a/foolbox/models/mxnet.py +++ b/foolbox/models/mxnet.py @@ -25,6 +25,10 @@ class MXNetModel(DifferentiableModel): (0, 1) or (0, 255). channel_axis : int The index of the axis that represents color channels. + preprocessing: 2-element tuple with floats or numpy arrays + Elementwises preprocessing of input; we first subtract the first + element of preprocessing from the input and then divide the input by + the second element. """ @@ -36,10 +40,13 @@ def __init__( device, num_classes, bounds, - channel_axis=1): + channel_axis=1, + preprocessing=(0, 1)): super(MXNetModel, self).__init__( - bounds=bounds, channel_axis=channel_axis) + bounds=bounds, + channel_axis=channel_axis, + preprocessing=preprocessing) import mxnet as mx @@ -65,6 +72,7 @@ def num_classes(self): def batch_predictions(self, images): import mxnet as mx + images = self._process_input(images) data_array = mx.nd.array(images, ctx=self._device) self._args_map[self._data_sym.name] = data_array model = self._batch_logits_sym.bind( @@ -77,6 +85,7 @@ def batch_predictions(self, images): def predictions_and_gradient(self, image, label): import mxnet as mx label = np.asarray(label) + image = self._process_input(image) data_array = mx.nd.array(image[np.newaxis], ctx=self._device) label_array = mx.nd.array(label[np.newaxis], ctx=self._device) self._args_map[self._data_sym.name] = data_array @@ -99,4 +108,19 @@ def predictions_and_gradient(self, image, label): ]) logits = logits_array.asnumpy() gradient = grad_array.asnumpy() + gradient = self._process_gradient(gradient) return np.squeeze(logits, axis=0), np.squeeze(gradient, axis=0) + + def _loss_fn(self, image, label): + import mxnet as mx + image = self._process_input(image) + data_array = mx.nd.array(image[np.newaxis], ctx=self._device) + label_array = mx.nd.array(np.array([label]), ctx=self._device) + self._args_map[self._data_sym.name] = data_array + self._args_map[self._label_sym.name] = label_array + model = self._loss_sym.bind( + ctx=self._device, args=self._args_map, grad_req='null') + model.forward(is_train=False) + loss_array = model.outputs[0] + loss = loss_array.asnumpy()[0] + return loss diff --git a/foolbox/models/pytorch.py b/foolbox/models/pytorch.py index c41cb13f..efd845a4 100644 --- a/foolbox/models/pytorch.py +++ b/foolbox/models/pytorch.py @@ -19,8 +19,10 @@ class PyTorchModel(DifferentiableModel): The index of the axis that represents color channels. cuda : bool A boolean specifying whether the model uses CUDA. - preprocess_fn : function - Will be called with the images before model predictions are calculated. + preprocessing: 2-element tuple with floats or numpy arrays + Elementwises preprocessing of input; we first subtract the first + element of preprocessing from the input and then divide the input by + the second element. """ @@ -31,27 +33,24 @@ def __init__( num_classes, channel_axis=1, cuda=True, - preprocess_fn=None): + preprocessing=(0, 1)): super(PyTorchModel, self).__init__(bounds=bounds, - channel_axis=channel_axis) + channel_axis=channel_axis, + preprocessing=preprocessing) self._num_classes = num_classes self._model = model self.cuda = cuda - if preprocess_fn is not None: - self.preprocessing_fn = lambda x: preprocess_fn(x.copy()) - else: - self.preprocessing_fn = lambda x: x - def batch_predictions(self, images): # lazy import import torch from torch.autograd import Variable + images = self._process_input(images) n = len(images) - images = torch.from_numpy(self.preprocessing_fn(images)) + images = torch.from_numpy(images) if self.cuda: # pragma: no cover images = images.cuda() images = Variable(images, volatile=True) @@ -73,6 +72,7 @@ def predictions_and_gradient(self, image, label): import torch.nn as nn from torch.autograd import Variable + image = self._process_input(image) target = np.array([label]) target = torch.from_numpy(target) if self.cuda: # pragma: no cover @@ -81,7 +81,7 @@ def predictions_and_gradient(self, image, label): assert image.ndim == 3 images = image[np.newaxis] - images = torch.from_numpy(self.preprocessing_fn(images)) + images = torch.from_numpy(images) if self.cuda: # pragma: no cover images = images.cuda() images = Variable(images, requires_grad=True) @@ -104,7 +104,34 @@ def predictions_and_gradient(self, image, label): if self.cuda: # pragma: no cover grad = grad.cpu() grad = grad.numpy() + grad = self._process_gradient(grad) grad = np.squeeze(grad, axis=0) assert grad.shape == image.shape return predictions, grad + + def _loss_fn(self, image, label): + # lazy import + import torch + import torch.nn as nn + from torch.autograd import Variable + + image = self._process_input(image) + target = np.array([label]) + target = torch.from_numpy(target) + if self.cuda: # pragma: no cover + target = target.cuda() + target = Variable(target) + + images = torch.from_numpy(image[None]) + if self.cuda: # pragma: no cover + images = images.cuda() + images = Variable(images, volatile=True) + predictions = self._model(images) + ce = nn.CrossEntropyLoss() + loss = ce(predictions, target) + loss = loss.data + if self.cuda: # pragma: no cover + loss = loss.cpu() + loss = loss.numpy() + return loss diff --git a/foolbox/models/tensorflow.py b/foolbox/models/tensorflow.py index bb97ad29..02ff44c0 100644 --- a/foolbox/models/tensorflow.py +++ b/foolbox/models/tensorflow.py @@ -19,6 +19,10 @@ class TensorFlowModel(DifferentiableModel): (0, 1) or (0, 255). channel_axis : int The index of the axis that represents color channels. + preprocessing: 2-element tuple with floats or numpy arrays + Elementwises preprocessing of input; we first subtract the first + element of preprocessing from the input and then divide the input by + the second element. """ @@ -27,10 +31,12 @@ def __init__( images, logits, bounds, - channel_axis=3): + channel_axis=3, + preprocessing=(0, 1)): super(TensorFlowModel, self).__init__(bounds=bounds, - channel_axis=channel_axis) + channel_axis=channel_axis, + preprocessing=preprocessing) # delay import until class is instantiated import tensorflow as tf @@ -52,7 +58,7 @@ def __init__( loss = tf.nn.sparse_softmax_cross_entropy_with_logits( labels=self._label[tf.newaxis], logits=self._logits[tf.newaxis]) - loss = tf.squeeze(loss, axis=0) + self._loss = tf.squeeze(loss, axis=0) gradients = tf.gradients(loss, images) assert len(gradients) == 1 self._gradient = tf.squeeze(gradients[0], axis=0) @@ -67,23 +73,37 @@ def num_classes(self): return n def batch_predictions(self, images): + images = self._process_input(images) predictions = self._session.run( self._batch_logits, feed_dict={self._images: images}) return predictions def predictions_and_gradient(self, image, label): + image = self._process_input(image) predictions, gradient = self._session.run( [self._logits, self._gradient], feed_dict={ self._images: image[np.newaxis], self._label: label}) + gradient = self._process_gradient(gradient) return predictions, gradient def gradient(self, image, label): + image = self._process_input(image) g = self._session.run( self._gradient, feed_dict={ self._images: image[np.newaxis], self._label: label}) + g = self._process_gradient(g) return g + + def _loss_fn(self, image, label): + image = self._process_input(image) + loss = self._session.run( + self._loss, + feed_dict={ + self._images: image[np.newaxis], + self._label: label}) + return loss diff --git a/foolbox/models/theano.py b/foolbox/models/theano.py index c7a36152..09483024 100644 --- a/foolbox/models/theano.py +++ b/foolbox/models/theano.py @@ -19,6 +19,10 @@ class TheanoModel(DifferentiableModel): Number of classes for which the model will output predictions. channel_axis : int The index of the axis that represents color channels. + preprocessing: 2-element tuple with floats or numpy arrays + Elementwises preprocessing of input; we first subtract the first + element of preprocessing from the input and then divide the input by + the second element. """ @@ -28,10 +32,12 @@ def __init__( logits, bounds, num_classes, - channel_axis=1): + channel_axis=1, + preprocessing=[0, 1]): super(TheanoModel, self).__init__(bounds=bounds, - channel_axis=channel_axis) + channel_axis=channel_axis, + preprocessing=preprocessing) self._num_classes = num_classes @@ -50,25 +56,31 @@ def __init__( self._predictions_and_gradient_fn = th.function( [images, labels], [logits, gradient]) self._gradient_fn = th.function([images, labels], gradient) + self._loss_fn = th.function([images, labels], loss) def batch_predictions(self, images): + images = self._process_input(images) predictions = self._batch_prediction_fn(images) assert predictions.shape == (images.shape[0], self.num_classes()) return predictions def predictions_and_gradient(self, image, label): + image = self._process_input(image) label = np.array(label, dtype=np.int32) predictions, gradient = self._predictions_and_gradient_fn( image[np.newaxis], label[np.newaxis]) predictions = np.squeeze(predictions, axis=0) + gradient = self._process_gradient(gradient) gradient = np.squeeze(gradient, axis=0) assert predictions.shape == (self.num_classes(),) assert gradient.shape == image.shape return predictions, gradient def gradient(self, image, label): + image = self._process_input(image) label = np.array(label, dtype=np.int32) gradient = self._gradient_fn(image[np.newaxis], label[np.newaxis]) + gradient = self._process_gradient(gradient) gradient = np.squeeze(gradient, axis=0) assert gradient.shape == image.shape return gradient diff --git a/foolbox/tests/test_models_keras.py b/foolbox/tests/test_models_keras.py index 0e27cd9a..30ca3890 100644 --- a/foolbox/tests/test_models_keras.py +++ b/foolbox/tests/test_models_keras.py @@ -1,4 +1,5 @@ import pytest + import numpy as np from keras.layers import GlobalAveragePooling2D from keras.layers import Activation @@ -102,10 +103,8 @@ def test_keras_model_preprocess(): logits = GlobalAveragePooling2D( data_format='channels_last')(inputs) - def preprocess_fn(x): - # modify x in-place - x /= 2 - return x + preprocessing = (np.arange(num_classes)[None, None], + np.random.uniform(size=(5, 5, channels)) + 1) model1 = KerasModel( Model(inputs=inputs, outputs=logits), @@ -116,7 +115,7 @@ def preprocess_fn(x): Model(inputs=inputs, outputs=logits), bounds=bounds, predicts='logits', - preprocess_fn=preprocess_fn) + preprocessing=preprocessing) model3 = KerasModel( Model(inputs=inputs, outputs=logits), @@ -142,3 +141,41 @@ def preprocess_fn(x): p1 - p1.max(), p3 - p3.max(), decimal=5) + + +def test_keras_model_gradients(): + num_classes = 1000 + bounds = (0, 255) + channels = num_classes + + inputs = Input(shape=(5, 5, channels)) + logits = GlobalAveragePooling2D( + data_format='channels_last')(inputs) + + preprocessing = (np.arange(num_classes)[None, None], + np.random.uniform(size=(5, 5, channels)) + 1) + + model = KerasModel( + Model(inputs=inputs, outputs=logits), + bounds=bounds, + predicts='logits', + preprocessing=preprocessing) + + eps = 1e-3 + + np.random.seed(22) + test_image = np.random.rand(5, 5, channels).astype(np.float32) + test_label = 7 + + _, g1 = model.predictions_and_gradient(test_image, test_label) + + l1 = model._loss_fn([test_image[None] - eps / 2 * g1, [test_label]])[0][0] + l2 = model._loss_fn([test_image[None] + eps / 2 * g1, [test_label]])[0][0] + + assert 1e5 * (l2 - l1) > 1 + + # make sure that gradient is numerically correct + np.testing.assert_array_almost_equal( + 1e5 * (l2 - l1), + 1e5 * eps * np.linalg.norm(g1)**2, + decimal=1) diff --git a/foolbox/tests/test_models_lasagne.py b/foolbox/tests/test_models_lasagne.py index d47ddef5..3b50341f 100644 --- a/foolbox/tests/test_models_lasagne.py +++ b/foolbox/tests/test_models_lasagne.py @@ -45,3 +45,43 @@ def mean_brightness_net(images): test_gradient) assert model.num_classes() == num_classes + + +@pytest.mark.parametrize('num_classes', [10, 1000]) +def test_lasagne_gradient(num_classes): + bounds = (0, 255) + channels = num_classes + + def mean_brightness_net(images): + logits = GlobalPoolLayer(images) + return logits + + images_var = T.tensor4('images') + images = InputLayer((None, channels, 5, 5), images_var) + logits = mean_brightness_net(images) + + preprocessing = (np.arange(num_classes)[None, None], + np.random.uniform(size=(5, 5, channels)) + 1) + + model = LasagneModel( + images, + logits, + preprocessing=preprocessing, + bounds=bounds) + + epsilon = 1e-2 + + np.random.seed(23) + test_image = np.random.rand(channels, 5, 5).astype(np.float32) + test_label = 7 + + _, g1 = model.predictions_and_gradient(test_image, test_label) + + l1 = model._loss_fn(test_image[None] - epsilon / 2 * g1, [test_label])[0] + l2 = model._loss_fn(test_image[None] + epsilon / 2 * g1, [test_label])[0] + + # make sure that gradient is numerically correct + np.testing.assert_array_almost_equal( + 1e4 * (l2 - l1), + 1e4 * epsilon * np.linalg.norm(g1)**2, + decimal=1) diff --git a/foolbox/tests/test_models_mxnet.py b/foolbox/tests/test_models_mxnet.py index 11dab225..448d5a38 100644 --- a/foolbox/tests/test_models_mxnet.py +++ b/foolbox/tests/test_models_mxnet.py @@ -47,3 +47,46 @@ def mean_brightness_net(images): test_gradient) assert model.num_classes() == num_classes + + +@pytest.mark.parametrize('num_classes', [10, 1000]) +def test_model_gradient(num_classes): + bounds = (0, 255) + channels = num_classes + + def mean_brightness_net(images): + logits = mx.symbol.mean(images, axis=(2, 3)) + return logits + + images = mx.symbol.Variable('images') + logits = mean_brightness_net(images) + + preprocessing = (np.arange(num_classes)[:, None, None], + np.random.uniform(size=(channels, 5, 5)) + 1) + + model = MXNetModel( + images, + logits, + {}, + device=mx.cpu(), + num_classes=num_classes, + bounds=bounds, + preprocessing=preprocessing, + channel_axis=1) + + test_images = np.random.rand(2, channels, 5, 5).astype(np.float32) + test_image = test_images[0] + test_label = 7 + + epsilon = 1e-2 + _, g1 = model.predictions_and_gradient(test_image, test_label) + l1 = model._loss_fn(test_image - epsilon / 2 * g1, test_label) + l2 = model._loss_fn(test_image + epsilon / 2 * g1, test_label) + + assert 1e4 * (l2 - l1) > 1 + + # make sure that gradient is numerically correct + np.testing.assert_array_almost_equal( + 1e4 * (l2 - l1), + 1e4 * epsilon * np.linalg.norm(g1)**2, + decimal=1) diff --git a/foolbox/tests/test_models_pytorch.py b/foolbox/tests/test_models_pytorch.py index f9c3cacd..145ef2e4 100644 --- a/foolbox/tests/test_models_pytorch.py +++ b/foolbox/tests/test_models_pytorch.py @@ -73,11 +73,8 @@ def forward(self, x): return logits model = Net() - - def preprocess_fn(x): - # modify x in-place - x /= 2 - return x + preprocessing = (np.arange(num_classes)[:, None, None], + np.random.uniform(size=(channels, 5, 5)) + 1) model1 = PyTorchModel( model, @@ -90,7 +87,7 @@ def preprocess_fn(x): bounds=bounds, num_classes=num_classes, cuda=False, - preprocess_fn=preprocess_fn) + preprocessing=preprocessing) model3 = PyTorchModel( model, @@ -117,3 +114,52 @@ def preprocess_fn(x): p1 - p1.max(), p3 - p3.max(), decimal=5) + + +def test_pytorch_model_gradient(): + num_classes = 1000 + bounds = (0, 255) + channels = num_classes + + class Net(nn.Module): + + def __init__(self): + super(Net, self).__init__() + + def forward(self, x): + x = torch.mean(x, 3) + x = torch.squeeze(x, dim=3) + x = torch.mean(x, 2) + x = torch.squeeze(x, dim=2) + logits = x + return logits + + model = Net() + preprocessing = (np.arange(num_classes)[:, None, None], + np.random.uniform(size=(channels, 5, 5)) + 1) + + model = PyTorchModel( + model, + bounds=bounds, + num_classes=num_classes, + cuda=False, + preprocessing=preprocessing) + + epsilon = 1e-2 + + np.random.seed(23) + test_image = np.random.rand(channels, 5, 5).astype(np.float32) + test_label = 7 + + _, g1 = model.predictions_and_gradient(test_image, test_label) + + l1 = model._loss_fn(test_image - epsilon / 2 * g1, test_label) + l2 = model._loss_fn(test_image + epsilon / 2 * g1, test_label) + + assert 1e4 * (l2 - l1) > 1 + + # make sure that gradient is numerically correct + np.testing.assert_array_almost_equal( + 1e4 * (l2 - l1), + 1e4 * epsilon * np.linalg.norm(g1)**2, + decimal=1) diff --git a/foolbox/tests/test_models_tensorflow.py b/foolbox/tests/test_models_tensorflow.py index 5eff4544..153027fd 100644 --- a/foolbox/tests/test_models_tensorflow.py +++ b/foolbox/tests/test_models_tensorflow.py @@ -83,3 +83,85 @@ def mean_brightness_net(images): test_gradient) assert model.num_classes() == num_classes + + +@pytest.mark.parametrize('num_classes', [10, 1000]) +def test_tensorflow_preprocessing(num_classes): + bounds = (0, 255) + channels = num_classes + + def mean_brightness_net(images): + logits = tf.reduce_mean(images, axis=(1, 2)) + return logits + + q = (np.arange(num_classes)[None, None], + np.random.uniform(size=(5, 5, channels)) + 1) + + g = tf.Graph() + with g.as_default(): + images = tf.placeholder(tf.float32, (None, 5, 5, channels)) + logits = mean_brightness_net(images) + + with TensorFlowModel(images, logits, bounds=bounds, + preprocessing=q) as model: + + test_images = np.random.rand(2, 5, 5, channels).astype(np.float32) + test_label = 7 + + assert model.batch_predictions(test_images).shape \ + == (2, num_classes) + + test_logits = model.predictions(test_images[0]) + assert test_logits.shape == (num_classes,) + + test_gradient = model.gradient(test_images[0], test_label) + assert test_gradient.shape == test_images[0].shape + + np.testing.assert_almost_equal( + model.predictions_and_gradient(test_images[0], test_label)[0], + test_logits) + np.testing.assert_almost_equal( + model.predictions_and_gradient(test_images[0], test_label)[1], + test_gradient) + + assert model.num_classes() == num_classes + + +@pytest.mark.parametrize('num_classes', [10, 1000]) +def test_tensorflow_gradient(num_classes): + bounds = (0, 255) + channels = num_classes + + def mean_brightness_net(images): + logits = tf.reduce_mean(images, axis=(1, 2)) + return logits + + q = (np.arange(num_classes)[None, None], + np.random.uniform(size=(5, 5, channels)) + 1) + + g = tf.Graph() + with g.as_default(): + images = tf.placeholder(tf.float32, (None, 5, 5, channels)) + logits = mean_brightness_net(images) + + with TensorFlowModel(images, logits, bounds=bounds, + preprocessing=q) as model: + + epsilon = 1e-2 + + np.random.seed(23) + test_image = np.random.rand(5, 5, channels).astype(np.float32) + test_label = 7 + + _, g1 = model.predictions_and_gradient(test_image, test_label) + + l1 = model._loss_fn(test_image - epsilon / 2 * g1, test_label) + l2 = model._loss_fn(test_image + epsilon / 2 * g1, test_label) + + assert 1e4 * (l2 - l1) > 1 + + # make sure that gradient is numerically correct + np.testing.assert_array_almost_equal( + 1e4 * (l2 - l1), + 1e4 * epsilon * np.linalg.norm(g1)**2, + decimal=1) diff --git a/foolbox/tests/test_models_theano.py b/foolbox/tests/test_models_theano.py index 750ceda8..60d4ca3d 100644 --- a/foolbox/tests/test_models_theano.py +++ b/foolbox/tests/test_models_theano.py @@ -43,3 +43,43 @@ def mean_brightness_net(images): test_gradient) assert model.num_classes() == num_classes + + +@pytest.mark.parametrize('num_classes', [10, 1000]) +def test_lasagne_gradient(num_classes): + bounds = (0, 255) + channels = num_classes + + def mean_brightness_net(images): + logits = T.mean(images, axis=(2, 3)) + return logits + + images = T.tensor4('images') + logits = mean_brightness_net(images) + + preprocessing = (np.arange(num_classes)[:, None, None], + np.random.uniform(size=(channels, 5, 5)) + 1) + + model = TheanoModel( + images, + logits, + num_classes=num_classes, + preprocessing=preprocessing, + bounds=bounds) + + epsilon = 1e-3 + + np.random.seed(23) + test_image = np.random.rand(channels, 5, 5).astype(np.float32) + test_label = 7 + + _, g1 = model.predictions_and_gradient(test_image, test_label) + + l1 = model._loss_fn(test_image[None] - epsilon / 2 * g1, [test_label])[0] + l2 = model._loss_fn(test_image[None] + epsilon / 2 * g1, [test_label])[0] + + # make sure that gradient is numerically correct + np.testing.assert_array_almost_equal( + 1e5 * (l2 - l1), + 1e5 * epsilon * np.linalg.norm(g1)**2, + decimal=1)