In [1]:
import torch
import torch.nn.functional as F
import matplotlib.pyplot as plt
%matplotlib inline

In [2]:
words = open('names.txt').read().splitlines()
words[:8], len(words)

(['emma', 'olivia', 'ava', 'isabella', 'sophia', 'charlotte', 'mia', 'amelia'],
 32033)

In [3]:
chars = sorted(list(set(''.join(words))))
s_to_i = {'.': 0} | {s:i+1 for i,s in enumerate(chars)}
i_to_s = {i: c for c, i in s_to_i.items()}
print(i_to_s)


{0: '.', 1: 'a', 2: 'b', 3: 'c', 4: 'd', 5: 'e', 6: 'f', 7: 'g', 8: 'h', 9: 'i', 10: 'j', 11: 'k', 12: 'l', 13: 'm', 14: 'n', 15: 'o', 16: 'p', 17: 'q', 18: 'r', 19: 's', 20: 't', 21: 'u', 22: 'v', 23: 'w', 24: 'x', 25: 'y', 26: 'z'}


In [30]:
block_size = 3
X, Y = [], []

for w in words:
	# print(w)
	context = [0] * block_size
	for c in w + '.':
		idx = s_to_i[c]
		X.append(context)
		Y.append(idx)
		# print(''.join([i_to_s[i] for i in context]), '--->', i_to_s[idx])
		context = context[1:] + [idx]

X = torch.tensor(X)
Y = torch.tensor(Y)

In [5]:
X.shape, Y.shape

(torch.Size([32, 3]), torch.Size([32]))

In [6]:
C = torch.randn((27, 2))

In [7]:
C[5], F.one_hot(torch.tensor(5), num_classes=27).float() @ C

(tensor([ 0.1370, -0.6434]), tensor([ 0.1370, -0.6434]))

In [8]:
emb = C[X]
emb.shape

torch.Size([32, 3, 2])

In [9]:
W1 = torch.randn((6, 100))
b1 = torch.randn((100,))

In [10]:
# (emb.reshape(-1, 6)).shape
(emb.reshape(32, 6)).shape

torch.Size([32, 6])

In [11]:
# (emb.view(-1, 6)).shape
(emb.view(32, 6)).shape

torch.Size([32, 6])

In [12]:
h = torch.tanh(emb.view(-1, 6) @ W1 + b1)
h.shape, h

(torch.Size([32, 100]),
 tensor([[-0.9992, -0.9909, -0.1217,  ...,  0.9998,  0.8689,  0.9986],
         [-0.9969, -0.5829,  0.2650,  ...,  0.9993,  0.9875,  0.9975],
         [-0.9972, -0.7683,  0.8966,  ...,  0.9995,  0.1327,  0.9995],
         ...,
         [-0.8809,  0.9715, -0.9811,  ..., -0.8039,  0.9991,  0.9652],
         [-0.9957,  0.7402,  0.7765,  ...,  0.8604, -0.9358,  1.0000],
         [ 0.9931, -0.8954, -0.8893,  ..., -0.9192,  0.8982, -0.9839]]))

In [13]:
W2 = torch.randn((100, 27))
b2 = torch.randn((27,))

In [14]:
logits = h @ W2 + b2
logits.shape, logits[:3]

(torch.Size([32, 27]),
 tensor([[  1.1418,   5.8713,  -1.2363,  10.7133,   5.8398,  11.0654,  -1.8630,
           -2.7214,  -7.0748,  -8.8627,  -4.8878,  -0.0305,   2.3964,  -8.1046,
           -8.3177,  -6.0157,   7.8448,   6.1693,  -6.7935,  -0.4248, -11.1370,
           15.7846,   3.0052,  11.1916,  -0.9997,  -0.2221,   5.7844],
         [  3.8092,   5.2330,   5.7965,  12.9424,   6.4163,  10.1635,  -3.2183,
           -3.2833,  -7.4532, -11.1888,  -5.3446,  -3.2703,   2.3193,  -0.2964,
          -14.9323, -13.5390,   2.4193,   1.6517, -10.6673,  -0.6407, -10.0308,
           14.3934,   0.9869,  11.1982,   1.5264,   3.6307,   8.5398],
         [ -5.5460,  -2.6480,  -3.2495,  12.1250,   4.4452,   9.7048,   5.2828,
           -7.9087,  -1.5418,  -5.5074,  -7.9509,   4.6267,   7.4867,  -3.2394,
           -5.5023,   2.1763,   3.9411,   9.7304,  -1.4274,  -0.8148,  -7.3482,
            5.7320,   1.5683,  -0.7392,  -7.1953,  -7.6585,   7.5809]]))

In [15]:
counts = logits.exp()

In [16]:
prob = counts / counts.sum(1, keepdim=True)

In [17]:
prob.shape

torch.Size([32, 27])

In [18]:
prob[torch.arange(32), Y]

tensor([8.6969e-03, 3.2250e-07, 1.7642e-07, 4.6025e-09, 2.0334e-03, 3.3199e-10,
        2.2368e-06, 9.2463e-07, 4.8404e-08, 5.1299e-10, 6.9875e-11, 4.1287e-10,
        4.8262e-05, 3.9057e-06, 3.2595e-11, 1.5766e-04, 1.9262e-11, 5.0398e-11,
        1.1131e-06, 4.9489e-04, 6.9987e-05, 5.8272e-07, 8.8380e-11, 7.2624e-11,
        5.4820e-05, 8.8972e-08, 1.0621e-10, 7.9534e-10, 4.9964e-09, 7.7046e-18,
        7.3553e-08, 2.7033e-07])

In [19]:
loss = -prob[torch.arange(32), Y].log().mean()
loss.item()

17.126298904418945

In [31]:
# ----- summed up -----

In [32]:
X.shape, Y.shape

(torch.Size([228146, 3]), torch.Size([228146]))

In [33]:
g = torch.Generator().manual_seed(2147483647)
C = torch.randn((27, 2), generator=g)
W1 = torch.randn((6, 100), generator=g)
b1 = torch.randn((100,), generator=g)
W2 = torch.randn((100, 27), generator=g)
b2 = torch.randn((27,), generator=g)
parameters = [C, W1, b1, W2, b2]

In [34]:
sum(p.nelement() for p in parameters) # num of parameters in total

3481

In [35]:
for p in parameters:
	p.requires_grad = True

In [36]:
lre = torch.linspace(-3, 0, 1000)
lrs = 10 ** lre

In [37]:
lri = []
lossi = []

for i in range(10000):

	# minibatch construct
	minibatch = torch.randint(0, X.shape[0], (32,))

	# forward pass
	emb = C[X[minibatch]] # (32, 3, 2)
	h = torch.tanh(emb.view(-1, 6) @ W1 + b1) # (32, 100)
	logits = h @ W2 + b2 # (32, 27)
	loss = F.cross_entropy(logits, Y[minibatch])
	
	# backward pass
	for p in parameters:
		p.grad = None
	loss.backward()
	
	## update
	# lr = lrs[i]
	lr = 0.1
	for p in parameters:
		p.data += -lr * p.grad
	
	# track stats
	# lri.append(lri)
	# lossi.append(loss.item())

In [38]:
# plt.plot(lri, lossi)

In [39]:
emb = C[X] # (32, 3, 2)
h = torch.tanh(emb.view(-1, 6) @ W1 + b1) # (32, 100)
logits = h @ W2 + b2 # (32, 27)
loss = F.cross_entropy(logits, Y)
loss.item()

2.4824843406677246

In [40]:
# training split, dev/val split, test split
# 80%, 10%, 10%