In [1]:
%matplotlib inline

In [2]:
import glob
from os.path import expanduser, basename, splitext

In [3]:
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F

In [4]:
torch.manual_seed(1)

<torch._C.Generator at 0x7fa9869f1970>

In [6]:
!wget https://download.pytorch.org/tutorial/data.zip -O ~/data.zip
!unzip ~/data.zip -d ~
!mkdir ~/data/names/names
!mv ~/data/names/*.txt ~/data/names/names
!mv ~/data/eng-fra.txt ~/data/names/eng-fra.txt
!rm ~/data.zip

--2018-08-20 21:45:35--  https://download.pytorch.org/tutorial/data.zip
Resolving download.pytorch.org (download.pytorch.org)... 205.251.219.72, 205.251.219.66, 205.251.219.114, ...
Connecting to download.pytorch.org (download.pytorch.org)|205.251.219.72|:443... connected.
HTTP request sent, awaiting response... 200 OK
Length: 2882130 (2.7M) [application/zip]
Saving to: ‘/home/ck/data.zip’


2018-08-20 21:45:36 (4.72 MB/s) - ‘/home/ck/data.zip’ saved [2882130/2882130]

Archive:  /home/ck/data.zip
  inflating: /home/ck/data/eng-fra.txt  
   creating: /home/ck/data/names/
  inflating: /home/ck/data/names/Arabic.txt  
  inflating: /home/ck/data/names/Chinese.txt  
  inflating: /home/ck/data/names/Czech.txt  
  inflating: /home/ck/data/names/Dutch.txt  
  inflating: /home/ck/data/names/English.txt  
  inflating: /home/ck/data/names/French.txt  
  inflating: /home/ck/data/names/German.txt  
  inflating: /home/ck/data/names/Greek.txt  
  inflating: /home/ck/data/names/Irish.txt  
  inflating

In [7]:
PATH = expanduser('~/data/names/names')

In [8]:
!ls $PATH

Arabic.txt   English.txt  Irish.txt	Polish.txt	Spanish.txt
Chinese.txt  French.txt   Italian.txt	Portuguese.txt	Vietnamese.txt
Czech.txt    German.txt   Japanese.txt	Russian.txt
Dutch.txt    Greek.txt	  Korean.txt	Scottish.txt


In [9]:
import unicodedata
import string

In [26]:
all_letters = string.ascii_letters + " .,;'"
n_letters = len(all_letters)

In [27]:
def unicode_to_ascii(s):
    return ''.join(
        c for c in unicodedata.normalize('NFD', s)
        if unicodedata.category(c) != 'Mn'
        and c in all_letters)

In [28]:
unicode_to_ascii('Ślusàrski')

'Slusarski'

In [29]:
def read_lines(filename):
    lines = open(filename, encoding='utf-8').read().strip().split('\n')
    return [unicode_to_ascii(line) for line in lines]

In [30]:
category_files = {}
all_categories = []

for filename in glob.glob(PATH + '/*.txt'):
    print(f'reading file {filename}')
    category, _ = splitext(basename(filename))
    all_categories.append(category)
    lines = read_lines(filename)
    category_files[category] = lines
    
n_categories = len(all_categories)

reading file /home/ck/data/names/names/Chinese.txt
reading file /home/ck/data/names/names/English.txt
reading file /home/ck/data/names/names/Korean.txt
reading file /home/ck/data/names/names/Portuguese.txt
reading file /home/ck/data/names/names/German.txt
reading file /home/ck/data/names/names/Dutch.txt
reading file /home/ck/data/names/names/Arabic.txt
reading file /home/ck/data/names/names/Spanish.txt
reading file /home/ck/data/names/names/Scottish.txt
reading file /home/ck/data/names/names/French.txt
reading file /home/ck/data/names/names/Italian.txt
reading file /home/ck/data/names/names/Irish.txt
reading file /home/ck/data/names/names/Russian.txt
reading file /home/ck/data/names/names/Vietnamese.txt
reading file /home/ck/data/names/names/Czech.txt
reading file /home/ck/data/names/names/Greek.txt
reading file /home/ck/data/names/names/Japanese.txt
reading file /home/ck/data/names/names/Polish.txt


In [31]:
def letter_to_index(letter):
    return all_letters.find(letter)

In [32]:
def letter_to_tensor(letter):
    tensor = torch.zeros(1, n_letters)
    tensor[0][letter_to_index(letter)] = 1
    return tensor

In [33]:
def line_to_tensor(line):
    tensor = torch.zeros(len(line), 1, n_letters)
    for index, letter in enumerate(line):
        tensor[index][0][letter_to_index(letter)] = 1
    return tensor

In [34]:
letter_to_tensor('J')

tensor([[0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
         0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 1.,
         0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
         0., 0., 0.]])

In [35]:
line_to_tensor('Jones').size()

torch.Size([5, 1, 57])

In [36]:
class RNN(nn.Module):
    
    def __init__(self, input_size, hidden_size, output_size):
        super().__init__()
        self.hidden_size = hidden_size
        self.i2h = nn.Linear(input_size + hidden_size, hidden_size)
        self.i2o = nn.Linear(input_size + hidden_size, output_size)
        self.softmax = nn.LogSoftmax(dim=1)
        
    def forward(self, i, h):
        combined = torch.cat((i, h), 1)
        hidden = self.i2h(combined)
        output = self.i2o(combined)
        output = self.softmax(output)
        return output, hidden
    
    def init_hidden(self):
        return torch.zeros(1, self.hidden_size)

In [37]:
n_hidden = 128

In [38]:
rnn = RNN(n_letters, n_hidden, n_categories)

In [None]:
input_tensor = letter_to_tensor('A')
hidden = rnn.init_hidden()
output, next_hidden = rnn(input_tensor, hidden)