# Configuration

In [1]:
# Configuration
VOCAB_SIZE = 1024  # Size of the vocabulary
CONTEXT_LENGTH = 512  # Fixed context length for chunks
EMBEDDING_DIM = 256  # Dimension of the token embeddings
NUM_HEADS = 8  # Number of attention heads
NUM_LAYERS = 2  # Number of transformer layers
QK_HEAD_DIM = 16  # Dimension of the query and key heads
V_HEAD_DIM = 32  # Dimension of the value head
MLP_DIM = 512  # Dimension of the hidden layers in the transformer
BATCH_SIZE = 128  # Batch size for training
EPOCHS = 20 # Number of epochs to train
TOKENIZER_FILE = "./data/tinystories-tokenizer"
CHUNK_FILE = "./data/chunked_stories"
SAMPLE_LIMIT = 100000  # Set to None to process the entire dataset
DICT_LABEL = 'seq'

In [2]:
%load_ext autoreload
%autoreload 2

import os
from pathlib import Path

import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
from tqdm.auto import tqdm

import mlx.core as mx
import mlx.nn as nn
import mlx.optimizers as optim
import mlx.data as dx

from datasets import load_dataset

from tokenizers import Tokenizer
from tokenizers.models import BPE
from tokenizers.trainers import BpeTrainer
from tokenizers.pre_tokenizers import Metaspace
from tokenizers.decoders import Metaspace as MetaspaceDecoder

from llm.modules import SmallLanguageModel, loss_fn, count_parameters, generate_story
from llm.data import batch_iterator, chunk_story, data_to_array_of_dict
from llm.gpt import GPT

# Dataset

In [3]:
# from datasets import load_dataset
dataset = load_dataset("roneneldan/TinyStories", split="train")
if SAMPLE_LIMIT:
    dataset = dataset.select(range(min(SAMPLE_LIMIT, len(dataset))))
validation = load_dataset("roneneldan/TinyStories", split="validation")
print(f"Dataset size: {len(dataset)}")

Dataset size: 100000


# Tokenizer

In [4]:
tokenizer_path = f'{TOKENIZER_FILE}_{VOCAB_SIZE}.json'
if os.path.exists(tokenizer_path):
    print(f"Tokenizer file {tokenizer_path} already exists. Skipping training.")
else:
    # Initialize a BPE tokenizer
    tokenizer = Tokenizer(BPE(unk_token="[UNK]"))
    tokenizer.pre_tokenizer = Metaspace(replacement=" ")
    tokenizer.decoder = MetaspaceDecoder(replacement=" ")

    # Configure the trainer with a vocabulary size and special tokens
    trainer = BpeTrainer(vocab_size=VOCAB_SIZE, special_tokens=["[UNK]", "[PAD]", "[SOS]", "[EOS]", "\n"])

    # Train the tokenizer on our text file
    print("Training tokenizer...")
    tokenizer.train_from_iterator(batch_iterator(dataset), trainer=trainer, length=len(dataset))
    # tokenizer.train(['./data/tinystories_data.txt'], trainer)
    print("Training complete.")

    # --- Save and Test the Tokenizer ---

    tokenizer.save(tokenizer_path)
    print(f"Tokenizer saved to {tokenizer_path}")

tokenizer = Tokenizer.from_file(tokenizer_path)
# tokenizer = PreTrainedTokenizerFast(tokenizer_file=tokenizer_path)
encoded = tokenizer.encode("Once upon a time, there was a little fox.\nIt lived in a forest and loved to explore.")

print("\n--- Testing the Tokenizer ---")
print("Tokens:", encoded.tokens)
print("IDs:", encoded.ids)
print("Decoded:", tokenizer.decode(encoded.ids, skip_special_tokens=True))

Tokenizer file ./data/tinystories-tokenizer_1024.json already exists. Skipping training.

--- Testing the Tokenizer ---
Tokens: [' Once', ' upon', ' a', ' time,', ' there', ' was', ' a', ' little', ' fo', 'x', '.', '\n', ' It', ' lived', ' in', ' a', ' forest', ' and', ' loved', ' to', ' expl', 'ore.']
IDs: [286, 302, 116, 337, 257, 137, 116, 256, 683, 86, 19, 4, 269, 794, 176, 116, 966, 122, 367, 123, 631, 771]
Decoded: Once upon a time, there was a little fox. It lived in a forest and loved to explore.


# Chunking

In [5]:
if os.path.exists(f'{CHUNK_FILE}_{VOCAB_SIZE}_{CONTEXT_LENGTH}.npz'):
    print(f"Chunk file {CHUNK_FILE}_{VOCAB_SIZE}_{CONTEXT_LENGTH}.npz already exists. Skipping chunking.")
else:
    # Load the tokenizer
    tokenizer = Tokenizer.from_file(f'{TOKENIZER_FILE}_{VOCAB_SIZE}.json')

    # Process all stories and collect chunks
    num_non_special_tokens = 0
    all_chunks = []
    unfinished_chunk = []
    for story in tqdm(dataset["text"], desc="Chunking stories"):
        story_chunks, unfinished_chunk, count = chunk_story(story, tokenizer, '[SOS]', '[EOS]', CONTEXT_LENGTH, unfinished_chunk)
        all_chunks.extend(story_chunks)
        num_non_special_tokens += count

    # Convert list to numpy array for efficient storage
    chunks_array = np.array(all_chunks, dtype=np.int32)

    # Print statistics
    print(f"Created {len(all_chunks)} chunks of length {CONTEXT_LENGTH}")
    print(f"Total non-special tokens: {num_non_special_tokens}")
    print(f"Array shape: {chunks_array.shape}")

    # Save the chunks to a compressed file
    print(f"Saving chunks to {CHUNK_FILE}_{VOCAB_SIZE}_{CONTEXT_LENGTH}.npz...")
    np.savez_compressed(f'{CHUNK_FILE}_{VOCAB_SIZE}_{CONTEXT_LENGTH}.npz', chunks=chunks_array)
    print(f"Saved successfully! File size: {os.path.getsize(f'{CHUNK_FILE}_{VOCAB_SIZE}_{CONTEXT_LENGTH}.npz') / (1024 * 1024):.2f} MB")

Chunk file ./data/chunked_stories_1024_512.npz already exists. Skipping chunking.


# Data Pipeline

In [6]:
data = np.load(f'{CHUNK_FILE}_{VOCAB_SIZE}_{CONTEXT_LENGTH}.npz')
dicts = data_to_array_of_dict(data['chunks'], name=DICT_LABEL)

assert type(dicts) == list
assert type(dicts[0]) == dict
assert type(dicts[0][DICT_LABEL]) == np.ndarray

buffer = dx.buffer_from_vector(dicts)
stream = buffer.to_stream().batch(BATCH_SIZE).shuffle(buffer_size=BATCH_SIZE*1000).prefetch(8,4)

In [7]:
for x in stream:
    print(x[DICT_LABEL].shape)
    print(type(x[DICT_LABEL]))
    break  # Just to test the first batch

(128, 512)
<class 'numpy.ndarray'>


# Model

In [21]:
model = GPT()
print(f"Number of parameters in the model: {count_parameters(model.parameters()):,}")

Number of parameters in the model: 15,499,264


In [31]:
model = SmallLanguageModel(vocab_dim=VOCAB_SIZE, embed_dim=EMBEDDING_DIM, n_head=NUM_HEADS, num_layers=NUM_LAYERS, qk_head_dim=QK_HEAD_DIM, v_head_dim=V_HEAD_DIM, mlp_dim=MLP_DIM, max_len=CONTEXT_LENGTH)
# check number of parameters
print(f"Number of parameters in the model: {count_parameters(model.parameters()):,}")

Number of parameters in the model: 1,576,960


In [9]:
# search for existing model weights with same vocab size and context length but wildcard epoch number
# load existing model weights if they exist and record the epoch number

matching_paths = list(Path('./data').glob(f'model_weights_{VOCAB_SIZE}_{CONTEXT_LENGTH}_*.npz'))
if len(matching_paths) == 0:
    print("No existing model weights found. Starting training from scratch.")
    last_epoch = None
    last_batch = None
elif len(matching_paths) > 1:
    raise ValueError(f"Multiple model weight files found for vocab size {VOCAB_SIZE} and context length {CONTEXT_LENGTH}. Please ensure only one exists.")
else:
    path = matching_paths[0]
    print(f"Found existing model weights: {path.name}")
    # Load the model weights
    model.load_weights(str(path))
    # Extract epoch number from filename
    last_epoch = int(path.stem.split('_')[-2])
    last_batch = int(path.stem.split('_')[-1])
    print(f"Loaded model weights from epoch {last_epoch}, batch {last_batch}.")

No existing model weights found. Starting training from scratch.


# Training

In [32]:
tokenizer = Tokenizer.from_file(f'{TOKENIZER_FILE}_{VOCAB_SIZE}.json')
sos_token_id = tokenizer.token_to_id('[SOS]')
eos_token_id = tokenizer.token_to_id('[EOS]')
pad_token_id = tokenizer.token_to_id('[PAD]')
optimizer = optim.AdamW(learning_rate=0.001, betas=[0.9, 0.99], weight_decay=0.01)
loss_and_grad = nn.value_and_grad(model, loss_fn)
print(f'MLX current default device: {mx.default_device()}')

MLX current default device: Device(gpu, 0)


In [26]:
def loss_fn(model, x, y):
    logits = model(x)
    B, T, C = logits.shape
    logits = logits.reshape(B*T, C)
    y = y.reshape(B*T)
    loss = nn.losses.cross_entropy(logits, y, reduction='mean')
    return loss

# model = GPT()
mx.eval(model.parameters())
loss_and_grad = nn.value_and_grad(model, loss_fn)
optimizer = optim.AdamW(learning_rate=0.001)

In [33]:
save_freq = 200
for epoch in range(EPOCHS):
    losses = []
    if last_epoch and epoch + 1 < last_epoch:
        continue
    for i, seq in enumerate(stream):
        if last_batch and i + 1 <= last_batch:
            continue
        mx_seq = mx.array(seq[DICT_LABEL])
        input_seq = mx_seq[:, :-1]  # Exclude the last token for input
        target_seq = mx_seq[:, 1:]  # Exclude the first token for target
        loss, grads = loss_and_grad(model, input_seq, target_seq)
        optimizer.update(model, grads)
        mx.eval(model.parameters(), optimizer.state)
        print(f"Batch {i + 1}, Loss: {loss:.4f}")
        if (i+1) % save_freq == 0:
            generate_story(model, tokenizer, "[SOS]", max_length=CONTEXT_LENGTH, eos_token_id=eos_token_id, temp=1.0)
            model.save_weights(f'./data/model_weights_{VOCAB_SIZE}_{CONTEXT_LENGTH}_{epoch+1}_{i+1}.npz')
            if i+1 != save_freq: os.remove(f'./data/model_weights_{VOCAB_SIZE}_{CONTEXT_LENGTH}_{epoch+1}_{i+1-save_freq}.npz') if i > 0 else None
            print('-'*20)
            print(f"Active memory: {mx.get_active_memory() / 1024**3:.2f} GB")
            print(f"Cache memory: {mx.get_cache_memory() / 1024**3:.2f} GB")
            print(f"Peak memory: {mx.get_peak_memory() / 1024**3:.2f} GB")
            mx.clear_cache()
            print('-'*20)
        losses.append(loss)
    stream.reset()
    avg_loss = mx.array(losses).mean()
    print(f"Epoch {epoch + 1}, Average Loss: {avg_loss:.4f}")
    generate_story(model, tokenizer, "[SOS]", max_length=CONTEXT_LENGTH, eos_token_id=eos_token_id, temp=1.0)
model.save_weights(f'./data/trained_model_weights_{VOCAB_SIZE}_{CONTEXT_LENGTH}_{EPOCHS}.npz')

Batch 1, Loss: 7.0772
Batch 2, Loss: 6.6403
Batch 3, Loss: 6.3506
Batch 4, Loss: 6.2419
Batch 5, Loss: 6.1409
Batch 6, Loss: 6.0743
Batch 7, Loss: 6.0181
Batch 8, Loss: 6.0436
Batch 9, Loss: 6.0351
Batch 10, Loss: 6.0835
Batch 11, Loss: 6.0853
Batch 12, Loss: 6.0115
Batch 13, Loss: 6.0087
Batch 14, Loss: 5.9957
Batch 15, Loss: 6.0321
Batch 16, Loss: 5.9987
Batch 17, Loss: 5.9565
Batch 18, Loss: 6.0541
Batch 19, Loss: 6.0295
Batch 20, Loss: 6.0813
Batch 21, Loss: 6.0128
Batch 22, Loss: 6.0534
Batch 23, Loss: 6.0104
Batch 24, Loss: 6.0066
Batch 25, Loss: 6.0257
Batch 26, Loss: 5.9715
Batch 27, Loss: 6.0085
Batch 28, Loss: 6.0130
Batch 29, Loss: 5.9657
Batch 30, Loss: 6.0174
Batch 31, Loss: 6.0384
Batch 32, Loss: 5.9279
Batch 33, Loss: 5.9858
Batch 34, Loss: 5.9576
Batch 35, Loss: 5.9803
Batch 36, Loss: 6.0071
Batch 37, Loss: 6.0454
Batch 38, Loss: 5.9718
Batch 39, Loss: 5.9523
Batch 40, Loss: 5.9742
Batch 41, Loss: 5.9748
Batch 42, Loss: 5.9606
Batch 43, Loss: 5.9872
Batch 44, Loss: 5.97

KeyboardInterrupt: 