In [1]:
%cd ..
import sentencepiece
import torch
from constants import *
from custom_data import *
from transformer import *
from data_structure import *

/home/dinh.trong.huy/nmt-data-envija/transformer-translator-pytorch


In [2]:
model = Transformer(64000, 64000)
# checkpoint = torch.load("saved_model\\best_ckpt.tar", map_location=torch.device('cpu'))
model.load_state_dict(torch.load("saved_model/ckpt1.pt", map_location=torch.device('cpu')))

<All keys matched successfully>

In [None]:
model.eval()

In [2]:
src_sp = spm.SentencePieceProcessor()
trg_sp = spm.SentencePieceProcessor()
src_sp.load(f"{SP_DIR}/{src_model_prefix}.model")
trg_sp.load(f"{SP_DIR}/{trg_model_prefix}.model")

True

In [16]:
input_sentence = "Hello my friend"
tokenized = src_sp.EncodeAsIds(input_sentence)
src = torch.LongTensor(pad_or_truncate(tokenized)).unsqueeze(0).to(device) # (1, L)
e_mask = (src != pad_id).unsqueeze(1).to(device) # (1, 1, L)


In [21]:
def convert_embedding_part(model, src, save_path): 
    print(src.shape)
    src_data = model.src_embedding(src) 
    torch.onnx.export(model.src_embedding, 
                      src, 
                      save_path, 
                      export_params=True, 
                      opset_version=16, 
                      do_constant_folding=True, 
                      input_names=['src'], 
                      output_names=['src_data']) 
    return src_data.detach().numpy().shape

src_data = convert_embedding_part(model, src, "onnx_test\\src_embedding.onnx")
src_data

torch.Size([1, 256])


(1, 256, 512)

In [20]:


def convert_pe_part(model, src, save_path): 
    src_data2 = model.positional_encoder(src) 
    torch.onnx.export(model.positional_encoder, 
                      src, 
                      save_path, 
                      export_params=True, 
                      opset_version=16, 
                      do_constant_folding=True, 
                      input_names=['src_data'], 
                      output_names=['src_data2']) 
    return src_data2

src_data2 = convert_pe_part(model, src_data, "onnx_test\\positional_encoder.onnx")

verbose: False, log level: Level.ERROR



In [28]:
def convert_encoder_part(model, src, src_mask, save_path): 
    e_output = model.encoder(src, src_mask) 
    torch.onnx.export(model.encoder, 
                      (src, src_mask), 
                      save_path, 
                      export_params=True, 
                      opset_version=16, 
                      do_constant_folding=True, 
                      input_names=['src_data2', 'e_mask'], 
                      output_names=['e_output']) 
    return e_output

e_output = convert_encoder_part(model, src_data2, e_mask, "onnx_test\\encoder.onnx")

verbose: False, log level: Level.ERROR



In [29]:
last_words = torch.LongTensor([pad_id] * seq_len) # (L)
last_words[0] = sos_id # (L)
cur_len = 1

trg = last_words.unsqueeze(0)
        
d_mask = (last_words.unsqueeze(0) != pad_id).unsqueeze(1) # (1, 1, L)
nopeak_mask = torch.ones([1, seq_len, seq_len], dtype=torch.bool)  # (1, L, L)
nopeak_mask = torch.tril(nopeak_mask)  # (1, L, L) to triangular shape
d_mask = d_mask & nopeak_mask

def convert_embedding2_part(model, trg, save_path): 
    trg_embedded = model.trg_embedding(trg) 
    torch.onnx.export(model.trg_embedding, 
                      trg, 
                      save_path, 
                      export_params=True, 
                      opset_version=16, 
                      do_constant_folding=True, 
                      input_names=['trg'], 
                      output_names=['trg_embedded']) 
    return trg_embedded

trg_embedded = convert_embedding2_part(model, trg, "onnx_test\\trg_embedding.onnx")

verbose: False, log level: Level.ERROR



In [25]:
trg_positional_encoded = model.positional_encoder(trg_embedded) 

In [30]:
def convert_decoder_part(model, trg_positional_encoded, e_output, e_mask, d_mask, save_path): 
    decoder_output = model.decoder(
                trg_positional_encoded,
                e_output,
                e_mask,
                d_mask
            ) # (1, L, d_model) 
    torch.onnx.export(model.decoder, 
                      (trg_positional_encoded, e_output, e_mask, d_mask), 
                      save_path, 
                      export_params=True, 
                      opset_version=16, 
                      do_constant_folding=True, 
                      input_names=['trg_positional_encoded', 'e_output', 'e_mask', 'd_mask'], 
                      output_names=['decoder_output']) 
    return decoder_output

decoder_output = convert_decoder_part(model, trg_positional_encoded, e_output, e_mask, d_mask, "onnx_test\\decoder.onnx")

verbose: False, log level: Level.ERROR



In [5]:
def convert_linear_part(model, d_output, save_path): 
    l_output = model.output_linear(d_output) 
    torch.onnx.export(model.positional_encoder, 
                      d_output, 
                      save_path, 
                      export_params=True, 
                      opset_version=16, 
                      do_constant_folding=True, 
                      input_names=['d_output'], 
                      output_names=['l_output']) 
    return l_output

d_output = torch.randn(1, seq_len, d_model).to("cpu")
l_output = convert_linear_part(model, d_output, "onnx_test/output_linear.onnx")

verbose: False, log level: Level.ERROR



In [11]:
model.softmax

LogSoftmax(dim=-1)

In [12]:
def convert_softmax_part(model, l_output, save_path): 
    output = model.softmax(l_output) 
    torch.onnx.export(model.softmax, 
                      l_output, 
                      save_path, 
                      export_params=True, 
                      opset_version=16, 
                      do_constant_folding=True, 
                      input_names=['l_output'], 
                      output_names=['output']) 
    return output


output = convert_softmax_part(model, l_output, "onnx_test/softmax.onnx")

verbose: False, log level: Level.ERROR



## Load and check model

In [3]:
import onnx
import os

embed1 = onnx.load(os.path.join(ONNX_DIR, "src_embedding.onnx"))
embed2 = onnx.load(os.path.join(ONNX_DIR, "trg_embedding.onnx"))
pe = onnx.load(os.path.join(ONNX_DIR, "positional_encoder.onnx"))
encoder = onnx.load(os.path.join(ONNX_DIR, "encoder.onnx"))
decoder = onnx.load(os.path.join(ONNX_DIR, "decoder.onnx"))
linear = onnx.load(os.path.join(ONNX_DIR, "output_linear.onnx"))
softmax = onnx.load(os.path.join(ONNX_DIR, "softmax.onnx"))

In [4]:
onnx.checker.check_model(embed1)
onnx.checker.check_model(embed2)
onnx.checker.check_model(pe)
onnx.checker.check_model(encoder)
onnx.checker.check_model(decoder)
onnx.checker.check_model(linear)
onnx.checker.check_model(softmax)

In [None]:
onnx.helper.printable_graph(encoder.graph)

In [6]:
import onnxruntime

src_embed = onnxruntime.InferenceSession(os.path.join(ONNX_DIR, "src_embedding.onnx"))
trg_embed = onnxruntime.InferenceSession(os.path.join(ONNX_DIR, "trg_embedding.onnx"))
pe = onnxruntime.InferenceSession(os.path.join(ONNX_DIR, "positional_encoder.onnx"))
encoder = onnxruntime.InferenceSession(os.path.join(ONNX_DIR, "encoder.onnx"))
decoder = onnxruntime.InferenceSession(os.path.join(ONNX_DIR, "decoder.onnx"))
linear = onnxruntime.InferenceSession(os.path.join(ONNX_DIR, "output_linear.onnx"))
softmax = onnxruntime.InferenceSession(os.path.join(ONNX_DIR, "softmax.onnx"))


In [14]:
import copy

class Translator():
    def __init__(self, session) -> None:
        self.src_embed, self.trg_embed, self.pe, self.encoder, self.decoder, self.linear, self.softmax = session
        self.src_sp = spm.SentencePieceProcessor()
        self.trg_sp = spm.SentencePieceProcessor()
        self.src_sp.load(f"{SP_DIR}/{src_model_prefix}.model")
        self.trg_sp.load(f"{SP_DIR}/{trg_model_prefix}.model")
        

    def translate(self, input_sentence, method="greedy"):
        tokenized = self.src_sp.EncodeAsIds(input_sentence)
        src = np.expand_dims(pad_or_truncate(tokenized), axis=0).astype('int64') # (1, L)
        e_mask = np.expand_dims((src != pad_id), axis=1) # (1, 1, L)
        print(src.shape)
        
        src_embed_input = { self.src_embed.get_inputs()[0].name: src }
        src_embed = self.src_embed.run(None, src_embed_input)[0]
        print(np.array(src_embed).shape)
        
        src_pe_input = { self.pe.get_inputs()[0].name: src_embed }
        src_pe = self.pe.run(None, src_pe_input)[0]
        print(np.array(src_pe).shape)
        
        e_output_input = { self.encoder.get_inputs()[0].name: src_pe, self.encoder.get_inputs()[1].name: e_mask}
        e_output = self.encoder.run(None, e_output_input)[0]
        print(np.array(e_output).shape)
        
        if method == 'greedy':
            print("Greedy decoding selected.")
            result = self.greedy_search(e_output, e_mask, trg_sp)

        return result
        
    def greedy_search(self, e_output, e_mask, trg_sp):
        last_words = [pad_id] * seq_len
        last_words[0] = sos_id
        cur_len = 1

        for i in range(seq_len):
            lw_expand = np.expand_dims(last_words, axis=0)
            d_mask = np.expand_dims((lw_expand != pad_id), axis=1) # (1, 1, L)
            nopeak_mask = np.ones((1, seq_len, seq_len)).astype('bool')
            nopeak_mask = np.tril(nopeak_mask) # (1, L, L) to triangular shape
            d_mask = d_mask & nopeak_mask # (1, L, L) padding false

            trg_embed_input = { self.trg_embed.get_inputs()[0].name: lw_expand}
            trg_embed = self.trg_embed.run(None, trg_embed_input)[0]

            trg_pe_input = { self.pe.get_inputs()[0].name: trg_embed }
            trg_pe = self.pe.run(None, trg_pe_input)[0]

            decoder_input = {self.decoder.get_inputs()[0].name: trg_pe,
                                    self.decoder.get_inputs()[1].name: e_output,
                                    self.decoder.get_inputs()[2].name: e_mask,
                                    self.decoder.get_inputs()[3].name: d_mask}
            decoder_output = self.decoder.run(None, decoder_input)[0]

            linear_input = { self.linear.get_inputs()[0].name: decoder_output}
            linear_output = self.linear.run(None, linear_input)[0]

            softmax_input = { self.linear.get_inputs()[0].name: linear_output}
            softmax_output = self.linear.run(None, softmax_input)[0]

            output = np.argmax(softmax_output, axis=-1)
            last_word_id = output[0][i].item()

            if i < seq_len-1:
                last_words[i+1] = last_word_id
                cur_len += 1
            
            if last_word_id == eos_id:
                break

        if last_words[-1] == pad_id:
            decoded_output = last_words[1:cur_len]
        else:
            decoded_output = last_words[1:]
        decoded_output = trg_sp.decode_ids(decoded_output)
        
        return decoded_output

    def beam_search(self, e_output, e_mask, trg_sp):
        cur_queue = PriorityQueue()
        for k in range(beam_size):
            cur_queue.put(BeamNode(sos_id, -0.0, [sos_id]))
        
        finished_count = 0
        
        for pos in range(seq_len):
            new_queue = PriorityQueue()
            for k in range(beam_size):
                node = cur_queue.get()
                if node.is_finished:
                    new_queue.put(node)
                else:
                    trg_input = torch.LongTensor(node.decoded + [pad_id] * (seq_len - len(node.decoded))).to(device) # (L)
                    d_mask = (trg_input.unsqueeze(0) != pad_id).unsqueeze(1).to(device) # (1, 1, L)
                    nopeak_mask = torch.ones([1, seq_len, seq_len], dtype=torch.bool).to(device)
                    nopeak_mask = torch.tril(nopeak_mask) # (1, L, L) to triangular shape
                    d_mask = d_mask & nopeak_mask # (1, L, L) padding false
                    
                    output = self.model.decoder(
                        trg_input.unsqueeze(0),
                        e_output,
                        e_mask,
                        d_mask
                    ) # (1, L, trg_vocab_size)
                    
                    output = torch.topk(output[0][pos], dim=-1, k=beam_size)
                    last_word_ids = output.indices.tolist() # (k)
                    last_word_prob = output.values.tolist() # (k)
                    
                    for i, idx in enumerate(last_word_ids):
                        new_node = BeamNode(idx, -(-node.prob + last_word_prob[i]), node.decoded + [idx])
                        if idx == eos_id:
                            new_node.prob = new_node.prob / float(len(new_node.decoded))
                            new_node.is_finished = True
                            finished_count += 1
                        new_queue.put(new_node)
            
            cur_queue = copy.deepcopy(new_queue)
            
            if finished_count == beam_size:
                break
        
        decoded_output = cur_queue.get().decoded
        
        if decoded_output[-1] == eos_id:
            decoded_output = decoded_output[1:-1]
        else:
            decoded_output = decoded_output[1:]
            
        return trg_sp.decode_ids(decoded_output)

    
    

In [15]:
session = (src_embed, trg_embed, pe, encoder, decoder, linear, softmax)
translator = Translator(session)
translator.translate("Hello")

(1, 256)
(1, 256, 512)
(1, 256, 512)
(1, 256, 512)
Greedy decoding selected.


'nóiợp nóiợp nóiợp nói nóiợp nóiợp nói nói nói nói nói nói nói nói nói nói nói nói nói nói nói nói nói nói nói nói nói nói nói nói nói nói nói nói nói nói nói nói nói nói nói nói nói nói nói nói nói nói nói nói nói nói nói nói nói nói nói nói nói nói nói nói nói nói nói nói nói nói nói nói nói nói nói nói nói nói nói nói nói nói nói nói nói nói nói nói nói nói nói nói nói nói nói nói nói nói nói nói nói nói nói nói nói nói nói nói nói nói nói nói nói nói nói nói nói nói nói nói nói nói nói nói nói nói nói nói nói nói nói nói nói nói nói nói nói nói nói nói nói nói nói nói nói nói nói nói nói nói nói nói nói nói nói nói nói nói nói nói nói nói nói nói nói nói nói nói nói nói nói nói nói nói nói nói nói nói nói nói nói nói nói nói nói nói nói nói nói nói nói nói nói nói nói nói nói nói nói nói nói nói nói nói nói nói nói nói nói nói nói nói nói nói nói nói nói nói nói nói nói nói nói nói nói nói nói nói nói nói nói nói nói nói nói nói nói nói nói nói nói nói nói nói nói nói nói nói nói n

In [None]:
def translate_onnx(img, session, max_seq_length=128, sos_token=1, eos_token=2):
    """data: BxCxHxW"""
    cnn_session, encoder_session, decoder_session = session
    
    # create cnn input
    cnn_input = {cnn_session.get_inputs()[0].name: img}
    src = cnn_session.run(None, cnn_input)
    
    # create encoder input
    encoder_input = {encoder_session.get_inputs()[0].name: src[0]}
    encoder_outputs, hidden = encoder_session.run(None, encoder_input)
    translated_sentence = [[sos_token] * len(img)]
    max_length = 0

    while max_length <= max_seq_length and not all(
        np.any(np.asarray(translated_sentence).T == eos_token, axis=1)
    ):
        tgt_inp = translated_sentence
        decoder_input = {decoder_session.get_inputs()[0].name: tgt_inp[-1], decoder_session.get_inputs()[1].name: hidden, decoder_session.get_inputs()[2].name: encoder_outputs}

        output, hidden, _ = decoder_session.run(None, decoder_input)
        output = np.expand_dims(output, axis=1)
        output = torch.Tensor(output)

        values, indices = torch.topk(output, 1)
        indices = indices[:, -1, 0]
        indices = indices.tolist()

        translated_sentence.append(indices)
        max_length += 1

        del output

    translated_sentence = np.asarray(translated_sentence).T

    return translated_sentence

In [16]:
def topk(array, k, axis=-1, sorted=True):
    # Use np.argpartition is faster than np.argsort, but do not return the values in order
    # We use array.take because you can specify the axis
    partitioned_ind = (
        np.argpartition(array, -k, axis=axis)
        .take(indices=range(-k, 0), axis=axis)
    )
    # We use the newly selected indices to find the score of the top-k values
    partitioned_scores = np.take_along_axis(array, partitioned_ind, axis=axis)
    
    if sorted:
        # Since our top-k indices are not correctly ordered, we can sort them with argsort
        # only if sorted=True (otherwise we keep it in an arbitrary order)
        sorted_trunc_ind = np.flip(
            np.argsort(partitioned_scores, axis=axis), axis=axis
        )
        
        # We again use np.take_along_axis as we have an array of indices that we use to
        # decide which values to select
        ind = np.take_along_axis(partitioned_ind, sorted_trunc_ind, axis=axis)
        scores = np.take_along_axis(partitioned_scores, sorted_trunc_ind, axis=axis)
    else:
        ind = partitioned_ind
        scores = partitioned_scores
    
    return scores, ind

import torch
import numpy as np

x = np.random.randn(50, 50, 10, 10)

axis = 2  # Change this to any axis and it'll be fine

val_np, ind_np = topk(x, k=10, axis=axis)

val_pt, ind_pt = torch.topk(torch.tensor(x), k=10, dim=axis)

print("Values are same:", np.all(val_np == val_pt.numpy()))
print("Indices are same:", np.all(ind_np == ind_pt.numpy()))

Values are same: True
Indices are same: True
