In [1]:
from typing import List
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
import onnx
import onnxruntime as ort
import numpy as np
import torch

## Explore Model Input&Output

In [6]:
onnx_model = onnx.load("../model_onnx/encoder_model.onnx")

In [9]:
onnx_model.graph.input

[name: "input_ids"
type {
  tensor_type {
    elem_type: 7
    shape {
      dim {
        dim_param: "batch_size"
      }
      dim {
        dim_param: "encoder_sequence_length"
      }
    }
  }
}
, name: "attention_mask"
type {
  tensor_type {
    elem_type: 7
    shape {
      dim {
        dim_param: "batch_size"
      }
      dim {
        dim_param: "encoder_sequence_length"
      }
    }
  }
}
]

In [10]:
onnx_model.graph.output

[name: "last_hidden_state"
type {
  tensor_type {
    elem_type: 1
    shape {
      dim {
        dim_param: "batch_size"
      }
      dim {
        dim_param: "encoder_sequence_length"
      }
      dim {
        dim_value: 768
      }
    }
  }
}
]

In [11]:
decoder_model = onnx.load("../model_onnx/decoder_model.onnx")

In [12]:
decoder_model.graph.input

[name: "encoder_attention_mask"
type {
  tensor_type {
    elem_type: 7
    shape {
      dim {
        dim_param: "batch_size"
      }
      dim {
        dim_param: "encoder_sequence_length"
      }
    }
  }
}
, name: "input_ids"
type {
  tensor_type {
    elem_type: 7
    shape {
      dim {
        dim_param: "batch_size"
      }
      dim {
        dim_param: "decoder_sequence_length"
      }
    }
  }
}
, name: "encoder_hidden_states"
type {
  tensor_type {
    elem_type: 1
    shape {
      dim {
        dim_param: "batch_size"
      }
      dim {
        dim_param: "encoder_sequence_length"
      }
      dim {
        dim_value: 768
      }
    }
  }
}
]

In [13]:
decoder_model.graph.output

[name: "logits"
type {
  tensor_type {
    elem_type: 1
    shape {
      dim {
        dim_param: "batch_size"
      }
      dim {
        dim_param: "decoder_sequence_length"
      }
      dim {
        dim_value: 32128
      }
    }
  }
}
, name: "present.0.decoder.key"
type {
  tensor_type {
    elem_type: 1
    shape {
      dim {
        dim_param: "batch_size"
      }
      dim {
        dim_value: 12
      }
      dim {
        dim_param: "past_decoder_sequence_length + 1"
      }
      dim {
        dim_value: 64
      }
    }
  }
}
, name: "present.0.decoder.value"
type {
  tensor_type {
    elem_type: 1
    shape {
      dim {
        dim_param: "batch_size"
      }
      dim {
        dim_value: 12
      }
      dim {
        dim_param: "past_decoder_sequence_length + 1"
      }
      dim {
        dim_value: 64
      }
    }
  }
}
, name: "present.0.encoder.key"
type {
  tensor_type {
    elem_type: 1
    shape {
      dim {
        dim_param: "batch_size"
      }
      di

## Inference

In [2]:
tokenizer = AutoTokenizer.from_pretrained("juierror/flan-t5-text2sql-with-schema")

In [3]:
def prepare_input(question: str, table: List[str]):
    table_prefix = "table:"
    question_prefix = "question:"
    join_table = ",".join(table)
    inputs = f"{question_prefix} {question} {table_prefix} {join_table}"
    output = tokenizer(inputs, max_length=512, return_tensors="pt")
    return output

In [4]:
input = prepare_input(question="get people name with age equal 25", table=["id", "name", "age"])

Truncation was not explicitly activated but `max_length` is provided a specific value, please use `truncation=True` to explicitly truncate examples to max length. Defaulting to 'longest_first' truncation strategy. If you encode pairs of sequences (GLUE-style) with the tokenizer you can select this strategy more precisely by providing a specific strategy to `truncation`.


In [5]:
input

{'input_ids': tensor([[ 822,   10,  129,  151,  564,   28, 1246, 4081,  944,  953,   10,    3,
           23,   26,    6, 4350,    6,  545,    1]]), 'attention_mask': tensor([[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1]])}

In [9]:
encoder_ort_sess = ort.InferenceSession("../model_onnx/encoder_model.onnx")
encoder_outputs = encoder_ort_sess.run(None, {'input_ids': input["input_ids"].numpy(), "attention_mask": input["attention_mask"].numpy()})

In [11]:
encoder_outputs[0].shape

(1, 19, 768)

In [12]:
decoder_ort_sess = ort.InferenceSession("../model_onnx/decoder_model.onnx")
decoder_outputs = decoder_ort_sess.run(
    None, 
    {
        "input_ids": input["input_ids"].numpy(), 
        "encoder_attention_mask": input["attention_mask"].numpy(),
        "encoder_hidden_states": outputs[0]
    })

In [14]:
logits = np.argmax(decoder_outputs[0][0], axis=-1)

In [15]:
logits

array([1525,    1,    1,    1,    1, 1246, 4081,  944,    1,    1,    1,
         23,   26,    1,  564, 3274,  545, 3274,    1])

In [16]:
tokenizer.decode(token_ids=logits)

2023-07-25 22:35:26.773430: I tensorflow/core/util/port.cc:110] oneDNN custom operations are on. You may see slightly different numerical results due to floating-point round-off errors from different computation orders. To turn them off, set the environment variable `TF_ENABLE_ONEDNN_OPTS=0`.
2023-07-25 22:35:26.799439: I tensorflow/core/platform/cpu_feature_guard.cc:182] This TensorFlow binary is optimized to use available CPU instructions in performance-critical operations.
To enable the following instructions: AVX2 AVX_VNNI FMA, in other operations, rebuild TensorFlow with the appropriate compiler flags.


'answer</s></s></s></s> age equal 25</s></s></s>id</s> name =age =</s>'