In [1]:
import torch
import torch.nn as nn

In [2]:
filepath = 'vanznames.txt'

In [3]:
with open(filepath, 'r') as f:
    names_raw = list(map(lambda x: x.replace('\n', ''), f.readlines()))
    
all_letters = list(set("".join(names_raw)))
n_letters = len(all_letters) + 1 # including EOS

def encode_name(name):
    return [all_letters.index(s) for s in name]

names_enc = [encode_name(name) for name in names_raw]

In [4]:
print(n_letters)
print(names_enc[0])

159
[51, 129, 28, 67, 30, 44, 19, 4, 96, 1, 4, 57, 4, 41, 57, 51, 4, 44, 141, 11]


In [5]:
# define network
class VanzNet(nn.Module):
    def __init__(self, input_size, hidden_size, output_size, dor=0.15, bid=True):
        super(VanzNet, self).__init__()
        self.hidden_size = hidden_size
        self.lstm = nn.LSTM(input_size, hidden_size, num_layers=3
                            , dropout=dor, bidirectional=bid)
        if bid:
            input_mult = 2
        else:
            input_mult = 1
        self.lstm2o = nn.Linear(hidden_size*input_mult, output_size)
        self.softmax = nn.LogSoftmax(dim=2)
    
    def forward(self, input, h, c):
        if h is None:            
            output_lstm, (h_out, c_out) = self.lstm(input)
        else:
            output_lstm, (h_out, c_out)  = self.lstm(input, (h, c))
        output_fc = self.lstm2o(output_lstm)
        output = self.softmax(output_fc)
        return output, h_out, c_out

In [6]:
# one-hot matrix for characters
def name2input(name_enc):
    """    
    dim: [len_name * 1 * n_letters]
    e.g. KASPAROV -> [OneHot(K), OneHot(A), ..., OneHot(V)]
    """
    tensor = torch.zeros(len(name_enc), 1, n_letters)
    for i, n in enumerate(name_enc):
        tensor[i][0][n] = 1
    return tensor

def name2target(name_enc):
    """
    dim: [len_name]
    e.g. target(KASPAROV) -> ASPAROV<EOS> -> [Idx(A), Idx(S), ..., Idx(EOS)]
    """
    return torch.LongTensor(name_enc[1:] + [n_letters - 1])

In [7]:
hidden_size = 128
rnn = VanzNet(n_letters, hidden_size, n_letters)

In [8]:
lr = 0.001
epochs = 200
print_every = 2
max_training_size = -1
lr_decay_step = 5
lr_decay_rate = 0.88

if max_training_size > 0:
    names_train = names_enc[:max_training_size]
else:
    names_train = names_enc

criterion = nn.NLLLoss()
optim = torch.optim.Adam(rnn.parameters(), lr=lr)

scheduler = torch.optim.lr_scheduler.StepLR(optim
                                            , lr_decay_step
                                            , lr_decay_rate)
losses = []
for epoch in range(epochs):
    loss_epoch = 0    
    for name in names_train:
        input_tensor = name2input(name)        
        target_tensor = name2target(name)
        target_tensor.unsqueeze_(-1)
        optim.zero_grad()
        loss = 0
        h, c = None, None
        for i in range(input_tensor.size(0)):
            output, h, c = rnn(input_tensor[i].view(1,1, -1), h, c)
            # print(output.size(), target_tensor[i].size())
            loss += criterion(output.view(1, -1), target_tensor[i])
            loss_epoch += loss
        loss_epoch /= input_tensor.size(0)
        loss.backward()
        optim.step()
    
    scheduler.step()
    losses.append(loss_epoch / len(names_train))
    
    if (epoch + 1) % print_every == 0:
        print("%d/%d: Loss %f" % (epoch+1, epochs, loss_epoch))
        # print(output.topk(1)[1])
        # print(target_tensor)
        # print("----------------")

2/200: Loss 1009775.000000
4/200: Loss 861443.687500
6/200: Loss 757746.687500
8/200: Loss 663345.562500
10/200: Loss 583207.250000
12/200: Loss 507072.718750
14/200: Loss 451895.968750
16/200: Loss 402872.937500
18/200: Loss 356772.750000
20/200: Loss 326195.437500
22/200: Loss 298238.875000


KeyboardInterrupt: 

In [16]:
def reconstruct_char(output):
    topv, topi = output.topk(1)
    idx = topi[0][0][0].item()
    # print(idx)
    if idx == n_letters - 1:
        # return all_letters[topi[0][0][1].item()]
        return 'EOS'
    else:
        return all_letters[idx]

def sample_name(start_char):
    name_sample = start_char
    if not (start_char in all_letters):
        return "Invalid start character!"
    else:
        start_char_enc = [all_letters.index(start_char)]
        input_tensor = name2input(start_char_enc)
        out_char = ""
        h, c = None, None
        
        while True:
            output, h, c = rnn(input_tensor[0].view(1,1,-1), h, c)
            out_char = reconstruct_char(output)
            if out_char == 'EOS':
                break
            name_sample += out_char
            input_tensor = name2input([all_letters.index(out_char)])
    
    return name_sample

In [17]:
start_letters = ['ก', 'ค', 'ม', 'จ', 'ว', 'ร', 'บ']
# start_letters = all_letters

with torch.no_grad():
    for sp in start_letters:
        print(sample_name(sp))

กิตติศักดิ์ กันทะวัง
คน บ้า
มินิ กองฟาง
จั้ก บุญจอง
วัน ทูกแป้ว
รักกันเมื่อยังหาย ใจ
บอย ก็กี
