In [None]:
import torch
import torch.nn.functional as F
import matplotlib.pyplot as plt

In [73]:
# Construct the dataset
words = list()
with open("names.txt", "r") as infile:
    for line in infile:
        words.append(line.strip())

special = "."
chars = sorted(list(set("".join(words))))
stoi = {s: i + 1 for i, s in enumerate(chars)}
stoi[special] = 0
itos = {i: s for s, i in stoi.items()}

xs, ys = list(), list()
for word in words:
    chars = [special] + list(word) + [special]
    for ch1, ch2 in zip(chars, chars[1:]):
        idx1 = stoi[ch1]
        idx2 = stoi[ch2]
        xs.append(idx1)
        ys.append(idx2)

xs = torch.tensor(xs)
ys = torch.tensor(ys)
num_examples = xs.nelement()
print(f"Number of examples: {num_examples}")

# Initialize the network
gen = torch.Generator().manual_seed(2147483647)
W = torch.randn((len(stoi), len(stoi)), generator=gen, requires_grad=True)

Number of examples: 228146


In [74]:
# Gradient descent
regularization = 0.01
for epoch in range(1000):
    xenc = F.one_hot(xs, num_classes=len(stoi)).float()
    logits = xenc @ W
    counts = logits.exp()
    probs = counts / counts.sum(1, keepdims=True)
    loss = -probs[torch.arange(len(ys)), ys].log().mean() + regularization * W.pow(2).mean()

    if epoch % 100 == 0:
        print(f"Epoch {epoch} - loss {loss.item():.4f}")
    # Set gradient to zero and backprob
    W.grad = None
    loss.backward()

    # Update the weights based on new gradients
    W.data += -50 * W.grad

Epoch 0 - loss 3.7686
Epoch 100 - loss 2.4900
Epoch 200 - loss 2.4830
Epoch 300 - loss 2.4815
Epoch 400 - loss 2.4810
Epoch 500 - loss 2.4807
Epoch 600 - loss 2.4806
Epoch 700 - loss 2.4805
Epoch 800 - loss 2.4805
Epoch 900 - loss 2.4805


In [None]:
# Sample from the model
for i in range(5):
    output = list()
    ix = 0

    while True:
        xenc = F.one_hot(torch.tensor([ix]), num_classes=len(stoi)).float()
        logits = xenc @ W
        counts = logits.exp()
        probs = counts / counts.sum(1, keepdim=True)

        # Next character
        ix = torch.multinomial(probs, num_samples=1, replacement=True, generator=gen).item()
        output.append(itos[ix])
        if ix == 0:
            break

    print("".join(output))

jigua.
sadryrolyiniydavesole.
rish.
be.
ka.
