Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add MXNet Gluon model functionality. #137

Merged
merged 4 commits into from May 25, 2018

Conversation

meissnereric
Copy link
Contributor

We have some models that are built using MXNet's Gluon interface (http://gluon.mxnet.io/), so I added another model class to handle that type of network.

Also, there is a bug in the mxnet codebase with using the mx.symbol.softmax_cross_entropy and it hasn't worked in a while so I fixed that in the existing MXNet code as well.

Just contributing back if others want to use MXNet models :) Feedback welcome.

Thanks,
Eric

@jonasrauber
Copy link
Member

Hi @meissnereric, thanks a lot for your pull-request. It's great to have support for Gluon models! I am aware of the softmax_cross_entropy bug in MXNet; just last week I wanted to finally upgrade our tests to the latest MXNet and ran into it again: apache/mxnet#6874 so I really appreciate your workaround.

I am looking forward to merging this. Before that, we need to fix the code formatting and update the tests.

Here are some code-formatting issues that need to be fixed:

flake8 --ignore E402,E741 .
./foolbox/models/mxnet_gluon.py:77:14: E225 missing whitespace around operator
./foolbox/models/mxnet_gluon.py:84:80: E501 line too long (97 > 79 characters)
./foolbox/models/mxnet_gluon.py:89:14: E225 missing whitespace around operator
./foolbox/models/mxnet.py:7:1: E302 expected 2 blank lines, found 1
./foolbox/models/mxnet.py:66:80: E501 line too long (89 > 79 characters)
./foolbox/models/mxnet.py:68:9: E265 block comment should start with '# '

Also, you made some changes to adversarial.py. Could you remove them from the pull-request?

Finally, it would be great if you could upgrade MXNet in the travis config to the latest MXNet version so that we can see if your workaround passes the tests; see here: https://github.com/bethgelab/foolbox/pull/133/files

It would be great if you could to these changes, otherwise I can also do them myself as soon as I find time.

@meissnereric
Copy link
Contributor Author

Cool, I fixed up the formatting, removed the adversarial.py change, and upgraded the travis file.

Thanks,
Eric

@coveralls
Copy link

coveralls commented Apr 16, 2018

Coverage Status

Coverage decreased (-4.8%) to 95.247% when pulling 7bcae18 on meissnereric:mxnet_gluon into 701742e on bethgelab:master.

Copy link
Member

@jonasrauber jonasrauber left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In the end we will need tests like foolbox/tests/test_models_mxnet.py

@@ -63,7 +63,11 @@ def __init__(
label = mx.symbol.Variable('label')
self._label_sym = label

loss = mx.symbol.softmax_cross_entropy(logits, label)
log_softmax = mx.sym.log_softmax(logits)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

can you add a comment pointing out that this is a workaround for apache/mxnet#6874



class MXNetGluonModel(DifferentiableModel):
"""Creates a :class:`Model` instance from existing `MXNet` symbols and weights.
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

can you update the documentation and the parameter list

loss = mx.nd.softmax_cross_entropy(L, label)
loss.backward()
return np.squeeze(L.asnumpy(), axis=0),
self._process_gradient(data_array.grad.asnumpy())
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

indentation missing

L = self._block(data_array)
return np.squeeze(L.asnumpy(), axis=0)

def predictions(self, image):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

no need to implement predictions if batch_predictions is implemented (the base class will handle this)

def num_classes(self):
return self._num_classes

def batch_predictions(self, image):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

batch_predictions gets a batch of inputs and should be able to calculate predictions for all of them; this implementation only works for single inputs

@jonasrauber
Copy link
Member

Thanks Eric!
I reviewed the code and listed some changes that are still necessary. In the end, we also need a small mean_brightness_net example for the tests, similar to foolbox/tests/test_models_mxnet.py.

@meissnereric meissnereric force-pushed the mxnet_gluon branch 5 times, most recently from 6f9a9b0 to 4928b65 Compare May 1, 2018 20:31
@meissnereric
Copy link
Contributor Author

Hey all,

Took a few weeks but I've added tests and addressed the feedback for the CR. Let me know if there's anything else!

Thanks,
Eric

@meissnereric
Copy link
Contributor Author

Any updates?



class MXNetGluonModel(DifferentiableModel):
"""Creates a :class:`Model` instance from existing `MXNet` symbols and weights.
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This line of the documentation still needs to be updated

@jonasrauber
Copy link
Member

Thanks a lot, @meissnereric! Looking forward to merging this as soon as that last small change in the documentation of the MXNetGluonModel documentation is fixed.

@meissnereric
Copy link
Contributor Author

Should be good to go now!

@jonasrauber jonasrauber merged commit 6f3a637 into bethgelab:master May 25, 2018
@jonasrauber
Copy link
Member

Great work! Thanks a lot!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

3 participants