In [1]:
%cd ..

/home/dinh.trong.huy/nmt-data-envija/transformer-translator-pytorch


In [11]:
from unicodedata import name
from tqdm import tqdm
from constants import *
from custom_data import *
from transformer import *
from data_structure import *
from torch import nn

In [12]:
model = Transformer(src_vocab_size=sp_src_vocab_size, trg_vocab_size=sp_trg_vocab_size, d_model=d_model).to(device)
checkpoint = torch.load(f"/home/dinh.trong.huy/nmt-data-envija/transformer-translator-pytorch/sq128_8_6_1_512_1024/javi2_mix_aug223k/ckpt_18_javi2.tar")
model.load_state_dict(checkpoint['model_state_dict'])

<All keys matched successfully>

In [13]:
def process_src(text_list):
    tokenized_list = []
    for text in tqdm(text_list):
        tokenized = src_sp.EncodeAsIds(text.strip())
        tokenized_list.append(len(tokenized)+1)

    return tokenized_list

In [14]:
def pad_or_truncate(tokenized_text):
    if len(tokenized_text) < seq_len:
        left = seq_len - len(tokenized_text)
        padding = [pad_id] * left
        tokenized_text += padding
    else:
        tokenized_text = tokenized_text[:seq_len]

    return tokenized_text

In [15]:
src_sp = spm.SentencePieceProcessor()
trg_sp = spm.SentencePieceProcessor()
src_sp.load(f"{SP_DIR}/{src_model_prefix}.model")
trg_sp.load(f"{SP_DIR}/{trg_model_prefix}.model")

True

In [16]:
input_sentence = "Hello my friend"
tokenized = src_sp.EncodeAsIds(input_sentence)
src = torch.LongTensor(pad_or_truncate(tokenized + [eos_id])).unsqueeze(0).to(device) # (1, L)
e_mask = (src != pad_id).unsqueeze(1).to(device) # (1, 1, L)


In [25]:
def convert_encoder(model, src, save_path):
    print(src.shape)
    e_output = model.encoder(src, e_mask) # (1, L, d_model)
    torch.onnx.export(model.encoder, 
                      (src, e_mask), 
                      save_path, 
                      export_params=True, 
                      opset_version=16, 
                      do_constant_folding=True, 
                      input_names=['src', 'e_mask'], 
                      output_names=['e_output']) 
    return e_output
    
e_output = convert_encoder(model, src, os.path.join(ONNX_DIR, "javi/encoder.onnx"))
e_output

torch.Size([1, 128])
verbose: False, log level: Level.ERROR



tensor([[[ 0.3900,  0.6810, -0.3820,  ...,  1.1953, -0.2641, -0.1094],
         [ 0.4117, -0.2298, -0.4929,  ...,  0.4729, -1.2949,  0.6192],
         [ 0.5429, -0.2892, -0.1731,  ...,  0.3445, -1.1674,  1.1886],
         ...,
         [-0.2749, -0.0511, -0.4211,  ...,  0.1454, -0.3272,  0.5160],
         [-0.1789, -0.1648, -0.6204,  ...,  0.2616, -0.2825,  0.4078],
         [-0.2815, -0.2188, -0.2826,  ...,  0.2891, -0.3639,  0.5546]]],
       device='cuda:1', grad_fn=<NativeLayerNormBackward0>)

In [26]:
last_words = torch.LongTensor([pad_id] * seq_len).to(device) # (L)
last_words[0] = sos_id # (L)
cur_len = 1

d_mask = (last_words.unsqueeze(0) != pad_id).unsqueeze(1).to(device) # (1, 1, L)
nopeak_mask = torch.ones([1, seq_len, seq_len], dtype=torch.bool).to(device)  # (1, L, L)
nopeak_mask = torch.tril(nopeak_mask)  # (1, L, L) to triangular shape
d_mask = d_mask & nopeak_mask  # (1, L, L) padding false

last_words_u = last_words.unsqueeze(0)

def convert_decoder(model, last_words_u, e_output, e_mask, d_mask, save_path): 
    
    decoder_output = model.decoder(
                last_words_u,
                e_output,
                e_mask,
                d_mask
            ) # (1, L, d_model) 
    torch.onnx.export(model.decoder, 
                      (last_words_u, e_output, e_mask, d_mask), 
                      save_path, 
                      export_params=True, 
                      opset_version=16, 
                      do_constant_folding=True, 
                      input_names=['last_words_u', 'e_output', 'e_mask', 'd_mask'], 
                      output_names=['decoder_output']) 
    return decoder_output

decoder_output = convert_decoder(model, last_words_u, e_output, e_mask, d_mask, os.path.join(ONNX_DIR, "javi/decoder.onnx"))
decoder_output

verbose: False, log level: Level.ERROR



tensor([[[-15.0229, -18.0776,  -8.5911,  ..., -18.0765, -18.0762, -18.0772],
         [  0.0000, -37.5369, -26.9227,  ..., -37.5363, -37.5363, -37.5370],
         [  0.0000, -37.6758, -27.1081,  ..., -37.6755, -37.6747, -37.6762],
         ...,
         [  0.0000, -37.4011, -26.6607,  ..., -37.4001, -37.4001, -37.4003],
         [  0.0000, -37.3517, -26.5201,  ..., -37.3515, -37.3512, -37.3516],
         [  0.0000, -37.3702, -26.6146,  ..., -37.3701, -37.3693, -37.3701]]],
       device='cuda:1', grad_fn=<LogSoftmaxBackward0>)

## Load and check model

In [27]:
import onnx
import os


encoder = onnx.load(os.path.join(ONNX_DIR, "javi/encoder.onnx"))
decoder = onnx.load(os.path.join(ONNX_DIR, "javi/decoder.onnx"))


In [28]:

onnx.checker.check_model(encoder)
onnx.checker.check_model(decoder)


In [29]:
onnx.helper.printable_graph(encoder.graph)

'graph torch_jit (\n  %src[INT64, 1x128]\n  %e_mask[BOOL, 1x1x128]\n) initializers (\n  %src_embedding.weight[FLOAT, 64000x512]\n  %layers.0.layer_norm_1.layer.weight[FLOAT, 512]\n  %layers.0.layer_norm_1.layer.bias[FLOAT, 512]\n  %layers.0.multihead_attention.w_q.bias[FLOAT, 512]\n  %layers.0.multihead_attention.w_k.bias[FLOAT, 512]\n  %layers.0.multihead_attention.w_v.bias[FLOAT, 512]\n  %layers.0.multihead_attention.w_0.bias[FLOAT, 512]\n  %layers.0.layer_norm_2.layer.weight[FLOAT, 512]\n  %layers.0.layer_norm_2.layer.bias[FLOAT, 512]\n  %layers.0.feed_forward.linear_1.bias[FLOAT, 1024]\n  %layers.0.feed_forward.linear_2.bias[FLOAT, 512]\n  %layers.1.layer_norm_1.layer.weight[FLOAT, 512]\n  %layers.1.layer_norm_1.layer.bias[FLOAT, 512]\n  %layers.1.multihead_attention.w_q.bias[FLOAT, 512]\n  %layers.1.multihead_attention.w_k.bias[FLOAT, 512]\n  %layers.1.multihead_attention.w_v.bias[FLOAT, 512]\n  %layers.1.multihead_attention.w_0.bias[FLOAT, 512]\n  %layers.1.layer_norm_2.layer.wei

In [30]:
import onnxruntime


encoder = onnxruntime.InferenceSession(os.path.join(ONNX_DIR, "javi/encoder.onnx"), providers=['CPUExecutionProvider'])
decoder = onnxruntime.InferenceSession(os.path.join(ONNX_DIR, "javi/decoder.onnx"), providers=['CPUExecutionProvider'])



In [31]:
import copy

class Translator():
    def __init__(self, session) -> None:
        self.encoder, self.decoder = session
        self.src_sp = spm.SentencePieceProcessor()
        self.trg_sp = spm.SentencePieceProcessor()
        self.src_sp.load(f"{SP_DIR}/{src_model_prefix}.model")
        self.trg_sp.load(f"{SP_DIR}/{trg_model_prefix}.model")
        

    def translate(self, input_sentence, method="greedy"):
        tokenized = self.src_sp.EncodeAsIds(input_sentence)
        src = np.expand_dims(pad_or_truncate(tokenized + [eos_id]), axis=0).astype('int64') # (1, L)
        e_mask = np.expand_dims((src != pad_id), axis=1) # (1, 1, L)
       
        e_output_input = { self.encoder.get_inputs()[0].name: src, self.encoder.get_inputs()[1].name: e_mask}
        e_output = self.encoder.run(None, e_output_input)[0]
        
        if method == "greedy":
            result = self.greedy_search(e_output, e_mask)
        elif method == 'beam':
        # print("Beam search selected.")
            result = self.beam_search(e_output, e_mask)
        else:
            raise ValueError("Method unsupported. Only support 'greedy' and 'beam' search")

        return result
        
    def greedy_search(self, e_output, e_mask):
        last_words = [pad_id] * seq_len
        last_words[0] = sos_id
        cur_len = 1

        for i in range(seq_len):
            lw_expand = np.expand_dims(last_words, axis=0)
            d_mask = np.expand_dims((lw_expand != pad_id), axis=1) # (1, 1, L)
            nopeak_mask = np.ones((1, seq_len, seq_len)).astype('bool')
            nopeak_mask = np.tril(nopeak_mask) # (1, L, L) to triangular shape
            d_mask = d_mask & nopeak_mask # (1, L, L) padding false

            decoder_input = {self.decoder.get_inputs()[0].name: lw_expand,
                                    self.decoder.get_inputs()[1].name: e_output,
                                    self.decoder.get_inputs()[2].name: e_mask,
                                    self.decoder.get_inputs()[3].name: d_mask}
            decoder_output = self.decoder.run(None, decoder_input)[0] # (1, L, trg_vocab_size)

            output = np.argmax(decoder_output, axis=-1)
            last_word_id = output[0][i].item()

            if i < seq_len-1:
                last_words[i+1] = last_word_id
                cur_len += 1
            
            if last_word_id == eos_id:
                break

        if last_words[-1] == pad_id:
            decoded_output = last_words[1:cur_len]
        else:
            decoded_output = last_words[1:]
        decoded_output = self.trg_sp.decode_ids(decoded_output)
        
        return decoded_output

    def beam_search(self, e_output, e_mask):
        cur_queue = PriorityQueue()
        for k in range(beam_size):
            cur_queue.put(BeamNode(sos_id, -0.0, [sos_id]))
        
        finished_count = 0
        
        for pos in range(seq_len):
            new_queue = PriorityQueue()
            for k in range(beam_size):
                node = cur_queue.get()
                if node.is_finished:
                    new_queue.put(node)
                else:
                    trg_input = node.decoded + [pad_id] * (seq_len - len(node.decoded)) # (L)
                    trg_input_expand = np.expand_dims(trg_input, axis=0) # (1, L)
                    d_mask = np.expand_dims((trg_input_expand != pad_id), axis=1) # (1, 1, L)
                    nopeak_mask = np.ones((1, seq_len, seq_len)).astype('bool')
                    nopeak_mask = np.tril(nopeak_mask) # (1, L, L) to triangular shape
                    d_mask = d_mask & nopeak_mask # (1, L, L) padding false
                    
                    decoder_input = {self.decoder.get_inputs()[0].name: trg_input_expand,
                                    self.decoder.get_inputs()[1].name: e_output,
                                    self.decoder.get_inputs()[2].name: e_mask,
                                    self.decoder.get_inputs()[3].name: d_mask}
                    output = self.decoder.run(None, decoder_input)[0] # (1, L, trg_vocab_size)

                    # output = self.model.decoder(
                    #     trg_input_expand,
                    #     e_output,
                    #     e_mask,
                    #     d_mask
                    # ) # (1, L, trg_vocab_size)
                    
                    output_prob, output_ind = self.topk(output[0][pos], k=beam_size, axis=-1)
                    last_word_ids = output_ind.tolist() # (k)
                    last_word_prob = output_prob.tolist() # (k)
                    
                    for i, idx in enumerate(last_word_ids):
                        new_node = BeamNode(idx, -(-node.prob + last_word_prob[i]), node.decoded + [idx])
                        if idx == eos_id:
                            new_node.prob = new_node.prob / float(len(new_node.decoded))
                            new_node.is_finished = True
                            finished_count += 1
                        new_queue.put(new_node)
            
            cur_queue = copy.deepcopy(new_queue)
            
            if finished_count == beam_size:
                break
        
        decoded_output = cur_queue.get().decoded
        
        if decoded_output[-1] == eos_id:
            decoded_output = decoded_output[1:-1]
        else:
            decoded_output = decoded_output[1:]
            
        return self.trg_sp.decode_ids(decoded_output)

    def pad_or_truncate(self, tokenized_text):
        if len(tokenized_text) < seq_len:
            left = seq_len - len(tokenized_text)
            padding = [pad_id] * left
            tokenized_text += padding
        else:
            tokenized_text = tokenized_text[:seq_len]

        return tokenized_text

    def topk(self, array, k, axis=-1, sorted=True):
        # Use np.argpartition is faster than np.argsort, but do not return the values in order
        # We use array.take because you can specify the axis
        partitioned_ind = (
            np.argpartition(array, -k, axis=axis)
            .take(indices=range(-k, 0), axis=axis)
        )
        # We use the newly selected indices to find the score of the top-k values
        partitioned_scores = np.take_along_axis(array, partitioned_ind, axis=axis)
        
        if sorted:
            # Since our top-k indices are not correctly ordered, we can sort them with argsort
            # only if sorted=True (otherwise we keep it in an arbitrary order)
            sorted_trunc_ind = np.flip(
                np.argsort(partitioned_scores, axis=axis), axis=axis
            )
            
            # We again use np.take_along_axis as we have an array of indices that we use to
            # decide which values to select
            ind = np.take_along_axis(partitioned_ind, sorted_trunc_ind, axis=axis)
            scores = np.take_along_axis(partitioned_scores, sorted_trunc_ind, axis=axis)
        else:
            ind = partitioned_ind
            scores = partitioned_scores
        
        return scores, ind

    
    

In [32]:
session = (encoder, decoder)
translator = Translator(session)
translator.translate("私の名前はYamada", method="beam")

'Tên tôi là yamada'