# Configuration

In [1]:
# Configuration
VOCAB_SIZE = 4096  # Size of the vocabulary
CONTEXT_LENGTH = 512  # Fixed context length for chunks
EMBEDDING_DIM = 512  # Dimension of the token embeddings
NUM_HEADS = 8  # Number of attention heads
NUM_LAYERS = 6  # Number of transformer layers
HIDDEN_DIM = 2048  # Dimension of the hidden layers in the transformer
BATCH_SIZE = 64  # Batch size for training
EPOCHS = 20 # Number of epochs to train
TOKENIZER_FILE = "./data/tinystories-tokenizer"
CHUNK_FILE = "./data/chunked_stories"
SAMPLE_LIMIT = None  # Set to None to process the entire dataset
DICT_LABEL = 'seq'

In [2]:
%load_ext autoreload
%autoreload 2

import os
from pathlib import Path

import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
from tqdm.auto import tqdm

import mlx.core as mx
import mlx.nn as nn
import mlx.optimizers as optim
import mlx.data as dx

from datasets import load_dataset

from tokenizers import Tokenizer
from tokenizers.models import BPE
from tokenizers.trainers import BpeTrainer
from tokenizers.pre_tokenizers import Whitespace

from llm.modules import SmallLanguageModel, loss_fn, create_causal_mask_triu, count_parameters, generate_story
from llm.data import chunk_story, data_to_array_of_dict

# Merge dataset to txt file

In [3]:
# from datasets import load_dataset
# ds = load_dataset("roneneldan/TinyStories")
# ds

In [4]:
# train_data = ds['train']
# print(f'Train data shape: {train_data.shape}')
# valid_data = ds['validation']
# print(f'Validation data shape: {valid_data.shape}')
# print('\n---------------------------------\n')
# print('This is a sample from the training data:\n')
# print(train_data[0]['text'])

In [5]:
# text_file_path = './data/tinystories_data.txt'

# with open(text_file_path, "w", encoding="utf-8") as f:
#     for example in train_data:
#         f.write(example['text'] + "\n")

# Tokenizer

In [6]:
if os.path.exists(f'{TOKENIZER_FILE}_{VOCAB_SIZE}.json'):
    print(f"Tokenizer file {TOKENIZER_FILE}_{VOCAB_SIZE}.json already exists. Skipping training.")
else:
    # Initialize a BPE tokenizer
    tokenizer = Tokenizer(BPE(unk_token="[UNK]"))
    tokenizer.pre_tokenizer = Whitespace()

    # Configure the trainer with a vocabulary size and special tokens
    trainer = BpeTrainer(vocab_size=VOCAB_SIZE, special_tokens=["[UNK]", "[PAD]", "[SOS]", "[EOS]"])

    # Train the tokenizer on our text file
    print("Training tokenizer...")
    tokenizer.train(['./data/tinystories_data.txt'], trainer)
    print("Training complete.")

    # --- Save and Test the Tokenizer ---
    tokenizer_path = f'{TOKENIZER_FILE}_{VOCAB_SIZE}.json'
    tokenizer.save(tokenizer_path)
    print(f"Tokenizer saved to {tokenizer_path}")

    # Load it back and test
    tokenizer = Tokenizer.from_file(tokenizer_path)
    encoded = tokenizer.encode("Once upon a time, there was a little fox.")

    print("\n--- Testing the Tokenizer ---")
    print("Tokens:", encoded.tokens)
    print("IDs:", encoded.ids)

Tokenizer file ./data/tinystories-tokenizer_4096.json already exists. Skipping training.


# Chunking

In [7]:
if os.path.exists(f'{CHUNK_FILE}_{VOCAB_SIZE}_{CONTEXT_LENGTH}.npz'):
    print(f"Chunk file {CHUNK_FILE}_{VOCAB_SIZE}_{CONTEXT_LENGTH}.npz already exists. Skipping chunking.")
else:
    # Load the dataset (use a subset for testing)
    dataset = load_dataset("roneneldan/TinyStories", split="train")
    if SAMPLE_LIMIT:
        dataset = dataset.select(range(min(SAMPLE_LIMIT, len(dataset))))

    # Load the tokenizer
    tokenizer = Tokenizer.from_file(f'{TOKENIZER_FILE}_{VOCAB_SIZE}.json')

    # Process all stories and collect chunks
    all_chunks = []
    for story in tqdm(dataset["text"], desc="Chunking stories"):
        story_chunks = chunk_story(story, tokenizer, '[EOS]', '[PAD]', CONTEXT_LENGTH)
        all_chunks.extend(story_chunks)

    # Convert list to numpy array for efficient storage
    chunks_array = np.array(all_chunks, dtype=np.int32)

    # Print statistics
    print(f"Created {len(all_chunks)} chunks of length {CONTEXT_LENGTH}")
    print(f"Total tokens: {len(all_chunks) * CONTEXT_LENGTH:,}")
    print(f"Array shape: {chunks_array.shape}")

    # Save the chunks to a compressed file
    print(f"Saving chunks to {CHUNK_FILE}_{VOCAB_SIZE}_{CONTEXT_LENGTH}.npz...")
    np.savez_compressed(f'{CHUNK_FILE}_{VOCAB_SIZE}_{CONTEXT_LENGTH}.npz', chunks=chunks_array)
    print(f"Saved successfully! File size: {os.path.getsize(f'{CHUNK_FILE}_{VOCAB_SIZE}_{CONTEXT_LENGTH}.npz') / (1024 * 1024):.2f} MB")

Chunk file ./data/chunked_stories_4096_512.npz already exists. Skipping chunking.


# Data Pipeline

In [8]:
data = np.load(f'{CHUNK_FILE}_{VOCAB_SIZE}_{CONTEXT_LENGTH}.npz')
dicts = data_to_array_of_dict(data['chunks'], name=DICT_LABEL)

assert type(dicts) == list
assert type(dicts[0]) == dict
assert type(dicts[0][DICT_LABEL]) == np.ndarray

buffer = dx.buffer_from_vector(dicts)
stream = buffer.to_stream().batch(BATCH_SIZE).shuffle(buffer_size=BATCH_SIZE*1000).prefetch(8,4)

In [9]:
for x in stream:
    print(x[DICT_LABEL].shape)
    print(type(x[DICT_LABEL]))
    break  # Just to test the first batch

(64, 512)
<class 'numpy.ndarray'>


# Model

In [10]:
model = SmallLanguageModel(vocab_dim=VOCAB_SIZE, embed_dim=EMBEDDING_DIM, n_head=NUM_HEADS, num_layers=NUM_LAYERS, mlp_dim=HIDDEN_DIM, max_len=CONTEXT_LENGTH)
x = mx.random.uniform(high=VOCAB_SIZE, shape=(32, 4)).astype(mx.int32)
# create mask to prevent attention to future tokens
mask = create_causal_mask_triu(x.shape[1])
output = model(x, mask)  # Forward pass
# check number of parameters
print(f"Number of parameters in the model: {count_parameters(model.parameters()):,}")

Number of parameters in the model: 2,359,296


In [11]:

# search for existing model weights with same vocab size and context length but wildcard epoch number
# load existing model weights if they exist and record the epoch number

matching_paths = list(Path('./data').glob(f'model_weights_{VOCAB_SIZE}_{CONTEXT_LENGTH}_*.npz'))
if len(matching_paths) == 0:
    print("No existing model weights found. Starting training from scratch.")
    last_epoch = None
    last_batch = None
elif len(matching_paths) > 1:
    raise ValueError(f"Multiple model weight files found for vocab size {VOCAB_SIZE} and context length {CONTEXT_LENGTH}. Please ensure only one exists.")
else:
    path = matching_paths[0]
    print(f"Found existing model weights: {path.name}")
    # Load the model weights
    model.load_weights(str(path))
    # Extract epoch number from filename
    last_epoch = int(path.stem.split('_')[-2])
    last_batch = int(path.stem.split('_')[-1])
    print(f"Loaded model weights from epoch {last_epoch}, batch {last_batch}.")

Found existing model weights: model_weights_4096_512_1_3800.npz
Loaded model weights from epoch 1, batch 3800.


# Training

In [12]:
tokenizer = Tokenizer.from_file(f'{TOKENIZER_FILE}_{VOCAB_SIZE}.json')
pad_token_id = tokenizer.token_to_id('[PAD]')
eos_token_id = tokenizer.token_to_id('[EOS]')
optimizer = optim.AdamW(learning_rate=0.0005, betas=[0.9, 0.95], weight_decay=0.1)
loss_and_grad_fn = nn.value_and_grad(model, loss_fn)
print(f'MLX current default device: {mx.default_device()}')

MLX current default device: Device(gpu, 0)


In [None]:
save_freq = 200
for epoch in range(EPOCHS):
    losses = []
    if epoch + 1 < last_epoch:
        continue
    for i, seq in enumerate(stream):
        if i + 1 <= last_batch:
            continue
        mx_seq = mx.array(seq[DICT_LABEL])
        input_seq = mx_seq[:, :-1]  # Exclude the last token for input
        target_seq = mx_seq[:, 1:]  # Exclude the first token for target
        loss, grads = loss_and_grad_fn(model, input_seq, target_seq, pad_token_id)
        optimizer.update(model, grads)
        mx.eval(model.parameters(), optimizer.state)
        print(f"Batch {i + 1}, Loss: {loss:.4f}")
        if (i+1) % save_freq == 0:
            generate_story(model, tokenizer, "Once upon a time", max_length=CONTEXT_LENGTH, eos_token_id=eos_token_id, temp=1.0)
            model.save_weights(f'./data/model_weights_{VOCAB_SIZE}_{CONTEXT_LENGTH}_{epoch+1}_{i+1}.npz')
            if i+1 != save_freq: os.remove(f'./data/model_weights_{VOCAB_SIZE}_{CONTEXT_LENGTH}_{epoch+1}_{i+1-save_freq}.npz') if i > 0 else None
            print('-'*20)
            print(f"Active memory: {mx.get_active_memory() / 1024**3:.2f} GB")
            print(f"Cache memory: {mx.get_cache_memory() / 1024**3:.2f} GB")
            print(f"Peak memory: {mx.get_peak_memory() / 1024**3:.2f} GB")
            mx.clear_cache()
            print('-'*20)
        losses.append(loss)
        stream.reset()
    avg_loss = mx.array(losses).mean()

Batch 3801, Loss: 2.7806
Batch 3802, Loss: 5.0575
Batch 3803, Loss: 2.9617
Batch 3804, Loss: 3.0233
Batch 3805, Loss: 3.5754
Batch 3806, Loss: 4.2433
Batch 3807, Loss: 2.7067
Batch 3808, Loss: 3.0185
Batch 3809, Loss: 4.2101
Batch 3810, Loss: 4.9735
Batch 3811, Loss: 3.0092
Batch 3812, Loss: 4.3553
Batch 3813, Loss: 5.3027
Batch 3814, Loss: 2.9390
Batch 3815, Loss: 3.2042
Batch 3816, Loss: 3.2187
Batch 3817, Loss: 3.0645
Batch 3818, Loss: 3.8630
Batch 3819, Loss: 2.8343
Batch 3820, Loss: 4.3199
Batch 3821, Loss: 3.3476
Batch 3822, Loss: 3.7546
Batch 3823, Loss: 2.9124
Batch 3824, Loss: 4.4151
Batch 3825, Loss: 3.1127
Batch 3826, Loss: 3.1227
Batch 3827, Loss: 4.2164
Batch 3828, Loss: 2.8742
Batch 3829, Loss: 2.9414
Batch 3830, Loss: 2.8198
Batch 3831, Loss: 3.1163
Batch 3832, Loss: 3.0606
Batch 3833, Loss: 2.8104
Batch 3834, Loss: 2.8589
Batch 3835, Loss: 3.0172
Batch 3836, Loss: 3.2655
Batch 3837, Loss: 4.8513
Batch 3838, Loss: 3.1041
Batch 3839, Loss: 4.3159
Batch 3840, Loss: 3.1917
