In [174]:
import torch
import torch.nn.functional as F
import matplotlib.pyplot as plt # for making figures
%matplotlib inline

In [175]:
# read in all the words
words = open('names.txt', 'r').read().splitlines()
words[:8]

['emma', 'olivia', 'ava', 'isabella', 'sophia', 'charlotte', 'mia', 'amelia']

In [176]:
len(words)

32033

In [177]:
# build the vocabulary of characters and mappings to/from integers
chars = sorted(list(set(''.join(words))))
stoi = {s:i+1 for i,s in enumerate(chars)}
stoi['.'] = 0
itos = {i:s for s,i in stoi.items()}
print(itos)

{1: 'a', 2: 'b', 3: 'c', 4: 'd', 5: 'e', 6: 'f', 7: 'g', 8: 'h', 9: 'i', 10: 'j', 11: 'k', 12: 'l', 13: 'm', 14: 'n', 15: 'o', 16: 'p', 17: 'q', 18: 'r', 19: 's', 20: 't', 21: 'u', 22: 'v', 23: 'w', 24: 'x', 25: 'y', 26: 'z', 0: '.'}


In [178]:
# build the dataset

block_size = 3 # context length: how many characters do we take to predict the next one?
X, Y = [], []
for w in words[:5]:
    print(w)
    context = [0] * block_size
    for ch in w + '.':
        ix = stoi[ch]
        X.append(context)
        Y.append(ix)
        print(''.join(itos[i] for i in context), '--->', itos[ix])
        context = context[1:] + [ix] # crop and append
  
X = torch.tensor(X)
Y = torch.tensor(Y)
X

emma
... ---> e
..e ---> m
.em ---> m
emm ---> a
mma ---> .
olivia
... ---> o
..o ---> l
.ol ---> i
oli ---> v
liv ---> i
ivi ---> a
via ---> .
ava
... ---> a
..a ---> v
.av ---> a
ava ---> .
isabella
... ---> i
..i ---> s
.is ---> a
isa ---> b
sab ---> e
abe ---> l
bel ---> l
ell ---> a
lla ---> .
sophia
... ---> s
..s ---> o
.so ---> p
sop ---> h
oph ---> i
phi ---> a
hia ---> .


tensor([[ 0,  0,  0],
        [ 0,  0,  5],
        [ 0,  5, 13],
        [ 5, 13, 13],
        [13, 13,  1],
        [ 0,  0,  0],
        [ 0,  0, 15],
        [ 0, 15, 12],
        [15, 12,  9],
        [12,  9, 22],
        [ 9, 22,  9],
        [22,  9,  1],
        [ 0,  0,  0],
        [ 0,  0,  1],
        [ 0,  1, 22],
        [ 1, 22,  1],
        [ 0,  0,  0],
        [ 0,  0,  9],
        [ 0,  9, 19],
        [ 9, 19,  1],
        [19,  1,  2],
        [ 1,  2,  5],
        [ 2,  5, 12],
        [ 5, 12, 12],
        [12, 12,  1],
        [ 0,  0,  0],
        [ 0,  0, 19],
        [ 0, 19, 15],
        [19, 15, 16],
        [15, 16,  8],
        [16,  8,  9],
        [ 8,  9,  1]])

In [179]:
X.shape, X.dtype, Y.shape, Y.dtype

(torch.Size([32, 3]), torch.int64, torch.Size([32]), torch.int64)

In [180]:
C = torch.randn((27, 2))


In [181]:
C[5]

tensor([0.4292, 0.2281])

In [182]:
F.one_hot(torch.tensor(5), num_classes=27).float() @ C

tensor([0.4292, 0.2281])

In [183]:
C[X].shape

torch.Size([32, 3, 2])

In [184]:
X[13,2]

tensor(1)

In [185]:
C[X][13,2]

tensor([ 1.8568, -0.4740])

In [186]:
C[1]

tensor([ 1.8568, -0.4740])

In [187]:
emb = C[X]
emb.shape
emb

tensor([[[-0.1605, -0.8230],
         [-0.1605, -0.8230],
         [-0.1605, -0.8230]],

        [[-0.1605, -0.8230],
         [-0.1605, -0.8230],
         [ 0.4292,  0.2281]],

        [[-0.1605, -0.8230],
         [ 0.4292,  0.2281],
         [ 0.1847, -0.4149]],

        [[ 0.4292,  0.2281],
         [ 0.1847, -0.4149],
         [ 0.1847, -0.4149]],

        [[ 0.1847, -0.4149],
         [ 0.1847, -0.4149],
         [ 1.8568, -0.4740]],

        [[-0.1605, -0.8230],
         [-0.1605, -0.8230],
         [-0.1605, -0.8230]],

        [[-0.1605, -0.8230],
         [-0.1605, -0.8230],
         [-0.2392,  0.8671]],

        [[-0.1605, -0.8230],
         [-0.2392,  0.8671],
         [ 1.9192,  0.5768]],

        [[-0.2392,  0.8671],
         [ 1.9192,  0.5768],
         [ 0.5699, -0.2114]],

        [[ 1.9192,  0.5768],
         [ 0.5699, -0.2114],
         [-0.6608, -1.0366]],

        [[ 0.5699, -0.2114],
         [-0.6608, -1.0366],
         [ 0.5699, -0.2114]],

        [[-0.6608, -1

In [188]:
W1 = torch.randn((6, 100))
b1 = torch.randn(100)
# emb @ W1 + b1

In [189]:
# F = torch.cat([emb[:, 0, :], emb[:, 1, :], emb[:, 2, :]], 1)

In [190]:
unbinded = torch.unbind(emb, 1)

In [191]:
# Optimize tensor view manipulation

In [192]:
a = torch.arange(18)

In [193]:
# no memory is being changed/copied/created etc. when .view is being called
a.view(3,3,2)

tensor([[[ 0,  1],
         [ 2,  3],
         [ 4,  5]],

        [[ 6,  7],
         [ 8,  9],
         [10, 11]],

        [[12, 13],
         [14, 15],
         [16, 17]]])

In [194]:
# always a 1-dimensional vector
a.storage()

 0
 1
 2
 3
 4
 5
 6
 7
 8
 9
 10
 11
 12
 13
 14
 15
 16
 17
[torch.storage._TypedStorage(dtype=torch.int64, device=cpu) of size 18]

In [195]:
emb.shape

torch.Size([32, 3, 2])

In [196]:
# more efficient and easier than bind/unbind
emb.view(32, 6) == torch.cat(torch.unbind(emb, 1), 1)

tensor([[True, True, True, True, True, True],
        [True, True, True, True, True, True],
        [True, True, True, True, True, True],
        [True, True, True, True, True, True],
        [True, True, True, True, True, True],
        [True, True, True, True, True, True],
        [True, True, True, True, True, True],
        [True, True, True, True, True, True],
        [True, True, True, True, True, True],
        [True, True, True, True, True, True],
        [True, True, True, True, True, True],
        [True, True, True, True, True, True],
        [True, True, True, True, True, True],
        [True, True, True, True, True, True],
        [True, True, True, True, True, True],
        [True, True, True, True, True, True],
        [True, True, True, True, True, True],
        [True, True, True, True, True, True],
        [True, True, True, True, True, True],
        [True, True, True, True, True, True],
        [True, True, True, True, True, True],
        [True, True, True, True, T

In [197]:
W1 = torch.randn((6, 100))
b1 = torch.randn(100)
# pytorch can infer missing shape of matrix with -1
h = torch.tanh(emb.view(-1, 6) @ W1 + b1)
h.shape # 100 activations for each of our 32 examples

torch.Size([32, 100])

In [198]:
W2 = torch.randn((100, 27))
b2 = torch.randn(27)

In [199]:
logits = h @ W2 + b2

In [200]:
counts = logits.exp()

In [201]:
prob = counts / counts.sum(1, keepdims=True)

In [202]:
# index into the rows of prob, pluck out the probability assigned to the correct character, and calculate loss
loss = -prob[torch.arange(32), Y].log().mean()
loss

tensor(18.5355)

In [203]:
# ---- clearer sequence -----

In [204]:
X.shape, Y.shape # dataset

(torch.Size([32, 3]), torch.Size([32]))

In [205]:
g = torch.Generator().manual_seed(2147483647) # for reproducibility
C = torch.randn((27, 2), generator=g)
W1 = torch.randn((6, 100), generator=g)
b1 = torch.randn(100, generator=g)
W2 = torch.randn((100, 27), generator=g)
b2 = torch.randn(27, generator=g)
parameters = [C, W1, b1, W2, b2] # cluster parameters into single list

In [206]:
sum(p.nelement() for p in parameters) # number of parameters in total

3481

In [207]:
emb = C[X] # (32, 3, 2)
h = torch.tanh(emb.view(-1, 6) @ W1 + b1)
logits = h @ W2 + b2
# counts = logits.exp()
# prob = counts / counts.sum(1, keepdims=True)
# loss = -prob[torch.arange(32), Y].log().mean()
# Above 3 commented lines can be done (more efficiently) with the following line
F.cross_entropy(logits, Y) # how well neural network currently works with this set of parameters

tensor(17.7697)

In [208]:
for p in parameters:
    p.requires_grad = True

In [218]:
for _ in range(1000):
    # forward pass
    emb = C[X] # (32, 3, 2)
    h = torch.tanh(emb.view(-1, 6) @ W1 + b1) # (32, 100)
    logits = h @ W2 + b2
    loss = F.cross_entropy(logits, Y)
    print(loss.item())
    # backward pass
    for p in parameters:
        p.grad = None
    loss.backward()
    # update
    for p in parameters:
        p.data += -0.1 * p.grad

0.2534579634666443
0.25345855951309204
0.2534559667110443
0.25345659255981445
0.25345394015312195
0.2534545361995697
0.253451943397522
0.2534525692462921
0.253449946641922
0.25345051288604736
0.253447949886322
0.2534485459327698
0.25344598293304443
0.2534465491771698
0.25344398617744446
0.2534445822238922
0.25344201922416687
0.25344258546829224
0.2534400224685669
0.25344058871269226
0.2534380257129669
0.2534386217594147
0.25343605875968933
0.2534366548061371
0.25343406200408936
0.2534346878528595
0.25343212485313416
0.2534327507019043
0.25343015789985657
0.2534307539463043
0.253428190946579
0.25342875719070435
0.2534262537956238
0.25342684984207153
0.2534243166446686
0.25342488288879395
0.2534223198890686
0.25342291593551636
0.253420352935791
0.25342094898223877
0.2534184455871582
0.2534189820289612
0.2534164786338806
0.25341710448265076
0.2534145712852478
0.25341516733169556
0.25341272354125977
0.25341325998306274
0.2534107565879822
0.25341135263442993
0.25340884923934937
0.2534094154

0.25306618213653564
0.25306645035743713
0.2530648708343506
0.25306516885757446
0.2530635595321655
0.2530638575553894
0.25306224822998047
0.25306257605552673
0.2530609369277954
0.2530612349510193
0.25305965542793274
0.253059983253479
0.25305837392807007
0.25305867195129395
0.253057062625885
0.2530573904514313
0.25305578112602234
0.2530561089515686
0.25305449962615967
0.25305476784706116
0.253053218126297
0.2530534863471985
0.2530519366264343
0.2530522346496582
0.25305065512657166
0.2530509829521179
0.253049373626709
0.25304967164993286
0.2530481219291687
0.2530483603477478
0.25304678082466125
0.25304707884788513
0.25304555892944336
0.25304585695266724
0.2530442774295807
0.25304457545280457
0.2530430555343628
0.2530432939529419
0.2530417740345001
0.253042072057724
0.25304052233695984
0.2530408203601837
0.25303927063941956
0.25303956866264343
0.2530380189418793
0.25303831696510315
0.2530367374420166
0.2530370354652405
0.2530355155467987
0.2530357837677002
0.25303423404693604
0.25303456187

0.2528199553489685
0.25282013416290283
0.2528190314769745
0.2528192102909088
0.25281810760498047
0.2528182864189148
0.25281715393066406
0.2528173625469208
0.25281625986099243
0.25281643867492676
0.25281527638435364
0.25281551480293274
0.252814382314682
0.2528145909309387
0.2528134286403656
0.2528136670589447
0.25281253457069397
0.2528127133846283
0.25281158089637756
0.2528117895126343
0.25281065702438354
0.25281086564064026
0.2528097629547119
0.25280994176864624
0.2528088390827179
0.2528090178966522
0.2528079152107239
0.2528080940246582
0.25280702114105225
0.2528071999549866
0.25280603766441345
0.25280627608299255
0.2528051435947418
0.25280535221099854
0.2528042495250702
0.2528044879436493
0.25280332565307617
0.2528035640716553
0.25280246138572693
0.25280261039733887
0.2528015375137329
0.25280171632766724
0.2528006136417389
0.2528007924556732
0.25279971957206726
0.2527998983860016
0.25279879570007324
0.25279897451400757
0.252797931432724
0.25279805064201355
0.25279700756073
0.252797156

In [217]:
logits.max(1)

torch.return_types.max(
values=tensor([13.9407, 19.0261, 21.1539, 21.4368, 17.6025, 13.9407, 16.8368, 14.8818,
        16.6236, 19.3744, 16.8883, 21.8235, 13.9407, 18.1736, 18.1035, 21.0729,
        13.9407, 17.4160, 16.2523, 18.0917, 19.3502, 16.9605, 11.7630, 11.4669,
        16.1268, 13.9407, 16.9908, 17.7825, 13.3590, 16.8897, 20.1144, 17.2401],
       grad_fn=<MaxBackward0>),
indices=tensor([ 9, 13, 13,  1,  0,  9, 12,  9, 22,  9,  1,  0,  9, 22,  1,  0,  9, 19,
         1,  2,  5, 12, 12,  1,  0,  9, 15, 16,  8,  9,  1,  0]))