Skip to content

Add RNN support for Pytorch #850

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 21 commits into from
Jul 23, 2024
Merged

Conversation

JanFSchulte
Copy link
Contributor

Adds support for RNN layers (GRU, LSTM, RNN) to the pytorch parser.

Caveat: We currently lack implementation for getitem operations, so we can currently not return the hidden state after the calculations

Caveat 2: We currently only support a single recurrent layers, whereas multiple within the same RNN instance are supported by pytorch

Caveat 3: We currently don't support the passing of non-zero initial values for the hidden states to the RNN

So this implementation is slightly hacky at the moment, but might serve as a starting point for discussion, and can be used by interested parties if they can life with the current limitations.

Also, this contains parts of #848 because I was inattentive.

Type of change

For a new feature or function, please create an issue first to discuss it
with us before submitting a pull request.

Note: Please delete options that are not relevant.

  • New feature (non-breaking change which adds functionality)

Tests

Added pytests to confirm that the layers work.

Checklist

  • I have read the guidelines for contributing.
  • I have commented my code, particularly in hard-to-understand areas.
  • I have made corresponding changes to the documentation.
  • My changes generate no new warnings.
  • I have installed and run pre-commit on the files I edited or added.
  • I have added tests that prove my fix is effective or that my feature works.

@JanFSchulte JanFSchulte marked this pull request as ready for review August 17, 2023 14:15
@vloncar vloncar added the please test Trigger testing by creating local PR branch label Aug 17, 2023
@vloncar
Copy link
Contributor

vloncar commented Aug 17, 2023

pre-commit.ci autofix

@jmitrevs jmitrevs added please test Trigger testing by creating local PR branch and removed please test Trigger testing by creating local PR branch labels Oct 20, 2023
@jmitrevs jmitrevs added this to the v1.0.0 milestone Oct 20, 2023
@jmitrevs
Copy link
Contributor

The tests fail with:

FAILED test_pytorch_api.py::test_skipped_layers[io_parallel-Vivado] - TypeError: config_from_pytorch_model() got an unexpected keyword argument 'inputs_channel_last'
FAILED test_pytorch_api.py::test_skipped_layers[io_parallel-Quartus] - TypeError: config_from_pytorch_model() got an unexpected keyword argument 'inputs_channel_last'
FAILED test_pytorch_api.py::test_skipped_layers[io_stream-Vivado] - TypeError: config_from_pytorch_model() got an unexpected keyword argument 'inputs_channel_last'
FAILED test_pytorch_api.py::test_skipped_layers[io_stream-Quartus] - TypeError: config_from_pytorch_model() got an unexpected keyword argument 'inputs_channel_last'

@jmitrevs jmitrevs added please test Trigger testing by creating local PR branch and removed please test Trigger testing by creating local PR branch labels May 3, 2024
@JanFSchulte
Copy link
Contributor Author

All test failures the last time around seemed to be related to issues with the tests themselves, which I have mostly fixed. The only change I made was to add missing includes to some Quartus templates to fix compiliation errors when uint_8 was used.

There are currently still some remaining test failures with the case when activations are used in their nn.functionals implementation instead of as classes. Here I can't reproduce the failures in a standalone file, the exact same code that fails in the pytest works fine running in standalone python. Have not figured out how to debug it in those circumstances.

@jmitrevs jmitrevs added please test Trigger testing by creating local PR branch and removed please test Trigger testing by creating local PR branch labels May 9, 2024
@jmitrevs jmitrevs added please test Trigger testing by creating local PR branch and removed please test Trigger testing by creating local PR branch labels May 31, 2024
@JanFSchulte JanFSchulte added please test Trigger testing by creating local PR branch and removed please test Trigger testing by creating local PR branch labels Jul 16, 2024
@JanFSchulte JanFSchulte added please test Trigger testing by creating local PR branch and removed please test Trigger testing by creating local PR branch labels Jul 16, 2024
@JanFSchulte JanFSchulte added please test Trigger testing by creating local PR branch and removed please test Trigger testing by creating local PR branch labels Jul 16, 2024
@JanFSchulte
Copy link
Contributor Author

The tests here finally pass again and I think it is basically ready to merge. One thing to note is that this includes a change to the pytorch config interface giving more options on how to do the channels_last conversion so that it can be either full, transposing both inputs and internal layers, internal, assuming that inputs are already transposed and only transposing internal layers, or off. I developed this at some point, not sure if it was based on a discussion with Vladimir, and included it in here kinda by accident. Let me know if that's desired or should be removed.

@@ -92,6 +93,7 @@ def format(self, node):
params['config_mult_h'] = f'config{node.index}_h_mult'
params['act_t'] = '{}_config{}'.format(node.get_attr('activation'), str(node.index) + '_act')
params['act_recurrent_t'] = '{}_config{}'.format(node.get_attr('recurrent_activation'), str(node.index) + '_rec_act')
params['pytorch'] = 'true' if "pytorch" in node.attributes.keys() else 'false'
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Very minor point: hls4ml seems to use single quotes more often (except for doc strings), so I would change "pytorch" to 'pytorch'. There are double quotes used in other places. But truthfully I am not convinced whether we need to change it and run the tests again, since this is so minor.

@@ -301,5 +306,9 @@ def __init__(self):

def format(self, node):
params = self._default_function_params(node)
params['weights'] = 'w{0}, wr{0}, b{0}'.format(str(node.index))
if "pytorch" in node.attributes.keys():
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is it better to check that pytorch is in the keys or that it is defined and True? What if it is defined and set False? I wonder if using node.get_attr('pytorch') (returns None if not found) or node.get_attr('pytorch', False) is better.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Right now it will only be defined if a recurring layer is parsed in pytorch, and I can't really envision a situation where we would have code that sets this key that isn't part of the pytorch parser. But I still think you are right that this should be implemented in a more future-proof way and node.get_attr('pytorch', False) is probably the most stringent solution. I will implement it.

@JanFSchulte JanFSchulte added please test Trigger testing by creating local PR branch and removed please test Trigger testing by creating local PR branch labels Jul 23, 2024
@jmitrevs jmitrevs merged commit 75b0b0d into fastmachinelearning:main Jul 23, 2024
7 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
please test Trigger testing by creating local PR branch
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants