In [63]:
import torch
import torch.nn.functional as F
import matplotlib.pyplot as plt
%matplotlib inline

In [64]:
#read in all the words 
words = open('names.txt','r').read().splitlines()
words[:8]

['emma', 'olivia', 'ava', 'isabella', 'sophia', 'charlotte', 'mia', 'amelia']

In [65]:
len(words)

32033

In [66]:
#build the vocabulary of characters and mappings to/from integers
chars = sorted(list(set(''.join(words))))
stoi = {s:i+1 for i,s in enumerate(chars)}
stoi['.'] = 0
itos = {i:s for s,i in stoi.items()}
print(itos)

{1: 'a', 2: 'b', 3: 'c', 4: 'd', 5: 'e', 6: 'f', 7: 'g', 8: 'h', 9: 'i', 10: 'j', 11: 'k', 12: 'l', 13: 'm', 14: 'n', 15: 'o', 16: 'p', 17: 'q', 18: 'r', 19: 's', 20: 't', 21: 'u', 22: 'v', 23: 'w', 24: 'x', 25: 'y', 26: 'z', 0: '.'}


In [67]:
#build the dataset

block_size = 3 # context length: how many characters do we take to predict the next one?
X, Y = [], []
for w in words[:5]:
    print(w) 
    context = [0] * block_size # pad the beginning of the word with zeros
    for ch in w + '.': # 
        ix = stoi[ch] # convert the character to an integer
        X.append(context) # add the context to the input
        Y.append(ix) # add the character to the output
        print(''.join(itos[i] for i in context), '---->', itos[ix]) # print the context and the character
        context = context[1:] + [ix] # move the context one character forward
    
X = torch.tensor(X)
Y = torch.tensor(Y)

emma
... ----> e
..e ----> m
.em ----> m
emm ----> a
mma ----> .
olivia
... ----> o
..o ----> l
.ol ----> i
oli ----> v
liv ----> i
ivi ----> a
via ----> .
ava
... ----> a
..a ----> v
.av ----> a
ava ----> .
isabella
... ----> i
..i ----> s
.is ----> a
isa ----> b
sab ----> e
abe ----> l
bel ----> l
ell ----> a
lla ----> .
sophia
... ----> s
..s ----> o
.so ----> p
sop ----> h
oph ----> i
phi ----> a
hia ----> .


In [68]:
# dataset description
X.shape, X.dtype, Y.shape, Y.dtype

(torch.Size([32, 3]), torch.int64, torch.Size([32]), torch.int64)

In [69]:
X

tensor([[ 0,  0,  0],
        [ 0,  0,  5],
        [ 0,  5, 13],
        [ 5, 13, 13],
        [13, 13,  1],
        [ 0,  0,  0],
        [ 0,  0, 15],
        [ 0, 15, 12],
        [15, 12,  9],
        [12,  9, 22],
        [ 9, 22,  9],
        [22,  9,  1],
        [ 0,  0,  0],
        [ 0,  0,  1],
        [ 0,  1, 22],
        [ 1, 22,  1],
        [ 0,  0,  0],
        [ 0,  0,  9],
        [ 0,  9, 19],
        [ 9, 19,  1],
        [19,  1,  2],
        [ 1,  2,  5],
        [ 2,  5, 12],
        [ 5, 12, 12],
        [12, 12,  1],
        [ 0,  0,  0],
        [ 0,  0, 19],
        [ 0, 19, 15],
        [19, 15, 16],
        [15, 16,  8],
        [16,  8,  9],
        [ 8,  9,  1]])

In [70]:
Y

tensor([ 5, 13, 13,  1,  0, 15, 12,  9, 22,  9,  1,  0,  1, 22,  1,  0,  9, 19,
         1,  2,  5, 12, 12,  1,  0, 19, 15, 16,  8,  9,  1,  0])

In [71]:
C = torch.randn((27,2)) #lookup table for the characters

In [72]:
C[5]

tensor([-3.2411,  0.8856])

In [73]:
# F.one_hot(5,num_classes=27) , this one doesn't work coz input needs to be a pytorch tensor
F.one_hot(torch.tensor(5),num_classes=27)

tensor([0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
        0, 0, 0])

In [74]:
#check the shape of the one hot encoding
F.one_hot(torch.tensor([5,6,7]),num_classes=27).shape

torch.Size([3, 27])

In [75]:
C[5]

tensor([-3.2411,  0.8856])

In [76]:
# F.one_hot(torch.tensor(5),num_classes=27).float() @ C  - if you use this, you'll get an error coz by default the one hot encoding is a long tensor and the lookup table is a float tensor
# So explicitly convert the one hot encoding to a float tensor
# F.one_hot(torch.tensor(5),num_classes=27).float() @ C

In [77]:
emb = C[X] #embedding layer
emb.shape

torch.Size([32, 3, 2])

In [78]:
#let's construct the hidden layer from the paper
W1 = torch.randn((6,100))
b1 = torch.randn((100))

In [79]:
# emb @ W1 + b1  - this is what we wanna do but we've stacked layers so this won't work


In [80]:
# convert emb shape from 32,3,2 to 32,6
# can be done using two of the following methods : -
# 1. torch.cat((emb[:,0,:],emb[:,1,:],emb[:,2,:]),dim=1)
# 2. torch.cat(torch.unbind(emb,dim=1),dim=1)
# 3. emb.view(emb.shape[0],-1)

# torch.cat(torch.unbind(emb,dim=1),dim=1).shape  #think about why is this inefficient (hint: - think about the number of operations)
# emb.view(emb.shape[0],-1).shape
# compare the above two outputs
#torch.cat(torch.unbind(emb,dim=1),dim=1) == emb.view(emb.shape[0],-1)


In [87]:
# hidden layer
h = torch.tanh(emb.view(emb.shape[0],-1) @ W1 + b1)

In [88]:
h

tensor([[-0.6996,  0.8050, -0.9999,  ...,  0.8831, -0.9996, -0.1585],
        [-0.9986,  1.0000, -1.0000,  ...,  0.2626, -1.0000,  0.7621],
        [-0.8292, -1.0000, -0.9988,  ...,  0.9999, -0.9939, -0.9366],
        ...,
        [-0.3547,  0.9908,  0.9952,  ..., -0.9852, -1.0000,  0.9457],
        [ 0.8583, -0.3122,  0.2531,  ...,  0.5100, -0.9516, -0.9667],
        [ 0.9793, -0.9882, -0.9655,  ...,  0.9982, -0.9377, -0.4408]])

In [89]:
h.shape

torch.Size([32, 100])

In [90]:
#output layer
W2 = torch.randn((100,27)) # 27 is the number of characters in the vocabulary
b2 = torch.randn((27)) 


In [91]:
logits = h @ W2 + b2 

In [92]:
logits

tensor([[-1.2670e+00, -4.1335e+00,  1.1560e+01, -3.8736e+00, -2.1251e+00,
          7.4626e+00,  1.8475e+01, -7.9225e-01, -1.3214e+00, -6.8801e+00,
         -5.2634e+00, -2.2411e+01,  1.4866e+01, -4.7668e+00,  5.7661e+00,
         -1.0663e+01, -9.4162e+00, -2.2631e+00, -6.0322e+00,  1.2263e+01,
         -1.4836e+01, -2.9581e+00, -2.6624e+00, -8.9862e+00, -7.4875e+00,
          9.9756e+00,  2.7997e+00],
        [ 7.8278e+00, -2.0445e+00,  1.1542e+01,  1.3643e+00, -8.5967e+00,
          3.2137e+00,  1.6943e+01,  1.3340e+00, -2.5883e+00, -8.6674e+00,
         -4.6562e+00, -1.8361e+01,  1.6021e+01, -1.3632e+01,  2.4298e+00,
         -1.4999e+01, -2.2976e+01, -1.0450e+01, -5.0823e+00,  8.2503e+00,
         -1.2388e+01, -1.9898e+00,  8.0736e+00, -1.7837e+01, -1.1363e+01,
          1.8423e+01, -2.9066e+00],
        [-1.4235e+01, -7.2348e+00, -3.0220e+00, -1.2276e+00,  1.3468e+00,
          3.0062e+00,  1.3418e+01, -9.6748e-02,  1.2857e+00,  4.3598e+00,
          7.7742e+00, -1.1214e+01,  1.24

In [93]:
logits.shape

torch.Size([32, 27])

In [94]:
counts = logits.exp()

In [95]:
prob = counts / counts.sum(dim=1,keepdim=True)

In [96]:
prob.shape

torch.Size([32, 27])

In [97]:
prob[0].sum()

tensor(1.0000)

In [98]:
torch.arange(32)

tensor([ 0,  1,  2,  3,  4,  5,  6,  7,  8,  9, 10, 11, 12, 13, 14, 15, 16, 17,
        18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31])

In [99]:
prob[torch.arange(32),Y] #probabilities of the correct characters

tensor([1.6012e-05, 9.0799e-15, 8.7358e-09, 1.8481e-12, 5.6068e-08, 2.1512e-13,
        3.9111e-01, 3.5415e-12, 8.9883e-09, 1.8946e-05, 5.1223e-16, 4.2348e-11,
        1.4733e-10, 9.5592e-18, 8.5673e-09, 6.7198e-08, 9.4505e-12, 1.2948e-03,
        7.3911e-12, 1.4363e-05, 2.2173e-07, 1.7416e-04, 9.9981e-01, 1.7899e-14,
        8.1261e-17, 1.9467e-03, 1.0247e-11, 1.1051e-04, 1.2154e-01, 2.5999e-08,
        1.3046e-07, 3.3286e-09])

In [100]:
prob[torch.arange(32),Y].log() #log probabilities of the correct characters

tensor([-1.1042e+01, -3.2333e+01, -1.8556e+01, -2.7017e+01, -1.6697e+01,
        -2.9168e+01, -9.3877e-01, -2.6366e+01, -1.8527e+01, -1.0874e+01,
        -3.5208e+01, -2.3885e+01, -2.2638e+01, -3.9189e+01, -1.8575e+01,
        -1.6516e+01, -2.5385e+01, -6.6494e+00, -2.5631e+01, -1.1151e+01,
        -1.5322e+01, -8.6555e+00, -1.8861e-04, -3.1654e+01, -3.7049e+01,
        -6.2416e+00, -2.5304e+01, -9.1104e+00, -2.1075e+00, -1.7465e+01,
        -1.5852e+01, -1.9521e+01])

In [101]:
loss = -prob[torch.arange(32),Y].log().mean() #negative log likelihood loss (this is the loss function we're trying to minimize)

In [102]:
loss

tensor(18.8946)

In [103]:
#### let's make it more organized
X.shape, Y.shape #dataset

(torch.Size([32, 3]), torch.Size([32]))

In [104]:
g = torch.Generator().manual_seed(2147483647) #set the seed for the random number generator
C = torch.randn((27,2),generator=g) #lookup table for the characters
W1 = torch.randn((6,100),generator=g) #weights for the hidden layer, 
b1 = torch.randn((100),generator=g) #bias for the hidden layer 
W2 = torch.randn((100,27),generator=g) #weights for the output layer
b2 = torch.randn((27),generator=g) #bias for the output layer
parameters = [C,W1,b1,W2,b2] #list of all the parameters

In [105]:
sum(p.numel() for p in parameters) #total number of parameters

3481

In [112]:
for p in parameters:
    p.requires_grad = True #set the requires_grad attribute to True for all the parameters

In [113]:
for _ in range(1000): #number of epochs
    #forward pass
    emb = C[X] #embedding layer (32,3,2)
    h = torch.tanh(emb.view(emb.shape[0],-1) @ W1 + b1) #hidden layer (32,100) 
    logits = h @ W2 + b2 #output layer (32,27) 
    # counts = logits.exp() #exponentiate the logits (32,27)
    # prob = counts / counts.sum(dim=1,keepdim=True) #softmax (32,27)
    # loss = -prob[torch.arange(32),Y].log().mean() #negative log likelihood loss (this is the loss function we're trying to minimize)
    loss = F.cross_entropy(logits,Y) #negative log likelihood loss (this is the loss function we're trying to minimize), 
            # *note here that we're not using the softmax function, this is because the cross entropy function already has the softmax function built in and has a more numerically stable implementation
    #backward pass
    print(loss.item()) #print the loss
    for p in parameters:
        p.grad = None #zero out the gradients
    loss.backward() #compute the gradients
    #update the parameters
    for p in parameters:
        p.data += -0.1 * p.grad #update the parameters

17.76971435546875
13.656402587890625
11.298768997192383
9.452457427978516
7.984261989593506
6.891321182250977
6.100014686584473
5.452036380767822
4.898152828216553
4.414664268493652
3.985849380493164
3.6028308868408203
3.262141704559326
2.961380958557129
2.6982972621917725
2.469712972640991
2.271660327911377
2.101283550262451
1.9571771621704102
1.8374855518341064
1.7380967140197754
1.6535115242004395
1.579089879989624
1.5117665529251099
1.4496047496795654
1.3913118839263916
1.335992455482483
1.283052682876587
1.2321909666061401
1.1833813190460205
1.1367988586425781
1.092664122581482
1.0510923862457275
1.0120269060134888
0.9752704501152039
0.9405564665794373
0.9076125025749207
0.8761923313140869
0.8460890054702759
0.8171356916427612
0.7891990542411804
0.7621744275093079
0.7359812259674072
0.7105579972267151
0.6858609914779663
0.661865234375
0.6385655999183655
0.6159817576408386
0.5941659808158875
0.5732104182243347
0.5532562136650085
0.534488320350647
0.5171167254447937
0.50133132934570