# Configuration

In [1]:
# Configuration
VOCAB_SIZE = 10000  # Size of the vocabulary
CONTEXT_LENGTH = 256  # Fixed context length for chunks
BATCH_SIZE = 32  # Batch size for training
TOKENIZER_FILE = "./data/tinystories-tokenizer.json"
OUTPUT_FILE = "./data/chunked_stories.npz"
SAMPLE_LIMIT = None  # Set to None to process the entire dataset
DICT_LABEL = 'seq'

In [2]:
%load_ext autoreload
%autoreload 2

import os

import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
from tqdm.auto import tqdm

import mlx.core as mx
import mlx.nn as nn
import mlx.data as dx

from datasets import load_dataset

from tokenizers import Tokenizer
from tokenizers.models import BPE
from tokenizers.trainers import BpeTrainer
from tokenizers.pre_tokenizers import Whitespace

from llm.modules import TransformerEncoder
from llm.data import chunk_story, data_to_array_of_dict

# Model

In [3]:
# create small language model transformer encoder

class SmallLanguageModel(nn.Module):
    def __init__(self, vocab_dim: int, embed_dim, n_head, num_layers, max_len=512, mlp_dim=2048):
        super(SmallLanguageModel, self).__init__()
        self.vocab_dim = vocab_dim
        self.embed_dim = embed_dim
        self.n_head = n_head
        self.num_layers = num_layers
        self.mlp_dim = mlp_dim
        self.max_len = max_len

        # self.softmax = nn.Softmax(axis=-1)
        # self.transformer_encoder = TransformerEncoder(vocab_dim, embed_dim, n_head, num_layers, max_len, kq_dim, mlp_dim)
        self.embedding = nn.Embedding(vocab_dim, embed_dim)
        self.transformer_layer = TransformerEncoder(embed_dim, n_head, num_layers, max_len, mlp_dim)
        self.output_proj = nn.Sequential(
            nn.Linear(embed_dim, vocab_dim, bias=False),  # Output projection layer
            nn.Softmax()  # Softmax to convert logits to probabilities
        )

    def __call__(self, x, mask: mx.array = None):
        # x is of shape (batch_size, seq_len)
        # print(f'Input x.shape: {x.shape}')
        # Convert input indices to embeddings
        x = self.embedding(x)  # Shape: (batch_size, seq_len, embed
        # print(f'After embedding, x.shape: {x.shape}')
        # Pass through the transformer encoder layer
        x = self.transformer_layer(x, mask)
        # print(f'After transformer layer, x.shape: {x.shape}')
        # Pass through the output projection layer
        x = self.output_proj(x)  # Shape: (batch_size, seq_len, vocab_dim)
        # print(f'After output projection, x.shape: {x.shape}')
        return x

In [4]:
def create_causal_mask_triu(L: int):
    # Create a boolean matrix where the upper triangle (excluding the diagonal) is True
    mask = mx.triu(mx.ones((L, L)), k=1).astype(mx.bool_)
    return mx.where(mask, -1e9, 0.0)[None, None, :, :]  # Add batch and head dimensions

In [5]:
model = SmallLanguageModel(vocab_dim=VOCAB_SIZE, embed_dim=4, n_head=2, num_layers=1)
x = mx.random.uniform(high=VOCAB_SIZE, shape=(32, 4)).astype(mx.int32)
# create mask to prevent attention to future tokens
mask = create_causal_mask_triu(x.shape[1])
output = model(x, mask)  # Forward pass

# Tokenizer

In [6]:
# from datasets import load_dataset
# ds = load_dataset("roneneldan/TinyStories")
# ds

In [7]:
# train_data = ds['train']
# print(f'Train data shape: {train_data.shape}')
# valid_data = ds['validation']
# print(f'Validation data shape: {valid_data.shape}')
# print('\n---------------------------------\n')
# print('This is a sample from the training data:\n')
# print(train_data[0]['text'])

In [8]:
# text_file_path = './data/tinystories_data.txt'

# with open(text_file_path, "w", encoding="utf-8") as f:
#     for example in train_data:
#         f.write(example['text'] + "\n")

In [9]:
# # Initialize a BPE tokenizer
# tokenizer = Tokenizer(BPE(unk_token="[UNK]"))
# tokenizer.pre_tokenizer = Whitespace()

# # Configure the trainer with a vocabulary size and special tokens
# trainer = BpeTrainer(vocab_size=VOCAB_SIZE, special_tokens=["[UNK]", "[PAD]", "[SOS]", "[EOS]"])

# # Train the tokenizer on our text file
# print("Training tokenizer...")
# tokenizer.train(['./data/tinystories_data.txt'], trainer)
# print("Training complete.")

# # --- Save and Test the Tokenizer ---
# tokenizer_path = "./data/tinystories-tokenizer.json"
# tokenizer.save(tokenizer_path)
# print(f"Tokenizer saved to {tokenizer_path}")

# # Load it back and test
# tokenizer = Tokenizer.from_file(tokenizer_path)
# encoded = tokenizer.encode("Once upon a time, there was a little fox.")

# print("\n--- Testing the Tokenizer ---")
# print("Tokens:", encoded.tokens)
# print("IDs:", encoded.ids)

# Chunking

In [10]:
# # Load the tokenizer
# tokenizer = Tokenizer.from_file(TOKENIZER_FILE)

# # Load the dataset (use a subset for testing)
# dataset = load_dataset("roneneldan/TinyStories", split="train")
# if SAMPLE_LIMIT:
#     dataset = dataset.select(range(min(SAMPLE_LIMIT, len(dataset))))

# # Process all stories and collect chunks
# all_chunks = []
# for story in tqdm(dataset["text"], desc="Chunking stories"):
#     story_chunks = chunk_story(story, tokenizer, '[EOS]', '[PAD]', CONTEXT_LENGTH)
#     all_chunks.extend(story_chunks)

# # Convert list to numpy array for efficient storage
# chunks_array = np.array(all_chunks, dtype=np.int32)

# # Print statistics
# print(f"Created {len(all_chunks)} chunks of length {CONTEXT_LENGTH}")
# print(f"Total tokens: {len(all_chunks) * CONTEXT_LENGTH:,}")
# print(f"Array shape: {chunks_array.shape}")

# # Save the chunks to a compressed file
# print(f"Saving chunks to {OUTPUT_FILE}...")
# np.savez_compressed(OUTPUT_FILE, chunks=chunks_array)
# print(f"Saved successfully! File size: {os.path.getsize(OUTPUT_FILE) / (1024 * 1024):.2f} MB")

# Data Pipeline

In [11]:
data = np.load(OUTPUT_FILE)
dicts = data_to_array_of_dict(data['chunks'], name=DICT_LABEL)

assert type(dicts) == list
assert type(dicts[0]) == dict
assert type(dicts[0][DICT_LABEL]) == np.ndarray

buffer = dx.buffer_from_vector(dicts)
stream = buffer.shuffle().to_stream().batch(32).prefetch(8,4)

In [12]:
for x in stream:
    print(x[DICT_LABEL].shape)
    print(type(x[DICT_LABEL]))
    break  # Just to test the first batch

(32, 256)
<class 'numpy.ndarray'>


# Training

In [13]:
tokenizer = Tokenizer.from_file(TOKENIZER_FILE)
pad_token_id = tokenizer.token_to_id('[PAD]')

In [14]:
for epoch in range(2):
    losses = []
    for i, seq in tqdm(enumerate(stream), desc="Processing batches"):
        mx_seq = mx.array(seq[DICT_LABEL])
        input_seq = mx_seq[:, :-1]  # Exclude the last token for input
        target_seq = mx_seq[:, 1:]  # Exclude the first token for target
        # Create mask to prevent attention to future tokens
        mask = create_causal_mask_triu(input_seq.shape[1])
        # Forward pass through the model
        logits = model(input_seq, mask)
        # compute loss
        loss = nn.losses.cross_entropy(logits, target_seq, reduction='none')  # Assuming 0 is the padding index
        padding_mask = (target_seq != pad_token_id).astype(mx.float32)
        loss = (loss * padding_mask).sum(axis=-1) / padding_mask.sum(axis=-1)
        loss = loss.mean()  # Average loss over the batch
        losses.append(loss)
        print(f"Batch {i + 1}, Loss: {loss:.4f}")
    # print average loss for the epoch with tqdm progress bar
    avg_loss = mx.array(losses).mean()
    tqdm.write(f"Epoch {epoch + 1}/{10}, Average Loss: {avg_loss:.4f}")

Processing batches: 0it [00:00, ?it/s]

Batch 1, Loss: 9.2103
Batch 2, Loss: 9.2103
Batch 3, Loss: nan
Batch 4, Loss: 9.2103
Batch 5, Loss: 9.2103
Batch 6, Loss: 9.2103
Batch 7, Loss: 9.2103
Batch 8, Loss: 9.2103
Batch 9, Loss: 9.2103
Batch 10, Loss: 9.2103
Batch 11, Loss: 9.2103
Batch 12, Loss: 9.2103
Batch 13, Loss: 9.2103
Batch 14, Loss: 9.2103
Batch 15, Loss: 9.2103
Batch 16, Loss: 9.2103
Batch 17, Loss: 9.2103
Batch 18, Loss: nan
Batch 19, Loss: 9.2103
Batch 20, Loss: 9.2103
Batch 21, Loss: 9.2103
Batch 22, Loss: 9.2103
Batch 23, Loss: 9.2103
Batch 24, Loss: 9.2103
Batch 25, Loss: 9.2103
Batch 26, Loss: 9.2103
Batch 27, Loss: nan
Batch 28, Loss: 9.2103
Batch 29, Loss: 9.2103
Batch 30, Loss: 9.2103
Batch 31, Loss: 9.2103
Batch 32, Loss: 9.2103
Batch 33, Loss: 9.2103
Batch 34, Loss: 9.2103
Batch 35, Loss: 9.2103
Batch 36, Loss: 9.2103
Batch 37, Loss: 9.2103
Batch 38, Loss: 9.2103
Batch 39, Loss: 9.2103
Batch 40, Loss: 9.2103
Batch 41, Loss: 9.2103
Batch 42, Loss: 9.2103
Batch 43, Loss: 9.2103
Batch 44, Loss: 9.2103
Batch 

KeyboardInterrupt: 

In [None]:
print(f'MLX current default device: {mx.default_device()}')

MLX current default device: Device(gpu, 0)
