In [None]:
import torch
import matplotlib.pyplot as plt

In [None]:
words = open('../data/names.txt', 'r').read().splitlines()
words[:10]

In [None]:
chars = sorted(list(set(''.join(words))))
ctoi = {c: i+1 for i, c in enumerate(chars)}
ctoi['.'] = 0
itoc = {i: c for c, i in ctoi.items()}

num_chars = 27

In [None]:
# Create train set
xs, ys = [], []

for w in words:
  w = ['.'] + ['.'] + list(w) + ['.'] + ['.']  
  for ch1, ch2, ch3 in zip(w, w[1:], w[2:]):
    i1, i2, i3= ctoi[ch1], ctoi[ch2], ctoi[ch3]
    context = [i1, i2]
    xs.append(context)
    ys.append(i3)

xs = torch.tensor(xs)
ys = torch.tensor(ys)

num_ex = ys.nelement()

xs.shape

In [None]:
# init NN
embed_dim = 20
context_len = 2
g = torch.Generator().manual_seed(2147483647)
C = torch.randn((num_chars, embed_dim), generator=g, requires_grad=True)
# C[xs].shape = (num_data_pairs, context_len, embed_dim)
W = torch.randn((context_len*embed_dim, num_chars), generator=g, requires_grad=True)

params = [C, W]
for p in params:
  p.requires_grad = True

In [None]:
num_iters=500
lr = 0.5

for k in range(num_iters):
  # NN forward pass
  xenc = C[xs]
  xenc = xenc.view(-1, context_len*embed_dim)
  logits = xenc @ W             # log counts -> only thing that will change in Transformers
  counts = logits.exp()         # equivalent to counts
  P = counts / counts.sum(dim=1, keepdims=True)
  # last 2 lines: softmax
  
  # loss: negative llh of probs corresponding to true labels
  loss = -P[torch.arange(num_ex), ys].log().mean() + 0.01*(W**2).mean()
  
  ## NN backward pass
  W.grad = None       # set grad to 0
  loss.backward()
  if k%10 == 0:
    print(f'Iter {k}, loss {loss.item()}')

  W.data += -lr*W.grad

In [None]:
# Sampling
g = torch.Generator().manual_seed(2147483647)
num_samples = 5

for i in range(num_samples):
  sample = []
  i1, i2 = 0, 0
  while True:
    xenc = C[torch.tensor([[i1, i2]])].flatten().unsqueeze(0)
    logits = xenc @ W
    counts = logits.exp()
    p = counts / counts.sum(1, keepdims=True)

    i3 = torch.multinomial(p, num_samples=1, replacement=True, generator=g).item()
    sample.append(itoc[i3])
    if i3 == 0:
      break
    i1, i2 = i2, i3

  print(''.join(sample))
