In [1]:
import math
import torch
import torch.nn.functional as F

In [2]:
# Dataset
words = open('names.txt', 'r').read().splitlines()
words[:10]

['emma',
 'olivia',
 'ava',
 'isabella',
 'sophia',
 'charlotte',
 'mia',
 'amelia',
 'harper',
 'evelyn']

In [3]:
############################################################################## Counting approach ##############################################################################
first_char = {}
trigrams = {}
for word in words:
    chs = ['.'] + list(word) + ['.']
    first_char[word[0]] = first_char.get(word[0], 0) + 1
    for ch1, ch2, ch3 in zip(chs, chs[1:], chs[2:]):
        trigram = (ch1, ch2, ch3)
        trigrams[trigram] = trigrams.get(trigram, 0) + 1
sorted(first_char.items(), key = lambda kv: -kv[1])

[('a', 4410),
 ('k', 2963),
 ('m', 2538),
 ('j', 2422),
 ('s', 2055),
 ('d', 1690),
 ('r', 1639),
 ('l', 1572),
 ('c', 1542),
 ('e', 1531),
 ('t', 1308),
 ('b', 1306),
 ('n', 1146),
 ('z', 929),
 ('h', 874),
 ('g', 669),
 ('i', 591),
 ('y', 535),
 ('p', 515),
 ('f', 417),
 ('o', 394),
 ('v', 376),
 ('w', 307),
 ('x', 134),
 ('q', 92),
 ('u', 78)]

In [4]:
sorted(trigrams.items(), key = lambda kv: -kv[1])[:10]

[(('a', 'h', '.'), 1714),
 (('n', 'a', '.'), 1673),
 (('a', 'n', '.'), 1509),
 (('o', 'n', '.'), 1503),
 (('.', 'm', 'a'), 1453),
 (('.', 'j', 'a'), 1255),
 (('.', 'k', 'a'), 1254),
 (('e', 'n', '.'), 1217),
 (('l', 'y', 'n'), 976),
 (('y', 'n', '.'), 953)]

In [5]:
N = torch.zeros((27, 26, 27), dtype=torch.int32) # Trigram count matrix
M = torch.zeros((27,), dtype=torch.int32) # First character count matrix

In [6]:
chars = sorted(list(set(''.join(words))))
stoi = {s:i+1 for i, s in enumerate(chars)} # Characters to indexes
stoi['.'] = 0
itos = {i:s for s,i in stoi.items()} # Indexes to characters
stoi

{'a': 1,
 'b': 2,
 'c': 3,
 'd': 4,
 'e': 5,
 'f': 6,
 'g': 7,
 'h': 8,
 'i': 9,
 'j': 10,
 'k': 11,
 'l': 12,
 'm': 13,
 'n': 14,
 'o': 15,
 'p': 16,
 'q': 17,
 'r': 18,
 's': 19,
 't': 20,
 'u': 21,
 'v': 22,
 'w': 23,
 'x': 24,
 'y': 25,
 'z': 26,
 '.': 0}

In [7]:
# Filling in N and M matrices
for w in words:
    chs = ['.'] + list(w) + ['.']
    ix = stoi[w[0]]
    M[ix] += 1
    for ch1, ch2, ch3 in zip(chs, chs[1:], chs[2:]):
        ix1 = stoi[ch1]
        ix2 = stoi[ch2]
        ix3 = stoi[ch3]
        N[ix1, ix2-1, ix3] += 1

In [8]:
M

tensor([   0, 4410, 1306, 1542, 1690, 1531,  417,  669,  874,  591, 2422, 2963,
        1572, 2538, 1146,  394,  515,   92, 1639, 2055, 1308,   78,  376,  307,
         134,  535,  929], dtype=torch.int32)

In [9]:
MP = M.float()
MP = MP / MP.sum() # First character probability distribution

In [10]:
NP = (N+1).float()
NP = NP / NP.sum(2, keepdims=True) # Trigram probability distribution

In [11]:
# Sample from MP and NP probability distributions
g = torch.Generator().manual_seed(2147483647)
for i in range(5):
    out = []
    ix1 = 0
    ix2 = torch.multinomial(MP, num_samples=1, replacement=True, generator=g).item() # Sample first character from MP
    out.append(itos[ix2])
    while True:
        p = NP[ix1][ix2-1]
        ix1 = ix2
        ix2 = torch.multinomial(p, num_samples=1, replacement=True, generator=g).item() # Sample next character from NP
        out.append(itos[ix2])
        if ix2 == 0:
            break
    print(''.join(out))

junide.
jakasid.
prelay.
adin.
kairritoper.


In [12]:
# Negative log likelihood loss
log_likelihood = 0.0
n = 0

for w in words:
    chs = ['.'] + list(w) + ['.']
    ix = stoi[w[0]]
    log_likelihood += torch.log(MP[ix]) # add log prob of first character
    n += 1
    for ch1, ch2, ch3 in zip(chs, chs[1:], chs[2:]):
        ix1 = stoi[ch1]
        ix2 = stoi[ch2]
        ix3 = stoi[ch3]
        prob = NP[ix1, ix2-1, ix3]
        logprob = torch.log(prob)
        log_likelihood += logprob # add log prob of next character
        n += 1
print(f'{log_likelihood=}')
nll = -log_likelihood
print(f'{nll=}')
print(f'{nll/n}')

log_likelihood=tensor(-504657.3750)
nll=tensor(504657.3750)
2.2119929790496826


In [13]:
# Neural Net approach #############################################################################
# Create training set of trigrams ((x, y), z)
X, y = [], []

for w in words:
    chs = ['.'] + list(w) + ['.']
    for ch1, ch2, ch3 in zip(chs, chs[1:], chs[2:]):
        ix1 = stoi[ch1]
        ix2 = stoi[ch2]
        ix3 = stoi[ch3]
        X.append(ix1 + 27 * ix2) # 729 possibilities
        y.append(ix3)
X = torch.tensor(X)
y = torch.tensor(y)
num = X.nelement()

In [14]:
Xenc = F.one_hot(X, num_classes = 729).float()
Xenc.shape

torch.Size([196113, 729])

In [15]:
W = torch.randn((729, 27), generator = g, requires_grad = True)

In [23]:
# Gradient descent
for k in range(50):
    # Forward pass
    logits = Xenc @ W
    counts = logits.exp()
    probs = counts / counts.sum(1, keepdims = True)
    loss = -probs[torch.arange(num), y].log().mean()
    
    # Backward pass
    W.grad = None
    loss.backward()
    
    # Update
    W.data += -100 * W.grad
    print(f'Iteration: {k}, loss: {loss}')

Iteration: 0, loss: 2.3061764240264893
Iteration: 1, loss: 2.3038339614868164
Iteration: 2, loss: 2.301539897918701
Iteration: 3, loss: 2.2992923259735107
Iteration: 4, loss: 2.2970898151397705
Iteration: 5, loss: 2.294931411743164
Iteration: 6, loss: 2.2928149700164795
Iteration: 7, loss: 2.2907397747039795
Iteration: 8, loss: 2.2887046337127686
Iteration: 9, loss: 2.286707878112793
Iteration: 10, loss: 2.2847487926483154
Iteration: 11, loss: 2.2828264236450195
Iteration: 12, loss: 2.2809388637542725
Iteration: 13, loss: 2.2790863513946533
Iteration: 14, loss: 2.2772672176361084
Iteration: 15, loss: 2.275480031967163
Iteration: 16, loss: 2.2737247943878174
Iteration: 17, loss: 2.272000312805176
Iteration: 18, loss: 2.270305633544922
Iteration: 19, loss: 2.2686400413513184
Iteration: 20, loss: 2.267002582550049
Iteration: 21, loss: 2.265393018722534
Iteration: 22, loss: 2.2638099193573
Iteration: 23, loss: 2.2622528076171875
Iteration: 24, loss: 2.260721445083618
Iteration: 25, loss: 2