Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Update the example of exporting Bart + BeamSearch to ONNX module to resolve comments. #14310

Merged
merged 17 commits into from
Dec 6, 2021
Merged
Show file tree
Hide file tree
Changes from 5 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 6 additions & 3 deletions examples/onnx/pytorch/translation/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -13,14 +13,17 @@ See the License for the specific language governing permissions and
limitations under the License.
-->

# Usage
# Bart + Beam Search to ONNX

This folder contains an example of exporting Bart + Beam Search generation(BartForConditionalGeneration) to ONNX file.

## Usage
fatcat-z marked this conversation as resolved.
Show resolved Hide resolved

This folder contains an example of exporting Bart + Beam Search generation (`BartForConditionalGeneration`) to ONNX.

Beam Search contains a for-loop workflow, so we need to make them TorchScript-compatible for exporting to ONNX. This example shows how to make a Bart model to be TorchScript-compatible by wrapping up it into a new model. In addition, some changes were made to the beam_search() function for making it TorchScript-compatible.
fatcat-z marked this conversation as resolved.
Show resolved Hide resolved


# How to run the example
## How to run the example
fatcat-z marked this conversation as resolved.
Show resolved Hide resolved

To make sure you can successfully run the latest versions of the example scripts, you have to **install the library from source** and install some example-specific requirements. To do this, execute the following steps in a new virtual environment:

Expand Down
11 changes: 7 additions & 4 deletions examples/onnx/pytorch/translation/bart_onnx/generation_onnx.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,14 +13,17 @@ def _convert_past_list_to_tuple(past):
"""
The type of past_key_values is tuple(tuple(torch.FloatTensor)) which is not TorchScript-compatible.
fatcat-z marked this conversation as resolved.
Show resolved Hide resolved
This will convert past values from a list to tuple(tuple(torch.FloatTensor)) for the inner decoder.
"""

According to the definition of past_key_values, each inner tuple(torch.FloatTensor) has 4 tensors,
so we convert every 4 elements in the list as a tuple(torch.FloatTensor).
"""
count_of_each_inner_tuple = 4
results = ()
temp_result = ()
count_n = len(past) // 4
count_n = len(past) // count_of_each_inner_tuple
for idx in range(count_n):
real_idx = idx * 4
temp_result = tuple(past[real_idx : real_idx + 4])
real_idx = idx * count_of_each_inner_tuple
temp_result = tuple(past[real_idx : real_idx + count_of_each_inner_tuple])
results += ((temp_result),)

return results
Expand Down
3 changes: 1 addition & 2 deletions examples/onnx/pytorch/translation/run_onnx_exporter.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,6 @@
import logging
import os
import sys
from datetime import datetime

import numpy as np
import torch
Expand Down Expand Up @@ -196,7 +195,7 @@ def main():
if args.output_file_path:
output_name = args.output_file_path
else:
output_name = "onnx_model_{}.onnx".format(datetime.now().utcnow().microsecond)
output_name = "onnx_model.onnx"
fatcat-z marked this conversation as resolved.
Show resolved Hide resolved

logger.info("Exporting model to ONNX")
export_and_validate_model(model, tokenizer, output_name, num_beams, max_length)
Expand Down