In [1]:
%cd ..

/home/dinh.trong.huy/nmt-data-envija/transformer-translator-pytorch


In [3]:
import torch
import sentencepiece as spm
import numpy as np
from constants import *
from data_structure import *
from tqdm import tqdm

In [5]:
def process_src(text_list):
    tokenized_list = []
    for text in tqdm(text_list):
        tokenized = src_sp.EncodeAsIds(text.strip())
        tokenized_list.append(len(tokenized)+1)

    return tokenized_list

In [5]:
def pad_or_truncate(tokenized_text):
    if len(tokenized_text) < seq_len:
        left = seq_len - len(tokenized_text)
        padding = [pad_id] * left
        tokenized_text += padding
    else:
        tokenized_text = tokenized_text[:seq_len]

    return tokenized_text

In [4]:
src_sp = spm.SentencePieceProcessor()
trg_sp = spm.SentencePieceProcessor()
src_sp.load(f"{SP_DIR}/{src_model_prefix}.model")
trg_sp.load(f"{SP_DIR}/{trg_model_prefix}.model")

True

In [6]:
input_sentence = "Hello my friend"
tokenized = src_sp.EncodeAsIds(input_sentence)
src = torch.LongTensor(pad_or_truncate(tokenized + [eos_id])).unsqueeze(0).to(device) # (1, L)
e_mask = (src != pad_id).unsqueeze(1).to(device) # (1, 1, L)


In [None]:
def convert_encoder(model, src, save_path):
    print(src.shape)
    e_output = model.encoder(src, e_mask) # (1, L, d_model)
    torch.onnx.export(model.encoder, 
                      (src, e_mask), 
                      save_path, 
                      export_params=True, 
                      opset_version=16, 
                      do_constant_folding=True, 
                      input_names=['src', 'e_mask'], 
                      output_names=['e_output']) 
    return e_output
    
e_output = convert_encoder(model, src, os.path.join(ONNX_DIR, "encoder.onnx"))
e_output

In [8]:
last_words = torch.LongTensor([pad_id] * seq_len).to(device) # (L)
last_words[0] = sos_id # (L)
cur_len = 1

d_mask = (last_words.unsqueeze(0) != pad_id).unsqueeze(1).to(device) # (1, 1, L)
nopeak_mask = torch.ones([1, seq_len, seq_len], dtype=torch.bool).to(device)  # (1, L, L)
nopeak_mask = torch.tril(nopeak_mask)  # (1, L, L) to triangular shape
d_mask = d_mask & nopeak_mask  # (1, L, L) padding false

last_words_u = last_words.unsqueeze(0)

def convert_decoder(model, last_words_u, e_output, e_mask, d_mask, save_path): 
    
    decoder_output = model.decoder(
                last_words_u,
                e_output,
                e_mask,
                d_mask
            ) # (1, L, d_model) 
    torch.onnx.export(model.decoder, 
                      (last_words_u, e_output, e_mask, d_mask), 
                      save_path, 
                      export_params=True, 
                      opset_version=16, 
                      do_constant_folding=True, 
                      input_names=['last_words_u', 'e_output', 'e_mask', 'd_mask'], 
                      output_names=['decoder_output']) 
    return decoder_output

decoder_output = convert_decoder(model, last_words_u, e_output, e_mask, d_mask, os.path.join(ONNX_DIR, "encoder.onnx"))
decoder_output

verbose: False, log level: Level.ERROR



tensor([[[-1.7998e+01, -1.9108e+01, -7.0844e+00,  ..., -1.9106e+01,
          -1.9105e+01, -1.9106e+01],
         [-1.1921e-07, -3.3358e+01, -2.0526e+01,  ..., -3.3356e+01,
          -3.3356e+01, -3.3357e+01],
         [-1.1921e-07, -3.3367e+01, -2.0561e+01,  ..., -3.3365e+01,
          -3.3365e+01, -3.3366e+01],
         ...,
         [-2.3842e-07, -3.3306e+01, -2.0390e+01,  ..., -3.3304e+01,
          -3.3304e+01, -3.3305e+01],
         [-2.3842e-07, -3.3319e+01, -2.0412e+01,  ..., -3.3318e+01,
          -3.3317e+01, -3.3318e+01],
         [-2.3842e-07, -3.3333e+01, -2.0431e+01,  ..., -3.3331e+01,
          -3.3331e+01, -3.3332e+01]]], grad_fn=<LogSoftmaxBackward0>)

## Load and check model

In [6]:
import onnx
import os


encoder = onnx.load(os.path.join(ONNX_DIR, "encoder.onnx"))
decoder = onnx.load(os.path.join(ONNX_DIR, "decoder.onnx"))


In [7]:

onnx.checker.check_model(encoder)
onnx.checker.check_model(decoder)


In [None]:
onnx.helper.printable_graph(encoder.graph)

In [9]:
import onnxruntime


encoder = onnxruntime.InferenceSession(os.path.join(ONNX_DIR, "encoder.onnx"))
decoder = onnxruntime.InferenceSession(os.path.join(ONNX_DIR, "decoder.onnx"))



In [30]:
import copy

class Translator():
    def __init__(self, session) -> None:
        self.encoder, self.decoder = session
        self.src_sp = spm.SentencePieceProcessor()
        self.trg_sp = spm.SentencePieceProcessor()
        self.src_sp.load(f"{SP_DIR}/{src_model_prefix}.model")
        self.trg_sp.load(f"{SP_DIR}/{trg_model_prefix}.model")
        

    def translate(self, input_sentence, method="greedy"):
        tokenized = self.src_sp.EncodeAsIds(input_sentence)
        src = np.expand_dims(pad_or_truncate(tokenized + [eos_id]), axis=0).astype('int64') # (1, L)
        e_mask = np.expand_dims((src != pad_id), axis=1) # (1, 1, L)
       
        e_output_input = { self.encoder.get_inputs()[0].name: src, self.encoder.get_inputs()[1].name: e_mask}
        e_output = self.encoder.run(None, e_output_input)[0]
        
        if method == "greedy":
            result = self.greedy_search(e_output, e_mask)
        elif method == 'beam':
        # print("Beam search selected.")
            result = self.beam_search(e_output, e_mask)
        else:
            raise ValueError("Method unsupported. Only support 'greedy' and 'beam' search")

        return result
        
    def greedy_search(self, e_output, e_mask):
        last_words = [pad_id] * seq_len
        last_words[0] = sos_id
        cur_len = 1

        for i in range(seq_len):
            lw_expand = np.expand_dims(last_words, axis=0)
            d_mask = np.expand_dims((lw_expand != pad_id), axis=1) # (1, 1, L)
            nopeak_mask = np.ones((1, seq_len, seq_len)).astype('bool')
            nopeak_mask = np.tril(nopeak_mask) # (1, L, L) to triangular shape
            d_mask = d_mask & nopeak_mask # (1, L, L) padding false

            decoder_input = {self.decoder.get_inputs()[0].name: lw_expand,
                                    self.decoder.get_inputs()[1].name: e_output,
                                    self.decoder.get_inputs()[2].name: e_mask,
                                    self.decoder.get_inputs()[3].name: d_mask}
            decoder_output = self.decoder.run(None, decoder_input)[0] # (1, L, trg_vocab_size)

            output = np.argmax(decoder_output, axis=-1)
            last_word_id = output[0][i].item()

            if i < seq_len-1:
                last_words[i+1] = last_word_id
                cur_len += 1
            
            if last_word_id == eos_id:
                break

        if last_words[-1] == pad_id:
            decoded_output = last_words[1:cur_len]
        else:
            decoded_output = last_words[1:]
        decoded_output = self.trg_sp.decode_ids(decoded_output)
        
        return decoded_output

    def beam_search(self, e_output, e_mask):
        cur_queue = PriorityQueue()
        for k in range(beam_size):
            cur_queue.put(BeamNode(sos_id, -0.0, [sos_id]))
        
        finished_count = 0
        
        for pos in range(seq_len):
            new_queue = PriorityQueue()
            for k in range(beam_size):
                node = cur_queue.get()
                if node.is_finished:
                    new_queue.put(node)
                else:
                    trg_input = node.decoded + [pad_id] * (seq_len - len(node.decoded)) # (L)
                    trg_input_expand = np.expand_dims(trg_input, axis=0) # (1, L)
                    d_mask = np.expand_dims((trg_input_expand != pad_id), axis=1) # (1, 1, L)
                    nopeak_mask = np.ones((1, seq_len, seq_len)).astype('bool')
                    nopeak_mask = np.tril(nopeak_mask) # (1, L, L) to triangular shape
                    d_mask = d_mask & nopeak_mask # (1, L, L) padding false
                    
                    decoder_input = {self.decoder.get_inputs()[0].name: trg_input_expand,
                                    self.decoder.get_inputs()[1].name: e_output,
                                    self.decoder.get_inputs()[2].name: e_mask,
                                    self.decoder.get_inputs()[3].name: d_mask}
                    output = self.decoder.run(None, decoder_input)[0] # (1, L, trg_vocab_size)

                    # output = self.model.decoder(
                    #     trg_input_expand,
                    #     e_output,
                    #     e_mask,
                    #     d_mask
                    # ) # (1, L, trg_vocab_size)
                    
                    output_prob, output_ind = self.topk(output[0][pos], k=beam_size, axis=-1)
                    last_word_ids = output_ind.tolist() # (k)
                    last_word_prob = output_prob.tolist() # (k)
                    
                    for i, idx in enumerate(last_word_ids):
                        new_node = BeamNode(idx, -(-node.prob + last_word_prob[i]), node.decoded + [idx])
                        if idx == eos_id:
                            new_node.prob = new_node.prob / float(len(new_node.decoded))
                            new_node.is_finished = True
                            finished_count += 1
                        new_queue.put(new_node)
            
            cur_queue = copy.deepcopy(new_queue)
            
            if finished_count == beam_size:
                break
        
        decoded_output = cur_queue.get().decoded
        
        if decoded_output[-1] == eos_id:
            decoded_output = decoded_output[1:-1]
        else:
            decoded_output = decoded_output[1:]
            
        return self.trg_sp.decode_ids(decoded_output)

    def pad_or_truncate(self, tokenized_text):
        if len(tokenized_text) < seq_len:
            left = seq_len - len(tokenized_text)
            padding = [pad_id] * left
            tokenized_text += padding
        else:
            tokenized_text = tokenized_text[:seq_len]

        return tokenized_text

    def topk(self, array, k, axis=-1, sorted=True):
        # Use np.argpartition is faster than np.argsort, but do not return the values in order
        # We use array.take because you can specify the axis
        partitioned_ind = (
            np.argpartition(array, -k, axis=axis)
            .take(indices=range(-k, 0), axis=axis)
        )
        # We use the newly selected indices to find the score of the top-k values
        partitioned_scores = np.take_along_axis(array, partitioned_ind, axis=axis)
        
        if sorted:
            # Since our top-k indices are not correctly ordered, we can sort them with argsort
            # only if sorted=True (otherwise we keep it in an arbitrary order)
            sorted_trunc_ind = np.flip(
                np.argsort(partitioned_scores, axis=axis), axis=axis
            )
            
            # We again use np.take_along_axis as we have an array of indices that we use to
            # decide which values to select
            ind = np.take_along_axis(partitioned_ind, sorted_trunc_ind, axis=axis)
            scores = np.take_along_axis(partitioned_scores, sorted_trunc_ind, axis=axis)
        else:
            ind = partitioned_ind
            scores = partitioned_scores
        
        return scores, ind

    
    

In [31]:
session = (encoder, decoder)
translator = Translator(session)
translator.translate("Hello my name is John", method="beam")

'Xin chào tên tôi là John'