Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Keeping hidden state vector of RNN layers inside the model at online operation #11085

Closed
okankop opened this issue Apr 1, 2022 · 13 comments
Closed
Assignees
Labels
feature request request for unsupported feature or enhancement

Comments

@okankop
Copy link

okankop commented Apr 1, 2022

When running an onnx model that contains an RNN layer in real-time (i.e. frame by frame), we need to feed and receive the hidden state vector of the RNN explicitly at the inference with onnxruntime. An example inference can be seen in here: https://github.com/microsoft/AEC-Challenge/blob/a66686c0ed22bc757551edb16b40964068adeaa5/baseline/icassp2022/enhance.py#L85

It would be much better to cache the hidden state vector inside the model. So that, we can use the cached vector at each inference and update it at the end of the inference within the model. This way, we do not need to deal with the hassle of handling this vector.

Possible solution
PyTorch solution: I have created a "nn.Parameter" in the model, where I am caching the RNN hidden state vector there and update it at every inference. This solution works perfectly in PyTorch.

When I export the PyTorch model to onnx and use it with onnxruntime in Python API, I noticed that the parameter in the model is not updated at each inference. Hence this solution does not work with onnxruntime.

I would appreciate any possible solution for this feature.

System information

  • ONNX Runtime version (you are using): 1.10.0
@ytaous ytaous added the feature request request for unsupported feature or enhancement label Apr 1, 2022
@chausner
Copy link
Contributor

chausner commented Apr 2, 2022

Agree, it would be nice to have an easier solution for this than manually feeding back the state into the network. I've seen many users asking about this because it is not documented and there are no samples.

@skottmckay
Copy link
Contributor

ORT is stateless so that concurrent inferencing is supported. Whilst it may be more convenient for your scenario to internally store the state, that change would have significant performance and implementation implications. For ORT you will need to make the hidden state a graph output to capture it, and a graph input so it can be passed back in.

Probably best to ask on the PyTorch GitHub page about how the export of nn.Parameter to ONNX works, and whether anything could/should be done to automatically setup this sort of pass-through mechanism if the parameter is being used to cache state.

@skottmckay
Copy link
Contributor

skottmckay commented Apr 6, 2022

As it's somewhat non-trivial to update the model correctly, here's a helper script that should work. Limited testing. Tweak as needed. @chausner hopefully this helps convert a model so that the state is easily fed back in.

https://gist.github.com/skottmckay/ff23c03752dfb9873eb15888e5892c78

For all the RNN/LSTM/GRU nodes in the main graph it will

  • add a graph output for the Y_h output so it can be retrieved
  • add a graph input for the initial_h input so the previous state in the Y_h output can be passed back in
  • if possible, adds an initializer for the default value of initial_h
    • this makes the graph input for initial_h optional, so on the first run or if you don't want/need to break up your sequence into multiple calls to InferenceSession::Run you don't have to pass in the zeros for initial_h
    • this can only be done if the batch dimension is a fixed value
      • need a fixed shape for an initializer

It does not look at subgraphs in Scan/Loop/If nodes. That could be done but is more complicated as the Y_h output needs to be wired back through all levels of the ancestor graphs. The initial_h input just needs to be added to the main graph though.

Example output from running against this model which has 5 RNN layers: https://github.com/bedilbek/onnx_models/tree/master/text/machine_comprehension/bidirectional_attention_flow

Processing bidaf-9.onnx...
Updated LSTM node with name 'OptimizedRNNStack105580'. Optional graph input named LSTM_0_initial_h was added. Graph output named LSTM_0_Y_h was added
Updated LSTM node with name 'OptimizedRNNStack103980'. Optional graph input named LSTM_1_initial_h was added. Graph output named LSTM_1_Y_h was added
Updated LSTM node with name 'OptimizedRNNStack111910'. Optional graph input named LSTM_2_initial_h was added. Graph output named LSTM_2_Y_h was added
Updated LSTM node with name 'OptimizedRNNStack112030'. Optional graph input named LSTM_3_initial_h was added. Graph output named LSTM_3_Y_h was added
Updated LSTM node with name 'OptimizedRNNStack115740'. Optional graph input named LSTM_4_initial_h was added. Graph output named LSTM_4_Y_h was added
Writing updated model to bidaf-9.updated.onnx

If people test the script out and can validate it works I can make it more official.

@okankop
Copy link
Author

okankop commented Apr 6, 2022

@skottmckay Sure, I will test it and let you know the outcome. Thanks for your effort.

@skottmckay skottmckay self-assigned this Apr 6, 2022
@skottmckay
Copy link
Contributor

There was a bug when creating an initializer with the default value as it didn't take into account the layout. That's fixed, and it now also handles an existing initializer providing the default value.

@okankop
Copy link
Author

okankop commented Apr 9, 2022

@skottmckay I have tried the code you have provided on my Onnx model exported from PyTorch. Let me share my observations:

  1. First of all, at the onnx exporting part, if I do not provide the initial_h in the inference (i.e. Y, Y_H = self.rnn(input) ), initial_h is neither in graph_inputs nor initializers. Therefore, your script does not do the binding part.
  2. In order to work around the issue above, I provided a fixed initial_h at the inference time, so that initial_h gets into the initializers list.
  3. Once I get the binding, I tried to run the model, but the hidden state in the RNN part is not updated as expected. So the output was corrupted.

What I would really appreciate from you is to provide us how to do the binding for the dummy onnx model created by the script that I am sharing.
export_dummy.zip

@skottmckay
Copy link
Contributor

Can you share the output from when your ran the script on your model?

As there might be multiple RNN/LSTM/GRU nodes the script tries to use a unique name. This is where it adds an initial_h input to the node if one doesn't already exist:

https://gist.github.com/skottmckay/ff23c03752dfb9873eb15888e5892c78#file-make_rnn_state_graph_input-py-L76-L84

            if not initial_h:
                # create name for new initial_h input and add to node. TODO: add checks to ensure new name is unique
                initial_h = f'{node.op_type}_{rnn_idx}_initial_h'

                # add any missing optional inputs so that we're guaranteed to have 6 or more inputs
                while len(node.input) < 6:
                    node.input.append('')

                node.input[5] = initial_h

This is where it adds the graph input

https://gist.github.com/skottmckay/ff23c03752dfb9873eb15888e5892c78#file-make_rnn_state_graph_input-py-L112-L115

            # add graph input. create matching ValueInfo based on the graph output
            input_vi = copy.copy(output_vi)
            input_vi.name = initial_h
            m.graph.input.append(input_vi)

This is where it adds the initializer

https://gist.github.com/skottmckay/ff23c03752dfb9873eb15888e5892c78#file-make_rnn_state_graph_input-py-L122-L127

            if initial_h not in initializers and batch_dim.HasField('dim_value'):
                dims = [directions, batch_dim.dim_value, hidden_size] if layout == 0 \
                    else [batch_dim.dim_value, directions, hidden_size]
                vals = [float(0)] * (batch_dim.dim_value * directions * hidden_size)
                default_vals = onnx.helper.make_tensor(initial_h, output_vi.type.tensor_type.elem_type, dims, vals)
                m.graph.initializer.append(default_vals)

If you dump the onnx model to text you should see these things.

import onnx
model = 'model.onnx'
m = onnx.load(model)
with open(model + '.txt', "w", encoding="utf-8") as ofile:
    ofile.write(onnx.helper.printable_graph(m.graph))

@skottmckay
Copy link
Contributor

Note: I updated the script to also handle the initial_c/Y_c input/output of RNN so the line numbers don't match the above. It still makes the state a graph output/input the same way though. Doesn't matter to your model as it has a GRU node.

@skottmckay
Copy link
Contributor

skottmckay commented Apr 11, 2022

The reason your dummy model can't be updated is that the input for initial_h is coming from another node.

Processing model_dummy.onnx...
Skipping GRU node with name 'GRU_9' as the initial_h input is provided by another node.
Model was not updated.

image

Looks like the pytorch exporter is adding some logic to broadcast the initial state using the batch size, and this is used as the initial_h input. That means it's not a simple case of setting up graph outputs/inputs to pass state through, with an initializer directly providing the initial_h value.

It may be simpler to define the initial states as input of the model prior to export to onnx. I assume this is possible given the pytorch exporter output mentions doing so:

site-packages\torch\onnx\symbolic_opset9.py:2255: UserWarning: Exporting a model to ONNX with a batch_size other than 1, with a variable length with GRU can cause an error when running the ONNX model with a different batch size. Make sure to save the model with a batch size of 1, or define the initial states (h0/c0) as inputs of the model. or define the initial states (h0/c0) as inputs of the model. ")

I've added some more logic to the script to allow for the initial state input to come from an initializer being broadcast via an Expand node. The graph input needs to match the name of the initializer that is the input to the Expand node. The graph input can have a different shape though.

e.g. say the input to the Expand node is an initializer called 'x' with shape {1, 1, 128} containing the default value. it is broadcast to the batch size by the Expand node at runtime.

if you want to pass in the state, the graph input will also be called 'x' but should contain entries for all items in the batch. say the batch size is 3, the shape of the input provided would be {1, 3, 128}. as that will match the output shape of the Expand node, it doesn't have anything it needs to expand.

With those changes I can execute the dummy model and pass in the state via an input called 'onnx::Expand_5'

import numpy as np
import onnxruntime as ort

so = ort.SessionOptions()
s0 = ort.InferenceSession('model_dummy.onnx', so, providers=['CPUExecutionProvider'])
s1 = ort.InferenceSession('model_dummy.updated.onnx', so, providers=['CPUExecutionProvider'])

# Run the models using the same input to prove they produce the same output
i = np.random.random(size=[1, 1, 128]).astype(np.float32)
initial_h = np.zeros([1, 1, 128]).astype(np.float32)

inputs0 = {'inputs': i}

# run orig and new model the current way
outputs0 = s0.run(['outputs'], inputs0)
outputs1 = s1.run(['outputs'], inputs0)
assert(np.all(outputs0[0] == outputs1[0]))

# run new version with initial state passed in with default value
inputs1 = {'inputs': i, 'onnx::Expand_5': initial_h}
outputs2 = s1.run(['outputs', '74'], inputs1)
assert(np.all(outputs0[0] == outputs2[0]))

# now run with state passed from previous output passed in
inputs2 = {'inputs': i, 'onnx::Expand_5': outputs2[1]}
output2 = s1.run(['outputs'], inputs2)

# print(output2[0])

@okankop
Copy link
Author

okankop commented Apr 12, 2022

Thank you very much @skottmckay. Your guide is super useful and I have reproduced all your steps. In my trained model, I even got the desired outputs with the modified onnx model. However, in order to get the desired output, I have to provide the previously created hidden state vector as a second input to the inference run as you did in the above example:

inputs2 = {'inputs': i, 'onnx::Expand_5': outputs2[1]}
output2 = s1.run(['outputs'], inputs2)

This being said, what I want to achieve is to provide only RNN input to the model, and the hidden state is automatically updated within the model. So that, when I call the inference the second time, the model uses the previously created hidden state in the RNN. I do not want to carry the hidden state of RNN back and forth at each inference time. Is it possible to achieve this?

@skottmckay
Copy link
Contributor

Not currently. ORT is intentionally stateless so it can be called concurrently.

@skottmckay
Copy link
Contributor

FWIW I have another script that can wrap all the state variables into a single opaque Sequence value. More convenient as the number of state variables grows. Slight perf cost with some extra copies to pack/unpack that opaque value.

https://gist.github.com/skottmckay/f9df4e2bdb526ca3895340de7a8dff86

Example usage for test model with 2 LSTM state variables:

# with individual state vars
inputs = {'input': i0}
output1 = s1.run(['output', 'LSTM_0_Y_h', 'LSTM_0_Y_c'], inputs)
inputs = {'input': i1, 'LSTM_0_initial_h': output1[1], 'LSTM_0_initial_c': output1[2]}
...
# with aggregated state var
inputs2 = {'input': i0, 'internal_state_in': []}  # internal_state_in is array for empty Sequence<Tensor>
output2 = s2.run(['output', 'internal_state_out'], inputs2)
inputs2_with_state = {'input': i1, 'internal_state_in': output2[1]}
...

@okankop
Copy link
Author

okankop commented Apr 13, 2022

Thank you very much @skottmckay I am closing the issue since I got all my answers.

@okankop okankop closed this as completed Apr 13, 2022
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
feature request request for unsupported feature or enhancement
Projects
None yet
Development

No branches or pull requests

4 participants