Skip to content

Commit

Permalink
Merge pull request #137 from meissnereric/mxnet_gluon
Browse files Browse the repository at this point in the history
Add MXNet Gluon model functionality.
  • Loading branch information
jonasrauber committed May 25, 2018
2 parents b0d24c4 + 7bcae18 commit 6f3a637
Show file tree
Hide file tree
Showing 5 changed files with 177 additions and 2 deletions.
2 changes: 1 addition & 1 deletion .travis.yml
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@ before_install:

- travis_wait travis_retry pip install --upgrade keras>=2.1.5
- python -c 'import keras; print(keras.__version__)'
- travis_wait travis_retry pip install --upgrade mxnet==0.10.0
- travis_wait travis_retry pip install --upgrade mxnet>=1.1.0
- python -c 'import mxnet; print(mxnet.__version__)'

#install open mpi for cntk
Expand Down
1 change: 1 addition & 0 deletions foolbox/models/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,3 +17,4 @@
from .theano import TheanoModel # noqa: F401
from .lasagne import LasagneModel # noqa: F401
from .mxnet import MXNetModel # noqa: F401
from .mxnet_gluon import MXNetGluonModel # noqa: F401
8 changes: 7 additions & 1 deletion foolbox/models/mxnet.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,13 @@ def __init__(
label = mx.symbol.Variable('label')
self._label_sym = label

loss = mx.symbol.softmax_cross_entropy(logits, label)
# workaround for https://github.com/apache/incubator-mxnet/issues/6874
log_softmax = mx.sym.log_softmax(logits)

loss = mx.sym.sum(
mx.sym.one_hot(indices=label, depth=num_classes) * log_softmax)

# loss = mx.symbol.softmax_cross_entropy(logits, label)
self._loss_sym = loss

self._args_map = args.copy()
Expand Down
89 changes: 89 additions & 0 deletions foolbox/models/mxnet_gluon.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,89 @@
from __future__ import absolute_import

from .base import DifferentiableModel

import numpy as np


class MXNetGluonModel(DifferentiableModel):
"""Creates a :class:`Model` instance from an existing `MXNet Gluon` Block.
Parameters
----------
block : `mxnet.gluon.Block`
The Gluon Block representing the model to be run.
ctx : `mxnet.context.Context`
The device, e.g. mxnet.cpu() or mxnet.gpu().
num_classes : int
The number of classes.
bounds : tuple
Tuple of lower and upper bound for the pixel values, usually
(0, 1) or (0, 255).
channel_axis : int
The index of the axis that represents color channels.
preprocessing: 2-element tuple with floats or numpy arrays
Elementwises preprocessing of input; we first subtract the first
element of preprocessing from the input and then divide the input by
the second element.
"""

def __init__(
self,
block,
bounds,
num_classes,
ctx=None,
channel_axis=1,
preprocessing=(0, 1)):
import mxnet as mx
self._num_classes = num_classes

if ctx is None:
ctx = mx.cpu()

super(MXNetGluonModel, self).__init__(
bounds=bounds,
channel_axis=channel_axis,
preprocessing=preprocessing)

self._device = ctx
self._block = block

def num_classes(self):
return self._num_classes

def batch_predictions(self, images):
import mxnet as mx
images = self._process_input(images)
data_array = mx.nd.array(images, ctx=self._device)
data_array.attach_grad()
with mx.autograd.record(train_mode=False):
L = self._block(data_array)
return L.asnumpy()

def predictions_and_gradient(self, image, label):
import mxnet as mx
image = self._process_input(image)
label = mx.nd.array([label])
data_array = mx.nd.array(image[np.newaxis], ctx=self._device)
data_array.attach_grad()
with mx.autograd.record(train_mode=False):
L = self._block(data_array)
loss = mx.nd.softmax_cross_entropy(L, label)
loss.backward()
return np.squeeze(L.asnumpy(), axis=0), \
np.squeeze(self._process_gradient(data_array.grad.asnumpy()),
axis=0)

def _loss_fn(self, image, label):
import mxnet as mx
image = self._process_input(image)
label = mx.nd.array([label])
data_array = mx.nd.array(image[np.newaxis], ctx=self._device)
data_array.attach_grad()
with mx.autograd.record(train_mode=False):
L = self._block(data_array)
loss = mx.nd.softmax_cross_entropy(L, label)
loss.backward()
return loss.asnumpy()
79 changes: 79 additions & 0 deletions foolbox/tests/test_models_mxnet_gluon.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,79 @@
import pytest
import mxnet as mx
import numpy as np

from foolbox.models import MXNetGluonModel
from mxnet.gluon import HybridBlock


class MeanBrightnessNet(HybridBlock):
def hybrid_forward(self, F, x, *args, **kwargs):
return mx.nd.mean(x, axis=(2, 3))


@pytest.mark.parametrize('num_classes', [10, 1000])
def test_model(num_classes):
bounds = (0, 255)
channels = num_classes

block = MeanBrightnessNet()

model = MXNetGluonModel(
block,
num_classes=num_classes,
bounds=bounds,
channel_axis=1)

test_images = np.random.rand(2, channels, 5, 5).astype(np.float32)
test_label = 7

# Tests
assert model.batch_predictions(test_images).shape \
== (2, num_classes)

test_logits = model.predictions(test_images[0])
assert test_logits.shape == (num_classes,)

test_gradient = model.gradient(test_images[0], test_label)
assert test_gradient.shape == test_images[0].shape

np.testing.assert_almost_equal(
model.predictions_and_gradient(test_images[0], test_label)[0],
test_logits)
np.testing.assert_almost_equal(
model.predictions_and_gradient(test_images[0], test_label)[1],
test_gradient)

assert model.num_classes() == num_classes


@pytest.mark.parametrize('num_classes', [10, 1000])
def test_model_gradient(num_classes):
bounds = (0, 255)
channels = num_classes

block = MeanBrightnessNet()

model = MXNetGluonModel(
block,
ctx=mx.cpu(),
num_classes=num_classes,
bounds=bounds,
channel_axis=1)

test_images = np.random.rand(2, channels, 5, 5).astype(np.float32)
test_image = test_images[0]
test_label = 7

epsilon = 1e-2
_, g1 = model.predictions_and_gradient(test_image, test_label)
l1 = model._loss_fn(test_image - epsilon / 2 * g1, test_label)
l2 = model._loss_fn(test_image + epsilon / 2 * g1, test_label)

assert 1e4 * (l2 - l1) > 1

# make sure that gradient is numerically correct
np.testing.assert_array_almost_equal(
1e4 * (l2 - l1),
1e4 * epsilon * np.linalg.norm(g1)**2,
decimal=1)

0 comments on commit 6f3a637

Please sign in to comment.