This is the neural-network version of makemore for trigrams

Load all the words.

In [1]:
words = []
with open("names.txt") as file:
    words = file.read().splitlines()
words[:3]

['emma', 'olivia', 'ava']

Encode all the words with integer values for the letters.

In [2]:
chars = ".abcdefghijklmnopqrstuvwxyz"
char_to_index = {}
index = 0
for char in list(chars):
    char_to_index[char] = index
    index += 1

index_to_char = {index:char for char,index in char_to_index.items()}
all_words_encoded = []
for word in words:
    letters = [".", "."] + list(word) + ["."]
    all_words_encoded.append([char_to_index[letter] for letter in letters])

def btoi(chars):
    if len(chars) != 2:
        raise Exception("bigram_to_index accepts only strings of two characters where both are either letters or \".\"")
    return chars[0] * 27 + chars[1]
#"".join([index_to_char[char] for char in encoded_words[0]])

Split all the words into trigrams. Represent the input (the first two characters) as the second + 27 * the first. Convert those all to one-hot encodings.

In [3]:
import torch
n_bigrams_possible = 27 * 27

xs, ys = [], []

for word_encoded in all_words_encoded[:]:
    for index1, index2, index3 in zip(word_encoded, word_encoded[1:], word_encoded[2:]):
        x = btoi([index1, index2])
        y = index3
        xs.append(x)
        ys.append(y)

import torch.nn.functional as F
inputs = F.one_hot(torch.tensor(xs), num_classes=n_bigrams_possible).float()

n_observations = len(xs)
print(inputs.shape)

torch.Size([228146, 729])


Our neural network will have one layer, with one node for each possible two-character combination. Each will output the probability for each character that it will come next. We initialize it with random, normally distributed weights.

In [4]:

torch.set_printoptions(linewidth=132)
W = torch.randn(n_bigrams_possible, 27, requires_grad=True)
W.shape

torch.Size([729, 27])

In [5]:
# Training loop
for i in range(10000):
    logits = inputs @ W
    counts = logits.exp()
    probs = counts / counts.sum(1, keepdims=True)
    loss = -probs[torch.arange(n_observations), ys].log().mean() + 0.01*(W**2).mean()
    print(i, loss.item())
    W.grad = None
    loss.backward()
    W.data += -50 * W.grad

0 3.7895710468292236
1 3.6840734481811523
2 3.6045732498168945
3 3.537385940551758
4 3.4780349731445312
5 3.4242959022521973
6 3.375133752822876
7 3.330003261566162
8 3.2885236740112305
9 3.250383138656616
10 3.2153022289276123
11 3.183016300201416
12 3.1532669067382812
13 3.1257970333099365
14 3.100360631942749
15 3.076728343963623
16 3.054694414138794
17 3.0340800285339355
18 3.0147314071655273
19 2.9965195655822754
20 2.9793314933776855
21 2.9630720615386963
22 2.9476561546325684
23 2.933012008666992
24 2.919074773788452
25 2.9057865142822266
26 2.8930952548980713
27 2.880955934524536
28 2.869326591491699
29 2.8581697940826416
30 2.84745192527771
31 2.8371429443359375
32 2.8272149562835693
33 2.817643404006958
34 2.8084053993225098
35 2.79948091506958
36 2.79085111618042
37 2.782498598098755
38 2.7744083404541016
39 2.766566276550293
40 2.7589590549468994
41 2.751574993133545
42 2.7444026470184326
43 2.7374322414398193
44 2.73065447807312
45 2.7240607738494873
46 2.717641830444336
4

KeyboardInterrupt: 

In [6]:
print("starting char probabilities", W[0].exp() / W[0].exp().sum())
P = W.exp() / W[:].exp().sum(0)
P

starting char probabilities tensor([5.7215e-05, 1.3766e-01, 4.0765e-02, 4.8131e-02, 5.2751e-02, 4.7788e-02, 1.3021e-02, 2.0884e-02, 2.7282e-02, 1.8450e-02,
        7.5599e-02, 9.2487e-02, 4.9067e-02, 7.9220e-02, 3.5771e-02, 1.2303e-02, 1.6078e-02, 2.8856e-03, 5.1159e-02, 6.4144e-02,
        4.0827e-02, 2.4497e-03, 1.1741e-02, 9.5888e-03, 4.1941e-03, 1.6702e-02, 2.8998e-02], grad_fn=<DivBackward0>)


tensor([[4.5593e-07, 5.6014e-04, 2.7034e-03,  ..., 3.2596e-04, 2.8524e-04, 2.0993e-03],
        [3.4872e-06, 2.4777e-04, 3.7065e-03,  ..., 6.2266e-04, 8.6908e-04, 3.2369e-03],
        [2.7187e-05, 4.3864e-03, 2.3732e-04,  ..., 3.0861e-04, 3.6327e-04, 2.8632e-04],
        ...,
        [3.3728e-04, 4.3293e-05, 3.8165e-04,  ..., 4.5870e-03, 3.4772e-04, 1.4735e-03],
        [2.7171e-03, 1.0963e-03, 5.2146e-04,  ..., 5.7771e-04, 1.2835e-04, 7.5108e-04],
        [7.7594e-04, 1.4448e-03, 6.7936e-04,  ..., 6.8217e-04, 3.0417e-03, 7.7982e-04]], grad_fn=<DivBackward0>)

In [34]:
from functools import reduce
all_output = []
for i in range(20000):
  output = []
  last_chars = [0, 0] # the bigram representing two word-starts characters.
  while True:
    last_bigram = btoi(last_chars)
    bigram_encoded = F.one_hot(torch.tensor([last_bigram]), num_classes=n_bigrams_possible).float()
    logits = bigram_encoded @ W
    counts = logits.exp()
    probs = counts / counts.sum(1, keepdims=True)
    next_char = torch.multinomial(probs, 1, replacement=True).item()
    if next_char == 0: 
      break
    output.append(next_char)
    last_chars[0] = last_chars[1]
    last_chars[1] = next_char
  
  all_output.append("".join(map(lambda char: index_to_char[char], output)))
print(all_output)
reduce(lambda x,y:x+y, map(len, all_output)) / len(all_output)

['olukton', 'adreon', 'royce', 'domen', 'rhyaaria', 'raft', 'oldaivon', 'luinaosperiahdzpbubrabiel', 'kalyn', 'jon', 'yontlyukmogabrey', 'shelaur', 'ya', 'ka', 'davann', 'isletson', 'katine', 'mah', 'ra', 'kare', 'relyn', 'ibhava', 'adomirzailleriahlournevahsaili', 'ba', 'mar', 'myla', 'sa', 'zoimpwah', 'lyna', 'cyce', 'gion', 'milavana', 'sah', 'thannistgeznfxon', 'jous', 'ja', 'cbieni', 'rysela', 'keiraya', 'en', 'oish', 'guritie', 'en', 'kaya', 'an', 'manalexlilin', 'ver', 'ka', 'shayzfchostis', 'remiylaylekshewesdjruz', 'kanidazlyn', 'kamarih', 'ka', 'leindrehnavik', 'nie', 'na', 'jarickilyn', 'rharia', 'jreanneth', 'jamakian', 'jann', 'arlyn', 'belaylah', 'as', 'phtleovarin', 'lenzahjknorian', 'idelrqfndippe', 'paw', 'jerain', 'ina', 'trefevalein', 'bric', 'aranoramorebairen', 'ala', 'aan', 'slen', 'mi', 'aleen', 'alans', 'anioveemircel', 'banien', 'jezrie', 'mord', 'polib', 'luan', 'lani', 'clanionneth', 'adinsle', 'dwina', 'bena', 'ny', 'javenoleth', 'mazizlen', 'dalon', 'leiggr

6.25635

Compute the average size of the words in the dataset

In [35]:
reduce(lambda x, y: x + y, map(len, words)) / len(words)

6.122217712983486