In [106]:
import torch
import torch.nn.functional as F
import matplotlib
%matplotlib inline

In [26]:
# Get the vocabulary for the mappings of to and from intergers

words = open("names.txt",'r').read().splitlines()
chars = sorted(list(set(''.join(words))))
stoi = {s:i+1 for i,s in enumerate(chars)}
stoi['.'] = 0
itos = {v:k for k,v in stoi.items()}


In [121]:
# Build the dataset
X,Y = [],[]
context_length = 3

for w in words:
    context = [0] * context_length
    #print(w)
    for ch in w + '.':
        ix = stoi[ch]
        X.append(context)
        Y.append(ix)
        #print(f"{''.join(itos[i] for i in context)} ---> {itos[ix]}")
        context = context[1:] + [ix]

X = torch.tensor(X)
Y = torch.tensor(Y)

In [122]:
X.shape, X.dtype, Y.shape, Y.dtype


(torch.Size([228146, 3]), torch.int64, torch.Size([228146]), torch.int64)

In [64]:
C = torch.randn((27,2))
C

tensor([[ 1.4140, -0.9164],
        [ 0.6345, -0.9186],
        [-0.4419,  0.8661],
        [-0.1640, -0.3767],
        [ 0.1695,  0.7452],
        [-2.2785, -0.4671],
        [ 0.1070,  0.6598],
        [-0.5363,  1.1308],
        [ 1.4783,  0.9955],
        [-0.0713,  0.4893],
        [ 0.6374,  0.9106],
        [ 1.3740, -2.3133],
        [ 1.2617, -1.1571],
        [ 1.6125,  0.0337],
        [ 0.1694, -1.0704],
        [ 1.4233,  0.2604],
        [-1.0649, -0.5052],
        [ 0.3267,  0.1601],
        [ 1.0696, -0.2487],
        [ 0.4330, -1.8828],
        [ 0.0865,  0.5502],
        [ 0.4356,  0.5052],
        [ 0.5720,  0.0388],
        [-0.8659, -0.2791],
        [-2.6366,  0.6593],
        [-1.1495, -0.4954],
        [ 0.0128, -0.3272]])

In [67]:
# Create the embeddings
emb = C[X]
emb.shape

torch.Size([32, 3, 2])

In [82]:
# Create the hidden layer
W1 = torch.randn(6,100)
b1 = torch.randn(100)
h = torch.tanh(emb.view(emb.shape[0],6) @ W1 + b1)

torch.Size([32, 100])

In [84]:
# Create the output layer
W2 = torch.randn(100,27)
b2 = torch.randn(27)

logits = h @ W2 + b2


In [91]:
counts = logits.exp()
probs = counts / counts.sum(1,keepdim=True)
loss = -probs[torch.arange(32),Y].log().mean()
loss

tensor(12.2805)

In [123]:
#----- more respectable
X.shape, Y.shape


(torch.Size([228146, 3]), torch.Size([228146]))

In [127]:
g = torch.Generator().manual_seed(2147483647)
C = torch.randn((27,2), generator=g)
W1 = torch.rand((6,100), generator=g)
b1 = torch.randn(100, generator=g)
W2 = torch.randn(100,27, generator=g)
b2 = torch.randn(27, generator=g)
parameters = [C , W1, b1, W2, b2]

In [125]:
sum(p.nelement() for p in parameters)

3481

In [128]:
# emb = C[X]
# h = torch.tanh(emb.view(-1,6) @ W1 + b1)
# logits = h @ W2 + b2 
# counts = logits.exp()
# prob = counts / counts.sum(1, keepdims=True)
# loss = -prob[torch.arange(32), Y].log().mean()
# loss

In [129]:
for p in parameters:
    p.requires_grad = True


In [130]:
for _ in range(10):
    # forward pass
    emb = C[X]
    h = torch.tanh(emb.view(-1,6) @ W1 + b1)
    logits = h @ W2 + b2 
    counts = logits.exp()
    loss = F.cross_entropy(logits, Y)
    print(f"loss: {loss.item()}")
    # backward pass
    for p in parameters:
        p.grad = None
    loss.backward()
    # update
    for p in parameters:
        p.data += -0.1 * p.grad



loss: 18.71084976196289
loss: 15.05393123626709
loss: 12.727892875671387
loss: 11.155088424682617
loss: 10.05884838104248
loss: 9.131641387939453
loss: 8.403254508972168
loss: 7.905227184295654
loss: 7.516906261444092
loss: 7.18239688873291
