In [1]:
import torch
import matplotlib.pyplot as plt

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
words = open('../data/names.txt', 'r').read().splitlines()
words[:10]

['emma',
 'olivia',
 'ava',
 'isabella',
 'sophia',
 'charlotte',
 'mia',
 'amelia',
 'harper',
 'evelyn']

In [4]:
chars = sorted(list(set(''.join(words))))
ctoi = {c: i+1 for i, c in enumerate(chars)}
ctoi['.'] = 0
itoc = {i: c for c, i in ctoi.items()}

num_chars = 27

In [27]:
# Create train set
xs, ys = [], []

for w in words:
  w = ['.'] + ['.'] + list(w) + ['.'] + ['.']
  for ch1, ch2, ch3 in zip(w, w[1:], w[2:]):
    i1, i2, i3= ctoi[ch1], ctoi[ch2], ctoi[ch3]
    xs.append((i1, i2))
    ys.append(i3)

xs = torch.tensor(xs)
ys = torch.tensor(ys)

num_ex = ys.nelement()

xs.shape

torch.Size([260179, 2])

In [28]:
import torch.nn.functional as F
xenc = F.one_hot(xs, num_classes=num_chars).float()
xenc.shape

torch.Size([260179, 2, 27])

In [29]:
# init NN
g = torch.Generator().manual_seed(2147483647)
W = torch.randn((num_chars, num_chars, num_chars), generator=g, requires_grad=True)
num_iters = 100
lr = 10

In [38]:
for k in range(num_iters):
  # NN forward pass
  # logits = xenc @ W             # log counts -> only thing that will change in Transformers
  logits = W[xs[:, 0], xs[:, 1]] 
  # print(logits.shape)
  # matrix mult actually plugs out the i-th row if label is i
  counts = logits.exp()         # equivalent to counts
  P = counts / counts.sum(dim=1, keepdims=True)
  # last 2 lines: softmax
  
  # loss: negative llh of probs corresponding to true labels
  loss = -P[torch.arange(num_ex), ys].log().mean() + 0.01*(W**2).mean()


  ## NN backward pass
  W.grad = None       # set grad to 0
  loss.backward()

  print(loss.item())

  W.data += -lr*W.grad

torch.Size([260179, 27])
2.47672963142395
torch.Size([260179, 27])
2.4752602577209473
torch.Size([260179, 27])
2.4738001823425293
torch.Size([260179, 27])
2.4723501205444336
torch.Size([260179, 27])
2.470909833908081
torch.Size([260179, 27])
2.4694790840148926
torch.Size([260179, 27])
2.468057632446289
torch.Size([260179, 27])
2.4666452407836914
torch.Size([260179, 27])
2.4652421474456787
torch.Size([260179, 27])
2.46384859085083
torch.Size([260179, 27])
2.462463855743408
torch.Size([260179, 27])
2.461087703704834
torch.Size([260179, 27])
2.459721088409424
torch.Size([260179, 27])
2.4583628177642822
torch.Size([260179, 27])
2.4570131301879883
torch.Size([260179, 27])
2.455672025680542
torch.Size([260179, 27])
2.4543392658233643
torch.Size([260179, 27])
2.4530153274536133
torch.Size([260179, 27])
2.4516994953155518
torch.Size([260179, 27])
2.4503917694091797
torch.Size([260179, 27])
2.449092388153076
torch.Size([260179, 27])
2.447801113128662
torch.Size([260179, 27])
2.4465174674987793


In [37]:
# Sampling
g = torch.Generator().manual_seed(2147483647)
num_samples = 5

for i in range(num_samples):
  sample = []
  i1, i2 = 0, 0
  while True:
    # xenc = F.one_hot(torch.tensor([i1]), num_classes=num_chars).float()
    # logits = xenc @ W
    logits = W[i1, i2].reshape((1, -1)) 
    counts = logits.exp()
    p = counts / counts.sum(1, keepdims=True)

    i3 = torch.multinomial(p, num_samples=1, replacement=True, generator=g).item()
    sample.append(itoc[i3])
    if i3 == 0:
      break
    i1, i2 = i2, i3

  print(''.join(sample))


mnvc.
uyallinyqxuznlwckvkovakvfid.
maryiuzeqml.
odfsfpen.
daespeldkqtqkwjkrmgjywov.
