# Configuration

In [None]:
# Configuration
VOCAB_SIZE = 2048  # Size of the vocabulary
CONTEXT_LENGTH = 512  # Fixed context length for chunks
EMBEDDING_DIM = 512  # Dimension of the token embeddings
NUM_HEADS = 8  # Number of attention heads
NUM_LAYERS = 6  # Number of transformer layers
HIDDEN_DIM = 2048  # Dimension of the hidden layers in the transformer
BATCH_SIZE = 128  # Batch size for training
EPOCHS = 20 # Number of epochs to train
TOKENIZER_FILE = "./data/tinystories-tokenizer"
CHUNK_FILE = "./data/chunked_stories"
SAMPLE_LIMIT = None  # Set to None to process the entire dataset
DICT_LABEL = 'seq'

In [None]:
%load_ext autoreload
%autoreload 2

import os

import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
from tqdm.auto import tqdm

import mlx.core as mx
import mlx.nn as nn
import mlx.optimizers as optim
import mlx.data as dx

from datasets import load_dataset

from tokenizers import Tokenizer
from tokenizers.models import BPE
from tokenizers.trainers import BpeTrainer
from tokenizers.pre_tokenizers import Whitespace

from llm.modules import SmallLanguageModel, loss_fn, create_causal_mask_triu, count_parameters, generate_story
from llm.data import chunk_story, data_to_array_of_dict

# Merge dataset to txt file

In [None]:
# from datasets import load_dataset
# ds = load_dataset("roneneldan/TinyStories")
# ds

In [None]:
# train_data = ds['train']
# print(f'Train data shape: {train_data.shape}')
# valid_data = ds['validation']
# print(f'Validation data shape: {valid_data.shape}')
# print('\n---------------------------------\n')
# print('This is a sample from the training data:\n')
# print(train_data[0]['text'])

In [None]:
# text_file_path = './data/tinystories_data.txt'

# with open(text_file_path, "w", encoding="utf-8") as f:
#     for example in train_data:
#         f.write(example['text'] + "\n")

# Tokenizer

In [None]:
if os.path.exists(f'{TOKENIZER_FILE}_{VOCAB_SIZE}.json'):
    print(f"Tokenizer file {TOKENIZER_FILE}_{VOCAB_SIZE}.json already exists. Skipping training.")
else:
    # Initialize a BPE tokenizer
    tokenizer = Tokenizer(BPE(unk_token="[UNK]"))
    tokenizer.pre_tokenizer = Whitespace()

    # Configure the trainer with a vocabulary size and special tokens
    trainer = BpeTrainer(vocab_size=VOCAB_SIZE, special_tokens=["[UNK]", "[PAD]", "[SOS]", "[EOS]"])

    # Train the tokenizer on our text file
    print("Training tokenizer...")
    tokenizer.train(['./data/tinystories_data.txt'], trainer)
    print("Training complete.")

    # --- Save and Test the Tokenizer ---
    tokenizer_path = f'{TOKENIZER_FILE}_{VOCAB_SIZE}.json'
    tokenizer.save(tokenizer_path)
    print(f"Tokenizer saved to {tokenizer_path}")

    # Load it back and test
    tokenizer = Tokenizer.from_file(tokenizer_path)
    encoded = tokenizer.encode("Once upon a time, there was a little fox.")

    print("\n--- Testing the Tokenizer ---")
    print("Tokens:", encoded.tokens)
    print("IDs:", encoded.ids)

# Chunking

In [None]:
if os.path.exists(f'{CHUNK_FILE}_{VOCAB_SIZE}_{CONTEXT_LENGTH}.npz'):
    print(f"Chunk file {CHUNK_FILE}_{VOCAB_SIZE}_{CONTEXT_LENGTH}.npz already exists. Skipping chunking.")
else:
    # Load the dataset (use a subset for testing)
    dataset = load_dataset("roneneldan/TinyStories", split="train")
    if SAMPLE_LIMIT:
        dataset = dataset.select(range(min(SAMPLE_LIMIT, len(dataset))))

    # Load the tokenizer
    tokenizer = Tokenizer.from_file(f'{TOKENIZER_FILE}_{VOCAB_SIZE}.json')

    # Process all stories and collect chunks
    all_chunks = []
    for story in tqdm(dataset["text"], desc="Chunking stories"):
        story_chunks = chunk_story(story, tokenizer, '[EOS]', '[PAD]', CONTEXT_LENGTH)
        all_chunks.extend(story_chunks)

    # Convert list to numpy array for efficient storage
    chunks_array = np.array(all_chunks, dtype=np.int32)

    # Print statistics
    print(f"Created {len(all_chunks)} chunks of length {CONTEXT_LENGTH}")
    print(f"Total tokens: {len(all_chunks) * CONTEXT_LENGTH:,}")
    print(f"Array shape: {chunks_array.shape}")

    # Save the chunks to a compressed file
    print(f"Saving chunks to {CHUNK_FILE}_{VOCAB_SIZE}_{CONTEXT_LENGTH}.npz...")
    np.savez_compressed(f'{CHUNK_FILE}_{VOCAB_SIZE}_{CONTEXT_LENGTH}.npz', chunks=chunks_array)
    print(f"Saved successfully! File size: {os.path.getsize(f'{CHUNK_FILE}_{VOCAB_SIZE}_{CONTEXT_LENGTH}.npz') / (1024 * 1024):.2f} MB")

# Data Pipeline

In [None]:
data = np.load(f'{CHUNK_FILE}_{VOCAB_SIZE}_{CONTEXT_LENGTH}.npz')
dicts = data_to_array_of_dict(data['chunks'], name=DICT_LABEL)

assert type(dicts) == list
assert type(dicts[0]) == dict
assert type(dicts[0][DICT_LABEL]) == np.ndarray

buffer = dx.buffer_from_vector(dicts)
stream = buffer.shuffle().to_stream().batch(32).prefetch(8,4)

In [None]:
for x in stream:
    print(x[DICT_LABEL].shape)
    print(type(x[DICT_LABEL]))
    break  # Just to test the first batch

# Model

In [None]:
model = SmallLanguageModel(vocab_dim=VOCAB_SIZE, embed_dim=EMBEDDING_DIM, n_head=NUM_HEADS, num_layers=NUM_LAYERS, mlp_dim=HIDDEN_DIM, max_len=CONTEXT_LENGTH)
x = mx.random.uniform(high=VOCAB_SIZE, shape=(32, 4)).astype(mx.int32)
# create mask to prevent attention to future tokens
mask = create_causal_mask_triu(x.shape[1])
output = model(x, mask)  # Forward pass
# check number of parameters
print(f"Number of parameters in the model: {count_parameters(model.parameters()):,}")

# Training

In [None]:
tokenizer = Tokenizer.from_file(f'{TOKENIZER_FILE}_{VOCAB_SIZE}.json')
pad_token_id = tokenizer.token_to_id('[PAD]')
eos_token_id = tokenizer.token_to_id('[EOS]')
optimizer = optim.AdamW(learning_rate=0.0005, betas=[0.9, 0.95], weight_decay=0.1)
loss_and_grad_fn = nn.value_and_grad(model, loss_fn)

In [None]:
print(f'MLX current default device: {mx.default_device()}')

In [None]:
for epoch in range(EPOCHS):
    losses = []
    # for i, seq in tqdm(enumerate(stream), desc="Processing batches"):
    for i, seq in enumerate(stream):
        mx_seq = mx.array(seq[DICT_LABEL])
        input_seq = mx_seq[:, :-1]  # Exclude the last token for input
        target_seq = mx_seq[:, 1:]  # Exclude the first token for target
        loss, grads = loss_and_grad_fn(model, input_seq, target_seq, pad_token_id)
        optimizer.update(model, grads)
        mx.eval(model.parameters(), optimizer.state)
        print(f"Batch {i + 1}, Loss: {loss:.4f}")
        if (i+1) % 300 == 0:
            generate_story(model, tokenizer, "Once upon a time", max_length=CONTEXT_LENGTH, eos_token_id=eos_token_id, temp=1.0)
    avg_loss = mx.array(losses).mean()
    # tqdm.write(f"Epoch {epoch + 1}/{10}, Average Loss: {avg_loss:.4f}")