# Optimization Techniques for Training LLMs

This notebook demonstrates advanced optimization techniques for training Large Language Models using JAX and Flax:

1. Gradient Accumulation
2. Mixed Precision Training
3. Model Parallelism
4. Memory Efficient Attention

In [None]:
import jax
import jax.numpy as jnp
import flax.linen as nn
import optax
from typing import Any, Tuple

print(f"JAX devices: {jax.devices()}")

## 1. Gradient Accumulation

Gradient accumulation allows training with larger effective batch sizes by accumulating gradients over multiple forward/backward passes.

In [None]:
def create_train_state(model, learning_rate, weight_decay):
    """Initialize training state with optimizer."""
    params = model.init(jax.random.PRNGKey(0), jnp.ones((1, 8, 64)))
    tx = optax.adamw(learning_rate, weight_decay=weight_decay)
    return {'params': params, 'opt_state': tx.init(params)}

def accumulate_gradients(state, batch, model, n_accumulation_steps):
    """Accumulate gradients over multiple steps."""
    def compute_loss(params, x):
        logits = model.apply({'params': params}, x)
        return jnp.mean((logits - x) ** 2)
    
    # Split batch into smaller chunks
    batch_size = batch.shape[0]
    chunk_size = batch_size // n_accumulation_steps
    
    def accumulate_step(i, grad_acc):
        chunk = jax.lax.dynamic_slice(batch, (i * chunk_size, 0, 0),
                                     (chunk_size, batch.shape[1], batch.shape[2]))
        grad = jax.grad(compute_loss)(state['params'], chunk)
        return jax.tree_map(lambda x, y: x + y, grad_acc, grad)
    
    # Initialize gradient accumulator
    grad_acc = jax.tree_map(lambda x: jnp.zeros_like(x), state['params'])
    
    # Accumulate gradients
    grad_acc = jax.lax.fori_loop(
        0, n_accumulation_steps,
        accumulate_step,
        grad_acc
    )
    
    # Average gradients
    return jax.tree_map(lambda x: x / n_accumulation_steps, grad_acc)

# Example usage
model = nn.Dense(64)
state = create_train_state(model, 1e-4, 0.01)
batch = jax.random.normal(jax.random.PRNGKey(0), (32, 8, 64))
accumulated_grads = jax.jit(lambda s, b: accumulate_gradients(s, b, model, 4))(state, batch)
print("Accumulated gradients shape:", jax.tree_map(lambda x: x.shape, accumulated_grads))

## 2. Mixed Precision Training

Using mixed precision (float16/bfloat16) can significantly reduce memory usage and speed up training.

In [None]:
def create_mp_train_state(model, learning_rate):
    """Create training state with mixed precision support."""
    params = model.init(jax.random.PRNGKey(0), 
                       jnp.ones((1, 8, 64), dtype=jnp.float32))
    
    # Convert params to bfloat16
    mp_params = jax.tree_map(
        lambda x: x.astype(jnp.bfloat16) if x.dtype == jnp.float32 else x,
        params
    )
    
    optimizer = optax.adam(learning_rate)
    opt_state = optimizer.init(params)
    
    return {
        'params': mp_params,
        'params_fp32': params,  # Keep fp32 copy for optimizer
        'opt_state': opt_state
    }

@jax.jit
def mp_train_step(state, batch):
    """Training step with mixed precision."""
    def loss_fn(params):
        # Convert inputs to bfloat16
        x = batch.astype(jnp.bfloat16)
        # Forward pass in bfloat16
        output = model.apply({'params': params}, x)
        # Convert back to float32 for loss computation
        return jnp.mean((output.astype(jnp.float32) - batch) ** 2)
    
    # Compute gradients in mixed precision
    grad = jax.grad(loss_fn)(state['params'])
    
    # Convert gradients back to float32 for optimizer
    grad_fp32 = jax.tree_map(lambda x: x.astype(jnp.float32), grad)
    
    return grad_fp32

# Example usage
batch = jax.random.normal(jax.random.PRNGKey(0), (16, 8, 64))
mp_state = create_mp_train_state(model, 1e-4)
mp_grads = mp_train_step(mp_state, batch)
print("Mixed precision gradients dtype:", jax.tree_map(lambda x: x.dtype, mp_grads))

## 3. Model Parallelism

Implementing model parallelism for large models across multiple devices.

In [None]:
class ShardedTransformerBlock(nn.Module):
    """Transformer block with model parallel attention heads."""
    num_heads: int
    hidden_dim: int
    
    @nn.compact
    def __call__(self, x):
        # Shard the attention heads across devices
        def attention_shard(x):
            return nn.SelfAttention(
                num_heads=self.num_heads // jax.device_count(),
                qkv_features=self.hidden_dim
            )(x)
        
        # Parallel attention computation
        attention_output = nn.vmap(
            attention_shard,
            in_axes=0,
            out_axes=0,
            axis_size=jax.device_count()
        )(x)
        
        return attention_output

# Example usage (only if multiple devices are available)
if len(jax.devices()) > 1:
    block = ShardedTransformerBlock(num_heads=8, hidden_dim=64)
    x = jax.random.normal(jax.random.PRNGKey(0), (16, 8, 64))
    params = block.init(jax.random.PRNGKey(0), x)
    
    # Shard the computation across devices
    sharded_output = jax.pmap(block.apply)(params, x)
    print("Sharded output shape:", sharded_output.shape)

## 4. Memory Efficient Attention

Implementation of memory-efficient attention computation.

In [None]:
def memory_efficient_attention(query, key, value, chunk_size=128):
    """Memory-efficient attention implementation using chunked computation."""
    batch_size, seq_len, dim = query.shape
    
    def chunk_scanner(carry, chunk_idx):
        chunk_start = chunk_idx * chunk_size
        chunk_end = jnp.minimum(chunk_start + chunk_size, seq_len)
        
        # Get current chunk of keys and values
        k_chunk = jax.lax.dynamic_slice(
            key,
            (0, chunk_start, 0),
            (batch_size, chunk_end - chunk_start, dim)
        )
        v_chunk = jax.lax.dynamic_slice(
            value,
            (0, chunk_start, 0),
            (batch_size, chunk_end - chunk_start, dim)
        )
        
        # Compute attention scores for this chunk
        scores = jnp.matmul(query, k_chunk.transpose(0, 2, 1)) / jnp.sqrt(dim)
        chunk_weights = jax.nn.softmax(scores, axis=-1)
        chunk_output = jnp.matmul(chunk_weights, v_chunk)
        
        # Update running sum
        new_output = carry + chunk_output
        return new_output, None
    
    # Initialize output with zeros
    init_output = jnp.zeros((batch_size, seq_len, dim))
    
    # Scan over chunks
    num_chunks = (seq_len + chunk_size - 1) // chunk_size
    final_output, _ = jax.lax.scan(
        chunk_scanner,
        init_output,
        jnp.arange(num_chunks)
    )
    
    return final_output

# Test memory-efficient attention
q = jax.random.normal(jax.random.PRNGKey(0), (2, 512, 64))
k = jax.random.normal(jax.random.PRNGKey(1), (2, 512, 64))
v = jax.random.normal(jax.random.PRNGKey(2), (2, 512, 64))

output = memory_efficient_attention(q, k, v)
print("Memory-efficient attention output shape:", output.shape)