[![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/baggiponte/makemore/blob/main/notebooks/02-mlp.ipynb)

# Setup

In [None]:
try:
    from makemore.datasets import fetch_names
except ModuleNotFoundError:
    !pip install --quiet -- makemore
    from makemore.datasets import fetch_names

import torch
import torch.nn as nn
import torch.nn.functional as F
import matplotlib.pyplot as plt

torch.manual_seed(21474483647)

# get the data
names = fetch_names(shuffle=True, seed=42)

# Things you can play with

Here are some special parameters, called "hyperparameters" that you can tweak manually:

In [None]:
CONTEXT_SIZE = 4
EMBEDDING_DIMS = 10
INNER_SIZE = 200

MINIBATCH_SIZE = 64
EPOCHS = 150_000
EPSILON = (0.1, 0.01)
EPSILON_CUTOFF = 100_000

# Let's create the neural network

In [None]:
class EmbeddingMLP(nn.Module):

    def __init__(self, context_size, embedding_dimensions, hidden_size):
        super().__init__()
        self.vocabulary_size = 27
        self.context_size = context_size
        self.embedding_dimensions = embedding_dimensions
        
        self.embeddings = nn.Embedding(self.vocabulary_size, self.embedding_dimensions)
        
        self.stack = nn.Sequential(
            nn.Linear(context_size * embedding_dimensions, hidden_size),
            nn.Tanh(),
            nn.Linear(hidden_size, self.vocabulary_size),
        )
    
    def forward(self, x):
        embeddings = self.embeddings(x).view(-1, self.context_size * self.embedding_dimensions)
        logits = self.stack(embeddings)
        return logits
        
model = EmbeddingMLP(CONTEXT_SIZE, EMBEDDING_DIMS, INNER_SIZE)

# Prepare the data for training

In [None]:
context, labels = names.get_ngrams(CONTEXT_SIZE, as_tensor=True)

TRAIN_SIZE = 0.8
TEST_SIZE = 0.9

training_index = int(TRAIN_SIZE*len(context))
test_index = int(TEST_SIZE*len(context))

X_train, X_validation, X_test = context[:training_index], context[training_index:test_index], context[test_index:]
y_train, y_validation, y_test = labels[:training_index], labels[training_index:test_index], labels[test_index:]

print(
    f"Train set:\tX: {len(X_train)}\ty:{len(y_train)}",
    f"Validation set:\tX: {len(X_validation)}\ty:{len(y_validation)}",
    f"Test set:\tX: {len(X_test)}\ty:{len(y_test)}",
    sep="\n"
)

# Train the model

In [None]:
%%time
lri = []
lossi = []
stepi = []

for k in range(EPOCHS):
    
    # generate minibatches
    ix = torch.randint(0, X_train.shape[0], (MINIBATCH_SIZE,))
    
    # forward pass
    batch = X_train[ix]
    
    logits = model(batch)
    
    # much better in terms of performance and numerical stability
    loss = F.cross_entropy(logits, y_train[ix])

    # backward pass
    model.zero_grad()
    
    loss.backward()

    # update weights
    lr = EPSILON[0] if k < EPSILON_CUTOFF else EPSILON[1]
    for p in model.parameters():
        p.data -= lr * p.grad
    
    if k % 10000 == 0: # print every once in a while
        print(f'{k:7d}/{EPOCHS:7d}: {loss.item():.4f}')
    
    stepi.append(k)
    lossi.append(loss.log10().item())

print(f"\nLast batch loss: {loss.item():.5f}")

In [None]:
_ = plt.plot(stepi, lossi)

## Loss

In [None]:
from typing import Literal

@torch.no_grad()
def evaluate_loss(X, y) -> None:
    logits = model(X)
    loss = F.cross_entropy(logits, y)
    
    return loss.item()

print(
    f"Train loss:\t\t{evaluate_loss(X_train, y_train):.5f}",
    f"Validation loss:\t{evaluate_loss(X_validation, y_validation):.5f}",
    sep="\n",
)

# Generate names

In [None]:
from makemore.utils import int_to_character

g = torch.Generator().manual_seed(21474483647 + 10)

for _ in range(20):
    
    out = []
    context = [0] * CONTEXT_SIZE
    while True:
        logits = model(torch.tensor([context]))
        probs = F.softmax(logits, dim=1)
        ix = torch.multinomial(probs, num_samples=1, generator=g).item()
        context = context[1:] + [ix]
        out.append(ix)
        
        if ix == 0:
            break
            
    print("".join(int_to_character(i) for i in out[:-1]))