In [1]:
import math
import torch
import torch.nn.functional as F

In [2]:
# Dataset
words = open('names.txt', 'r').read().splitlines()
words[:10]

['emma',
 'olivia',
 'ava',
 'isabella',
 'sophia',
 'charlotte',
 'mia',
 'amelia',
 'harper',
 'evelyn']

In [3]:
############################################################################## Counting approach ##############################################################################
first_char = {}
trigrams = {}
for word in words:
    chs = ['.'] + list(word) + ['.']
    first_char[word[0]] = first_char.get(word[0], 0) + 1
    for ch1, ch2, ch3 in zip(chs, chs[1:], chs[2:]):
        trigram = (ch1, ch2, ch3)
        trigrams[trigram] = trigrams.get(trigram, 0) + 1
sorted(first_char.items(), key = lambda kv: -kv[1])

[('a', 4410),
 ('k', 2963),
 ('m', 2538),
 ('j', 2422),
 ('s', 2055),
 ('d', 1690),
 ('r', 1639),
 ('l', 1572),
 ('c', 1542),
 ('e', 1531),
 ('t', 1308),
 ('b', 1306),
 ('n', 1146),
 ('z', 929),
 ('h', 874),
 ('g', 669),
 ('i', 591),
 ('y', 535),
 ('p', 515),
 ('f', 417),
 ('o', 394),
 ('v', 376),
 ('w', 307),
 ('x', 134),
 ('q', 92),
 ('u', 78)]

In [4]:
sorted(trigrams.items(), key = lambda kv: -kv[1])[:10]

[(('a', 'h', '.'), 1714),
 (('n', 'a', '.'), 1673),
 (('a', 'n', '.'), 1509),
 (('o', 'n', '.'), 1503),
 (('.', 'm', 'a'), 1453),
 (('.', 'j', 'a'), 1255),
 (('.', 'k', 'a'), 1254),
 (('e', 'n', '.'), 1217),
 (('l', 'y', 'n'), 976),
 (('y', 'n', '.'), 953)]

In [5]:
N = torch.zeros((27, 27, 27), dtype=torch.int32) # Trigram count matrix

In [6]:
chars = sorted(list(set(''.join(words))))
stoi = {s:i+1 for i, s in enumerate(chars)} # Characters to indexes
stoi['.'] = 0
itos = {i:s for s,i in stoi.items()} # Indexes to characters
stoi

{'a': 1,
 'b': 2,
 'c': 3,
 'd': 4,
 'e': 5,
 'f': 6,
 'g': 7,
 'h': 8,
 'i': 9,
 'j': 10,
 'k': 11,
 'l': 12,
 'm': 13,
 'n': 14,
 'o': 15,
 'p': 16,
 'q': 17,
 'r': 18,
 's': 19,
 't': 20,
 'u': 21,
 'v': 22,
 'w': 23,
 'x': 24,
 'y': 25,
 'z': 26,
 '.': 0}

In [7]:
# Filling in N and M matrices
for w in words:
    chs = ['.'] + ['.'] + list(w) + ['.']
    for ch1, ch2, ch3 in zip(chs, chs[1:], chs[2:]):
        ix1 = stoi[ch1]
        ix2 = stoi[ch2]
        ix3 = stoi[ch3]
        N[ix1, ix2, ix3] += 1

In [8]:
NP = (N+1).float()
NP = NP / NP.sum(2, keepdims=True) # Trigram probability distribution

In [9]:
# Sample from MP and NP probability distributions
g = torch.Generator().manual_seed(2147483647)
for i in range(5):
    out = []
    ix1 = 0
    ix2 = 0 # Sample first character from MP
    while True:
        p = NP[ix1][ix2]
        ix1 = ix2
        ix2 = torch.multinomial(p, num_samples=1, replacement=True, generator=g).item() # Sample next character from NP
        out.append(itos[ix2])
        if ix2 == 0:
            break
    print(''.join(out))

junide.
jakasid.
prelay.
adin.
kairritoper.


In [10]:
# Negative log likelihood loss
log_likelihood = 0.0
n = 0

for w in words:
    chs = ['.'] + ['.'] + list(w) + ['.']
    for ch1, ch2, ch3 in zip(chs, chs[1:], chs[2:]):
        ix1 = stoi[ch1]
        ix2 = stoi[ch2]
        ix3 = stoi[ch3]
        prob = NP[ix1, ix2, ix3]
        logprob = torch.log(prob)
        log_likelihood += logprob # add log prob of next character
        n += 1
print(f'{log_likelihood=}')
nll = -log_likelihood
print(f'{nll=}')
print(f'{nll/n}')

log_likelihood=tensor(-504653.)
nll=tensor(504653.)
2.2119739055633545


In [11]:
############################################################################# Neural Net approach #############################################################################
# Create training set of trigrams ((x, y), z)
X, y = [], []

for w in words:
    chs = ['.'] + ['.'] + list(w) + ['.']
    for ch1, ch2, ch3 in zip(chs, chs[1:], chs[2:]):
        ix1 = stoi[ch1]
        ix2 = stoi[ch2]
        ix3 = stoi[ch3]
        X.append(ix1 + 27 * ix2) # 729 possibilities
        y.append(ix3)
X = torch.tensor(X)
y = torch.tensor(y)
num = X.nelement()

In [12]:
Xenc = F.one_hot(X, num_classes = 729).float()
Xenc.shape

torch.Size([228146, 729])

In [13]:
W = torch.randn((729, 27), generator = g, requires_grad = True)

In [16]:
# Gradient descent
for epoch in range(100):
    
    # Forward pass
    logits = Xenc @ W
    
    # Softmax
    counts = logits.exp()
    probs = counts / counts.sum(1, keepdims = True)
    
    # nll loss
    loss = -probs[torch.arange(num), y].log().mean()
    
    # Backward pass
    W.grad = None
    loss.backward()
    
    # Update
    W.data += -150 * W.grad
    if (epoch == 0 or ((epoch+1)%10) == 0):
        print(f'Epoch: {epoch+1}, loss: {loss}')

Epoch: 1, loss: 2.3535947799682617
Epoch: 10, loss: 2.321342945098877
Epoch: 20, loss: 2.3175909519195557
Epoch: 30, loss: 2.304429292678833
Epoch: 40, loss: 2.3028669357299805
Epoch: 50, loss: 2.291486978530884
Epoch: 60, loss: 2.2913997173309326
Epoch: 70, loss: 2.281254291534424
Epoch: 80, loss: 2.282212734222412
Epoch: 90, loss: 2.2729592323303223
Epoch: 100, loss: 2.274686813354492


In [17]:
# Sample from neural net model
for _ in range(10):
    out = []
    # Start with ..
    ix1 = 0
    ix2 = 0
    while True:
        xenc = F.one_hot(torch.tensor([ix1 + 27 * ix2]), num_classes=729).float()
        logits = xenc @ W
        counts = logits.exp()
        p = counts / counts.sum(1, keepdims=True)
        
        ix1 = ix2
        ix2 = torch.multinomial(p, num_samples=1, replacement=True, generator=g).item()
        out.append(itos[ix2])
        if ix2 == 0:
            break
    print(''.join(out))

lijamerele.
silahlson.
ra.
zala.
casi.
ra.
mar.
ra.
zon.
janna.
