In [1]:
import torch
import torch.nn.functional as F
import matplotlib.pyplot as plt
%matplotlib inline

Matplotlib is building the font cache; this may take a moment.


In [2]:
words = open("names.txt", "r").read().splitlines()
words[:8]

['emma', 'olivia', 'ava', 'isabella', 'sophia', 'charlotte', 'mia', 'amelia']

In [3]:
len(words)

32033

In [4]:
# build the vocabulary of characters and mappings to/from integers
chars = sorted(list(set(''.join(words))))
stoi = {s:i+1 for i,s in enumerate(chars)}
stoi['.'] = 0
itos = {i:s for s,i in stoi.items()}
print(itos)

{1: 'a', 2: 'b', 3: 'c', 4: 'd', 5: 'e', 6: 'f', 7: 'g', 8: 'h', 9: 'i', 10: 'j', 11: 'k', 12: 'l', 13: 'm', 14: 'n', 15: 'o', 16: 'p', 17: 'q', 18: 'r', 19: 's', 20: 't', 21: 'u', 22: 'v', 23: 'w', 24: 'x', 25: 'y', 26: 'z', 0: '.'}


In [49]:
#build the dataset

block_size = 3 #context length, how many characters do we take to predict the next one?
X, Y = [], []

for w in words:

    context = [0]*block_size
    for ch in w + '.':
        ix = stoi[ch]
        X.append(context)
        Y.append(ix)
        # print(''.join(itos[i] for i in context), '--->', itos[ix])
        context = context[1:] + [ix]

X = torch.tensor(X)
Y = torch.tensor(Y)

In [15]:
X.shape, X.dtype, Y.shape, Y.dtype

(torch.Size([32, 3]), torch.int64, torch.Size([32]), torch.int64)

In [12]:
#embedding for characters (2D)
C = torch.randn((27,2))
C

tensor([[ 1.2565,  1.2789],
        [-1.9737, -1.1920],
        [ 0.7217,  2.2084],
        [-1.0437, -0.6410],
        [-2.3968,  0.1482],
        [ 0.8471, -1.1352],
        [ 1.5126,  0.6362],
        [ 1.6220, -1.2815],
        [-0.1168, -0.4418],
        [ 1.0136, -1.5135],
        [-2.3728,  0.7900],
        [ 1.9052,  1.1632],
        [-0.4115, -1.5138],
        [-0.3030,  0.4848],
        [ 0.3223,  1.6248],
        [ 0.3525,  0.6096],
        [ 1.1807,  2.2449],
        [ 1.0616,  0.2388],
        [ 1.9691, -2.6342],
        [ 0.5457, -0.3260],
        [ 0.3052,  0.4998],
        [-0.9172,  0.2843],
        [ 1.9726,  0.5146],
        [-1.4049, -0.5832],
        [ 0.8239,  0.0492],
        [ 0.7660,  0.4976],
        [ 3.0066,  0.1700]])

In [14]:
F.one_hot(torch.tensor(5), num_classes=27).float() @ C

tensor([ 0.8471, -1.1352])

In [17]:
emb = C[X]
emb.shape

torch.Size([32, 3, 2])

In [31]:
W1 = torch.randn((6, 100))
b1 = torch.randn(100)

In [33]:
h = torch.tanh(emb.view(-1,6) @ W1 + b1)
h.shape

torch.Size([32, 100])

In [34]:
W2 = torch.randn((100,27))
b2 = torch.randn(27)

In [35]:
logits = h @ W2 + b2
logits.shape

torch.Size([32, 27])

In [36]:
counts = logits.exp()
probs = counts / counts.sum(1, keepdims=True)
probs.shape

torch.Size([32, 27])

In [40]:
loss = -probs[torch.arange(32), Y].log().mean() #Y contains the index of the character that follows (what we are trying to predict)
loss

tensor(16.7904)

In [50]:
#make reproducible
g = torch.Generator().manual_seed(2147483647) # for reproducibility
C = torch.randn((27, 2), generator=g) #embedding matrix
W1 = torch.randn((6, 100), generator=g) #first layer of weights
b1 = torch.randn(100, generator=g)
W2 = torch.randn((100, 27), generator=g)
b2 = torch.randn(27, generator=g)
parameters = [C, W1, b1, W2, b2]

In [51]:
sum(p.nelement() for p in parameters) #total number of parameters

3481

In [52]:
for p in parameters:
    p.requires_grad = True

In [71]:
for _ in range(100):
    #minibatch construct
    ix = torch.randint(0, X.shape[0], (32,))

    #forward pass
    emb = C[X[ix]] # (32, 3, 2)
    h = torch.tanh(emb.view(-1,6) @ W1 + b1) #(32,100)
    logits = h @ W2 + b2 #(32,27)
    loss = F.cross_entropy(logits, Y[ix]) #alternative to rolling our own mean nlll loss (logits are activations, each row is essentially scores for each character index, Y are target indices)

    #backward pass
    for p in parameters:
        p.grad = None
    
    loss.backward()
    #update
    for p in parameters:
        p.data += -0.1 * p.grad

print(loss.item())

2.6483726501464844


tensor([ 59115, 222589, 145519,  20121, 201918, 128631, 118488, 202662, 178759,
         23517, 193945, 226003,  91349,  12780, 180741, 195416, 223110, 118160,
        110582,  59531,  90645,  78288,  30326, 155760, 128907, 160310, 155998,
        191140,  98608,  41666,  16300,   8891])

In [81]:
torch.tanh(torch.tensor(-0.5))

tensor(-0.4621)