In [None]:
import torch
import matplotlib.pyplot as plt

In [None]:
words = open('../data/names.txt', 'r').read().splitlines()
words[:10]

In [None]:
chars = sorted(list(set(''.join(words))))
ctoi = {c: i+1 for i, c in enumerate(chars)}
ctoi['.'] = 0
itoc = {i: c for c, i in ctoi.items()}

num_chars = 27

In [None]:
# Create train set
xs, ys = [], []

for w in words:
  w = ['.'] + list(w) + ['.']
  for ch1, ch2 in zip(w, w[1:]):
    r, c = ctoi[ch1], ctoi[ch2]
    xs.append(r)
    ys.append(c)

xs = torch.tensor(xs)
ys = torch.tensor(ys)

num_ex = xs.nelement()
num_ex

In [None]:
import torch.nn.functional as F
xenc = F.one_hot(xs, num_classes=num_chars).float()
xenc.shape

In [None]:
# init NN
g = torch.Generator().manual_seed(2147483647)
W = torch.randn((num_chars, num_chars), generator=g, requires_grad=True)
num_iters = 100
lr = 10

In [None]:
for k in range(num_iters):
  # NN forward pass
  logits = xenc @ W             # log counts -> only thing that will change in Transformers
  print(logits.shape)
  # matrix mult actually plugs out the i-th row if label is i
  counts = logits.exp()         # equivalent to counts
  P = counts / counts.sum(dim=1, keepdims=True)
  # last 2 lines: softmax
  
  # loss: negative llh of probs corresponding to true labels
  loss = -P[torch.arange(num_ex), ys].log().mean() + 0.01*(W**2).mean()


  ## NN backward pass
  W.grad = None       # set grad to 0
  loss.backward()

  print(loss.item())

  W.data += -lr*W.grad

In [None]:
# Sampling
g = torch.Generator().manual_seed(2147483647)
num_samples = 5

for i in range(num_samples):
  sample = []
  i1 = 0
  while True:
    xenc = F.one_hot(torch.tensor([i1]), num_classes=num_chars).float()
    logits = xenc @ W
    counts = logits.exp()
    p = counts / counts.sum(1, keepdims=True)

    i2 = torch.multinomial(p, num_samples=1, replacement=True, generator=g).item()
    sample.append(itoc[i2])
    if i2 == 0:
      break
    i1 = i2

  print(''.join(sample))

  # Exactly the same results as explicit model!