Exercises:

1. Train a trigram language model, i.e. take two characters as an input to predict the 3rd one. Feel free to use either counting or a neural net. Evaluate the loss; Did it improve over a bigram model?

It did better (as expected). Table description:
- training loss,
- no regularisation/smoothing,
- 200 epochs for 2/3-gram models,
- 600 epochs for 4-gram (probably needs better hyperparameters).

| model type    | bigram loss | trigram loss | 4-gram loss  |
|---------------|-------------|--------------|--------------|
| counting      |       2.454 |        1.942 |        1.471 |
| backprop (nn) |       2.460 |        2.029 |        1.780 |


In [25]:
import torch
from tqdm import tqdm

## Data

In [2]:
data_fpath = './data/names.txt'

In [3]:
with open(data_fpath, 'r') as f:
    words = f.read().splitlines()
words[:5]

['emma', 'olivia', 'ava', 'isabella', 'sophia']

In [4]:
len(words)

32033

In [5]:
word_lens = [len(word) for word in words]
print(f'min len: {min(word_lens)}; max len: {max(word_lens)}')

min len: 2; max len: 15


## Ngram model as an array with counts
using generalized version for ngrams given arbitrary n

In [45]:
n = 4

In [46]:
SEP_TOK = '.'

In [47]:
ngrams_dict = {}
for word in words:
    chars = [SEP_TOK]*(n-1) + list(word) + [SEP_TOK]*(n-1)
    ngram_chars = [chars[i:] for i in range(n)]
    for ngram in zip(*ngram_chars):
        ngrams_dict[ngram] = ngrams_dict.get(ngram, 0) + 1
ngrams_dict = sorted(ngrams_dict.items(), key=lambda kv: kv[1], reverse=True)
ngrams_dict[:10]

[(('n', '.', '.', '.'), 6763),
 (('a', '.', '.', '.'), 6640),
 (('.', '.', '.', 'a'), 4410),
 (('e', '.', '.', '.'), 3983),
 (('.', '.', '.', 'k'), 2963),
 (('.', '.', '.', 'm'), 2538),
 (('i', '.', '.', '.'), 2489),
 (('.', '.', '.', 'j'), 2422),
 (('h', '.', '.', '.'), 2409),
 (('.', '.', '.', 's'), 2055)]

In [48]:
vocab = [SEP_TOK] + sorted(list(set(''.join(words))))
stoi = {s: i for i, s in enumerate(vocab)}
itos = {i: s for i, s in enumerate(vocab)}

In [49]:
N = torch.zeros([len(vocab) for _ in range(n)], dtype=torch.int32)
for word in tqdm(words):
    chars = [SEP_TOK]*(n-1) + list(word) + [SEP_TOK]*(n-1)
    ngram_chars = [chars[i:] for i in range(n)]
    for ngram in zip(*ngram_chars):
        ixs = tuple(stoi[ch] for ch in ngram)
        N[ixs] += 1

  0%|          | 51/32033 [00:00<01:03, 505.30it/s]

100%|██████████| 32033/32033 [01:02<00:00, 513.36it/s]


In [50]:
base_count = 0 # smooths the probabilities
P = (N+base_count).float()
P = P / P.sum(axis=(n-1), keepdim=True)

### Sampling from the model

In [51]:
from collections import deque
n_samples = 20
g = torch.Generator().manual_seed(2147483647)
for _ in range(n_samples):
    ixs = deque([stoi[SEP_TOK]] * (n-1))
    out = []
    while True:
        prob_distr = P[tuple(ixs)]
        ix = torch.multinomial(prob_distr, num_samples=1, replacement=True, generator=g).item()
        if ix == stoi[SEP_TOK]:
            break
        ixs.popleft()
        ixs.append(ix)
        out.append(itos[ix])
    print(''.join(out))

juniba
jakasir
presar
adria
jira
tolomas
ter
kalania
yanilena
jededaileti
tayse
siely
artez
noud
than
demmerceyn
lena
jaylie
reanae
ocely


### Evaluating the performance

In [52]:
log_likelihood = 0.0
count = 0
for word in tqdm(words, 'Evaluating'):
    chars = [SEP_TOK]*(n-1) + list(word) + [SEP_TOK]*(n-1)
    ngram_chars = [chars[i:] for i in range(n)]
    for ngram in zip(*ngram_chars):
        ixs = tuple(stoi[ch] for ch in ngram)
        prob = P[ixs]
        logprob = torch.log(prob)
        log_likelihood += logprob
        count += 1

print(f'{log_likelihood=}')
nll = -log_likelihood
print(f'{nll=}')
print(f'{nll/count=}')

Evaluating:   0%|          | 138/32033 [00:00<00:47, 672.83it/s]

Evaluating: 100%|██████████| 32033/32033 [00:46<00:00, 692.80it/s]

log_likelihood=tensor(-429832.4062)
nll=tensor(429832.4062)
nll/count=tensor(1.4710)





## Ngram model as neural net 

In [53]:
# creating the training set of bigrams
xs, ys = [], []
for word in tqdm(words, f'Creating {n}-gram samples'):
    chars = [SEP_TOK]*(n-1) + list(word) + [SEP_TOK]*(n-1)
    ngram_chars = [chars[i:] for i in range(n)]
    for ngram in zip(*ngram_chars):
        ixs = [stoi[ch] for ch in ngram]
        xs.append(ixs[:-1])
        ys.append(ixs[-1])

xs = torch.tensor(xs)
ys = torch.tensor(ys)

print(f'Number of training examples: {xs.shape[0]}')

Creating 4-gram samples:   0%|          | 0/32033 [00:00<?, ?it/s]

Creating 4-gram samples: 100%|██████████| 32033/32033 [00:03<00:00, 9515.39it/s] 


Number of training examples: 292212


### Training loop

In [61]:
def calc_loss(xs, ys, W, weight_decay=1e-4):
    logits = W[[x for x in xs.T]]
    counts = logits.exp()
    probs = counts / counts.sum(1, keepdim=True)
    # loss = average negative log likelihood
    loss = -probs[torch.arange(len(ys)), ys].log().mean() + weight_decay*(W**2).mean()
    return loss

In [64]:
# initializing the "model"
g = torch.Generator().manual_seed(2147483647)
W = torch.randn(tuple(len(vocab) for _ in range(n)), generator=g, requires_grad=True)

In [None]:
for ep in range(600):
    # forward pass
    tr_loss = calc_loss(xs, ys, W, 0)

    # backward pass
    W.grad = None
    tr_loss.backward()

    # update
    W.data += -100 * W.grad

    if ep % 10 == 9:
        print(f'{ep+1:>3}th epoch, tr_loss={tr_loss.item():.3f}')

### Sampling from the network

In [66]:
from collections import deque
n_samples = 20
g = torch.Generator().manual_seed(2147483647)
for _ in range(n_samples):
    ixs = deque([stoi[SEP_TOK]] * (n-1))
    out = []
    while True:
        logits = W[tuple(ixs)]
        counts = logits.exp()
        probs = counts / counts.sum()
        ix = torch.multinomial(probs, num_samples=1, replacement=True, generator=g).item()
        if ix == stoi[SEP_TOK]:
            break
        ixs.popleft()
        ixs.append(ix)
        out.append(itos[ix])
    print(''.join(out))

junjdedianaqidouxutnypaxnuq
jimrltozsogjatqzvugignaduwjbuldvhajzdbiminrwimpadsvzywcfxvbryn
farmumtkyf
demmerponnsleigh
ani
cora
yaehocpkqjyked
webdmeiibwyaftwtiansnhspoluwaspphfdgosfmxtpqcixz
repahfmtydt
jayrslu
isa
dyfj
mjluuj
mahvupwyilpvhecgiagr
jenhwvdxtta
malyn
brey
aui
lavlpocq
themilana
