In [1]:
import string
import os
import random
import numpy as np
from unidecode import unidecode
from torch.utils.data import Dataset, DataLoader, BatchSampler
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.utils.data.dataset import random_split
from collections import Counter


In [2]:
NAMES_PATH = 'data/NameClassifyingPytorch/names'
ASCII = string.ascii_letters
ASCII_COUNT = len(ASCII)
COUNTRIES = os.listdir(NAMES_PATH)
BATCH_SIZE = 64

def char_onehot(c):
    if c not in ASCII:
        return []
    
    result = [0.0] * ASCII_COUNT    
    result[ASCII.index(c)] = 1.0
    return result

In [3]:
def ascii_name_onehot(name):
    name_onehot = []

    for c in name:
        c = char_onehot(c)
        if c:
            name_onehot.append(c)

    return name_onehot

In [4]:
# From path to a whole Dataset
class NameDataset(Dataset):
    def __init__(self, name_original, name_unicode, name_tensor, label):
        self.name_original = name_original  # Each row is a original name
        self.name_unicode = name_unicode    # Each row is a unicoded name
        self.name_tensor = name_tensor      # Each row is a name tensor
        self.label = label                  # Each row is a country label

    def __len__(self):
        return len(self.label)
    
    def __getitem__(self, idx):
        if isinstance(idx, list):
            return [self.name_original[i] for i in idx],\
                [self.name_unicode[i] for i in idx],\
                [self.name_tensor[i] for i in idx],\
                [self.label[i] for i in idx]
        
        return self.name_original[idx],\
                self.name_unicode[idx],\
                self.name_tensor[idx],\
                self.label[idx]

In [5]:
class NameRNN(nn.Module):
    def __init__(self, input_dim, output_dim, hidden_dim=256):
        super().__init__()
        self.rnn = nn.RNN(input_size=input_dim, hidden_size=hidden_dim, num_layers=1, nonlinearity='tanh', bias=True, batch_first=True)
        self.fcc = nn.Linear(in_features=hidden_dim, out_features=output_dim, bias=True)
        self.softmax = nn.LogSoftmax(dim=1)
    
    def forward(self, x):
        rnn_out, _ = self.rnn(x)
        rnn_out = rnn_out[:,-1,:] # batch_size, layer, dimensions
        rnn_logits = self.fcc(rnn_out)
        rnn_softmax = self.softmax(rnn_logits)
        return rnn_softmax

In [14]:
class NameGRU(nn.Module):
    def __init__(self, input_dim, output_dim, hidden_dim=256):
        super().__init__()
        self.gru = nn.GRU(input_size=input_dim, hidden_size=hidden_dim, num_layers=1, bias=True, batch_first=True)
        self.fcc = nn.Linear(in_features=hidden_dim, out_features=output_dim, bias=True)
        self.softmax = nn.LogSoftmax(dim=1)
    
    def forward(self, x):
        gru_out, _ = self.gru(x)
        gru_out = gru_out[:,-1,:] # batch_size, layer, dimensions
        gru_logits = self.fcc(gru_out)
        gru_softmax = self.softmax(gru_logits)
        return gru_softmax

In [7]:
def get_dataset_dict(path):
    dataset_dict = {} # name_length to dataset
    
    for filename in os.listdir(path):
        with open(path + "/" + filename, encoding='utf-8') as f:
            for name in f.readlines():
                name = name.strip()
                nameunicode = unidecode(name.replace(" ", ""))
                name_onehot = ascii_name_onehot(nameunicode)
                if name_onehot:
                    name_len = len(name_onehot)
                    if name_len not in dataset_dict:
                        dataset_dict[len(name_onehot)] = [[], [], [], []]
                    
                    dataset_dict[name_len][0].append(name)
                    dataset_dict[name_len][1].append(nameunicode)
                    dataset_dict[name_len][2].append(name_onehot)
                    dataset_dict[name_len][3].append(COUNTRIES.index(filename))

    result = {}
    for length, ds in dataset_dict.items():
        ds[2] = torch.tensor(ds[2], dtype=torch.float)
        ds[3] = torch.tensor(ds[3], dtype=torch.long)
        result[length] = NameDataset(ds[0], ds[1], ds[2], ds[3])
    
    return result
# train_set, test_set = torch.utils.data.random_split(alldata, [.85, .15], 

In [8]:
dataset_dict = get_dataset_dict(NAMES_PATH)
dataloader_dict = {i: DataLoader(ds, batch_size=BATCH_SIZE, shuffle=True) for i, ds in dataset_dict.items()}

In [9]:
def train(model):
    epochs = 200
    optimizer = optim.Adam(model.parameters(), lr=1e-3)
    for epoch in range(epochs):
        epoch_loss = 0
        for _, dataloader in dataloader_dict.items():
            dl_loss = 0
            for dl in dataloader:
                origin_name, unicode_name, onehot_name, country_label = dl            
                x_out = model(onehot_name)
                loss = F.cross_entropy(x_out, country_label, reduction='mean')
                
                optimizer.zero_grad()
                loss.backward()
                optimizer.step()

                dl_loss += loss.item()
            epoch_loss += dl_loss / len(dataloader)
        if epoch % 10 == 0:
            print(f"Epoch={epoch}: Loss={epoch_loss}")

In [10]:
def make_prediction(model, name):
    model.eval()
    name = unidecode(name)
    name_onehot = ascii_name_onehot(name)
    name_onehot = torch.tensor(name_onehot)
    name_onehot = name_onehot.unsqueeze(0)  # A batch of one datapoint
    output = model(name_onehot)
    predicted_idx = output.argmax(dim=1)
    return COUNTRIES[predicted_idx.item()]

In [11]:
def eval(model):
    for file_name in os.listdir(NAMES_PATH):
        pred_countries = []
        with open(NAMES_PATH + "/" + file_name) as f:
            for line in f.readlines():
                name = line.strip().replace(" ", "")
                predicted_country = make_prediction(model, name)
                pred_countries.append(predicted_country)
        freq_dict = dict(Counter(pred_countries))
        if file_name == "German.txt":
            print(freq_dict)
        acc = freq_dict.get(file_name, 0) / len(pred_countries) * 100
        print(f"{file_name}: {acc:.2f}%")

In [12]:
model = NameRNN(input_dim=ASCII_COUNT, output_dim=len(COUNTRIES))
train(model)
eval(model)

Epoch=0: Loss=29.466061256787633
Epoch=10: Loss=12.59384837373893
Epoch=20: Loss=7.039466823320813
Epoch=30: Loss=4.759995821461458
Epoch=40: Loss=2.7148205555454643
Epoch=50: Loss=2.2154061262300995
Epoch=60: Loss=1.8845877567073925
Epoch=70: Loss=1.8495852750742126
Epoch=80: Loss=1.4358441896432825
Epoch=90: Loss=1.4376517851140058
Epoch=100: Loss=1.483886699791049
Epoch=110: Loss=1.8674178134771795
Epoch=120: Loss=1.187523109188684
Epoch=130: Loss=6.488517623916673
Epoch=140: Loss=1.3353331479924058
Epoch=150: Loss=1.6194514183626527
Epoch=160: Loss=1.2919664841359806
Epoch=170: Loss=1.6317933917027188
Epoch=180: Loss=1.9819259822478283
Epoch=190: Loss=2.0955305625179914
Czech.txt: 76.30%
{'German.txt': 611, 'Russian.txt': 30, 'Dutch.txt': 6, 'English.txt': 57, 'French.txt': 8, 'Czech.txt': 8, 'Italian.txt': 1, 'Arabic.txt': 1, 'Greek.txt': 1, 'Portuguese.txt': 1}
German.txt: 84.39%
Arabic.txt: 86.80%
Japanese.txt: 93.54%
Chinese.txt: 85.45%
Vietnamese.txt: 82.19%
Russian.txt: 97.72

In [15]:
model = NameGRU(input_dim=ASCII_COUNT, output_dim=len(COUNTRIES))
train(model)
eval(model)

Epoch=0: Loss=26.738926375940018
Epoch=10: Loss=6.128811904692457
Epoch=20: Loss=3.1435142928273456
Epoch=30: Loss=2.0501439740400214
Epoch=40: Loss=1.5239607063984502
Epoch=50: Loss=1.2511156156344683
Epoch=60: Loss=1.139568912266375
Epoch=70: Loss=1.0275797698468845
Epoch=80: Loss=1.0744358611053677
Epoch=90: Loss=1.1332311138819713
Epoch=100: Loss=0.9249245838845859
Epoch=110: Loss=0.9096556757481913
Epoch=120: Loss=0.9099932961688206
Epoch=130: Loss=0.8713164116664696
Epoch=140: Loss=0.9617373589619271
Epoch=150: Loss=0.8914289252196089
Epoch=160: Loss=0.8614266479861493
Epoch=170: Loss=0.8676715695721071
Epoch=180: Loss=0.8046213090564028
Epoch=190: Loss=0.871670348164656
Czech.txt: 95.95%
{'German.txt': 670, 'French.txt': 4, 'Russian.txt': 4, 'Vietnamese.txt': 1, 'English.txt': 21, 'Polish.txt': 1, 'Korean.txt': 1, 'Arabic.txt': 1, 'Dutch.txt': 15, 'Chinese.txt': 2, 'Spanish.txt': 1, 'Czech.txt': 3}
German.txt: 92.54%
Arabic.txt: 100.00%
Japanese.txt: 99.50%
Chinese.txt: 94.40%
V