In [None]:
import torch
import matplotlib.pyplot as plt

In [None]:
words = open('../data/names.txt', 'r').read().splitlines()
words[:10]

In [None]:
chars = sorted(list(set(''.join(words))))
ctoi = {c: i+1 for i, c in enumerate(chars)}
ctoi['.'] = 0
itoc = {i: c for c, i in ctoi.items()}

In [None]:
num_chars = 27  # 26 letters + '.'
counts = torch.zeros((num_chars, num_chars), dtype=torch.int32)
# row: prev char, col: next char

for w in words:
  w = ['.'] + list(w) + ['.']
  for ch1, ch2 in zip(w, w[1:]):
    r, c = ctoi[ch1], ctoi[ch2]
    counts[r, c] += 1

In [None]:
%matplotlib inline

plt.figure(figsize=(16,16))
plt.imshow(counts, cmap='Blues')
for i in range(27):
    for j in range(27):
        chstr = itoc[i] + itoc[j]
        plt.text(j, i, chstr, ha="center", va="bottom", color='gray')
        plt.text(j, i, counts[i, j].item(), ha="center", va="top", color='gray')
plt.axis('off');

In [None]:
# Prob distribution incl. model smoothing
P = (counts+1).float() / counts.sum(dim=1, keepdims=True)

In [None]:
# Sampling (forward mode)
g = torch.Generator().manual_seed(2147483647)
num_samples = 5

for i in range(num_samples):
  sample = []
  i1 = 0
  while True:
    i2 = torch.multinomial(P[i1], num_samples=1, replacement=True, generator=g).item()
    sample.append(itoc[i2])
    if i2 == 0:
      break
    i1 = i2

  print(''.join(sample))

In [None]:
def compute_nll(word: str):
  llh = 0
  chs = ['.'] + list(word) + ['.']
  for ch1, ch2 in zip(chs, chs[1:]):
    logprob = torch.log(P[ctoi[ch1], ctoi[ch2]])
    llh += logprob
  return -llh

compute_nll('max')