In [11]:
import torch
import glob
import os
import unicodedata
import string
import torch.nn as nn
import random

### Pre process


In [12]:
all_letters = string.ascii_letters + " .,;'"
n_letters = len(all_letters)

def unicode2ascii(name):
    return ''.join(
        c for c in unicodedata.normalize('NFD', name)
        if unicodedata.category(c) != 'Mn'
        and c in all_letters)

In [13]:
langs_names = dict()
for path in glob.glob('data/names/*.txt'):
    names = open(path).read().strip().split()
    names = [unicode2ascii(name) for name in names]
    langs_names[os.path.basename(path).strip('.txt')] = names
    
n_langs = len(langs_names)
all_langs = list(langs_names.keys())

### Neural network

In [14]:
class RNN(nn.Module):
    def __init__(self, input_size, hidden_size, output_size):
        super().__init__()
        
        self.hidden_size = hidden_size
        
        self.i2h = nn.Linear(input_size + hidden_size, hidden_size)
        self.i2o = nn.Linear(input_size + hidden_size, output_size)
        
        self.softmax = nn.LogSoftmax(dim=1)
        
    def forward(self, input, hidden):
        combined = torch.cat((input, hidden), 1)
        hidden = self.i2h(combined)
        output = self.i2o(combined)
        output = self.softmax(output)
        return output, hidden
    
    def init_hidden(self):
        return torch.zeros(1, self.hidden_size)

### Helper 

In [15]:
def letter2tensor(letter):
    tensor = torch.zeros(1, n_letters)
    tensor[0][all_letters.find(letter)] = 1
    return tensor

In [16]:
def name2tensor(name):
    tensor = torch.zeros(len(name),1, n_letters)
    for i, letter in enumerate(name):
        tensor[i][0][all_letters.find(letter)] = 1
    return tensor

In [17]:
def random_train_examples():
    lang = random.choice(all_langs)
    name = random.choice(langs_names[lang])
    return lang, name

### Training

In [18]:
learning_rate = 0.005
criterion = nn.NLLLoss()

In [19]:
def train(lang_tensor, name_tensor):
    hidden = rnn.init_hidden()
    
    rnn.zero_grad()
    
    for i in range(name_tensor.size(0)):
        output, hidden = rnn(name_tensor[i], hidden)
    
    loss = criterion(output, lang_tensor)
    loss.backward()
    
    for p in rnn.parameters():
        p.data.add_(-learning_rate, p.grad.data)
        
    return output, loss.item()
    

In [68]:
n_iters = 100000
print_every = 5000

rnn = RNN(n_letters,256, n_langs)

for iter in range(1, n_iters + 1):
    lang, name = random_train_examples()
    lang_tensor = torch.tensor([all_langs.index(lang)], dtype=torch.long)
    name_tensor = name2tensor(name)
    output, loss = train(lang_tensor, name_tensor)
    
    if iter % print_every == 0:
        print('%d iter' % iter)

5000 iter
10000 iter
15000 iter
20000 iter
25000 iter
30000 iter
35000 iter
40000 iter
45000 iter
50000 iter
55000 iter
60000 iter
65000 iter
70000 iter
75000 iter
80000 iter
85000 iter
90000 iter
95000 iter
100000 iter


In [69]:
def evaluate(name_tensor):
    hidden = rnn.init_hidden()

    for i in range(name_tensor.size()[0]):
        output, hidden = rnn(name_tensor[i], hidden)

    return output

In [73]:
correct = 0
wrong = 0
for i in range(10000):
    lang, name = random_train_examples()
    lang_tensor = torch.tensor([all_langs.index(lang)], dtype=torch.long)
    name_tensor = name2tensor(name)
    output = evaluate(name_tensor)
    topv, topi = output.topk(1,1,True)
    lang_i = topi[0].item()
    if lang_i == all_langs.index(lang):
        correct += 1
    else:
        wrong +=1

correct*100/10000

60.41

In [71]:
def predict(input_name, n_predictions=3):
    print('\n> %s' % input_name)
    with torch.no_grad():
        output = evaluate(name2tensor(input_name))
        
        topv, topi = output.topk(n_predictions, 1, True)
        predictions = []
        
        for i in range(n_predictions):
            value = topv[0][i].item()
            lang_index = topi[0][i].item()
            print('(%.2f) %s' % (value, all_langs[lang_index]))
            predictions.append([value, all_langs[lang_index]])

In [76]:
predict('Satoshi')


> Satoshi
(-1.43) Arabic
(-1.51) Japanese
(-2.04) Polish
