In [9]:
import torch
import torch.nn.functional as F
import matplotlib.pyplot as plt
%matplotlib inline

from typing import List, Union

In [10]:
words = open("names.txt", "r").read().splitlines()
print(len(words))
print(max(len(w) for w in words))
print(words[:8])

32033
15
['emma', 'olivia', 'ava', 'isabella', 'sophia', 'charlotte', 'mia', 'amelia']


In [11]:
# build character dictionaries
chars = ["."] + sorted(list(set("".join(words))))
char_to_idx = {char:idx for idx, char in enumerate(chars)}
idx_to_char = {idx: char for char, idx in char_to_idx.items()}

In [12]:
def build_dataset(words):
    X_list: List[List[int]] = []; y_list: List[int] = []
    context_length = 5
    for word in words:
        
        window = [0] * context_length
        for char in word + ".":
            _x = window; X_list.append(_x)
            _y = char_to_idx[char]; y_list.append(_y)
            window = window[1:] + [char_to_idx[char]]
            
    X = torch.tensor(X_list); y = torch.tensor(y_list)
    
    return X, y

import random
random.seed(42)
random.shuffle(words)
n1 = int(0.8 * len(words)); n2 = int(0.9 * len(words))

X_train, y_train = build_dataset(words[:n1]) # 80%
X_dev, y_dev = build_dataset(words[n1:n2])   # 10%
X_test, y_test = build_dataset(words[n2:])   # 10%

X_train.shape[0], X_dev.shape[0], X_test.shape[0]

(182625, 22655, 22866)