In [None]:
import math

import numpy as np
import matplotlib.pyplot as plt
import torch
import torch.nn.functional as F


%matplotlib inline

In [None]:
words = open('names.txt', 'r').read().splitlines()

In [None]:
words[:10]

In [None]:
t = {}
for w in words:
    chs = ['.'] + list(w) + ['.']
    for ch1, ch2, ch3 in zip(chs, chs[1:], chs[2:]):
        trigram = (ch1+ch2, ch3)
        t[trigram] = t.get(trigram, 0) + 1

In [None]:
chars = ['.'] + sorted(list(set(''.join(words))))
stoi = {s:i for i, s in enumerate(chars)}
# stoi['.'] = 0
itos = {i:s for s, i in stoi.items()} # could use enum(chars) but <.> will be missed

bichars = [x+y for x in chars for y in chars]
btoi = {b:i for i, b in enumerate(bichars)}
itob = {i:b for b, i in btoi.items()}

In [None]:
# N = torch.zeros((27, 27, 27), dtype=torch.int32)
# for w in words:
#     chs = ['.'] + list(w) + ['.']
#     for ch1, ch2, ch3 in zip(chs, chs[1:], chs[2:]):
#         ix1 = stoi[ch1]
#         ix2 = stoi[ch2]
#         ix3 = stoi[ch3]
#         N[ix1, ix2, ix3] += 1

N = torch.zeros((729, 27), dtype=torch.int32)
for w in words:
    chs = ['.'] + list(w) + ['.']
    for ch1, ch2, ch3 in zip(chs, chs[1:], chs[2:]):
        ix1 = btoi[ch1+ch2]
        ix2 = stoi[ch3]
        N[ix1, ix2] += 1

In [None]:
stoi['m']

In [None]:
btoi['mh'] # 359
27*stoi['m'] # 338
27*stoi['m']+stoi['h'] # 359

In [None]:
plt.figure(figsize=(16, 16))
plt.imshow(N[:27], cmap='Oranges')
for i in range(27):
    for j in range(27):
        chstr = itob[i]+itos[j]
        plt.text(j, i, chstr, ha='center', va='bottom', color='black')
        plt.text(j, i, N[i, j].item(), ha='center', va='top', color='black')
plt.axis('off');

In [None]:
# sanity check
za = 0
for w in words:
    za+=1 if w[:2] == 'za' else 0
za

In [None]:
P = (N+1).float() # why 1? some counts are 0 so probab will be 0 i.e. inf loglikelihood
P /= P.sum(1, keepdim=True)

In [333]:
g = torch.Generator().manual_seed(2147483647)

for i in range(10):
    out = []
    start_p = P[:26].view(-1)
    start_s = torch.multinomial(start_p, num_samples=1, replacement=True, generator=g).item()
    b, s = divmod(start_s, 27)
    b_next = itob[b][1]+itos[s]
    ix = btoi[b_next]
    out.append(b_next) 
    while True:
        
        p = P[ix]
        s = torch.multinomial(p, num_samples=1, replacement=True, generator=g).item()
        b_next = b_next[1]+itos[s]
        ix = btoi[b_next]
        if b_next[1] != '.':
            out.append(b_next[1])
        else:
            break

    print(''.join(out))

iounide
oalianah
rhyliah
ya
runa
qui
reltoper
my
gele
andannaaryanileniassibduinrwin


In [312]:
P[0]

tensor([0.0370, 0.0370, 0.0370, 0.0370, 0.0370, 0.0370, 0.0370, 0.0370, 0.0370,
        0.0370, 0.0370, 0.0370, 0.0370, 0.0370, 0.0370, 0.0370, 0.0370, 0.0370,
        0.0370, 0.0370, 0.0370, 0.0370, 0.0370, 0.0370, 0.0370, 0.0370, 0.0370])

In [320]:
p = P[:26].view(-1)
start = torch.multinomial(p, num_samples=1, replacement=True, generator=g).item()
start

507

In [321]:
# btoi['mh'] # 359
# 27*stoi['m'] # 338
# 27*stoi['m']+stoi['h'] # 359

In [322]:
b, s = divmod(start, 27)
b, s
b, s, itob[b], itos[s]

(18, 21, '.r', 'u')

In [323]:
b_next = itob[b][1]+itos[s]
ix = btoi[b_next]
b_next, ix

('ru', 507)

In [324]:
P[ix]

tensor([0.0609, 0.0251, 0.0645, 0.0358, 0.0573, 0.0645, 0.0143, 0.0072, 0.0358,
        0.0215, 0.0036, 0.0394, 0.0251, 0.0717, 0.0609, 0.0036, 0.0072, 0.0287,
        0.0036, 0.1685, 0.0896, 0.0072, 0.0323, 0.0143, 0.0108, 0.0036, 0.0430])

In [325]:
p = P[ix]
s = torch.multinomial(p, num_samples=1, replacement=True, generator=g).item()
s

14

In [326]:
itos[s]

'n'

In [327]:
b_next = b_next[1]+itos[s]
ix = btoi[b_next]
b_next, ix

('un', 581)

In [None]:
btoi['rl']

In [None]:
p = P[ix]
ix = torch.multinomial(p, num_samples=1, replacement=True, generator=g).item()
ix

In [None]:
itos[ix]