Skip to content

Commit

Permalink
Fixed MXNet interface (#42)
Browse files Browse the repository at this point in the history
* added aux_states to mxnet

* fix coverage

* fixed overintendation
  • Loading branch information
wielandbrendel authored and jonasrauber committed Jul 22, 2017
1 parent 60dd526 commit 2e9e589
Show file tree
Hide file tree
Showing 2 changed files with 29 additions and 14 deletions.
39 changes: 27 additions & 12 deletions foolbox/models/mxnet.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,9 +14,9 @@ class MXNetModel(DifferentiableModel):
The input to the model.
logits : `mxnet.symbol.Symbol`
The predictions of the model, before the softmax.
weights : `dictionary mapping str to mxnet.nd.array`
The weights of the model.
device : `mxnet.context.Context`
args : `dictionary mapping str to mxnet.nd.array`
The parameters of the model.
ctx : `mxnet.context.Context`
The device, e.g. mxnet.cpu() or mxnet.gpu().
num_classes : int
The number of classes.
Expand All @@ -25,6 +25,8 @@ class MXNetModel(DifferentiableModel):
(0, 1) or (0, 255).
channel_axis : int
The index of the axis that represents color channels.
aux_states : `dictionary mapping str to mxnet.nd.array`
The states of auxiliary parameters of the model.
preprocessing: 2-element tuple with floats or numpy arrays
Elementwises preprocessing of input; we first subtract the first
element of preprocessing from the input and then divide the input by
Expand All @@ -36,11 +38,12 @@ def __init__(
self,
data,
logits,
weights,
device,
args,
ctx,
num_classes,
bounds,
channel_axis=1,
aux_states=None,
preprocessing=(0, 1)):

super(MXNetModel, self).__init__(
Expand All @@ -52,7 +55,7 @@ def __init__(

self._num_classes = num_classes

self._device = device
self._device = ctx

self._data_sym = data
self._batch_logits_sym = logits
Expand All @@ -63,9 +66,18 @@ def __init__(
loss = mx.symbol.softmax_cross_entropy(logits, label)
self._loss_sym = loss

weight_names = list(weights.keys())
weight_arrays = [weights[name] for name in weight_names]
self._args_map = dict(zip(weight_names, weight_arrays))
self._args_map = args.copy()
self._aux_map = aux_states.copy() if aux_states is not None else None

# move all parameters to correct device
for k in self._args_map.keys():
self._args_map[k] = \
self._args_map[k].as_in_context(ctx) # pragma: no cover

if aux_states is not None:
for k in self._aux_map.keys(): # pragma: no cover
self._aux_map[k] = \
self._aux_map[k].as_in_context(ctx) # pragma: no cover

def num_classes(self):
return self._num_classes
Expand All @@ -76,7 +88,8 @@ def batch_predictions(self, images):
data_array = mx.nd.array(images, ctx=self._device)
self._args_map[self._data_sym.name] = data_array
model = self._batch_logits_sym.bind(
ctx=self._device, args=self._args_map, grad_req='null')
ctx=self._device, args=self._args_map, grad_req='null',
aux_states=self._aux_map)
model.forward(is_train=False)
logits_array = model.outputs[0]
logits = logits_array.asnumpy()
Expand All @@ -99,7 +112,8 @@ def predictions_and_gradient(self, image, label):
ctx=self._device,
args=self._args_map,
args_grad=grad_map,
grad_req='write')
grad_req='write',
aux_states=self._aux_map)
model.forward(is_train=True)
logits_array = model.outputs[0]
model.backward([
Expand All @@ -119,7 +133,8 @@ def _loss_fn(self, image, label):
self._args_map[self._data_sym.name] = data_array
self._args_map[self._label_sym.name] = label_array
model = self._loss_sym.bind(
ctx=self._device, args=self._args_map, grad_req='null')
ctx=self._device, args=self._args_map, grad_req='null',
aux_states=self._aux_map)
model.forward(is_train=False)
loss_array = model.outputs[0]
loss = loss_array.asnumpy()[0]
Expand Down
4 changes: 2 additions & 2 deletions foolbox/tests/test_models_mxnet.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ def mean_brightness_net(images):
images,
logits,
{},
device=mx.cpu(),
ctx=mx.cpu(),
num_classes=num_classes,
bounds=bounds,
channel_axis=1)
Expand Down Expand Up @@ -68,7 +68,7 @@ def mean_brightness_net(images):
images,
logits,
{},
device=mx.cpu(),
ctx=mx.cpu(),
num_classes=num_classes,
bounds=bounds,
preprocessing=preprocessing,
Expand Down

0 comments on commit 2e9e589

Please sign in to comment.