-
Notifications
You must be signed in to change notification settings - Fork 415
/
modeling_seq2seq.py
2111 lines (1833 loc) · 94 KB
/
modeling_seq2seq.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816
817
818
819
820
821
822
823
824
825
826
827
828
829
830
831
832
833
834
835
836
837
838
839
840
841
842
843
844
845
846
847
848
849
850
851
852
853
854
855
856
857
858
859
860
861
862
863
864
865
866
867
868
869
870
871
872
873
874
875
876
877
878
879
880
881
882
883
884
885
886
887
888
889
890
891
892
893
894
895
896
897
898
899
900
901
902
903
904
905
906
907
908
909
910
911
912
913
914
915
916
917
918
919
920
921
922
923
924
925
926
927
928
929
930
931
932
933
934
935
936
937
938
939
940
941
942
943
944
945
946
947
948
949
950
951
952
953
954
955
956
957
958
959
960
961
962
963
964
965
966
967
968
969
970
971
972
973
974
975
976
977
978
979
980
981
982
983
984
985
986
987
988
989
990
991
992
993
994
995
996
997
998
999
1000
# Copyright 2022 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
ORTModelForXXX classes related to seq2seq, allowing to run ONNX Models with ONNX Runtime using the same API as
Transformers.
"""
import copy
import logging
import shutil
import warnings
from abc import ABC, abstractmethod
from pathlib import Path
from tempfile import TemporaryDirectory
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple, Union
import numpy as np
import torch
from huggingface_hub import hf_hub_download
from huggingface_hub.constants import HUGGINGFACE_HUB_CACHE
from transformers import (
AutoModelForSeq2SeqLM,
AutoModelForSpeechSeq2Seq,
AutoModelForVision2Seq,
GenerationConfig,
Pix2StructForConditionalGeneration, # Pix2struct does not support AutoModel
)
from transformers.file_utils import add_end_docstrings, add_start_docstrings_to_model_forward
from transformers.generation.logits_process import WhisperTimeStampLogitsProcessor
from transformers.modeling_outputs import BaseModelOutput, Seq2SeqLMOutput
from transformers.models.auto.modeling_auto import MODEL_FOR_SPEECH_SEQ_2_SEQ_MAPPING_NAMES
from transformers.models.whisper.tokenization_whisper import TASK_IDS, TO_LANGUAGE_CODE
import onnxruntime as ort
from ..exporters.onnx import main_export
from ..onnx.utils import _get_external_data_paths
from ..utils import check_if_transformers_greater
from ..utils.file_utils import validate_file_exists
from ..utils.normalized_config import NormalizedConfigManager
from ..utils.save_utils import maybe_load_preprocessors, maybe_save_preprocessors
from .base import ORTDecoderForSeq2Seq, ORTEncoder
from .constants import (
DECODER_MERGED_ONNX_FILE_PATTERN,
DECODER_ONNX_FILE_PATTERN,
DECODER_WITH_PAST_ONNX_FILE_PATTERN,
ENCODER_ONNX_FILE_PATTERN,
)
from .modeling_ort import ONNX_MODEL_END_DOCSTRING, ORTModel
from .utils import (
ONNX_DECODER_NAME,
ONNX_DECODER_WITH_PAST_NAME,
ONNX_ENCODER_NAME,
get_provider_for_device,
parse_device,
validate_provider_availability,
)
if check_if_transformers_greater("4.25.0"):
from transformers.generation import GenerationMixin
else:
from transformers.generation_utils import GenerationMixin
from huggingface_hub.utils import EntryNotFoundError
if TYPE_CHECKING:
from transformers import PretrainedConfig
logger = logging.getLogger(__name__)
SEQ2SEQ_ENCODER_INPUTS_DOCSTRING = r"""
Args:
input_ids (`torch.LongTensor`):
Indices of input sequence tokens in the vocabulary of shape `(batch_size, encoder_sequence_length)`.
attention_mask (`torch.LongTensor`):
Mask to avoid performing attention on padding token indices, of shape
`(batch_size, encoder_sequence_length)`. Mask values selected in `[0, 1]`.
"""
SPEECH_ENCODER_INPUTS_DOCSTRING = r"""
Args:
input_features (`torch.FloatTensor`):
Mel / fbank features extracted from the raw speech waveform. `(batch_size, feature_size, encoder_sequence_length)`.
"""
VISION_ENCODER_INPUTS_DOCSTRING = r"""
Args:
pixel_values (`torch.FloatTensor`):
Features extracted from an Image. This tensor should be of shape `(batch_size, num_channels, height, width)`.
"""
PIX2STRUCT_INPUTS_DOCSTRING = r"""
Args:
flattened_patches (`torch.FloatTensor` of shape `(batch_size, sequence_length, num_channels x patch_height x patch_width)`):
Flattened and padded pixel values.
attention_mask (`torch.FloatTensor` of shape `(batch_size, sequence_length)`):
Mask to avoid performing attention on padding pixel values.
"""
DECODER_INPUTS_DOCSTRING = r"""
Args:
input_ids (`torch.LongTensor`):
Indices of decoder input sequence tokens in the vocabulary of shape `(batch_size, decoder_sequence_length)`.
encoder_hidden_states (`torch.FloatTensor`):
The encoder `last_hidden_state` of shape `(batch_size, encoder_sequence_length, hidden_size)`.
encoder_attention_mask (`torch.LongTensor`, *optional*):
Mask to avoid performing cross-attention on padding tokens indices of encoder `input_ids`.
past_key_values (`tuple(tuple(torch.FloatTensor), *optional*, defaults to `None`)`
Contains the precomputed key and value hidden states of the attention blocks used to speed up decoding.
The tuple is of length `config.n_layers` with each tuple having 2 tensors of shape
`(batch_size, num_heads, decoder_sequence_length, embed_size_per_head)` and 2 additional tensors of shape
`(batch_size, num_heads, encoder_sequence_length, embed_size_per_head)`.
"""
SEQ2SEQ_ONNX_MODEL_DOCSTRING = r"""
Args:
input_ids (`torch.LongTensor`):
Indices of input sequence tokens in the vocabulary of shape `(batch_size, encoder_sequence_length)`.
attention_mask (`torch.LongTensor`):
Mask to avoid performing attention on padding token indices, of shape
`(batch_size, encoder_sequence_length)`. Mask values selected in `[0, 1]`.
decoder_input_ids (`torch.LongTensor`):
Indices of decoder input sequence tokens in the vocabulary of shape `(batch_size, decoder_sequence_length)`.
encoder_outputs (`torch.FloatTensor`):
The encoder `last_hidden_state` of shape `(batch_size, encoder_sequence_length, hidden_size)`.
past_key_values (`tuple(tuple(torch.FloatTensor), *optional*, defaults to `None`)`
Contains the precomputed key and value hidden states of the attention blocks used to speed up decoding.
The tuple is of length `config.n_layers` with each tuple having 2 tensors of shape
`(batch_size, num_heads, decoder_sequence_length, embed_size_per_head)` and 2 additional tensors of shape
`(batch_size, num_heads, encoder_sequence_length, embed_size_per_head)`.
"""
SPEECH_SEQ2SEQ_ONNX_MODEL_DOCSTRING = r"""
Args:
input_features (`torch.FloatTensor`):
Mel features extracted from the raw speech waveform.
`(batch_size, feature_size, encoder_sequence_length)`.
decoder_input_ids (`torch.LongTensor`):
Indices of decoder input sequence tokens in the vocabulary of shape `(batch_size, decoder_sequence_length)`.
encoder_outputs (`torch.FloatTensor`):
The encoder `last_hidden_state` of shape `(batch_size, encoder_sequence_length, hidden_size)`.
past_key_values (`tuple(tuple(torch.FloatTensor), *optional*, defaults to `None`)`
Contains the precomputed key and value hidden states of the attention blocks used to speed up decoding.
The tuple is of length `config.n_layers` with each tuple having 2 tensors of shape
`(batch_size, num_heads, decoder_sequence_length, embed_size_per_head)` and 2 additional tensors of shape
`(batch_size, num_heads, encoder_sequence_length, embed_size_per_head)`.
"""
VISION_ENCODER_DECODER_SEQ2SEQ_ONNX_MODEL_DOCSTRING = r"""
Args:
pixel_values (`torch.FloatTensor`):
Features extracted from an Image. This tensor should be of shape
`(batch_size, num_channels, height, width)`.
decoder_input_ids (`torch.LongTensor`):
Indices of decoder input sequence tokens in the vocabulary of shape `(batch_size, decoder_sequence_length)`.
encoder_outputs (`torch.FloatTensor`):
The encoder `last_hidden_state` of shape `(batch_size, encoder_sequence_length, hidden_size)`.
past_key_values (`tuple(tuple(torch.FloatTensor), *optional*, defaults to `None`)`
Contains the precomputed key and value hidden states of the attention blocks used to speed up decoding.
The tuple is of length `config.n_layers` with each tuple having 2 tensors of shape
`(batch_size, num_heads, decoder_sequence_length, embed_size_per_head)` and 2 additional tensors of shape
`(batch_size, num_heads, encoder_sequence_length, embed_size_per_head)`.
"""
PIX2STRUCT_ONNX_MODEL_DOCSTRING = r"""
Args:
flattened_patches (`torch.FloatTensor` of shape `(batch_size, seq_length, hidden_size)`):
Flattened pixel patches. the `hidden_size` is obtained by the following formula: `hidden_size` =
`num_channels` * `patch_size` * `patch_size`
The process of flattening the pixel patches is done by `Pix2StructProcessor`.
attention_mask (`torch.FloatTensor` of shape `(batch_size, sequence_length)`, *optional*):
Mask to avoid performing attention on padding token indices.
decoder_input_ids (`torch.LongTensor` of shape `(batch_size, target_sequence_length)`, *optional*):
Indices of decoder input sequence tokens in the vocabulary.
Pix2StructText uses the `pad_token_id` as the starting token for `decoder_input_ids` generation. If
`past_key_values` is used, optionally only the last `decoder_input_ids` have to be input (see
`past_key_values`).
decoder_attention_mask (`torch.BoolTensor` of shape `(batch_size, target_sequence_length)`, *optional*):
Default behavior: generate a tensor that ignores pad tokens in `decoder_input_ids`. Causal mask will also
be used by default.
encoder_outputs (`tuple(tuple(torch.FloatTensor)`, *optional*):
Tuple consists of (`last_hidden_state`, `optional`: *hidden_states*, `optional`: *attentions*)
`last_hidden_state` of shape `(batch_size, sequence_length, hidden_size)` is a sequence of hidden states at
the output of the last layer of the encoder. Used in the cross-attention of the decoder.
past_key_values (`tuple(tuple(torch.FloatTensor), *optional*, defaults to `None`)`
Contains the precomputed key and value hidden states of the attention blocks used to speed up decoding.
The tuple is of length `config.n_layers` with each tuple having 2 tensors of shape
`(batch_size, num_heads, decoder_sequence_length, embed_size_per_head)` and 2 additional tensors of shape
`(batch_size, num_heads, encoder_sequence_length, embed_size_per_head)`.
"""
_TOKENIZER_FOR_DOC = "AutoTokenizer"
_PROCESSOR_FOR_DOC = "AutoProcessor"
_IMAGE_PROCESSER_FOR_DOC = "AutoImageProcessor"
TRANSLATION_EXAMPLE = r"""
Example of text generation:
```python
>>> from transformers import {processor_class}
>>> from optimum.onnxruntime import {model_class}
>>> tokenizer = {processor_class}.from_pretrained("{checkpoint}")
>>> model = {model_class}.from_pretrained("{checkpoint}")
>>> inputs = tokenizer("My name is Eustache and I like to", return_tensors="pt")
>>> gen_tokens = model.generate(**inputs)
>>> outputs = tokenizer.batch_decode(gen_tokens)
```
Example using `transformers.pipeline`:
```python
>>> from transformers import {processor_class}, pipeline
>>> from optimum.onnxruntime import {model_class}
>>> tokenizer = {processor_class}.from_pretrained("{checkpoint}")
>>> model = {model_class}.from_pretrained("{checkpoint}")
>>> onnx_translation = pipeline("translation_en_to_de", model=model, tokenizer=tokenizer)
>>> text = "My name is Eustache."
>>> pred = onnx_translation(text)
```
"""
AUTOMATIC_SPEECH_RECOGNITION_EXAMPLE = r"""
Example of text generation:
```python
>>> from transformers import {processor_class}
>>> from optimum.onnxruntime import {model_class}
>>> from datasets import load_dataset
>>> processor = {processor_class}.from_pretrained("{checkpoint}")
>>> model = {model_class}.from_pretrained("{checkpoint}")
>>> ds = load_dataset("hf-internal-testing/librispeech_asr_dummy", "clean", split="validation")
>>> inputs = processor.feature_extractor(ds[0]["audio"]["array"], return_tensors="pt")
>>> gen_tokens = model.generate(inputs=inputs.input_features)
>>> outputs = processor.tokenizer.batch_decode(gen_tokens)
```
Example using `transformers.pipeline`:
```python
>>> from transformers import {processor_class}, pipeline
>>> from optimum.onnxruntime import {model_class}
>>> from datasets import load_dataset
>>> processor = {processor_class}.from_pretrained("{checkpoint}")
>>> model = {model_class}.from_pretrained("{checkpoint}")
>>> speech_recognition = pipeline("automatic-speech-recognition", model=model, tokenizer=processor.tokenizer, feature_extractor=processor.feature_extractor)
>>> ds = load_dataset("hf-internal-testing/librispeech_asr_dummy", "clean", split="validation")
>>> pred = speech_recognition(ds[0]["audio"]["array"])
```
"""
IMAGE_TO_TEXT_EXAMPLE = r"""
Example of text generation:
```python
>>> from transformers import {processor_class}, {tokenizer_class}
>>> from optimum.onnxruntime import {model_class}
>>> from PIL import Image
>>> import requests
>>> processor = {processor_class}.from_pretrained("{checkpoint}")
>>> tokenizer = {tokenizer_class}.from_pretrained("{checkpoint}")
>>> model = {model_class}.from_pretrained("{checkpoint}", export=True)
>>> url = "http://images.cocodataset.org/val2017/000000039769.jpg"
>>> image = Image.open(requests.get(url, stream=True).raw)
>>> inputs = processor(image, return_tensors="pt")
>>> gen_tokens = model.generate(**inputs)
>>> outputs = tokenizer.batch_decode(gen_tokens, skip_special_tokens=True)
```
Example using `transformers.pipeline`:
```python
>>> from transformers import {processor_class}, {tokenizer_class}, pipeline
>>> from optimum.onnxruntime import {model_class}
>>> from PIL import Image
>>> import requests
>>> processor = {processor_class}.from_pretrained("{checkpoint}")
>>> tokenizer = {tokenizer_class}.from_pretrained("{checkpoint}")
>>> model = {model_class}.from_pretrained("{checkpoint}", export=True)
>>> url = "http://images.cocodataset.org/val2017/000000039769.jpg"
>>> image = Image.open(requests.get(url, stream=True).raw)
>>> image_to_text = pipeline("image-to-text", model=model, tokenizer=tokenizer, feature_extractor=processor, image_processor=processor)
>>> pred = image_to_text(image)
```
"""
PIX2STRUCT_EXAMPLE = r"""
Example of pix2struct:
```python
>>> from transformers import {processor_class}
>>> from optimum.onnxruntime import {model_class}
>>> from PIL import Image
>>> import requests
>>> processor = {processor_class}.from_pretrained("{checkpoint}")
>>> model = {model_class}.from_pretrained("{checkpoint}", export=True, use_io_binding=True)
>>> url = "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/transformers/tasks/ai2d-demo.jpg"
>>> image = Image.open(requests.get(url, stream=True).raw)
>>> question = "What does the label 15 represent? (1) lava (2) core (3) tunnel (4) ash cloud"
>>> inputs = processor(images=image, text=question, return_tensors="pt")
>>> gen_tokens = model.generate(**inputs)
>>> outputs = processor.batch_decode(gen_tokens, skip_special_tokens=True)
```
"""
class ORTEncoderForSpeech(ORTEncoder):
"""
Encoder model for ONNX Runtime inference for Whisper model.
Args:
session (`ort.InferenceSession`):
The ONNX Runtime inference session associated to the encoder.
"""
@add_start_docstrings_to_model_forward(SPEECH_ENCODER_INPUTS_DOCSTRING)
def forward(
self,
input_features: torch.FloatTensor,
attention_mask: torch.LongTensor,
**kwargs,
) -> BaseModelOutput:
use_torch = isinstance(input_features, torch.Tensor)
self.parent_model.raise_on_numpy_input_io_binding(use_torch)
if self.parent_model.device.type == "cuda" and self.parent_model.use_io_binding:
model_inputs = (
[input_features, attention_mask] if "attention_mask" in self.input_names else [input_features]
)
io_binding, output_shapes, output_buffers = self.parent_model._prepare_io_binding(
self.session,
*model_inputs,
ordered_input_names=self._ordered_input_names,
)
io_binding.synchronize_inputs()
self.session.run_with_iobinding(io_binding)
io_binding.synchronize_outputs()
last_hidden_state = output_buffers["last_hidden_state"].view(output_shapes["last_hidden_state"])
else:
if use_torch:
onnx_inputs = {"input_features": input_features.cpu().detach().numpy()}
if "attention_mask" in self.input_names:
onnx_inputs["attention_mask"] = attention_mask.cpu().detach().numpy()
else:
onnx_inputs = {"input_features": input_features}
if "attention_mask" in self.input_names:
onnx_inputs["attention_mask"] = attention_mask
# TODO: Replace with a better solution
# attention_mask is exported with int64 datatype and tokenizer produces int32 input
# for speech2text model. Hence, the input is type casted for inference.
if "attention_mask" in self.input_names:
if self.session.get_inputs()[1].type == "tensor(int64)":
onnx_inputs["attention_mask"] = onnx_inputs["attention_mask"].astype(np.int64)
outputs = self.session.run(None, onnx_inputs)
last_hidden_state = outputs[self.output_names["last_hidden_state"]]
if use_torch:
last_hidden_state = torch.from_numpy(last_hidden_state).to(self.device)
return BaseModelOutput(last_hidden_state=last_hidden_state)
class ORTEncoderForVisionEncoderDecoder(ORTEncoder):
"""
Encoder model for ONNX Runtime inference for VisionEncoderDecoder models.
Args:
session (`ort.InferenceSession`):
The ONNX Runtime inference session associated to the encoder.
"""
@add_start_docstrings_to_model_forward(VISION_ENCODER_INPUTS_DOCSTRING)
def forward(
self,
pixel_values: torch.FloatTensor,
**kwargs,
) -> BaseModelOutput:
use_torch = isinstance(pixel_values, torch.Tensor)
self.parent_model.raise_on_numpy_input_io_binding(use_torch)
if self.parent_model.device.type == "cuda" and self.parent_model.use_io_binding:
known_output_shapes = self.compute_encoder_known_output_shapes(pixel_values)
io_binding, output_shapes, output_buffers = self.parent_model._prepare_io_binding(
self.session,
pixel_values,
known_output_shapes=known_output_shapes,
ordered_input_names=self._ordered_input_names,
)
io_binding.synchronize_inputs()
self.session.run_with_iobinding(io_binding)
io_binding.synchronize_outputs()
last_hidden_state = output_buffers["last_hidden_state"].view(output_shapes["last_hidden_state"])
else:
if use_torch:
onnx_inputs = {"pixel_values": pixel_values.cpu().detach().numpy()}
else:
onnx_inputs = {"pixel_values": pixel_values}
outputs = self.session.run(None, onnx_inputs)
last_hidden_state = outputs[self.output_names["last_hidden_state"]]
if use_torch:
last_hidden_state = torch.from_numpy(last_hidden_state).to(self.device)
return BaseModelOutput(last_hidden_state=last_hidden_state)
def compute_encoder_known_output_shapes(self, pixel_values: torch.FloatTensor) -> Dict[str, List[int]]:
if self.normalized_config.config.model_type == "donut-swin":
# TODO: kind of weird to export to ONNX with dynamic output shape if it is in fact static...
encoder_sequence_length = (
self.normalized_config.config.image_size[0]
* self.normalized_config.config.image_size[1]
// self.normalized_config.config.hidden_size
)
elif self.normalized_config.config.model_type in ["vit", "deit"]:
return None
else:
raise ValueError(
f"Unsupported encoder model type {self.normalized_config.config.model_type} for ORTForVisionSeq2Seq with IOBinding."
"Currently supported models are vit, donut-swin and deit."
"Please submit a PR to add support for this model type."
)
return {
"last_hidden_state": [
pixel_values.shape[0], # batch size
encoder_sequence_length,
self.normalized_config.config.hidden_size,
]
}
class ORTEncoderForPix2Struct(ORTEncoder):
"""
Encoder model for ONNX Runtime inference for Pix2Struct.
Args:
session (`ort.InferenceSession`):
The ONNX Runtime inference session associated to the encoder.
"""
@add_start_docstrings_to_model_forward(PIX2STRUCT_INPUTS_DOCSTRING)
def forward(
self,
flattened_patches: torch.FloatTensor,
attention_mask: torch.LongTensor,
**kwargs,
) -> BaseModelOutput:
use_torch = isinstance(flattened_patches, torch.Tensor)
self.parent_model.raise_on_numpy_input_io_binding(use_torch)
if self.parent_model.device.type == "cuda" and self.parent_model.use_io_binding:
model_inputs = (
[flattened_patches, attention_mask] if "attention_mask" in self.input_names else [flattened_patches]
)
io_binding, output_shapes, output_buffers = self.parent_model._prepare_io_binding(
self.session,
*model_inputs,
ordered_input_names=self._ordered_input_names,
)
io_binding.synchronize_inputs()
self.session.run_with_iobinding(io_binding)
io_binding.synchronize_outputs()
last_hidden_state = output_buffers["last_hidden_state"].view(output_shapes["last_hidden_state"])
else:
if use_torch:
onnx_inputs = {"flattened_patches": flattened_patches.cpu().detach().numpy()}
if "attention_mask" in self.input_names:
onnx_inputs["attention_mask"] = attention_mask.cpu().detach().numpy()
else:
onnx_inputs = {"flattened_patches": flattened_patches}
if "attention_mask" in self.input_names:
onnx_inputs["attention_mask"] = attention_mask
if "attention_mask" in self.input_names:
if self.session.get_inputs()[1].type == "tensor(int64)":
onnx_inputs["attention_mask"] = onnx_inputs["attention_mask"].astype(np.int64)
outputs = self.session.run(None, onnx_inputs)
last_hidden_state = outputs[self.output_names["last_hidden_state"]]
if use_torch:
last_hidden_state = torch.from_numpy(last_hidden_state).to(self.device)
return BaseModelOutput(last_hidden_state=last_hidden_state)
class ORTModelForConditionalGeneration(ORTModel, ABC):
"""
Sequence-to-sequence model with a language modeling head for ONNX Runtime inference.
Important attributes:
config ([`PretrainedConfig`]):
Instance of the configuration associated to the model. Initializing with a config file does
not load the weights associated with the model, only the configuration.
use_io_binding (`Optional[bool]`, defaults to `None`):
Whether use IOBinding during inference to avoid memory copy between the host and devices. Defaults to `True`
if the device is CUDA, otherwise defaults to `False`.
use_cache (`bool`):
Whether or not past key/values cache should be used. It is determined by whether an InferenceSession for
that was provided or not.
providers (`List[str`]):
The list of execution providers the model is running on.
encoder (`ORTEncoder`):
The encoder model.
decoder (`ORTDecoderForSeq2Seq`):
The decoder model.
decoder_with_past (`Optional[ORTDecoderForSeq2Seq]`):
The decoder model handling the past key/values if `use_cache=True`, else `None`.
Other attributes:
encoder_file_name (`str`, defaults to `optimum.onnxruntime.utils.ONNX_ENCODER_NAME`):
The name of the ONNX file containing the encoder part of the model.
decoder_file_name (`str`, defaults to `optimum.onnxruntime.utils.ONNX_DECODER_NAME`):
The name of the ONNX file containing the decoder part of the model.
decoder_file_with_past_name (`str`, defaults to `optimum.onnxruntime.utils.ONNX_DECODER_WITH_PAST_NAME`):
The name of the ONNX file containing the decoder with past key/values part of the model.
model_save_dir (`str`, defaults to `""`):
The directory under which the model exported to ONNX was saved.
"""
# Used in from_transformers to export model to onnxORTEncoder
base_model_prefix = "onnx_model"
_supports_cache_class = False
def __init__(
self,
encoder_session: ort.InferenceSession,
decoder_session: ort.InferenceSession,
config: "PretrainedConfig",
onnx_paths: List[str],
decoder_with_past_session: Optional[ort.InferenceSession] = None,
use_cache: bool = True,
use_io_binding: Optional[bool] = None,
model_save_dir: Optional[Union[str, Path, TemporaryDirectory]] = None,
preprocessors: Optional[List] = None,
generation_config: Optional[GenerationConfig] = None,
**kwargs,
):
"""
Args:
encoder_session (`ort.InferenceSession`):
The ONNX Runtime inference session associated to the encoder.
decoder_session (`ort.InferenceSession`):
The ONNX Runtime inference session associated to the decoder.
config ([`PretrainedConfig`]):
`config` is an instance of the configuration associated to the model. Initializing with a config file
does not load the weights associated with the model, only the configuration.
onnx_paths (`List[str]`):
Path to ONNX files associated with the model.
decoder_with_past_session (`Optional[ort.InferenceSession]`, *optional*):
The ONNX Runtime inference session associated to the decoder with past key values.
use_io_binding (`bool`, *optional*, defaults to `None`):
Whether use IOBinding during inference to avoid memory copy between the host and devices. Defaults to
`True` if the device is CUDA, otherwise defaults to `False`.
model_save_dir (`str`, *optional*, defaults to `""`):
The directory under which the model exported to ONNX was saved.
preprocessors (`Optional[List]`, defaults to `None`):
The list of the preprocessors (tokenizer, processor, feature_extractor) to save alongside the ORTModel.
generation_config (`Optional[GenerationConfig]`, defaults to `None`):
The generation configuration used by default when calling `generate()`.
Refer to https://huggingface.co/docs/transformers/main/en/main_classes/text_generation#transformers.GenerationMixin.generate.
"""
# TODO: remove at version 2.0
def show_deprecated_argument(arg_name):
if kwargs.pop(arg_name, None) is not None:
logger.warning(
f"The {arg_name} argument to create an {self.__class__.__name__} is deprecated, and not used "
"anymore."
)
show_deprecated_argument("last_encoder_model_name")
show_deprecated_argument("last_decoder_model_name")
show_deprecated_argument("last_decoder_with_past_model_name")
if kwargs:
raise ValueError(
f"{self.__class__.__name__} received {', '.join(kwargs.keys())}, but do not accept those arguments."
)
ABC.__init__(self)
if use_io_binding is None:
if decoder_session.get_providers()[0] == "CUDAExecutionProvider":
use_io_binding = True
else:
use_io_binding = False
self.shared_attributes_init(
encoder_session,
use_io_binding=use_io_binding,
model_save_dir=model_save_dir,
preprocessors=preprocessors,
)
self.config = config
self.name_or_path = config.name_or_path
self.onnx_paths = onnx_paths
self.use_cache = use_cache
if use_cache is True:
# Auto-detect whether the provided session is a merged non-past / with-past or not
# TODO: make __init__ private and pass `use_merged` as an argument
use_merged = "use_cache_branch" in [inp.name for inp in decoder_session.get_inputs()]
if use_merged is True and decoder_with_past_session is not None:
raise ValueError(
"Detected a merged decoder, but decoder_with_past_session was provided."
"Please only set decoder_session, or provide a non-merged decoder_session."
)
if use_cache is True and use_merged is False and decoder_with_past_session is None:
raise ValueError(
"The parameter use_cache was set as True, but neither decoder_with_past_session was passed"
" nor a use_cache branch can be found in the decoder_session."
" Please pass a decoder_with_past_session or set use_cache=False."
)
else:
use_merged = False
if decoder_with_past_session is not None:
raise ValueError(
"The parameter decoder_with_past_session was passed, although use_cache is False."
"Please pass use_cache=True for decoder_with_past_session to be used."
)
if use_cache is False and use_io_binding is True:
raise ValueError(
"When using CUDAExecutionProvider, the parameters combination use_cache=False, use_io_binding=True"
" is not supported. Please either pass use_cache=True, use_io_binding=True (default),"
" or use_cache=False, use_io_binding=False."
)
self.use_merged = use_merged
self.encoder = self._initialize_encoder(encoder_session)
self.encoder_model_path = Path(encoder_session._model_path)
self.encoder_model_name = self.encoder_model_path.name
self.decoder = ORTDecoderForSeq2Seq(decoder_session, self)
self.decoder_model_path = Path(decoder_session._model_path)
self.decoder_model_name = self.decoder_model_path.name
# If a decoder_with_past_path is provided, an inference session for the decoder with past key/values as inputs
# will be enabled
self.decoder_with_past = None
self.decoder_with_past_model_path = None
self.decoder_with_past_model_name = None
if self.use_cache is True and self.use_merged is False:
self.decoder_with_past = ORTDecoderForSeq2Seq(decoder_with_past_session, self)
self.decoder_with_past_model_path = Path(decoder_with_past_session._model_path)
self.decoder_with_past_model_name = self.decoder_with_past_model_path.name
if generation_config is None:
generation_config = GenerationConfig.from_model_config(config)
self.generation_config = generation_config
@abstractmethod
def _initialize_encoder(self, session: ort.InferenceSession) -> ORTEncoder:
pass
@staticmethod
def load_model(
encoder_path: Union[str, Path],
decoder_path: Union[str, Path],
decoder_with_past_path: Optional[Union[str, Path]] = None,
provider: str = "CPUExecutionProvider",
session_options: Optional[ort.SessionOptions] = None,
provider_options: Optional[Dict] = None,
):
"""
Creates an instance of [`~optimum.onnxruntime.modeling_seq2seq.ORTModelForConditionalGeneration`].
Three inference sessions will be created for respectively the encoder, decoder and decoder with past key values
models. The default provider is `CPUExecutionProvider` to match the default behaviour in PyTorch/TensorFlow/JAX.
Args:
encoder_path (`Union[str, Path]`):
The path of the encoder ONNX model.
decoder_path (`Union[str, Path]`):
The path of the decoder ONNX model.
decoder_with_past_path (`Optional[Union[str, Path]]`, *optional*):
The path of the decoder with past key values ONNX model.
provider (`str`, *optional*, defaults to `"CPUExecutionProvider"`):
ONNX Runtime provider to use for loading the model. See https://onnxruntime.ai/docs/execution-providers/
for possible providers.
session_options (`Optional[ort.SessionOptions]`, *optional*),:
ONNX Runtime session options to use for loading the model. Defaults to `None`.
provider_options (`Optional[Dict]`, *optional*):
Provider option dictionary corresponding to the provider used. See available options
for each provider: https://onnxruntime.ai/docs/api/c/group___global.html . Defaults to `None`.
"""
encoder_session = ORTModel.load_model(encoder_path, provider, session_options, provider_options)
decoder_session = ORTModel.load_model(decoder_path, provider, session_options, provider_options)
decoder_with_past_session = None
# If a decoder_with_past_path is provided, an inference session for the decoder with past key/values as inputs
# will be enabled
if decoder_with_past_path is not None:
decoder_with_past_session = ORTModel.load_model(
decoder_with_past_path, provider, session_options, provider_options
)
return encoder_session, decoder_session, decoder_with_past_session
def _save_pretrained(self, save_directory: Union[str, Path]):
"""
Saves the model encoder, decoder and decoder with past key values as well as its configuration file to a
directory, so that it can be re-loaded using the
[`~optimum.onnxruntime.modeling_seq2seq.ORTModelForSeq2SeqLM.from_pretrained`] class method.
Args:
save_directory (`Union[str, Path`]):
The directory where to save the model files.
"""
save_directory = Path(save_directory)
src_paths = [Path(path) for path in self.onnx_paths]
dst_paths = [save_directory / path.name for path in src_paths]
# add external data paths in case of large models
src_paths, dst_paths = _get_external_data_paths(src_paths, dst_paths)
for src_path, dst_path in zip(src_paths, dst_paths):
shutil.copyfile(src_path, dst_path)
self.generation_config.save_pretrained(save_directory)
@classmethod
def _from_pretrained(
cls,
model_id: Union[str, Path],
config: "PretrainedConfig",
use_auth_token: Optional[Union[bool, str]] = None,
token: Optional[Union[bool, str]] = None,
revision: Optional[str] = None,
force_download: bool = False,
cache_dir: str = HUGGINGFACE_HUB_CACHE,
encoder_file_name: str = ONNX_ENCODER_NAME,
decoder_file_name: str = ONNX_DECODER_NAME,
decoder_with_past_file_name: str = ONNX_DECODER_WITH_PAST_NAME,
subfolder: str = "",
local_files_only: bool = False,
use_cache: bool = True,
use_merged: Optional[bool] = None,
provider: str = "CPUExecutionProvider",
session_options: Optional[ort.SessionOptions] = None,
provider_options: Optional[Dict[str, Any]] = None,
use_io_binding: Optional[bool] = None,
model_save_dir: Optional[Union[str, Path, TemporaryDirectory]] = None,
**kwargs,
):
if use_auth_token is not None:
warnings.warn(
"The `use_auth_token` argument is deprecated and will be removed soon. Please use the `token` argument instead.",
FutureWarning,
)
if token is not None:
raise ValueError("You cannot use both `use_auth_token` and `token` arguments at the same time.")
token = use_auth_token
model_path = Path(model_id)
# We do not implement the logic for use_cache=False, use_merged=True
if use_cache is False:
if use_merged is True:
raise ValueError(
"The parameters combination use_cache=False, use_merged=True is not supported."
" To use a merged decoder, past key values must be used."
)
use_merged = False
decoder_merged_path = None
# We use `is not False` here to include two cases: use_merged = None (in which case we auto-detect it),
# and use_merged = True (explicitely specified by the user)
if use_merged is not False:
try:
decoder_merged_path = ORTModelForConditionalGeneration.infer_onnx_filename(
model_id,
[DECODER_MERGED_ONNX_FILE_PATTERN],
argument_name=None,
subfolder=subfolder,
token=token,
revision=revision,
)
use_merged = True
decoder_path = decoder_merged_path
except FileNotFoundError as e:
if use_merged is True:
raise FileNotFoundError(
"The parameter `use_merged=True` was passed to ORTModelForCausalLM.from_pretrained()"
" but no ONNX file for a merged decoder could be found in"
f" {str(Path(model_id, subfolder))}, with the error: {e}"
)
use_merged = False
decoder_without_past_path = None
decoder_with_past_path = None
if use_merged is False:
if not validate_file_exists(model_id, decoder_file_name, subfolder=subfolder, revision=revision):
decoder_without_past_path = ORTModelForConditionalGeneration.infer_onnx_filename(
model_id,
[DECODER_ONNX_FILE_PATTERN],
"decoder_file_name",
subfolder=subfolder,
token=token,
revision=revision,
)
else:
decoder_without_past_path = model_path / subfolder / decoder_file_name
decoder_path = decoder_without_past_path
decoder_regular_onnx_filenames = ORTModelForConditionalGeneration._generate_regular_names_for_filename(
ONNX_DECODER_NAME
)
if decoder_path.name not in decoder_regular_onnx_filenames:
logger.warning(
f"The ONNX file {decoder_path.name} is not a regular name used in optimum.onnxruntime that are {decoder_regular_onnx_filenames}, the "
f"{cls.__name__} might not behave as expected."
)
# If the decoder without / with past has been merged, we do not need to look for any additional file
if use_cache is True and use_merged is False:
if not validate_file_exists(
model_id, decoder_with_past_file_name, subfolder=subfolder, revision=revision
):
try:
decoder_with_past_path = ORTModelForConditionalGeneration.infer_onnx_filename(
model_id,
[DECODER_WITH_PAST_ONNX_FILE_PATTERN],
"decoder_with_past_file_name",
subfolder=subfolder,
token=token,
revision=revision,
)
except FileNotFoundError as e:
raise FileNotFoundError(
"The parameter `use_cache=True` was passed to ORTModelForCausalLM.from_pretrained()"
" but no ONNX file using past key values could be found in"
f" {str(Path(model_id, subfolder))}, with the error: {e}"
)
else:
decoder_with_past_path = model_path / subfolder / decoder_with_past_file_name
decoder_path = decoder_without_past_path
decoder_with_past_regular_onnx_filenames = (
ORTModelForConditionalGeneration._generate_regular_names_for_filename(ONNX_DECODER_WITH_PAST_NAME)
)
if decoder_with_past_path.name not in decoder_with_past_regular_onnx_filenames:
logger.warning(
f"The ONNX file {decoder_with_past_path.name} is not a regular name used in optimum.onnxruntime that are {decoder_with_past_regular_onnx_filenames}, "
f"the {cls.__name__} might not behave as expected."
)
if not validate_file_exists(model_id, encoder_file_name, subfolder=subfolder, revision=revision):
encoder_path = ORTModelForConditionalGeneration.infer_onnx_filename(
model_id,
[ENCODER_ONNX_FILE_PATTERN],
"encoder_file_name",
subfolder=subfolder,
token=token,
revision=revision,
)
else:
encoder_path = model_path / subfolder / encoder_file_name
encoder_regular_onnx_filenames = ORTModelForConditionalGeneration._generate_regular_names_for_filename(
ONNX_ENCODER_NAME
)
if encoder_path.name not in encoder_regular_onnx_filenames:
logger.warning(
f"The ONNX file {encoder_path.name} is not a regular name used in optimum.onnxruntime, the "
"ORTModelForConditionalGeneration might not behave as expected."
)
preprocessors = None
if model_path.is_dir():
new_model_save_dir = model_path
preprocessors = maybe_load_preprocessors(model_id)
else:
attribute_name_to_filename = {
"last_encoder_model_name": encoder_path.name,
"last_decoder_model_name": decoder_path.name if use_merged is False else None,
"last_decoder_with_past_model_name": (
decoder_with_past_path.name if (use_merged is False and use_cache is True) else None
),
"last_decoder_merged_name": decoder_merged_path.name if use_merged is True else None,
}
paths = {}
for attr_name, filename in attribute_name_to_filename.items():
if filename is None:
continue
model_cache_path = hf_hub_download(
repo_id=model_id,
subfolder=subfolder,
filename=filename,
token=token,
revision=revision,
cache_dir=cache_dir,
force_download=force_download,
local_files_only=local_files_only,
)
# try download external data
try:
hf_hub_download(
repo_id=model_id,
subfolder=subfolder,
filename=filename + "_data",
token=token,
revision=revision,
cache_dir=cache_dir,
force_download=force_download,
local_files_only=local_files_only,
)
except EntryNotFoundError:
# model doesn't use external data
pass
paths[attr_name] = Path(model_cache_path).name
new_model_save_dir = Path(model_cache_path).parent
preprocessors = maybe_load_preprocessors(model_id, subfolder=subfolder)
if use_merged is True:
decoder_path = new_model_save_dir / paths["last_decoder_merged_name"]
decoder_merged_path = new_model_save_dir / paths["last_decoder_merged_name"]
else:
decoder_path = new_model_save_dir / paths["last_decoder_model_name"]
decoder_without_past_path = new_model_save_dir / paths["last_decoder_model_name"]
if use_cache is True:
decoder_with_past_path = new_model_save_dir / paths["last_decoder_with_past_model_name"]
encoder_path = new_model_save_dir / paths["last_encoder_model_name"]
ort_inference_sessions = cls.load_model(
encoder_path=encoder_path,
decoder_path=decoder_path,
decoder_with_past_path=None if use_merged is True or use_cache is False else decoder_with_past_path,
provider=provider,
session_options=session_options,
provider_options=provider_options,
)
if model_save_dir is None:
model_save_dir = new_model_save_dir
generation_config = None
try:
generation_config = GenerationConfig.from_pretrained(
model_id,
cache_dir=cache_dir,