-
Notifications
You must be signed in to change notification settings - Fork 460
Add RNN support for Pytorch #850
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
pre-commit.ci autofix |
The tests fail with:
|
All test failures the last time around seemed to be related to issues with the tests themselves, which I have mostly fixed. The only change I made was to add missing includes to some Quartus templates to fix compiliation errors when There are currently still some remaining test failures with the case when activations are used in their |
The tests here finally pass again and I think it is basically ready to merge. One thing to note is that this includes a change to the pytorch config interface giving more options on how to do the channels_last conversion so that it can be either |
@@ -92,6 +93,7 @@ def format(self, node): | |||
params['config_mult_h'] = f'config{node.index}_h_mult' | |||
params['act_t'] = '{}_config{}'.format(node.get_attr('activation'), str(node.index) + '_act') | |||
params['act_recurrent_t'] = '{}_config{}'.format(node.get_attr('recurrent_activation'), str(node.index) + '_rec_act') | |||
params['pytorch'] = 'true' if "pytorch" in node.attributes.keys() else 'false' |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Very minor point: hls4ml seems to use single quotes more often (except for doc strings), so I would change "pytorch" to 'pytorch'. There are double quotes used in other places. But truthfully I am not convinced whether we need to change it and run the tests again, since this is so minor.
@@ -301,5 +306,9 @@ def __init__(self): | |||
|
|||
def format(self, node): | |||
params = self._default_function_params(node) | |||
params['weights'] = 'w{0}, wr{0}, b{0}'.format(str(node.index)) | |||
if "pytorch" in node.attributes.keys(): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Is it better to check that pytorch is in the keys or that it is defined and True? What if it is defined and set False? I wonder if using node.get_attr('pytorch')
(returns None
if not found) or node.get_attr('pytorch', False)
is better.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Right now it will only be defined if a recurring layer is parsed in pytorch, and I can't really envision a situation where we would have code that sets this key that isn't part of the pytorch parser. But I still think you are right that this should be implemented in a more future-proof way and node.get_attr('pytorch', False)
is probably the most stringent solution. I will implement it.
Adds support for RNN layers (GRU, LSTM, RNN) to the pytorch parser.
Caveat: We currently lack implementation for
getitem
operations, so we can currently not return the hidden state after the calculationsCaveat 2: We currently only support a single recurrent layers, whereas multiple within the same RNN instance are supported by pytorch
Caveat 3: We currently don't support the passing of non-zero initial values for the hidden states to the RNN
So this implementation is slightly hacky at the moment, but might serve as a starting point for discussion, and can be used by interested parties if they can life with the current limitations.
Also, this contains parts of #848 because I was inattentive.
Type of change
For a new feature or function, please create an issue first to discuss it
with us before submitting a pull request.
Note: Please delete options that are not relevant.
Tests
Added pytests to confirm that the layers work.
Checklist
pre-commit
on the files I edited or added.