In [96]:
import torch
import torch.nn.functional as F
import matplotlib.pyplot as plt # for making figures
%matplotlib inline

In [97]:
# read in all the words
words = open('names.txt', 'r').read().splitlines()
words[:8]

['emma', 'olivia', 'ava', 'isabella', 'sophia', 'charlotte', 'mia', 'amelia']

In [98]:
len(words)

32033

In [99]:
# build the vocabulary of characters and mappings to/from integers
chars = sorted(list(set(''.join(words))))
stoi = {s:i+1 for i,s in enumerate(chars)}
stoi['.'] = 0
itos = {i:s for s,i in stoi.items()}
print(itos)

{1: 'a', 2: 'b', 3: 'c', 4: 'd', 5: 'e', 6: 'f', 7: 'g', 8: 'h', 9: 'i', 10: 'j', 11: 'k', 12: 'l', 13: 'm', 14: 'n', 15: 'o', 16: 'p', 17: 'q', 18: 'r', 19: 's', 20: 't', 21: 'u', 22: 'v', 23: 'w', 24: 'x', 25: 'y', 26: 'z', 0: '.'}


In [100]:
print(stoi)

{'a': 1, 'b': 2, 'c': 3, 'd': 4, 'e': 5, 'f': 6, 'g': 7, 'h': 8, 'i': 9, 'j': 10, 'k': 11, 'l': 12, 'm': 13, 'n': 14, 'o': 15, 'p': 16, 'q': 17, 'r': 18, 's': 19, 't': 20, 'u': 21, 'v': 22, 'w': 23, 'x': 24, 'y': 25, 'z': 26, '.': 0}


In [166]:
block_size = 3
X, Y = [], []
for w in words:
    context = [0] * block_size
    for ch in w + ".":
        # print(context)
        ix = stoi[ch]
        X.append(context)
        Y.append(ix)
        print(''.join(itos[i] for i in context), '--->', itos[ix])
        context = context[1:] + [ix]

X = torch.tensor(X)
Y = torch.tensor(Y)

... ---> y
..y ---> u
.yu ---> h
yuh ---> e
uhe ---> n
hen ---> g
eng ---> .
... ---> d
..d ---> i
.di ---> o
dio ---> n
ion ---> d
ond ---> r
ndr ---> e
dre ---> .
... ---> x
..x ---> a
.xa ---> v
xav ---> i
avi ---> e
vie ---> n
ien ---> .
... ---> j
..j ---> o
.jo ---> r
jor ---> i
ori ---> .
... ---> j
..j ---> u
.ju ---> a
jua ---> n
uan ---> l
anl ---> u
nlu ---> i
lui ---> s
uis ---> .
... ---> e
..e ---> r
.er ---> a
era ---> n
ran ---> d
and ---> i
ndi ---> .
... ---> p
..p ---> h
.ph ---> i
phi ---> a
hia ---> .
... ---> s
..s ---> a
.sa ---> m
sam ---> a
ama ---> t
mat ---> h
ath ---> a
tha ---> .
... ---> p
..p ---> h
.ph ---> o
pho ---> e
hoe ---> n
oen ---> i
eni ---> x
nix ---> .
... ---> e
..e ---> m
.em ---> m
emm ---> e
mme ---> l
mel ---> y
ely ---> n
lyn ---> n
ynn ---> .
... ---> h
..h ---> o
.ho ---> l
hol ---> l
oll ---> a
lla ---> n
lan ---> .
... ---> h
..h ---> o
.ho ---> l
hol ---> l
oll ---> i
lli ---> s
lis ---> .
... ---> c
..c ---> a
.ca ---> l
cal ---> l

In [167]:
X.shape, X.dtype, Y.shape, Y.dtype

(torch.Size([228146, 3]), torch.int64, torch.Size([228146]), torch.int64)

In [109]:
C = torch.rand((27, 2))
C

tensor([[0.6418, 0.4417],
        [0.2725, 0.4714],
        [0.5786, 0.1731],
        [0.7663, 0.9310],
        [0.0048, 0.4282],
        [0.0737, 0.2680],
        [0.2207, 0.7131],
        [0.1909, 0.0114],
        [0.9758, 0.1366],
        [0.5438, 0.8642],
        [0.8584, 0.1882],
        [0.1164, 0.4327],
        [0.1716, 0.0503],
        [0.6738, 0.8158],
        [0.7585, 0.8612],
        [0.6956, 0.1967],
        [0.6590, 0.6322],
        [0.6148, 0.7177],
        [0.5058, 0.1621],
        [0.5712, 0.7908],
        [0.4739, 0.1961],
        [0.3386, 0.2074],
        [0.9146, 0.9684],
        [0.0426, 0.6525],
        [0.2510, 0.0060],
        [0.6082, 0.5551],
        [0.5861, 0.0919]])

In [118]:
emb = C[X]
emb

tensor([[[0.6418, 0.4417],
         [0.6418, 0.4417],
         [0.6418, 0.4417]],

        [[0.6418, 0.4417],
         [0.6418, 0.4417],
         [0.0737, 0.2680]],

        [[0.6418, 0.4417],
         [0.0737, 0.2680],
         [0.6738, 0.8158]],

        [[0.0737, 0.2680],
         [0.6738, 0.8158],
         [0.6738, 0.8158]],

        [[0.6738, 0.8158],
         [0.6738, 0.8158],
         [0.2725, 0.4714]],

        [[0.6418, 0.4417],
         [0.6418, 0.4417],
         [0.6418, 0.4417]],

        [[0.6418, 0.4417],
         [0.6418, 0.4417],
         [0.6956, 0.1967]],

        [[0.6418, 0.4417],
         [0.6956, 0.1967],
         [0.1716, 0.0503]],

        [[0.6956, 0.1967],
         [0.1716, 0.0503],
         [0.5438, 0.8642]],

        [[0.1716, 0.0503],
         [0.5438, 0.8642],
         [0.9146, 0.9684]],

        [[0.5438, 0.8642],
         [0.9146, 0.9684],
         [0.5438, 0.8642]],

        [[0.9146, 0.9684],
         [0.5438, 0.8642],
         [0.2725, 0.4714]],

    

In [135]:
W1 = torch.randn((6, 100))
b1 = torch.randn(100)

In [136]:
h = torch.tanh(emb.view(emb.shape[0], 6) @ W1 + b1)

tensor([[-0.9947,  0.8291, -0.2506,  ..., -0.4744, -0.1291, -0.9193],
        [-0.9718,  0.6794, -0.1593,  ..., -0.8157,  0.5700, -0.8612],
        [-0.9962,  0.7342, -0.2820,  ..., -0.9665, -0.8618, -0.7311],
        ...,
        [-0.9963,  0.7294,  0.3833,  ...,  0.4273, -0.3928, -0.9425],
        [-0.9974,  0.9654, -0.7795,  ...,  0.0967,  0.3673, -0.9390],
        [-0.9938,  0.6962,  0.5809,  ..., -0.7004, -0.1440, -0.7141]])

In [138]:
h.shape

torch.Size([32, 100])

In [141]:
W2 = torch.randn((100, 27))
b2 = torch.randn(27)

In [142]:
logits = h @ W2 + b2

In [143]:
logits

tensor([[ 2.8444e+00,  7.4334e+00,  6.5262e+00,  7.5898e+00, -8.9104e+00,
         -4.6839e+00, -9.6686e-01,  6.1252e+00,  2.3346e-01,  5.2188e+00,
          3.8626e+00, -3.8649e+00, -1.2328e+00,  5.6783e-01, -1.4389e+01,
          3.7009e+00,  4.7408e+00, -5.7185e+00,  1.0631e+01,  1.8843e+00,
          8.7152e-01,  8.0906e-02, -5.1483e+00,  1.0532e+01, -1.3594e+01,
          4.2618e+00, -1.1771e+00],
        [ 1.9218e+00,  4.3485e+00,  1.6505e+00,  5.7870e+00, -9.4042e+00,
         -1.7585e+00, -2.0727e+00,  2.8138e+00,  3.9755e+00,  3.4465e+00,
          2.8568e+00, -2.6628e+00, -3.6580e+00, -3.7647e+00, -1.4347e+01,
          5.7077e+00,  3.0375e+00, -5.3602e+00,  7.9805e+00, -1.4592e+00,
          5.0242e+00,  2.1983e+00, -1.9260e-01,  7.4173e+00, -1.0228e+01,
          7.4686e+00, -5.5358e+00],
        [ 7.9594e-01,  8.2626e+00,  5.7637e+00,  1.3642e+01, -5.5422e+00,
         -3.5816e+00,  4.0692e-01,  7.8700e+00, -1.0133e+00,  3.3286e+00,
          4.4234e+00, -1.3031e+00,  1.71

In [146]:
counts = logits.exp()

In [147]:
probs = counts / counts.sum(1, keepdims=True)

In [148]:
probs.shape

torch.Size([32, 27])

In [149]:
probs[0].shape

torch.Size([27])

In [151]:
loss = -probs[torch.arange(32), Y].log().mean()
loss

tensor(9.1071)

(torch.Size([182441, 3]), torch.Size([182441]))

In [210]:
# build the dataset
block_size = 3 # context length: how many characters do we take to predict the next one?

def build_dataset(words):  
  X, Y = [], []
  for w in words:

    #print(w)
    context = [0] * block_size
    for ch in w + '.':
      ix = stoi[ch]
      X.append(context)
      Y.append(ix)
      #print(''.join(itos[i] for i in context), '--->', itos[ix])
      context = context[1:] + [ix] # crop and append

  X = torch.tensor(X)
  Y = torch.tensor(Y)
  print(X.shape, Y.shape)
  return X, Y

import random
random.seed(42)
random.shuffle(words)
n1 = int(0.8*len(words))
n2 = int(0.9*len(words))
print("n1 : " + str(n1))
print("n2 : " + str(n2))
Xtr, Ytr = build_dataset(words[:n1])
Xdev, Ydev = build_dataset(words[n1:n2])
Xte, Yte = build_dataset(words[n2:])

n1 : 25626
n2 : 28829
torch.Size([182512, 3]) torch.Size([182512])
torch.Size([22860, 3]) torch.Size([22860])
torch.Size([22774, 3]) torch.Size([22774])


In [211]:
g = torch.Generator().manual_seed(2147483647) # for reproducibility
C = torch.randn((27, 10), generator=g)
W1 = torch.randn((30, 200), generator=g)
b1 = torch.randn(200, generator=g)
W2 = torch.randn((200, 27), generator=g)
b2 = torch.randn(27, generator=g)
parameters = [C, W1, b1, W2, b2]

In [212]:
for p in parameters:
    p.requires_grad = True

In [213]:
sum(p.nelement() for p in parameters) # number of parameters in total

11897

In [224]:
for _ in range(50000):
    ix = torch.randint(0, Xtr.shape[0], (32,))
    emb = C[Xtr[ix]]
    h = torch.tanh(emb.view(-1, 30) @ W1 + b1)
    logits = h @ W2 + b2
    loss = F.cross_entropy(logits, Ytr[ix])

    for p in parameters:
        p.grad = None
    loss.backward()

    lr = 0.01
    for p in parameters:
        p.data += -lr * p.grad
    
print(loss.item()) 

2.204967737197876


In [225]:
loss

tensor(2.2050, grad_fn=<NllLossBackward0>)

In [226]:
emb = C[Xdev]
h = torch.tanh(emb.view(-1, 30) @ W1 + b1)
logits = h @ W2 + b2
loss = F.cross_entropy(logits, Ydev)
loss

tensor(2.1993, grad_fn=<NllLossBackward0>)

In [229]:
emb = C[Xte]
h = torch.tanh(emb.view(-1, 30) @ W1 + b1)
logits = h @ W2 + b2
loss = F.cross_entropy(logits, Yte)
loss

tensor(2.2006, grad_fn=<NllLossBackward0>)

# 