This is the neural-network version of makemore for trigrams

Load all the words.

In [None]:
words = []
with open("names.txt") as file:
    words = file.read().splitlines()
words[:3]

Encode all the words with integer values for the letters.

In [None]:
chars = ".abcdefghijklmnopqrstuvwxyz"
char_to_index = {}
index = 0
for char in list(chars):
    char_to_index[char] = index
    index += 1

index_to_char = {index:char for char,index in char_to_index.items()}
all_words_encoded = []
for word in words:
    letters = [".", "."] + list(word) + ["."]
    all_words_encoded.append([char_to_index[letter] for letter in letters])

def btoi(chars):
    if len(chars) != 2:
        raise Exception("bigram_to_index accepts only strings of two characters where both are either letters or \".\"")
    return char_to_index[chars[0]] * 27 + char_to_index[chars[1]]
#"".join([index_to_char[char] for char in encoded_words[0]])

Our neural network will have one layer, with one node for each possible two-character combination. Each will output the probability for each character that it will come next. We initialize it with random, normally distributed weights.

In [None]:
n_bigrams_possible = 27 * 27

import torch
torch.set_printoptions(linewidth=132)
W = torch.randn(n_bigrams_possible, 27, requires_grad=True)

Split all the words into trigrams. Represent the input (the first two characters) as the second + 27 * the first. Convert those all to one-hot encodings.

In [None]:
xs, ys = [], []

for word_encoded in all_words_encoded[:]:
    for index1, index2, index3 in zip(word_encoded, word_encoded[1:], word_encoded[2:]):
        x = (index1 * 27) + index2
        y = index3
        xs.append(x)
        ys.append(y)

import torch.nn.functional as F
inputs = F.one_hot(torch.tensor(xs), num_classes=n_bigrams_possible).float()

n_observations = len(xs)
n_observations

In [None]:
logits = inputs @ W
counts = logits.exp()
probs = counts / counts.sum(1, keepdims=True)
logits.shape

In [None]:
for i in range(100):
    logits = inputs @ W
    counts = logits.exp()
    probs = counts / counts.sum(1, keepdims=True)
    loss = -probs[torch.arange(n_observations)].log().mean() + 0.01*(W**2).mean()
    print(i, loss.item())
    W.grad = None
    loss.backward()
    W.data += -50 * W.grad

In [None]:
for i in range(20):
  print(i)
  output = []
  start_bigram = 0 # the bigram representing two word-starts characters.
  start_bigram_encoded = F.one_hot(torch.tensor([start_bigram]), num_classes=n_bigrams_possible).float()
  logits = start_bigram_encoded @ W
  counts = logits.exp()
  probs = counts / counts.sum(1, keepdims=True)
  index = torch.multinomial(probs, 1, replacement=True).item()
  
  while True:
    encoded = F.one_hot(torch.tensor([index]), num_classes=n_bigrams_possible).float()
    logits = encoded @ W
    counts = logits.exp()
    probs = counts / counts.sum(1, keepdims=True)
    next_index = torch.multinomial(probs, 1, replacement=True).item()
    next_char = index_to_char[next_index]
    output.append(next_char)
    if next_char == ".":
      break
  print("".join(output))

Let's count how many trigrams actually come up