In [84]:
words = open('names.txt', 'r').read().splitlines()

In [85]:
# count training set occurrences for each trigram
t = {}
for w in words:
  chs = ['.'] + list(w) + ['.']
  for (ch1, ch2, ch3) in zip(chs, chs[1:], chs[2:]):
    trigram = (ch1+ch2, ch3)
    t[trigram] = t.get(trigram, 0) + 1

In [86]:
# convertion utils for the 'next char'. 27 elements.
chars = sorted(list(set(''.join(words))))
stoi = {s:i+1 for i,s in enumerate(chars)}
stoi['.'] = 0
itos = {i:s for s,i in stoi.items()}

In [87]:
# convertion utils for 'previous bigram'. 729 unique elements 
chars.append('.')
import itertools as itt
bigrams = [ch1+ch2 for ch1, ch2 in list(set(itt.product(chars, chars)))]
bigramtoi = {cp:i for i,cp in enumerate(bigrams)}
itobigram = {i:cp for cp,i in bigramtoi.items()}

In [88]:
# Part 1: we use the counts from the training set to calculate probabilities table P and the loss function.

In [89]:
import torch
N = torch.zeros((27, 27, 27), dtype=torch.int32)

for w in words:
  chs = ['.'] + list(w) + ['.']
  for (ch1, ch2, ch3) in zip(chs, chs[1:], chs[2:]):
    ix1 = stoi[ch1]
    ix2 = stoi[ch2]
    ix3 = stoi[ch3]
    N[ix1, ix2, ix3] += 1

In [90]:
P = (N+1).float()
P /= P.sum((1, 2), keepdim=True)

In [91]:
# We can now sample using the table P
g = torch.Generator().manual_seed(2147483647)

for i in range(5):
  ix1 = 2
  ix2 = 4
  out = [itos[ix1] + itos[ix2]]
  while True:
    p = P[ix1][ix2]
    ix3 = torch.multinomial(p, num_samples=1, replacement=True, generator=g).item()
    out.append(itos[ix3])
    if itos[ix3] == '.':
      break
    ix1 = ix2
    ix2 = ix3
  print(''.join(out)) 

bduuwjde.
bdianasid.
bdulexay.
bdo.
bdin.


In [92]:
import numpy as np

# calculate loss function on the training set
log_likelihood = 0.0
n = 0

for w in words:
#for w in ["andrejq"]:
  chs = ['.'] + list(w) + ['.']
  for (ch1, ch2, ch3) in zip(chs, chs[1:], chs[2:]):
    ix1 = stoi[ch1]
    ix2 = stoi[ch2]
    ix3 = stoi[ch3]
    prob = P[ix1, ix2, ix3]
    logprob = torch.log(prob)
    log_likelihood += logprob
    n += 1
    #print(f'{ch1}{ch2}{ch3}: {prob:.4f} {logprob:.4f}')

print(f'{log_likelihood=}')
nll = -log_likelihood
print(f'{nll=}')
print(f'{nll/n}') 

log_likelihood=tensor(-894872.8750)
nll=tensor(894872.8750)
4.563047409057617


In [93]:
# Part 2: we won't use the probabilities P anymore, we will calculate the equivalent W using an ANN

In [94]:
# create the training set of trigrams in the form ('xy','z')
xs, ys = [], []

for w in words:
  chs = ['.'] + list(w) + ['.']
  for (ch1, ch2, ch3) in zip(chs, chs[1:], chs[2:]):
    trigram = (ch1+ch2, ch3)
    ix1 = bigramtoi[trigram[0]]
    ix2 = stoi[ch3]
    #print(trigram)
    xs.append(ix1)
    ys.append(ix2)
    
xs = torch.tensor(xs)
ys = torch.tensor(ys)

In [95]:
import torch.nn.functional as F
xenc = F.one_hot(xs, num_classes=729).float()
xenc.shape

torch.Size([196113, 729])

In [96]:
# initialize W with some random values
g = torch.Generator().manual_seed(2147483647)
W = torch.randn((729, 27), generator=g, requires_grad=True)

In [97]:
# adjusts W values until we equal the loss function result from the training set above.
xenc = F.one_hot(xs, num_classes=729).float()
loss = 1000
  
while loss >= 2.092747449874878:
    
  # forward pass
  logits = xenc @ W
  counts = logits.exp()
  probs = counts / counts.sum(1, keepdims=True)
  loss = -probs[torch.arange(xs.nelement()), ys].log().mean() + 0.01*(W**2).mean()
  
  # backward pass
  W.grad = None
  loss.backward()
  
  # update
  W.data += -300 * W.grad

  print(loss.item(), end="\r")

2.0927419662475586

In [98]:
# sampling from NN
g = torch.Generator().manual_seed(2147483647)

for i in range(5):
  ix1 = 45
  last_bigram = itobigram[ix1]
  out = [last_bigram]
    
  while True:  
    
    xenc = F.one_hot(torch.tensor([ix1]), num_classes=729).float()
    logits = xenc @ W
    counts = logits.exp()
    p = counts / counts.sum(1, keepdims=True) # probabilities for next character
    
    ix2 = torch.multinomial(p, num_samples=1, replacement=True, generator=g).item()
    out.append(itos[ix2])
    last_bigram = last_bigram[1] + itos[ix2]
    ix1 = bigramtoi[last_bigram]
      
    if itos[ix2] == '.':
      break

  print(''.join(out))

zmaunide.
zmilyasid.
zma.
zmalay.
zmacin.
