Based on the tutorial by Sean Robertson <a href="https://pytorch.org/tutorials/intermediate/char_rnn_generation_tutorial.html">here</a>.

In [11]:
import glob
import os
import random
import string
import unicodedata
from io import open

import torch
import torch.nn as nn

In [2]:
DATA_DIR = '../../data'

In [3]:
all_letters = string.ascii_letters + " .,;'-'"
n_letters = len(all_letters) + 1 # + 1 for EOS

In [4]:
def find_files(path):
    return glob.glob(path)

In [5]:
def unicode_to_ascii(s):
    return ''.join(c for c in unicodedata.normalize('NFD', s) 
                   if unicodedata.category(c) != 'Mn' and c in all_letters)

In [6]:
def read_lines(filename):
    lines = open(filename, encoding='utf-8').read().strip().split('\n')
    return [unicode_to_ascii(line) for line in lines]

In [7]:
category_lines = {}
all_categories = []
for filename in find_files('%s/names/*.txt' % DATA_DIR):
    category = os.path.splitext(os.path.basename(filename))[0]
    all_categories.append(category)
    lines = read_lines(filename)
    category_lines[category] = lines
n_categories = len(all_categories)
if n_categories == 0:
    raise RuntimeError('Data not found')
print(f'# categories: {n_categories} {all_categories}')
print(unicode_to_ascii("O'Néàl"))

# categories: 18 ['Czech', 'German', 'Arabic', 'Japanese', 'Chinese', 'Vietnamese', 'Russian', 'French', 'Irish', 'English', 'Spanish', 'Greek', 'Italian', 'Portuguese', 'Scottish', 'Dutch', 'Korean', 'Polish']
O'Neal


In [9]:
DROPOUT = 0.1

In [10]:
class RNN(nn.Module):
    def __init__(self, input_size, hidden_size, output_size):
        super(RNN, self).__init__()
        self.hidden_size = hidden_size
        self.i2h = nn.Linear(n_categories + input_size + hidden_size, 
                             hidden_size)
        self.i2o = nn.Linear(n_categories + input_size + hidden_size,
                             output_size)
        self.o2o = nn.Linear(hidden_size + output_size, output_size)
        self.dropout = nn.Dropout(DROPOUT)
        self.softmax = nn.LogSoftmax(dim=1)
        
    def forward(self, category, input, hidden):
        input_combined = torch.cat((category, input, hidden), 1)
        hidden = self.i2h(input_combined)
        output = self.i2h(input_combined)
        output_combined = torch.cat((hidden, output), 1)
        output = self.o2o(output_combined)
        output = self.dropout(output)
        output = self.softmax(output)
        return output, hidden
    
    def init_hidden(self):
        return torch.zeros(1, self.hidden_size)

In [12]:
def random_choice(lst):
    return lst[random.randint(0, len(lst) - 1)]

In [13]:
# Get random category and line from category
def random_training_pair():
    category = random_choice(all_categories)
    line = random_choice(category_lines[category])
    return category, line