**E01:** train a trigram language model, i.e. take two characters as an input to predict the 3rd one. Feel free to use either counting or a neural net. Evaluate the loss; Did it improve over a bigram model?

### First step: How do I train a trigram model? First I need to start with trying to understand how to actually produce tri-grams

In [26]:
import torch
import torch.nn.functional as F

import matplotlib.pyplot as plt

In [27]:
words = open('names.txt', 'r').read().split()

Now we need to set up chars so that stoi and itos are configured properly. We should have to try to identify the next letter, so I'm not quite sure how I'll figure out how to do this yet

I'm thinking we'll just need two sets of dictionaries in order to properly map everything together, since we want all of the trigrams to be referenced to unique sets

In [81]:
single_chr = sorted(list(set(''.join(words))))
single_stoi = {s:i+1 for i, s in enumerate(single_chr)} # reserve elem 0 for the dot
single_stoi['.'] = 0
single_itos = {i:s for s, i in single_stoi.items()} # flip integers and strings around
double_chr = [f'{ch1}{ch2}' for ch1 in single_itos.values() for ch2 in single_itos.values()]
double_stoi = {s:i for i, s in enumerate(double_chr)}
del double_stoi['..']
double_itos = {i:s for s, i in double_stoi.items()}

In [82]:
# Build your datasets
xs, ys = [], []
for w in words:
    chs = ['.'] + list(w) + ['.']
    #print(chs)
    for ch1, ch2, ch3 in zip(chs, chs[1:], chs[2:]):
        ch12 = ch1 + ch2
        ix_d = double_stoi[ch12] # double character index
        ix_s = single_stoi[ch3]  # single character index
        #print(f'ch12 ch3: {ch12} {ch3}')
        #print(f'ix_d, ix_s: {ix_d}, {ix_s}')
        xs.append(ix_d)
        ys.append(ix_s)
        
# then we'll turn these into tensors since we'll use them to build the network with PyTorch
xs = torch.tensor(xs)
ys = torch.tensor(ys)
print(f'number of examples: {xs.nelement()}')
#print(f'xs: {xs}')
#print(f'ys: {ys}')

number of examples: 196113


In [83]:
train_steps = 50
train_step = 50
num_inputs = 27*27-1
g = torch.Generator().manual_seed(2147483647)
W = torch.randn((num_inputs, 27), generator=g, requires_grad=True) # remember to set requires_grad
for k in range(train_steps):
    # Forward pass
    xenc = F.one_hot(xs, num_classes=num_inputs).float()
    logits = xenc @ W
    counts = logits.exp()
    probs = counts / counts.sum(1, keepdims=True)
    anll = -probs[torch.arange(len(ys)), ys].log().mean() + 0.01*(W**2).mean()
    
    # Backward pass
    W.grad = None
    anll.backward()
    
    # Update
    W.data += -train_step * W.grad
    print(f'avg neg log likelihood loss in step {k+1} of {train_steps}: {anll:.4f}')
    
print(anll)

avg neg log likelihood loss in step 1 of 50: 3.7883
avg neg log likelihood loss in step 2 of 50: 3.7083
avg neg log likelihood loss in step 3 of 50: 3.6333
avg neg log likelihood loss in step 4 of 50: 3.5629
avg neg log likelihood loss in step 5 of 50: 3.4971
avg neg log likelihood loss in step 6 of 50: 3.4358
avg neg log likelihood loss in step 7 of 50: 3.3790
avg neg log likelihood loss in step 8 of 50: 3.3265
avg neg log likelihood loss in step 9 of 50: 3.2782
avg neg log likelihood loss in step 10 of 50: 3.2339
avg neg log likelihood loss in step 11 of 50: 3.1933
avg neg log likelihood loss in step 12 of 50: 3.1560
avg neg log likelihood loss in step 13 of 50: 3.1217
avg neg log likelihood loss in step 14 of 50: 3.0899
avg neg log likelihood loss in step 15 of 50: 3.0605
avg neg log likelihood loss in step 16 of 50: 3.0330
avg neg log likelihood loss in step 17 of 50: 3.0074
avg neg log likelihood loss in step 18 of 50: 2.9833
avg neg log likelihood loss in step 19 of 50: 2.9606
av

In [87]:
g = torch.Generator().manual_seed(2147483647)

for i in range(20):
    out = []
    ix = 0
    while True:
        xenc = F.one_hot(torch.tensor([ix]), num_classes=num_inputs).float()
        logits = xenc @ W
        counts = logits.exp()
        probs = counts / counts.sum(1, keepdims=True)
        
        ix = torch.multinomial(probs, num_samples=1, replacement=True, generator=g).item()
        out.append(single_itos[ix])
        if ix == 0:
            break
    print(''.join(out))

ouwjdldjawcqid.
pxlfbywednw.
.
oh.
nwtozshsh.
g.
zvbmaxnpauydbbleviajsdbyuinrwipblasnjyinbyt.
rtbcffrmumtsyfodtumjmnpytszwjqrsaed.
roreayg.
zpcejajaaedlwtdfmiiibwyfinwtg.
psnhsvfihsuszphddg.
nfbptpariluir.
paufbtkit.
r.
nbmri.
isuyuytr.
nmeaujibkivuywtdlpch.
.
ywpg.
.
