/
base.py
506 lines (416 loc) 路 22.9 KB
/
base.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
# Copyright 2022 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Defines the base classes that are used to perform inference with ONNX Runtime of Transformers models."""
from abc import abstractmethod
from typing import TYPE_CHECKING, Dict, Optional, Set, Tuple, Union
import numpy as np
import torch
from transformers.modeling_outputs import BaseModelOutput, Seq2SeqLMOutput
from onnxruntime import InferenceSession
from ..utils import NormalizedConfigManager
from ..utils.logging import warn_once
from .utils import get_ordered_input_names, logging
logger = logging.get_logger(__name__)
if TYPE_CHECKING:
from .modeling_ort import ORTModel
class ORTModelPart:
"""
For multi-file ONNX models, such as encoder-decoder models, represents a part of the model.
It has its own `onnxruntime.InferenceSession`, and can perform a forward pass.
"""
def __init__(
self,
session: InferenceSession,
parent_model: "ORTModel",
):
self.session = session
self.parent_model = parent_model
self.normalized_config = NormalizedConfigManager.get_normalized_config_class(
self.parent_model.config.model_type
)(self.parent_model.config)
self.main_input_name = self.parent_model.main_input_name
self.input_names = {input_key.name: idx for idx, input_key in enumerate(self.session.get_inputs())}
self.output_names = {output_key.name: idx for idx, output_key in enumerate(self.session.get_outputs())}
self._ordered_input_names = get_ordered_input_names(self.input_names.keys(), func=self.forward)
@property
def device(self):
return self.parent_model.device
@abstractmethod
def forward(self, *args, **kwargs):
pass
def __call__(self, *args, **kwargs):
return self.forward(*args, **kwargs)
class ORTEncoder(ORTModelPart):
"""
Encoder part of the encoder-decoder model for ONNX Runtime inference.
"""
def forward(
self,
input_ids: torch.LongTensor,
attention_mask: torch.LongTensor,
**kwargs,
) -> BaseModelOutput:
use_torch = isinstance(input_ids, torch.Tensor)
self.parent_model.raise_on_numpy_input_io_binding(use_torch)
if self.device.type == "cuda" and self.parent_model.use_io_binding:
model_inputs = [input_ids]
if "attention_mask" in self.input_names:
model_inputs.append(attention_mask)
io_binding, output_shapes, output_buffers = self.parent_model._prepare_io_binding(
self.session,
*model_inputs,
ordered_input_names=self._ordered_input_names,
)
io_binding.synchronize_inputs()
self.session.run_with_iobinding(io_binding)
io_binding.synchronize_outputs()
last_hidden_state = output_buffers["last_hidden_state"].view(output_shapes["last_hidden_state"])
else:
if use_torch:
onnx_inputs = {"input_ids": input_ids.cpu().detach().numpy()}
# Add the attention_mask inputs when needed
if "attention_mask" in self.input_names:
onnx_inputs["attention_mask"] = attention_mask.cpu().detach().numpy()
else:
onnx_inputs = {"input_ids": input_ids}
# Add the attention_mask inputs when needed
if "attention_mask" in self.input_names:
onnx_inputs["attention_mask"] = attention_mask
# Run inference
outputs = self.session.run(None, onnx_inputs)
last_hidden_state = outputs[self.output_names["last_hidden_state"]]
if use_torch:
last_hidden_state = torch.from_numpy(last_hidden_state).to(self.device)
return BaseModelOutput(last_hidden_state=last_hidden_state)
class ORTDecoderForSeq2Seq(ORTModelPart):
"""
Decoder model with a language modeling head on top for ONNX Runtime inference.
"""
def __init__(
self,
session: InferenceSession,
parent_model: "ORTModel",
):
super().__init__(session, parent_model)
# TODO: make this less hacky.
self.key_value_input_names = [key for key in self.input_names if (".key" in key) or (".value" in key)]
self.key_value_output_names = [key for key in self.output_names if (".key" in key) or (".value" in key)]
# To handle the old case when past_key_values were following the format: past_key_values_{idx}
if len(self.key_value_input_names) == 0:
self.key_value_input_names = [key for key in self.input_names if "key_values" in key]
if len(self.key_value_output_names) == 0:
self.key_value_output_names = [key for key in self.output_names if "key_values" in key]
if self.parent_model.use_cache is True and len(self.key_value_output_names) == 0:
raise RuntimeError("Could not find the past key values in the provided model.")
self.use_past_in_outputs = len(self.key_value_output_names) > 0
self.use_past_in_inputs = len(self.key_value_input_names) > 0
self.use_fp16 = False
for inp in session.get_inputs():
if "past_key_values" in inp.name and inp.type == "tensor(float16)":
self.use_fp16 = True
break
# We may use ORTDecoderForSeq2Seq for vision-encoder-decoder models, where models as gpt2
# can be used but do not support KV caching for the cross-attention key/values, see:
# https://github.com/huggingface/transformers/blob/v4.31.0/src/transformers/models/gpt2/modeling_gpt2.py#L302-L311
# This attribute is used to avoid returning cross-attention KV-cache in this case.
self.no_cross_attention_cache = getattr(self.parent_model, "no_cross_attention_cache", False)
if (not self.parent_model.use_merged and self.use_past_in_inputs) or self.no_cross_attention_cache:
self.num_pkv = 2
else:
# When using a merged model, we always have the same number of output whether we use past key values or not,
# and in the case past key values are used, empty tensors are given as cross-attention past key values as they
# are constants
self.num_pkv = 4
self.past_key_values_cross_attention_output_names = set()
for output_name in self.output_names:
if output_name.startswith("present") and "encoder" in output_name:
self.past_key_values_cross_attention_output_names.add(output_name)
self.use_legacy_outputs = (
self.parent_model.use_merged is False and len(self.past_key_values_cross_attention_output_names) > 0
)
def compute_past_key_values_output_shapes(
self,
input_ids: torch.Tensor,
encoder_hidden_states: torch.Tensor,
use_cache_branch: Optional[bool],
past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None,
) -> Dict[str, int]:
batch_size = input_ids.size(0)
num_attention_heads = self.normalized_config.num_attention_heads
embed_size_per_head = self.normalized_config.hidden_size // num_attention_heads
sequence_length = input_ids.size(1)
encoder_sequence_length = encoder_hidden_states.size(1)
if past_key_values is not None and use_cache_branch is not False:
# Here, use_cache_branch may be None in the case of separate decoder without/with past, or True if the with past branch
# of a merged decoder is used
sequence_length += past_key_values[0].size(2)
self_attn_shape = (batch_size, num_attention_heads, sequence_length, embed_size_per_head)
if past_key_values is not None and use_cache_branch is True:
cross_attn_shape = (0, num_attention_heads, 1, embed_size_per_head)
else:
cross_attn_shape = (batch_size, num_attention_heads, encoder_sequence_length, embed_size_per_head)
past_key_values_shapes = {}
for idx, name in enumerate(self.key_value_output_names):
is_self_attn = idx % 4 < 2
# decoder with past does not ouput cross attention key/values as they are constants
past_key_values_shapes[name] = self_attn_shape if (is_self_attn or self.num_pkv == 2) else cross_attn_shape
return past_key_values_shapes
def get_outputs_not_to_bind(self, use_merged_cache: bool) -> Set[str]:
result = {
output_name
for output_name in self.output_names
if (not output_name.startswith("present") and output_name not in {"loss", "logits"})
}
if use_merged_cache is True:
# When using a merged decoder and the use cache branch, we output 0-dim tensors that IO Binding do not support.
# Therefore, we do not bind them.
result = result.union(self.past_key_values_cross_attention_output_names)
return result
def forward(
self,
input_ids: torch.LongTensor,
encoder_hidden_states: torch.FloatTensor,
decoder_attention_mask: Optional[torch.LongTensor] = None,
encoder_attention_mask: Optional[torch.LongTensor] = None,
past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None,
labels: Optional[torch.LongTensor] = None,
use_cache_branch: None = None,
) -> Seq2SeqLMOutput:
# Adding use_cache_branch in the signature here is just a hack for IO Binding
use_torch = isinstance(input_ids, torch.Tensor)
self.parent_model.raise_on_numpy_input_io_binding(use_torch)
# Flatten the past_key_values
if past_key_values is not None:
past_key_values = tuple(
past_key_value for pkv_per_layer in past_key_values for past_key_value in pkv_per_layer
)
# no-ops if merged decoder is not used
use_merged_no_cache = past_key_values is None and self.parent_model.use_merged
use_merged_cache = past_key_values is not None and self.parent_model.use_merged
use_cache_branch_tensor, past_key_values = self.prepare_inputs_for_merged(
input_ids, past_key_values, use_torch=use_torch
)
if self.parent_model.use_io_binding:
known_output_shapes = self.compute_past_key_values_output_shapes(
input_ids,
encoder_hidden_states,
use_cache_branch=use_cache_branch_tensor.item() if use_cache_branch_tensor is not None else None,
past_key_values=past_key_values,
)
outputs_to_not_bind = self.get_outputs_not_to_bind(use_merged_cache)
model_inputs = [input_ids]
if "encoder_hidden_states" in self.input_names:
model_inputs.append(encoder_hidden_states)
if "decoder_attention_mask" in self.input_names:
model_inputs.append(decoder_attention_mask)
if "encoder_attention_mask" in self.input_names:
model_inputs.append(encoder_attention_mask)
if past_key_values is not None:
model_inputs += past_key_values
if "labels" in self.input_names:
model_inputs.append(labels)
known_output_shapes.update({"loss": []})
if use_cache_branch_tensor is not None:
model_inputs.append(use_cache_branch_tensor)
io_binding, output_shapes, output_buffers = self.parent_model._prepare_io_binding(
self.session,
*model_inputs,
known_output_shapes=known_output_shapes,
ordered_input_names=self._ordered_input_names,
outputs_to_not_bind=outputs_to_not_bind,
)
# Set -1 for sequence_length as it could be larger than the real sequence_length
for name, shape in output_shapes.items():
if name in self.key_value_output_names:
output_shapes[name] = shape[:2] + (-1,) + shape[3:]
io_binding.synchronize_inputs()
self.session.run_with_iobinding(io_binding)
io_binding.synchronize_outputs()
# Tuple of length equal to : number of layer * number of past_key_value per decoder layer (2 corresponds to the
# self-attention layer and 2 to the cross-attention layer)
out_past_key_values = ()
for name in self.key_value_output_names:
# TODO: this should be improved
if name in self.past_key_values_cross_attention_output_names and use_merged_cache:
continue
out_past_key_values += (output_buffers[name].view(output_shapes[name]),)
logits = output_buffers["logits"].view(output_shapes["logits"])
loss = None
if "loss" in self.output_names:
loss = output_buffers["loss"].view(output_shapes["loss"])
if not self.use_past_in_outputs:
out_past_key_values = None
elif not self.use_past_in_inputs or use_merged_no_cache:
out_past_key_values = tuple(
out_past_key_values[i : i + self.num_pkv] for i in range(0, len(out_past_key_values), self.num_pkv)
)
else:
if self.use_legacy_outputs is True:
msg = (
"For the decoder with past, using ONNX models outputting cross attention past key values"
" is deprecated and the support will be removed in optimum 2.0. We recommend exporting again the model"
" with optimum>=1.7.3."
)
warn_once(logger, msg=msg)
out_past_key_values = tuple(
out_past_key_values[i : i + self.num_pkv]
for i in range(0, len(out_past_key_values), self.num_pkv)
)
# grab the cross attention key/values from the inputs
elif self.num_pkv == 2:
out_past_key_values = tuple(
out_past_key_values[i : i + self.num_pkv]
+ past_key_values[2 * i + 2 : 2 * i + 2 + self.num_pkv]
for i in range(0, len(out_past_key_values), self.num_pkv)
)
elif self.num_pkv == 4:
# despite num_pkv being 4, we did not bind the cross-attention output
out_past_key_values = tuple(
out_past_key_values[i : i + 2] + past_key_values[2 * i + 2 : 2 * i + 4]
for i in range(0, len(out_past_key_values), 2)
)
else:
raise ValueError("Unsupported num_pkv")
else:
if use_torch:
onnx_inputs = {
"input_ids": input_ids.cpu().detach().numpy(),
}
# Add the encoder_hidden_states inputs when needed
if "encoder_hidden_states" in self.input_names:
onnx_inputs["encoder_hidden_states"] = encoder_hidden_states.cpu().detach().numpy()
# Add the decoder_attention_mask inputs when needed
if "decoder_attention_mask" in self.input_names:
onnx_inputs["decoder_attention_mask"] = decoder_attention_mask.cpu().detach().numpy()
# Add the encoder_attention_mask inputs when needed
if "encoder_attention_mask" in self.input_names:
onnx_inputs["encoder_attention_mask"] = encoder_attention_mask.cpu().detach().numpy()
if past_key_values is not None:
# Add the past_key_values to the decoder inputs
for input_name, past_key_value in zip(self.key_value_input_names, past_key_values):
onnx_inputs[input_name] = past_key_value.cpu().detach().numpy()
if "labels" in self.input_names:
# TODO: Any preprocessing like `self._shift_right(labels)`?
onnx_inputs["labels"] = labels.cpu().detach().numpy()
if self.parent_model.use_merged is True:
onnx_inputs["use_cache_branch"] = use_cache_branch_tensor.cpu().detach().numpy()
else:
onnx_inputs = {
"input_ids": input_ids,
}
# Add the encoder_hidden_states inputs when needed
if "encoder_hidden_states" in self.input_names:
onnx_inputs["encoder_hidden_states"] = encoder_hidden_states
# Add the decoder_attention_mask inputs when needed
if "decoder_attention_mask" in self.input_names:
onnx_inputs["decoder_attention_mask"] = decoder_attention_mask
# Add the encoder_attention_mask inputs when needed
if "encoder_attention_mask" in self.input_names:
onnx_inputs["encoder_attention_mask"] = encoder_attention_mask
if past_key_values is not None:
# Add the past_key_values to the decoder inputs
for input_name, past_key_value in zip(self.key_value_input_names, past_key_values):
onnx_inputs[input_name] = past_key_value
if "labels" in self.input_names:
# TODO: Any preprocessing like `self._shift_right(labels)`?
onnx_inputs["labels"] = labels
if self.parent_model.use_merged is True:
onnx_inputs["use_cache_branch"] = use_cache_branch_tensor
# Run inference
outputs = self.session.run(None, onnx_inputs)
# TODO: using two loops here is probably unefficient
# Tuple of length equal to : number of layer * number of past_key_value per decoder layer (2 corresponds to the
# self-attention layer and 2 to the cross-attention layer)
out_past_key_values = tuple(
torch.from_numpy(outputs[self.output_names[key]]).to(self.device)
for key in self.key_value_output_names
)
logits = outputs[self.output_names["logits"]]
if use_torch:
logits = torch.from_numpy(logits).to(self.device)
loss = None
if "loss" in self.output_names:
loss = outputs[self.output_names["loss"]]
if use_torch:
loss = torch.from_numpy(loss).to(self.device)
# TODO: this is extremely ugly and unreadable. What if cross-attention k/v change?
# Tuple of tuple of length `n_layers`, with each tuple of length equal to:
# * 4 for the decoder without cache (k/v of self-attention + k/v of cross-attention)
# * 2 for the decoder with cache (k/v of self-attention as cross-attention cache is constant)
if not self.use_past_in_outputs:
out_past_key_values = None
elif not self.use_past_in_inputs or use_merged_no_cache or self.no_cross_attention_cache:
out_past_key_values = tuple(
out_past_key_values[i : i + self.num_pkv] for i in range(0, len(out_past_key_values), self.num_pkv)
)
else:
if self.use_legacy_outputs is True:
msg = (
"For the decoder with past, using ONNX models outputting cross attention past key values"
" is deprecated and the support will be removed in optimum 2.0. We recommend exporting again the model"
" with optimum>=1.7.3."
)
warn_once(logger, msg=msg)
out_past_key_values = tuple(
out_past_key_values[i : i + self.num_pkv]
for i in range(0, len(out_past_key_values), self.num_pkv)
)
# grab the cross attention key/values from the inputs
elif self.num_pkv == 2:
out_past_key_values = tuple(
out_past_key_values[i : i + self.num_pkv]
+ past_key_values[2 * i + 2 : 2 * i + 2 + self.num_pkv]
for i in range(0, len(out_past_key_values), self.num_pkv)
)
elif self.num_pkv == 4:
out_past_key_values = tuple(
out_past_key_values[i : i + 2] + past_key_values[i + 2 : i + 4]
for i in range(0, len(out_past_key_values), self.num_pkv)
)
else:
raise ValueError("Unsupported num_pkv")
return Seq2SeqLMOutput(loss=loss, logits=logits, past_key_values=out_past_key_values)
def prepare_inputs_for_merged(
self,
input_ids: Union[None, torch.LongTensor, np.ndarray],
past_key_values: Union[None, Tuple[torch.FloatTensor], Tuple[np.ndarray]],
use_torch: bool,
):
if self.parent_model.use_merged:
constructor = torch if use_torch is True else np
# Uses without/with branch of a merged decoder depending on whether real past key values are passed
use_cache_branch = constructor.full((1,), past_key_values is not None)
else:
# Uses separate decoders
use_cache_branch = None
if use_torch and use_cache_branch is not None:
use_cache_branch = use_cache_branch.to(self.device)
# Generate dummy past for the first forward if uses a merged decoder
if self.parent_model.use_merged and past_key_values is None:
batch_size = input_ids.shape[0]
num_attention_heads = self.normalized_config.num_attention_heads
embed_size_per_head = self.normalized_config.hidden_size // num_attention_heads
dtype = constructor.float16 if self.use_fp16 else constructor.float32
shape = (batch_size, num_attention_heads, 1, embed_size_per_head)
key_or_value = constructor.zeros(shape, dtype=dtype)
if use_torch is True:
key_or_value = key_or_value.to(self.device)
past_key_values = tuple(key_or_value for _ in range(len(self.key_value_input_names)))
return use_cache_branch, past_key_values
class ORTDecoder(ORTDecoderForSeq2Seq):
def __init__(self, *args, **kwargs):
logger.warning(
"The class `ORTDecoder` is deprecated and will be removed in optimum v1.15.0, please use `ORTDecoderForSeq2Seq` instead."
)
super().__init__(*args, **kwargs)