## Loading and Processing Data

In [None]:
!pip install tiktoken
!pip install datasets

Collecting tiktoken
  Downloading tiktoken-0.11.0-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.metadata (6.7 kB)
Downloading tiktoken-0.11.0-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (1.2 MB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m1.2/1.2 MB[0m [31m14.4 MB/s[0m eta [36m0:00:00[0m
[?25hInstalling collected packages: tiktoken
Successfully installed tiktoken-0.11.0
Collecting datasets
  Downloading datasets-4.0.0-py3-none-any.whl.metadata (19 kB)
Collecting dill<0.3.9,>=0.3.0 (from datasets)
  Downloading dill-0.3.8-py3-none-any.whl.metadata (10 kB)
Collecting xxhash (from datasets)
  Downloading xxhash-3.5.0-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.metadata (12 kB)
Collecting multiprocess<0.70.17 (from datasets)
  Downloading multiprocess-0.70.16-py312-none-any.whl.metadata (7.2 kB)
Collecting fsspec<=2025.3.0,>=2023.1.0 (from fsspec[http]<=2025.3.0,>=2023.1.0->datasets)
  Downloading fsspec-2025.3.0-py3-none-a

In [None]:
import tiktoken
import re

import jax.numpy as jnp
import jax

import optax
import orbax.checkpoint as ocp


import torch
from torch.utils.data import Dataset, DataLoader

from typing import TypedDict

from datasets import load_dataset

In [None]:
class TinyStoriesDataset(Dataset):
    def __init__(self, hf_dataset, tokenizer, max_length, stride):
        self.tokenizer = tokenizer
        self.input_ids = []
        self.target_ids = []

        # Concatenate all stories into one long list of token IDs
        all_token_ids = []
        for example in hf_dataset:
            text = example['text']
            # Add an End-Of-Sequence token to separate stories
            token_ids = tokenizer.encode(text) + [tokenizer.eot_token]
            all_token_ids.extend(token_ids)

        # Create overlapping chunks from the concatenated sequence
        for i in range(0, len(all_token_ids) - max_length - 1, stride):
            input_chunk = all_token_ids[i : i + max_length]
            target_chunk = all_token_ids[i + 1 : i + max_length + 1]
            self.input_ids.append(torch.tensor(input_chunk))
            self.target_ids.append(torch.tensor(target_chunk))

    def __len__(self):
        return len(self.input_ids)

    def __getitem__(self, idx):
        return self.input_ids[idx], self.target_ids[idx]

In [None]:
def create_dataloader(hf_dataset_split, batch_size=4, max_length=256,
                         stride=128, shuffle=True, drop_last=True,
                         num_workers=0):
    tokenizer = tiktoken.get_encoding('gpt2')
    dataset = TinyStoriesDataset(hf_dataset_split, tokenizer, max_length, stride)
    dataloader = DataLoader(
        dataset, batch_size=batch_size, shuffle=shuffle,
        drop_last=drop_last, num_workers=num_workers
    )

    return dataloader

In [None]:
class AttentionParams(TypedDict):
    W_query: jax.Array
    W_key: jax.Array
    W_value: jax.Array
    W_out: jax.Array

class FeedForwardParams(TypedDict):
    W_l1: jax.Array
    W_l2: jax.Array

class LayerNormParams(TypedDict):
    W_gamma: jax.Array
    W_beta: jax.Array

class TransformerBlockParams(TypedDict):
    attention: AttentionParams
    layer_norm1: LayerNormParams
    feed_forward: FeedForwardParams
    layer_norm2: LayerNormParams

class ModelParams(TypedDict):
    embedding: jax.Array
    layers: list[TransformerBlockParams]

In [None]:
def init_params(prng_key, vocab_size, d_model, num_layers):
    d_ff = d_model * 4
    initializer = jax.nn.initializers.glorot_normal()

    all_params = {}

    prng_key, pnrg_embed_key = jax.random.split(prng_key, 2)
    all_params['embedding'] = initializer(pnrg_embed_key, (vocab_size, d_model))

    all_params['layers'] = []
    for _ in range(num_layers):
        prng_key, attn_key, ff_key, ln_key = jax.random.split(prng_key, 4)
        wq_key, wk_key, wv_key, wo_key = jax.random.split(attn_key, 4)
        ff_l1_key, ff_l2_key = jax.random.split(ff_key, 2)
        g_key1, b_key1, g_key2, b_key2 = jax.random.split(ln_key, 4)
        layer_params: TransformerBlockParams = {
            'attention': {
                'W_query': initializer(wq_key, (d_model, d_model)),
                'W_key': initializer(wk_key, (d_model, d_model)),
                'W_value': initializer(wv_key, (d_model, d_model)),
                'W_out': initializer(wo_key, (d_model, d_model))
            },
            'layer_norm1': {
                'W_gamma': initializer(g_key1, (1, d_model)),
                'W_beta': initializer(b_key2, (1, d_model)),
            },
            'feed_forward': {
                'W_ff_l1': initializer(ff_l1_key, (d_model, d_ff)),
                'W_ff_l2': initializer(ff_l2_key, (d_ff, d_model)),
            },
            'layer_norm2': {
                'W_gamma': initializer(g_key2, (1, d_model)),
                'W_beta': initializer(b_key2, (1, d_model)),
            }

        }
        all_params['layers'].append(layer_params)

    return all_params

In [None]:
def get_positional_embeddings(seq_len, d_model):
    """Generates sinusoidal positional embeddings."""
    positions = jnp.arange(seq_len)[:, jnp.newaxis]
    div_term = jnp.exp(jnp.arange(0, d_model, 2) * -(jnp.log(10000.0) / d_model))

    pe = jnp.zeros((seq_len, d_model))
    pe = pe.at[:, 0::2].set(jnp.sin(positions * div_term))
    pe = pe.at[:, 1::2].set(jnp.cos(positions * div_term))

    return pe

In [None]:
def dropout(key, rate, x, training=True):
    """
    A simple dropout implementation in JAX.
    """
    if not training or rate == 0.0:
        return x

    # The keep probability
    keep_prob = 1.0 - rate

    # 1. Generate a random boolean mask
    mask = jax.random.bernoulli(key, keep_prob, x.shape)

    # 2. & 3. Apply mask and scale
    return jnp.where(mask, x / keep_prob, 0)

In [None]:
def layer_norm(x, params: LayerNormParams, epsilon=1e-5):
    """
    Applies Layer Normalization to the input array `x`.

    Args:
        x (jax.Array): The input array.
        gamma (jax.Array): The learnable scale parameter, shape should match the normalization axis.
        beta (jax.Array): The learnable bias parameter, shape should match the normalization axis.
        epsilon (float): A small value for numerical stability.

    Returns:
        jax.Array: The normalized output array.
    """
    # Define the axis over which to normalize.
    # For a common (batch, sequence, features) setup, this is the last axis.
    normalization_axis = -1

    # Extract learnable parameters.
    gamma = params['W_gamma']
    beta = params['W_beta']

    # Calculate the mean and variance over the specified axis.
    mean = jnp.mean(x, axis=normalization_axis, keepdims=True)
    var = jnp.var(x, axis=normalization_axis, keepdims=True, correction=True)

    # Normalize the input.
    x_norm = (x - mean) / jnp.sqrt(var + epsilon)

    # Apply the learnable scale and bias.
    output = gamma * x_norm + beta
    return output

In [None]:
def multi_head_attention(x, params: AttentionParams, num_heads, head_dim,
                        drop_key, drop_rate, training):
    """
    Efficient multi-head causal self-attention.

    params: Per-layer params with keys 'W_query', etc.
    """

    # chex.assert_equal(num_heads * head_dim == d_model,\
    #                   ValueError("num_heads * head_dim must equal d_model."))

    # Let these represent: batch, time (seq_len), dimension (d_model).
    b, t, d = x.shape

    # Calculate Q, K, V from x using respective weight matrices and reshape them
    # via number of heads, i.e., partition d_model -> (num_heads x head_dim).
    Q = (x @ params['W_query']).reshape((b, t, num_heads, head_dim))
    K = (x @ params['W_key']).reshape((b, t, num_heads, head_dim))
    V = (x @ params['W_value']).reshape((b, t, num_heads, head_dim))

    # Transpose Q, K, V such that we move batch and num_heads to the first two
    # axes because we want to parallelize multiplication over these. The
    # principal matmul is (seq_len, head_dim) @ same.T such that we get the
    # desired (seq_len, seq_len) for attention scores.
    # Result: (b, num_heads, t, head_dim)
    Q = Q.transpose((0, 2, 1, 3))
    K = K.transpose((0, 2, 1, 3))
    V = V.transpose((0, 2, 1, 3))

    # Calculate attention scores; transpose last 2 axes for matmul compatibility.
    attn_scores = Q @ K.transpose((0, 1, 3, 2))

    # Generate and apply mask
    mask = jnp.triu(jnp.ones((t, t)), k=1)
    masked_attn_scores = jnp.where(mask.astype(bool), -jnp.inf, attn_scores)

    # Scale and normalize along the horiz axis because rows represent the
    # probability distribution per token-query.
    attn_weights = jax.nn.softmax(masked_attn_scores / jnp.sqrt(head_dim), axis=-1)
    attn_weights = dropout(drop_key, drop_rate, attn_weights, training)

    # Matmul attention scores with values
    # (b, num_heads t, t) @ (..., t, head_dim) -> (b, nh, t, head_dim), then
    # reorder to our original (b, t, nh, hd). Finally reshape to original.
    context_vecs = (attn_weights @ V).transpose((0, 2, 1, 3)).reshape((b, t, d))

    return context_vecs @ params['W_out']

In [None]:
def transformer_block(x, layer_params: TransformerBlockParams, num_heads,\
                      head_dim, key, drop_rate, training):
    """Applies one layer of multi-head attention and a feed-forward network."""
    key, attn_key, ffn_key, embed_key = jax.random.split(key, 4)

    attn_params = layer_params['attention']
    sublayer1_result = multi_head_attention(x, attn_params, num_heads, head_dim,\
                                            attn_key, drop_rate, training)

    ln1_params = layer_params['layer_norm1']
    sublayer1_result = layer_norm(x + sublayer1_result, ln1_params)

    ffn_params = layer_params['feed_forward']
    sublayer2_result = jax.nn.relu(sublayer1_result @ ffn_params['W_ff_l1'])
    sublayer2_result = sublayer2_result @ ffn_params['W_ff_l2']
    sublayer2_result = dropout(ffn_key, drop_rate, sublayer2_result, training)

    ln2_params = layer_params['layer_norm2']
    return layer_norm(sublayer1_result + sublayer2_result, ln2_params)


In [None]:
def transformer_forward_pass(token_ids: jax.Array, params: ModelParams,\
                             num_heads, drop_key, drop_rate, training):
    seq_len, d_model = token_ids.shape[1], params['embedding'].shape[1]
    head_dim = d_model // num_heads

    # Get word embeddings by selecting tokens from embedding tensor
    word_embeds = params['embedding'][token_ids]
    pos_embeds = get_positional_embeddings(seq_len, d_model)
    x = word_embeds + pos_embeds

    # Apply dropout on word embeddings
    drop_key, embed_key = jax.random.split(drop_key)
    x = dropout(embed_key, drop_rate, x, training)

    # Split the key for each layer
    layer_keys = jax.random.split(drop_key, len(params['layers']))

    for i, layer in enumerate(params['layers']):
        x = transformer_block(x, layer, num_heads, head_dim, layer_keys[i],
                              drop_rate, training)

    logits = x @ params['embedding'].T

    return logits

In [None]:
def generate_text(forward_pass_fn, params, key, start_tokens, max_new_tokens, context_size, num_heads, dropout_rate, temperature=1.0):
    """
    Generates text autoregressively using sampling.
    """
    # Ensure start_tokens is a 2D array: (batch_size, num_tokens)
    if start_tokens.ndim == 1:
        start_tokens = start_tokens[jnp.newaxis, :]

    for _ in range(max_new_tokens):
        # Get a new key for this generation step
        key, step_key = jax.random.split(key)

        # Crop context if it's too long
        idx_cond = start_tokens[:, -context_size:]

        # Call the forward pass function with the correct arguments
        logits = forward_pass_fn(
            idx_cond, params, num_heads, step_key, dropout_rate, training=False
        )

        # Get logits for the very last token
        last_token_logits = logits[:, -1, :]

            # Apply temperature scaling
        scaled_logits = last_token_logits / temperature

        # Sample from the probability distribution
        next_token_id = jax.random.categorical(step_key, last_token_logits)

        # Append the new token
        start_tokens = jnp.concatenate(
            [start_tokens, next_token_id[:, jnp.newaxis]], axis=1
        )

    return start_tokens

In [None]:
def test_pass():
    d_model = 256
    params = init_params(jax.random.PRNGKey(22), n_vocab, d_model, 1)
    num_heads = 8
    prng_key = jax.random.PRNGKey(9000)
    drop_rate = 0.1
    training = False

    jit_forward_pass = jax.jit(
        transformer_forward_pass, static_argnames=('num_heads', 'drop_rate', 'training')
    )
    logits = jit_forward_pass(jnp.asarray(inputs), params, num_heads, prng_key,\
                            drop_rate, training)

    probs = jax.nn.softmax(logits)

    inputs_jax = jnp.asarray(inputs)
    targets_jax = jnp.asarray(targets)
    inputs_gen = generate_text_simple(jit_forward_pass, params, inputs_jax, 4, 256)

    decoded_texts = [dataloader.dataset.tokenizer.decode(seq) for seq in inputs_gen]
    decoded_targets = [dataloader.dataset.tokenizer.decode(seq) for seq in targets_jax]
    for i, text in enumerate(decoded_texts):
        target_text = decoded_targets[i]
        print(f"Target: {target_text}")
        print(f"Output: {text}\n")

## Training

In [None]:
def cross_entropy_loss(logits, targets):
    """Calculates the cross-entropy loss."""
    return optax.softmax_cross_entropy_with_integer_labels(
        logits=logits, labels=targets
    ).mean()

In [None]:
def train_step(params, optimizer_state, optimizer, batch_inputs, batch_targets, drop_key, num_heads, drop_rate):
    """Performs a single training step: loss, gradients, and updates."""

    def compute_loss(params):
        logits = transformer_forward_pass(
            batch_inputs, params, num_heads, drop_key, drop_rate, training=True
        )
        return cross_entropy_loss(logits.reshape(-1, logits.shape[-1]), batch_targets.reshape(-1))

    loss, grads = jax.value_and_grad(compute_loss)(params)

    # The optimizer is now passed in as an argument
    updates, optimizer_state = optimizer.update(grads, optimizer_state)
    params = optax.apply_updates(params, updates)

    return params, optimizer_state, loss

In [None]:
def eval_step(params, batch_inputs, batch_targets, num_heads, drop_rate, key):
    """Performs a single evaluation step."""
    logits = transformer_forward_pass(
        batch_inputs, params, num_heads, key, drop_rate, training=False
    )
    return cross_entropy_loss(logits.reshape(-1, logits.shape[-1]), batch_targets.reshape(-1))

In [None]:
ts_dataset = load_dataset("roneneldan/TinyStories")
batch_size = 32
max_length = 256
stride = 128

train_dataloader = create_dataloader(
    ts_dataset['train'], batch_size=batch_size, max_length=max_length,
    stride=stride, shuffle=True, drop_last=True, num_workers=1
)
val_dataloader = create_dataloader(
    ts_dataset['validation'], batch_size=batch_size, max_length=max_length,
    stride=stride, shuffle=False, drop_last=True, num_workers=1
)

KeyboardInterrupt: 

In [None]:
def do_training(train_dataloader, val_dataloader):
    # --- 1. Hyperparameters ---
    d_model = 256
    num_layers = 6
    num_heads = 8
    dropout_rate = 0.1
    batch_size = 32
    max_length = 256
    stride = 128
    learning_rate = 1e-4
    num_epochs = 5

    params ={}
    optimizer_state=None

    path = ocp.test_utils.erase_and_create_empty('./model_checkpoints')
    checkpointer = ocp.AsyncCheckpointer(ocp.StandardCheckpointHandler())
    save_data = {'params': params, 'optimizer_state': optimizer_state}

    # Get vocab size from the tokenizer used in the dataloader
    tokenizer = tiktoken.get_encoding('gpt2')
    vocab_size = tokenizer.n_vocab

    # --- 3. Initialize Model and Optimizer ---
    main_key = jax.random.PRNGKey(0)
    init_key, train_key = jax.random.split(main_key)

    print("Initializing params...")

    params = init_params(init_key, vocab_size, d_model, num_layers)
    optimizer = optax.adam(learning_rate)
    optimizer_state = optimizer.init(params)

    jit_train_step = jax.jit(
        train_step, static_argnames=('optimizer', 'num_heads', 'drop_rate')
    )
    jit_eval_step = jax.jit(
        eval_step, static_argnames=('num_heads', 'drop_rate')
    )

    print("Begin training loop.")

    # --- 4. The Training Loop ---
    for epoch in range(num_epochs):
        # --- Training Phase ---
        total_train_loss = 0
        for i, (inputs, targets) in enumerate(train_dataloader):
            train_key, step_key = jax.random.split(train_key)

            inputs_jax = jnp.asarray(inputs)
            targets_jax = jnp.asarray(targets)

            # Call the JIT-compiled function with the optimizer
            params, optimizer_state, loss = jit_train_step(
                params, optimizer_state, optimizer, inputs_jax, targets_jax,
                step_key, num_heads, dropout_rate
            )

            total_train_loss += loss
            if (i + 1) % 50 == 0:
                print(f"--- Epoch {epoch+1}/{num_epochs} | Batch {i+1}/{len(train_dataloader)} | Loss: {loss:.4f}")
            avg_train_loss = total_train_loss / len(train_dataloader)

        # --- Validation Phase ---
        total_val_loss = 0
        for inputs, targets in val_dataloader:
            # We don't need a new key for validation as dropout is disabled
            inputs_jax, targets_jax = jnp.asarray(inputs), jnp.asarray(targets)
            loss = jit_eval_step(
                params, inputs_jax, targets_jax, num_heads, dropout_rate, main_key
            )
            total_val_loss += loss

        avg_val_loss = total_val_loss / len(val_dataloader)
        print(f"--- End of Epoch {epoch+1} | Avg Train Loss: {avg_train_loss:.4f} | Avg Val Loss: {avg_val_loss:.4f} ---")

        # --- Generate and print a sample at the end of each epoch ---
        train_key, gen_key = jax.random.split(train_key)
        start_text = "as they turned around"
        start_tokens = jnp.asarray(
            train_dataloader.dataset.tokenizer.encode(start_text)
        )

        generated_tokens = generate_text(
            transformer_forward_pass, params, gen_key, start_tokens,
            max_new_tokens=32, context_size=max_length,
            num_heads=num_heads, dropout_rate=dropout_rate, temperature=0.5
        )

        # Decode the entire generated sequence
        decoded_text = train_dataloader.dataset.tokenizer.decode(generated_tokens[0].tolist())
        print(f"Sample: {decoded_text}\n")

        checkpointer.save(path / '1', args=ocp.args.StandardSave(save_data))
        print(f"Checkpoint for epoch {epoch} saved.")


In [None]:
do_training(train_dataloader, val_dataloader)

Initializing params...
Begin training loop.


  self.pid = os.fork()


--- Epoch 1/5 | Batch 50/231441 | Loss: 10.4633
--- Epoch 1/5 | Batch 100/231441 | Loss: 10.0181
--- Epoch 1/5 | Batch 150/231441 | Loss: 9.5122
--- Epoch 1/5 | Batch 200/231441 | Loss: 9.0073
--- Epoch 1/5 | Batch 250/231441 | Loss: 8.3923
--- Epoch 1/5 | Batch 300/231441 | Loss: 7.8657
--- Epoch 1/5 | Batch 350/231441 | Loss: 7.3888
--- Epoch 1/5 | Batch 400/231441 | Loss: 6.9569
--- Epoch 1/5 | Batch 450/231441 | Loss: 6.6737
--- Epoch 1/5 | Batch 500/231441 | Loss: 6.4646
--- Epoch 1/5 | Batch 550/231441 | Loss: 6.3255
--- Epoch 1/5 | Batch 600/231441 | Loss: 6.1050
--- Epoch 1/5 | Batch 650/231441 | Loss: 6.0654
--- Epoch 1/5 | Batch 700/231441 | Loss: 6.0761
--- Epoch 1/5 | Batch 750/231441 | Loss: 5.9723
--- Epoch 1/5 | Batch 800/231441 | Loss: 5.9577
--- Epoch 1/5 | Batch 850/231441 | Loss: 5.9065
--- Epoch 1/5 | Batch 900/231441 | Loss: 6.0005
--- Epoch 1/5 | Batch 950/231441 | Loss: 6.0497
--- Epoch 1/5 | Batch 1000/231441 | Loss: 5.9992
--- Epoch 1/5 | Batch 1050/231441 | Lo

  self.pid = os.fork()


KeyboardInterrupt: 