In [22]:
from __future__ import annotations
import random
import torch
import torch.nn.functional as F

%matplotlib inline

In [23]:
words: list[str] = open("names.txt", "r").read().splitlines()

In [24]:
# Gets all the characters, a-z
chars: list[str] = sorted(list(set("".join(words))))

# Maps each character to an integer
stoi: dict[str, int] = {s: i + 1 for i, s in enumerate(chars)}
stoi["."] = 0

# Maps each integer to a character
itos: dict[int, str] = {i: s for s, i in stoi.items()}

In [25]:
# Training set of trigrams (x, y, z)
xs, ys = [], []

for w in words:
    if len(w) < 3:
        continue
    chs = ["."] + list(w) + ["."]
    for ch1, ch2, ch3 in zip(chs, chs[1:], chs[2:]):
        ix1 = stoi[ch1]
        ix2 = stoi[ch2]
        ix3 = stoi[ch3]
        xs.append([ix1, ix2])
        ys.append(ix3)

xs = torch.tensor(xs)
ys = torch.tensor(ys)
num = xs.nelement()
print('number of examples: ', num)

number of examples:  391966


In [26]:
# Weights
W = torch.randn((27*2, 27), requires_grad=True)

In [27]:
# Gradient descent
for k in range(100):
    # Forward pass
    xenc = F.one_hot(xs, num_classes=27).float() # Neural net input: one-hot encoding
    logits = xenc.view(-1, 27*2) @ W # Predict log-counts
    counts = logits.exp() # Counts, equivalent to N
    probs = counts / counts.sum(dim=1, keepdim=True) # Probabilities for next character
    # Negative log likelihood  
    loss = -probs[torch.arange(ys.shape[0]), ys].log().mean() + 0.01*(W**2).mean() # Regularization
    print(loss.item())

    # Backward pass
    W.grad = None
    loss.backward()

    # Update
    W.data += -50 * W.grad

4.420506000518799
3.5385472774505615
3.1578071117401123
2.9506165981292725
2.820193290710449
2.729976177215576
2.6633028984069824
2.612173557281494
2.571805238723755
2.539106845855713
2.5120575428009033
2.489321231842041
2.469984292984009
2.453383684158325
2.4390153884887695
2.426478624343872
2.4154534339904785
2.4056825637817383
2.3969597816467285
2.389120578765869
2.3820321559906006
2.37558650970459
2.369697332382202
2.364292860031128
2.359313726425171
2.3547098636627197
2.3504397869110107
2.346468687057495
2.3427658081054688
2.3393051624298096
2.3360636234283447
2.333021879196167
2.330162286758423
2.3274691104888916
2.3249289989471436
2.3225302696228027
2.320261001586914
2.3181116580963135
2.316073179244995
2.3141379356384277
2.3122987747192383
2.3105485439300537
2.3088815212249756
2.3072917461395264
2.3057749271392822
2.3043253421783447
2.3029394149780273
2.3016128540039062
2.3003427982330322
2.2991247177124023
2.2979562282562256
2.2968344688415527
2.2957568168640137
2.294720649719

In [28]:
# Generate a word based from the neural network
for i in range(10):
    ix = 0
    iy = random.randint(1, 26)
    word = ""
    while True:
        xenc = F.one_hot(torch.tensor([ix, iy]), num_classes=27).float()
        logits = xenc.view(-1, 27*2) @ W
        counts = logits.exp()
        p = counts / counts.sum(1, keepdim=True)
        
        ix = torch.multinomial(p, num_samples=1, replacement=True).item()
        if ix == 0:
            break

        word += itos[ix]

    print(word)

uiiea
vde
fnrprsnxprcbctawn
aeiaaiiaaiajeoeeiaiaraieiaobiawxliii
a
aya
lnvnanas
osmiq
uuuhnxxxsiyun
ami
