In [1]:
import torch
import torch.nn.functional as F
import matplotlib.pyplot as plt
%matplotlib inline

In [3]:
words = open('names.txt', 'r').read().splitlines()
words[:10]

['emma',
 'olivia',
 'ava',
 'isabella',
 'sophia',
 'charlotte',
 'mia',
 'amelia',
 'harper',
 'evelyn']

In [6]:
chars = sorted(list(set(''.join(words))))
stoi = {s: i + 1 for i, s in enumerate(chars)}
stoi['.'] = 0
itos = {i: s for s, i in stoi.items()}

In [126]:
block_size = 3
X, Y = [], []
for w in words:
#     print(w)
    
    context = [0] * block_size
    for ch in w + '.':
        ix = stoi[ch]
        X.append(context)
        Y.append(ix)
        #print(''.join(itos[i] for i in context), '--->', itos[ix])
        context = context[1:] + [ix]
        
X = torch.tensor(X)
Y = torch.tensor(Y)

In [127]:
X.shape, X.dtype, Y.shape, Y.dtype

(torch.Size([228146, 3]), torch.int64, torch.Size([228146]), torch.int64)

In [128]:
g = torch.Generator().manual_seed(2147483647)
C = torch.randn((27, 2), generator = g)

W1 = torch.randn((6, 100), generator = g)
b1 = torch.randn(100, generator = g)
W2 = torch.randn((100, 27), generator = g)
b2 = torch.randn(27, generator = g)

parameters = [C, W1, b1, W2, b2]

In [129]:
sum(p.nelement() for p in parameters)

3481

In [130]:
for p in parameters:
    p.requires_grad = True

In [None]:
torch.linspace(0.001, 1, 1000)

In [125]:
for _ in range(1000):
    
    #minibatch construction
    ix = torch.randint(0, X.shape[0], (32,))
    emb = C[X[ix]]
    h = torch.tanh(emb.view(-1, 6) @ W1 + b1)
    logits = h @ W2 + b2
    loss = F.cross_entropy(logits, Y[ix])
    print(loss.item())
    
#backward pass
    for p in parameters:
        p.grad = None
    loss.backward()
    
# update
    for p in parameters:
        p.data += -1 * p.grad
        



19.15500259399414
16.2569637298584
17.037494659423828
16.71219825744629
11.619129180908203
9.487751960754395
13.345929145812988
7.4726152420043945
14.483611106872559
15.14730453491211
8.7661714553833
13.561622619628906
13.876729011535645
10.567667961120605
9.91977596282959
14.097320556640625
9.744333267211914
7.217655658721924
8.805779457092285
8.863473892211914
9.774338722229004
10.851609230041504
9.111623764038086
9.268775939941406
10.219837188720703
11.281505584716797
11.370672225952148
9.972352981567383
10.107983589172363
8.837499618530273
9.912198066711426
7.552181720733643
9.306863784790039
11.44161605834961
9.154603004455566
9.474291801452637
7.82865571975708
7.162656307220459
7.577486515045166
7.43464994430542
9.265030860900879
8.606184959411621
9.559286117553711
7.40189790725708
7.036202907562256
6.085901737213135
7.066850662231445
8.872003555297852
10.050360679626465
7.086748123168945
7.915227890014648
7.226171970367432
7.927050590515137
5.326209545135498
8.729974746704102
6.

8.913022994995117
7.946148872375488
6.722776889801025
5.449456214904785
6.235199928283691
5.934115886688232
5.753729820251465
5.302563667297363
6.000153064727783
5.813592910766602
4.90034818649292
5.2444844245910645
4.962517261505127
5.787181377410889
4.910055160522461
4.3695502281188965
3.872957706451416
4.682260036468506
4.139753818511963
6.747422695159912
5.106368541717529
4.880311489105225
3.573035717010498
5.077587604522705
4.897144794464111
5.890421390533447
5.37209415435791
4.430594444274902
3.3876678943634033
4.4494853019714355
4.725004196166992
4.496461391448975
4.305785179138184
5.579298496246338
4.501203536987305
4.0660247802734375
4.49676513671875
5.54590368270874
5.445439338684082
5.364492893218994
4.657905101776123
4.415987014770508
4.138499736785889
4.12486457824707
4.900765895843506
4.419750213623047
3.527433395385742
5.083294868469238
4.730360984802246
4.918431282043457
4.96579122543335
7.501095771789551
4.602318286895752
5.347006320953369
4.0631103515625
3.96812677383

In [120]:
emb = C[X]
h = torch.tanh(emb.view(-1, 6) @ W1 + b1)
logits = h @ W2 + b2
loss = F.cross_entropy(logits, Y)
loss.item()

2.548656940460205

tensor([ 5, 13, 13,  1,  0, 15, 12,  9, 22,  9,  1,  0,  1, 22,  1,  0,  9, 19,
         1,  2,  5, 12, 12,  1,  0, 19, 15, 16,  8,  9,  1,  0])