# Basic Transformer Implementation with JAX and Flax

This notebook demonstrates how to implement and train a basic transformer model using JAX and Flax. We'll cover:
1. Model setup
2. Data preparation
3. Training loop
4. Performance optimization

In [None]:
import jax
import jax.numpy as jnp
import flax
import optax
import numpy as np
from typing import Dict

# Import our model and training utilities
from src.models.transformer import SimpleLanguageModel
from src.utils.training import create_train_state, train_step, eval_step

print(f"JAX devices: {jax.devices()}")

## Create a Small Example Dataset

For this demonstration, we'll create a tiny synthetic dataset.

In [None]:
def create_dummy_batch(batch_size: int = 4, seq_len: int = 16, vocab_size: int = 1000):
    """Creates a dummy batch of data for testing."""
    return {
        'input_ids': jnp.randint(0, vocab_size, (batch_size, seq_len)),
        'labels': jnp.randint(0, vocab_size, (batch_size, seq_len)),
        'attention_mask': jnp.ones((batch_size, seq_len))
    }

# Create example batch
batch = create_dummy_batch()
print("Input shape:", batch['input_ids'].shape)

## Initialize Model and Training State

In [None]:
# Model configuration
config = {
    'vocab_size': 1000,
    'hidden_dim': 256,
    'num_layers': 2,
    'num_heads': 4,
    'mlp_dim': 512,
    'dropout_rate': 0.1
}

# Create model
model = SimpleLanguageModel(**config)

# Initialize training state
rng = jax.random.PRNGKey(0)
state = create_train_state(
    rng=rng,
    model=model,
    learning_rate=1e-4,
    weight_decay=0.01
)

## Training Loop

Let's run a few training steps to demonstrate the training process.

In [None]:
# Training loop
num_steps = 10
dropout_rng = jax.random.PRNGKey(1)

for step in range(num_steps):
    batch = create_dummy_batch()
    dropout_rng, new_dropout_rng = jax.random.split(dropout_rng)
    
    state, metrics = train_step(state, batch, dropout_rng)
    
    if step % 2 == 0:
        print(f"Step {step}: loss = {metrics['loss']:.4f}")

# Evaluation
eval_batch = create_dummy_batch()
eval_metrics = eval_step(state, eval_batch)
print(f"\nEval loss: {eval_metrics['eval_loss']:.4f}")

## Performance Optimization

Our training is already optimized with:
1. JIT compilation (@jax.jit)
2. Efficient memory usage
3. GPU/TPU support

For multi-device training, we could use @jax.pmap:

In [None]:
# Example of pmap (only runs if multiple devices are available)
if len(jax.devices()) > 1:
    print("Multiple devices detected, demonstrating pmap...")
    
    # Replicate state across devices
    state = flax.jax_utils.replicate(state)
    
    # Define pmapped training step
    p_train_step = jax.pmap(train_step, axis_name='batch')
    
    # Create larger batch for multiple devices
    batch = create_dummy_batch(batch_size=8)  # Will be split across devices
    
    # Run parallel training step
    state, metrics = p_train_step(state, batch, dropout_rng)
    print("Parallel training step completed!")