In [1]:
%cd ..
import sentencepiece
import torch
from constants import *
from custom_data import *
from transformer import *
from data_structure import *

d:\transformer-translator-pytorch


In [13]:
model = Transformer(64000, 64000)
checkpoint = torch.load("saved_model\\best_ckpt.tar", map_location=torch.device('cpu'))
model.load_state_dict(checkpoint['model_state_dict'])

<All keys matched successfully>

In [14]:
model.eval()

Transformer(
  (src_embedding): Embedding(64000, 512)
  (trg_embedding): Embedding(64000, 512)
  (positional_encoder): PositionalEncoder()
  (encoder): Encoder(
    (layers): ModuleList(
      (0-2): 3 x EncoderLayer(
        (layer_norm_1): LayerNormalization(
          (layer): LayerNorm((512,), eps=1e-06, elementwise_affine=True)
        )
        (multihead_attention): MultiheadAttention(
          (w_q): Linear(in_features=512, out_features=512, bias=True)
          (w_k): Linear(in_features=512, out_features=512, bias=True)
          (w_v): Linear(in_features=512, out_features=512, bias=True)
          (dropout): Dropout(p=0.1, inplace=False)
          (attn_softmax): Softmax(dim=-1)
          (w_0): Linear(in_features=512, out_features=512, bias=True)
        )
        (drop_out_1): Dropout(p=0.1, inplace=False)
        (layer_norm_2): LayerNormalization(
          (layer): LayerNorm((512,), eps=1e-06, elementwise_affine=True)
        )
        (feed_forward): FeedFowardLayer(
 

In [7]:
src_sp = spm.SentencePieceProcessor()
trg_sp = spm.SentencePieceProcessor()
src_sp.load(f"{SP_DIR}/{src_model_prefix}.model")
trg_sp.load(f"{SP_DIR}/{trg_model_prefix}.model")

True

In [16]:
input_sentence = "Hello my friend"
tokenized = src_sp.EncodeAsIds(input_sentence)
src = torch.LongTensor(pad_or_truncate(tokenized)).unsqueeze(0).to(device) # (1, L)
e_mask = (src != pad_id).unsqueeze(1).to(device) # (1, 1, L)


In [21]:
def convert_embedding_part(model, src, save_path): 
    print(src.shape)
    src_data = model.src_embedding(src) 
    # torch.onnx.export(model.src_embedding, 
    #                   src, 
    #                   save_path, 
    #                   export_params=True, 
    #                   opset_version=16, 
    #                   do_constant_folding=True, 
    #                   input_names=['src'], 
    #                   output_names=['src_data']) 
    return src_data.detach().numpy().shape

src_data = convert_embedding_part(model, src, "onnx_test\\src_embedding.onnx")
src_data

torch.Size([1, 256])


(1, 256, 512)

In [20]:


def convert_pe_part(model, src, save_path): 
    src_data2 = model.positional_encoder(src) 
    torch.onnx.export(model.positional_encoder, 
                      src, 
                      save_path, 
                      export_params=True, 
                      opset_version=16, 
                      do_constant_folding=True, 
                      input_names=['src_data'], 
                      output_names=['src_data2']) 
    return src_data2

src_data2 = convert_pe_part(model, src_data, "onnx_test\\positional_encoder.onnx")

verbose: False, log level: Level.ERROR



In [28]:
def convert_encoder_part(model, src, src_mask, save_path): 
    e_output = model.encoder(src, src_mask) 
    torch.onnx.export(model.encoder, 
                      (src, src_mask), 
                      save_path, 
                      export_params=True, 
                      opset_version=16, 
                      do_constant_folding=True, 
                      input_names=['src_data2', 'e_mask'], 
                      output_names=['e_output']) 
    return e_output

e_output = convert_encoder_part(model, src_data2, e_mask, "onnx_test\\encoder.onnx")

verbose: False, log level: Level.ERROR



In [29]:
last_words = torch.LongTensor([pad_id] * seq_len) # (L)
last_words[0] = sos_id # (L)
cur_len = 1

trg = last_words.unsqueeze(0)
        
d_mask = (last_words.unsqueeze(0) != pad_id).unsqueeze(1) # (1, 1, L)
nopeak_mask = torch.ones([1, seq_len, seq_len], dtype=torch.bool)  # (1, L, L)
nopeak_mask = torch.tril(nopeak_mask)  # (1, L, L) to triangular shape
d_mask = d_mask & nopeak_mask

def convert_embedding2_part(model, trg, save_path): 
    trg_embedded = model.trg_embedding(trg) 
    torch.onnx.export(model.trg_embedding, 
                      trg, 
                      save_path, 
                      export_params=True, 
                      opset_version=16, 
                      do_constant_folding=True, 
                      input_names=['trg'], 
                      output_names=['trg_embedded']) 
    return trg_embedded

trg_embedded = convert_embedding2_part(model, trg, "onnx_test\\trg_embedding.onnx")

verbose: False, log level: Level.ERROR



In [25]:
trg_positional_encoded = model.positional_encoder(trg_embedded) 

In [30]:
def convert_decoder_part(model, trg_positional_encoded, e_output, e_mask, d_mask, save_path): 
    decoder_output = model.decoder(
                trg_positional_encoded,
                e_output,
                e_mask,
                d_mask
            ) # (1, L, d_model) 
    torch.onnx.export(model.decoder, 
                      (trg_positional_encoded, e_output, e_mask, d_mask), 
                      save_path, 
                      export_params=True, 
                      opset_version=16, 
                      do_constant_folding=True, 
                      input_names=['trg_positional_encoded', 'e_output', 'e_mask', 'd_mask'], 
                      output_names=['decoder_output']) 
    return decoder_output

decoder_output = convert_decoder_part(model, trg_positional_encoded, e_output, e_mask, d_mask, "onnx_test\\decoder.onnx")

verbose: False, log level: Level.ERROR



## Load and check model

In [31]:
import onnx

embed1 = onnx.load("onnx_test\\src_embedding.onnx")
embed2 = onnx.load("onnx_test\\trg_embedding.onnx")
pe = onnx.load("onnx_test\\positional_encoder.onnx")
encoder = onnx.load("onnx_test\\encoder.onnx")
decoder = onnx.load("onnx_test\\decoder.onnx")

In [32]:
onnx.checker.check_model(embed1)
onnx.checker.check_model(embed2)
onnx.checker.check_model(pe)
onnx.checker.check_model(encoder)
onnx.checker.check_model(decoder)

In [33]:
onnx.helper.printable_graph(encoder.graph)

'graph torch_jit (\n  %src_data2[FLOAT, 1x256x512]\n  %e_mask[BOOL, 1x1x256]\n) initializers (\n  %layers.0.layer_norm_1.layer.weight[FLOAT, 512]\n  %layers.0.layer_norm_1.layer.bias[FLOAT, 512]\n  %layers.0.multihead_attention.w_q.bias[FLOAT, 512]\n  %layers.0.multihead_attention.w_k.bias[FLOAT, 512]\n  %layers.0.multihead_attention.w_v.bias[FLOAT, 512]\n  %layers.0.multihead_attention.w_0.bias[FLOAT, 512]\n  %layers.0.layer_norm_2.layer.weight[FLOAT, 512]\n  %layers.0.layer_norm_2.layer.bias[FLOAT, 512]\n  %layers.0.feed_forward.linear_1.bias[FLOAT, 2048]\n  %layers.0.feed_forward.linear_2.bias[FLOAT, 512]\n  %layers.1.layer_norm_1.layer.weight[FLOAT, 512]\n  %layers.1.layer_norm_1.layer.bias[FLOAT, 512]\n  %layers.1.multihead_attention.w_q.bias[FLOAT, 512]\n  %layers.1.multihead_attention.w_k.bias[FLOAT, 512]\n  %layers.1.multihead_attention.w_v.bias[FLOAT, 512]\n  %layers.1.multihead_attention.w_0.bias[FLOAT, 512]\n  %layers.1.layer_norm_2.layer.weight[FLOAT, 512]\n  %layers.1.laye

In [2]:
import onnxruntime

src_embed = onnxruntime.InferenceSession("onnx_test\\src_embedding.onnx")
trg_embed = onnxruntime.InferenceSession("onnx_test\\trg_embedding.onnx")
pe = onnxruntime.InferenceSession("onnx_test\\positional_encoder.onnx")
encoder = onnxruntime.InferenceSession("onnx_test\\encoder.onnx")
decoder = onnxruntime.InferenceSession("onnx_test\\decoder.onnx")


In [8]:
class Translator():
    def __init__(self, session) -> None:
        self.src_embed, self.trg_embed, self.pe, self.encoder, self.decoder = session
        self.src_sp = spm.SentencePieceProcessor()
        self.trg_sp = spm.SentencePieceProcessor()
        self.src_sp.load(f"{SP_DIR}/{src_model_prefix}.model")
        self.trg_sp.load(f"{SP_DIR}/{trg_model_prefix}.model")
        

    def translate(self, input_sentence, session, method="greedy"):
        tokenized = self.src_sp.EncodeAsIds(input_sentence)
        src = np.expand_dims(pad_or_truncate(tokenized), axis=0).astype('int64') # (1, L)
        e_mask = np.expand_dims((src != pad_id), axis=1) # (1, 1, L)
        print(src.shape)
        
        src_data_embed_input = { self.src_embed.get_inputs()[0].name: src }
        src_data_embed = self.src_embed.run(None, src_data_embed_input)[0]
        print(np.array(src_data_embed).shape)
        
        src_data_pe_input = { self.pe.get_inputs()[0].name: src_data_embed }
        src_data_pe = self.pe.run(None, src_data_pe_input)[0]
        print(np.array(src_data_pe).shape)
        
        e_output_input = { self.encoder.get_inputs()[0].name: src_data_pe, self.encoder.get_inputs()[1].name: e_mask}
        e_output = self.encoder.run(None, e_output_input)[0]
        print(np.array(e_output).shape)
        
        if method == 'greedy':
            print("Greedy decoding selected.")
            result = self.greedy_search(e_output, e_mask, trg_sp)
            
    def greedy_search(self, e_output, e_mask, trg_sp):
        last_words = torch.LongTensor([pad_id] * seq_len).to(device) # (L)
        last_words[0] = sos_id # (L)
        cur_len = 1

        for i in range(seq_len):
            d_mask = (last_words.unsqueeze(0) != pad_id).unsqueeze(1).to(device) # (1, 1, L)
            nopeak_mask = torch.ones([1, seq_len, seq_len], dtype=torch.bool).to(device)  # (1, L, L)
            nopeak_mask = torch.tril(nopeak_mask)  # (1, L, L) to triangular shape
            d_mask = d_mask & nopeak_mask  # (1, L, L) padding false

            trg_embedded = self.model.trg_embedding(last_words.unsqueeze(0))
            trg_positional_encoded = self.model.positional_encoder(trg_embedded)
            decoder_output = self.model.decoder(
                trg_positional_encoded,
                e_output,
                e_mask,
                d_mask
            ) # (1, L, d_model)

            output = self.model.softmax(
                self.model.output_linear(decoder_output)
            ) # (1, L, trg_vocab_size)

            output = torch.argmax(output, dim=-1) # (1, L)
            last_word_id = output[0][i].item()
            
            if i < seq_len-1:
                last_words[i+1] = last_word_id
                cur_len += 1
            
            if last_word_id == eos_id:
                break

        if last_words[-1].item() == pad_id:
            decoded_output = last_words[1:cur_len].tolist()
        else:
            decoded_output = last_words[1:].tolist()
        decoded_output = trg_sp.decode_ids(decoded_output)
        
        return decoded_output
    
    

(1, 256)
(1, 256, 512)
(1, 256, 512)
(1, 256, 512)


In [None]:
def translate_onnx(img, session, max_seq_length=128, sos_token=1, eos_token=2):
    """data: BxCxHxW"""
    cnn_session, encoder_session, decoder_session = session
    
    # create cnn input
    cnn_input = {cnn_session.get_inputs()[0].name: img}
    src = cnn_session.run(None, cnn_input)
    
    # create encoder input
    encoder_input = {encoder_session.get_inputs()[0].name: src[0]}
    encoder_outputs, hidden = encoder_session.run(None, encoder_input)
    translated_sentence = [[sos_token] * len(img)]
    max_length = 0

    while max_length <= max_seq_length and not all(
        np.any(np.asarray(translated_sentence).T == eos_token, axis=1)
    ):
        tgt_inp = translated_sentence
        decoder_input = {decoder_session.get_inputs()[0].name: tgt_inp[-1], decoder_session.get_inputs()[1].name: hidden, decoder_session.get_inputs()[2].name: encoder_outputs}

        output, hidden, _ = decoder_session.run(None, decoder_input)
        output = np.expand_dims(output, axis=1)
        output = torch.Tensor(output)

        values, indices = torch.topk(output, 1)
        indices = indices[:, -1, 0]
        indices = indices.tolist()

        translated_sentence.append(indices)
        max_length += 1

        del output

    translated_sentence = np.asarray(translated_sentence).T

    return translated_sentence