This is the neural-network version of makemore for trigrams

Load all the words.

In [137]:
words = []
with open("names.txt") as file:
    words = file.read().splitlines()
words[:3]

['emma', 'olivia', 'ava']

From that we create the trigrams, which are the training set.

In [138]:
chars = ".abcdefghijklmnopqrstuvwxyz"
char_to_index = {}
index = 0
for char in list(chars):
    char_to_index[char] = index
    index += 1

index_to_char = {index:char for char,index in char_to_index.items()}
encodings = []
for word in words:
    letters = [".", "."] + list(word) + ["."]
    encodings.append([char_to_index[letter] for letter in letters])

        

KeyError: '..'

Our neural network will have one layer, with one node for each possible two-character combination. Each will output the probability for each character that it will come next. We initialize it with random, normally distributed weights.

In [None]:
n_bigrams_possible = 27 * 27

import torch
W = torch.randn(n_bigrams_possible, 27, requires_grad=True)

Split all the words into trigrams. Represent the input (the first two characters) as the second + 27 * the first. Convert those all to one-hot encodings.

In [None]:
xs, ys = [], []
for encoding in encodings:
    for index1, index2, index3 in zip(encoding, encoding[1:], encoding[2:]):
        x = (index1 * 27) + index2
        y = index3
        xs.append(x)
        ys.append(y)

import torch.nn.functional as F
inputs = F.one_hot(torch.tensor(xs), num_classes=n_bigrams_possible).float()

n_observations = len(xs)

In [None]:
for i in range(100):
    logits = inputs @ W
    counts = logits.exp()
    probs = counts / counts.sum(1, keepdims=True)
    loss = -probs[torch.arange(n_observations)].log().mean()
    print(loss.item())
    W.grad = None
    loss.backward()
    W.data += -.5 * W.grad

In [None]:
F.one_hot(torch.arange(27), num_classes=n_bigrams_possible)

In [None]:
for i in range(20):
  output = []
  index = 0
  first_possible_bigrams = F.one_hot(torch.arange(27), num_classes=n_bigrams_possible).float()
  logits = first_possible_bigrams @ W
  counts = logits.exp()
  probs = counts / counts.sum(1, keepdims=True)
  index = torch.multinomial(probs, 1, replacement=True).item()
  
  while True:
    encoded = F.one_hot(torch.tensor(index), num_classes=n_bigrams_possible).float()
    logits = encoded @ W
    counts = logits.exp()
    probs = counts / counts.sum(1, keepdims=True)
    next_index = torch.multinomial(probs, 1, replacement=True).item()
    next_char = index_to_char[next_index]
    output.append(next_char)
    if next_char == ".":
      break
  print("".join(output))