In [None]:
from architecture.transformer.transformer import create_transformer
from architecture.transformer import hparams
from helper import utils

from pythainlp.tokenize import word_tokenize
import torch
import numpy as np
import json
import collections

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

# Helper func

In [None]:
def load_vocab(vocab_path):
    with open(vocab_path, 'r') as f:
        vocab = json.load(f)
        vocab['src']['stoi'] = collections.defaultdict(lambda: 3, vocab['src']['stoi'])
    return vocab, len(vocab['src']['itos']), len(vocab['trg']['itos'])

def preprocess_input(query, stoi):
    query = word_tokenize(query)
    query = [stoi['<SOS>']] + [stoi[tok] for tok in query] + [stoi['<EOS>']]
    query = np.array([query])
    query = torch.from_numpy(query)
    return query

def gen_output(y_hat):
    gen = []
    for tok in y_hat[0]:
        if tok.item() == 2:
            break
        if tok.item() not in [0, 1, 2]:
            gen.append(vocab['trg']['itos'][str(tok.item())])
    gen = ''.join(gen)
    return gen

# Model and vocab path

In [None]:
VOCAB_PATH = 'data/dataset/vocab.json'
MODEL_PATH = 'model/512-6-8-2048.pt'

# Load model and vocab

In [None]:
vocab, src_size, tgt_size = load_vocab(VOCAB_PATH)

model = create_transformer(
    input_dim=src_size,
    output_dim=tgt_size,
    d_model=hparams.d_model,
    N=hparams.N,
    h=hparams.h,
    d_ff=hparams.d_ff,
    dropout=0.1,
    device=device
)

model.eval()
model.load_state_dict(torch.load(MODEL_PATH, map_location=device))

# Gen text

In [None]:
text = '........'

x = preprocess_input(text, vocab['src']['stoi'])

y_hat = model.predict(x , max_len=20)

gen_text = gen_output(y_hat)

print(gen_text)