# Configuration

In [None]:
# Model Configuration
CONTEXT_LENGTH = 512  # Fixed context length for chunks
EMBEDDING_DIM = 256  # Dimension of the token embeddings
NUM_HEADS = 8  # Number of attention heads
NUM_LAYERS = 3  # Number of transformer layers
QK_HEAD_DIM = 16  # Dimension of the query and key heads
V_HEAD_DIM = 32  # Dimension of the value head
MLP_DIM = 1024  # Dimension of the hidden layers in the transformer
DROPOUT_RATE = 0.1  # Dropout rate for regularization

# Data Configuration
VOCAB_SIZE = 512  # Size of the vocabulary
PADDING = True # Whether to pad sequences
PACKING = True # Whether to pack sequences for training

# Training Configuration
SEED = 42  # Random seed for reproducibility
BATCH_SIZE = 128  # Batch size for training
EPOCHS = 10 # Number of epochs to train
SAMPLE_LIMIT = 200000  # Set to None to process the entire dataset
LR = 0.001  # Learning rate for the optimizer
WEIGHT_DECAY = 0.01  # Weight decay for the optimizer
BETA1 = 0.9  # Beta1 for the Adam optimizer
BETA2 = 0.999  # Beta2 for the Adam optimizer

# File Paths and Labels
TOKENIZER_FILE = "./data/tinystories-tokenizer"
CHUNK_FILE = "./data/chunked_stories"
LOG_DIR = None
# LOG_DIR = './runs/2025-08-26_17-09-10'
DICT_LABEL = 'seq'

In [None]:
%load_ext autoreload
%autoreload 2

import os
from datetime import datetime
from pathlib import Path

import numpy as np
import matplotlib.pyplot as plt
from tqdm.auto import tqdm

import mlx.core as mx
import mlx.nn as nn
import mlx.optimizers as optim
import mlx.data as dx
from mlx.utils import tree_flatten

from datasets import load_dataset
from tokenizers import Tokenizer

from tensorboardX import SummaryWriter

from models.mlx import SmallLanguageModel, loss_fn, count_parameters, generate_story
from data.utils import train_tokenizer, chunk_story, data_to_array_of_dict, create_dict_parameters, encode_story, pack_stories, pretty_json

In [None]:
params = create_dict_parameters(locals())
LOG_DIR = f'runs/{datetime.now().strftime("%Y-%m-%d_%H-%M-%S")}' if LOG_DIR is None else LOG_DIR
writer = SummaryWriter(log_dir=LOG_DIR)
if len(list(Path(LOG_DIR).glob('events.out.tfevents.*'))) == 1:
    print(f"Logging parameters")
    writer.add_text('Parameters', pretty_json(params))

# Tokenizer

In [None]:
tokenizer_path = f'{TOKENIZER_FILE}_{VOCAB_SIZE}_{SAMPLE_LIMIT}.json'
if os.path.exists(tokenizer_path):
    tokenizer = Tokenizer.from_file(tokenizer_path)
    print(f"Tokenizer file {tokenizer_path} already exists. Skipping training.")
else:
    dataset = load_dataset("roneneldan/TinyStories", split="train")
    if SAMPLE_LIMIT:
        dataset = dataset.select(range(min(SAMPLE_LIMIT, len(dataset))))
    tokenizer = train_tokenizer(dataset, vocab_size=VOCAB_SIZE, special_tokens=["[UNK]", "[PAD]", "[SOS]", "[EOS]", "\n"])
    tokenizer.save(tokenizer_path)
    print(f"Tokenizer saved to {tokenizer_path}")

sos_token_id = tokenizer.token_to_id('[SOS]')
eos_token_id = tokenizer.token_to_id('[EOS]')
pad_token_id = tokenizer.token_to_id('[PAD]')

tokenizer = Tokenizer.from_file(tokenizer_path)
encoded = tokenizer.encode("Once upon a time, there was a little fox.\nIt lived in a forest and loved to explore.")

print("\n--- Testing the Tokenizer ---")
print("Tokens:", encoded.tokens)
print("IDs:", encoded.ids)
print("Decoded:", tokenizer.decode(encoded.ids, skip_special_tokens=True))

# Chunking

In [None]:
if PADDING and PACKING:
    chunk_file_path = f'{CHUNK_FILE}_{VOCAB_SIZE}_{CONTEXT_LENGTH}_{SAMPLE_LIMIT}_padding_packing.npz'
elif PADDING:
    chunk_file_path = f'{CHUNK_FILE}_{VOCAB_SIZE}_{CONTEXT_LENGTH}_{SAMPLE_LIMIT}_padding.npz'
elif PACKING:
    chunk_file_path = f'{CHUNK_FILE}_{VOCAB_SIZE}_{CONTEXT_LENGTH}_{SAMPLE_LIMIT}_packing.npz'
else:
    chunk_file_path = f'{CHUNK_FILE}_{VOCAB_SIZE}_{CONTEXT_LENGTH}_{SAMPLE_LIMIT}.npz'
figure_path = f'./figures/histogram_{VOCAB_SIZE}_{SAMPLE_LIMIT}.png'
if os.path.exists(chunk_file_path):
    print(f"Chunk file {chunk_file_path} already exists. Skipping chunking.")

    # display the existing histogram
    plt.imshow(plt.imread(figure_path))
    plt.axis('off')
else:
    # Load the dataset
    if not ('dataset' in locals()):
        dataset = load_dataset("roneneldan/TinyStories", split="train")
        if SAMPLE_LIMIT:
            dataset = dataset.select(range(min(SAMPLE_LIMIT, len(dataset))))

    # Load the tokenizer
    tokenizer = Tokenizer.from_file(tokenizer_path)

    # Process all stories and collect chunks
    all_chunks = []
    unfinished_chunk = []
    num_non_special_tokens = []
    for story in tqdm(dataset["text"], desc="Chunking stories"):
        if PACKING:
            story_chunks, non_special_token_count = encode_story(story, tokenizer, '[SOS]', '[EOS]')
            all_chunks.append(story_chunks)
        else:
            story_chunks, unfinished_chunk, non_special_token_count = chunk_story(story, tokenizer, '[SOS]', '[EOS]', CONTEXT_LENGTH,
                                                            unfinished_chunk=unfinished_chunk, padding=PADDING, pad_token='[PAD]')
            all_chunks.extend(story_chunks)
        num_non_special_tokens.append(non_special_token_count)

    # Convert list to numpy array for efficient storage
    if PACKING:
        chunks_array = np.array(pack_stories(all_chunks, CONTEXT_LENGTH, tokenizer.token_to_id('[PAD]')), dtype=np.int32)
    else:
        chunks_array = np.array(all_chunks, dtype=np.int32)
    unique_tokens, counts = np.unique(chunks_array, return_counts=True)

    # Print statistics
    print(f"Total tokens: {CONTEXT_LENGTH * chunks_array.shape[0]:,}")
    print(f"Total non-special tokens: {np.sum(counts[3:]):,}")
    print(f"Number of special tokens: {np.sum(counts[:3]):,}")
    print(f"Array shape: {chunks_array.shape}")

    # Save the chunks to a compressed file
    print(f"Saving chunks to {chunk_file_path}...")
    np.savez_compressed(chunk_file_path, chunks=chunks_array)
    print(f"Saved successfully! File size: {os.path.getsize(chunk_file_path) / (1024 * 1024):.2f} MB")
    if PADDING and PACKING:
        text_info_path = f'./data/chunk_info_{VOCAB_SIZE}_{CONTEXT_LENGTH}_{SAMPLE_LIMIT}_padding_packing.txt'
    elif PADDING:
        text_info_path = f'./data/chunk_info_{VOCAB_SIZE}_{CONTEXT_LENGTH}_{SAMPLE_LIMIT}_padding.txt'
    elif PACKING:
        text_info_path = f'./data/chunk_info_{VOCAB_SIZE}_{CONTEXT_LENGTH}_{SAMPLE_LIMIT}_packing.txt'
    else:
        text_info_path = f'./data/chunk_info_{VOCAB_SIZE}_{CONTEXT_LENGTH}_{SAMPLE_LIMIT}.txt'

    plt.hist(num_non_special_tokens, bins=50, color='blue')
    plt.title("Distribution of Story Lengths")
    plt.xlabel("Length (number of tokens)")
    plt.ylabel("Frequency")
    plt.savefig(figure_path)

    with open(text_info_path, 'w') as f:
        f.write(f"Sample limit: {SAMPLE_LIMIT:,}\n")
        f.write(f"Vocabulary Size: {VOCAB_SIZE:,}\n")
        f.write(f"Context length: {CONTEXT_LENGTH:,}\n")
        f.write(f"Number of chunks: {chunks_array.shape[0]:,}\n")
        f.write(f"Number of tokens: {CONTEXT_LENGTH * chunks_array.shape[0]:,}\n")
        f.write(f"Number of non-special tokens: {np.sum(counts[3:]):,}\n")
        f.write(f"Number of special tokens: {np.sum(counts[:3]):,}\n")
        f.write(f"Padding used: {PADDING}\n")
        f.write(f"Packing used: {PACKING}\n")

# Data Pipeline

In [None]:
mx.random.seed(SEED)
np.random.seed(SEED)
data = np.load(chunk_file_path)
dicts = data_to_array_of_dict(data['chunks'], name=DICT_LABEL)

assert type(dicts) == list
assert type(dicts[0]) == dict
assert type(dicts[0][DICT_LABEL]) == np.ndarray

buffer = dx.buffer_from_vector(dicts)
stream = buffer.to_stream().batch(BATCH_SIZE).shuffle(buffer_size=BATCH_SIZE*100).prefetch(8,1) # For mlx-data 0.0.2 the seed only works with 1 thread
num_batches = len(dicts) // BATCH_SIZE

In [None]:
for x in stream:
    print(x[DICT_LABEL].shape)
    print(type(x[DICT_LABEL]))
    text = tokenizer.decode(x[DICT_LABEL][0], skip_special_tokens=False).split(' ')
    for i in range(0, len(text), 30):
        print(' '.join(text[i:i+30]))
    text = tokenizer.decode(x[DICT_LABEL][BATCH_SIZE-1], skip_special_tokens=False).split(' ')
    for i in range(0, len(text), 30):
        print(' '.join(text[i:i+30]))
    break  # Just to test the first batch
stream.reset()

# Model

In [None]:
model = SmallLanguageModel(vocab_dim=VOCAB_SIZE, embed_dim=EMBEDDING_DIM, n_head=NUM_HEADS, num_layers=NUM_LAYERS, qk_head_dim=QK_HEAD_DIM, v_head_dim=V_HEAD_DIM, mlp_dim=MLP_DIM, max_len=CONTEXT_LENGTH)
num_parameters = count_parameters(model.parameters())
print(f"Number of parameters in the model: {num_parameters:,}")

In [None]:
# search for existing model weights with same vocab size and context length but wildcard epoch number
# load existing model weights if they exist and record the epoch number

matching_paths = list(Path(LOG_DIR).glob(f'model_weights_{VOCAB_SIZE}_{CONTEXT_LENGTH}_*.npz'))
if len(matching_paths) == 0:
    print("No existing model weights found. Starting training from scratch.")
    last_epoch = None
    last_batch = None
elif len(matching_paths) > 1:
    raise ValueError(f"Multiple model weight files found for vocab size {VOCAB_SIZE} and context length {CONTEXT_LENGTH}. Please ensure only one exists.")
else:
    path = matching_paths[0]
    print(f"Found existing model weights: {path.name}")
    # Load the model weights
    model.load_weights(str(path))
    # Extract epoch number from filename
    weight_name = path.stem.split('_')
    if len(weight_name) == 5:
        last_epoch = int(path.stem.split('_')[-1]) + 1 # start from next epoch
        last_batch = None
    elif len(weight_name) == 6:
        last_epoch = int(path.stem.split('_')[-2])
        last_batch = int(path.stem.split('_')[-1])
    else:
        raise ValueError(f"Unexpected filename format: {path.name}")
    print(f"Loaded model weights from epoch {last_epoch}, batch {last_batch}.")

optimizer = optim.AdamW(learning_rate=LR, betas=[BETA1, BETA2], weight_decay=WEIGHT_DECAY)
if os.path.exists(f'{LOG_DIR}/optimizer.safetensors'):
    print(f"Loading optimizer state from {LOG_DIR}/optimizer.safetensors")
    state = mx.utils.tree_unflatten(mx.load("optimizer.safetensors"))
    optimizer.state = state

# Training

In [None]:
loss_and_grad = nn.value_and_grad(model, loss_fn)
print(f'MLX current default device: {mx.default_device()}')

In [None]:
save_freq = 200
model.train()
for epoch in range(EPOCHS):
    losses = []
    if last_epoch and epoch + 1 < last_epoch:
        stream.reset()
        continue
    for i, seq in enumerate(stream):
        if last_batch and epoch + 1 == last_epoch and i + 1 <= last_batch:
            continue
        mx_seq = mx.array(seq[DICT_LABEL])
        input_seq = mx_seq[:, :-1]  # Exclude the last token for input
        target_seq = mx_seq[:, 1:]  # Exclude the first token for target
        if PADDING:
            loss, grads = loss_and_grad(model, input_seq, target_seq, pad_token_id=pad_token_id)
        else:
            loss, grads = loss_and_grad(model, input_seq, target_seq)
        optimizer.update(model, grads)
        mx.eval(model.parameters(), optimizer.state)
        print(f"Batch {i + 1}, Loss: {loss:.4f}")
        writer.add_scalar('Loss/train', loss.item(), epoch * num_batches + i)
        if (i+1) % save_freq == 0:
            generate_story(model, tokenizer, "[SOS]", max_length=CONTEXT_LENGTH, eos_token_id=eos_token_id, temp=1.0)
            model.save_weights(f'{LOG_DIR}/model_weights_{VOCAB_SIZE}_{CONTEXT_LENGTH}_{epoch+1}_{i+1}.npz')
            state = tree_flatten(optimizer.state, destination={})
            mx.save_safetensors(f'{LOG_DIR}/optimizer.safetensors', state)
            if i+1 != save_freq:
                prev_save_path = f'{LOG_DIR}/model_weights_{VOCAB_SIZE}_{CONTEXT_LENGTH}_{epoch+1}_{i+1-save_freq}.npz'
                if os.path.exists(prev_save_path):
                    os.remove(prev_save_path)
            print('-'*20)
            print(f"Active memory: {mx.get_active_memory() / 1024**3:.2f} GB")
            print(f"Cache memory: {mx.get_cache_memory() / 1024**3:.2f} GB")
            print(f"Peak memory: {mx.get_peak_memory() / 1024**3:.2f} GB")
            mx.clear_cache()
            print('-'*20)
        losses.append(loss)

    avg_loss = mx.array(losses).mean()
    print(f"Epoch {epoch + 1}, Average Loss: {avg_loss:.4f}")
    writer.add_scalar('Loss/epoch_train', avg_loss.item(), epoch)
    generate_story(model, tokenizer, "[SOS]", max_length=CONTEXT_LENGTH, eos_token_id=eos_token_id, temp=1.0)
    matching_paths = list(Path(LOG_DIR).glob(f'model_weights_{VOCAB_SIZE}_{CONTEXT_LENGTH}_{epoch+1}_*.npz'))
    if len(matching_paths) > 0:
        os.remove(matching_paths[0])
    model.save_weights(f'{LOG_DIR}/model_weights_{VOCAB_SIZE}_{CONTEXT_LENGTH}_{epoch+1}.npz')
    if epoch + 1 > 1:
        prev_epoch_path = f'{LOG_DIR}/model_weights_{VOCAB_SIZE}_{CONTEXT_LENGTH}_{epoch}.npz'
        if os.path.exists(prev_epoch_path):
            os.remove(prev_epoch_path)
writer.add_hparams(params, {'hparam/last_loss': avg_loss.item()})

In [None]:
generate_story(model, tokenizer, "[SOS]", max_length=CONTEXT_LENGTH, eos_token_id=eos_token_id, temp=0.2)