In [1]:
words = open('names.txt', 'r').read().splitlines()

In [2]:
words[:10]

['emma',
 'olivia',
 'ava',
 'isabella',
 'sophia',
 'charlotte',
 'mia',
 'amelia',
 'harper',
 'evelyn']

In [20]:
b = {}

for w in words:
    chs = ['.', '.'] + list(w) + ['.']
    for ch1, ch2, ch3 in zip(chs, chs[1:], chs[2:]):
        bigram = (ch1, ch2, ch3)
        b[bigram] = b.get(bigram, 0) + 1

In [21]:
sorted(b.items(), key=lambda kv: kv[1], reverse=True)

[(('.', '.', 'a'), 4410),
 (('.', '.', 'k'), 2963),
 (('.', '.', 'm'), 2538),
 (('.', '.', 'j'), 2422),
 (('.', '.', 's'), 2055),
 (('a', 'h', '.'), 1714),
 (('.', '.', 'd'), 1690),
 (('n', 'a', '.'), 1673),
 (('.', '.', 'r'), 1639),
 (('.', '.', 'l'), 1572),
 (('.', '.', 'c'), 1542),
 (('.', '.', 'e'), 1531),
 (('a', 'n', '.'), 1509),
 (('o', 'n', '.'), 1503),
 (('.', 'm', 'a'), 1453),
 (('.', '.', 't'), 1308),
 (('.', '.', 'b'), 1306),
 (('.', 'j', 'a'), 1255),
 (('.', 'k', 'a'), 1254),
 (('e', 'n', '.'), 1217),
 (('.', '.', 'n'), 1146),
 (('l', 'y', 'n'), 976),
 (('y', 'n', '.'), 953),
 (('a', 'r', 'i'), 950),
 (('.', '.', 'z'), 929),
 (('i', 'a', '.'), 903),
 (('.', '.', 'h'), 874),
 (('i', 'e', '.'), 858),
 (('a', 'n', 'n'), 825),
 (('e', 'l', 'l'), 822),
 (('a', 'n', 'a'), 804),
 (('i', 'a', 'n'), 790),
 (('m', 'a', 'r'), 776),
 (('i', 'n', '.'), 766),
 (('e', 'l', '.'), 727),
 (('y', 'a', '.'), 716),
 (('a', 'n', 'i'), 703),
 (('.', 'd', 'a'), 700),
 (('l', 'a', '.'), 684),
 (('

In [8]:
import torch

In [9]:
if torch.backends.mps.is_available():
    mps_device = torch.device("mps")
    x = torch.ones(1, device=mps_device)
    print (x)
else:
    print ("MPS device not found.")

tensor([1.], device='mps:0')


In [13]:
chars = list('abcdefghijklmnopqrstuvwxyz')
stoi = {s:i+1 for i,s in enumerate(chars)}
stoi['.'] = 0
itos = {i:s for s,i in stoi.items()}

In [27]:
N = torch.zeros((27, 27, 27), dtype=torch.int32)

In [28]:
for w in words:
    chs = ['.', '.'] + list(w) + ['.']
    for ch1, ch2, ch3 in zip(chs, chs[1:], chs[2:]):
        ix1 = stoi[ch1]
        ix2 = stoi[ch2]
        ix3 = stoi[ch3]
        N[ix1,ix2,ix3] += 1

In [30]:
# validate counts are same as above
print("..a", N[0,0,1].item())
print("ah.", N[1,8,0].item())
print("na.", N[14,1,0].item())

..a 4410
ah. 1714
na. 1673


In [31]:
g = torch.Generator().manual_seed(2147483647)

In [46]:
P = (N+1).float()
P /= P.sum(2, keepdim=True)
P[0,0].sum()

tensor(1.)

In [50]:
for i in range(10):
    ix1 = 0
    ix2 = 0
    out = []
    
    while True:
        p = P[ix1, ix2]
    
        ix1 = ix2
        ix2 = torch.multinomial(p, num_samples=1, replacement=True, generator=g).item()
        out.append(itos[ix2])
        
        if ix2 == 0:
            break
    print("".join(out))

makimarityharlonimittain.
luwak.
ka.
da.
samiyah.
javer.
gotai.
moriellavojkwuthda.
kaley.
maside.


In [51]:
log_likelihood = 0.0
n = 0

for w in words:
    chs = ['.', '.'] + list(w) + ['.']
    for ch1, ch2, ch3 in zip(chs, chs[1:], chs[2:]):
        ix1 = stoi[ch1]
        ix2 = stoi[ch2]
        ix3 = stoi[ch3]
        prob = P[ix1, ix2, ix3]
        logprob = torch.log(prob)
        log_likelihood += logprob
        n += 1
        # print(f'{ch1}{ch2}: {prob:.4f} {logprob:.4f}')

print(f'{log_likelihood=}')
nll = -log_likelihood/n
print(f'{nll=}')

log_likelihood=tensor(-504653.)
nll=tensor(2.2120)


# Trigram NN Optimization

In [54]:
import torch.nn.functional as F

In [86]:
# create training set
xs, ys = [], []

for w in words:
    chs = ['.', '.'] + list(w) + ['.']
    for ch1, ch2, ch3 in zip(chs, chs[1:], chs[2:]):
        ix1 = stoi[ch1]
        ix2 = stoi[ch2]
        ix3 = stoi[ch3]
        
        xs.append([ix1, ix2])
        ys.append(ix3)

xs = torch.tensor(xs)
ys = torch.tensor(ys)
num = xs.nelement() // 2
print('number of examples:', num)

# initialize network
g = torch.Generator().manual_seed(2147483647)
W = torch.randn((54,27), generator=g, requires_grad=True)

number of examples: 228146


In [90]:
for k in range(100):

    # forward pass
    xtemp = F.one_hot(xs, num_classes=27).float()
    xlen, ylen, zlen = xtemp.shape
    xenc = xtemp.reshape((xlen, ylen*zlen))
    
    logits = xenc @ W
    counts = logits.exp()
    probs = counts / counts.sum(1, keepdim=True)
    loss = -probs[torch.arange(num), ys].log().mean()
    print('loss:', loss.item())

    # backward pass
    W.grad = None
    loss.backward()

    # update
    W.data += -50 * W.grad


loss: 3.203230381011963
loss: 2.8220293521881104
loss: 2.5432095527648926
loss: 2.44171142578125
loss: 2.411390781402588
loss: 2.3871679306030273
loss: 2.3668408393859863
loss: 2.351937770843506
loss: 2.3434951305389404
loss: 2.3400886058807373
loss: 2.3388891220092773
loss: 2.338777780532837
loss: 2.3400020599365234
loss: 2.3433945178985596
loss: 2.3566362857818604
loss: 2.367482900619507
loss: 2.4272377490997314
loss: 2.358185052871704
loss: 2.3793535232543945
loss: 2.370321750640869
loss: 2.4235622882843018
loss: 2.358818292617798
loss: 2.379650592803955
loss: 2.369866132736206
loss: 2.4215571880340576
loss: 2.35932993888855
loss: 2.38138484954834
loss: 2.3696324825286865
loss: 2.420348644256592
loss: 2.3596229553222656
loss: 2.3824758529663086
loss: 2.3694777488708496
loss: 2.419612169265747
loss: 2.3598012924194336
loss: 2.3831684589385986
loss: 2.3693742752075195
loss: 2.4191513061523438
loss: 2.3599119186401367
loss: 2.383615732192993
loss: 2.36930513381958
loss: 2.4188613891601