New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add MXNet Gluon model functionality. #137
Conversation
Hi @meissnereric, thanks a lot for your pull-request. It's great to have support for Gluon models! I am aware of the I am looking forward to merging this. Before that, we need to fix the code formatting and update the tests. Here are some code-formatting issues that need to be fixed:
Also, you made some changes to adversarial.py. Could you remove them from the pull-request? Finally, it would be great if you could upgrade MXNet in the travis config to the latest MXNet version so that we can see if your workaround passes the tests; see here: https://github.com/bethgelab/foolbox/pull/133/files It would be great if you could to these changes, otherwise I can also do them myself as soon as I find time. |
ce8d2b8
to
0b2a5b8
Compare
Cool, I fixed up the formatting, removed the adversarial.py change, and upgraded the travis file. Thanks, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
In the end we will need tests like foolbox/tests/test_models_mxnet.py
…
foolbox/models/mxnet.py
Outdated
@@ -63,7 +63,11 @@ def __init__( | |||
label = mx.symbol.Variable('label') | |||
self._label_sym = label | |||
|
|||
loss = mx.symbol.softmax_cross_entropy(logits, label) | |||
log_softmax = mx.sym.log_softmax(logits) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
can you add a comment pointing out that this is a workaround for apache/mxnet#6874
foolbox/models/mxnet_gluon.py
Outdated
|
||
|
||
class MXNetGluonModel(DifferentiableModel): | ||
"""Creates a :class:`Model` instance from existing `MXNet` symbols and weights. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
can you update the documentation and the parameter list
foolbox/models/mxnet_gluon.py
Outdated
loss = mx.nd.softmax_cross_entropy(L, label) | ||
loss.backward() | ||
return np.squeeze(L.asnumpy(), axis=0), | ||
self._process_gradient(data_array.grad.asnumpy()) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
indentation missing
foolbox/models/mxnet_gluon.py
Outdated
L = self._block(data_array) | ||
return np.squeeze(L.asnumpy(), axis=0) | ||
|
||
def predictions(self, image): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
no need to implement predictions
if batch_predictions
is implemented (the base class will handle this)
foolbox/models/mxnet_gluon.py
Outdated
def num_classes(self): | ||
return self._num_classes | ||
|
||
def batch_predictions(self, image): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
batch_predictions gets a batch of inputs and should be able to calculate predictions for all of them; this implementation only works for single inputs
Thanks Eric! |
6f9a9b0
to
4928b65
Compare
Hey all, Took a few weeks but I've added tests and addressed the feedback for the CR. Let me know if there's anything else! Thanks, |
Any updates? |
foolbox/models/mxnet_gluon.py
Outdated
|
||
|
||
class MXNetGluonModel(DifferentiableModel): | ||
"""Creates a :class:`Model` instance from existing `MXNet` symbols and weights. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This line of the documentation still needs to be updated
Thanks a lot, @meissnereric! Looking forward to merging this as soon as that last small change in the documentation of the |
Should be good to go now! |
Great work! Thanks a lot! |
We have some models that are built using MXNet's Gluon interface (http://gluon.mxnet.io/), so I added another model class to handle that type of network.
Also, there is a bug in the mxnet codebase with using the mx.symbol.softmax_cross_entropy and it hasn't worked in a while so I fixed that in the existing MXNet code as well.
Just contributing back if others want to use MXNet models :) Feedback welcome.
Thanks,
Eric