# Classifying Names with a Character-Level RNN
This code is based on [this](http://pytorch.org/tutorials/intermediate/char_rnn_classification_tutorial.html) tutorial by Sean Robertson.
First download data from [here](https://downloads.pytorch.org/tutorial/data.zip), and unzip to a folder `$HOME/Download/data`.

## Imports and Definitions

In [None]:
import os
import string
import unicodedata
from collections import namedtuple

import math
import numpy as np
import pandas as pd

from glob import glob
from os.path import splitext, basename, exists

import torch
import torch.nn as nn
from torch.autograd import Variable
from torch.utils.data import Dataset, DataLoader

import time
import matplotlib.pyplot as plt
import matplotlib.ticker as ticker
from IPython.display import display, clear_output

cuda = torch.cuda.is_available()

all_letters = string.ascii_lowercase
n_letters = len(all_letters)

## Helper Functions

In [None]:
# Turn a Unicode string to plain ASCII, thanks to http://stackoverflow.com/a/518232/2809427
def uni2ascii(s):
    return ''.join(
        c for c in unicodedata.normalize('NFD', s).lower()
        if unicodedata.category(c) != 'Mn'
        and c in all_letters
    )
print(uni2ascii('Ślusàrski'))

In [None]:
# Read a file and split into lines
def read_surnames(filename):
    surnames_uni = open(filename, encoding='utf-8').read().strip().split('\n')
    return [uni2ascii(sur) for sur in surnames_uni]

In [None]:
def timeSince(since):
    now = time.time()
    s = now - since
    m = math.floor(s / 60)
    s -= m * 60
    return '%dm %ds' % (m, s)

In [None]:
# Find letter index from all_letters, e.g. "a" = 0
index_by_letter = {
    letter: i
    for i, letter in enumerate(all_letters)
}
def letterToIndex(letter):
    return all_letters.find(letter)

# Just for demonstration, turn a letter into a <1 x n_letters> Tensor
def tensor_from_letter(letter):
    tensor = torch.zeros(1, n_letters)
    tensor[0][index_by_letter[letter.lower()]] = 1
    return tensor

# Turn a line into a <line_length x 1 x n_letters>,
# or an array of one-hot letter vectors
def tensor_from_string(string):
    tensor = torch.zeros(len(string), 1, n_letters)
    for li, letter in enumerate(string.lower()):
        if letter in index_by_letter:
            tensor[li][0][index_by_letter[letter]] = 1
    return tensor

# j = tensor_from_letter('J')
# jones = tensor_from_string('Jones')
# jones

In [None]:
def language_from_proba(proba):
    max_vals, max_idx = proba.data.topk(1) # Tensor out of Variable with .data
    lang = max_idx[0][0]
    return all_languages[lang], lang

## Dataset

In [None]:
_example = namedtuple('example', 'surname language')

class Example(_example):

    def __new__(self, surname, language=None):
        return super().__new__(self, surname, language)
    
    def features(self, rpad=0):
        res = tensor_from_string(self.surname+' '*rpad)
        return res.cuda() if cuda else res
    
    @property
    def target(self):
        assert self.language is not None
        res = torch.LongTensor([all_languages.index(self.language)])
        return res.cuda() if cuda else res
    
    def __len__(self):
        return len(self.surname)
    
    def to_dict(self):
        return {'features': Variable(self.features()), 'target': self.target}
    
    @classmethod
    def from_namedtuple(cls, s):
        return cls(s.surname, s.language)
    
    def _repr_html_(self):
        disp = self.surname
        if self.language: disp += " (language: %s)"%self.language
        return disp

In [None]:
class SurnameDataset(Dataset):
    """Surname dataset by language/country."""

    def __init__(self, data):
        series = pd.concat({
            target: pd.Series(features)
            for target, features in data.items()
        })
        df = pd.DataFrame(series)
        df.columns = ['surname']
        df.index.names = ['language', 'number']
        df.reset_index(inplace=True)
        del df['number']
        df['example'] = df.apply(lambda x: Example.from_namedtuple(x), axis=1)
        self.languages = np.unique(df.language)
        df.reset_index()
        self.df = df
        self.surnames_by_language = data
        
    def __len__(self):
        return len(self.df)

    def __getitem__(self, idx):
        lang = self.languages[self.rand(len(self.languages))]
        surnames = self.surnames_by_language[lang]
        surn = surnames[self.rand(len(surnames))]
        return Example(surn, lang)
#         return self.df.example.iloc[idx]
    
    def _repr_html_(self):
        return self.df._repr_html_()
    
    def rand(self, n):
        return np.random.randint(0, n-1)
    
    def random_index(self):
        lang = self.languages[self.rand(len(self.languages))]
        surnames = self.surnames_by_language[lang]
        surn = surnames[self.rand(len(surnames))]
        dfl = df[(df.language==lang)]
        return dfl[(dfl.surname==surn)].index[0]

    def random_example(self):
        lang = self.languages[self.rand(len(self.languages))]
        surnames = self.surnames_by_language[lang]
        surn = surnames[self.rand(len(surnames))]
        return Example(surn, lang)

In [None]:
def collate(example_list):
    ordered = sorted(example_list, key=len, reverse=True)
    targets = torch.stack([E.target for E in ordered])
    lengths = [len(x) for x in ordered]
    max_length = lengths[0]
    pads = [max_length-len(x) for x in ordered]
    tensors = [E.features(pad) for E, pad in zip(ordered, pads)]
    batch = Variable(torch.stack(tensors)).squeeze(2)
    features = torch.nn.utils.rnn.pack_padded_sequence(batch, lengths, batch_first=True)
    return {'features': features, 'target': targets}

In [None]:
class MySampler(torch.utils.data.sampler.Sampler):

    def __init__(self, data_source):
        self.data_source = data_source

    def __iter__(self):
        for i in range(len(self.data_source)):
            yield self.data_source.random_index()
    
    def __len__(self):
        return len(self.data_source)

In [None]:
class RandomSampler(torch.utils.data.sampler.Sampler):
    """Samples elements randomly, without replacement.

    Arguments:
        data_source (Dataset): dataset to sample from
    """

    def __init__(self, data_source):
        self.data_source = data_source

    def __iter__(self):
        return iter(torch.randperm(len(self.data_source)).long())

    def __len__(self):
        return len(self.data_source)

## Network

In [None]:
class RNN(nn.Module):
    
    def __init__(self, input_size, hidden_size, output_size):
        super().__init__()
        self.hidden_size = hidden_size
        self.rnn = nn.RNN(input_size, hidden_size, batch_first=True)
        self.linear = nn.Linear(hidden_size, output_size)
        self.softmax = nn.LogSoftmax(dim=-1)

    def forward(self, X, hidden):
        X, hidden = self.rnn(X, hidden)
        X, batch_sizes = torch.nn.utils.rnn.pad_packed_sequence(X, batch_first=True)
        X = torch.stack([p[i-1] for p, i in zip(X, batch_sizes)])
        X = self.softmax(self.linear(X))
        return X
    
    def single(self, x, hidden):
        """x.shape = torch.size([1, letters_in_word, num_letters])"""
        x, h = self.rnn(x, hidden)
        x = self.linear(x[:, -1])
        return self.softmax(x)
        
    def init_hidden(self, batch_size=1):
        res = Variable(torch.zeros(1, batch_size, self.hidden_size))
        return res.cuda() if cuda else res

## Training the net

In [None]:
# Build the category_lines dictionary, a list of names per language
for dir in ('Download', 'Downloads'):
    path = os.path.join(os.environ['HOME'], dir)
    if os.path.exists(path):
        break

In [None]:
surnames_by_language = {
    splitext(basename(filename))[0]: read_surnames(filename)
    for filename in glob(os.path.join(path, 'data/names/*.txt'))
}

all_languages = list(surnames_by_language.keys())
n_languages = len(all_languages)
n_languages

In [None]:
n_hidden = 128
rnn = RNN(n_letters, n_hidden, n_languages)
if cuda:
    rnn = rnn.cuda()
dset = SurnameDataset(surnames_by_language)

In [None]:
log_likelihood = nn.NLLLoss()
learning_rate = 0.01
def train(example):
    features, target = example['features'], Variable(example['target'])
    batch_size = example['target'].shape[0]
    rnn.zero_grad()
    logits = rnn(features, rnn.init_hidden(batch_size))
#     print('target', target.shape)
#     print('logits', logits.shape)
    loss = log_likelihood(logits, target.squeeze(1))
    loss.backward()
    for p in rnn.parameters():
        p.data.add_(-learning_rate/batch_size, p.grad.data)
    return logits, loss.data[0]

In [None]:
epochs, epoch = 50, 0
print_every = 5000
plot_every = 1000

# Keep track of losses for plotting
current_loss, loss = 0, 0
all_losses = []

In [None]:
batch_size = 16
# sampler = MySampler(dset)
# sampler = RandomSampler(dset)
dataloader = DataLoader(dset, batch_size=batch_size, collate_fn=collate, num_workers=0)#, sampler=sampler)

In [None]:
opt = torch.optim.SGD(rnn.parameters(), lr=learning_rate)

In [None]:
start = time.time()

for epoch in range(1, epochs+1):
    current_loss = 0    
    for i, example in enumerate(dataloader):
        if cuda:
            example['features'] = example['features'].cuda()
            example['target'] = example['target'].cuda()
        this_batch_size = example['target'].shape[0]
        opt.zero_grad()
        logits = rnn(example['features'], rnn.init_hidden(this_batch_size))
        loss = log_likelihood(logits, Variable(example['target']).squeeze(1))
        loss.backward()
        opt.step()
        
        current_loss += loss*dataloader.batch_size

    # Print iter number, loss, name and guess
    E = dset.random_example()
    proba = rnn.single(Variable(E.features().transpose(0,1)), rnn.init_hidden())
    guess, _ = language_from_proba(proba)
    correct = '✓' if guess == E.language else '✗ (%s)' % E.language
    display('%d %d%% (%s) %.4f %s / %s %s' % (
        epoch, epoch / epochs * 100, timeSince(start), loss, E.surname, guess, correct))

    # Add current loss avg to list of losses
    
    all_losses.append(current_loss/(dataloader.batch_size*i))

In [None]:
# iterations = 100000
# print_every = 5000
# plot_every = 1000

# # Keep track of losses for plotting
# current_loss = 0
# all_losses = []
# start = time.time()

# for epoch in range(1, iterations + 1):
#     example = dset.random_example()
#     output, loss = train(example.to_dict())
#     current_loss += loss

#     # Print iter number, loss, name and guess
#     if epoch % print_every == 0:
#         if epoch % (print_every*5) == 0:
#             clear_output()
#         guess, guess_i = language_from_proba(output)
#         correct = '✓' if guess == example.language else '✗ (%s)' % example.language
#         display('%d %d%% (%s) %.4f %s / %s %s' % (
#             epoch, epoch / iterations * 100, timeSince(start), loss, example.surname, guess, correct))

#     # Add current loss avg to list of losses
#     if epoch % plot_every == 0:
#         all_losses.append(current_loss / plot_every)
#         current_loss = 0

In [None]:
plt.figure()
plt.plot(all_losses, 'ko-')
plt.xlabel("Epoch")
plt.ylabel("loss")
plt.show()

In [None]:
# Keep track of correct guesses in a confusion matrix
confusion = torch.zeros(n_languages, n_languages)
num_samples = 10000

# Go through a bunch of examples and record which are correctly guessed
for i in range(num_samples):
    if (i%1000) == 0:
        print(i)
    example = dset.random_example()
    var = Variable(example.features().transpose(0, 1))
    var = var.cuda() if cuda else var
    proba = rnn.single(var, rnn.init_hidden())
    
    _, guess = language_from_proba(proba)
    truth = all_languages.index(example.language)
    confusion[truth, guess] += 1

# Normalize by dividing every row by its sum
for i in range(n_languages):
    confusion[i] = confusion[i] / (10**-16+confusion[i].sum())

# Set up plot
fig = plt.figure(figsize=(10, 10))
ax = fig.add_subplot(111)
cax = ax.matshow(confusion.numpy())
fig.colorbar(cax)

# Set up axes

ax.set_xticklabels([''] + all_languages, rotation=90)
ax.set_yticklabels([''] + all_languages)

# Force label at every tick
ax.xaxis.set_major_locator(ticker.MultipleLocator(1))
ax.yaxis.set_major_locator(ticker.MultipleLocator(1))

# sphinx_gallery_thumbnail_number = 2
plt.show()