In [1]:
import torch
import torch.nn.functional as F
import matplotlib.pyplot as plt
%matplotlib inline

In [2]:
words = open('names.txt').read().splitlines()
words[:8], len(words)

(['emma', 'olivia', 'ava', 'isabella', 'sophia', 'charlotte', 'mia', 'amelia'],
 32033)

In [3]:
chars = sorted(list(set(''.join(words))))
s_to_i = {'.': 0} | {s:i+1 for i,s in enumerate(chars)}
i_to_s = {i: c for c, i in s_to_i.items()}
print(i_to_s)


{0: '.', 1: 'a', 2: 'b', 3: 'c', 4: 'd', 5: 'e', 6: 'f', 7: 'g', 8: 'h', 9: 'i', 10: 'j', 11: 'k', 12: 'l', 13: 'm', 14: 'n', 15: 'o', 16: 'p', 17: 'q', 18: 'r', 19: 's', 20: 't', 21: 'u', 22: 'v', 23: 'w', 24: 'x', 25: 'y', 26: 'z'}


In [4]:
block_size = 3
X, Y = [], []

for w in words[:5]:
	print(w)
	context = [0] * block_size
	for c in w + '.':
		idx = s_to_i[c]
		X.append(context)
		Y.append(idx)
		print(''.join([i_to_s[i] for i in context]), '--->', i_to_s[idx])
		context = context[1:] + [idx]

X = torch.tensor(X)
Y = torch.tensor(Y)

emma
... ---> e
..e ---> m
.em ---> m
emm ---> a
mma ---> .
olivia
... ---> o
..o ---> l
.ol ---> i
oli ---> v
liv ---> i
ivi ---> a
via ---> .
ava
... ---> a
..a ---> v
.av ---> a
ava ---> .
isabella
... ---> i
..i ---> s
.is ---> a
isa ---> b
sab ---> e
abe ---> l
bel ---> l
ell ---> a
lla ---> .
sophia
... ---> s
..s ---> o
.so ---> p
sop ---> h
oph ---> i
phi ---> a
hia ---> .


In [5]:
X.shape, Y.shape

(torch.Size([32, 3]), torch.Size([32]))

In [6]:
C = torch.randn((27, 2))

In [7]:
C[5], F.one_hot(torch.tensor(5), num_classes=27).float() @ C

(tensor([2.8155, 1.5945]), tensor([2.8155, 1.5945]))

In [8]:
emb = C[X]
emb.shape

torch.Size([32, 3, 2])

In [9]:
W1 = torch.randn((6, 100))
b1 = torch.randn((100,))

In [10]:
# (emb.reshape(-1, 6)).shape
(emb.reshape(32, 6)).shape

torch.Size([32, 6])

In [11]:
# (emb.view(-1, 6)).shape
(emb.view(32, 6)).shape

torch.Size([32, 6])

In [14]:
h = torch.tanh(emb.view(-1, 6) @ W1 + b1)
h.shape, h

(torch.Size([32, 100]),
 tensor([[-0.9224,  0.9999,  0.8731,  ..., -0.9581,  0.9835, -0.9931],
         [-0.9666,  0.9993,  0.9902,  ..., -0.9998,  0.4777,  0.9451],
         [-0.9937, -0.5445, -1.0000,  ...,  0.8577, -0.8274,  0.9999],
         ...,
         [ 0.7670,  0.9916,  0.9501,  ..., -0.7192,  0.1115, -0.9619],
         [-0.8185,  0.8511, -0.9986,  ...,  0.8134,  0.0461, -0.3466],
         [ 0.9213,  0.9992,  0.2497,  ..., -0.9982,  0.9988, -0.9987]]))

In [16]:
W2 = torch.randn((100, 27))
b2 = torch.randn((27,))

In [18]:
logits = h @ W2 + b2
logits.shape, logits[:3]

(torch.Size([32, 27]),
 tensor([[  5.9441,  10.4885, -18.0245,   7.0498,   1.2182, -12.3557,   1.7396,
          -10.0321,  -6.4100,  -4.6139, -20.7665,  -6.7166,  21.8115,  -0.7176,
           -1.5933, -15.3465,  -4.0633,  -1.3293,  -7.9317,   2.6023, -11.3017,
           22.6629, -15.0545,  -7.5700,  10.1388,  -4.4695,  -3.9472],
         [  5.3665,  -5.2898,  12.2408,  21.9479,  -0.9399, -11.9085,   4.9820,
          -15.0463,   4.3398,   1.8547,  11.7773,   8.3810,   4.1092,   7.4955,
            6.2921,  -0.2723,   2.5665,  -4.8547,  14.6974,   0.4837,  12.6616,
           -3.9040, -17.2736,   5.1584,   0.2884,  19.4615,  -0.8154],
         [ -4.6229,   3.5987,  11.8771,   3.7929,   1.8401,   2.7031,  10.6394,
          -13.7094,   2.5693,   1.3119,   2.4846,   5.4684,  -3.4860,   0.9073,
            0.3943,   8.9392, -17.5498,   0.2026,  -4.8799, -14.0416,  -4.1208,
           -0.7761,  -1.3735,  -4.1658, -14.4030,  -8.2522,  -4.2282]]))

In [21]:
counts = logits.exp()

In [22]:
prob = counts / counts.sum(1, keepdim=True)

In [24]:
prob.shape

torch.Size([32, 27])

In [25]:
prob[torch.arange(32), Y]

tensor([4.3378e-16, 4.8793e-07, 1.2792e-05, 4.5731e-08, 1.5852e-06, 2.1794e-17,
        6.4711e-03, 2.6212e-10, 1.1342e-12, 1.1731e-09, 1.3225e-04, 5.2188e-03,
        3.6171e-06, 1.5611e-18, 3.4729e-06, 4.0677e-09, 9.9881e-13, 1.3811e-09,
        7.9321e-06, 1.1480e-08, 4.3963e-01, 8.6420e-12, 1.0852e-10, 1.0187e-09,
        7.4000e-01, 1.3597e-09, 9.3721e-10, 6.4610e-10, 5.3600e-03, 3.8632e-07,
        2.2253e-06, 1.9601e-04])

In [27]:
loss = -prob[torch.arange(32), Y].log().mean()
loss.item()

17.39974594116211

In [28]:
# ----- summed up -----

In [29]:
X.shape, Y.shape

(torch.Size([32, 3]), torch.Size([32]))

In [30]:
g = torch.Generator().manual_seed(2147483647)
C = torch.randn((27, 2), generator=g)
W1 = torch.randn((6, 100), generator=g)
b1 = torch.randn((100,), generator=g)
W2 = torch.randn((100, 27), generator=g)
b2 = torch.randn((27,), generator=g)
parameters = [C, W1, b1, W2, b2]

In [31]:
sum(p.nelement() for p in parameters) # num of parameters in total

3481

In [32]:
emb = C[X] # (32, 3, 2)
h = torch.tanh(emb.view(-1, 6) @ W1 + b1) # (32, 100)
logits = h @ W2 + b2 # (32, 27)
counts = logits.exp()
prob = counts / counts.sum(1, keepdim=True)
loss = -prob[torch.arange(32), Y].log().mean()
loss.item()

17.769710540771484