-
Notifications
You must be signed in to change notification settings - Fork 403
/
optimization.py
329 lines (284 loc) 路 15.2 KB
/
optimization.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
# Copyright 2021 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Main class for performing graph optimization with ONNX Runtime."""
import gc
import os
from pathlib import Path
from typing import TYPE_CHECKING, Dict, List, Optional, Union
import onnx
from onnx import load_model
from transformers.models.auto.configuration_auto import AutoConfig
from onnxruntime.transformers.onnx_model_bert import BertOnnxModel
from onnxruntime.transformers.optimizer import optimize_model
from ..onnx.utils import check_model_uses_external_data
from ..utils import CONFIG_NAME, NormalizedConfigManager, logging
from ..utils.save_utils import maybe_save_preprocessors
from .configuration import OptimizationConfig, ORTConfig
from .modeling_decoder import ORTModelForCausalLM
from .modeling_ort import ORTModel
from .modeling_seq2seq import ORTModelForConditionalGeneration
from .utils import ONNX_WEIGHTS_NAME, ORTConfigManager
if TYPE_CHECKING:
from transformers import PretrainedConfig
logger = logging.get_logger()
class ORTOptimizer:
"""
Handles the ONNX Runtime optimization process for models shared on huggingface.co/models.
"""
def __init__(self, onnx_model_path: List[os.PathLike], config: "PretrainedConfig", from_ortmodel: bool = False):
"""
Args:
onnx_model_path (`List[os.PathLike]`):
The paths of the onnx models to optimize.
config ([`~transformers.PretrainedConfig`]):
An instance of the configuration associated to the model to optimize.
from_ortmodel (`bool`, defaults to `False`):
Whether the model being optimized is already loaded into an ORTModel, or if it was passed from disk.
"""
super().__init__()
self.onnx_model_path = onnx_model_path
self.config = config
self.model_type = self.config.model_type
self.from_ortmodel = from_ortmodel
try:
self.normalized_config = NormalizedConfigManager.get_normalized_config_class(self.model_type)(self.config)
except KeyError:
raise NotImplementedError(
f"Tried to use ORTOptimizer for the model type {self.model_type}, but it is not available yet. Please open an issue"
" or submit a PR at https://github.com/huggingface/optimum."
)
@classmethod
def from_pretrained(
cls, model_or_path: Union[str, os.PathLike, ORTModel], file_names: Optional[List[str]] = None
) -> "ORTOptimizer":
"""
Args:
model_or_path (`Union[str, os.PathLike, ORTModel]`):
The path to a local directory hosting the model to optimize or an instance of an `ORTModel` to quantize.
Can be either:
- A path to a local *directory* containing the model to optimize.
- An instance of [`~optimum.onnxruntime.ORTModel`].
file_names(`Optional[List[str]]`, defaults to `None`):
The list of file names of the models to optimize.
"""
onnx_model_path = []
config = None
if isinstance(model_or_path, ORTModel):
from_ortmodel = True
if isinstance(model_or_path, ORTModelForConditionalGeneration):
onnx_model_path += [
model_or_path.encoder_model_path,
model_or_path.decoder_model_path,
]
# Add the decoder with past key/values if present
if model_or_path.use_cache:
onnx_model_path.append(model_or_path.decoder_with_past_model_path)
elif isinstance(model_or_path, ORTModelForCausalLM) and model_or_path.use_merged:
raise NotImplementedError(
"ORTOptimizer does not support ORTModelForCausalLM models when without/with past models are merged. "
"Please re-export your model. This can be done by using the optimum-cli ONNX export tool or `ORTModelForCausalLM.from_pretrained(..., export=True, use_merged=False)`."
)
else:
onnx_model_path.append(model_or_path.model_path)
config = model_or_path.config
elif os.path.isdir(model_or_path):
from_ortmodel = False
file_names = [ONNX_WEIGHTS_NAME] if file_names is None else file_names
model_or_path = Path(model_or_path)
if CONFIG_NAME not in os.listdir(model_or_path):
raise ValueError(f"The local directory does not contain the configuration file {CONFIG_NAME}.")
config = AutoConfig.from_pretrained(model_or_path)
for file_name in file_names:
onnx_model_path.append(model_or_path.joinpath(file_name))
else:
raise ValueError(f"Unable to load the model from {model_or_path}.")
return cls(onnx_model_path, config=config, from_ortmodel=from_ortmodel)
def optimize(
self,
optimization_config: OptimizationConfig,
save_dir: Union[str, os.PathLike],
file_suffix: Optional[str] = "optimized",
use_external_data_format: Optional[bool] = None,
one_external_file: bool = True,
):
"""
Optimizes a model given the optimization specifications defined in `optimization_config`.
Args:
optimization_config ([`~optimum.onnxruntime.OptimizationConfig`]):
The configuration containing the parameters related to optimization.
save_dir (`Union[str, os.PathLike]`):
The path used to save the optimized model.
file_suffix (`str`, defaults to `"optimized"`):
The file suffix used to save the optimized model.
use_external_data_format (`Optional[bool]`, defaults to `None`):
Whether to use external data format to store model of size >= 2Gb. This argument is deprecated.
one_external_file (`bool`, defaults to `True`):
When `use_external_data_format=True`, whether to save all tensors to one external file.
If False, save each tensor to a file named with the tensor name.
"""
if use_external_data_format is not None:
logger.warning(
"The argument use_external_data_format in the ORTOptimizer.optimize() method is deprecated and will"
" be removed in optimum 2.0."
)
save_dir = Path(save_dir)
save_dir.mkdir(parents=True, exist_ok=True)
ORTConfigManager.check_optimization_supported_model(self.model_type, optimization_config)
self.config.save_pretrained(save_dir)
maybe_save_preprocessors(self.onnx_model_path[0].parent, save_dir)
model_type = ORTConfigManager.get_model_ort_type(self.config.model_type)
optimization_options = optimization_config.create_fusion_options(model_type)
logger.info("Optimizing model...")
# TODO: this is quite inefficient as we load in memory if models are <2GB without external data
model_uses_external_data = False
for model_path in self.onnx_model_path:
# check if external data was exported
onnx_model = onnx.load(str(model_path), load_external_data=False)
if check_model_uses_external_data(onnx_model) is True:
model_uses_external_data = True
break
del onnx_model
gc.collect()
# Create and save the configuration summarizing all the parameters related to optimization
ort_config = ORTConfig(
optimization=optimization_config,
use_external_data_format=model_uses_external_data,
one_external_file=one_external_file,
)
for model_path in self.onnx_model_path:
suffix = f"_{file_suffix}" if file_suffix else ""
output_path = save_dir.joinpath(f"{model_path.stem}{suffix}").with_suffix(model_path.suffix)
try:
optimizer = optimize_model(
model_path.as_posix(),
model_type,
self.normalized_config.num_attention_heads,
self.normalized_config.hidden_size,
opt_level=optimization_config.optimization_level,
optimization_options=optimization_options,
use_gpu=optimization_config.optimize_for_gpu,
only_onnxruntime=not optimization_config.enable_transformers_specific_optimizations,
)
if optimization_config.fp16:
if model_uses_external_data:
# Refer to https://github.com/microsoft/onnxruntime/blob/v1.15.0/onnxruntime/python/tools/transformers/float16.py#L204
# The ONNX infer_shapes_path method should be used instead of infer_shapes
# for models >= 2 GB, and it expects a model written to disk.
# Note that convert_float_to_float16 then overwrites optimizer.model as the
# new ModelProto.
optimizer.save_model_to_file(
output_path.as_posix(),
use_external_data_format=model_uses_external_data,
all_tensors_to_one_file=one_external_file,
)
optimizer.model = output_path.as_posix()
# keep_io_types to keep inputs/outputs as float32
optimizer.convert_float_to_float16(
use_symbolic_shape_infer=not optimization_config.disable_shape_inference, keep_io_types=True
)
except Exception as e:
if "Incomplete symbolic shape inference" in str(e):
err = RuntimeError(
f"{str(e)}. Try to set `disable_shape_inference=True` in your optimization configuration."
)
raise err from e
raise
# TODO: ORT save_model_to_file will save as `.data` although we save as `.onnx_data` in the export
optimizer.save_model_to_file(
output_path.as_posix(),
use_external_data_format=model_uses_external_data,
all_tensors_to_one_file=one_external_file,
)
# if loading from disk and saving in the same repository, remove previous external data
if Path(model_path.as_posix() + "_data").is_file() and self.from_ortmodel is False:
os.remove(model_path.as_posix() + "_data")
# Save the model configuration
self.config.save_pretrained(save_dir)
ort_config.save_pretrained(save_dir)
logger.info(
f"Optimized model saved at: {save_dir} (external data format: "
f"{model_uses_external_data}; saved all tensor to one file: "
f"{one_external_file})"
)
return Path(save_dir)
@staticmethod
def get_fused_operators(onnx_model_path: Union[str, os.PathLike]) -> Dict[str, int]:
"""
Computes the dictionary mapping the name of the fused operators to their number of apparition in the model.
Args:
onnx_model_path (`Union[str, os.PathLike]`):
Path of the ONNX model.
Returns:
The dictionary mapping the name of the fused operators to their number of apparition in the model.
"""
onnx_optimized_model = BertOnnxModel(load_model(onnx_model_path))
fused_operator = onnx_optimized_model.get_fused_operator_statistics()
logger.info(
f"The following operators were fused : { ', '.join([k for k,v in fused_operator.items() if v > 0])}"
)
return {k: v for k, v in fused_operator.items() if v > 0}
@staticmethod
def get_nodes_number_difference(
onnx_model_path: Union[str, os.PathLike], onnx_optimized_model_path: Union[str, os.PathLike]
) -> int:
"""
Compute the difference in the number of nodes between the original and the optimized model.
Args:
onnx_model_path (`Union[str, os.PathLike]`):
Path of the ONNX model.
onnx_optimized_model_path (`Union[str, os.PathLike]`):
Path of the optimized ONNX model.
Returns:
The difference in the number of nodes between the original and the optimized model.
"""
onnx_model = BertOnnxModel(load_model(onnx_model_path))
onnx_optimized_model = BertOnnxModel(load_model(onnx_optimized_model_path))
# Information in the number of nodes decrease resulting from optimization
nodes_number_onnx_model = len(onnx_model.nodes())
nodes_number_onnx_optimized_model = len(onnx_optimized_model.nodes())
difference_nodes_number = nodes_number_onnx_model - nodes_number_onnx_optimized_model
logger.info(
f"There are {nodes_number_onnx_model} nodes before optimization and {nodes_number_onnx_optimized_model}"
f"nodes after. The number of nodes removed is {difference_nodes_number}"
)
return difference_nodes_number
@staticmethod
def get_operators_difference(
onnx_model_path: Union[str, os.PathLike], onnx_optimized_model_path: Union[str, os.PathLike]
) -> Dict[str, int]:
"""
Compute the dictionary mapping the operators name to the difference in the number of corresponding nodes between
the original and the optimized model.
Args:
onnx_model_path (`Union[str, os.PathLike]`):
Path of the ONNX model.
onnx_optimized_model_path (`Union[str, os.PathLike]`):
Path of the optimized ONNX model.
Returns:
The dictionary mapping the operators name to the difference in the number of corresponding nodes between the
original and the optimized model.
"""
onnx_model = BertOnnxModel(load_model(onnx_model_path))
onnx_optimized_model = BertOnnxModel(load_model(onnx_optimized_model_path))
def nodes_difference_given_type(op_type):
onnx_model_nodes_with_op_type = len(onnx_model.get_nodes_by_op_type(op_type))
onnx_optimized_model_nodes_with_op_type = len(onnx_optimized_model.get_nodes_by_op_type(op_type))
return onnx_model_nodes_with_op_type - onnx_optimized_model_nodes_with_op_type
# Compute operators difference between the original and the optimized models
op_types = set()
for model in [onnx_model, onnx_optimized_model]:
for node in model.nodes():
op_types.add(node.op_type)
operators_difference = {op_type: nodes_difference_given_type(op_type) for op_type in op_types}
return {k: v for k, v in operators_difference.items() if v != 0}