In [3]:
import torch
import torch.nn.functional as F
import matplotlib.pyplot as plt
%matplotlib inline

if False and torch.backends.mps.is_available(): # mps turned out to be much slower for some reason
    pt_device = torch.device("mps")
    print("torch using mps")
else:
    pt_device="cpu"
    
torch.set_default_device(pt_device)    

In [4]:
words = open("names.txt").read().splitlines()
print(len(words))
words[:5]

32033


['emma', 'olivia', 'ava', 'isabella', 'sophia']

In [5]:
chars = sorted(list(set(''.join(words))))
stoi = {ch: i+1 for i, ch in enumerate(chars)}
stoi['.'] = 0   
itos = {i: ch for ch, i in stoi.items()}
print(stoi)
print(itos)

{'a': 1, 'b': 2, 'c': 3, 'd': 4, 'e': 5, 'f': 6, 'g': 7, 'h': 8, 'i': 9, 'j': 10, 'k': 11, 'l': 12, 'm': 13, 'n': 14, 'o': 15, 'p': 16, 'q': 17, 'r': 18, 's': 19, 't': 20, 'u': 21, 'v': 22, 'w': 23, 'x': 24, 'y': 25, 'z': 26, '.': 0}
{1: 'a', 2: 'b', 3: 'c', 4: 'd', 5: 'e', 6: 'f', 7: 'g', 8: 'h', 9: 'i', 10: 'j', 11: 'k', 12: 'l', 13: 'm', 14: 'n', 15: 'o', 16: 'p', 17: 'q', 18: 'r', 19: 's', 20: 't', 21: 'u', 22: 'v', 23: 'w', 24: 'x', 25: 'y', 26: 'z', 0: '.'}


In [6]:
def build_data(words, block_size):
    X, Y = [], [] 
    for w in words:
        context = [0] * block_size
        for c in w + '.':
            ix = stoi[c]
            X.append(context)
            Y.append(ix)
            context = context[1:] + [ix]
    return torch.tensor(X), torch.tensor(Y)

import random

random.shuffle(words)
n1 = int(len(words) * 0.8)
n2 = int(len(words) * 0.9)

block_size = 3

X_train, Y_train = build_data(words[:n1], block_size)
X_dev, Y_dev = build_data(words[n1:n2], block_size)
X_test, Y_test = build_data(words[n2:], block_size)

print(X_train.shape, X_dev.shape, X_test.shape)
print(Y_train.shape, Y_dev.shape, Y_test.shape)

torch.Size([182315, 3]) torch.Size([22918, 3]) torch.Size([22913, 3])
torch.Size([182315]) torch.Size([22918]) torch.Size([22913])


In [7]:
C = torch.rand((27,10))
W1 = torch.rand((C.shape[1] * 3, 200))
b1 = torch.rand((200,))
W2 = torch.randn(200, 27)
b2 = torch.randn(27)

parameters = [C, W1, b1, W2, b2]

for p in parameters:
    p.requires_grad = True

sum(p.numel() for p in parameters)

11897

In [15]:
%%time

max_steps = 200000
batch_size = 32
lossi = []

for i in range(200000):
    # mini batch, run the whole fordward back ward update in just a small batch
    ix = torch.randint(0, X_train.shape[0], (batch_size,))
    # forward pas
    emb = C[X_train[ix]] 
    h = torch.tanh(emb.view(-1, C.shape[1] * 3) @ W1 + b1) 
    logits = h @ W2 + b2 # (32, 27)
    # counts = logits.exp()
    # prob = counts / counts.sum(1, keepdim=True)
    # loss = -prob[torch.arange(32), Y].log().mean()
    loss = F.cross_entropy(logits, Y_train[ix]) # exactly the same as the above, just much more efficient
    # backward pass
    for p in parameters:
        p.grad = None
    loss.backward()

    # update
    learning_rate = 0.1 if i < max_steps/2 else 0.01

    for p in parameters:
        p.data += -learning_rate * p.grad

    if i % 10000 == 0:
        print(f"{i:7d}/{max_steps:7d} loss: {loss.item():.4f}")
        
    lossi.append(loss.item())


loss for single batch 3.528170347213745
CPU times: user 35.6 s, sys: 1min 46s, total: 2min 21s
Wall time: 17.5 s


In [10]:
emb = C[X_train] # (32, 3, 2)
h = torch.tanh(emb.view(-1,C.shape[1] * 3) @ W1 + b1) 
logits = h @ W2 + b2 # (32, 27)
loss = F.cross_entropy(logits, Y_train) 
print("loss for the training dataset", loss.item())

loss for the training dataset 3.11991810798645


In [11]:
emb = C[X_dev] # (32, 3, 2)
h = torch.tanh(emb.view(-1,C.shape[1] * 3) @ W1 + b1) 
logits = h @ W2 + b2 # (32, 27)
loss = F.cross_entropy(logits, Y_dev) 
print("loss for the dev dataset", loss.item())

loss for the dev dataset 3.113191843032837


In [13]:
# sampling from the model

def sample_next(context, temperature=1):
    emb = C[context]
    h = torch.tanh(emb.view(-1, C.shape[1] * 3) @ W1 + b1) 
    logits = h @ W2 + b2 # (32, 27)
    probs = F.softmax(logits / temperature, 1)
    return torch.multinomial(probs, 1).item()

def sample(init_context, temperature=1):
    res = init_context
    while True:
        next = sample_next(res[-3:], temperature)
        if next == 0: break
        res.append(next)
    
    return ''.join([itos[i] for i in res])

for _ in range(10):
    print(sample([0, 0, 0], temperature=1)[3:])

oeereeeaoeemoetaavuoiexeuoehe
raoaesenne
ciqneeievsyeevleeeaernieddepleepkriileetmn
oynaioeiyommolaanelfndliealeueeeesakeheraube
eeieeieeeinudageeeeyilor
ytlocavaeeeedmvedevexarednaeyyyaeiesieeue
enoteoxlvee
idjemceieeev
unnoelyeeeeeresnoeteoesteeeeeeoennieheekeeeeebingioenikkaheyoenreyecieyeri
ynleeeddivednnlyinesmaneaixeglreisiacaehayatamrgeliamertsomjeeinsmidedleaeaeoeoakedeelaeieeoteeeeeoeasxiherdjmevesieobtaealeeiduverexe
