Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Update the example of exporting Bart + BeamSearch to ONNX module to resolve comments. #14310

Merged
merged 17 commits into from
Dec 6, 2021
Merged
Show file tree
Hide file tree
Changes from 12 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
44 changes: 44 additions & 0 deletions examples/onnx/pytorch/translation/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,44 @@
<!---
Copyright 2021 The HuggingFace Team. All rights reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
-->

# Bart + Beam Search to ONNX


## Usage
fatcat-z marked this conversation as resolved.
Show resolved Hide resolved

This folder contains an example of exporting Bart + Beam Search generation (`BartForConditionalGeneration`) to ONNX.

Beam Search contains a for-loop workflow, so we need to make them TorchScript-compatible for exporting to ONNX. This example shows how to make a Bart model to be TorchScript-compatible by wrapping up it into a new model. In addition, some changes were made to the beam_search() function for making it TorchScript-compatible.
fatcat-z marked this conversation as resolved.
Show resolved Hide resolved


## How to run the example
fatcat-z marked this conversation as resolved.
Show resolved Hide resolved

To make sure you can successfully run the latest versions of the example scripts, you have to **install the library from source** and install some example-specific requirements. To do this, execute the following steps in a new virtual environment:

```bash
git clone https://github.com/huggingface/transformers
cd transformers
pip install .
```
Then cd in this example folder and run
```bash
pip install -r requirements.txt
```

Now you can run the example command below to get the example ONNX file:

```bash
python run_onnx_exporter.py --model_name_or_path facebook/bart-base
```
135 changes: 36 additions & 99 deletions examples/onnx/pytorch/translation/bart_onnx/generation_onnx.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import copy
import itertools
from typing import List, Optional, Tuple

import torch
Expand All @@ -8,23 +9,21 @@
from transformers.generation_utils import GenerationMixin


def flatten_list(past):
values = []
if past is not None:
for i, p in enumerate(past):
for j, q in enumerate(p):
values.append(q)

return values

def _convert_past_list_to_tuple(past):
fatcat-z marked this conversation as resolved.
Show resolved Hide resolved
"""
The type of past_key_values is tuple(tuple(torch.FloatTensor)) which is not TorchScript-compatible.
fatcat-z marked this conversation as resolved.
Show resolved Hide resolved
This will convert past values from a list to tuple(tuple(torch.FloatTensor)) for the inner decoder.

def list_to_tuple(past):
According to the definition of past_key_values, each inner tuple(torch.FloatTensor) has 4 tensors,
so we convert every 4 elements in the list as a tuple(torch.FloatTensor).
"""
count_of_each_inner_tuple = 4
results = ()
temp_result = ()
count_n = len(past) // 4
count_n = len(past) // count_of_each_inner_tuple
for idx in range(count_n):
real_idx = idx * 4
temp_result = tuple(past[real_idx : real_idx + 4])
real_idx = idx * count_of_each_inner_tuple
temp_result = tuple(past[real_idx : real_idx + count_of_each_inner_tuple])
results += ((temp_result),)

return results
Expand All @@ -51,7 +50,7 @@ def __init__(self, decoder):
def forward(self, input_ids, encoder_state, attention_mask, past=None):
all_results = None
if past is not None:
all_results = list_to_tuple(past)
all_results = _convert_past_list_to_tuple(past)
input_ids = input_ids[:, -1:]

last_hidden_state, past_key_values = self.decoder(
Expand All @@ -68,28 +67,33 @@ def forward(self, input_ids, encoder_state, attention_mask, past=None):
return last_hidden_state, past_values


def create_traced_encoder(encoder, input_ids, attention_mask):
def _create_traced_encoder(encoder, input_ids, attention_mask):
encoder_c = copy.deepcopy(encoder)
encoder_for_onnx = EncoderForONNX(encoder_c)

# return torch.jit.trace(encoder, (input_ids, attention_mask))
return torch.jit.trace(encoder_for_onnx, (input_ids, attention_mask))


def create_traced_decoder(decoder, input_ids, encoder_state, attention_mask, past=None):
def _create_traced_decoder(decoder, input_ids, encoder_state, attention_mask, past=None):
decoder_c = copy.deepcopy(decoder)
decoder_for_onnx = DecoderForONNX(decoder_c)
past_values = flatten_list(past)
past_values = list(itertools.chain.from_iterable(past or ()))

# Do this twice so we got 2 different decoders for further work.
if past_values is None or len(past_values) == 0:
return torch.jit.trace(decoder_for_onnx, (input_ids, encoder_state, attention_mask))
else:
if past_values:
return torch.jit.trace(decoder_for_onnx, (input_ids, encoder_state, attention_mask, past_values))
else:
return torch.jit.trace(decoder_for_onnx, (input_ids, encoder_state, attention_mask))


class BartConfigTS(BartConfig, torch.nn.Module):
def init_module(self):
"""
BartConfigTS is a TorchScript-compatible transformers.models.bart.configuration_bart.BartConfig.
TorchScript only supports sub-classes of torch.nn.Module.
"""

def __init__(self, config):
BartConfig.__init__(self, config)
torch.nn.Module.__init__(self)


Expand Down Expand Up @@ -127,7 +131,6 @@ class BARTGenerator(torch.nn.Module, GenerationMixin):
def __init__(self, model):
super().__init__()
self.config = BartConfigTS(model.config)
self.config.init_module()
self.config.force_bos_token_to_be_generated = False
self._trace_modules(model)
self.logits_processor = MinLengthLogitsProcessorTS(self.config.min_length, self.config.eos_token_id)
Expand All @@ -136,7 +139,6 @@ def __init__(self, model):
self.decoder_layers = model.config.decoder_layers

def _trace_modules(self, model):
# Be aware of the last one 2 should be kept.
input_ids = torch.tensor(
[
[
Expand Down Expand Up @@ -200,89 +202,25 @@ def _trace_modules(self, model):
57,
8629,
5,
2,
model.config.eos_token_id,
]
],
device=model.device,
dtype=torch.long,
)
attention_mask = torch.tensor(
[
[
True,
True,
True,
True,
True,
True,
True,
True,
True,
True,
True,
True,
True,
True,
True,
True,
True,
True,
True,
True,
True,
True,
True,
True,
True,
True,
True,
True,
True,
True,
True,
True,
True,
True,
True,
True,
True,
True,
True,
True,
True,
True,
True,
True,
True,
True,
True,
True,
True,
True,
True,
True,
True,
True,
True,
True,
True,
True,
True,
True,
True,
]
],
[[True] * input_ids.shape[-1]],
fatcat-z marked this conversation as resolved.
Show resolved Hide resolved
device=model.device,
dtype=torch.bool,
)
self.encoder = create_traced_encoder(model.get_encoder(), input_ids, attention_mask)
self.encoder = _create_traced_encoder(model.get_encoder(), input_ids, attention_mask)
encoder_outputs = model.get_encoder()(input_ids, attention_mask=attention_mask, return_dict=True)
decoder = model.model.decoder
decoder_outputs = decoder(input_ids, attention_mask, encoder_outputs["last_hidden_state"], None, None, None)
self.decoder_no_past = create_traced_decoder(
self.decoder_no_past = _create_traced_decoder(
model.model.decoder, input_ids, encoder_outputs["last_hidden_state"], attention_mask
)
self.decoder_with_past = create_traced_decoder(
self.decoder_with_past = _create_traced_decoder(
model.model.decoder, input_ids, encoder_outputs["last_hidden_state"], attention_mask, decoder_outputs[1]
)

Expand Down Expand Up @@ -414,8 +352,8 @@ def __init__(self):
self._beam_hyps_count = torch.zeros(self.batch_size, dtype=torch.long)
self._beam_hyps_worst_scores = torch.zeros(self.batch_size) + 1e9
self._beam_hyps_max_length: int = self.max_length - 1
self._beam_hyps: List[torch.Tensor] = [torch.zeros(2)] # placeholder for TorchScript compatible
self._beam_scores: List[torch.Tensor] = [torch.zeros(2)] # placeholder for TorchScript compatible
self._beam_hyps: List[torch.Tensor] = [torch.zeros(2)] # placeholder for TorchScript compatibility
self._beam_scores: List[torch.Tensor] = [torch.zeros(2)] # placeholder for TorchScript compatibility

def is_done(self) -> torch.Tensor:
return self._done.all()
Expand Down Expand Up @@ -474,11 +412,11 @@ def hypo_add(self, hyp: torch.Tensor, sum_logprobs: float, hypo_idx: int):
score = sum_logprobs / (hyp.shape[-1] ** self.length_penalty)
hyps_count = self.hypo_len(hypo_idx)
if hyps_count < self.num_beams or score > self._beam_hyps_worst_scores[hypo_idx]:
# NOTE: work around difference of torch.sum(empty_tensor) = 0, while error in onnx.
# NOTE: work around difference of torch.sum(empty_tensor) == 0, while error in onnx.
# Bug: https://msdata.visualstudio.com/Vienna/_workitems/edit/1486599
fatcat-z marked this conversation as resolved.
Show resolved Hide resolved
beam_idx = (
torch.sum(self._beam_hyps_count[:hypo_idx]) if hypo_idx != 0 else torch.tensor(0, dtype=torch.long)
)
# beam_idx = torch.sum(_beam_hyps_count[:hypo_idx])
self._beam_scores.insert(beam_idx, torch.tensor([score]))
self._beam_hyps.insert(beam_idx, hyp)
if hyps_count + 1 > self.num_beams:
Expand Down Expand Up @@ -605,7 +543,7 @@ def finalize(
self.hypo_add(final_tokens, final_score, batch_idx)

# select the best hypotheses
# NOTE: new is not scriptable
# NOTE: torch.Tensor.new_zeros() is not scriptable
sent_lengths = torch.zeros(batch_size * self.num_beam_hyps_to_keep, dtype=torch.long)
best = []
best_scores = torch.zeros(
Expand Down Expand Up @@ -782,7 +720,6 @@ def forward(self, input_ids, attention_mask, num_beams, max_length, decoder_star
bos_token_id=bos_token_id,
)

# from generation_utils.py
batch_size = input_ids.shape[0]

length_penalty = self.config.length_penalty
Expand Down
33 changes: 21 additions & 12 deletions examples/onnx/pytorch/translation/bart_onnx/reduce_onnx_size.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,15 @@
"""
Code to remove duplicate initializers to reduce ONNX model size.
fatcat-z marked this conversation as resolved.
Show resolved Hide resolved
"""

import os

import numpy

import onnx


def is_equal_tensor_proto(a, b):
def _is_equal_tensor_proto(a, b):
name_a = a.name
name_b = b.name

Expand All @@ -20,25 +24,25 @@ def is_equal_tensor_proto(a, b):
return res


def node_replace_input_with(node_proto, name, new_name):
def _node_replace_input_with(node_proto, name, new_name):
for i, input_name in enumerate(node_proto.input):
if input_name == name:
node_proto.input.insert(i, new_name)
node_proto.input.pop(i + 1)

if node_proto.op_type == "If":
graph_replace_input_with(node_proto.attribute[0].g, name, new_name)
graph_replace_input_with(node_proto.attribute[1].g, name, new_name)
_graph_replace_input_with(node_proto.attribute[0].g, name, new_name)
_graph_replace_input_with(node_proto.attribute[1].g, name, new_name)
if node_proto.op_type == "Loop":
graph_replace_input_with(node_proto.attribute[0].g, name, new_name)
_graph_replace_input_with(node_proto.attribute[0].g, name, new_name)


def graph_replace_input_with(graph_proto, name, new_name):
def _graph_replace_input_with(graph_proto, name, new_name):
for n in graph_proto.node:
node_replace_input_with(n, name, new_name)
_node_replace_input_with(n, name, new_name)


def remove_dup_initializers_from_model(model, model_without_ext, ind_to_replace):
def _remove_dup_initializers_from_model(model, model_without_ext, ind_to_replace):
inits_with_data = [i for i in model.graph.initializer]
inits = [i for i in model_without_ext.graph.initializer]
for i, ref_i in ind_to_replace:
Expand All @@ -52,10 +56,15 @@ def remove_dup_initializers_from_model(model, model_without_ext, ind_to_replace)
model_without_ext.graph.initializer.remove(inits[i])

# for n in model.graph.node:
graph_replace_input_with(model_without_ext.graph, name_i, name_ref)
_graph_replace_input_with(model_without_ext.graph, name_i, name_ref)


def remove_dup_initializers(onnx_file_path):
"""
Removes duplicate initializers from the model to reduce its size.
Writes a new file in the same directory as onnx_file_path and returns the path to that file.
"""

model_file_folder = os.path.dirname(onnx_file_path)
model_file_name = os.path.basename(onnx_file_path)

Expand All @@ -76,7 +85,7 @@ def remove_dup_initializers(onnx_file_path):
for j in range(i + 1, len(inits)):
if j in dup_set:
continue
if is_equal_tensor_proto(inits[i], inits[j]):
if _is_equal_tensor_proto(inits[i], inits[j]):
dup_set.add(i)
dup_set.add(j)

Expand All @@ -103,8 +112,8 @@ def remove_dup_initializers(onnx_file_path):

print("total reduced size: ", total_reduced_size / 1024 / 1024 / 1024, "GB")

ind_to_replace = sorted(ind_to_replace, key=lambda x: x[0])
remove_dup_initializers_from_model(model, model, ind_to_replace)
ind_to_replace = sorted(ind_to_replace)
_remove_dup_initializers_from_model(model, model, ind_to_replace)

optimized_model_file_name = "optimized_" + model_file_name
new_model = os.path.join(model_file_folder, optimized_model_file_name)
Expand Down
Loading