-
Notifications
You must be signed in to change notification settings - Fork 25.5k
/
convert_graph_to_onnx.py
222 lines (170 loc) 路 7.9 KB
/
convert_graph_to_onnx.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
from argparse import ArgumentParser
from os import listdir, makedirs
from os.path import abspath, dirname, exists
from typing import Dict, List, Optional, Tuple
from transformers import is_tf_available, is_torch_available
from transformers.pipelines import Pipeline, pipeline
from transformers.tokenization_utils import BatchEncoding
class OnnxConverterArgumentParser(ArgumentParser):
"""
Wraps all the script arguments supported to export transformers models to ONNX IR
"""
def __init__(self):
super(OnnxConverterArgumentParser, self).__init__("ONNX Converter")
self.add_argument("--model", type=str, required=True, help="Model's id or path (ex: bert-base-cased)")
self.add_argument("--tokenizer", type=str, help="Tokenizer's id or path (ex: bert-base-cased)")
self.add_argument("--framework", type=str, choices=["pt", "tf"], help="Framework for loading the model")
self.add_argument("--opset", type=int, default=11, help="ONNX opset to use")
self.add_argument("--check-loading", action="store_true", help="Check ONNX is able to load the model")
self.add_argument("--use-external-format", action="store_true", help="Allow exporting model >= than 2Gb")
self.add_argument("output")
def ensure_valid_input(model, tokens, input_names):
"""
Ensure input are presented in the correct order, without any None
Args:
model: The model used to forward the input data
tokens: BatchEncoding holding the input data
input_names: The name of the inputs
Returns: Tuple
"""
model_args_name = model.forward.__code__.co_varnames
ordered_input_names = []
model_args = []
for arg_name in model_args_name[1:]: # start at index 1 to skip "self" argument
if arg_name in input_names:
ordered_input_names.append(arg_name)
model_args.append(tokens[arg_name])
else:
break
return ordered_input_names, tuple(model_args)
def infer_shapes(nlp: Pipeline, framework: str) -> Tuple[List[str], List[str], Dict, BatchEncoding]:
def build_shape_dict(tensor, is_input: bool, seq_len: int):
if isinstance(tensor, (tuple, list)):
return [build_shape_dict(t, is_input, seq_len) for t in tensor]
else:
# Let's assume batch is the first axis with only 1 element (~~ might not be always true ...)
axes = {[axis for axis, numel in enumerate(tensor.shape) if numel == 1][0]: "batch"}
if is_input:
if len(tensor.shape) == 2:
axes[1] = "sequence"
else:
raise ValueError("Unable to infer tensor axes ({})".format(len(tensor.shape)))
else:
seq_axes = [dim for dim, shape in enumerate(tensor.shape) if shape == seq_len]
axes.update({dim: "sequence" for dim in seq_axes})
return axes
tokens = nlp.tokenizer.encode_plus("This is a sample output", return_tensors=framework)
seq_len = tokens.input_ids.shape[-1]
outputs = nlp.model(**tokens) if framework == "pt" else nlp.model(tokens)
if not isinstance(outputs, (list, tuple)):
outputs = (outputs,)
# Generate input names & axes
input_vars = list(tokens.keys())
input_dynamic_axes = {k: build_shape_dict(v, True, seq_len) for k, v in tokens.items()}
# flatten potentially grouped outputs (past for gpt2, attentions)
outputs_flat = []
for output in outputs:
if isinstance(output, (tuple, list)):
outputs_flat.extend(output)
else:
outputs_flat.append(output)
# Generate output names & axes
output_names = ["output_{}".format(i) for i in range(len(outputs_flat))]
output_dynamic_axes = {k: build_shape_dict(v, False, seq_len) for k, v in zip(output_names, outputs_flat)}
# Create the aggregated axes representation
dynamic_axes = dict(input_dynamic_axes, **output_dynamic_axes)
return input_vars, output_names, dynamic_axes, tokens
def load_graph_from_args(framework: str, model: str, tokenizer: Optional[str] = None) -> Pipeline:
# If no tokenizer provided
if tokenizer is None:
tokenizer = model
print("Loading pipeline (model: {}, tokenizer: {})".format(model, tokenizer))
# Allocate tokenizer and model
return pipeline("feature-extraction", model=model, tokenizer=tokenizer, framework=framework)
def convert_pytorch(nlp: Pipeline, opset: int, output: str, use_external_format: bool):
if not is_torch_available():
raise Exception("Cannot convert because PyTorch is not installed. Please install torch first.")
import torch
from torch.onnx import export
print("PyTorch: {}".format(torch.__version__))
with torch.no_grad():
input_names, output_names, dynamic_axes, tokens = infer_shapes(nlp, "pt")
ordered_input_names, model_args = ensure_valid_input(nlp.model, tokens, input_names)
export(
nlp.model,
model_args,
f=output,
input_names=ordered_input_names,
output_names=output_names,
dynamic_axes=dynamic_axes,
do_constant_folding=True,
use_external_data_format=use_external_format,
enable_onnx_checker=True,
opset_version=opset,
)
def convert_tensorflow(nlp: Pipeline, opset: int, output: str):
if not is_tf_available():
raise Exception(
"Cannot convert {} because TF is not installed. Please install torch first.".format(args.model)
)
print("/!\\ Please note TensorFlow doesn't support exporting model > 2Gb /!\\")
try:
import tensorflow as tf
from keras2onnx import convert_keras, save_model, __version__ as k2ov
print("TensorFlow: {}, keras2onnx: {}".format(tf.version.VERSION, k2ov))
# Build
input_names, output_names, dynamic_axes, tokens = infer_shapes(nlp, "tf")
# Forward
nlp.model.predict(tokens.data)
onnx_model = convert_keras(nlp.model, nlp.model.name, target_opset=opset)
save_model(onnx_model, output)
except ImportError as e:
raise Exception(
"Cannot import {} required to convert TF model to ONNX. Please install {} first.".format(e.name, e.name)
)
def convert(
framework: str,
model: str,
output: str,
opset: int,
tokenizer: Optional[str] = None,
use_external_format: bool = False,
):
print("ONNX opset version set to: {}".format(opset))
# Load the pipeline
nlp = load_graph_from_args(framework, model, tokenizer)
parent = dirname(output)
if not exists(parent):
print("Creating folder {}".format(parent))
makedirs(parent)
elif len(listdir(parent)) > 0:
raise Exception("Folder {} is not empty, aborting conversion".format(parent))
# Export the graph
if framework == "pt":
convert_pytorch(nlp, opset, output, use_external_format)
else:
convert_tensorflow(nlp, opset, output)
def verify(path: str):
from onnxruntime import InferenceSession, SessionOptions
from onnxruntime.capi.onnxruntime_pybind11_state import RuntimeException
print("Checking ONNX model loading from: {}".format(path))
try:
onnx_options = SessionOptions()
_ = InferenceSession(path, onnx_options, providers=["CPUExecutionProvider"])
print("Model correctly loaded")
except RuntimeException as re:
print("Error while loading the model: {}".format(re))
if __name__ == "__main__":
parser = OnnxConverterArgumentParser()
args = parser.parse_args()
# Make sure output is absolute path
args.output = abspath(args.output)
try:
# Convert
convert(args.framework, args.model, args.output, args.opset, args.tokenizer, args.use_external_format)
# And verify
if args.check_loading:
verify(args.output)
except Exception as e:
print("Error while converting the model: {}".format(e))
exit(1)