-
Notifications
You must be signed in to change notification settings - Fork 390
/
modeling_decoder.py
697 lines (610 loc) 路 30.2 KB
/
modeling_decoder.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
# Copyright 2022 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Classes handling causal-lm related architectures in ONNX Runtime."""
import logging
import shutil
from pathlib import Path
from tempfile import TemporaryDirectory
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple, Union
import torch
from huggingface_hub import hf_hub_download
from huggingface_hub.utils import EntryNotFoundError
from transformers import AutoModelForCausalLM, GenerationConfig
from transformers.file_utils import add_start_docstrings_to_model_forward
from transformers.modeling_outputs import CausalLMOutputWithCrossAttentions
import onnxruntime
from ..exporters.onnx import main_export
from ..onnx.utils import _get_external_data_paths
from ..utils import check_if_transformers_greater
from ..utils.file_utils import validate_file_exists
from ..utils.save_utils import maybe_load_preprocessors, maybe_save_preprocessors
from .base import ORTDecoder
from .constants import DECODER_MERGED_ONNX_FILE_PATTERN, DECODER_ONNX_FILE_PATTERN, DECODER_WITH_PAST_ONNX_FILE_PATTERN
from .modeling_ort import ORTModel
from .utils import (
ONNX_DECODER_NAME,
ONNX_DECODER_WITH_PAST_NAME,
get_provider_for_device,
parse_device,
validate_provider_availability,
)
if TYPE_CHECKING:
from transformers import PretrainedConfig
if check_if_transformers_greater("4.25.0"):
from transformers.generation import GenerationMixin
else:
from transformers.generation_utils import GenerationMixin
logger = logging.getLogger(__name__)
DECODER_INPUTS_DOCSTRING = r"""
Args:
input_ids (`torch.LongTensor`):
Indices of decoder input sequence tokens in the vocabulary of shape `(batch_size, sequence_length)`.
attention_mask (`torch.LongTensor`, *optional*):
Mask to avoid performing attention on padding token indices, of shape
`(batch_size, sequence_length)`. Mask values selected in `[0, 1]`.
past_key_values (`tuple(tuple(torch.FloatTensor), *optional*, defaults to `None`)`
Contains the precomputed key and value hidden states of the attention blocks used to speed up decoding.
The tuple is of length `config.n_layers` with each tuple having 2 tensors of shape
`(batch_size, num_heads, sequence_length, embed_size_per_head)`.
"""
CAUSALLM_ONNX_MODEL_DOCSTRING = r"""
Args:
input_ids (`torch.LongTensor`):
Indices of decoder input sequence tokens in the vocabulary of shape `(batch_size, sequence_length)`.
attention_mask (`torch.LongTensor`):
Mask to avoid performing attention on padding token indices, of shape
`(batch_size, sequence_length)`. Mask values selected in `[0, 1]`.
past_key_values (`tuple(tuple(torch.FloatTensor), *optional*, defaults to `None`)`
Contains the precomputed key and value hidden states of the attention blocks used to speed up decoding.
The tuple is of length `config.n_layers` with each tuple having 2 tensors of shape
`(batch_size, num_heads, sequence_length, embed_size_per_head)`.
"""
_TOKENIZER_FOR_DOC = "AutoTokenizer"
TEXT_GENERATION_EXAMPLE = r"""
Example of text generation:
```python
>>> from transformers import {processor_class}
>>> from optimum.onnxruntime import {model_class}
>>> import torch
>>> tokenizer = {processor_class}.from_pretrained("{checkpoint}")
>>> model = {model_class}.from_pretrained("{checkpoint}")
>>> inputs = tokenizer("My name is Arthur and I live in", return_tensors="pt")
>>> gen_tokens = model.generate(**inputs,do_sample=True,temperature=0.9, min_length=20,max_length=20)
>>> tokenizer.batch_decode(gen_tokens) # doctest: +IGNORE_RESULT
```
Example using `transformers.pipelines`:
```python
>>> from transformers import {processor_class}, pipeline
>>> from optimum.onnxruntime import {model_class}
>>> tokenizer = {processor_class}.from_pretrained("{checkpoint}")
>>> model = {model_class}.from_pretrained("{checkpoint}")
>>> onnx_gen = pipeline("text-generation", model=model, tokenizer=tokenizer)
>>> text = "My name is Arthur and I live in"
>>> gen = onnx_gen(text)
```
"""
class ORTModelDecoder(ORTModel):
"""
Base class for implementing models with a causal language modeling head using ONNX Runtime inference.
"""
def __init__(
self,
decoder_session: onnxruntime.InferenceSession,
config: "PretrainedConfig",
onnx_paths: List[str],
decoder_with_past_session: Optional[onnxruntime.InferenceSession] = None,
use_cache: bool = True,
use_io_binding: Optional[bool] = None,
model_save_dir: Optional[Union[str, Path, TemporaryDirectory]] = None,
preprocessors: Optional[List] = None,
generation_config: Optional[GenerationConfig] = None,
**kwargs,
):
"""
Args:
decoder_session (`onnxruntime.InferenceSession`):
The ONNX Runtime inference session associated to the decoder.
config ([`~transformers.PretrainedConfig`]):
An instance of the configuration associated to the model. Initializing with a config file does
not load the weights associated with the model, only the configuration.
decoder_with_past_session (`Optional[onnxruntime.InferenceSession]`, defaults to `None`):
The ONNX Runtime inference session associated to the decoder with past key values. This argument should not
be set if use_merged=True is used.
onnx_paths (`List[str]`):
Path to ONNX files associated with the model.
use_cache (`bool`, defaults to `True`):
Whether or not past key/values cache should be used. Defaults to `True`.
use_io_binding (`Optional[bool]`, defaults to `None`):
Whether use IOBinding during inference to avoid memory copy between the host and devices. Defaults to
`True` if the device is CUDA, otherwise defaults to `False`.
model_save_dir (`Optional[Union[str, Path, TemporaryDirectory]]`, defaults to `""`):
The directory under which the model exported to ONNX was saved.
preprocessors (`Optional[List]`, defaults to `None`):
The list of the preprocessors (tokenizer, processor, feature_extractor) to save alongside the ORTModel.
generation_config (`Optional[GenerationConfig]`, defaults to `None`):
The generation configuration used by default when calling `generate()`.
Refer to https://huggingface.co/docs/transformers/main/en/main_classes/text_generation#transformers.GenerationMixin.generate.
"""
if use_io_binding is None:
if decoder_session.get_providers()[0] == "CUDAExecutionProvider":
use_io_binding = True
else:
use_io_binding = False
self.shared_attributes_init(
decoder_session,
use_io_binding=use_io_binding,
model_save_dir=model_save_dir,
)
self.config = config
# TODO: remove at version 2.0
def show_deprecated_argument(arg_name):
if kwargs.pop(arg_name, None) is not None:
logger.warning(
f"The {arg_name} argument to create an {self.__class__.__name__} is deprecated, and not used "
"anymore."
)
show_deprecated_argument("last_decoder_model_name")
show_deprecated_argument("last_decoder_with_past_model_name")
if kwargs:
raise ValueError(
f"{self.__class__.__name__} received {', '.join(kwargs.keys())}, but do not accept those arguments."
)
if use_cache is True:
# Auto-detect whether the provided session is a merged non-past / with-past or not
# TODO: make __init__ private and pass `use_merged` as an argument
use_merged = "use_cache_branch" in [inp.name for inp in decoder_session.get_inputs()]
if use_merged is True and decoder_with_past_session is not None:
raise ValueError(
"Detected a merged decoder, but decoder_with_past_session was provided."
"Please only set decoder_session, or provide a non-merged decoder_session."
)
if use_cache is True and use_merged is False and decoder_with_past_session is None:
raise ValueError(
"The parameter use_cache was set as True, but neither decoder_with_past_session was passed"
" nor a use_cache branch can be found in the decoder_session."
" Please pass a decoder_with_past_session or set use_cache=False."
)
else:
use_merged = False
if decoder_with_past_session is not None:
raise ValueError(
"The parameter decoder_with_past_session was passed, although use_cache is False."
"Please pass use_cache=True for decoder_with_past_session to be used."
)
if use_cache is False and use_io_binding is True:
raise ValueError(
"When using CUDAExecutionProvider, the parameters combination use_cache=False, use_io_binding=True"
" is not supported. Please either pass use_cache=True, use_io_binding=True (default),"
" or use_cache=False, use_io_binding=False."
)
self.onnx_paths = onnx_paths
self.use_cache = use_cache
self.use_merged = use_merged
self.decoder = ORTDecoder(decoder_session, self)
self.decoder_model_path = Path(decoder_session._model_path)
self.decoder_model_name = self.decoder_model_path.name
self.decoder_with_past = None
self.decoder_with_past_model_path = None
self.decoder_with_past_model_name = None
if self.use_cache is True and self.use_merged is False:
self.decoder_with_past = ORTDecoder(decoder_with_past_session, self)
self.decoder_with_past_model_path = Path(decoder_with_past_session._model_path)
self.decoder_with_past_model_name = self.decoder_with_past_model_path.name
if generation_config is None:
generation_config = GenerationConfig.from_model_config(config)
self.generation_config = generation_config
@staticmethod
def _generate_regular_names_for_filename(filename: str):
name, extension = filename.rsplit(".", maxsplit=1)
return [
filename,
f"{name}_quantized.{extension}",
f"{name}_optimized.{extension}",
f"{name}_merged.{extension}",
]
@staticmethod
def load_model(
decoder_path: Union[str, Path],
decoder_with_past_path: Optional[Union[str, Path]] = None,
provider: str = "CPUExecutionProvider",
session_options: Optional[onnxruntime.SessionOptions] = None,
provider_options: Optional[Dict] = None,
):
"""
Creates an instance of [`~optimum.onnxruntime.ORTModelDecoder`].
Three inference sessions will be created for respectively the decoder and decoder with past key values
models. The default provider is `CPUExecutionProvider` to match the default behaviour in PyTorch/TensorFlow/JAX.
Args:
decoder_path (`str` or `Path`):
The path of the decoder ONNX model.
decoder_with_past_path (`str` or `Path`, *optional*):
The path of the decoder with past key values ONNX model.
provider(`str`, *optional*, defaults to `"CPUExecutionProvider"`):
The ONNX Runtime provider to use for loading the model.
session_options (`Optional[onnxruntime.SessionOptions]`, *optional*),:
ONNX Runtime session options to use for loading the model.
provider_options (`Optional[Dict]`, *optional*):
Provider option dictionary corresponding to the provider used. See available options
for each provider: https://onnxruntime.ai/docs/api/c/group___global.html.
"""
decoder_session = ORTModel.load_model(decoder_path, provider, session_options, provider_options)
decoder_with_past_session = None
# If a decoder_with_past_path is provided, an inference session for the decoder with past key/values as inputs
# will be enabled
if decoder_with_past_path is not None:
decoder_with_past_session = ORTModel.load_model(
decoder_with_past_path, provider, session_options, provider_options
)
return decoder_session, decoder_with_past_session
def _save_pretrained(self, save_directory: Union[str, Path]):
"""
Saves the model decoder and decoder with past key values as well as its configuration file to a
directory, so that it can be re-loaded using the
[`~optimum.onnxruntime.modeling_causal.ORTModelDecoder.from_pretrained`] class method.
Args:
save_directory (`str` or `Path`):
The directory where to save the model files.
"""
save_directory = Path(save_directory)
src_paths = [Path(path) for path in self.onnx_paths]
dst_paths = [save_directory / path.name for path in src_paths]
# add external data paths in case of large models
src_paths, dst_paths = _get_external_data_paths(src_paths, dst_paths)
for src_path, dst_path in zip(src_paths, dst_paths):
shutil.copyfile(src_path, dst_path)
self.generation_config.save_pretrained(save_directory)
@classmethod
def _from_pretrained(
cls,
model_id: Union[str, Path],
config: "PretrainedConfig",
use_auth_token: Optional[Union[bool, str]] = None,
revision: Optional[str] = None,
force_download: bool = False,
cache_dir: Optional[str] = None,
decoder_file_name: str = ONNX_DECODER_NAME,
decoder_with_past_file_name: str = ONNX_DECODER_WITH_PAST_NAME,
subfolder: str = "",
local_files_only: bool = False,
use_cache: bool = True,
use_merged: Optional[bool] = None,
provider: str = "CPUExecutionProvider",
session_options: Optional[onnxruntime.SessionOptions] = None,
provider_options: Optional[Dict[str, Any]] = None,
use_io_binding: Optional[bool] = None,
model_save_dir: Optional[Union[str, Path, TemporaryDirectory]] = None,
**kwargs,
):
model_path = Path(model_id)
# We do not implement the logic for use_cache=False, use_merged=True
if use_cache is False:
if use_merged is True:
raise ValueError(
"The parameters combination use_cache=False, use_merged=True is not supported."
" To use a merged decoder, past key values must be used."
)
use_merged = False
decoder_merged_path = None
# We use `is not False` here to include two cases: use_merged = None (in which case we auto-detect it),
# and use_merged = True (explicitely specified by the user)
if use_merged is not False:
try:
decoder_merged_path = ORTModelDecoder.infer_onnx_filename(
model_id,
[DECODER_MERGED_ONNX_FILE_PATTERN],
argument_name=None,
subfolder=subfolder,
use_auth_token=use_auth_token,
revision=revision,
)
use_merged = True
decoder_path = decoder_merged_path
except FileNotFoundError as e:
if use_merged is True:
raise FileNotFoundError(
"The parameter `use_merged=True` was passed to ORTModelForCausalLM.from_pretrained()"
" but no ONNX file for a merged decoder could be found in"
f" {str(Path(model_id, subfolder))}, with the error: {e}"
)
use_merged = False
decoder_without_past_path = None
decoder_with_past_path = None
if use_merged is False:
if not validate_file_exists(model_id, decoder_file_name, subfolder=subfolder, revision=revision):
decoder_without_past_path = ORTModelDecoder.infer_onnx_filename(
model_id,
[DECODER_ONNX_FILE_PATTERN],
"decoder_file_name",
subfolder=subfolder,
use_auth_token=use_auth_token,
revision=revision,
)
else:
decoder_without_past_path = model_path / subfolder / decoder_file_name
decoder_path = decoder_without_past_path
decoder_regular_onnx_filenames = ORTModelDecoder._generate_regular_names_for_filename(ONNX_DECODER_NAME)
if decoder_path.name not in decoder_regular_onnx_filenames:
logger.warning(
f"The ONNX file {decoder_path.name} is not a regular name used in optimum.onnxruntime that are {decoder_regular_onnx_filenames}, the "
f"{cls.__name__} might not behave as expected."
)
# If the decoder without / with past has been merged, we do not need to look for any additional file
if use_cache is True:
if not validate_file_exists(
model_id, decoder_with_past_file_name, subfolder=subfolder, revision=revision
):
try:
decoder_with_past_path = ORTModelDecoder.infer_onnx_filename(
model_id,
[DECODER_WITH_PAST_ONNX_FILE_PATTERN],
"decoder_with_past_file_name",
subfolder=subfolder,
use_auth_token=use_auth_token,
revision=revision,
)
except FileNotFoundError as e:
raise FileNotFoundError(
"The parameter `use_cache=True` was passed to ORTModelForCausalLM.from_pretrained()"
" but no ONNX file using past key values could be found in"
f" {str(Path(model_id, subfolder))}, with the error: {e}"
)
else:
decoder_with_past_path = model_path / subfolder / decoder_with_past_file_name
decoder_with_past_regular_onnx_filenames = ORTModelDecoder._generate_regular_names_for_filename(
ONNX_DECODER_WITH_PAST_NAME
)
if decoder_with_past_path.name not in decoder_with_past_regular_onnx_filenames:
logger.warning(
f"The ONNX file {decoder_with_past_path.name} is not a regular name used in optimum.onnxruntime that are {decoder_with_past_regular_onnx_filenames}, "
f"the {cls.__name__} might not behave as expected."
)
preprocessors = None
if model_path.is_dir():
new_model_save_dir = model_path
preprocessors = maybe_load_preprocessors(model_id)
else:
attribute_name_to_filename = {
"last_decoder_model_name": decoder_path.name if use_merged is False else None,
"last_decoder_with_past_model_name": decoder_with_past_path.name
if (use_cache is True and use_merged is False)
else None,
"last_decoder_merged_name": decoder_merged_path.name if use_merged is True else None,
}
paths = {}
for attr_name, filename in attribute_name_to_filename.items():
if filename is None:
continue
model_cache_path = hf_hub_download(
repo_id=model_id,
subfolder=subfolder,
filename=filename,
use_auth_token=use_auth_token,
revision=revision,
cache_dir=cache_dir,
force_download=force_download,
local_files_only=local_files_only,
)
# try download external data
try:
hf_hub_download(
repo_id=model_id,
subfolder=subfolder,
filename=filename + "_data",
use_auth_token=use_auth_token,
revision=revision,
cache_dir=cache_dir,
force_download=force_download,
local_files_only=local_files_only,
)
except EntryNotFoundError:
# model doesn't use external data
pass
paths[attr_name] = Path(model_cache_path).name
new_model_save_dir = Path(model_cache_path).parent
preprocessors = maybe_load_preprocessors(model_id, subfolder=subfolder)
if use_merged is True:
decoder_path = new_model_save_dir / paths["last_decoder_merged_name"]
decoder_merged_path = new_model_save_dir / paths["last_decoder_merged_name"]
else:
decoder_path = new_model_save_dir / paths["last_decoder_model_name"]
decoder_without_past_path = new_model_save_dir / paths["last_decoder_model_name"]
if use_cache is True:
decoder_with_past_path = new_model_save_dir / paths["last_decoder_with_past_model_name"]
ort_inference_sessions = cls.load_model(
decoder_path=decoder_path,
decoder_with_past_path=None if use_merged is True or use_cache is False else decoder_with_past_path,
provider=provider,
session_options=session_options,
provider_options=provider_options,
)
if model_save_dir is None:
model_save_dir = new_model_save_dir
generation_config = None
try:
generation_config = GenerationConfig.from_pretrained(
model_id,
cache_dir=cache_dir,
force_download=force_download,
local_files_only=local_files_only,
use_auth_token=use_auth_token,
revision=revision,
subfolder=subfolder,
)
except OSError:
logger.info("Generation config file not found, using a generation config created from the model config.")
onnx_paths = []
if use_merged is False:
onnx_paths.append(decoder_without_past_path)
if use_cache is True:
onnx_paths.append(decoder_with_past_path)
else:
onnx_paths.append(decoder_merged_path)
return cls(
ort_inference_sessions[0],
config,
decoder_with_past_session=ort_inference_sessions[1],
use_cache=use_cache,
use_io_binding=use_io_binding,
model_save_dir=model_save_dir,
preprocessors=preprocessors,
generation_config=generation_config,
onnx_paths=onnx_paths,
)
@classmethod
def _from_transformers(
cls,
model_id: str,
config: "PretrainedConfig",
use_auth_token: Optional[Union[bool, str]] = None,
revision: str = "main",
force_download: bool = True,
cache_dir: Optional[str] = None,
subfolder: str = "",
local_files_only: bool = False,
trust_remote_code: bool = False,
use_cache: bool = True,
use_merged: bool = False,
provider: str = "CPUExecutionProvider",
session_options: Optional[onnxruntime.SessionOptions] = None,
provider_options: Optional[Dict[str, Any]] = None,
use_io_binding: Optional[bool] = None,
task: Optional[str] = None,
) -> "ORTModelDecoder":
if task is None:
task = cls._auto_model_to_task(cls.auto_model_class)
if use_cache is True:
task = task + "-with-past"
if use_cache is False and use_merged is True:
raise ValueError(
"The incompatible arguments use_cache=False, use_merged=True were passed to ORTModelForCausalLM.from_pretrained()."
" Please pass either use_cache=False, use_merged=False to disable past key value caching, or use_cache=True, use_merged=False"
" to disable the merging of the decoder not using / using past key and value."
)
save_dir = TemporaryDirectory()
save_dir_path = Path(save_dir.name)
main_export(
model_name_or_path=model_id,
output=save_dir_path,
task=task,
do_validation=False,
no_post_process=not use_merged,
subfolder=subfolder,
revision=revision,
cache_dir=cache_dir,
use_auth_token=use_auth_token,
local_files_only=local_files_only,
force_download=force_download,
trust_remote_code=trust_remote_code,
)
config.save_pretrained(save_dir_path)
maybe_save_preprocessors(model_id, save_dir_path, src_subfolder=subfolder)
return cls._from_pretrained(
save_dir_path,
config,
use_cache=use_cache,
use_merged=use_merged,
provider=provider,
session_options=session_options,
provider_options=provider_options,
use_io_binding=use_io_binding,
model_save_dir=save_dir,
)
def to(self, device: Union[torch.device, str, int]):
"""
Changes the ONNX Runtime provider according to the device.
Args:
device (`Union[torch.device, str, int]`):
Device ordinal for CPU/GPU supports. Setting this to -1 will leverage CPU, a positive will run
the model on the associated CUDA device id. You can pass native `torch.device` or a `str` too.
Returns:
`ORTModel`: the model placed on the requested device.
"""
device, provider_options = parse_device(device)
if device.type == "cuda" and self.providers[0] == "TensorrtExecutionProvider":
return self
provider = get_provider_for_device(device)
validate_provider_availability(provider) # raise error if the provider is not available
self.device = device
self.decoder.session.set_providers([provider], provider_options=[provider_options])
if self.decoder_with_past is not None:
self.decoder_with_past.session.set_providers([provider], provider_options=[provider_options])
self.providers = self.decoder.session.get_providers()
return self
class ORTModelForCausalLM(ORTModelDecoder, GenerationMixin):
"""
ONNX model with a causal language modeling head for ONNX Runtime inference.
"""
auto_model_class = AutoModelForCausalLM
main_input_name = "input_ids"
@add_start_docstrings_to_model_forward(
CAUSALLM_ONNX_MODEL_DOCSTRING.format("batch_size, sequence_length")
+ TEXT_GENERATION_EXAMPLE.format(
processor_class=_TOKENIZER_FOR_DOC,
model_class="ORTModelForCausalLM",
checkpoint="optimum/gpt2",
)
)
def forward(
self,
input_ids: torch.LongTensor = None,
attention_mask: Optional[torch.FloatTensor] = None,
past_key_values: Optional[Tuple[Tuple[torch.Tensor]]] = None,
labels: Optional[torch.LongTensor] = None,
**kwargs,
) -> CausalLMOutputWithCrossAttentions:
if past_key_values is None or self.use_cache is False:
outputs = self.decoder(
input_ids=input_ids,
attention_mask=attention_mask,
past_key_values=past_key_values,
labels=labels,
)
elif self.use_merged is True:
outputs = self.decoder(
input_ids=input_ids[:, -1:],
past_key_values=past_key_values,
attention_mask=attention_mask,
)
else:
outputs = self.decoder_with_past(
input_ids=input_ids[:, -1:],
past_key_values=past_key_values,
attention_mask=attention_mask,
labels=labels,
)
return CausalLMOutputWithCrossAttentions(
loss=outputs.get("loss", None), logits=outputs.logits, past_key_values=outputs.past_key_values
)
# Adapted from transformers.models.gpt2.modeling_gpt2.GPT2LMHeadModel.prepare_inputs_for_generation
def prepare_inputs_for_generation(self, input_ids, past_key_values=None, **kwargs):
# if model is used as a decoder in encoder-decoder model, the decoder attention mask is created on the fly
attention_mask = kwargs.get("attention_mask", None) # input_ids.new_ones(input_ids.shape)
use_cache = kwargs.get("use_cache", None)
return {
"input_ids": input_ids,
"past_key_values": past_key_values,
"use_cache": use_cache,
"position_ids": None,
"attention_mask": attention_mask,
}
# Copied from transformers.models.gpt2.modeling_gpt2.GPT2LMHeadModel._reorder_cache
@staticmethod
def _reorder_cache(past: Tuple[Tuple[torch.Tensor]], beam_idx: torch.Tensor) -> Tuple[Tuple[torch.Tensor]]:
return tuple(
tuple(past_state.index_select(0, beam_idx.to(past_state.device)) for past_state in layer_past)
for layer_past in past
)
def can_generate(self):
"""Returns True to validate the check that the model using `GenerationMixin.generate()` can indeed generate."""
return True