In [40]:
from __future__ import annotations
import matplotlib.pyplot as plt
import torch
import torch.nn.functional as F

%matplotlib inline

In [41]:
words: list[str] = open("names.txt", "r").read().splitlines()

In [42]:
# Gets all the characters, a-z
chars: list[str] = sorted(list(set("".join(words))))

# Maps each character to an integer
stoi: dict[str, int] = {s: i + 1 for i, s in enumerate(chars)}
stoi["."] = 0

# Maps each integer to a character
itos: dict[int, str] = {i: s for s, i in stoi.items()}

In [43]:
# Training set of trigrams (x, y, z)
xs, ys, zs = [], [], []

for w in words:
    if len(w) < 3:
        continue
    chs = ["."] + list(w) + ["."]
    for ch1, ch2, ch3 in zip(chs, chs[1:], chs[2:]):
        ix1 = stoi[ch1]
        ix2 = stoi[ch2]
        ix3 = stoi[ch3]
        xs.append(ix1)
        ys.append(ix2)
        zs.append(ix3)

xs = torch.tensor(xs)
ys = torch.tensor(ys)
zs = torch.tensor(zs)
num = xs.nelement()
print('number of examples: ', num)

number of examples:  195983


In [44]:
# Weights
W = torch.randn((27, 27), requires_grad=True)

In [45]:
# Gradient descent
for k in range(100):
    # Forward pass
    xenc = F.one_hot(xs, num_classes=27).float() # Neural net input: one-hot encoding
    logits = xenc @ W # Predict log-counts
    counts = logits.exp() # Counts, equivalent to N
    probs = counts / counts.sum(dim=1, keepdim=True) # Probabilities for next character
    # Negative log likelihood   
    loss = -probs[range(len(ys)), ys].log().mean() - probs[range(len(zs)), zs].log().mean() + 0.01*(W**2).mean() # Regularization
    print(loss.item())

    # Backward pass
    W.grad = None
    loss.backward()

    # Update
    W.data += -50 * W.grad

7.596063137054443
6.584152698516846
6.289413928985596
5.982585430145264
5.933457374572754
5.7625732421875
5.7820940017700195
5.676336765289307
5.7417073249816895
5.602336406707764
5.63689661026001
5.602008819580078
5.718888759613037
5.5240254402160645
5.515870571136475
5.518462181091309
5.5594401359558105
5.550570011138916
5.684530735015869
5.479833126068115
5.473168849945068
5.468958854675293
5.465601921081543
5.464122772216797
5.470057964324951
5.489652156829834
5.581336498260498
5.47454309463501
5.539151191711426
5.492151260375977
5.605845928192139
5.449794292449951
5.469024658203125
5.495965957641602
5.630456447601318
5.437384128570557
5.435172080993652
5.438477993011475
5.457881927490234
5.486152172088623
5.618582248687744
5.430708408355713
5.430464744567871
5.438803195953369
5.478694915771484
5.489065170288086
5.633249759674072
5.425116539001465
5.4227142333984375
5.423880577087402
5.4298858642578125
5.458856105804443
5.483203887939453
5.626572132110596
5.4208807945251465
5.41830

In [46]:
# Generate a word based from the neural network
for i in range(10):
    ix = 0
    word = ""
    while True:
        xenc = F.one_hot(torch.tensor([ix]), num_classes=27).float()
        logits = xenc @ W
        counts = logits.exp()
        p = counts / counts.sum(1, keepdim=True)
        
        ix = torch.multinomial(p, num_samples=1, replacement=True).item()
        if ix == 0:
            break

        word += itos[ix]

    print(word)

te
aerinriselngema
dyanloelna
asisilalaryzavir
rzivralnaeidi
cezarabdeboioeneekiaenavahlsanayn
kaihiphasnuhalna
avise
amcoi
ileihrl
