In [1]:
"""
Implements a simple n-gram language model in PyTorch.
Acts as the correctness reference for all the other versions.
"""
import math
import torch
import torch.nn as nn
from torch.nn import functional as F

In [2]:
from common import RNG

## Hyperparameter 

In [16]:
context_length = 3 # if 3 tokens predict the 4th, this is a 4-gram model
embedding_size = 64
hidden_size = 512

learning_rate = 1e-3
batch_size = 64
num_steps = 50000

### (optional) Optimize Hyperparameter

In [None]:
# pip install optunahub

In [19]:
import optuna

def t(train_model, evaluate_model, X_val, y_val):
    def objective(trial):
        embedding_size = trial.suggest_int('embedding_size', 50, 300)
        learning_rate = trial.suggest_loguniform('learning_rate', 1e-5, 1e-3)
        hidden_size = trial.suggest_int('hidden_size', 128, 512)
        
        model = train_model(embedding_size, learning_rate, hidden_size)
        val_loss = evaluate_model(model, X_val, y_val)
        
        return val_loss
    
    return objective



## MLP

In [4]:
class MLP(nn.Module):
    """
    Takes the previous n tokens, encodes them with a lookup table,
    concatenates the vectors and predicts the next token with an MLP.

    Reference:
    Bengio et al. 2003 https://www.jmlr.org/papers/volume3/bengio03a/bengio03a.pdf
    """

    def __init__(self, vocab_size, context_length, embedding_size, hidden_size):
        super().__init__()
        self.wte = nn.Embedding(vocab_size, embedding_size) # token embedding table
        self.mlp = nn.Sequential(
            nn.Linear(context_length * embedding_size, hidden_size),
            nn.Tanh(),
            nn.GELU(),
            nn.Linear(hidden_size, vocab_size)
        )

    def forward(self, idx, targets=None):
        # idx are the input tokens, (B, T) tensor of integers
        # targets are the target tokens, (B, ) tensor of integers
        B, T = idx.size()
        # encode all the tokens using the embedding table
        emb = self.wte(idx) # (B, T, embedding_size)
        # concat all of the embeddings together
        emb = emb.view(B, -1) # (B, T * embedding_size)
        # forward through the MLP
        logits = self.mlp(emb)
        # if we are given desired targets, also calculate the loss
        loss = None
        if targets is not None:
            loss = F.cross_entropy(logits, targets)
        return logits, loss

In [5]:
# -----------------------------------------------------------------------------
# simple DataLoader that iterates over all the n-grams

def dataloader(tokens, context_length, batch_size):
    # returns inputs, targets as torch Tensors of shape (B, T), (B, )
    n = len(tokens)
    inputs, targets = [], []
    pos = 0
    while True:
        # simple sliding window over the tokens, of size context_length + 1
        window = tokens[pos:pos + context_length + 1]
        inputs.append(window[:-1])
        targets.append(window[-1])
        # once we've collected a batch, emit it
        if len(inputs) == batch_size:
            yield (torch.tensor(inputs), torch.tensor(targets))
            inputs, targets = [], []
        # advance the position and wrap around if we reach the end
        pos += 1
        if pos + context_length >= n:
            pos = 0

In [6]:
# -----------------------------------------------------------------------------
# evaluation function

def eval_split(model, tokens, max_batches=None):
    # calculate the loss on the given tokens
    model.eval()
    total_loss = 0
    num_batches = len(tokens) // batch_size
    if max_batches is not None:
        num_batches = min(num_batches, max_batches)
    data_iter = dataloader(tokens, context_length, batch_size)
    for _ in range(num_batches):
        inputs, targets = next(data_iter)
        logits, loss = model(inputs, targets)
        total_loss += loss.item()
    mean_loss = total_loss / num_batches
    return mean_loss

## Training

In [7]:
random = RNG(1337)
# TODO: actually use this rng for the model initialization

# "train" the Tokenizer, so we're able to map between characters and tokens
train_text = open('data/train.txt', 'r').read()
assert all(c == '\n' or ('a' <= c <= 'z') for c in train_text)

In [9]:
uchars = sorted(list(set(train_text))) # unique characters we see in the input
uchars

['\n',
 'a',
 'b',
 'c',
 'd',
 'e',
 'f',
 'g',
 'h',
 'i',
 'j',
 'k',
 'l',
 'm',
 'n',
 'o',
 'p',
 'q',
 'r',
 's',
 't',
 'u',
 'v',
 'w',
 'x',
 'y',
 'z']

In [11]:
vocab_size = len(uchars)
vocab_size

27

In [13]:
# -----------------------------------------------------------------------------
# let's train!

char_to_token = {c: i for i, c in enumerate(uchars)}
token_to_char = {i: c for i, c in enumerate(uchars)}
print(char_to_token)
print(token_to_char)

{'\n': 0, 'a': 1, 'b': 2, 'c': 3, 'd': 4, 'e': 5, 'f': 6, 'g': 7, 'h': 8, 'i': 9, 'j': 10, 'k': 11, 'l': 12, 'm': 13, 'n': 14, 'o': 15, 'p': 16, 'q': 17, 'r': 18, 's': 19, 't': 20, 'u': 21, 'v': 22, 'w': 23, 'x': 24, 'y': 25, 'z': 26}
{0: '\n', 1: 'a', 2: 'b', 3: 'c', 4: 'd', 5: 'e', 6: 'f', 7: 'g', 8: 'h', 9: 'i', 10: 'j', 11: 'k', 12: 'l', 13: 'm', 14: 'n', 15: 'o', 16: 'p', 17: 'q', 18: 'r', 19: 's', 20: 't', 21: 'u', 22: 'v', 23: 'w', 24: 'x', 25: 'y', 26: 'z'}


In [14]:
EOT_TOKEN = char_to_token['\n'] # designate \n as the delimiting <|endoftext|> token
# pre-tokenize all the splits one time up here
test_tokens = [char_to_token[c] for c in open('data/test.txt', 'r').read()]
val_tokens = [char_to_token[c] for c in open('data/val.txt', 'r').read()]
train_tokens = [char_to_token[c] for c in open('data/train.txt', 'r').read()]

In [15]:

# create the model

model = MLP(vocab_size, context_length, embedding_size, hidden_size)

# create the optimizer

optimizer = torch.optim.AdamW(model.parameters(), lr=learning_rate, weight_decay=1e-4)

In [17]:

# training loop
print(f'num_steps {num_steps}, num_epochs {num_steps * batch_size / len(train_tokens):.2f}')
train_data_iter = dataloader(train_tokens, context_length, batch_size)
for step in range(num_steps):
    # cosine learning rate schedule, from max lr to 0
    lr = learning_rate * 0.5 * (1 + math.cos(math.pi * step / num_steps))
    for param_group in optimizer.param_groups:
        param_group['lr'] = lr
    # every now and then evaluate the validation loss
    last_step = step == num_steps - 1
    if step % 200 == 0 or last_step:
        train_loss = eval_split(model, train_tokens, max_batches=20)
        val_loss = eval_split(model, val_tokens)
        print(f'step {step} | train_loss {train_loss:.4f} | val_loss {val_loss:.4f} | lr {lr:e}')
    # ensure the model is in training mode
    model.train()
    # get the next batch of training data
    inputs, targets = next(train_data_iter)
    # forward through the model
    logits, loss = model(inputs, targets)
    # backpropagate and update the weights
    loss.backward()
    # step the optimizer
    optimizer.step()
    optimizer.zero_grad()


num_steps 50000, num_epochs 14.97
step 0 | train_loss 3.3127 | val_loss 3.3217 | lr 1.000000e-03
step 200 | train_loss 2.4165 | val_loss 2.4280 | lr 9.999605e-04
step 400 | train_loss 2.3531 | val_loss 2.3637 | lr 9.998421e-04
step 600 | train_loss 2.3161 | val_loss 2.3258 | lr 9.996447e-04
step 800 | train_loss 2.3040 | val_loss 2.3013 | lr 9.993685e-04
step 1000 | train_loss 2.2985 | val_loss 2.2921 | lr 9.990134e-04
step 1200 | train_loss 2.2708 | val_loss 2.2793 | lr 9.985795e-04
step 1400 | train_loss 2.2513 | val_loss 2.2580 | lr 9.980668e-04
step 1600 | train_loss 2.2518 | val_loss 2.2558 | lr 9.974755e-04
step 1800 | train_loss 2.2489 | val_loss 2.2439 | lr 9.968057e-04
step 2000 | train_loss 2.2324 | val_loss 2.2297 | lr 9.960574e-04
step 2200 | train_loss 2.2277 | val_loss 2.2293 | lr 9.952307e-04
step 2400 | train_loss 2.2246 | val_loss 2.2174 | lr 9.943259e-04
step 2600 | train_loss 2.2141 | val_loss 2.2127 | lr 9.933430e-04
step 2800 | train_loss 2.2119 | val_loss 2.2143 |