In [94]:
# Following along with
# https://www.youtube.com/watch?v=TCH_1BHY58I&list=PLAqhIrjkxbuWI23v9cThsA9GvCAUhRvKZ&index=3
# https://github.com/karpathy/nn-zero-to-hero/blob/master/lectures/makemore/makemore_part2_mlp.ipynb
import torch
import random

# for making figures
%matplotlib inline

In [95]:
# read in all the words - helps us build words
all_words = open("words.txt", "r").read().splitlines()
MAX_WORDS = min(10_000, len(all_words))
words = random.sample(all_words, MAX_WORDS)

In [96]:
random.seed(6_6_1978)  # pin the RNG
random.sample(all_words, 20)

['florencia',
 'draysen',
 'simrin',
 'yasna',
 'lathan',
 'lilymarie',
 'maryah',
 'ara',
 'pheonix',
 'muir',
 'aubriegh',
 'maryruth',
 'feroz',
 'abdiel',
 'anabiya',
 'kristin',
 'dashon',
 'harlei',
 'valery',
 'janani']

In [97]:
# build the vocabulary of characters and mappings to/from integers
chars = sorted(list(set("".join(words))))
stoi = {s: i + 1 for i, s in enumerate(chars)}
stoi["."] = 0
itos = {i: s for s, i in stoi.items()}


def string_to_index(s):
    return stoi[s]


def index_to_string(s):
    return itos[s]


def to_word(t):
    return "".join([itos[i.item()] for i in t])

In [98]:
# build the dataset

block_size = (
    4  # context length: how many characters do we take to predict the next one?
)
X, Y = [], []
for word in words[:]:
    # print(word)
    context = [string_to_index(".")] * block_size  # we start with '...'
    for char in word + ".":
        ix = string_to_index(char)
        X.append(context)
        Y.append(ix)
        # print(''.join(itos[i] for i in context), '--->', itos[ix])
        context = context[1:] + [ix]  # crop and append

X = torch.tensor(X)
Y = torch.tensor(Y)
inputs, expected_output = X, Y

In [99]:
def debug_samples():
    for inp, out in random.sample(list(zip(inputs, expected_output)), 20):
        print(f"{to_word(inp)}=>{index_to_string(out.item())} == {inp}, {out} ")


debug_samples()

....=>k == tensor([0, 0, 0, 0]), 11 
heod=>o == tensor([ 8,  5, 15,  4]), 15 
alys=>s == tensor([ 1, 12, 25, 19]), 19 
kash=>m == tensor([11,  1, 19,  8]), 13 
.ale=>i == tensor([ 0,  1, 12,  5]), 9 
jado=>r == tensor([10,  1,  4, 15]), 18 
mane=>h == tensor([13,  1, 14,  5]), 8 
elod=>i == tensor([ 5, 12, 15,  4]), 9 
..ra=>r == tensor([ 0,  0, 18,  1]), 18 
..ma=>s == tensor([ 0,  0, 13,  1]), 19 
...t=>y == tensor([ 0,  0,  0, 20]), 25 
tanv=>i == tensor([20,  1, 14, 22]), 9 
...s=>h == tensor([ 0,  0,  0, 19]), 8 
iele=>n == tensor([ 9,  5, 12,  5]), 14 
.ari=>s == tensor([ 0,  1, 18,  9]), 19 
...e=>v == tensor([0, 0, 0, 5]), 22 
arbo=>r == tensor([ 1, 18,  2, 15]), 18 
.joz=>i == tensor([ 0, 10, 15, 26]), 9 
....=>z == tensor([0, 0, 0, 0]), 26 
....=>m == tensor([0, 0, 0, 0]), 13 
