In [1]:
import string
import os
import random
import numpy as np
from unidecode import unidecode
from torch.utils.data import Dataset, DataLoader, BatchSampler
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.utils.data.dataset import random_split


In [2]:
NAMES_PATH = 'data/NameClassifyingPytorch/names'
ASCII = string.ascii_letters
ASCII_COUNT = len(ASCII)
COUNTRIES = os.listdir(NAMES_PATH)
BATCH_SIZE = 64

def char_onehot(c):
    if c not in ASCII:
        return []
    
    result = [0.0] * ASCII_COUNT    
    result[ASCII.index(c)] = 1.0
    return result

In [3]:
def ascii_name_onehot(name):
    name_onehot = []

    for c in name:
        c = char_onehot(c)
        if c:
            name_onehot.append(c)

    return name_onehot

In [4]:
# From path to a whole Dataset
class NameDataset(Dataset):
    def __init__(self, name_original, name_unicode, name_tensor, label):
        self.name_original = name_original  # Each row is a original name
        self.name_unicode = name_unicode    # Each row is a unicoded name
        self.name_tensor = name_tensor      # Each row is a name tensor
        self.label = label                  # Each row is a country label

    def __len__(self):
        return len(self.label)
    
    def __getitem__(self, idx):
        if isinstance(idx, list):
            return [self.name_original[i] for i in idx],\
                [self.name_unicode[i] for i in idx],\
                [self.name_tensor[i] for i in idx],\
                [self.label[i] for i in idx]
        
        return self.name_original[idx],\
                self.name_unicode[idx],\
                self.name_tensor[idx],\
                self.label[idx]

In [54]:
class NameRNN(nn.Module):
    def __init__(self, input_dim, output_dim, hidden_dim=256):
        super().__init__()
        self.rnn = nn.RNN(input_size=input_dim, hidden_size=hidden_dim, num_layers=1, nonlinearity='tanh', bias=True, batch_first=True)
        self.fcc = nn.Linear(in_features=hidden_dim, out_features=output_dim, bias=True)
        self.softmax = nn.LogSoftmax(dim=1)
    
    def forward(self, x):
        rnn_out, _ = self.rnn(x)
        rnn_out = rnn_out[:,-1,:] # batch_size, layer, dimensions
        rnn_logits = self.fcc(rnn_out)
        rnn_softmax = self.softmax(rnn_logits)
        return rnn_softmax

In [49]:
def get_dataset_dict(path):
    dataset_dict = {} # name_length to dataset
    
    for filename in os.listdir(path):
        with open(path + "/" + filename, encoding='utf-8') as f:
            for name in f.readlines():
                name = name.strip()
                nameunicode = unidecode(name.replace(" ", ""))
                name_onehot = ascii_name_onehot(nameunicode)
                if name_onehot:
                    name_len = len(name_onehot)
                    if name_len not in dataset_dict:
                        dataset_dict[len(name_onehot)] = [[], [], [], []]
                    
                    dataset_dict[name_len][0].append(name)
                    dataset_dict[name_len][1].append(nameunicode)
                    dataset_dict[name_len][2].append(name_onehot)
                    dataset_dict[name_len][3].append(COUNTRIES.index(filename))

    result = {}
    for length, ds in dataset_dict.items():
        ds[2] = torch.tensor(ds[2], dtype=torch.float)
        ds[3] = torch.tensor(ds[3], dtype=torch.long)
        result[length] = NameDataset(ds[0], ds[1], ds[2], ds[3])
    
    return result
# train_set, test_set = torch.utils.data.random_split(alldata, [.85, .15], 

In [50]:
dataset_dict = get_dataset_dict(NAMES_PATH)
dataloader_dict = {i: DataLoader(ds, batch_size=BATCH_SIZE, shuffle=True) for i, ds in dataset_dict.items()}

In [55]:
epochs = 200
model = NameRNN(input_dim=ASCII_COUNT, output_dim=len(COUNTRIES))
optimizer = optim.Adam(model.parameters(), lr=1e-3)
for epoch in range(epochs):
    epoch_loss = 0
    for _, dataloader in dataloader_dict.items():
        dl_loss = 0
        for dl in dataloader:
            origin_name, unicode_name, onehot_name, country_label = dl            
            x_out = model(onehot_name)
            loss = F.cross_entropy(x_out, country_label, reduction='mean')
            
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

            dl_loss += loss.item()
        epoch_loss += dl_loss / len(dataloader)
    if epoch % 10 == 0:
        print(f"Epoch={epoch}: Loss={epoch_loss}")

Epoch=0: Loss=26.920059372061147
Epoch=10: Loss=13.290016304908846
Epoch=20: Loss=6.065706416333093
Epoch=30: Loss=4.065374626176362
Epoch=40: Loss=4.545919353867724
Epoch=50: Loss=2.415995747396278
Epoch=60: Loss=5.273899072025392
Epoch=70: Loss=1.8179100722297281
Epoch=80: Loss=1.6721103832068798
Epoch=90: Loss=1.5358747446631724
Epoch=100: Loss=2.326447373587122
Epoch=110: Loss=1.4832171308222915
Epoch=120: Loss=1.3813077446240822
Epoch=130: Loss=1.2674762396364718
Epoch=140: Loss=1.2410983292330868
Epoch=150: Loss=1.5091312807385435
Epoch=160: Loss=1.2892449915147932
Epoch=170: Loss=1.3659690868831642
Epoch=180: Loss=1.9065759705607468
Epoch=190: Loss=2.2338998299310244


In [52]:
def make_prediction(name):
    model.eval()
    name = unidecode(name)
    name_onehot = ascii_name_onehot(name)
    name_onehot = torch.tensor(name_onehot)
    name_onehot = name_onehot.unsqueeze(0)  # A batch of one datapoint
    output = model(name_onehot)
    predicted_idx = output.argmax(dim=1)
    return COUNTRIES[predicted_idx.item()]

In [57]:
from collections import Counter

for file_name in os.listdir(NAMES_PATH):
    pred_countries = []
    with open(NAMES_PATH + "/" + file_name) as f:
        for line in f.readlines():
            name = line.strip().replace(" ", "")
            predicted_country = make_prediction(name)
            pred_countries.append(predicted_country)
    freq_dict = dict(Counter(pred_countries))
    if file_name == "German.txt":
        print(freq_dict)
    acc = freq_dict.get(file_name, 0) / len(pred_countries) * 100
    print(f"{file_name}: {acc:.2f}%")

Czech.txt: 96.53%
{'German.txt': 651, 'English.txt': 22, 'Irish.txt': 3, 'French.txt': 15, 'Vietnamese.txt': 1, 'Russian.txt': 5, 'Dutch.txt': 14, 'Czech.txt': 10, 'Arabic.txt': 1, 'Spanish.txt': 1, 'Korean.txt': 1}
German.txt: 89.92%
Arabic.txt: 97.65%
Japanese.txt: 99.29%
Chinese.txt: 88.43%
Vietnamese.txt: 82.19%
Russian.txt: 99.61%
French.txt: 90.61%
Irish.txt: 88.79%
English.txt: 95.28%
Spanish.txt: 79.87%
Greek.txt: 95.57%
Italian.txt: 98.03%
Portuguese.txt: 77.03%
Scottish.txt: 48.00%
Dutch.txt: 91.58%
Korean.txt: 89.36%
Polish.txt: 93.53%
