In [1]:
import numpy as np
import onnxruntime
from tokenizers import Tokenizer


In [2]:
model_path = "./models/translator_transformer_v4_2_layers.onnx"

ort_session = onnxruntime.InferenceSession(model_path,
                                           providers=['CPUExecutionProvider'])

In [27]:
# check basic info about the model that we will need
ort_inputs_info = ort_session.get_inputs()
input_names = [x.name for x in ort_inputs_info]
print(f"Model expects {len(ort_inputs_info)} inputs with names {input_names}")

print(f"Shape if the 1st input: {ort_inputs_info[0].shape}")
print(f"Shape if the 2nd input: {ort_inputs_info[1].shape}")

Model expects 2 inputs with names ['l_src_', 'l_tgt_']
Shape if the 1st input: [1, 128]
Shape if the 2nd input: [1, 128]


In [19]:
# instead of hardcoding sequence lengths we infer them fron the ONNX object
src_seq_length = ort_inputs_info[0].shape[1]
tgt_seq_length = ort_inputs_info[1].shape[1]

In [28]:
# we need our tokenizers which original model used for training
src_lang = Tokenizer.from_file("de_tokenizer")
tgt_lang = Tokenizer.from_file("en_tokenizer")

In [29]:
def translate(model: onnxruntime.InferenceSession,
              src_sentence: str,
              src_lang: Tokenizer,
              tgt_lang: Tokenizer,
              max_tgt_length: int,
              src_seq_length: int = 128,
              tgt_seq_length: int = 128,
              clean: bool = False) -> str:
    """
    Translate a sentence using ONNX model.
    We pass one token at a time, i.e. generating autoregressively
    using model's own outputs as inputs.
    """
    input_ids = src_lang.encode(src_sentence).ids
    input_ids = np.pad(input_ids, (0, src_seq_length - len(input_ids)),
                       constant_values=src_lang.token_to_id("<PAD>"))

    input_ids = np.reshape(input_ids, (1, -1))

    tgt_indices = [tgt_lang.token_to_id("<BOS>")]  # type: list[int]

    for t in range(max_tgt_length):
        tgt_ids = np.array(tgt_indices)
        tgt_ids = np.pad(tgt_ids, (0, tgt_seq_length - len(tgt_ids)),
                         constant_values=tgt_lang.token_to_id("<PAD>"))
        tgt_ids = np.reshape(tgt_ids, (1, -1))

        # shape (1, 128, 8000) -> (batch_size, seq_length, tgt_vocab_size)
        model_outputs = model.run(None, {"l_src_": input_ids,
                                         "l_tgt_": tgt_ids})[0]

        # we take the best prediction for step t only (generating autoregressivly)
        prediction = model_outputs[0, t].argmax()  # type: int
        tgt_indices.append(prediction)

        if prediction == tgt_lang.token_to_id("<EOS>"):
            break

    sentence = tgt_lang.decode(tgt_indices, skip_special_tokens=clean)

    return sentence

In [32]:
input_sentence = "Hallo Welt, wie geht es?"

translate(model=ort_session,
          src_sentence=input_sentence,
          src_lang=src_lang,
          tgt_lang=tgt_lang,
          max_tgt_length=42,
          src_seq_length=src_seq_length,
          tgt_seq_length=tgt_seq_length,
          clean=True)

'hello world, how is it going?'