In [None]:
import torch
import matplotlib.pyplot as plt

In [None]:
words = open('../data/names.txt', 'r').read().splitlines()
words[:10]

In [None]:
chars = sorted(list(set(''.join(words))))
ctoi = {c: i+1 for i, c in enumerate(chars)}
ctoi['.'] = 0
itoc = {i: c for c, i in ctoi.items()}

In [None]:
num_chars = 27  # 26 letters + '.'
counts = torch.zeros((num_chars, num_chars, num_chars), dtype=torch.int32)
# d1, d2: two prev chars; d3: next char

for w in words:
  w = ['.'] + ['.'] + list(w) + ['.'] + ['.']
  for ch1, ch2, ch3 in zip(w, w[1:], w[2:]):
    counts[ctoi[ch1], ctoi[ch2], ctoi[ch3]] += 1

In [None]:
import torch.nn.functional as F
P = (counts+1).float().clone()
# normalize P s.t. for every pair (i1, i2) as first two indices, we get a prob distr with sum 1
P = F.normalize(P, p=1, dim=2)

# for i1, i2 in zip(range(27), range(27)):
#   print(P[i1, i2].sum())

In [None]:
# Sampling (forward mode)
g = torch.Generator().manual_seed(2147483647)
num_samples = 5

for _ in range(num_samples):
  sample = []
  i1, i2 = 0, 0
  while True:
    # p = counts[i1, i2].float()
    p = P[i1, i2]
    i3 = torch.multinomial(p, num_samples=1, replacement=True, generator=g).item()
    sample.append(itoc[i3])
    if i3 == 0:
      break
    i1, i2 = i2, i3

  print(''.join(sample))

In [None]:
def compute_nll(word: str):
  llh = 0
  w = ['.'] + ['.'] + list(word) + ['.'] + ['.']
  for ch1, ch2, ch3 in zip(w, w[1:], w[2:]):
    logprob = torch.log(P[ctoi[ch1], ctoi[ch2], ctoi[ch3]])
    llh += logprob
  return -llh

compute_nll('max')