In [1]:
import os

In [2]:
HOME_PATH = './'

WEIGHT_PATH = os.path.join(HOME_PATH, 'weight')
if not os.path.isdir(WEIGHT_PATH):
    os.makedirs(WEIGHT_PATH)

## Model

In [3]:
import torch
import torch.nn as nn
import torch.nn.functional as F

In [4]:
from tqdm import tqdm

In [5]:
class Model(nn.Module):
    def __init__(self, vocab_size, embeding_size, hidden_size):
        super(Model, self).__init__()
        self.embeding = nn.Embedding(vocab_size, embeding_size)
        self.lstm = nn.LSTM(embeding_size, hidden_size, num_layers=2, batch_first=True, dropout=0.1)
        self.linear = nn.Linear(hidden_size, vocab_size)
        
    def forward(self, x, hidden_state=None):
        # x: BxS
        x = self.embeding(x) # BxSxE
        if hidden_state is None:
            x, hidden_state = self.lstm(x) # BxSx2H
        else:
            x, hidden_state = self.lstm(x, hidden_state) # BxSx2H
        x = F.relu(x)
        x = self.linear(x) # BxSxV
        return x, hidden_state
    
    def predict(self, x, hidden_state=None):
        x, hidden_state = self.forward(x, hidden_state)
        x = F.softmax(x, dim=-1) # BxSxV
        return x, hidden_state

In [15]:
tokenizer = torch.load(os.path.join(WEIGHT_PATH, 'vocab.h5'))

In [24]:
vocab_size = len(tokenizer.word_index) + 1
embedding_size = 200
hidden_size = 256

In [18]:
model = Model(vocab_size, embedding_size, hidden_size)

In [27]:
weight_param = torch.load(os.path.join(WEIGHT_PATH, 'model.h5'))
model.load_state_dict(weight_param)
model.eval()

Model(
  (embeding): Embedding(90, 200)
  (lstm): LSTM(200, 256, num_layers=2, batch_first=True, dropout=0.1)
  (linear): Linear(in_features=256, out_features=90, bias=True)
)

In [25]:
def update_topk(topk_seq, topk_score, new_hs, new_cs, probs, k, ix):
    # topk_seq: kxS
    # topk_score: kx1
    # topk_hs, topk_cs: ?xkx?
    # probs: kx1xV
    topk_probs, topk_ix = probs[:,-1,:].topk(k)
    topk_log = topk_probs.log()
    new_scores = topk_log + topk_score
    
    k_probs, k_ix = new_scores.view(-1).topk(k)
    row = k_ix // k
    col = k_ix % k
    
    topk_seq[:, :ix] = topk_seq[row, :ix]
    topk_seq[:, ix] = topk_ix[row, col]
    
    topk_hs = new_hs[:,row,:]
    topk_cs = new_cs[:,row,:]
    
    topk_score = k_probs.unsqueeze(1)
    return topk_seq, topk_score, topk_hs, topk_cs

In [65]:
def beam_search(start, model, k=3, max_len=30, device=-1):
    seq = tokenizer.texts_to_sequences([start]) # 1xS
    seq = torch.tensor(seq).long()
    topk_seq = torch.zeros(k, max_len).long()
#     topk_seq[:,:len(start)] = seq[0,:]
    prob, (hs, cs) = model.predict(seq) # 1xSxV, hs
    prob_k, idx_k = prob[:,-1,:].topk(k=k, dim=-1)
#     cix = len(start)
    cix = 0
    topk_seq[:,cix] = idx_k[0]
    cix+=1
    topk_score = torch.zeros(k, 1)
    topk_hs = torch.zeros(hs.size(0), k, hs.size(2))
    topk_cs = torch.zeros(cs.size(0), k, cs.size(2))
    topk_hs[:,:,:] = hs[:,-1,:].unsqueeze(1)
    topk_cs[:,:,:] = cs[:,-1,:].unsqueeze(1)
    
    res_seq = []
    res_score = []
    
    for i in range(cix, max_len):
        probs, (hs, cs) = model.predict(topk_seq[:,i-1].unsqueeze(-1), (topk_hs, topk_cs)) # 1xSxV, hs
        topk_seq, topk_score, topk_hs, topk_cs = update_topk(topk_seq, topk_score, hs, cs, probs, k, i)

        eos_tok = tokenizer.word_index[' ']
        eos_ix = (topk_seq==eos_tok).nonzero()
        
        seq_end_ix = eos_ix[:,0].numpy()
        seq_notend_ix = [t for t in range(k) if t not in seq_end_ix]
        
        if len(seq_end_ix) > 0:
            _seq = topk_seq[seq_end_ix]
            for s in _seq:
                res_seq.append(s.numpy())
            
            _score = topk_score[seq_end_ix]
            for s in _score:
                res_score.append(s.numpy()[0])
            
            topk_score = topk_score[seq_notend_ix].contiguous()
            topk_seq = topk_seq[seq_notend_ix].contiguous()
            topk_hs = topk_hs[:,seq_notend_ix,:].contiguous()
            topk_cs = topk_cs[:,seq_notend_ix,:].contiguous()
            
            k -= len(seq_end_ix)
            
        if k==0:
            break
       
    
    return res_seq, res_score

In [70]:
start = '^nguyễn tr'
with torch.no_grad():
    
    seq, sc = beam_search(start, model, k=10)

for s in convert_to_text(seq):print(start+s)

^nguyễn trà 
^nguyễn trí 
^nguyễn trần 
^nguyễn trờng 
^nguyễn trọng 
^nguyễn trình 
^nguyễn trung 
^nguyễn trường 
^nguyễn trungng 
^nguyễn trrường 


ần 
ờng 
ọng 
ường 
ungng 


In [57]:
sc

[-0.008477796, -0.7744576, -0.027882854, -0.24207884, -1.5635265]

In [45]:
def seq2text(seq):
    text = ''
    for s in seq:
        if s==0:
            break
        text += tokenizer.index_word[s]
    return text

In [48]:
def convert_to_text(seqs):
    res = []
    for s in seqs:
        t = seq2text(s)
        res.append(t)
    return res

In [None]:
def simple(start='^', max_len=25, end_char=' '):
    seq = tokenizer.texts_to_sequences([start])
    seq = torch.tensor(seq)
    prob, hs = model.predict(seq)
    name = start
    ix = prob[:,-1,:].argmax(dim=-1).item()
    name += tokenizer.index_word[ix]
    end_token = tokenizer.word_index[end_char]
    for i in range(len(start), max_len):
        seq = torch.tensor([[ix]])
        prob, hs = model.predict(seq, hs)
        ix = prob[:,-1,:].argmax(dim=-1).item()
        if ix == end_token:
            break
        name += tokenizer.index_word[ix]

    return name

In [None]:
def gen_name():
    name = ''
    seq = np.zeros((1,20,1), dtype=np.float)
    pred = model.predict(seq)
    ix = np.random.choice(list(range(NUM_CHAR)), p=pred.ravel())
    name = name + ix_to_char[ix]
    while len(name) < 50:
        seq = preprocess(name)
        pred = model.predict(seq)
        ix = np.random.choice(list(range(NUM_CHAR)), p=pred.ravel())
        name = name + ix_to_char[ix]
        if ix == 0:
            break
    return name 

In [None]:
simple('^nguyễn vă')

'^nguyễn văn'

In [None]:
def predict_next_char(inp, k=5):
    seq = tokenizer.texts_to_sequences([inp])
    seq = torch.tensor(seq)
    prob, hs = model.predict(seq)
    kprob, kix = prob[0,-1,:].topk(k)
    res = {tokenizer.index_word[i]:p for p,i in zip(kprob.detach().numpy(), kix.detach().numpy())}
    return res


In [None]:
def predict_next_char(s, k=10):
    seq = tokenizer.texts_to_sequences([s])
    seq = torch.tensor(seq)
    prob, hs = model.predict(seq)
    kprob, kix = prob[0,-1,:].topk(k)
    for p, i in zip(kprob.detach().numpy(), kix.detach().numpy()):
        print(tokenizer.index_word[i], '%.4f'%(p))

In [None]:
predict_next_char('^')

n 0.6883
u 0.0304
t 0.0278
m 0.0231
d 0.0207
h 0.0201
p 0.0184
y 0.0183
l 0.0156
o 0.0123


In [None]:
VOCAB_PATH = os.path.join(HOME_PATH, 'vocab.h5')

torch.save(tokenizer, VOCAB_PATH)

In [None]:
state = torch.load(os.path.join(WEIGHT_PATH, 'epoch_99.h5'))

In [None]:
model.load_state_dict(state['model'])

IncompatibleKeys(missing_keys=[], unexpected_keys=[])

In [None]:
state.keys()

dict_keys(['model', 'optim'])

In [None]:
t = state['model']