/
convert_generation.py
3124 lines (2635 loc) · 122 KB
/
convert_generation.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816
817
818
819
820
821
822
823
824
825
826
827
828
829
830
831
832
833
834
835
836
837
838
839
840
841
842
843
844
845
846
847
848
849
850
851
852
853
854
855
856
857
858
859
860
861
862
863
864
865
866
867
868
869
870
871
872
873
874
875
876
877
878
879
880
881
882
883
884
885
886
887
888
889
890
891
892
893
894
895
896
897
898
899
900
901
902
903
904
905
906
907
908
909
910
911
912
913
914
915
916
917
918
919
920
921
922
923
924
925
926
927
928
929
930
931
932
933
934
935
936
937
938
939
940
941
942
943
944
945
946
947
948
949
950
951
952
953
954
955
956
957
958
959
960
961
962
963
964
965
966
967
968
969
970
971
972
973
974
975
976
977
978
979
980
981
982
983
984
985
986
987
988
989
990
991
992
993
994
995
996
997
998
999
1000
# -------------------------------------------------------------------------
# Copyright (c) Microsoft Corporation. All rights reserved.
# Licensed under the MIT License.
# -------------------------------------------------------------------------
"""
This converts GPT2 or T5 model to onnx with beam search operator.
Example 1: convert gpt2 model with beam search:
python convert_generation.py -m gpt2 --output gpt2_beam_search.onnx
Example 2: convert gpt2 model with beam search containing specific cuda optimizations:
python convert_generation.py -m gpt2 --output gpt2_beam_search.onnx --use_gpu \
--past_present_share_buffer --use_decoder_masked_attention
Example 3: convert gpt2 model with beam search with mixed precision and enable SkipLayerNorm strict mode:
python convert_generation.py -m gpt2 --output gpt2_beam_search.onnx --use_gpu -p fp16 --use_sln_strict_mode
Example 4: convert T5 model with beam search in two steps:
cd ./models/t5
python convert_to_onnx.py -m t5-small
cd ../..
python convert_generation.py -m t5-small --model_type t5 \
--decoder_onnx ./models/t5/onnx_models/t5-small_decoder.onnx \
--encoder_decoder_init_onnx ./models/t5/onnx_models/t5-small_encoder_decoder_init.onnx \
--output ./models/t5/onnx_models/t5_small_beam_search.onnx
Example 5: convert T5 model with beam search. All in one step:
python convert_generation.py -m t5-small --model_type t5 --output ./models/t5/onnx_models/t5_small_beam_search.onnx
Example 6: convert T5 model with beam search containing specific cuda optimizations. All in one step:
python convert_generation.py -m t5-small --model_type t5 --output ./models/t5/onnx_models/t5_small_beam_search.onnx \
--use_gpu --past_present_share_buffer --use_decoder_masked_attention
Example 7: convert MT5 model with external data file like mt5-base-beamsearch.onnx.data in below example.
python convert_generation.py -m google/mt5-base --model_type mt5 --output mt5-base-beamsearch.onnx -e
Example 8: convert gpt2 model with greedy search:
python convert_generation.py -m gpt2 --output gpt2_greedy_search.onnx --num_beams 1 --num_return_sequences 1
Example 9: convert gpt2 model with sampling:
python convert_generation.py -m gpt2 --output gpt2_sampling.onnx --num_beams 1 --num_return_sequences 1 --top_p 0.6
"""
import argparse
import logging
import math
import os
import time
from enum import Enum
from pathlib import Path
from typing import Any, Dict, List, Optional, Union
import numpy as np
import onnx
import torch
from benchmark_helper import Precision, setup_logger
from fusion_utils import NumpyHelper
from onnx import GraphProto, ModelProto, TensorProto
from onnx_model import OnnxModel
from transformers import (
GPT2Config,
GPT2LMHeadModel,
GPT2Tokenizer,
MT5Config,
MT5ForConditionalGeneration,
T5Config,
T5ForConditionalGeneration,
T5Tokenizer,
)
from onnxruntime import GraphOptimizationLevel, InferenceSession, SessionOptions, get_available_providers
from onnxruntime.transformers.models.gpt2.convert_to_onnx import main as convert_gpt2_to_onnx
from onnxruntime.transformers.models.gpt2.gpt2_helper import PRETRAINED_GPT2_MODELS
from onnxruntime.transformers.models.t5.convert_to_onnx import export_onnx_models as export_t5_onnx_models
from onnxruntime.transformers.models.t5.t5_helper import PRETRAINED_MT5_MODELS, PRETRAINED_T5_MODELS
logger = logging.getLogger("")
class GenerationType(Enum):
BEAMSEARCH = "beam_search"
GREEDYSEARCH = "greedy_search"
SAMPLING = "sampling"
def __str__(self):
return self.value
def parse_arguments(argv: Optional[List[str]] = None) -> argparse.Namespace:
"""Parse arguments
Args:
argv (Optional[List[str]], optional): _description_. Defaults to None.
Returns:
argparse.Namespace: Parsed arguments.
"""
parser = argparse.ArgumentParser()
input_group = parser.add_argument_group("Input options")
input_group.add_argument(
"-m",
"--model_name_or_path",
required=True,
type=str,
help="Pytorch model checkpoint path, or pretrained model name in the list: "
+ ", ".join(PRETRAINED_GPT2_MODELS + PRETRAINED_T5_MODELS + PRETRAINED_MT5_MODELS),
)
input_group.add_argument(
"--model_type",
required=False,
type=str,
default="gpt2",
choices=["gpt2", "t5", "mt5"],
help="Model type (default is gpt2) in the list: " + ", ".join(["gpt2", "t5", "mt5"]),
)
input_group.add_argument(
"--cache_dir",
required=False,
type=str,
default=os.path.join(".", "cache_models"),
help="Directory to cache pre-trained models",
)
input_group.add_argument(
"--decoder_onnx",
required=False,
type=str,
default="",
help="Path of onnx model for decoder. Specify it when you have exported the model.",
)
input_group.add_argument(
"--encoder_decoder_init_onnx",
required=False,
type=str,
default="",
help="Path of ONNX model for encoder and decoder initialization. Specify it when you have exported the model.",
)
parser.add_argument(
"--verbose",
required=False,
action="store_true",
help="Print more information",
)
parser.set_defaults(verbose=False)
output_group = parser.add_argument_group("Output options")
output_group.add_argument(
"--output",
required=True,
type=str,
help="Output path for onnx model with beam search.",
)
output_group.add_argument(
"-p",
"--precision",
required=False,
type=Precision,
default=Precision.FLOAT32,
choices=[Precision.FLOAT32, Precision.FLOAT16],
help="Precision of model to run. fp32 for full precision, fp16 for half or mixed precision",
)
output_group.add_argument(
"-b",
"--op_block_list",
required=False,
nargs="*",
default=["auto"],
help="Disable certain onnx operators when exporting model to onnx format. When using default"
'value for gpt2 type of model fp16 precision, it will be set to ["Add", "LayerNormalization",'
' "SkipLayerNormalization", "FastGelu"]. Other situation, it will be set to []',
)
output_group.add_argument(
"-e",
"--use_external_data_format",
required=False,
action="store_true",
help="save external data for model > 2G",
)
output_group.set_defaults(use_external_data_format=False)
output_group.add_argument(
"-s", "--run_shape_inference", required=False, action="store_true", help="run shape inference"
)
output_group.set_defaults(run_shape_inference=False)
output_group.add_argument(
"-dpvs",
"--disable_pad_vocab_size",
required=False,
action="store_true",
help="Do not pad logits MatMul weight to be a multiple of 8 along the dimension where dim value is"
" the vocab size. The logits MatMul may hence be of poor performance for fp16 precision.",
)
output_group.set_defaults(disable_pad_vocab_size=False)
output_group.add_argument(
"-dsgd",
"--disable_separate_gpt2_decoder_for_init_run",
required=False,
action="store_true",
help="Do not create separate decoder subgraphs for initial and remaining runs. This does not allow "
"for optimizations based on sequence lengths in each subgraph",
)
output_group.set_defaults(disable_separate_gpt2_decoder_for_init_run=False)
output_group.add_argument(
"-i",
"--disable_shared_initializers",
required=False,
action="store_true",
help="do not share initializers in encoder and decoder for T5 or in the init decoder and decoder for "
"GPT2. It will increase memory usage of t5/mt5/gpt2 models.",
)
output_group.set_defaults(disable_shared_initializers=False)
model_group = parser.add_argument_group("Beam search parameters that stored in the output model")
model_group.add_argument(
"--output_sequences_scores",
required=False,
action="store_true",
help="output sequences scores",
)
model_group.set_defaults(output_sequences_scores=False)
model_group.add_argument(
"--output_token_scores",
required=False,
action="store_true",
help="output token scores",
)
model_group.set_defaults(output_token_scores=False)
model_group.add_argument("--early_stopping", required=False, action="store_true")
model_group.set_defaults(early_stopping=False)
model_group.add_argument(
"--no_repeat_ngram_size",
type=int,
required=False,
default=0,
help="No repeat ngram size",
)
model_group.add_argument(
"--vocab_mask",
required=False,
action="store_true",
help="Enable vocab_mask. This mask applies only to every generated token to filter some bad words.",
)
model_group.set_defaults(vocab_mask=False)
model_group.add_argument(
"--past_present_share_buffer",
required=False,
action="store_true",
help="Use shared buffer for past and present, currently work for gpt2 greedy/sampling search.",
)
model_group.set_defaults(past_present_share_buffer=False)
model_group.add_argument(
"--use_decoder_masked_attention",
required=False,
action="store_true",
help="Uses `DecoderMaskedSelfAttention` or `DecoderMaskedMultiHeadAttention` to optimize the decoding Attention computation. "
"Must be used with `past_present_share_buffer`. Currently, only Attention head sizes of 32, 64 and 128 are supported.",
)
model_group.set_defaults(use_decoder_masked_attention=False)
model_group.add_argument(
"--prefix_vocab_mask",
required=False,
action="store_true",
help="Enable prefix_vocab_mask. This mask can be used to filter bad words in the first generated token only",
)
model_group.set_defaults(prefix_vocab_mask=False)
model_group.add_argument(
"--custom_attention_mask",
required=False,
action="store_true",
help="Enable custom_attention_mask. This mask can be used to replace default encoder attention mask",
)
model_group.set_defaults(custom_attention_mask=False)
model_group.add_argument(
"--presence_mask",
required=False,
action="store_true",
help="Presence mask for custom sampling",
)
model_group.set_defaults(presence_mask=False)
model_group.add_argument(
"--seed",
required=False,
action="store_true",
help="Random seed for sampling op",
)
model_group.set_defaults(seed=False)
beam_parameters_group = parser.add_argument_group(
"Beam search parameters not stored in the output model, for testing parity and performance"
)
beam_parameters_group.add_argument("--min_length", type=int, required=False, default=1, help="Min sequence length")
beam_parameters_group.add_argument("--max_length", type=int, required=False, default=50, help="Max sequence length")
beam_parameters_group.add_argument("--num_beams", type=int, required=False, default=4, help="Beam size")
beam_parameters_group.add_argument(
"--num_return_sequences",
type=int,
required=False,
default=1,
help="Number of return sequence <= num_beams",
)
beam_parameters_group.add_argument(
"--length_penalty",
type=float,
required=False,
default=1,
help="Positive. >1 to penalize and <1 to encourage short sentence.",
)
beam_parameters_group.add_argument(
"--repetition_penalty",
type=float,
required=False,
default=1,
help="Positive. >1 to penalize and <1 to encourage.",
)
beam_parameters_group.add_argument(
"--temperature",
type=float,
required=False,
default=1.0,
help="The value used to module the next token probabilities.",
)
beam_parameters_group.add_argument(
"--top_p",
type=float,
required=False,
default=1.0,
help="Top P for sampling",
)
beam_parameters_group.add_argument(
"--filter_value",
type=float,
required=False,
default=-float("Inf"),
help="Filter value for Top P sampling",
)
beam_parameters_group.add_argument(
"--min_tokens_to_keep",
type=int,
required=False,
default=1,
help="Minimum number of tokens we keep per batch example in the output.",
)
beam_parameters_group.add_argument(
"--presence_penalty",
type=float,
required=False,
default=0.0,
help="presence penalty for custom sampling.",
)
beam_parameters_group.add_argument(
"--custom",
type=int,
required=False,
default=0,
help="If 1 customized top P logic is applied",
)
beam_parameters_group.add_argument(
"--vocab_size",
type=int,
required=False,
default=-1,
help="Vocab_size of the underlying model used to decide the shape of vocab mask",
)
beam_parameters_group.add_argument(
"--eos_token_id",
type=int,
required=False,
default=-1,
help="custom eos_token_id for generating model with existing onnx encoder/decoder",
)
beam_parameters_group.add_argument(
"--pad_token_id",
type=int,
required=False,
default=-1,
help="custom pad_token_id for generating model with existing onnx encoder/decoder",
)
test_group = parser.add_argument_group("Other options for testing parity and performance")
test_group.add_argument(
"--use_sln_strict_mode",
required=False,
action="store_true",
help="Enable strict mode for SLN in CUDA provider. This ensures a better accuracy but will be slower.",
)
test_group.set_defaults(use_sln_strict_mode=False)
test_group.add_argument(
"--use_gpu", required=False, action="store_true", help="use GPU for inference. Required for fp16."
)
test_group.set_defaults(use_gpu=False)
test_group.add_argument(
"--disable_parity",
required=False,
action="store_true",
help="do not run parity test",
)
test_group.set_defaults(disable_parity=False)
test_group.add_argument(
"--disable_perf_test",
required=False,
action="store_true",
help="do not run perf test",
)
test_group.set_defaults(disable_perf_test=False)
test_group.add_argument(
"--torch_performance",
required=False,
action="store_true",
help="test PyTorch performance",
)
test_group.set_defaults(torch_performance=False)
test_group.add_argument(
"--total_runs",
required=False,
type=int,
default=1,
help="Number of times of inference for latency measurement",
)
test_group.add_argument(
"--save_test_data",
required=False,
action="store_true",
help="save test data for onnxruntime_perf_test tool",
)
test_group.set_defaults(save_test_data=False)
args = parser.parse_args(argv)
return args
def gpt2_to_onnx(args: argparse.Namespace):
"""Convert GPT-2 model to onnx
Args:
args (argparse.Namespace): arguments parsed from command line
"""
model_name = args.model_name_or_path
arguments = [
"--model_name_or_path",
model_name,
"--output",
args.decoder_onnx,
"--optimize_onnx",
"--precision",
"fp32" if args.precision == Precision.FLOAT32 else "fp16",
"--test_runs",
"1",
"--test_cases",
"10",
"--overwrite", # Overwrite onnx file if existed
]
if args.cache_dir:
arguments.extend(["--cache_dir", args.cache_dir])
if args.use_gpu:
arguments.append("--use_gpu")
if args.use_external_data_format:
arguments.append("--use_external_data_format")
if len(args.op_block_list):
arguments.extend(["--op_block_list"])
arguments.extend(args.op_block_list)
if args.precision == Precision.FLOAT16:
assert args.use_gpu, "fp16 or mixed precision model cannot run in CPU. Please add --use_gpu"
# TODO(tianleiwu): Use auto mixed precision for fp16 conversion: arguments.append('--auto_mixed_precision')
# Need change cuda kernel to support a combination of fp32 logits and fp16 past state.
# Currently logits and past state shall be same data type.
if args.verbose:
logger.info(f"arguments for convert_to_onnx:{arguments}")
convert_gpt2_to_onnx(argv=arguments)
def t5_to_onnx(args: argparse.Namespace):
"""Convert T5 model to onnx
Args:
args (argparse.Namespace): arguments parsed from command line
"""
paths = export_t5_onnx_models(
args.model_name_or_path,
args.cache_dir,
Path(args.output).parent,
use_gpu=args.use_gpu,
use_external_data_format=args.use_external_data_format,
optimize_onnx=(args.precision != Precision.FLOAT16),
precision=args.precision,
verbose=False,
use_decoder_start_token=False,
merge_encoder_and_decoder_init=True,
overwrite=True,
disable_auto_mixed_precision=False,
use_int32_inputs=True,
model_type=args.model_type,
)
logger.debug(f"onnx model for encoder: {paths[0]}")
logger.debug(f"onnx model for decoder: {paths[1]}")
args.encoder_decoder_init_onnx = paths[0]
args.decoder_onnx = paths[1]
def shape_inference(onnx_path: str, use_external_data_format: bool = True):
"""Shape inference on an onnx file, which will be overwritten.
Args:
onnx_path (str): Path of onnx model
use_external_data_format(bool): output tensors to external data or not.
"""
# Run symbolic shape inference to walk around ORT shape inference issue for subgraph.
from onnxruntime.tools.symbolic_shape_infer import SymbolicShapeInference
model = onnx.load_model(onnx_path, load_external_data=True)
out = SymbolicShapeInference.infer_shapes(model, auto_merge=True, guess_output_rank=False)
if out:
OnnxModel.save(out, onnx_path, save_as_external_data=use_external_data_format)
else:
logger.warning("Failed to run symbolic shape inference on the model.")
def pad_weights_of_logits_matmul(onnx_path: str, use_external_data_format: bool = True) -> bool:
"""Pad the logits MatMul weight in the provided decoder model, which will be overwritten.
Args:
onnx_path (str): Path of onnx model
use_external_data_format(bool): output tensors to external data or not.
"""
decoder_model_proto = onnx.load_model(onnx_path, load_external_data=True)
logits_output_name = decoder_model_proto.graph.output[0].name
decoder_model = OnnxModel(decoder_model_proto)
output_name_to_node = decoder_model.output_name_to_node()
assert logits_output_name in output_name_to_node
matmul_node = output_name_to_node[logits_output_name]
# Sanity check - the logits need to be produced by a MatMul node
if matmul_node.op_type != "MatMul":
return False
# The logits MatMul weight MUST be an initializer (or)
# it MUST be flowing through a Transpose whose input is
# an initializer
pad_along_axis_1 = True
logits_weight = decoder_model.get_initializer(matmul_node.input[1])
if logits_weight is None:
transpose_before_matmul = decoder_model.match_parent(matmul_node, "Transpose", 1)
if transpose_before_matmul is None:
return False
logits_weight = decoder_model.get_initializer(transpose_before_matmul.input[0])
if logits_weight is None:
return False
pad_along_axis_1 = False
# The logits MatMul weight MUST be fp16
if logits_weight.data_type != TensorProto.DataType.FLOAT16:
return False
# The logits MatMul weight MUST be 2-dimensional
if len(logits_weight.dims) != 2:
return False
# Pad and over-write the initializer (if needed)
actual_vocab_size = logits_weight.dims[1]
if (actual_vocab_size % 8) == 0:
# Already "padded"
return True
padded_vocab_size = math.ceil(actual_vocab_size / 8) * 8
padding = padded_vocab_size - actual_vocab_size
# TODO(hasesh): Handle cases where the fp16 data is stored in the
# non-raw data field
if logits_weight.raw_data:
if pad_along_axis_1:
padding_data = np.zeros((logits_weight.dims[0], padding), dtype=np.float16)
weight_with_padding = np.concatenate((NumpyHelper.to_array(logits_weight), padding_data), axis=1)
logits_weight.dims[1] = padded_vocab_size
else:
padding_data = np.zeros((padding, logits_weight.dims[1]), dtype=np.float16)
weight_with_padding = np.concatenate((NumpyHelper.to_array(logits_weight), padding_data), axis=0)
logits_weight.dims[0] = padded_vocab_size
logits_weight.raw_data = weight_with_padding.tobytes()
else:
return False
# Save the model
OnnxModel.save(decoder_model_proto, onnx_path, save_as_external_data=use_external_data_format)
return True
def create_ort_session(model_path: str, use_gpu: bool, use_sln_strict_mode: bool) -> InferenceSession:
"""Create OnnxRuntime session.
Args:
model_path (str): onnx model path
use_gpu (bool): use GPU or not
use_sln_strict_mode (bool): use strict mode for skip layer normalization or not
Raises:
RuntimeError: CUDAExecutionProvider is not available when --use_gpu is specified.
Returns:
onnxruntime.InferenceSession: The created session.
"""
sess_options = SessionOptions()
sess_options.graph_optimization_level = GraphOptimizationLevel.ORT_DISABLE_ALL
execution_providers = ["CUDAExecutionProvider", "CPUExecutionProvider"] if use_gpu else ["CPUExecutionProvider"]
if use_gpu:
if "CUDAExecutionProvider" not in get_available_providers():
raise RuntimeError("CUDAExecutionProvider is not available for --use_gpu!")
else:
logger.info("use CUDAExecutionProvider")
if use_sln_strict_mode:
cuda_provider_options = {"enable_skip_layer_norm_strict_mode": True}
provider_options = {"CUDAExecutionProvider": cuda_provider_options}
execution_providers = [
(name, provider_options[name]) if name in provider_options else name for name in execution_providers
]
ort_session = InferenceSession(model_path, sess_options, providers=execution_providers)
return ort_session
def verify_gpt2_subgraph(graph: onnx.GraphProto, precision: Precision):
"""Verify GPT-2 subgraph
Args:
graph (onnx.GraphProto): onnx graph of GPT-2
precision (Precision): Precision (FLOAT16 or FLOAT32) of the model.
Raises:
ValueError: Number of inputs not expected.
ValueError: Input name is not expected.
ValueError: Input data type is not expected.
ValueError: Number of outputs not expected.
ValueError: Output name is not expected.
ValueError: Output data type is not expected.
"""
is_float16 = precision == Precision.FLOAT16
input_count = len(graph.input)
layer_count = input_count - 3
assert layer_count >= 1
expected_inputs = ["input_ids", "position_ids", "attention_mask"] + [f"past_{i}" for i in range(layer_count)]
if len(graph.input) != len(expected_inputs):
raise ValueError(f"Number of inputs expected to be {len(expected_inputs)}. Got {len(graph.input)}")
for i, expected_input in enumerate(expected_inputs):
if graph.input[i].name != expected_input:
raise ValueError(f"Input {i} is expected to be {expected_input}. Got {graph.input[i].name}")
expected_type = TensorProto.INT32
if i >= 3:
expected_type = TensorProto.FLOAT16 if is_float16 else TensorProto.FLOAT
input_type = graph.input[i].type.tensor_type.elem_type
if input_type != expected_type:
raise ValueError(f"Input {i} is expected to have onnx data type {expected_type}. Got {input_type}")
logger.info("Verifying GPT-2 graph inputs: name and data type are good.")
expected_outputs = ["logits"] + [f"present_{i}" for i in range(layer_count)]
if len(graph.output) != len(expected_outputs):
raise ValueError(f"Number of outputs expected to be {len(expected_outputs)}. Got {len(graph.output)}")
for i, expected_output in enumerate(expected_outputs):
if graph.output[i].name != expected_output:
raise ValueError(f"Output {i} is expected to be {expected_output}. Got {graph.output[i].name}")
expected_type = TensorProto.FLOAT16 if is_float16 else TensorProto.FLOAT
output_type = graph.output[i].type.tensor_type.elem_type
if output_type != expected_type:
raise ValueError(f"Input {i} is expected to have onnx data type {expected_type}. Got {output_type}")
logger.info("Verifying GPT-2 graph outputs: name and data type are good.")
# TODO(tianleiwu): verify shapes of inputs and outputs.
return
def verify_t5_decoder_subgraph(graph: onnx.GraphProto, precision: Precision):
"""Verify T5 decoder subgraph
Args:
graph (onnx.GraphProto): onnx graph of T5 decoder
precision (Precision): Precision (FLOAT16 or FLOAT32) of the model.
Raises:
ValueError: Number of inputs not expected.
ValueError: Input name is not expected.
ValueError: Input data type is not expected.
ValueError: Number of outputs not expected.
ValueError: Output name is not expected.
ValueError: Output data type is not expected.
"""
is_float16 = precision == Precision.FLOAT16
float_type = TensorProto.FLOAT16 if is_float16 else TensorProto.FLOAT
input_count = len(graph.input)
layer_count = (input_count - 2) // 4
assert layer_count >= 1
# Expect inputs:
# input_ids: int32 (B, 1)
# encoder_attention_mask: int32 (B, encode_sequence_length)
# past_key_self_0: (B, num_heads, past_decode_sequence_length, head_size)
# past_value_self_0: (B, num_heads, past_decode_sequence_length, head_size)
# ... (for each self attention layer)
# past_key_cross_0: (B, num_heads, encode_sequence_length, head_size)
# past_value_cross_0: (B, num_heads, encode_sequence_length, head_size)
# ... (for each cross attention layer)
# TODO: encoder_hidden_states is optional
expected_inputs = ["input_ids", "encoder_attention_mask"]
for i in range(layer_count):
expected_inputs.append(f"past_key_self_{i}")
expected_inputs.append(f"past_value_self_{i}")
for i in range(layer_count):
expected_inputs.append(f"past_key_cross_{i}")
expected_inputs.append(f"past_value_cross_{i}")
if len(graph.input) != len(expected_inputs):
raise ValueError(f"Number of inputs expected to be {len(expected_inputs)}. Got {len(graph.input)}")
for i, expected_input in enumerate(expected_inputs):
if graph.input[i].name != expected_input:
raise ValueError(f"Input {i} is expected to be {expected_input}. Got {graph.input[i].name}")
expected_type = TensorProto.INT32 if i < 2 else float_type
input_type = graph.input[i].type.tensor_type.elem_type
if input_type != expected_type:
raise ValueError(f"Input {i} is expected to have onnx data type {expected_type}. Got {input_type}")
# Expect outputs:
# logits: (B, 1, vocab_size)
# present_key_self_0: (B, num_heads, past_decode_sequence_length + 1, head_size)
# present_value_self_0: (B, num_heads, past_decode_sequence_length + 1, head_size)
# ... (for each self attention layer)
expected_outputs = ["logits"]
for i in range(layer_count):
expected_outputs.append(f"present_key_self_{i}")
expected_outputs.append(f"present_value_self_{i}")
if len(graph.output) != len(expected_outputs):
raise ValueError(f"Number of outputs expected to be {len(expected_outputs)}. Got {len(graph.output)}")
for i, expected_output in enumerate(expected_outputs):
if graph.output[i].name != expected_output:
raise ValueError(f"Output {i} is expected to be {expected_output}. Got {graph.output[i].name}")
output_type = graph.output[i].type.tensor_type.elem_type
if output_type != float_type:
raise ValueError(f"Output {i} is expected to have onnx data type {float_type}. Got {output_type}")
def verify_t5_encoder_decoder_init_subgraph(graph: onnx.GraphProto, precision: Precision):
"""Verify T5 decoder subgraph
Args:
graph (onnx.GraphProto): onnx graph of T5 decoder
precision (Precision): Precision (FLOAT16 or FLOAT32) of the model.
Raises:
ValueError: Number of inputs not expected.
ValueError: Input name is not expected.
ValueError: Input data type is not expected.
ValueError: Number of outputs not expected.
ValueError: Output name is not expected.
ValueError: Output data type is not expected.
"""
is_float16 = precision == Precision.FLOAT16
layer_count = (len(graph.output) - 2) // 4
assert layer_count >= 1
# Expect 3 inputs:
# encoder_input_ids: int32 (B, encode_sequence_length)
# encoder_attention_mask: int32 (B, encode_sequence_length)
# decoder_input_ids: int32 (B, 1)
expected_inputs = ["encoder_input_ids", "encoder_attention_mask", "decoder_input_ids"]
if len(graph.input) != len(expected_inputs):
raise ValueError(f"Number of inputs expected to be {len(expected_inputs)}. Got {len(graph.input)}")
for i, expected_input in enumerate(expected_inputs):
if graph.input[i].name != expected_input:
raise ValueError(f"Input {i} is expected to be {expected_input}. Got {graph.input[i].name}")
expected_type = TensorProto.INT32
input_type = graph.input[i].type.tensor_type.elem_type
if input_type != expected_type:
raise ValueError(f"Input {i} is expected to have onnx data type {expected_type}. Got {input_type}")
# Expected outputs:
# logits: (B, 1, vocab_size)
# encoder_hidden_states: (B, encode_sequence_length, encoder_hidden_size)
# present_key_self_0: (B, num_heads, 1, head_size)
# present_value_self_0: (B, num_heads, 1, head_size)
# ... (for each self attention layer)
# present_key_cross_0: (B, num_heads, encode_sequence_length, head_size)
# present_value_cross_0: (B, num_heads, encode_sequence_length, head_size)
# ... (for each cross attention layer)
expected_outputs = ["logits", "encoder_hidden_states"]
for i in range(layer_count):
expected_outputs.append(f"present_key_self_{i}")
expected_outputs.append(f"present_value_self_{i}")
for i in range(layer_count):
expected_outputs.append(f"present_key_cross_{i}")
expected_outputs.append(f"present_value_cross_{i}")
if len(graph.output) != len(expected_outputs):
raise ValueError(f"Number of outputs expected to be {len(expected_outputs)}. Got {len(graph.output)}")
for i, expected_output in enumerate(expected_outputs):
if graph.output[i].name != expected_output:
raise ValueError(f"Output {i} is expected to be {expected_output}. Got {graph.output[i].name}")
expected_type = TensorProto.FLOAT16 if is_float16 else TensorProto.FLOAT
output_type = graph.output[i].type.tensor_type.elem_type
if output_type != expected_type:
raise ValueError(f"Output {i} is expected to have onnx data type {expected_type}. Got {output_type}")
logger.info("T5 encoder graph verified: name and data type of inputs and outputs are good.")
def remove_shared_initializers(
graph1: GraphProto,
graph2: GraphProto,
shared_prefix: str = "shared_",
min_elements: int = 1024,
signature_cache1: Optional[dict] = None,
signature_cache2: Optional[dict] = None,
):
"""Remove initializers with same value from two graphs.
Args:
graph1 (GraphProto): the first graph to process
graph2 (GraphProto): the second graph to process
shared_prefix (str): add prefix to the shared initializers among two graphs
min_elements (int, optional): minimal number of elements for initializers to be considered. Defaults to 1024.
signature_cache1 (dict): Optional dictionary to store data signatures of tensors in graph1 in order to speed up comparison
signature_cache2 (dict): Optional dictionary to store data signatures of tensors in graph2 in order to speed up comparison
"""
mapping_initializers_1 = {}
mapping_initializers_2 = {}
shared_initializers_1 = []
shared_initializers_2 = []
shared_initializers_names = []
for initializer1 in graph1.initializer:
if not (initializer1.dims and sum(initializer1.dims) >= min_elements):
continue
for initializer2 in graph2.initializer:
if not (initializer2.dims and sum(initializer2.dims) >= min_elements):
continue
if OnnxModel.has_same_value(initializer1, initializer2, signature_cache1, signature_cache2):
mapping_initializers_1[initializer1.name] = shared_prefix + initializer2.name
shared_initializers_1.append(initializer1)
if initializer2.name not in mapping_initializers_2:
shared_name = shared_prefix + initializer2.name
mapping_initializers_2[initializer2.name] = shared_name
shared_initializers_2.append(initializer2)
shared_initializers_names.append(shared_name)
break
logger.debug(f"shared initializers:{shared_initializers_names}")
# Make sure new name does not exist in graph 1
for node in graph1.node:
for j in range(len(node.input)):
if node.input[j] in shared_initializers_names:
raise RuntimeError(f"name is found in graph 1: {node.input[j]}")
# Make sure new name does not exist in graph 2
for node in graph2.node:
for j in range(len(node.input)):
if node.input[j] in shared_initializers_names:
raise RuntimeError(f"name is found in graph 2: {node.input[j]}")
# Remove shared initializers from graph 2
for initializer in shared_initializers_2:
graph2.initializer.remove(initializer)
# Rename value info for old names in graph 2
for value_info in graph2.value_info:
if value_info.name in mapping_initializers_2:
value_info.name = mapping_initializers_2[value_info.name]
# Rename nodes inputs in graph 2:
for node in graph2.node:
for j in range(len(node.input)):
if node.input[j] in mapping_initializers_2:
new_name = mapping_initializers_2[node.input[j]]
logger.debug(f"graph 2 rename node {node.name} input {j} from {node.input[j]} to {new_name}")
node.input[j] = new_name
# Remove shared initializers from graph 1
for initializer in shared_initializers_1:
graph1.initializer.remove(initializer)
# Rename value info for old names in graph 1
for value_info in graph1.value_info:
if value_info.name in mapping_initializers_1:
value_info.name = mapping_initializers_1[value_info.name]
# Rename nodes inputs in graph 1:
for node in graph1.node:
for j in range(len(node.input)):
if node.input[j] in mapping_initializers_1:
new_name = mapping_initializers_1[node.input[j]]
logger.debug(f"graph 1 rename node {node.name} input {j} from {node.input[j]} to {new_name}")
node.input[j] = new_name
# Rename shared initializers in graph 2
for initializer in shared_initializers_2:
initializer.name = mapping_initializers_2[initializer.name]
for initializer in shared_initializers_2:
shape = onnx.numpy_helper.to_array(initializer).shape
value_info = onnx.helper.make_tensor_value_info(initializer.name, initializer.data_type, shape)
# Need add value_info for initializers moved to parent graph. Otherwise, ORT will fail.
graph1.value_info.append(value_info)
graph2.value_info.append(value_info)
return shared_initializers_2
def get_shared_initializers(encoder_model: ModelProto, decoder_model: ModelProto):
encoder = OnnxModel(encoder_model)
decoder = OnnxModel(decoder_model)
encoder.add_prefix_to_names("e_")
decoder.add_prefix_to_names("d_")
signature_cache1, signature_cache2 = {}, {}
encoder.remove_duplicated_initializer(signature_cache1)
decoder.remove_duplicated_initializer(signature_cache2)
initializers = remove_shared_initializers(
decoder.model.graph,
encoder.model.graph,
shared_prefix="s_",
signature_cache1=signature_cache1,
signature_cache2=signature_cache2,