In [1]:
import torch
import torch.nn.functional as F
import matplotlib.pyplot as plt
%matplotlib inline

In [3]:
words = open('names.txt', 'r').read().splitlines()
words[:10]

['emma',
 'olivia',
 'ava',
 'isabella',
 'sophia',
 'charlotte',
 'mia',
 'amelia',
 'harper',
 'evelyn']

In [4]:
chars = sorted(list(set(''.join(words))))
stoi = {s: i + 1 for i, s in enumerate(chars)}
stoi['.'] = 0
itos = {i: s for s, i in stoi.items()}
print(itos)

{1: 'a', 2: 'b', 3: 'c', 4: 'd', 5: 'e', 6: 'f', 7: 'g', 8: 'h', 9: 'i', 10: 'j', 11: 'k', 12: 'l', 13: 'm', 14: 'n', 15: 'o', 16: 'p', 17: 'q', 18: 'r', 19: 's', 20: 't', 21: 'u', 22: 'v', 23: 'w', 24: 'x', 25: 'y', 26: 'z', 0: '.'}


In [5]:
block_size = 3
X, Y = [], []
for w in words[:3]:
    print(w)
    context = [0] * block_size
    for ch in w + '.':
        ix = stoi[ch]
        X.append(context)
        Y.append(ix)
        print(''.join(itos[i] for i in context), '----->', itos[ix])
        context = context[1:] + [ix] #crop the last one and append

X = torch.tensor(X)
Y = torch.tensor(Y)

emma
... -----> e
..e -----> m
.em -----> m
emm -----> a
mma -----> .
olivia
... -----> o
..o -----> l
.ol -----> i
oli -----> v
liv -----> i
ivi -----> a
via -----> .
ava
... -----> a
..a -----> v
.av -----> a
ava -----> .


In [6]:
C = torch.randn((27, 2))

In [7]:
C[X]

tensor([[[ 0.2989, -1.2410],
         [ 0.2989, -1.2410],
         [ 0.2989, -1.2410]],

        [[ 0.2989, -1.2410],
         [ 0.2989, -1.2410],
         [-1.1907, -0.0646]],

        [[ 0.2989, -1.2410],
         [-1.1907, -0.0646],
         [ 0.6483, -0.4550]],

        [[-1.1907, -0.0646],
         [ 0.6483, -0.4550],
         [ 0.6483, -0.4550]],

        [[ 0.6483, -0.4550],
         [ 0.6483, -0.4550],
         [-1.0422, -0.7855]],

        [[ 0.2989, -1.2410],
         [ 0.2989, -1.2410],
         [ 0.2989, -1.2410]],

        [[ 0.2989, -1.2410],
         [ 0.2989, -1.2410],
         [-0.2528, -0.8239]],

        [[ 0.2989, -1.2410],
         [-0.2528, -0.8239],
         [-0.2224,  1.0154]],

        [[-0.2528, -0.8239],
         [-0.2224,  1.0154],
         [-0.3499, -0.5857]],

        [[-0.2224,  1.0154],
         [-0.3499, -0.5857],
         [ 0.2954, -0.5359]],

        [[-0.3499, -0.5857],
         [ 0.2954, -0.5359],
         [-0.3499, -0.5857]],

        [[ 0.2954, -0

In [8]:
C[X].shape

torch.Size([16, 3, 2])

In [9]:
emb = C[X]

In [10]:
W1 = torch.randn((6,100))
b1 = torch.randn(100)

In [12]:
torch.cat(torch.unbind(emb, 1), 1).shape

torch.Size([16, 6])

In [17]:
h = torch.tanh(emb.view(-1,6) @ W1 + b1)

In [16]:
h.shape

torch.Size([16, 100])

In [18]:
W2 = torch.randn((100, 27))
b2 = torch.randn(27)

In [19]:
logits = h @ W2 + b2

In [20]:
logits.shape

torch.Size([16, 27])

In [21]:
logits

tensor([[ -6.6745,  -7.4754,   5.6886,   2.5965,  -1.2465,  -6.5902,  -8.2576,
          12.4431,   2.9072,   5.5546,  -8.8877,  -6.7650,   1.0405,  -3.1388,
          14.1150,   3.0673,  -4.5765,   6.2176,  -6.1191,   6.0405,   0.5132,
          -5.8822, -23.4996,  -6.3595, -10.3019,  -6.8006,  -0.0592],
        [-10.2342,  -5.5114,  14.0960,  -1.1482,  -8.9194,  -7.4169, -12.3618,
           6.6457,   4.8402,  -5.7327,  -1.2169, -10.3200,   0.6088,  -5.3768,
          -3.5432,  10.3747,  -4.7056,   0.4381,  -6.3505,  -6.3350,  -1.2659,
         -10.9178,  -4.9719,   2.7948,  -9.6452,  -2.1447,   0.2916],
        [  1.9207,  -5.6206,  -0.8630,  -0.5479,  -3.4970, -16.7600,  -3.1978,
           6.0159,   3.0366,  -3.2691,  -6.2295,  -2.1267,   3.3850,   8.8305,
          -0.0874,  -8.0117,  -4.6370,   8.0025,  -8.0551,   6.7281,   3.9872,
           2.7688, -15.2153,  -6.9070,   2.4371,  -8.2294,  -0.1412],
        [ -1.3998,   6.1825,  -3.3062,   0.4939,  -2.0678,  -2.1343,   7.4536,


In [22]:
counts = logits.exp()
probs = counts / counts.sum(dim=1, keepdims=True)
probs

tensor([[7.8717e-10, 3.5336e-10, 1.8419e-04, 8.3633e-06, 1.7923e-07, 8.5634e-10,
         1.6163e-10, 1.5802e-01, 1.1412e-05, 1.6109e-04, 8.6073e-11, 7.1900e-10,
         1.7646e-06, 2.7015e-08, 8.4103e-01, 1.3392e-05, 6.4151e-09, 3.1262e-04,
         1.3717e-09, 2.6188e-04, 1.0415e-06, 1.7383e-09, 3.8816e-17, 1.0785e-09,
         2.0925e-11, 6.9386e-10, 5.8753e-07],
        [2.6476e-11, 2.9780e-09, 9.7571e-01, 2.3380e-07, 9.8592e-11, 4.4297e-10,
         3.1537e-12, 5.6710e-04, 9.3234e-05, 2.3867e-09, 2.1827e-07, 2.4297e-11,
         1.3548e-06, 3.4070e-09, 2.1315e-08, 2.3614e-02, 6.6663e-09, 1.1423e-06,
         1.2867e-09, 1.3068e-09, 2.0784e-07, 1.3364e-11, 5.1074e-09, 1.2058e-05,
         4.7710e-11, 8.6308e-08, 9.8660e-07],
        [6.0864e-04, 3.2301e-07, 3.7621e-05, 5.1551e-05, 2.7006e-06, 4.6927e-12,
         3.6426e-06, 3.6550e-02, 1.8578e-03, 3.3919e-06, 1.7570e-07, 1.0632e-05,
         2.6321e-03, 6.0987e-01, 8.1708e-05, 2.9564e-08, 8.6376e-07, 2.6646e-01,
         2.8309e-

In [24]:
#probabilities of the actual Y
probs[torch.arange(16), Y]

tensor([8.5634e-10, 3.4070e-09, 6.0987e-01, 8.3214e-03, 3.3377e-13, 1.3392e-05,
        3.1312e-06, 9.5065e-07, 3.7751e-06, 9.1179e-14, 2.0739e-06, 3.0734e-10,
        3.5336e-10, 5.0988e-13, 2.5240e-09, 2.9926e-09])

In [25]:
loss = -probs[torch.arange(16), Y].log().mean()
loss

tensor(17.4465)

Let's bring everything together before we train our network

### Training

In [26]:
X.shape, Y.shape

(torch.Size([16, 3]), torch.Size([16]))

In [30]:
g = torch.Generator().manual_seed(42)
C = torch.randn((27, 2), generator=g)
W1 = torch.randn((6,100), generator=g)
b1 = torch.randn(100, generator=g)
W2 = torch.randn((100, 27), generator=g)
b2 = torch.randn(27, generator=g)
parameters = [C, W1, b1, W2, b2]

In [31]:
#total number of parameters
sum(p.nelement() for p in parameters)

3481

In [33]:
emb = C[X] #(16, 3, 2) 16 letters, 3 at a time, 2 dimensional representation in the lookup table
h = torch.tanh(emb.view(-1,6) @ W1 + b1) #(16, 100)
logits = h @ W2 + b2 #(16, 27)
counts = logits.exp()
probs = counts / counts.sum(dim=1, keepdims=True)
loss = -probs[torch.arange(16), Y].log().mean()
loss.item()

18.162921905517578