In [None]:
from typing import Final, Tuple
import torch


In [2]:
words: Final[list[str]] = open("names.txt").read().splitlines()
words[:10]

['emma',
 'olivia',
 'ava',
 'isabella',
 'sophia',
 'charlotte',
 'mia',
 'amelia',
 'harper',
 'evelyn']

In [3]:
char_set: Final[set[str]] = set(list("".join(words)))
char_set

{'a',
 'b',
 'c',
 'd',
 'e',
 'f',
 'g',
 'h',
 'i',
 'j',
 'k',
 'l',
 'm',
 'n',
 'o',
 'p',
 'q',
 'r',
 's',
 't',
 'u',
 'v',
 'w',
 'x',
 'y',
 'z'}

In [5]:
stoi: Final[dict[str, int]] = {
    ch: i + 1
    for i, ch in enumerate(char_set)
}
stoi['.'] = 0
itos: Final[dict[int, str]] = {i: ch for ch, i in stoi.items()}
char_count: Final[int] = len(stoi)
print(stoi)
print(itos)

{'n': 1, 'e': 2, 't': 3, 'f': 4, 'o': 5, 'r': 6, 'b': 7, 'y': 8, 'v': 9, 'i': 10, 'c': 11, 'p': 12, 'h': 13, 'u': 14, 'z': 15, 'l': 16, 'x': 17, 'j': 18, 's': 19, 'q': 20, 'w': 21, 'd': 22, 'm': 23, 'g': 24, 'a': 25, 'k': 26, '.': 0}
{1: 'n', 2: 'e', 3: 't', 4: 'f', 5: 'o', 6: 'r', 7: 'b', 8: 'y', 9: 'v', 10: 'i', 11: 'c', 12: 'p', 13: 'h', 14: 'u', 15: 'z', 16: 'l', 17: 'x', 18: 'j', 19: 's', 20: 'q', 21: 'w', 22: 'd', 23: 'm', 24: 'g', 25: 'a', 26: 'k', 0: '.'}


In [14]:
xs: Final[list[Tuple[int, int]]] = []
ys: Final[list[int]] = []
for w in words[:1]:
    chs: list[str] = ['.'] + list(w) + ['.']
    for ch1, ch2, ch3 in zip(chs, chs[1:], chs[2:]):
        xs.append((stoi[ch1], stoi[ch2]))
        ys.append(stoi[ch3])
xs_ts = torch.tensor(xs)
ys_ts = torch.tensor(ys)


In [15]:
print(xs_ts)
print(ys_ts)

tensor([[ 0,  2],
        [ 2, 23],
        [23, 23],
        [23, 25]])
tensor([23, 23, 25,  0])


In [17]:
import torch.nn.functional as F

x_enc = F.one_hot(xs_ts, num_classes=char_count).float()
y_enc = F.one_hot(ys_ts, num_classes=char_count).float()

print(x_enc)
print(y_enc)

tensor([[[1., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
          0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
         [0., 0., 1., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
          0., 0., 0., 0., 0., 0., 0., 0., 0., 0.]],

        [[0., 0., 1., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
          0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
         [0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
          0., 0., 0., 0., 0., 0., 1., 0., 0., 0.]],

        [[0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
          0., 0., 0., 0., 0., 0., 1., 0., 0., 0.],
         [0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
          0., 0., 0., 0., 0., 0., 1., 0., 0., 0.]],

        [[0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
          0., 0., 0., 0., 0., 0., 1., 0., 0., 0.],
         [0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
          0.,

In [19]:
print(x_enc.shape)
x_enc_flat = x_enc.view((x_enc.size(0), -1))
print(x_enc_flat.shape)

torch.Size([4, 2, 27])
torch.Size([4, 54])


In [22]:
W = torch.randn((char_count *2, char_count))
logits = x_enc_flat @ W
counts = torch.exp(logits)
probs = counts / counts.sum(dim=1, keepdim=True)
print(probs)
print(probs[0].sum())

tensor([[0.0174, 0.0244, 0.0078, 0.0303, 0.0099, 0.0992, 0.1343, 0.0068, 0.0183,
         0.0008, 0.0175, 0.0097, 0.0048, 0.0054, 0.0155, 0.0023, 0.0011, 0.1384,
         0.0472, 0.0266, 0.0775, 0.0043, 0.0173, 0.0640, 0.0043, 0.0027, 0.2122],
        [0.0159, 0.2638, 0.0629, 0.0407, 0.0341, 0.0183, 0.0028, 0.0165, 0.0111,
         0.0019, 0.0074, 0.0155, 0.0325, 0.0733, 0.0204, 0.0318, 0.0496, 0.0065,
         0.0415, 0.0031, 0.0138, 0.1600, 0.0325, 0.0018, 0.0184, 0.0191, 0.0048],
        [0.0235, 0.1786, 0.1370, 0.0020, 0.0672, 0.0097, 0.0093, 0.0085, 0.0394,
         0.0121, 0.0033, 0.0068, 0.0047, 0.0010, 0.0049, 0.0090, 0.0177, 0.0152,
         0.0531, 0.0401, 0.2340, 0.0095, 0.0107, 0.0118, 0.0718, 0.0102, 0.0092],
        [0.1023, 0.0314, 0.1457, 0.0059, 0.0201, 0.0233, 0.0585, 0.0613, 0.0431,
         0.0245, 0.0104, 0.0018, 0.0350, 0.0013, 0.0098, 0.0197, 0.0164, 0.0188,
         0.0084, 0.0307, 0.2293, 0.0358, 0.0042, 0.0093, 0.0166, 0.0136, 0.0226]])
tensor(1.0000)


In [24]:
xs: Final[list[Tuple[int, int]]] = []
ys: Final[list[int]] = []
for w in words:
    chs: list[str] = ['.'] + list(w) + ['.']
    for ch1, ch2, ch3 in zip(chs, chs[1:], chs[2:]):
        xs.append((stoi[ch1], stoi[ch2]))
        ys.append(stoi[ch3])
xs_ts = torch.tensor(xs)
ys_ts = torch.tensor(ys)
n_examples = len(xs_ts)
n_examples

196113

In [42]:
g: torch.Generator = torch.Generator().manual_seed(2147483647)
W: torch.Tensor = torch.randn((char_count * 2, char_count), generator=g, requires_grad=True)


In [43]:
# Gradient Descent
n_epochs: Final[int] = 200
for epoch in range(n_epochs):
    # Forward
    x_enc = F.one_hot(xs_ts, num_classes=char_count).float()
    x_enc_flat = x_enc.view((x_enc.size(0), -1))
    logits = x_enc_flat @ W
    counts = torch.exp(logits)
    probs = counts / counts.sum(dim=1, keepdim=True)
    loss = -torch.log(probs[torch.arange(n_examples), ys_ts]).mean() + 0.01 * (W ** 2).sum()
    print(f"Epoch {epoch}: Loss: {loss.item()}")

    # Backward
    W.grad = None
    loss.backward()

    # Update weights
    assert W.grad is not None
    W.data -= 15.0 * (0.99 ** epoch) * W.grad


Epoch 0: Loss: 18.2454833984375
Epoch 1: Loss: 10.044971466064453
Epoch 2: Loss: 6.321364402770996
Epoch 3: Loss: 4.6018595695495605
Epoch 4: Loss: 3.7959489822387695
Epoch 5: Loss: 3.4130616188049316
Epoch 6: Loss: 3.228872060775757
Epoch 7: Loss: 3.1392369270324707
Epoch 8: Loss: 3.0951387882232666
Epoch 9: Loss: 3.0732176303863525
Epoch 10: Loss: 3.062211275100708
Epoch 11: Loss: 3.0566320419311523
Epoch 12: Loss: 3.0537774562835693
Epoch 13: Loss: 3.0523040294647217
Epoch 14: Loss: 3.0515365600585938
Epoch 15: Loss: 3.051133394241333
Epoch 16: Loss: 3.050920009613037
Epoch 17: Loss: 3.0508060455322266
Epoch 18: Loss: 3.0507447719573975
Epoch 19: Loss: 3.0507116317749023
Epoch 20: Loss: 3.0506932735443115
Epoch 21: Loss: 3.05068302154541
Epoch 22: Loss: 3.050677537918091
Epoch 23: Loss: 3.0506744384765625
Epoch 24: Loss: 3.0506725311279297
Epoch 25: Loss: 3.0506715774536133
Epoch 26: Loss: 3.050671339035034
Epoch 27: Loss: 3.050671100616455
Epoch 28: Loss: 3.050670623779297
Epoch 29

In [44]:
print(xs_ts.shape)

torch.Size([196113, 2])


In [45]:
# Counting method.

N = torch.zeros((char_count, char_count, char_count), dtype=torch.int32)
for ch1, ch2, ch3 in zip(xs_ts[:, 0], xs_ts[:, 1], ys_ts):
    N[ch1, ch2, ch3] += 1

P = (N + 1).float()
P /= P.sum(dim=2, keepdim=True)
P[0, 0].sum()

tensor(1.)

In [46]:
log_likelihood = 0.0
n = 0
for w in words:
    chs: list[str] = ['.'] + list(w) + ['.']
    for ch1, ch2, ch3 in zip(chs, chs[1:], chs[2:]):
        ix1, ix2, ix3 = stoi[ch1], stoi[ch2], stoi[ch3]
        prob = P[ix1, ix2, ix3]
        log_prob = torch.log(prob)
        log_likelihood += log_prob.item()
        n += 1
nll = -log_likelihood / n
print(f"Negative Log Likelihood: {nll:.4f}")

Negative Log Likelihood: 2.0931


In [47]:
bigram_N = torch.zeros((char_count, char_count), dtype=torch.int32)
for w in words:
    chs: list[str] = ['.'] + list(w) + ['.']
    for ch1, ch2 in zip(chs, chs[1:]):
        ix1, ix2 = stoi[ch1], stoi[ch2]
        bigram_N[ix1, ix2] += 1
bigram_P = (bigram_N + 1).float()
bigram_P /= bigram_P.sum(dim=1, keepdim=True)
bigram_P[0].sum()

tensor(1.0000)

In [49]:
g = torch.Generator().manual_seed(2147483647 + 1)
for i in range(20):
    ix1 = 0
    ix2 = int(torch.multinomial(bigram_P[ix1], num_samples=1, replacement=True, generator=g).item())
    out: list[str] = []
    while True:
        p = P[ix1, ix2]
        ix3 = int(torch.multinomial(p, num_samples=1, replacement=True, generator=g).item())
        out.append(itos[ix3])
        if ix3 == 0:
            break
        ix1, ix2 = ix2, ix3
    print("".join(out))


alin.
iahspan.
rie.
a.
aiya.
ierleige.
indael.
ssihailyn.
ai.
aylathaahemai.
ssi.
yna.
arn.
aiyitzaraelya.
raeileis.
a.
al.
wzxzair.
een.
a.


In [50]:
# Sampling from trigram NN model
g = torch.Generator().manual_seed(2147483647 + 1)
n_samples: Final[int] = 10
for i in range(n_samples):
    ix1 = 0
    ix2 = int(torch.multinomial(bigram_P[ix1], num_samples=1, replacement=True, generator=g).item())
    out: list[str] = []
    while True:
        x_enc = F.one_hot(torch.tensor([[ix1, ix2]]), num_classes=char_count).float()
        x_enc_flat = x_enc.view((x_enc.size(0), -1))
        logits = x_enc_flat @ W
        counts = torch.exp(logits)
        probs = counts / counts.sum(dim=1, keepdim=True)
        ix3 = int(torch.multinomial(probs, num_samples=1, replacement=True, generator=g).item())
        out.append(itos[ix3])
        if ix3 == 0:
            break
        ix1, ix2 = ix2, ix3
    print("".join(out))


aligyyzjhspmn.
artunrovfbvlpldttlenij.
egdtedtesceh.
ajyqtmi.
lykjtbdafumaf.
qwiwwvotfwoan.
abyizbxggeczangdzqizeps.
g.
gw.
wzxzvsxvlqmowymqfpjxgqk.
