In [12]:
import string
import unicodedata
import glob
import os

all_letters = string.ascii_letters
n_letters = len(all_letters)


def unicode_to_ascii(s):
    return ''.join(
        c for c in unicodedata.normalize('NFD', s)
        if unicodedata.category(c) != 'Mn'
        and c in all_letters
    )
print(unicode_to_ascii('Ślusàrski'))

Slusarski


In [21]:
def read_file(filename):
    lines = open(filename, encoding='utf-8').read().strip('\n').split('\n')
    return [unicode_to_ascii(line) for line in lines]

all_names_by_category = dict()

for filename in glob.glob('data/names/*.txt'):
    category = os.path.splitext(os.path.basename(filename))[0]
    lines = read_file(filename)
    all_names_by_category[category] = lines

categories = list(all_names_by_category.keys())
print(categories)

['Arabic', 'Chinese', 'Czech', 'Dutch', 'English', 'French', 'German', 'Greek', 'Irish', 'Italian', 'Japanese', 'Korean', 'Polish', 'Portuguese', 'Russian', 'Scottish', 'Spanish', 'Vietnamese']


In [24]:
import torch

def line_to_tensor(s):
    tensor = torch.zeros(len(s), 1, n_letters)
    for i, l in enumerate(s):
        index = all_letters.find(l)
        tensor[i][0][index] = 1
    return tensor

print(line_to_tensor('yi'))

tensor([[[0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
          0., 0., 0., 0., 0., 0., 0., 1., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
          0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
          0.]],

        [[0., 0., 0., 0., 0., 0., 0., 0., 1., 0., 0., 0., 0., 0., 0., 0., 0.,
          0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
          0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
          0.]]])


In [30]:
import random

def randomChoice(l):
    return l[random.randint(0, len(l) - 1)]

def randomTrainingExample():
    category = randomChoice(categories)
    line = randomChoice(all_names_by_category[category])
    category_tensor = torch.tensor([categories.index(category)], dtype=torch.long)
    line_tensor = line_to_tensor(line)
    return category, line, category_tensor, line_tensor

for i in range(10):
    category, line, category_tensor, line_tensor = randomTrainingExample()
    print('category =', category, category_tensor, '/ line =', line)

category = Greek tensor([7]) / line = Milionis
category = Vietnamese tensor([17]) / line = Chau
category = Czech tensor([2]) / line = Spoerl
category = Irish tensor([8]) / line = Naoimhin
category = Greek tensor([7]) / line = Frangopoulos
category = Greek tensor([7]) / line = Close
category = German tensor([6]) / line = Bosch
category = Scottish tensor([15]) / line = Craig
category = Vietnamese tensor([17]) / line = Do
category = German tensor([6]) / line = Bergfalk


In [47]:
import torch.nn as nn


class RNN(nn.Module):
    def __init__(self, input_size, hidden_size, output_size):
        super(RNN, self).__init__()
        
        self.hidden_size = hidden_size
        
        self.i2o = nn.Linear(input_size + hidden_size, output_size)
        self.i2h = nn.Linear(input_size + hidden_size, hidden_size)
        self.softmax = nn.LogSoftmax(dim=1)
    
    
    def forward(self, input, hidden):
        combined = torch.cat([input, hidden], 1)
        hidden = self.i2h(combined)
        output = self.softmax(self.i2o(combined))
        return output, hidden
    
    
    def init_hidden(self):
        return torch.zeros(1, self.hidden_size)
        
    
n_hidden = 128

rnn = RNN(n_letters, n_hidden, len(categories))

In [48]:
def category_from_output(output):
    _, top_i = output.topk(1)
    category_i = top_i[0].item()
    return categories[category_i], category_i


criterion = nn.NLLLoss()

learning_rate = 0.005

def train(category_tensor, line_tensor):
    hidden = rnn.init_hidden()
    
    rnn.zero_grad()
    
    for i in range(line_tensor.size()[0]):
        output, hidden = rnn(line_tensor[i], hidden)
    loss = criterion(output, category_tensor)
    loss.backward()
    
    for p in rnn.parameters():
        p.data.add_(-learning_rate, p.grad.data)
    
    return output, loss.item()

In [50]:
n_iters = 10000

for i in range(n_iters):
    category, line, category_tensor, line_tensor = randomTrainingExample()
    output, loss = train(category_tensor, line_tensor)
    print('%d iter, loss: %.4f'%(i, loss))

0 iter, loss: 2.9450
1 iter, loss: 2.8237
2 iter, loss: 2.7890
3 iter, loss: 2.9668
4 iter, loss: 2.7806
5 iter, loss: 2.8125
6 iter, loss: 2.8195
7 iter, loss: 2.7318
8 iter, loss: 2.9876
9 iter, loss: 3.0210
10 iter, loss: 2.9178
11 iter, loss: 2.8876
12 iter, loss: 2.8714
13 iter, loss: 2.7773
14 iter, loss: 2.8660
15 iter, loss: 2.8765
16 iter, loss: 2.7263
17 iter, loss: 2.9376
18 iter, loss: 2.9079
19 iter, loss: 2.8840
20 iter, loss: 2.6945
21 iter, loss: 2.8488
22 iter, loss: 2.7229
23 iter, loss: 2.8029
24 iter, loss: 2.9055
25 iter, loss: 2.6770
26 iter, loss: 2.6101
27 iter, loss: 2.6811
28 iter, loss: 2.8014
29 iter, loss: 2.7634
30 iter, loss: 2.7175
31 iter, loss: 3.0154
32 iter, loss: 2.7267
33 iter, loss: 2.9607
34 iter, loss: 2.8197
35 iter, loss: 2.9204
36 iter, loss: 2.8517
37 iter, loss: 2.6495
38 iter, loss: 2.8855
39 iter, loss: 2.7943
40 iter, loss: 2.7718
41 iter, loss: 2.8202
42 iter, loss: 2.7454
43 iter, loss: 2.8531
44 iter, loss: 2.7904
45 iter, loss: 2.904

389 iter, loss: 2.7101
390 iter, loss: 2.6097
391 iter, loss: 2.9482
392 iter, loss: 2.8252
393 iter, loss: 2.8723
394 iter, loss: 2.7364
395 iter, loss: 2.8880
396 iter, loss: 2.8788
397 iter, loss: 2.7502
398 iter, loss: 2.6087
399 iter, loss: 2.5980
400 iter, loss: 2.8832
401 iter, loss: 2.7789
402 iter, loss: 2.6685
403 iter, loss: 2.5569
404 iter, loss: 2.8592
405 iter, loss: 2.7707
406 iter, loss: 2.6897
407 iter, loss: 2.6190
408 iter, loss: 2.8344
409 iter, loss: 2.8227
410 iter, loss: 2.8722
411 iter, loss: 2.8077
412 iter, loss: 2.7941
413 iter, loss: 2.9050
414 iter, loss: 2.7809
415 iter, loss: 2.7564
416 iter, loss: 2.9292
417 iter, loss: 2.8884
418 iter, loss: 2.8502
419 iter, loss: 2.9897
420 iter, loss: 2.9140
421 iter, loss: 2.9058
422 iter, loss: 2.9558
423 iter, loss: 2.5962
424 iter, loss: 2.6712
425 iter, loss: 2.7179
426 iter, loss: 2.8458
427 iter, loss: 2.7364
428 iter, loss: 2.8134
429 iter, loss: 2.7597
430 iter, loss: 2.8433
431 iter, loss: 2.8677
432 iter, l

762 iter, loss: 2.7510
763 iter, loss: 2.7954
764 iter, loss: 2.8981
765 iter, loss: 2.9316
766 iter, loss: 2.9348
767 iter, loss: 3.0005
768 iter, loss: 2.8279
769 iter, loss: 2.8984
770 iter, loss: 2.9161
771 iter, loss: 2.7417
772 iter, loss: 2.7070
773 iter, loss: 2.6972
774 iter, loss: 2.8679
775 iter, loss: 2.8770
776 iter, loss: 2.4672
777 iter, loss: 2.7715
778 iter, loss: 2.5939
779 iter, loss: 2.8213
780 iter, loss: 2.5860
781 iter, loss: 2.9344
782 iter, loss: 2.7424
783 iter, loss: 2.7469
784 iter, loss: 2.8586
785 iter, loss: 2.8952
786 iter, loss: 2.7948
787 iter, loss: 2.9965
788 iter, loss: 2.7484
789 iter, loss: 2.7964
790 iter, loss: 2.8850
791 iter, loss: 2.8446
792 iter, loss: 2.8816
793 iter, loss: 2.7181
794 iter, loss: 2.8027
795 iter, loss: 2.9796
796 iter, loss: 2.9722
797 iter, loss: 2.7965
798 iter, loss: 2.9214
799 iter, loss: 2.9049
800 iter, loss: 2.4263
801 iter, loss: 2.5507
802 iter, loss: 2.7513
803 iter, loss: 2.8431
804 iter, loss: 2.7313
805 iter, l

1128 iter, loss: 2.8016
1129 iter, loss: 2.7141
1130 iter, loss: 2.7377
1131 iter, loss: 2.9817
1132 iter, loss: 3.0456
1133 iter, loss: 2.8860
1134 iter, loss: 2.8626
1135 iter, loss: 2.2483
1136 iter, loss: 2.8581
1137 iter, loss: 2.7259
1138 iter, loss: 3.0089
1139 iter, loss: 2.8578
1140 iter, loss: 2.5032
1141 iter, loss: 3.0431
1142 iter, loss: 2.8480
1143 iter, loss: 2.6145
1144 iter, loss: 2.8213
1145 iter, loss: 3.0737
1146 iter, loss: 2.8110
1147 iter, loss: 3.0338
1148 iter, loss: 2.7108
1149 iter, loss: 2.7715
1150 iter, loss: 2.8500
1151 iter, loss: 2.9577
1152 iter, loss: 2.7661
1153 iter, loss: 2.9199
1154 iter, loss: 2.7527
1155 iter, loss: 2.8044
1156 iter, loss: 2.7096
1157 iter, loss: 2.8358
1158 iter, loss: 2.6447
1159 iter, loss: 2.7603
1160 iter, loss: 2.7782
1161 iter, loss: 2.9896
1162 iter, loss: 2.6231
1163 iter, loss: 2.8067
1164 iter, loss: 2.6362
1165 iter, loss: 2.9107
1166 iter, loss: 2.6996
1167 iter, loss: 2.7429
1168 iter, loss: 2.9620
1169 iter, loss:

1527 iter, loss: 2.9439
1528 iter, loss: 2.8570
1529 iter, loss: 2.9505
1530 iter, loss: 2.8398
1531 iter, loss: 2.8113
1532 iter, loss: 2.7680
1533 iter, loss: 2.8754
1534 iter, loss: 2.5677
1535 iter, loss: 2.8536
1536 iter, loss: 2.7210
1537 iter, loss: 2.7404
1538 iter, loss: 2.9824
1539 iter, loss: 2.8297
1540 iter, loss: 2.7778
1541 iter, loss: 2.6310
1542 iter, loss: 2.8371
1543 iter, loss: 2.8074
1544 iter, loss: 2.6272
1545 iter, loss: 3.0149
1546 iter, loss: 2.9975
1547 iter, loss: 2.7987
1548 iter, loss: 2.8577
1549 iter, loss: 2.8096
1550 iter, loss: 2.7034
1551 iter, loss: 2.3468
1552 iter, loss: 2.4243
1553 iter, loss: 2.6458
1554 iter, loss: 2.9390
1555 iter, loss: 2.7178
1556 iter, loss: 2.5679
1557 iter, loss: 2.7210
1558 iter, loss: 2.9600
1559 iter, loss: 2.3700
1560 iter, loss: 2.6733
1561 iter, loss: 2.5612
1562 iter, loss: 2.5845
1563 iter, loss: 2.8407
1564 iter, loss: 2.9946
1565 iter, loss: 2.8607
1566 iter, loss: 2.8555
1567 iter, loss: 2.7599
1568 iter, loss:

1894 iter, loss: 2.8425
1895 iter, loss: 2.7175
1896 iter, loss: 2.7081
1897 iter, loss: 2.8697
1898 iter, loss: 3.0343
1899 iter, loss: 2.6964
1900 iter, loss: 2.7745
1901 iter, loss: 2.8382
1902 iter, loss: 2.5843
1903 iter, loss: 2.7781
1904 iter, loss: 2.8410
1905 iter, loss: 2.7331
1906 iter, loss: 2.7263
1907 iter, loss: 2.9391
1908 iter, loss: 2.8747
1909 iter, loss: 2.9249
1910 iter, loss: 2.6873
1911 iter, loss: 2.7722
1912 iter, loss: 2.7893
1913 iter, loss: 2.6794
1914 iter, loss: 2.8730
1915 iter, loss: 2.6940
1916 iter, loss: 2.7817
1917 iter, loss: 2.9277
1918 iter, loss: 2.8714
1919 iter, loss: 2.6985
1920 iter, loss: 2.8699
1921 iter, loss: 2.2170
1922 iter, loss: 2.8340
1923 iter, loss: 2.9456
1924 iter, loss: 2.6814
1925 iter, loss: 2.8690
1926 iter, loss: 3.0090
1927 iter, loss: 2.7009
1928 iter, loss: 2.6083
1929 iter, loss: 2.8193
1930 iter, loss: 2.9136
1931 iter, loss: 2.7771
1932 iter, loss: 2.7577
1933 iter, loss: 2.5527
1934 iter, loss: 2.8023
1935 iter, loss:

2287 iter, loss: 2.5099
2288 iter, loss: 2.7677
2289 iter, loss: 2.6180
2290 iter, loss: 2.8535
2291 iter, loss: 2.5879
2292 iter, loss: 2.9061
2293 iter, loss: 2.7677
2294 iter, loss: 2.8805
2295 iter, loss: 2.7094
2296 iter, loss: 2.7980
2297 iter, loss: 2.8320
2298 iter, loss: 2.1543
2299 iter, loss: 2.5903
2300 iter, loss: 2.8283
2301 iter, loss: 3.0281
2302 iter, loss: 2.9567
2303 iter, loss: 2.8235
2304 iter, loss: 2.6716
2305 iter, loss: 2.9019
2306 iter, loss: 2.6816
2307 iter, loss: 2.9057
2308 iter, loss: 2.7977
2309 iter, loss: 2.7553
2310 iter, loss: 2.5848
2311 iter, loss: 2.8370
2312 iter, loss: 2.9612
2313 iter, loss: 2.8684
2314 iter, loss: 2.4708
2315 iter, loss: 2.9107
2316 iter, loss: 2.8383
2317 iter, loss: 2.7804
2318 iter, loss: 2.8091
2319 iter, loss: 2.8618
2320 iter, loss: 2.7638
2321 iter, loss: 2.9158
2322 iter, loss: 3.0836
2323 iter, loss: 2.8671
2324 iter, loss: 2.6729
2325 iter, loss: 2.8141
2326 iter, loss: 2.6609
2327 iter, loss: 2.7162
2328 iter, loss:

2680 iter, loss: 2.7574
2681 iter, loss: 2.7294
2682 iter, loss: 2.8572
2683 iter, loss: 2.8505
2684 iter, loss: 2.8007
2685 iter, loss: 2.5174
2686 iter, loss: 2.1927
2687 iter, loss: 2.9064
2688 iter, loss: 2.8113
2689 iter, loss: 2.7130
2690 iter, loss: 2.7291
2691 iter, loss: 2.8651
2692 iter, loss: 2.7949
2693 iter, loss: 2.8493
2694 iter, loss: 2.6609
2695 iter, loss: 2.7611
2696 iter, loss: 2.8370
2697 iter, loss: 2.8077
2698 iter, loss: 2.7292
2699 iter, loss: 2.7005
2700 iter, loss: 3.0327
2701 iter, loss: 2.9462
2702 iter, loss: 2.7544
2703 iter, loss: 2.5926
2704 iter, loss: 2.6081
2705 iter, loss: 2.8409
2706 iter, loss: 2.8351
2707 iter, loss: 2.4905
2708 iter, loss: 2.7684
2709 iter, loss: 2.8297
2710 iter, loss: 2.8652
2711 iter, loss: 3.1251
2712 iter, loss: 2.9704
2713 iter, loss: 1.9518
2714 iter, loss: 2.8735
2715 iter, loss: 2.8695
2716 iter, loss: 2.6840
2717 iter, loss: 2.8026
2718 iter, loss: 2.8526
2719 iter, loss: 2.8976
2720 iter, loss: 2.8030
2721 iter, loss:

3040 iter, loss: 2.6369
3041 iter, loss: 2.9492
3042 iter, loss: 2.8047
3043 iter, loss: 2.5136
3044 iter, loss: 2.8086
3045 iter, loss: 2.7795
3046 iter, loss: 2.7881
3047 iter, loss: 2.7871
3048 iter, loss: 2.5622
3049 iter, loss: 2.6417
3050 iter, loss: 1.8619
3051 iter, loss: 2.7545
3052 iter, loss: 2.2726
3053 iter, loss: 1.8437
3054 iter, loss: 2.4401
3055 iter, loss: 2.4624
3056 iter, loss: 3.0087
3057 iter, loss: 2.3871
3058 iter, loss: 2.8977
3059 iter, loss: 2.7357
3060 iter, loss: 2.7269
3061 iter, loss: 2.8301
3062 iter, loss: 3.0314
3063 iter, loss: 2.3871
3064 iter, loss: 3.0273
3065 iter, loss: 2.8083
3066 iter, loss: 2.6605
3067 iter, loss: 2.4810
3068 iter, loss: 2.5838
3069 iter, loss: 3.0136
3070 iter, loss: 2.1662
3071 iter, loss: 2.7714
3072 iter, loss: 1.8935
3073 iter, loss: 1.9312
3074 iter, loss: 2.4327
3075 iter, loss: 2.9017
3076 iter, loss: 2.3504
3077 iter, loss: 2.7148
3078 iter, loss: 3.1316
3079 iter, loss: 2.6841
3080 iter, loss: 2.8102
3081 iter, loss:

3401 iter, loss: 2.6406
3402 iter, loss: 2.8633
3403 iter, loss: 2.0372
3404 iter, loss: 2.5947
3405 iter, loss: 2.7764
3406 iter, loss: 2.6712
3407 iter, loss: 2.8328
3408 iter, loss: 2.4089
3409 iter, loss: 2.7801
3410 iter, loss: 2.8611
3411 iter, loss: 2.5081
3412 iter, loss: 2.9013
3413 iter, loss: 2.8569
3414 iter, loss: 2.8683
3415 iter, loss: 2.6658
3416 iter, loss: 2.5087
3417 iter, loss: 1.9906
3418 iter, loss: 2.6245
3419 iter, loss: 2.9663
3420 iter, loss: 2.8511
3421 iter, loss: 2.3014
3422 iter, loss: 2.7416
3423 iter, loss: 2.7321
3424 iter, loss: 1.8099
3425 iter, loss: 2.4846
3426 iter, loss: 2.6408
3427 iter, loss: 2.4975
3428 iter, loss: 2.2449
3429 iter, loss: 2.6728
3430 iter, loss: 2.9955
3431 iter, loss: 3.0255
3432 iter, loss: 2.5954
3433 iter, loss: 3.0195
3434 iter, loss: 2.8076
3435 iter, loss: 2.5031
3436 iter, loss: 2.5381
3437 iter, loss: 2.9730
3438 iter, loss: 2.7708
3439 iter, loss: 2.4539
3440 iter, loss: 2.6646
3441 iter, loss: 2.7058
3442 iter, loss:

3761 iter, loss: 2.8495
3762 iter, loss: 2.9222
3763 iter, loss: 2.5555
3764 iter, loss: 2.7936
3765 iter, loss: 2.9428
3766 iter, loss: 2.7988
3767 iter, loss: 2.5986
3768 iter, loss: 2.8120
3769 iter, loss: 3.1254
3770 iter, loss: 1.6843
3771 iter, loss: 2.2613
3772 iter, loss: 2.7362
3773 iter, loss: 1.4895
3774 iter, loss: 2.9647
3775 iter, loss: 2.9292
3776 iter, loss: 2.9270
3777 iter, loss: 2.5858
3778 iter, loss: 2.8092
3779 iter, loss: 2.6456
3780 iter, loss: 2.7004
3781 iter, loss: 2.9806
3782 iter, loss: 2.8473
3783 iter, loss: 2.8033
3784 iter, loss: 2.6372
3785 iter, loss: 2.2991
3786 iter, loss: 1.3008
3787 iter, loss: 2.8226
3788 iter, loss: 2.6129
3789 iter, loss: 2.4231
3790 iter, loss: 2.4859
3791 iter, loss: 1.2080
3792 iter, loss: 2.7717
3793 iter, loss: 2.3179
3794 iter, loss: 2.7933
3795 iter, loss: 2.3909
3796 iter, loss: 2.5584
3797 iter, loss: 2.9386
3798 iter, loss: 2.5908
3799 iter, loss: 2.7672
3800 iter, loss: 2.5725
3801 iter, loss: 2.6468
3802 iter, loss:

4122 iter, loss: 3.2125
4123 iter, loss: 2.5038
4124 iter, loss: 2.6239
4125 iter, loss: 2.6440
4126 iter, loss: 2.2989
4127 iter, loss: 3.0619
4128 iter, loss: 2.8058
4129 iter, loss: 2.6710
4130 iter, loss: 2.8562
4131 iter, loss: 2.3377
4132 iter, loss: 2.7842
4133 iter, loss: 2.7979
4134 iter, loss: 3.0805
4135 iter, loss: 2.9313
4136 iter, loss: 2.5913
4137 iter, loss: 2.8957
4138 iter, loss: 2.0745
4139 iter, loss: 1.6468
4140 iter, loss: 2.7334
4141 iter, loss: 2.7881
4142 iter, loss: 2.8829
4143 iter, loss: 2.8820
4144 iter, loss: 2.5406
4145 iter, loss: 2.4739
4146 iter, loss: 2.6292
4147 iter, loss: 2.8406
4148 iter, loss: 2.6556
4149 iter, loss: 2.8037
4150 iter, loss: 2.7499
4151 iter, loss: 1.6724
4152 iter, loss: 2.7767
4153 iter, loss: 2.6595
4154 iter, loss: 2.6946
4155 iter, loss: 1.6146
4156 iter, loss: 2.8900
4157 iter, loss: 2.5220
4158 iter, loss: 2.7803
4159 iter, loss: 2.6172
4160 iter, loss: 2.8791
4161 iter, loss: 2.5115
4162 iter, loss: 2.7094
4163 iter, loss:

4501 iter, loss: 2.2985
4502 iter, loss: 2.4968
4503 iter, loss: 2.7210
4504 iter, loss: 2.9911
4505 iter, loss: 2.5583
4506 iter, loss: 1.9200
4507 iter, loss: 2.4975
4508 iter, loss: 2.7182
4509 iter, loss: 2.2953
4510 iter, loss: 2.8094
4511 iter, loss: 2.4534
4512 iter, loss: 2.4723
4513 iter, loss: 2.8648
4514 iter, loss: 2.9955
4515 iter, loss: 2.4831
4516 iter, loss: 2.2837
4517 iter, loss: 2.8831
4518 iter, loss: 2.8822
4519 iter, loss: 1.4412
4520 iter, loss: 2.7579
4521 iter, loss: 3.0002
4522 iter, loss: 2.6960
4523 iter, loss: 1.4761
4524 iter, loss: 2.1046
4525 iter, loss: 2.5113
4526 iter, loss: 2.7242
4527 iter, loss: 2.6147
4528 iter, loss: 2.3504
4529 iter, loss: 1.1941
4530 iter, loss: 2.7414
4531 iter, loss: 2.5560
4532 iter, loss: 3.3942
4533 iter, loss: 2.4865
4534 iter, loss: 2.0328
4535 iter, loss: 2.5539
4536 iter, loss: 3.2425
4537 iter, loss: 2.5462
4538 iter, loss: 2.3343
4539 iter, loss: 3.0260
4540 iter, loss: 3.0802
4541 iter, loss: 2.3167
4542 iter, loss:

4868 iter, loss: 1.3633
4869 iter, loss: 2.3867
4870 iter, loss: 2.6093
4871 iter, loss: 2.4054
4872 iter, loss: 2.4684
4873 iter, loss: 2.1726
4874 iter, loss: 2.6674
4875 iter, loss: 2.2579
4876 iter, loss: 2.2777
4877 iter, loss: 1.4793
4878 iter, loss: 3.2806
4879 iter, loss: 3.0120
4880 iter, loss: 2.0698
4881 iter, loss: 2.1850
4882 iter, loss: 3.2868
4883 iter, loss: 2.6716
4884 iter, loss: 2.9558
4885 iter, loss: 2.6459
4886 iter, loss: 1.9060
4887 iter, loss: 1.5760
4888 iter, loss: 2.7776
4889 iter, loss: 2.0049
4890 iter, loss: 2.4704
4891 iter, loss: 2.8591
4892 iter, loss: 2.1674
4893 iter, loss: 2.9946
4894 iter, loss: 2.7381
4895 iter, loss: 2.6738
4896 iter, loss: 2.9646
4897 iter, loss: 2.4881
4898 iter, loss: 2.6863
4899 iter, loss: 2.1354
4900 iter, loss: 1.3431
4901 iter, loss: 2.5840
4902 iter, loss: 2.6449
4903 iter, loss: 2.6446
4904 iter, loss: 2.1867
4905 iter, loss: 2.9271
4906 iter, loss: 2.9869
4907 iter, loss: 3.0136
4908 iter, loss: 1.1983
4909 iter, loss:

5228 iter, loss: 2.2280
5229 iter, loss: 2.5478
5230 iter, loss: 2.6158
5231 iter, loss: 2.1149
5232 iter, loss: 2.4850
5233 iter, loss: 1.1946
5234 iter, loss: 2.7327
5235 iter, loss: 2.4704
5236 iter, loss: 2.4512
5237 iter, loss: 2.4085
5238 iter, loss: 2.0492
5239 iter, loss: 2.6453
5240 iter, loss: 1.9867
5241 iter, loss: 2.8502
5242 iter, loss: 3.0486
5243 iter, loss: 2.5888
5244 iter, loss: 2.7980
5245 iter, loss: 2.8270
5246 iter, loss: 2.5367
5247 iter, loss: 1.8670
5248 iter, loss: 2.3923
5249 iter, loss: 2.5106
5250 iter, loss: 3.0182
5251 iter, loss: 2.8150
5252 iter, loss: 2.5261
5253 iter, loss: 2.4511
5254 iter, loss: 0.9004
5255 iter, loss: 2.2527
5256 iter, loss: 2.4319
5257 iter, loss: 2.6051
5258 iter, loss: 2.9547
5259 iter, loss: 2.3901
5260 iter, loss: 3.1140
5261 iter, loss: 2.1371
5262 iter, loss: 3.0208
5263 iter, loss: 3.1288
5264 iter, loss: 2.3919
5265 iter, loss: 2.2483
5266 iter, loss: 1.9960
5267 iter, loss: 2.1417
5268 iter, loss: 2.4137
5269 iter, loss:

5613 iter, loss: 0.3408
5614 iter, loss: 2.9341
5615 iter, loss: 1.8172
5616 iter, loss: 2.8033
5617 iter, loss: 0.9907
5618 iter, loss: 1.6361
5619 iter, loss: 1.1582
5620 iter, loss: 2.4245
5621 iter, loss: 0.6867
5622 iter, loss: 2.0665
5623 iter, loss: 0.7325
5624 iter, loss: 2.6924
5625 iter, loss: 2.3576
5626 iter, loss: 3.0553
5627 iter, loss: 3.3641
5628 iter, loss: 2.0403
5629 iter, loss: 2.5425
5630 iter, loss: 3.2489
5631 iter, loss: 1.7894
5632 iter, loss: 2.7620
5633 iter, loss: 2.1576
5634 iter, loss: 2.3952
5635 iter, loss: 1.7699
5636 iter, loss: 2.9597
5637 iter, loss: 2.7733
5638 iter, loss: 2.4691
5639 iter, loss: 2.5196
5640 iter, loss: 1.9944
5641 iter, loss: 2.3953
5642 iter, loss: 1.8451
5643 iter, loss: 2.4575
5644 iter, loss: 2.8446
5645 iter, loss: 2.0315
5646 iter, loss: 3.7846
5647 iter, loss: 2.3188
5648 iter, loss: 2.3874
5649 iter, loss: 3.9362
5650 iter, loss: 1.1927
5651 iter, loss: 3.2464
5652 iter, loss: 2.2161
5653 iter, loss: 2.1922
5654 iter, loss:

5971 iter, loss: 2.4900
5972 iter, loss: 1.9073
5973 iter, loss: 1.8566
5974 iter, loss: 2.3809
5975 iter, loss: 2.5128
5976 iter, loss: 2.3832
5977 iter, loss: 3.4646
5978 iter, loss: 2.0835
5979 iter, loss: 2.4775
5980 iter, loss: 3.0240
5981 iter, loss: 3.5985
5982 iter, loss: 2.7062
5983 iter, loss: 1.2130
5984 iter, loss: 2.6345
5985 iter, loss: 1.3459
5986 iter, loss: 1.5373
5987 iter, loss: 2.7706
5988 iter, loss: 1.8153
5989 iter, loss: 2.2091
5990 iter, loss: 0.7059
5991 iter, loss: 2.2369
5992 iter, loss: 0.4084
5993 iter, loss: 2.5306
5994 iter, loss: 3.1187
5995 iter, loss: 2.6671
5996 iter, loss: 2.4335
5997 iter, loss: 2.1858
5998 iter, loss: 1.8286
5999 iter, loss: 2.1426
6000 iter, loss: 1.8324
6001 iter, loss: 1.6604
6002 iter, loss: 2.5114
6003 iter, loss: 3.1884
6004 iter, loss: 2.6654
6005 iter, loss: 3.3002
6006 iter, loss: 3.2022
6007 iter, loss: 3.0137
6008 iter, loss: 2.1427
6009 iter, loss: 2.2617
6010 iter, loss: 2.6890
6011 iter, loss: 3.3596
6012 iter, loss:

6331 iter, loss: 1.8977
6332 iter, loss: 1.9139
6333 iter, loss: 2.7106
6334 iter, loss: 1.3703
6335 iter, loss: 3.0239
6336 iter, loss: 1.0093
6337 iter, loss: 3.1967
6338 iter, loss: 2.2118
6339 iter, loss: 2.6482
6340 iter, loss: 2.4338
6341 iter, loss: 2.6128
6342 iter, loss: 2.7343
6343 iter, loss: 1.7298
6344 iter, loss: 3.5688
6345 iter, loss: 2.0427
6346 iter, loss: 2.7402
6347 iter, loss: 2.5272
6348 iter, loss: 2.1691
6349 iter, loss: 2.2537
6350 iter, loss: 2.7201
6351 iter, loss: 2.5108
6352 iter, loss: 2.2208
6353 iter, loss: 2.5839
6354 iter, loss: 2.1650
6355 iter, loss: 2.7769
6356 iter, loss: 2.8766
6357 iter, loss: 2.3859
6358 iter, loss: 2.6064
6359 iter, loss: 2.3686
6360 iter, loss: 1.6901
6361 iter, loss: 2.7285
6362 iter, loss: 1.4832
6363 iter, loss: 0.7162
6364 iter, loss: 1.6928
6365 iter, loss: 2.7950
6366 iter, loss: 2.0653
6367 iter, loss: 2.1472
6368 iter, loss: 1.2001
6369 iter, loss: 2.0418
6370 iter, loss: 1.7139
6371 iter, loss: 2.0206
6372 iter, loss:

6696 iter, loss: 2.3707
6697 iter, loss: 3.2830
6698 iter, loss: 3.7565
6699 iter, loss: 2.2522
6700 iter, loss: 1.7157
6701 iter, loss: 2.6362
6702 iter, loss: 3.1463
6703 iter, loss: 1.9905
6704 iter, loss: 1.9154
6705 iter, loss: 1.2121
6706 iter, loss: 1.8038
6707 iter, loss: 2.5847
6708 iter, loss: 1.5381
6709 iter, loss: 2.0384
6710 iter, loss: 2.3153
6711 iter, loss: 1.7129
6712 iter, loss: 2.5518
6713 iter, loss: 2.8918
6714 iter, loss: 2.7844
6715 iter, loss: 2.9923
6716 iter, loss: 2.5762
6717 iter, loss: 2.1456
6718 iter, loss: 2.4605
6719 iter, loss: 3.5919
6720 iter, loss: 2.3754
6721 iter, loss: 2.3861
6722 iter, loss: 2.1982
6723 iter, loss: 1.3922
6724 iter, loss: 2.9984
6725 iter, loss: 2.3325
6726 iter, loss: 2.7963
6727 iter, loss: 3.0741
6728 iter, loss: 1.6050
6729 iter, loss: 1.9596
6730 iter, loss: 2.4435
6731 iter, loss: 2.0737
6732 iter, loss: 2.8933
6733 iter, loss: 1.1498
6734 iter, loss: 2.6120
6735 iter, loss: 3.2389
6736 iter, loss: 1.3511
6737 iter, loss:

7056 iter, loss: 1.7578
7057 iter, loss: 2.5902
7058 iter, loss: 2.3354
7059 iter, loss: 1.9089
7060 iter, loss: 3.1527
7061 iter, loss: 2.0391
7062 iter, loss: 3.3248
7063 iter, loss: 1.9384
7064 iter, loss: 2.0893
7065 iter, loss: 3.2536
7066 iter, loss: 2.1451
7067 iter, loss: 2.3594
7068 iter, loss: 2.7650
7069 iter, loss: 1.8310
7070 iter, loss: 1.4248
7071 iter, loss: 1.7438
7072 iter, loss: 0.7175
7073 iter, loss: 2.2751
7074 iter, loss: 2.3593
7075 iter, loss: 2.8733
7076 iter, loss: 2.7568
7077 iter, loss: 2.6342
7078 iter, loss: 1.4467
7079 iter, loss: 3.7214
7080 iter, loss: 2.9366
7081 iter, loss: 2.2992
7082 iter, loss: 3.2531
7083 iter, loss: 2.1306
7084 iter, loss: 1.8172
7085 iter, loss: 2.2714
7086 iter, loss: 2.8093
7087 iter, loss: 2.4291
7088 iter, loss: 1.8234
7089 iter, loss: 2.4877
7090 iter, loss: 0.7219
7091 iter, loss: 3.0954
7092 iter, loss: 1.8398
7093 iter, loss: 2.5423
7094 iter, loss: 1.5931
7095 iter, loss: 1.2175
7096 iter, loss: 2.1984
7097 iter, loss:

7424 iter, loss: 4.3184
7425 iter, loss: 4.0714
7426 iter, loss: 2.2991
7427 iter, loss: 1.8096
7428 iter, loss: 2.9219
7429 iter, loss: 1.0570
7430 iter, loss: 2.1847
7431 iter, loss: 2.6537
7432 iter, loss: 2.4497
7433 iter, loss: 3.4813
7434 iter, loss: 1.8111
7435 iter, loss: 2.3914
7436 iter, loss: 2.0149
7437 iter, loss: 2.7317
7438 iter, loss: 1.6536
7439 iter, loss: 2.6801
7440 iter, loss: 2.6666
7441 iter, loss: 1.3479
7442 iter, loss: 2.0547
7443 iter, loss: 2.8468
7444 iter, loss: 2.3313
7445 iter, loss: 2.2710
7446 iter, loss: 1.3894
7447 iter, loss: 2.3854
7448 iter, loss: 1.8654
7449 iter, loss: 1.9167
7450 iter, loss: 2.7951
7451 iter, loss: 2.4460
7452 iter, loss: 2.7366
7453 iter, loss: 1.6157
7454 iter, loss: 3.5733
7455 iter, loss: 2.5583
7456 iter, loss: 2.2619
7457 iter, loss: 1.1971
7458 iter, loss: 2.7154
7459 iter, loss: 2.4448
7460 iter, loss: 2.0537
7461 iter, loss: 3.0801
7462 iter, loss: 2.7979
7463 iter, loss: 2.8694
7464 iter, loss: 1.3618
7465 iter, loss:

7793 iter, loss: 2.0365
7794 iter, loss: 2.3519
7795 iter, loss: 2.6332
7796 iter, loss: 1.9535
7797 iter, loss: 3.0688
7798 iter, loss: 0.8679
7799 iter, loss: 0.6371
7800 iter, loss: 2.5314
7801 iter, loss: 2.1310
7802 iter, loss: 2.6017
7803 iter, loss: 1.6963
7804 iter, loss: 2.2415
7805 iter, loss: 3.6114
7806 iter, loss: 1.9591
7807 iter, loss: 1.1024
7808 iter, loss: 2.7700
7809 iter, loss: 1.6274
7810 iter, loss: 1.7503
7811 iter, loss: 1.7052
7812 iter, loss: 3.2662
7813 iter, loss: 2.1260
7814 iter, loss: 2.6980
7815 iter, loss: 1.7205
7816 iter, loss: 2.3442
7817 iter, loss: 2.5400
7818 iter, loss: 0.8003
7819 iter, loss: 0.6887
7820 iter, loss: 1.9358
7821 iter, loss: 2.0705
7822 iter, loss: 2.9435
7823 iter, loss: 2.5551
7824 iter, loss: 1.5035
7825 iter, loss: 3.4909
7826 iter, loss: 2.4492
7827 iter, loss: 2.7235
7828 iter, loss: 2.0189
7829 iter, loss: 2.2562
7830 iter, loss: 2.0026
7831 iter, loss: 2.9722
7832 iter, loss: 2.7375
7833 iter, loss: 2.8091
7834 iter, loss:

8163 iter, loss: 2.1794
8164 iter, loss: 1.6995
8165 iter, loss: 2.9808
8166 iter, loss: 1.5359
8167 iter, loss: 3.7036
8168 iter, loss: 2.0241
8169 iter, loss: 2.9385
8170 iter, loss: 0.3741
8171 iter, loss: 2.8020
8172 iter, loss: 2.5846
8173 iter, loss: 1.5149
8174 iter, loss: 4.4298
8175 iter, loss: 2.1793
8176 iter, loss: 2.0172
8177 iter, loss: 2.1329
8178 iter, loss: 2.7339
8179 iter, loss: 2.0314
8180 iter, loss: 1.5920
8181 iter, loss: 1.9372
8182 iter, loss: 2.7735
8183 iter, loss: 2.3975
8184 iter, loss: 0.4786
8185 iter, loss: 1.9173
8186 iter, loss: 1.7286
8187 iter, loss: 2.7144
8188 iter, loss: 1.8261
8189 iter, loss: 2.3102
8190 iter, loss: 2.7792
8191 iter, loss: 2.2428
8192 iter, loss: 1.0839
8193 iter, loss: 1.5897
8194 iter, loss: 2.0392
8195 iter, loss: 2.4229
8196 iter, loss: 1.6834
8197 iter, loss: 1.1201
8198 iter, loss: 2.7601
8199 iter, loss: 2.5036
8200 iter, loss: 3.3274
8201 iter, loss: 2.3295
8202 iter, loss: 1.7842
8203 iter, loss: 1.5500
8204 iter, loss:

8573 iter, loss: 2.3423
8574 iter, loss: 1.7731
8575 iter, loss: 1.1143
8576 iter, loss: 2.2555
8577 iter, loss: 1.4036
8578 iter, loss: 2.2682
8579 iter, loss: 2.9427
8580 iter, loss: 2.3025
8581 iter, loss: 2.7488
8582 iter, loss: 1.4831
8583 iter, loss: 1.8926
8584 iter, loss: 1.9539
8585 iter, loss: 1.5814
8586 iter, loss: 2.1572
8587 iter, loss: 2.3259
8588 iter, loss: 2.2722
8589 iter, loss: 1.5609
8590 iter, loss: 1.5438
8591 iter, loss: 2.3741
8592 iter, loss: 1.0249
8593 iter, loss: 1.5140
8594 iter, loss: 3.4738
8595 iter, loss: 2.3199
8596 iter, loss: 2.1106
8597 iter, loss: 2.4274
8598 iter, loss: 1.9303
8599 iter, loss: 2.0792
8600 iter, loss: 3.8722
8601 iter, loss: 2.3484
8602 iter, loss: 2.2010
8603 iter, loss: 0.7183
8604 iter, loss: 2.9987
8605 iter, loss: 3.7054
8606 iter, loss: 2.2869
8607 iter, loss: 1.3329
8608 iter, loss: 2.3472
8609 iter, loss: 2.5083
8610 iter, loss: 0.9969
8611 iter, loss: 2.2654
8612 iter, loss: 1.8783
8613 iter, loss: 1.9382
8614 iter, loss:

8931 iter, loss: 1.7407
8932 iter, loss: 1.0803
8933 iter, loss: 2.9542
8934 iter, loss: 2.5425
8935 iter, loss: 1.6886
8936 iter, loss: 1.9413
8937 iter, loss: 2.0008
8938 iter, loss: 2.1398
8939 iter, loss: 2.2773
8940 iter, loss: 2.2102
8941 iter, loss: 3.4534
8942 iter, loss: 3.2189
8943 iter, loss: 0.9039
8944 iter, loss: 0.2848
8945 iter, loss: 2.7925
8946 iter, loss: 1.4342
8947 iter, loss: 1.6613
8948 iter, loss: 2.2307
8949 iter, loss: 2.3854
8950 iter, loss: 1.9279
8951 iter, loss: 1.3749
8952 iter, loss: 1.8745
8953 iter, loss: 1.6843
8954 iter, loss: 2.2491
8955 iter, loss: 1.6009
8956 iter, loss: 2.4173
8957 iter, loss: 2.5197
8958 iter, loss: 2.9257
8959 iter, loss: 2.7544
8960 iter, loss: 1.9598
8961 iter, loss: 0.9376
8962 iter, loss: 2.8813
8963 iter, loss: 3.0410
8964 iter, loss: 0.9340
8965 iter, loss: 2.6463
8966 iter, loss: 2.4084
8967 iter, loss: 2.1241
8968 iter, loss: 1.6768
8969 iter, loss: 1.3941
8970 iter, loss: 1.0888
8971 iter, loss: 4.4174
8972 iter, loss:

9307 iter, loss: 1.6636
9308 iter, loss: 3.2410
9309 iter, loss: 1.1069
9310 iter, loss: 2.3306
9311 iter, loss: 2.8544
9312 iter, loss: 2.1263
9313 iter, loss: 1.1138
9314 iter, loss: 2.2429
9315 iter, loss: 1.4517
9316 iter, loss: 3.2473
9317 iter, loss: 2.2065
9318 iter, loss: 1.5461
9319 iter, loss: 2.1540
9320 iter, loss: 3.3977
9321 iter, loss: 1.4217
9322 iter, loss: 2.3962
9323 iter, loss: 2.9376
9324 iter, loss: 1.5554
9325 iter, loss: 1.3401
9326 iter, loss: 1.8112
9327 iter, loss: 3.0626
9328 iter, loss: 1.4070
9329 iter, loss: 2.1487
9330 iter, loss: 1.2709
9331 iter, loss: 3.4694
9332 iter, loss: 2.2824
9333 iter, loss: 2.7013
9334 iter, loss: 1.8184
9335 iter, loss: 2.4280
9336 iter, loss: 2.1178
9337 iter, loss: 1.3881
9338 iter, loss: 3.1262
9339 iter, loss: 1.9713
9340 iter, loss: 1.0813
9341 iter, loss: 1.9869
9342 iter, loss: 1.6960
9343 iter, loss: 1.0297
9344 iter, loss: 1.7566
9345 iter, loss: 1.9191
9346 iter, loss: 2.1427
9347 iter, loss: 1.0347
9348 iter, loss:

9653 iter, loss: 2.8868
9654 iter, loss: 2.2830
9655 iter, loss: 2.9852
9656 iter, loss: 0.8757
9657 iter, loss: 2.1906
9658 iter, loss: 1.1381
9659 iter, loss: 2.5028
9660 iter, loss: 1.4725
9661 iter, loss: 3.1546
9662 iter, loss: 2.0641
9663 iter, loss: 3.4042
9664 iter, loss: 1.5155
9665 iter, loss: 2.2591
9666 iter, loss: 2.9412
9667 iter, loss: 2.0185
9668 iter, loss: 1.5570
9669 iter, loss: 3.2539
9670 iter, loss: 2.7403
9671 iter, loss: 1.6353
9672 iter, loss: 1.0747
9673 iter, loss: 3.1223
9674 iter, loss: 2.7958
9675 iter, loss: 0.3116
9676 iter, loss: 1.7341
9677 iter, loss: 1.5042
9678 iter, loss: 1.2080
9679 iter, loss: 1.8040
9680 iter, loss: 1.7747
9681 iter, loss: 4.2309
9682 iter, loss: 1.9480
9683 iter, loss: 2.7484
9684 iter, loss: 2.3456
9685 iter, loss: 2.1570
9686 iter, loss: 2.7547
9687 iter, loss: 2.2861
9688 iter, loss: 2.8269
9689 iter, loss: 1.1110
9690 iter, loss: 3.0067
9691 iter, loss: 2.9775
9692 iter, loss: 2.1972
9693 iter, loss: 1.8890
9694 iter, loss:

In [51]:
def predict(input_name, n_prediction=3):
    print('\n> %s'%(input_name))
    with torch.no_grad():
        hidden = rnn.init_hidden()
        line_tensor = line_to_tensor(input_name)
        for i in range(len(input_name)):
            output, hidden = rnn(line_tensor[i], hidden)
        topv, topi = output.topk(n_prediction, 1, True)
        
        for i in range(n_prediction):
            value = topv[0][i].item()
            category_index = topi[0][i].item()
            print('%.2f %s'%(value, categories[category_index]))

In [66]:
predict('jane')


> jane
-1.69 Chinese
-1.74 Vietnamese
-1.77 Korean
