In [3]:
import torch
from transformers import T5ForConditionalGeneration
from tokenization_enc_dec import EncDecTokenizer

In [2]:
model = T5ForConditionalGeneration.from_pretrained('./torch_eva/')

In [4]:
tokenizer = EncDecTokenizer('./EVA/src/bpe_dialog_new/vocab.txt')

In [19]:
input_ids = torch.LongTensor([
    tokenizer.encode('你好') + [tokenizer.sep_id, tokenizer.get_sentinel_id(0)]]
)
mask = torch.ones_like(input_ids)

In [20]:
!rm -rf onnx_eva_encoder
!mkdir -p onnx_eva_encoder

In [21]:
encoder_outputs = model.encoder(input_ids, mask)

In [22]:
encoder_outputs.last_hidden_state.shape

torch.Size([1, 3, 2048])

In [29]:
torch.onnx._export(
    model.encoder,
    (input_ids, mask),
    './onnx_eva_encoder/encoder.onnx',
    input_names=["input_ids", "attention_mask"],
    output_names=["last_hidden_state"],
    dynamic_axes={
        "input_ids": {0: "batch", 1: "sequence"},
        "attention_mask": {0: "batch", 1: "sequence"},
        "last_hidden_state": {0: "batch", 1: "sequence"},
    },
    opset_version=13,
    use_external_data_format=True
)

BaseModelOutputWithPastAndCrossAttentions(last_hidden_state=tensor([[[-0.6673, -0.4716,  0.2209,  ...,  0.7204, -1.2785,  0.0203],
         [ 0.0112,  0.0476,  0.1098,  ...,  0.0424,  0.0538,  0.0340],
         [-0.0718,  0.0491, -0.0719,  ..., -0.0109, -0.0191,  0.0133]]],
       grad_fn=<MulBackward0>), past_key_values=None, hidden_states=None, attentions=None, cross_attentions=None)

In [24]:
decoder_input_ids = torch.LongTensor([
    [tokenizer.get_sentinel_id(0)]]
)
decoder_mask = torch.ones_like(decoder_input_ids)

In [25]:
decoder_outputs = model.decoder(
    decoder_input_ids, decoder_mask, encoder_outputs.last_hidden_state, mask)

In [27]:
decoder_outputs.last_hidden_state

tensor([[[-0.0450, -2.1971, -0.1809,  ..., -1.4545, -0.2110, -2.0464]]],
       grad_fn=<MulBackward0>)

In [28]:
!rm -rf onnx_eva_decoder
!mkdir -p onnx_eva_decoder

In [30]:
torch.onnx._export(
    model.decoder,
    (decoder_input_ids, decoder_mask, encoder_outputs.last_hidden_state, mask),
    './onnx_eva_decoder/decoder.onnx',
    input_names=[
        "decoder_input_ids", "decoder_attention_mask",
        'encoder_hidden_states', 'encoder_attention_mask'
    ],
    output_names=["last_hidden_state"],
    dynamic_axes={
        "decoder_input_ids": {0: "batch", 1: "sequence"},
        "decoder_attention_mask": {0: "batch", 1: "sequence"},
        "encoder_hidden_states": {0: "batch", 1: "sequence"},
        "encoder_attention_mask": {0: "batch", 1: "sequence"},
        "last_hidden_state": {0: "batch", 1: "sequence"},
    },
    opset_version=13,
    use_external_data_format=True
)

  if causal_mask.shape[1] < attention_mask.shape[1]:


BaseModelOutputWithPastAndCrossAttentions(last_hidden_state=tensor([[[-0.0450, -2.1971, -0.1809,  ..., -1.4545, -0.2110, -2.0464]]],
       grad_fn=<MulBackward0>), past_key_values=((tensor([[[[ 0.0484,  0.1978,  0.0890,  ...,  0.1586, -0.0168, -0.0992]],

         [[ 0.0249,  0.0582,  0.0362,  ...,  0.0371,  0.0187, -0.0925]],

         [[ 0.1331,  0.2285,  0.4245,  ..., -0.1750, -0.0947, -0.1823]],

         ...,

         [[ 0.4579, -0.0117, -0.2745,  ...,  0.1693, -0.0921, -0.4806]],

         [[-0.2088, -0.0517, -0.1821,  ...,  0.0679,  0.1341,  0.1587]],

         [[-0.5319, -0.1163,  0.1201,  ...,  0.7124, -0.1576,  0.0966]]]],
       grad_fn=<TransposeBackward0>), tensor([[[[ 0.0164, -0.0238, -0.0072,  ...,  0.0394,  0.0220,  0.0107]],

         [[-0.0886,  0.0120,  0.0244,  ...,  0.0879, -0.0707, -0.1356]],

         [[-0.1064,  0.0839,  0.0040,  ..., -0.0759,  0.0831, -0.0011]],

         ...,

         [[ 0.0754, -0.0920, -0.0767,  ...,  0.0713, -0.0562, -0.0536]],

        

In [37]:
!rm -rf onnx_eva
!mkdir -p onnx_eva

In [38]:
torch.onnx._export(
    model.lm_head,
    (decoder_outputs.last_hidden_state,),
    './onnx_eva/lm.onnx',
    input_names=[
        'decoder_hidden_states',
    ],
    output_names=["logits"],
    dynamic_axes={
        "decoder_hidden_states": {0: "batch", 1: "sequence"},
        "logits": {0: "batch", 1: "sequence"},
    },
    opset_version=13,
    use_external_data_format=True
)

tensor([[[-0.2692, -9.3293, -9.2776,  ..., -9.3171, -9.2244, -9.1933]]],
       grad_fn=<UnsafeViewBackward>)

In [39]:
import onnx
from onnxruntime.quantization import quantize_dynamic, QuantType

In [40]:
!rm -rf onnx_eva_q
!mkdir -p onnx_eva_q

In [41]:
quantized_model = quantize_dynamic(
    './onnx_eva_encoder/encoder.onnx',
    './onnx_eva_q/encoder.onnx',
    weight_type=QuantType.QUInt8
)

quantized_model = quantize_dynamic(
    './onnx_eva_decoder/decoder.onnx',
    './onnx_eva_q/decoder.onnx',
    weight_type=QuantType.QUInt8
)

quantized_model = quantize_dynamic(
    './onnx_eva/lm.onnx',
    './onnx_eva_q/lm.onnx',
    weight_type=QuantType.QUInt8
)