-
Notifications
You must be signed in to change notification settings - Fork 405
/
model_configs.py
2241 lines (1805 loc) 路 90.1 KB
/
model_configs.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816
817
818
819
820
821
822
823
824
825
826
827
828
829
830
831
832
833
834
835
836
837
838
839
840
841
842
843
844
845
846
847
848
849
850
851
852
853
854
855
856
857
858
859
860
861
862
863
864
865
866
867
868
869
870
871
872
873
874
875
876
877
878
879
880
881
882
883
884
885
886
887
888
889
890
891
892
893
894
895
896
897
898
899
900
901
902
903
904
905
906
907
908
909
910
911
912
913
914
915
916
917
918
919
920
921
922
923
924
925
926
927
928
929
930
931
932
933
934
935
936
937
938
939
940
941
942
943
944
945
946
947
948
949
950
951
952
953
954
955
956
957
958
959
960
961
962
963
964
965
966
967
968
969
970
971
972
973
974
975
976
977
978
979
980
981
982
983
984
985
986
987
988
989
990
991
992
993
994
995
996
997
998
999
1000
# coding=utf-8
# Copyright 2022 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Model specific ONNX configurations."""
import random
from pathlib import Path
from typing import TYPE_CHECKING, Any, Dict, List, Literal, Optional, Tuple, Union
from packaging import version
from transformers.utils import is_tf_available
from ...onnx import merge_decoders
from ...utils import (
DEFAULT_DUMMY_SHAPES,
BloomDummyPastKeyValuesGenerator,
DummyAudioInputGenerator,
DummyCodegenDecoderTextInputGenerator,
DummyDecoderTextInputGenerator,
DummyEncodecInputGenerator,
DummyInputGenerator,
DummyIntGenerator,
DummyPastKeyValuesGenerator,
DummyPix2StructInputGenerator,
DummyPointsGenerator,
DummySeq2SeqDecoderTextInputGenerator,
DummySeq2SeqPastKeyValuesGenerator,
DummySpeechT5InputGenerator,
DummyTextInputGenerator,
DummyTimestepInputGenerator,
DummyVisionEmbeddingsGenerator,
DummyVisionEncoderDecoderPastKeyValuesGenerator,
DummyVisionInputGenerator,
DummyXPathSeqInputGenerator,
FalconDummyPastKeyValuesGenerator,
GemmaDummyPastKeyValuesGenerator,
GPTBigCodeDummyPastKeyValuesGenerator,
MistralDummyPastKeyValuesGenerator,
NormalizedConfig,
NormalizedEncoderDecoderConfig,
NormalizedSeq2SeqConfig,
NormalizedTextAndVisionConfig,
NormalizedTextConfig,
NormalizedTextConfigWithGQA,
NormalizedVisionConfig,
is_diffusers_available,
logging,
)
from ...utils.normalized_config import NormalizedConfigManager
from .base import ConfigBehavior, OnnxConfig, OnnxConfigWithPast, OnnxSeq2SeqConfigWithPast
from .config import (
AudioOnnxConfig,
AudioToTextOnnxConfig,
EncoderDecoderBaseOnnxConfig,
TextAndVisionOnnxConfig,
TextDecoderOnnxConfig,
TextDecoderWithPositionIdsOnnxConfig,
TextEncoderOnnxConfig,
TextSeq2SeqOnnxConfig,
VisionOnnxConfig,
)
from .constants import ONNX_DECODER_MERGED_NAME, ONNX_DECODER_NAME, ONNX_DECODER_WITH_PAST_NAME
from .model_patcher import (
FalconModelPatcher,
MusicgenModelPatcher,
SAMModelPatcher,
SentenceTransformersCLIPPatcher,
SentenceTransformersTransformerPatcher,
SpeechT5ModelPatcher,
VisionEncoderDecoderPatcher,
WavLMModelPatcher,
)
if TYPE_CHECKING:
from transformers import PretrainedConfig
from transformers.modeling_utils import PreTrainedModel
from .model_patcher import ModelPatcher
if is_tf_available():
from transformers.modeling_tf_utils import TFPreTrainedModel
if is_diffusers_available():
from diffusers import ModelMixin
logger = logging.get_logger(__name__)
class BertOnnxConfig(TextEncoderOnnxConfig):
NORMALIZED_CONFIG_CLASS = NormalizedTextConfig
ATOL_FOR_VALIDATION = 1e-4
DEFAULT_ONNX_OPSET = 14 # now uses F.scaled_dot_product_attention by default for torch>=2.1.1.
@property
def inputs(self) -> Dict[str, Dict[int, str]]:
if self.task == "multiple-choice":
dynamic_axis = {0: "batch_size", 1: "num_choices", 2: "sequence_length"}
else:
dynamic_axis = {0: "batch_size", 1: "sequence_length"}
return {
"input_ids": dynamic_axis,
"attention_mask": dynamic_axis,
"token_type_ids": dynamic_axis,
}
class AlbertOnnxConfig(BertOnnxConfig):
DEFAULT_ONNX_OPSET = 11
class ConvBertOnnxConfig(BertOnnxConfig):
DEFAULT_ONNX_OPSET = 11
class ElectraOnnxConfig(BertOnnxConfig):
DEFAULT_ONNX_OPSET = 11
class RoFormerOnnxConfig(BertOnnxConfig):
DEFAULT_ONNX_OPSET = 11
class SqueezeBertOnnxConfig(BertOnnxConfig):
DEFAULT_ONNX_OPSET = 11
class MobileBertOnnxConfig(BertOnnxConfig):
DEFAULT_ONNX_OPSET = 11
class NystromformerOnnxConfig(BertOnnxConfig):
DEFAULT_ONNX_OPSET = 11
class XLMOnnxConfig(BertOnnxConfig):
DEFAULT_ONNX_OPSET = 11
class SplinterOnnxConfig(BertOnnxConfig):
DEFAULT_ONNX_OPSET = 11
class DistilBertOnnxConfig(BertOnnxConfig):
DEFAULT_ONNX_OPSET = 11
@property
def inputs(self) -> Dict[str, Dict[int, str]]:
if self.task == "multiple-choice":
dynamic_axis = {0: "batch_size", 1: "num_choices", 2: "sequence_length"}
else:
dynamic_axis = {0: "batch_size", 1: "sequence_length"}
return {"input_ids": dynamic_axis, "attention_mask": dynamic_axis}
class MPNetOnnxConfig(DistilBertOnnxConfig):
DEFAULT_ONNX_OPSET = 12 # For lower opsets, results in: Type 'tensor(int64)' of input parameter (/0/auto_model/encoder/Add_1_output_0) of operator (Min) in node (/0/auto_model/encoder/Min) is invalid.
class RobertaOnnxConfig(DistilBertOnnxConfig):
pass
class CamembertOnnxConfig(DistilBertOnnxConfig):
pass
class FlaubertOnnxConfig(BertOnnxConfig):
DEFAULT_ONNX_OPSET = 11
class IBertOnnxConfig(DistilBertOnnxConfig):
pass
class XLMRobertaOnnxConfig(DistilBertOnnxConfig):
pass
class DebertaOnnxConfig(BertOnnxConfig):
DEFAULT_ONNX_OPSET = 12
@property
def inputs(self) -> Dict[str, Dict[int, str]]:
common_inputs = super().inputs
if self._config.type_vocab_size == 0:
common_inputs.pop("token_type_ids")
return common_inputs
class MarkupLMOnnxConfig(BertOnnxConfig):
DEFAULT_ONNX_OPSET = 11
DUMMY_INPUT_GENERATOR_CLASSES = (
DummyTextInputGenerator,
DummyXPathSeqInputGenerator,
)
@property
def inputs(self) -> Dict[str, Dict[int, str]]:
dynamic_axis = {0: "batch_size", 1: "sequence_length"}
xpath_dynamic_axis = {0: "batch_size", 1: "sequence_length", 2: "max_depth"}
return {
"input_ids": dynamic_axis,
"attention_mask": dynamic_axis,
"token_type_ids": dynamic_axis,
"xpath_subs_seq": xpath_dynamic_axis,
"xpath_tags_seq": xpath_dynamic_axis,
}
class DebertaV2OnnxConfig(DebertaOnnxConfig):
pass
class EsmOnnxConfig(TextEncoderOnnxConfig):
NORMALIZED_CONFIG_CLASS = NormalizedTextConfig
ATOL_FOR_VALIDATION = 1e-4
DEFAULT_ONNX_OPSET = 12
@property
def inputs(self) -> Dict[str, Dict[int, str]]:
dynamic_axis = {0: "batch_size", 1: "sequence_length"}
return {
"input_ids": dynamic_axis,
"attention_mask": dynamic_axis,
}
class GPT2OnnxConfig(TextDecoderWithPositionIdsOnnxConfig):
DEFAULT_ONNX_OPSET = 13
NORMALIZED_CONFIG_CLASS = NormalizedTextConfig.with_args(num_layers="n_layer", num_attention_heads="n_head")
class GPTJOnnxConfig(GPT2OnnxConfig):
pass
class CodeGenOnnxConfig(GPT2OnnxConfig):
pass
class ImageGPTOnnxConfig(GPT2OnnxConfig):
pass
class GPTNeoOnnxConfig(TextDecoderWithPositionIdsOnnxConfig):
DEFAULT_ONNX_OPSET = 13
NORMALIZED_CONFIG_CLASS = NormalizedTextConfig.with_args(num_attention_heads="num_heads")
class GPTNeoXOnnxConfig(TextDecoderWithPositionIdsOnnxConfig):
DEFAULT_ONNX_OPSET = 13
NORMALIZED_CONFIG_CLASS = NormalizedTextConfig
class OPTOnnxConfig(TextDecoderOnnxConfig):
# OPT does not require position_ids input.
DEFAULT_ONNX_OPSET = 13
NORMALIZED_CONFIG_CLASS = NormalizedTextConfig
class LlamaOnnxConfig(TextDecoderWithPositionIdsOnnxConfig):
DEFAULT_ONNX_OPSET = 14 # Llama now uses F.scaled_dot_product_attention by default for torch>=2.1.1.
DUMMY_INPUT_GENERATOR_CLASSES = (DummyTextInputGenerator, MistralDummyPastKeyValuesGenerator)
DUMMY_PKV_GENERATOR_CLASS = MistralDummyPastKeyValuesGenerator
NORMALIZED_CONFIG_CLASS = NormalizedTextConfig
class Qwen2OnnxConfig(LlamaOnnxConfig):
pass
class GemmaOnnxConfig(LlamaOnnxConfig):
DUMMY_INPUT_GENERATOR_CLASSES = (DummyTextInputGenerator, GemmaDummyPastKeyValuesGenerator)
DUMMY_PKV_GENERATOR_CLASS = GemmaDummyPastKeyValuesGenerator
pass
class PhiOnnxConfig(TextDecoderWithPositionIdsOnnxConfig):
DEFAULT_ONNX_OPSET = 14 # Phi now uses F.scaled_dot_product_attention by default for torch>=2.1.1.
NORMALIZED_CONFIG_CLASS = NormalizedTextConfig
class Phi3OnnxConfig(PhiOnnxConfig):
DUMMY_INPUT_GENERATOR_CLASSES = (
MistralDummyPastKeyValuesGenerator,
) + TextDecoderOnnxConfig.DUMMY_INPUT_GENERATOR_CLASSES
DUMMY_PKV_GENERATOR_CLASS = MistralDummyPastKeyValuesGenerator
NORMALIZED_CONFIG_CLASS = NormalizedTextConfigWithGQA
class MistralOnnxConfig(TextDecoderWithPositionIdsOnnxConfig):
# This is because of the patching of torch.triu in AttentionMaskConverter, that exists from transformers>=4.35
MIN_TRANSFORMERS_VERSION = version.parse("4.34.99")
# The ONNX export of this architecture needs the Trilu operator support, available since opset 14
DEFAULT_ONNX_OPSET = 14
DUMMY_INPUT_GENERATOR_CLASSES = (
MistralDummyPastKeyValuesGenerator,
) + TextDecoderOnnxConfig.DUMMY_INPUT_GENERATOR_CLASSES
DUMMY_PKV_GENERATOR_CLASS = MistralDummyPastKeyValuesGenerator
NORMALIZED_CONFIG_CLASS = NormalizedTextConfig.with_args(num_key_value_heads="num_key_value_heads", allow_new=True)
class MPTOnnxConfig(TextDecoderOnnxConfig):
# MPT does not require position_ids input.
DEFAULT_ONNX_OPSET = 13
NORMALIZED_CONFIG_CLASS = NormalizedTextConfig.with_args(
num_attention_heads="n_heads", hidden_size="d_model", num_layers="n_layers"
)
class BloomOnnxConfig(TextDecoderOnnxConfig):
# Bloom does not require position_ids input.
DUMMY_INPUT_GENERATOR_CLASSES = (
BloomDummyPastKeyValuesGenerator,
) + TextDecoderOnnxConfig.DUMMY_INPUT_GENERATOR_CLASSES
DUMMY_PKV_GENERATOR_CLASS = BloomDummyPastKeyValuesGenerator
NORMALIZED_CONFIG_CLASS = NormalizedTextConfig.with_args(num_layers="n_layer", num_attention_heads="n_head")
def add_past_key_values(self, inputs_or_outputs: Dict[str, Dict[int, str]], direction: str):
if direction not in ["inputs", "outputs"]:
raise ValueError(f'direction must either be "inputs" or "outputs", but {direction} was given')
if direction == "inputs":
decoder_sequence_name = "past_sequence_length"
name = "past_key_values"
else:
decoder_sequence_name = "past_sequence_length + 1"
name = "present"
for i in range(self._normalized_config.num_layers):
inputs_or_outputs[f"{name}.{i}.key"] = {
0: "batch_size x num_heads",
2: decoder_sequence_name,
}
inputs_or_outputs[f"{name}.{i}.value"] = {
0: "batch_size x num_heads",
1: decoder_sequence_name,
}
class GPTBigCodeOnnxConfig(TextDecoderWithPositionIdsOnnxConfig):
DUMMY_INPUT_GENERATOR_CLASSES = (
GPTBigCodeDummyPastKeyValuesGenerator,
) + TextDecoderOnnxConfig.DUMMY_INPUT_GENERATOR_CLASSES
DEFAULT_ONNX_OPSET = 14 # GPT BigCode now uses F.scaled_dot_product_attention by default for torch>=2.1.1.
DUMMY_PKV_GENERATOR_CLASS = GPTBigCodeDummyPastKeyValuesGenerator
NORMALIZED_CONFIG_CLASS = NormalizedConfigManager.get_normalized_config_class("gpt_bigcode")
def add_past_key_values(self, inputs_or_outputs: Dict[str, Dict[int, str]], direction: str):
if direction not in ["inputs", "outputs"]:
raise ValueError(f'direction must either be "inputs" or "outputs", but {direction} was given')
if direction == "inputs":
decoder_sequence_name = "past_sequence_length"
name = "past_key_values"
else:
decoder_sequence_name = "past_sequence_length + 1"
name = "present"
for i in range(self._normalized_config.num_layers):
# No dim for `n_head` when using multi-query attention
inputs_or_outputs[f"{name}.{i}.key_value"] = {
0: "batch_size",
1: decoder_sequence_name,
}
def flatten_past_key_values(self, flattened_output, name, idx, t):
flattened_output[f"{name}.{idx}.key_value"] = t
class FalconOnnxConfig(TextDecoderOnnxConfig):
# This is due to the cache refactoring for Falcon in 4.36
MIN_TRANSFORMERS_VERSION = version.parse("4.35.99")
DUMMY_INPUT_GENERATOR_CLASSES = (
FalconDummyPastKeyValuesGenerator,
) + TextDecoderOnnxConfig.DUMMY_INPUT_GENERATOR_CLASSES
DEFAULT_ONNX_OPSET = 14 # Falcon uses aten::triu that requires opset>=14, and F.scaled_dot_product_attention
NORMALIZED_CONFIG_CLASS = NormalizedTextConfig
DUMMY_PKV_GENERATOR_CLASS = FalconDummyPastKeyValuesGenerator
def __init__(
self,
config: "PretrainedConfig",
task: str = "feature-extraction",
int_dtype: str = "int64",
float_dtype: str = "fp32",
use_past: bool = False,
use_past_in_inputs: bool = False,
preprocessors: Optional[List[Any]] = None,
legacy: bool = False,
):
super().__init__(
config=config,
task=task,
int_dtype=int_dtype,
float_dtype=float_dtype,
use_past=use_past,
use_past_in_inputs=use_past_in_inputs,
preprocessors=preprocessors,
legacy=legacy,
)
# For some reason Falcon config.num_kv_heads can not be trusted, see in Transformers:
# https://github.com/huggingface/transformers/blob/v4.34.0/src/transformers/models/falcon/modeling_falcon.py#L337
self._normalized_config.num_kv_heads = (
self._normalized_config.num_kv_heads
if (self._normalized_config.new_decoder_architecture or not self._normalized_config.multi_query)
else 1
)
@property
def inputs(self) -> Dict[str, Dict[int, str]]:
common_inputs = super().inputs
if not self.legacy and not self._config.alibi and self.task in ["text-generation", "feature-extraction"]:
# When alibi is used, position_ids are not used in Falcon.
# Reference: https://github.com/huggingface/transformers/blob/v4.34.0/src/transformers/models/falcon/modeling_falcon.py#L1116
common_inputs["position_ids"] = {0: "batch_size", 1: "sequence_length"}
return common_inputs
# we need to set output_attentions=True in the model input to avoid calling
# torch.nn.functional.scaled_dot_product_attention that is not supported by the ONNX export
def patch_model_for_export(
self, model: Union["PreTrainedModel", "TFPreTrainedModel"], model_kwargs: Optional[Dict[str, Any]] = None
) -> "ModelPatcher":
return FalconModelPatcher(self, model, model_kwargs=model_kwargs)
class T5DummySeq2SeqPastKeyValuesGenerator(DummySeq2SeqPastKeyValuesGenerator):
def generate(self, input_name: str, framework: str = "pt", int_dtype: str = "int64", float_dtype: str = "fp32"):
encoder_shape = (
self.batch_size,
self.normalized_config.encoder_num_attention_heads,
self.encoder_sequence_length,
self.normalized_config.key_value_dim,
)
decoder_shape = (
self.batch_size,
self.normalized_config.decoder_num_attention_heads,
self.sequence_length,
self.normalized_config.key_value_dim,
)
return [
(
self.random_float_tensor(decoder_shape, framework=framework, dtype=float_dtype),
self.random_float_tensor(decoder_shape, framework=framework, dtype=float_dtype),
self.random_float_tensor(encoder_shape, framework=framework, dtype=float_dtype),
self.random_float_tensor(encoder_shape, framework=framework, dtype=float_dtype),
)
for _ in range(self.normalized_config.decoder_num_layers)
]
class T5OnnxConfig(TextSeq2SeqOnnxConfig):
DEFAULT_ONNX_OPSET = 13
DUMMY_INPUT_GENERATOR_CLASSES = TextSeq2SeqOnnxConfig.DUMMY_INPUT_GENERATOR_CLASSES[:-1] + (
T5DummySeq2SeqPastKeyValuesGenerator,
)
DUMMY_PKV_GENERATOR_CLASS = T5DummySeq2SeqPastKeyValuesGenerator
NORMALIZED_CONFIG_CLASS = NormalizedSeq2SeqConfig.with_args(
hidden_size="d_model",
num_attention_heads="num_heads",
encoder_num_layers="num_layers",
decoder_num_layers="num_decoder_layers",
key_value_dim="d_kv",
allow_new=True,
)
def generate_dummy_inputs_for_validation(
self, reference_model_inputs: Dict[str, Any], onnx_input_names: Optional[List[str]] = None
) -> Dict[str, Any]:
if self._behavior is ConfigBehavior.DECODER:
reference_model_inputs["input_ids"] = reference_model_inputs.pop("decoder_input_ids")
if onnx_input_names is not None:
if "encoder_outputs" in reference_model_inputs:
if "encoder_hidden_states" in onnx_input_names:
reference_model_inputs["encoder_hidden_states"] = reference_model_inputs.pop("encoder_outputs")[0]
else:
reference_model_inputs.pop("encoder_outputs")
else:
# TODO: remove this else in optimum 2.0 and make onnx_input_names a required argument
# T5 requires encoder_hidden_states as an input for both the without/with past models,
# which is different than other architectures that require it only for the without past case
reference_model_inputs["encoder_hidden_states"] = reference_model_inputs.pop("encoder_outputs")[0]
return super().generate_dummy_inputs_for_validation(reference_model_inputs)
class MT5OnnxConfig(T5OnnxConfig):
ATOL_FOR_VALIDATION = 1e-4
class LongT5OnnxConfig(T5OnnxConfig):
DEFAULT_ONNX_OPSET = 14
class BartDummyTextInputGenerator(DummyTextInputGenerator):
def __init__(
self,
task: str,
normalized_config: NormalizedSeq2SeqConfig,
batch_size: int = DEFAULT_DUMMY_SHAPES["batch_size"],
sequence_length: int = DEFAULT_DUMMY_SHAPES["sequence_length"],
num_choices: int = DEFAULT_DUMMY_SHAPES["num_choices"],
random_batch_size_range: Optional[Tuple[int, int]] = None,
random_sequence_length_range: Optional[Tuple[int, int]] = None,
random_num_choices_range: Optional[Tuple[int, int]] = None,
force_eos_token_id_presence: bool = True,
**kwargs,
):
super().__init__(
task=task,
normalized_config=normalized_config,
batch_size=batch_size,
sequence_length=sequence_length,
num_choices=num_choices,
random_batch_size_range=random_batch_size_range,
random_sequence_length_range=random_sequence_length_range,
random_num_choices_range=random_num_choices_range,
)
self.force_eos_token_id_presence = force_eos_token_id_presence
self.eos_token_id = normalized_config.eos_token_id
def generate(self, input_name: str, framework: str = "pt", int_dtype: str = "int64", float_dtype: str = "fp32"):
int_tensor = super().generate(input_name, framework=framework, int_dtype=int_dtype, float_dtype=float_dtype)
# This inserts EOS_TOKEN_ID at random locations along the sequence length dimension.
if self.force_eos_token_id_presence and "input_ids" in input_name and self.task == "text-classification":
for idx in range(self.batch_size):
if self.eos_token_id in int_tensor[idx]:
continue
random_idx = random.randint(1, self.sequence_length - 1)
int_tensor[idx][random_idx] = self.eos_token_id
return int_tensor
class M2M100OnnxConfig(TextSeq2SeqOnnxConfig):
NORMALIZED_CONFIG_CLASS = NormalizedSeq2SeqConfig.with_args(
encoder_num_layers="encoder_layers",
decoder_num_layers="decoder_layers",
num_layers="decoder_layers", # Used for the text-generation task past key values input generation.
encoder_num_attention_heads="encoder_attention_heads",
decoder_num_attention_heads="decoder_attention_heads",
eos_token_id="eos_token_id",
)
DUMMY_INPUT_GENERATOR_CLASSES = (
BartDummyTextInputGenerator,
{
"feature-extraction": DummySeq2SeqDecoderTextInputGenerator,
"text-generation": DummyDecoderTextInputGenerator,
},
{
"feature-extraction": DummySeq2SeqPastKeyValuesGenerator,
"text-generation": DummyPastKeyValuesGenerator,
},
)
def _create_dummy_input_generator_classes(self, **kwargs) -> List["DummyInputGenerator"]:
dummy_text_input_generator = self.DUMMY_INPUT_GENERATOR_CLASSES[0](
self.task, self._normalized_config, **kwargs
)
task = "feature-extraction" if self.task != "text-generation" else "text-generation"
dummy_decoder_text_input_generator = self.DUMMY_INPUT_GENERATOR_CLASSES[1][task](
self.task, self._normalized_config, **kwargs
)
if self.task != "text-generation":
kwargs["encoder_sequence_length"] = dummy_text_input_generator.sequence_length
dummy_seq2seq_past_key_values_generator = self.DUMMY_INPUT_GENERATOR_CLASSES[2][task](
self.task, self._normalized_config, **kwargs
)
dummy_inputs_generators = [
dummy_text_input_generator,
dummy_decoder_text_input_generator,
dummy_seq2seq_past_key_values_generator,
]
return dummy_inputs_generators
@property
def inputs_for_default_and_seq2seq_lm(self):
return super().inputs
@property
def inputs_for_causal_lm(self):
if self.use_past_in_inputs:
common_inputs = {
"input_ids": {0: "batch_size"},
"attention_mask": {0: "batch_size", 1: "past_sequence_length + 1"},
}
for i in range(self._normalized_config.decoder_num_layers):
common_inputs[f"past_key_values.{i}.key"] = {
0: "batch_size",
2: "past_sequence_length",
}
common_inputs[f"past_key_values.{i}.value"] = {
0: "batch_size",
2: "past_sequence_length",
}
else:
common_inputs = {
"input_ids": {0: "batch_size", 1: "sequence_length"},
"attention_mask": {0: "batch_size", 1: "sequence_length"},
}
return common_inputs
@property
def inputs_for_other_tasks(self):
return {
"input_ids": {0: "batch_size", 1: "sequence_length"},
"attention_mask": {0: "batch_size", 1: "sequence_length"},
}
@property
def inputs(self) -> Dict[str, Dict[int, str]]:
inputs_properties = {
"feature-extraction": self.inputs_for_default_and_seq2seq_lm,
"text2text-generation": self.inputs_for_default_and_seq2seq_lm,
"text-generation": self.inputs_for_causal_lm,
"other": self.inputs_for_other_tasks,
}
return inputs_properties.get(self.task, inputs_properties["other"])
@property
def outputs(self) -> Dict[str, Dict[int, str]]:
if self.task in ["feature-extraction", "text2text-generation"]:
common_outputs = super().outputs
else:
common_outputs = super(OnnxConfigWithPast, self).outputs
if self.use_past:
# When exporting decoder models with use_cache=True, both the decoder without past and with past have the KV cache as an output.
for i in range(self._normalized_config.encoder_num_layers):
common_outputs[f"present.{i}.key"] = {0: "batch_size", 2: "past_sequence_length + sequence_length"}
common_outputs[f"present.{i}.value"] = {
0: "batch_size",
2: "past_sequence_length + sequence_length",
}
return common_outputs
def generate_dummy_inputs(self, framework: str = "pt", **kwargs):
# This will handle the attention mask padding when Bart is used for text-generation.
if self.task == "text-generation":
self.PAD_ATTENTION_MASK_TO_PAST = True
dummy_inputs = super().generate_dummy_inputs(framework=framework, **kwargs)
# Setting it back to the default version.
self.PAD_ATTENTION_MASK_TO_PAST = False
return dummy_inputs
def flatten_past_key_values(self, flattened_output, name, idx, t):
if self.task in ["feature-extraction", "text2text-generation"]:
flattened_output = super().flatten_past_key_values(flattened_output, name, idx, t)
else:
flattened_output = super(OnnxSeq2SeqConfigWithPast, self).flatten_past_key_values(
flattened_output, name, idx, t
)
class BartOnnxConfig(M2M100OnnxConfig):
DEFAULT_ONNX_OPSET = 14 # Bart now uses F.scaled_dot_product_attention by default for torch>=2.1.1.
MIN_TORCH_VERSION = version.parse("2.1.2")
pass
class MBartOnnxConfig(BartOnnxConfig):
pass
class BlenderbotOnnxConfig(BartOnnxConfig):
pass
class BlenderbotSmallOnnxConfig(BartOnnxConfig):
pass
# big_bird and bigbird_pegasus are unsupported for now as block sparse attention is written in pure python and numpy in transformers.
# Thus, the case attention_type == "block_sparse" is unusable.
# Even with rewritting this part in pure PyTorch, torch.onnx.export is then prohibitively slow.
# References: https://github.com/pytorch/pytorch/issues/63734 & https://github.com/pytorch/pytorch/issues/94821
"""
class BigBirdOnnxConfig(DistilBertOnnxConfig):
pass
class BigBirdPegasusOnnxConfig(BartOnnxConfig):
def generate_dummy_inputs_for_validation(self, reference_model_inputs: Dict[str, Any]) -> Dict[str, Any]:
if self._behavior is ConfigBehavior.ENCODER:
# TODO: check why the attention mask is not present in the exported model
reference_model_inputs.pop("attention_mask")
return super().generate_dummy_inputs_for_validation(reference_model_inputs)
"""
class PegasusOnnxConfig(BartOnnxConfig):
pass
class MarianOnnxConfig(BartOnnxConfig):
pass
class ViTOnnxConfig(VisionOnnxConfig):
NORMALIZED_CONFIG_CLASS = NormalizedVisionConfig
MIN_TORCH_VERSION = version.parse("1.11")
DEFAULT_ONNX_OPSET = 14 # now uses F.scaled_dot_product_attention by default for torch>=2.1.1.
@property
def inputs(self) -> Dict[str, Dict[int, str]]:
return {"pixel_values": {0: "batch_size", 1: "num_channels", 2: "height", 3: "width"}}
@property
def outputs(self) -> Dict[str, Dict[int, str]]:
common_outputs = super().outputs
if self.task == "feature-extraction":
common_outputs["last_hidden_state"] = {0: "batch_size"}
return common_outputs
class CvTOnnxConfig(ViTOnnxConfig):
DEFAULT_ONNX_OPSET = 13
ATOL_FOR_VALIDATION = 1e-2
class LevitOnnxConfig(ViTOnnxConfig):
DEFAULT_ONNX_OPSET = 11
class DeiTOnnxConfig(ViTOnnxConfig):
DEFAULT_ONNX_OPSET = 14 # now uses F.scaled_dot_product_attention by default for torch>=2.1.1.
class BeitOnnxConfig(ViTOnnxConfig):
DEFAULT_ONNX_OPSET = 11
class ConvNextOnnxConfig(ViTOnnxConfig):
DEFAULT_ONNX_OPSET = 11
class ConvNextV2OnnxConfig(ViTOnnxConfig):
DEFAULT_ONNX_OPSET = 11
class MobileViTOnnxConfig(ViTOnnxConfig):
ATOL_FOR_VALIDATION = 1e-4
DEFAULT_ONNX_OPSET = 11
class RegNetOnnxConfig(ViTOnnxConfig):
# This config has the same inputs as ViTOnnxConfig
DEFAULT_ONNX_OPSET = 11
class ResNetOnnxConfig(ViTOnnxConfig):
ATOL_FOR_VALIDATION = 1e-3
DEFAULT_ONNX_OPSET = 11
class DetrOnnxConfig(ViTOnnxConfig):
DEFAULT_ONNX_OPSET = 12
@property
def outputs(self) -> Dict[str, Dict[int, str]]:
if self.task == "image-segmentation":
return {
"logits": {0: "batch_size", 1: "num_queries"},
"pred_masks": {0: "batch_size", 1: "num_queries"},
}
else:
return super().outputs
class TableTransformerOnnxConfig(DetrOnnxConfig):
pass
class YolosOnnxConfig(ViTOnnxConfig):
DEFAULT_ONNX_OPSET = 14 # now uses F.scaled_dot_product_attention by default for torch>=2.1.1.
class SwinOnnxConfig(ViTOnnxConfig):
DEFAULT_ONNX_OPSET = 11
class Swin2srOnnxConfig(SwinOnnxConfig):
pass
class DptOnnxConfig(ViTOnnxConfig):
DEFAULT_ONNX_OPSET = 11
class GlpnOnnxConfig(ViTOnnxConfig):
DEFAULT_ONNX_OPSET = 11
class PoolFormerOnnxConfig(ViTOnnxConfig):
NORMALIZED_CONFIG_CLASS = NormalizedVisionConfig
ATOL_FOR_VALIDATION = 2e-3
DEFAULT_ONNX_OPSET = 11
class SegformerOnnxConfig(YolosOnnxConfig):
pass
class MobileNetV1OnnxConfig(ViTOnnxConfig):
ATOL_FOR_VALIDATION = 1e-4
DEFAULT_ONNX_OPSET = 11
@property
def inputs(self) -> Dict[str, Dict[int, str]]:
return {"pixel_values": {0: "batch_size"}}
class MobileNetV2OnnxConfig(MobileNetV1OnnxConfig):
pass
class DonutSwinOnnxConfig(ViTOnnxConfig):
DEFAULT_ONNX_OPSET = 11
class TimmDefaultOnnxConfig(ViTOnnxConfig):
ATOL_FOR_VALIDATION = 1e-3
DEFAULT_ONNX_OPSET = 12
def rename_ambiguous_inputs(self, inputs):
# The input name in the model signature is `x, hence the export input name is updated.
model_inputs = {}
model_inputs["x"] = inputs["pixel_values"]
return model_inputs
@property
def torch_to_onnx_input_map(self) -> Dict[str, str]:
return {"x": "pixel_values"}
class SentenceTransformersTransformerOnnxConfig(TextEncoderOnnxConfig):
NORMALIZED_CONFIG_CLASS = NormalizedTextConfig
DEFAULT_ONNX_OPSET = 14 # Some bottleneck transformers models require a specific ONNX opset to be successfully exported. We put a rather high opset here for the export to work for all architectures.
@property
def inputs(self) -> Dict[str, Dict[int, str]]:
return {
"input_ids": {0: "batch_size", 1: "sequence_length"},
"attention_mask": {0: "batch_size", 1: "sequence_length"},
}
@property
def outputs(self) -> Dict[str, Dict[int, str]]:
return {
"token_embeddings": {0: "batch_size", 1: "sequence_length"},
"sentence_embedding": {0: "batch_size"},
}
# we need to set output_attentions=True in the model input to avoid calling
# torch.nn.functional.scaled_dot_product_attention that is not supported by the ONNX export
# due to the op torch.nn.functional.multi_head_attention_forward used for WavLM
def patch_model_for_export(
self, model: Union["PreTrainedModel", "TFPreTrainedModel"], model_kwargs: Optional[Dict[str, Any]] = None
) -> "ModelPatcher":
return SentenceTransformersTransformerPatcher(self, model, model_kwargs=model_kwargs)
class CLIPNormalizedConfig(NormalizedTextAndVisionConfig):
TEXT_CONFIG = "text_config"
VISION_CONFIG = "vision_config"
class CLIPOnnxConfig(TextAndVisionOnnxConfig):
NORMALIZED_CONFIG_CLASS = CLIPNormalizedConfig
DEFAULT_ONNX_OPSET = 14
@property
def inputs(self) -> Dict[str, Dict[int, str]]:
return {
"input_ids": {0: "text_batch_size", 1: "sequence_length"},
"pixel_values": {0: "image_batch_size", 1: "num_channels", 2: "height", 3: "width"},
"attention_mask": {0: "text_batch_size", 1: "sequence_length"},
}
@property
def outputs(self) -> Dict[str, Dict[int, str]]:
return {
"logits_per_image": {0: "image_batch_size", 1: "text_batch_size"},
"logits_per_text": {0: "text_batch_size", 1: "image_batch_size"},
"text_embeds": {0: "text_batch_size"},
"image_embeds": {0: "image_batch_size"},
}
class SentenceTransformersCLIPOnnxConfig(CLIPOnnxConfig):
@property
def outputs(self) -> Dict[str, Dict[int, str]]:
return {
"text_embeds": {0: "text_batch_size"},
"image_embeds": {0: "image_batch_size"},
}
def patch_model_for_export(
self, model: Union["PreTrainedModel", "TFPreTrainedModel"], model_kwargs: Optional[Dict[str, Any]] = None
) -> "ModelPatcher":
return SentenceTransformersCLIPPatcher(self, model, model_kwargs=model_kwargs)
class CLIPTextWithProjectionOnnxConfig(TextEncoderOnnxConfig):
ATOL_FOR_VALIDATION = 1e-3
# The ONNX export of this architecture needs the Trilu operator support, available since opset 14
DEFAULT_ONNX_OPSET = 14
NORMALIZED_CONFIG_CLASS = NormalizedConfig.with_args(
vocab_size="vocab_size",
sequence_length="max_position_embeddings",
num_layers="num_hidden_layers",
allow_new=True,
)
@property
def inputs(self) -> Dict[str, Dict[int, str]]:
return {
"input_ids": {0: "batch_size", 1: "sequence_length"},
}
@property
def outputs(self) -> Dict[str, Dict[int, str]]:
common_outputs = {
"text_embeds": {0: "batch_size", 1: "sequence_length"},
"last_hidden_state": {0: "batch_size", 1: "sequence_length"},
}
if self._normalized_config.output_hidden_states:
for i in range(self._normalized_config.num_layers + 1):
common_outputs[f"hidden_states.{i}"] = {0: "batch_size", 1: "sequence_length"}
return common_outputs
class CLIPTextOnnxConfig(CLIPTextWithProjectionOnnxConfig):
@property
def outputs(self) -> Dict[str, Dict[int, str]]:
common_outputs = {
"last_hidden_state": {0: "batch_size", 1: "sequence_length"},
"pooler_output": {0: "batch_size"},
}
if self._normalized_config.output_hidden_states:
for i in range(self._normalized_config.num_layers + 1):
common_outputs[f"hidden_states.{i}"] = {0: "batch_size", 1: "sequence_length"}
return common_outputs
def generate_dummy_inputs(self, framework: str = "pt", **kwargs):
dummy_inputs = super().generate_dummy_inputs(framework=framework, **kwargs)
if framework == "pt":
import torch
dummy_inputs["input_ids"] = dummy_inputs["input_ids"].to(dtype=torch.int32)
return dummy_inputs
class UNetOnnxConfig(VisionOnnxConfig):
ATOL_FOR_VALIDATION = 1e-3
# The ONNX export of a CLIPText architecture, an other Stable Diffusion component, needs the Trilu
# operator support, available since opset 14
DEFAULT_ONNX_OPSET = 14
NORMALIZED_CONFIG_CLASS = NormalizedConfig.with_args(
image_size="sample_size",
num_channels="in_channels",
hidden_size="cross_attention_dim",
vocab_size="norm_num_groups",
allow_new=True,
)
DUMMY_INPUT_GENERATOR_CLASSES = (
DummyVisionInputGenerator,
DummyTimestepInputGenerator,
DummySeq2SeqDecoderTextInputGenerator,
)
@property
def inputs(self) -> Dict[str, Dict[int, str]]: