In [6]:
%%capture
!pip install datasets

huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...
	- Avoid using `tokenizers` before the fork if possible
	- Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)


In [23]:
import os
import math
import pickle
import jax
import jax.numpy as jnp
from jax import random, lax, vmap
import flax.linen as nn
import optax
from types import SimpleNamespace
from functools import partial
from tqdm import tqdm
from datasets import load_dataset
from tokenizers import SentencePieceUnigramTokenizer
from typing import Any, Callable, Dict, List, Optional, Tuple
from flax.training import train_state, orbax_utils
import orbax.checkpoint as ocp

In [8]:
os.environ['JAX_PLATFORM_NAME'] = 'tpu'
os.environ['XLA_PYTHON_CLIENT_PREALLOCATE'] = 'false'
print("JAX devices:", len(jax.devices()))

JAX devices: 8


In [34]:
class LLaMAConfig:
    """Configuration for LLaMA model and dataset preparation"""
    # Model architecture settings
    vocab_size: int = 32000
    dim: int = 512  # Hidden dimension
    n_layers: int = 8  # Number of transformer layers
    n_heads: int = 8  # Number of attention heads
    n_kv_heads: int = 4  # Number of key/value heads (for grouped-query attention)
    max_seq_len: int = 2048  # Maximum sequence length
    dropout_rate: float = 0.0  # Dropout rate
    # RoPE settings
    rope_theta: float = 10000.0  # Base for rotary embeddings
    
    # Training settings
    batch_size: int = 16
    learning_rate: float = 3e-4
    weight_decay: float = 0.1
    warmup_steps: int = 1000
    max_steps: int = 100000
    
    # Generation settings
    temperature: float = 0.8
    top_k: int = 40
    top_p: float = 0.95
    
    # Dataset settings
    dataset_name: str = "wikitext"  # Default dataset
    dataset_config: str = "wikitext-2-raw-v1"  # Dataset configuration name
    dataset_path: str = None  # Path to local dataset
    split: str = "train"  # Dataset split to use
    text_column: str = None  # Column to use for text (auto-detected if None)
    
    # Tokenizer settings
    tokenizer_path: str = "llama_tokenizer.json"  # Path to save/load tokenizer
    tokenizer_save_path: str = None  # Optional different path to save tokenizer
    use_existing_tokenizer: bool = False  # Whether to use existing tokenizer
    tokenizer_sample_size: int = 10000  # Number of samples to use for tokenizer training
    special_tokens: list = ["<pad>", "<unk>", "<bos>", "<eos>"]  # List of special tokens (defaults to ["<pad>", "<unk>", "<bos>", "<eos>"])
    
    # Processing options
    format_type: str = "flat"  # Dataset format: "flat", "chunked", or default
    chunk_size: int = 1024  # Size of chunks for "chunked" format
    chunk_overlap: int = 0  # Overlap between chunks
    add_eos_between_examples: bool = True  # Whether to add EOS token between examples
    add_eos_between_segments: bool = False  # Whether to add EOS token between segments
    keep_text_column: bool = False  # Whether to keep original text column

    def __post_init__(self):
        """Initialize defaults for None values"""
        if self.special_tokens is None:
            self.special_tokens = ["<pad>", "<unk>", "<bos>", "<eos>"]
        if self.tokenizer_save_path is None:
            self.tokenizer_save_path = self.tokenizer_path

In [10]:
class RMSNorm(nn.Module):
    """Root Mean Square Layer Normalization"""
    dim: int
    eps: float = 1e-5

    @nn.compact
    def __call__(self, x):
        weight = self.param('weight', nn.initializers.ones, (self.dim,))
        variance = jnp.mean(jnp.square(x), axis=-1, keepdims=True)
        x = x * jnp.reciprocal(jnp.sqrt(variance + self.eps))
        return x * weight

In [11]:
# Rotary Position Embeddings
def precompute_freqs_cis(dim: int, max_seq_len: int, theta: float = 10000.0):
    """Precompute the frequency tensor for complex exponentials (rotary embeddings)."""
    # Compute the frequencies for each feature dimension
    freqs = 1.0 / (theta ** (jnp.arange(0, dim // 2, dtype=jnp.float32) / dim))
    t = jnp.arange(max_seq_len, dtype=jnp.float32)
    # Create the frequency matrix by outer product
    freqs = jnp.outer(t, freqs)
    # Convert to complex exponentials
    return jnp.complex64(jnp.exp(1j * freqs))

def apply_rotary_emb(xq, xk, freqs_cis):
    """Apply rotary embeddings to the query and key tensors."""
    # Reshape inputs to isolate the last dimension into pairs for complex multiplication
    xq_r, xk_r = jnp.reshape(xq, (*xq.shape[:-1], -1, 2)), jnp.reshape(xk, (*xk.shape[:-1], -1, 2))

    # Convert to complex numbers
    xq_complex = jnp.complex64(xq_r[..., 0] + 1j * xq_r[..., 1])
    xk_complex = jnp.complex64(xk_r[..., 0] + 1j * xk_r[..., 1])

    # Reshape frequency cis for broadcasting
    freqs_cis = jnp.reshape(freqs_cis, (1, freqs_cis.shape[0], 1, freqs_cis.shape[1]))

    # Apply rotation through complex multiplication
    xq_out = xq_complex * freqs_cis
    xk_out = xk_complex * freqs_cis

    # Convert back to real tensor and reshape
    xq = jnp.stack([jnp.real(xq_out), jnp.imag(xq_out)], axis=-1).reshape(xq.shape)
    xk = jnp.stack([jnp.real(xk_out), jnp.imag(xk_out)], axis=-1).reshape(xk.shape)

    return xq, xk

In [12]:
@partial(jax.jit)
def flash_attention(q, k, v, mask=None, scale=None):
    """
    Optimized implementation of attention mechanism using JAX primitives
    for better compiler optimization and memory efficiency.
    """
    batch_size, num_heads, seq_len, head_dim = q.shape

    # Compute scale if not provided
    if scale is None:
        scale = 1.0 / jnp.sqrt(head_dim)

    # Compute attention scores with fused operation
    # Fuse transpose and matmul for better compiler optimization
    scores = jnp.einsum('bhid,bhjd->bhij', q, k) * scale

    # Apply causal mask if provided
    if mask is not None:
        scores = scores + mask

    # Stabilize softmax by subtracting max value
    # This prevents overflow and allows for better precision
    scores_max = jnp.max(scores, axis=-1, keepdims=True)
    scores = scores - lax.stop_gradient(scores_max)

    # Apply softmax with higher precision
    attn_weights = jnp.exp(scores)
    attn_weights = attn_weights / jnp.sum(attn_weights, axis=-1, keepdims=True)

    # Compute attention output with fused operation
    output = jnp.einsum('bhij,bhjd->bhid', attn_weights, v)

    return output

In [13]:
def swiglu(x, w1, w2, w3):
    """SwiGLU activation function using Flax modules"""
    return w2(jax.nn.silu(w3(x)) * w1(x))

In [14]:
class LLaMACausalSelfAttention(nn.Module):
    """Multi-head causal self-attention with support for grouped-query attention"""
    config: LLaMAConfig

    def setup(self):
        config = self.config
        dim = config.dim
        n_heads = config.n_heads
        n_kv_heads = config.n_kv_heads
        head_dim = dim // n_heads

        # QKV projections
        self.wq = nn.Dense(n_heads * head_dim,
                          kernel_init=nn.initializers.variance_scaling(1.0, 'fan_in', 'normal'))
        self.wk = nn.Dense(n_kv_heads * head_dim,
                          kernel_init=nn.initializers.variance_scaling(1.0, 'fan_in', 'normal'))
        self.wv = nn.Dense(n_kv_heads * head_dim,
                          kernel_init=nn.initializers.variance_scaling(1.0, 'fan_in', 'normal'))

        # Output projection
        self.wo = nn.Dense(dim,
                          kernel_init=nn.initializers.variance_scaling(1.0, 'fan_in', 'normal'))

        # QK normalization for improved stability
        self.q_norm = RMSNorm(head_dim)
        self.k_norm = RMSNorm(head_dim)

    def __call__(self, x, freqs_cis, mask=None, deterministic=True):
        B, T, C = x.shape
        config = self.config
        n_heads = config.n_heads
        n_kv_heads = config.n_kv_heads
        head_dim = C // n_heads

        # Linear projections
        q = self.wq(x).reshape(B, T, n_heads, head_dim)
        k = self.wk(x).reshape(B, T, n_kv_heads, head_dim)
        v = self.wv(x).reshape(B, T, n_kv_heads, head_dim)

        # Apply QK normalization
        q = jnp.swapaxes(self.q_norm(jnp.swapaxes(q, 1, 2)), 1, 2)
        k = jnp.swapaxes(self.k_norm(jnp.swapaxes(k, 1, 2)), 1, 2)

        # Apply rotary embeddings
        q, k = apply_rotary_emb(q, k, freqs_cis[:T])

        # Repeat k and v heads if n_heads > n_kv_heads (grouped-query attention)
        if n_heads > n_kv_heads:
            k = jnp.repeat(k, n_heads // n_kv_heads, axis=2)
            v = jnp.repeat(v, n_heads // n_kv_heads, axis=2)

        # Transpose tensors for attention computation (B, H, T, D)
        q, k, v = map(lambda x: jnp.swapaxes(x, 1, 2), (q, k, v))

        # Use flash attention (conceptually)
        output = flash_attention(q, k, v, mask)

        # Transpose output and project back to full dimension
        output = jnp.swapaxes(output, 1, 2).reshape(B, T, -1)
        output = self.wo(output)

        return output

# LLaMA MLP Module
class LLaMAMLP(nn.Module):
    """Feed-forward network with SwiGLU activation"""
    config: LLaMAConfig

    def setup(self):
        dim = self.config.dim
        hidden_dim = 4 * dim  # 4x expansion

        # Linear projections
        self.w1 = nn.Dense(hidden_dim,
                         kernel_init=nn.initializers.variance_scaling(1.0, 'fan_in', 'normal'))
        self.w2 = nn.Dense(dim,
                         kernel_init=nn.initializers.variance_scaling(1.0, 'fan_in', 'normal'))
        self.w3 = nn.Dense(hidden_dim,
                         kernel_init=nn.initializers.variance_scaling(1.0, 'fan_in', 'normal'))

    def __call__(self, x):
        return swiglu(x, self.w1, self.w2, self.w3)

# LLaMA Transformer Block
class LLaMABlock(nn.Module):
    """LLaMA transformer block"""
    config: LLaMAConfig

    def setup(self):
        self.attention_norm = RMSNorm(self.config.dim)
        self.attention = LLaMACausalSelfAttention(self.config)
        self.ffn_norm = RMSNorm(self.config.dim)
        self.ffn = LLaMAMLP(self.config)
        self.dropout = nn.Dropout(self.config.dropout_rate)

    def __call__(self, x, freqs_cis, mask=None, deterministic=True):
        # Pre-norm for attention
        h = x + self.dropout(
            self.attention(self.attention_norm(x), freqs_cis, mask, deterministic),
            deterministic=deterministic
        )

        # Pre-norm for FFN
        out = h + self.dropout(
            self.ffn(self.ffn_norm(h)),
            deterministic=deterministic
        )

        return out

# Full LLaMA Model
class LLaMA3(nn.Module):
    """LLaMA language model"""
    config: LLaMAConfig

    def setup(self):
        config = self.config

        # Token embeddings
        self.token_embedding = nn.Embed(
            config.vocab_size,
            config.dim,
            embedding_init=nn.initializers.normal(stddev=0.02)
        )

        # Transformer blocks
        self.blocks = [LLaMABlock(config) for _ in range(config.n_layers)]

        # Final layer norm
        self.norm_f = RMSNorm(config.dim)

        # Output projection (tied with embeddings)
        self.lm_head = nn.Dense(
            config.vocab_size,
            kernel_init=nn.initializers.normal(stddev=0.02),
            use_bias=False
        )

        # Pre-compute rotary embeddings
        self.freqs_cis = precompute_freqs_cis(
            config.dim // config.n_heads,
            config.max_seq_len,
            config.rope_theta
        )

    def __call__(self, input_ids, deterministic=True):
        B, T = input_ids.shape

        # Create causal attention mask
        mask = jnp.tril(
            jnp.ones((self.config.max_seq_len, self.config.max_seq_len))
        )
        mask = jnp.where(mask == 0, jnp.finfo(jnp.float32).min, 0.0)
        mask = mask[None, None, :T, :T]

        # Get embeddings
        h = self.token_embedding(input_ids)

        # Apply transformer blocks
        for block in self.blocks:
            h = block(h, self.freqs_cis, mask, deterministic)

        # Apply final normalization
        h = self.norm_f(h)

        # Get logits
        logits = self.lm_head(h)

        return logits

    def generate(self, input_ids, max_new_tokens, rng_key, temperature=0.8, top_k=40, top_p=0.95):
        """Generate text using the model"""
        B, T = input_ids.shape

        # Create initial output array
        output = input_ids

        # Generate tokens
        for i in range(max_new_tokens):
            # Keep the context within max sequence length
            curr_input = output[:, -self.config.max_seq_len:]

            # Get logits for the next token
            logits = self(curr_input, deterministic=True)[:, -1, :]

            # Apply temperature
            logits = logits / temperature

            # Apply top-k filtering
            if top_k > 0:
                top_k_v, top_k_i = jax.lax.top_k(logits, top_k)
                indices_to_remove = jnp.broadcast_to(
                    jnp.arange(logits.shape[-1]) < top_k_i[:, -1:],
                    logits.shape
                )
                logits = jnp.where(indices_to_remove, logits, jnp.finfo(jnp.float32).min)

            # Apply top-p (nucleus) filtering
            if top_p < 1.0:
                sorted_indices = jnp.argsort(logits, axis=-1)[:, ::-1]  # Sort indices in descending order
                sorted_logits = jnp.take_along_axis(logits, sorted_indices, axis=-1)  # Get sorted values

                cumulative_probs = jnp.cumsum(jax.nn.softmax(sorted_logits, axis=-1), axis=-1)

                # Remove tokens with cumulative probability above the threshold
                sorted_indices_to_remove = cumulative_probs > top_p
                # Shift the indices to the right to keep the first token above threshold
                sorted_indices_to_remove = jnp.roll(sorted_indices_to_remove, 1, axis=1)
                sorted_indices_to_remove = sorted_indices_to_remove.at[:, 0].set(False)

                # Scatter sorted tensors to original indexing
                indices_to_remove = jnp.zeros_like(logits, dtype=bool)
                indices_to_remove = indices_to_remove.at[jnp.arange(B)[:, None], sorted_indices].set(sorted_indices_to_remove)
                logits = jnp.where(indices_to_remove, jnp.finfo(jnp.float32).min, logits)

            # Sample from the filtered distribution
            rng_key, sample_key = random.split(rng_key)
            next_token = random.categorical(sample_key, logits, shape=(B,))

            # Append the sampled token to the sequence
            output = jnp.concatenate([output, next_token[:, None]], axis=1)

        return output

In [15]:
# Create initial training state with Flax TrainState
def create_train_state(model, config, rng_key):
    """Create initial training state."""
    # Initialize model parameters
    init_params = model.init(rng_key, jnp.ones((1, 1), dtype=jnp.int32))

    # Create learning rate schedule
    lr_schedule = optax.warmup_cosine_decay_schedule(
        init_value=0.0,
        peak_value=config.learning_rate,
        warmup_steps=config.warmup_steps,
        decay_steps=config.max_steps,
        end_value=config.learning_rate * 0.1
    )

    # Create optimizer
    optimizer = optax.chain(
        optax.clip_by_global_norm(1.0),
        optax.adamw(
            learning_rate=lr_schedule,
            b1=0.9,
            b2=0.95,
            eps=1e-8,
            weight_decay=config.weight_decay
        )
    )

    # Create and return train state - ensure parameters have consistent structure
    return train_state.TrainState.create(
        apply_fn=model.apply,
        params=init_params,
        tx=optimizer
    )

In [16]:
def train_step(state, batch, dropout_rng):
    """Single training step"""
    inputs, targets = batch

    # Define loss function
    def loss_fn(params):
        # Apply model with correct parameter structure
        logits = state.apply_fn({"params": params}, inputs, deterministic=False, rngs={'dropout': dropout_rng})

        # Reshape for cross entropy
        logits = logits.reshape(-1, logits.shape[-1])
        targets_flat = targets.reshape(-1)

        # Compute cross entropy loss
        loss = optax.softmax_cross_entropy_with_integer_labels(
            logits, targets_flat
        ).mean()

        return loss

    # Get gradients - make sure we're getting gradients for the actual parameters
    grad_fn = jax.value_and_grad(loss_fn)

    # Check param structure and extract the inner params if needed
    if "params" in state.params:
        actual_params = state.params["params"]
    else:
        actual_params = state.params

    loss, grads = grad_fn(actual_params)

    # Now wrap the gradients in the same structure as state.params for apply_gradients
    if "params" in state.params:
        wrapped_grads = {"params": grads}
    else:
        wrapped_grads = grads

    # Update state with correctly structured gradients
    new_state = state.apply_gradients(grads=wrapped_grads)

    return new_state, loss

# JIT-compiled training step for efficiency
train_step_jit = jax.jit(train_step)

# FIXED: Define pmapped training step with improved dropout RNG handling
def train_step_pmap_wrapper(state, batch, dropout_rng):
    # Wrapper for consistency when pmapping
    return train_step(state, batch, dropout_rng)

train_step_pmap = jax.pmap(train_step_pmap_wrapper, axis_name='batch')

In [17]:
def evaluate(model_apply_fn, params, eval_data, config, num_batches=10):
    """Evaluate model on validation data"""
    key = random.PRNGKey(42)
    total_loss = 0.0

    for i in range(num_batches):
        key, batch_key = random.split(key)
        inputs, targets = get_batch(batch_key, eval_data, config)

        # Forward pass
        # Check if params already has a 'params' key
        if 'params' in params:
            logits = model_apply_fn(params, inputs, deterministic=True)
        else:
            logits = model_apply_fn({'params': params}, inputs, deterministic=True)

        # Reshape for cross entropy
        logits = logits.reshape(-1, logits.shape[-1])
        targets = targets.reshape(-1)

        # Compute cross entropy loss
        loss = optax.softmax_cross_entropy_with_integer_labels(
            logits, targets
        ).mean()

        total_loss += loss

    avg_loss = total_loss / num_batches
    perplexity = jnp.exp(avg_loss)

    return avg_loss, perplexity

In [18]:
def train_llama(config, num_epochs=5, steps_per_epoch=1000, save_every=1000):
    """Train LLaMA model"""
    # Initialize TPU
    n_devices = initialize_tpu()

    # Setup model
    model = LLaMA3(config)
    rng_key = random.PRNGKey(42)

    # Create training state
    state = create_train_state(model, config, rng_key)

    # Replicate the state across devices for multi-device training
    if n_devices > 1:
        state = jax.device_put_replicated(state, jax.devices())

    # Prepare datasets
    train_dataset, tokenizer = prepare_datasets(config)

    # Create checkpoint directory with absolute path
    checkpoint_dir = os.path.abspath("llama_checkpoints")
    os.makedirs(checkpoint_dir, exist_ok=True)
    
    # Setup Orbax checkpointer
    checkpointer = ocp.PyTreeCheckpointer()
    options = ocp.CheckpointManagerOptions(
        max_to_keep=3,
        create=True
    )
    checkpoint_manager = ocp.CheckpointManager(
        checkpoint_dir, 
        checkpointer, 
        options
    )

    # Training loop
    rng_key = random.PRNGKey(0)
    step = 0
    total_steps = num_epochs * steps_per_epoch

    print(f"Starting training for {num_epochs} epochs ({total_steps} steps)")

    for epoch in range(num_epochs):
        epoch_loss = 0.0

        for step_in_epoch in tqdm(range(steps_per_epoch), desc=f"Epoch {epoch+1}/{num_epochs}"):
            # Get a batch of data
            rng_key, batch_key = random.split(rng_key)
            batch_data = get_batch(batch_key, train_dataset, config)

            # Training step
            if n_devices > 1:
                # Calculate per-device batch size
                per_device_batch = config.batch_size // n_devices
                if per_device_batch == 0:
                    raise ValueError(f"Batch size {config.batch_size} is too small for {n_devices} devices")

                # Reshape batch data for multi-device training
                inputs, targets = batch_data

                # Reshape to (n_devices, per_device_batch, seq_len)
                inputs = inputs.reshape((n_devices, per_device_batch, inputs.shape[1]))
                targets = targets.reshape((n_devices, per_device_batch, targets.shape[1]))

                batch_data = (inputs, targets)

                # Create per-device RNG keys with proper shape for pmap
                rng_key, dropout_keys = create_device_rng_keys(rng_key, n_devices)

                # Apply pmapped training step
                state, loss = train_step_pmap(state, batch_data, dropout_keys)

                # Average loss across devices
                loss = jnp.mean(loss)
            else:
                # Single device training
                rng_key, dropout_key = random.split(rng_key)
                state, loss = train_step_jit(state, batch_data, dropout_key)

            epoch_loss += loss
            step += 1

            # Log progress
            if step % 100 == 0:
                print(f"Step {step}/{total_steps}, Loss: {loss:.4f}")

            # Save checkpoint
            if step % save_every == 0:
                if n_devices > 1:
                    # For multi-device, save only the first copy
                    save_state = jax.tree.map(lambda x: x[0], state)
                else:
                    save_state = state

                # Save checkpoint using Orbax
                save_args = orbax_utils.save_args_from_target(save_state)
                checkpoint_manager.save(step, save_state, save_kwargs={'save_args': save_args})
                print(f"Checkpoint saved at step {step}")

                # Generate sample text
                if n_devices > 1:
                    sample_params = jax.tree.map(lambda x: x[0], state.params)
                else:
                    sample_params = state.params

                prompt = tokenizer.encode("Once upon a time").ids
                prompt_tensor = jnp.array([prompt])

                sample_rng = random.PRNGKey(step)
                # Check if sample_params already has a 'params' key
                if 'params' in sample_params:
                    generated = model.apply(
                        sample_params,
                        prompt_tensor,
                        max_new_tokens=50,
                        rng_key=sample_rng,
                        temperature=config.temperature,
                        top_k=config.top_k,
                        top_p=config.top_p,
                        method=model.generate
                    )
                else:
                    generated = model.apply(
                        {"params": sample_params},
                        prompt_tensor,
                        max_new_tokens=50,
                        rng_key=sample_rng,
                        temperature=config.temperature,
                        top_k=config.top_k,
                        top_p=config.top_p,
                        method=model.generate
                    )
                generated_text = tokenizer.decode(generated[0].tolist())
                print(f"\nSample generation at step {step}:\n{generated_text}\n")

        # End of epoch
        avg_epoch_loss = epoch_loss / steps_per_epoch
        print(f"Epoch {epoch+1} complete. Average loss: {avg_epoch_loss:.4f}")

        # Evaluate on validation set
        if n_devices > 1:
            eval_params = jax.tree.map(lambda x: x[0], state.params)
        else:
            eval_params = state.params

        # Validation loss and perplexity
        val_loss, perplexity = evaluate(model.apply, eval_params, train_dataset, config)
        print(f"Validation Loss: {val_loss:.4f}, Perplexity: {perplexity:.2f}")

    # Save final model
    if n_devices > 1:
        final_state = jax.tree.map(lambda x: x[0], state)
    else:
        final_state = state

    # Save final checkpoint using Orbax
    save_args = orbax_utils.save_args_from_target(final_state)
    checkpoint_manager.save(
        total_steps, 
        final_state, 
        save_kwargs={'save_args': save_args}
    )
    print("Training complete. Final model saved.")

    return final_state

In [19]:
def generate_text(model, params, tokenizer, prompt, max_new_tokens=100, temperature=0.8):
    """Generate text from a prompt"""
    # Ensure prompt is a string
    if not isinstance(prompt, str):
        prompt = str(prompt)

    prompt_tokens = tokenizer.encode(prompt).ids
    prompt_tensor = jnp.array([prompt_tokens])

    rng_key = random.PRNGKey(0)
    # Check if params already has a 'params' key
    if 'params' in params:
        generated = model.apply(
            params,
            prompt_tensor,
            max_new_tokens=max_new_tokens,
            rng_key=rng_key,
            temperature=temperature,
            top_k=40,
            top_p=0.95,
            method=model.generate
        )
    else:
        generated = model.apply(
            {"params": params},
            prompt_tensor,
            max_new_tokens=max_new_tokens,
            rng_key=rng_key,
            temperature=temperature,
            top_k=40,
            top_p=0.95,
            method=model.generate
        )

    # Convert jnp array to Python list before decoding
    generated_text = tokenizer.decode(generated[0].tolist())
    return generated_text

In [None]:
def get_batch(key, data, config):
    """Create a batch of data for training"""
    batch_size = config.batch_size
    seq_len = config.max_seq_len

    # Generate random starting indices
    total_tokens = len(data["input_ids"]) - seq_len - 1  # -1 for target shifting

    # Make sure we have enough tokens
    if total_tokens <= 0:
        raise ValueError(f"Not enough tokens in dataset. Found {len(data['input_ids'])}, need at least {seq_len + 2}")

    # Generate batch_size random starting points
    ix = random.randint(key, (batch_size,), 0, total_tokens)

    # Create input and target sequences
    x = jnp.stack([jnp.array(data["input_ids"][i:i+seq_len]) for i in ix])
    y = jnp.stack([jnp.array(data["input_ids"][i+1:i+seq_len+1]) for i in ix])

    return x, y

In [20]:
def load_checkpoint(checkpoint_dir, step=None):
    """Load a checkpoint from the given directory"""
    checkpoint_dir = os.path.abspath(checkpoint_dir)
    
    # Setup Orbax checkpointer
    checkpointer = ocp.PyTreeCheckpointer()
    options = ocp.CheckpointManagerOptions(create=False)
    checkpoint_manager = ocp.CheckpointManager(
        checkpoint_dir, 
        checkpointer, 
        options
    )
    
    # Get the latest step if none specified
    if step is None:
        step = checkpoint_manager.latest_step()
        if step is None:
            raise ValueError(f"No checkpoints found in {checkpoint_dir}")
    
    # Create a dummy state to restore structure
    model = LLaMA3(LLaMAConfig())
    rng_key = random.PRNGKey(0)
    dummy_state = create_train_state(model, LLaMAConfig(), rng_key)
    
    # Restore checkpoint
    restored_state = checkpoint_manager.restore(step, dummy_state)
    print(f"Restored checkpoint from step {step}")
    
    return restored_state, step

In [36]:
def prepare_datasets(config):
    """Load and prepare datasets for LLaMA-like LLM training with flexible dataset support"""
    from datasets import load_dataset
    from tokenizers import SentencePieceUnigramTokenizer, Tokenizer, models, pre_tokenizers
    import os
    
    # Load dataset based on configuration
    if hasattr(config, 'dataset_name') and config.dataset_name:
        # Load from Hugging Face datasets
        if hasattr(config, 'dataset_config') and config.dataset_config:
            dataset = load_dataset(config.dataset_name, config.dataset_config, split=config.split)
        else:
            dataset = load_dataset(config.dataset_name, split=config.split)
    elif hasattr(config, 'dataset_path') and config.dataset_path:
        # Load from local files
        dataset = load_dataset(
            'json' if config.dataset_path.endswith('.json') else 'text', 
            data_files=config.dataset_path, 
            split=config.split
        )
    else:
        # Default to tiny_shakespeare as fallback
        dataset = load_dataset("karpathy/tiny_shakespeare", split="train")
    
    # Print the dataset structure
    print(f"Dataset loaded: {config.dataset_name if hasattr(config, 'dataset_name') else 'local'}")
    print("Dataset structure:", list(dataset.features.keys()))
    print(f"Dataset size: {len(dataset)} examples")
    
    # Determine text column
    text_column = getattr(config, 'text_column', None)
    if not text_column:
        column_names = list(dataset.features.keys())
        text_column = next((col for col in ['text', 'content', 'document'] 
                          if col in column_names), column_names[0])
    print(f"Using '{text_column}' as the text column")
    
    # Initialize tokenizer - either load existing or train new one
    tokenizer_path = getattr(config, 'tokenizer_path', "llama_tokenizer.json")
    if os.path.exists(tokenizer_path) and getattr(config, 'use_existing_tokenizer', False):
        print(f"Loading existing tokenizer from {tokenizer_path}")
        tokenizer = Tokenizer.from_file(tokenizer_path)
    else:
        print("Training new tokenizer")
        
        # Set special tokens from config or use defaults
        special_tokens = getattr(config, 'special_tokens', ["<pad>", "<unk>", "<bos>", "<eos>"])
        
        # Ensure <unk> token is in the special tokens
        if "<unk>" not in special_tokens:
            special_tokens.append("<unk>")
            print("Added <unk> token to special tokens")
        
        # Get sample texts for training the tokenizer
        sample_size = min(getattr(config, 'tokenizer_sample_size', 10000), len(dataset))
        
        # Sample texts, handling potentially different data formats
        sample_texts = []
        for example in dataset.select(range(sample_size)):
            text = example.get(text_column, "")
            if isinstance(text, str) and text.strip():
                sample_texts.append(text)
            elif isinstance(text, list) and all(isinstance(t, str) for t in text):
                sample_texts.extend([t for t in text if t.strip()])
        
        # Create and train the tokenizer
        tokenizer = SentencePieceUnigramTokenizer()
        tokenizer.train_from_iterator(
            sample_texts,
            vocab_size=config.vocab_size,
            special_tokens=special_tokens,
            unk_token="<unk>"  # Explicitly specify the unk token
        )
        
        # Save the tokenizer - make sure path is a valid string
        tokenizer_save_path = getattr(config, 'tokenizer_save_path', tokenizer_path)
        # Ensure the path is a valid string and not None
        if tokenizer_save_path is None:
            tokenizer_save_path = "llama_tokenizer.json"
        
        print(f"Saving tokenizer to {tokenizer_save_path}")
        tokenizer.save(tokenizer_save_path)
        print(f"Tokenizer saved to {tokenizer_save_path}")
    
    # Define tokenize function with handling for unknown tokens
    def tokenize_function(example):
        try:
            # Handle different text formats (string, list of strings, etc.)
            text = example.get(text_column, "")
            
            if isinstance(text, str) and text.strip():
                # Use encode with add_special_tokens=False to avoid adding special tokens automatically
                encoded = tokenizer.encode(text)
                return {"input_ids": encoded.ids}
            elif isinstance(text, list) and all(isinstance(t, str) for t in text):
                # Handle list of strings (e.g., for dialogue datasets)
                tokens = []
                for t in text:
                    if t.strip():
                        encoded = tokenizer.encode(t)
                        tokens.extend(encoded.ids)
                        if getattr(config, 'add_eos_between_segments', False):
                            # Add EOS token between segments if configured
                            eos_id = tokenizer.token_to_id("<eos>")
                            if eos_id is not None:
                                tokens.append(eos_id)
                return {"input_ids": tokens}
            
            return {"input_ids": []}
        except Exception as e:
            print(f"Error tokenizing example: {e}")
            # Return empty list for problematic examples
            return {"input_ids": []}
    
    # Tokenize dataset with progress reporting
    print("Tokenizing dataset...")
    tokenized_dataset = dataset.map(
        tokenize_function,
        remove_columns=[col for col in dataset.column_names if col != text_column] 
        if getattr(config, 'keep_text_column', False) else dataset.column_names,
        batched=False,  # Process one example at a time for better error handling
        desc="Tokenizing"
    )
    
    # Filter out empty examples
    original_size = len(tokenized_dataset)
    tokenized_dataset = tokenized_dataset.filter(lambda x: len(x["input_ids"]) > 0)
    filtered_size = len(tokenized_dataset)
    if original_size != filtered_size:
        print(f"Filtered out {original_size - filtered_size} empty examples")
    
    # Process based on the specified format
    format_type = getattr(config, 'format_type', "flat")
    
    if format_type == "flat":
        # Flatten all token sequences into one long sequence
        all_tokens = []
        for example in tokenized_dataset:
            all_tokens.extend(example["input_ids"])
            
            # Add EOS token between examples if configured
            if getattr(config, 'add_eos_between_examples', False):
                eos_id = tokenizer.token_to_id("<eos>")
                if eos_id is not None:
                    all_tokens.append(eos_id)
                    
        print(f"Created flattened dataset with {len(all_tokens)} tokens")
        return {"input_ids": all_tokens}, tokenizer
    
    elif format_type == "chunked":
        # Create fixed-length chunks of tokens for training
        chunk_size = getattr(config, 'chunk_size', 512)
        overlap = getattr(config, 'chunk_overlap', 0)
        
        chunked_datasets = []
        for example in tokenized_dataset:
            tokens = example["input_ids"]
            for i in range(0, len(tokens) - chunk_size + 1, chunk_size - overlap):
                chunk = tokens[i:i + chunk_size]
                if len(chunk) == chunk_size:  # Only keep full-sized chunks
                    chunked_datasets.append({"input_ids": chunk})
        
        print(f"Created {len(chunked_datasets)} chunks of size {chunk_size}")
        return chunked_datasets, tokenizer
    
    else:
        # Keep dataset as is (one example per entry)
        print(f"Using dataset with {len(tokenized_dataset)} separate examples")
        return tokenized_dataset, tokenizer

In [None]:
config = LLaMAConfig()
# Create checkpoint directory with absolute path
checkpoint_dir = os.path.abspath("llama_checkpoints")
os.makedirs(checkpoint_dir, exist_ok=True)

# Train the model
final_state = train_llama(config, num_epochs=5, steps_per_epoch=10,save_every= 10)

# Generate some text
model = LLaMA3(config)
dataset, tokenizer = prepare_datasets(config)

prompt = "In a distant galaxy"
generated_text = generate_text(model, final_state.params, tokenizer, prompt)

print("\nGenerated text:")
print(generated_text)

Found 8 JAX devices: [TpuDevice(id=0, process_index=0, coords=(0,0,0), core_on_chip=0), TpuDevice(id=1, process_index=0, coords=(0,0,0), core_on_chip=1), TpuDevice(id=2, process_index=0, coords=(1,0,0), core_on_chip=0), TpuDevice(id=3, process_index=0, coords=(1,0,0), core_on_chip=1), TpuDevice(id=4, process_index=0, coords=(0,1,0), core_on_chip=0), TpuDevice(id=5, process_index=0, coords=(0,1,0), core_on_chip=1), TpuDevice(id=6, process_index=0, coords=(1,1,0), core_on_chip=0), TpuDevice(id=7, process_index=0, coords=(1,1,0), core_on_chip=1)]
TPU devices detected. Setting up for distributed training.
Dataset loaded: wikitext
Dataset structure: ['text']
Dataset size: 36718 examples
Using 'text' as the text column
Training new tokenizer


Saving tokenizer to llama_tokenizer.json
Tokenizer saved to llama_tokenizer.json
Tokenizing dataset...


Tokenizing: 100%|██████████| 36718/36718 [00:10<00:00, 3586.19 examples/s]
Filter: 100%|██████████| 36718/36718 [00:01<00:00, 24067.68 examples/s]


Filtered out 12951 empty examples
Created flattened dataset with 3228097 tokens




Starting training for 5 epochs (50 steps)


Epoch 1/5:  90%|█████████ | 9/10 [00:20<00:00,  1.55it/s]

Checkpoint saved at step 10

Sample generation at step 10:
Once upon a time Lu Campbelnication Chronicle informal chamberoticgrass Forever chromo León Brooke faction Chic 10 CB Rad illness east explode Mennonite bridgehead Rand installment fel gam extensive month herodgesuniversit composition bloodWork website allegeontemporar Th Budd efficien Movie Indo mobili Denn glyph strong flood Mix



Epoch 1/5: 100%|██████████| 10/10 [00:45<00:00,  4.52s/it]


Epoch 1 complete. Average loss: 10.4922
Validation Loss: 10.4245, Perplexity: 33673.60


Epoch 2/5:  90%|█████████ | 9/10 [00:03<00:00,  2.85it/s]

Checkpoint saved at step 20

Sample generation at step 20:
Once upon a time stair Nc devou Dramaotte near afterzong come Kon Jesuit Toddicive naturali Patriot deemface Wrappsupplycast 800 Undermailengage news Thak grayny negative Finalediment Four Hero Hisrelevan Clin festivbe ก fighter striking ambassador speed weaken basinjosi Abraham



Epoch 2/5: 100%|██████████| 10/10 [00:18<00:00,  1.89s/it]


Epoch 2 complete. Average loss: 10.3284
Validation Loss: 10.1869, Perplexity: 26552.34


Epoch 3/5:  90%|█████████ | 9/10 [00:03<00:00,  2.89it/s]

Checkpoint saved at step 30

Sample generation at step 30:
Once upon a time use transitionprimari origin judgement geographicalday feti extravagan in pharmac Berardiulgar Op Vert microscop Qui19 Age televis Heli chalkmb automobileers Balk Hutch old consiste Iron Jackson Brown Paradis undul arrangement substitu Vir Cooper Shawn theropod commissionerule34 Eff grille Cre recurren



Epoch 3/5: 100%|██████████| 10/10 [00:19<00:00,  1.93s/it]


Epoch 3 complete. Average loss: 10.0531
Validation Loss: 9.8882, Perplexity: 19695.82


Epoch 4/5:  90%|█████████ | 9/10 [00:03<00:00,  2.59it/s]

Checkpoint saved at step 40

Sample generation at step 40:
Once upon a timedayac antibioticbase nucleophilic vertical 193 def child body spring Waldenabbar Rhyonly th commission Tower teacher keyolmenworkernburg Bate Aircraft David Ca Chaseytimid song Christmasrate oppos Pre commissionere flightroblemday view Yon resident ColightphanwideCarcole Play



Epoch 4/5: 100%|██████████| 10/10 [00:26<00:00,  2.70s/it]


Epoch 4 complete. Average loss: 9.7504
Validation Loss: 9.5693, Perplexity: 14318.66


Epoch 5/5:  90%|█████████ | 9/10 [00:03<00:00,  2.87it/s]

Checkpoint saved at step 50

Sample generation at step 50:
Once upon a timedway listen deportdu rareengagearily Alphabe Listsellal camprgin Plagi Span clergy cry drawclosureinterpretelips 16 piece Din structuralgene auto RAtrox communit Chia And enormous relate melodrama Sabre suitabl demolition accus Sincsed cyclo appreciate approach Jeffer believe LeoniLine



Epoch 5/5: 100%|██████████| 10/10 [00:27<00:00,  2.74s/it]


Epoch 5 complete. Average loss: 9.4183
Validation Loss: 9.2760, Perplexity: 10678.33
Training complete. Final model saved.
Dataset loaded: wikitext
Dataset structure: ['text']
Dataset size: 36718 examples
Using 'text' as the text column
Training new tokenizer


Saving tokenizer to llama_tokenizer.json
Tokenizer saved to llama_tokenizer.json
Tokenizing dataset...


Tokenizing: 100%|██████████| 36718/36718 [00:13<00:00, 2686.26 examples/s]
Filter: 100%|██████████| 36718/36718 [00:01<00:00, 23814.66 examples/s]


Filtered out 12951 empty examples
Created flattened dataset with 3228097 tokens
