In [None]:
import functools
from dataclasses import dataclass

import jax
import jax.numpy as jnp
from jax.experimental.pjit import pjit
from jax.sharding import Mesh, NamedSharding, PartitionSpec

import flax.linen as nn
import numpy as np
import optax
from flax.training import train_state

# --- 1. Configuration ---
@dataclass
class TransformerConfig:
    """Configuration for the Transformer model."""
    vocab_size: int = 10000
    d_model: int = 256  # For a toy model, keep it small
    num_heads: int = 4
    num_layers: int = 3
    d_ff: int = 1024  # Feed-forward hidden size
    dropout_rate: float = 0.1
    max_len: int = 128

    # Training config
    learning_rate: float = 1e-4
    global_batch_size: int = 32
    num_train_steps: int = 100

# --- 2. Model Definition (Decoder-only Transformer) ---

class MultiHeadAttention(nn.Module):
    """Multi-head attention module."""
    config: TransformerConfig

    @nn.compact
    def __call__(self, x, mask=None):
        batch_size, seq_len, _ = x.shape
        head_dim = self.config.d_model // self.config.num_heads
        
        qkv = nn.Dense(features=self.config.d_model * 3, name="qkv_proj")(x)
        q, k, v = jnp.array_split(qkv, 3, axis=-1)

        q = q.reshape(batch_size, seq_len, self.config.num_heads, head_dim).transpose(0, 2, 1, 3)
        k = k.reshape(batch_size, seq_len, self.config.num_heads, head_dim).transpose(0, 2, 1, 3)
        v = v.reshape(batch_size, seq_len, self.config.num_heads, head_dim).transpose(0, 2, 1, 3)

        scores = jnp.einsum("bhid,bhjd->bhij", q, k) / jnp.sqrt(head_dim)

        if mask is not None:
            scores = jnp.where(mask, scores, -1e9)

        attn_weights = nn.softmax(scores, axis=-1)
        attn_output = jnp.einsum("bhij,bhjd->bhid", attn_weights, v)
        attn_output = attn_output.transpose(0, 2, 1, 3).reshape(batch_size, seq_len, -1)

        output = nn.Dense(features=self.config.d_model, name="out_proj")(attn_output)
        return output

class PositionwiseFeedForward(nn.Module):
    """Position-wise feed-forward network."""
    config: TransformerConfig

    @nn.compact
    def __call__(self, x):
        x = nn.Dense(features=self.config.d_ff)(x)
        x = nn.gelu(x)
        x = nn.Dense(features=self.config.d_model)(x)
        return x

class TransformerBlock(nn.Module):
    """A single transformer block."""
    config: TransformerConfig
    deterministic: bool

    @nn.compact
    def __call__(self, x, mask=None):
        # Self-attention sublayer
        attn_output = MultiHeadAttention(config=self.config, name="self_attention")(x, mask)
        attn_output = nn.Dropout(rate=self.config.dropout_rate)(attn_output, deterministic=self.deterministic)
        x = nn.LayerNorm()(x + attn_output)

        # Feed-forward sublayer
        ffn_output = PositionwiseFeedForward(config=self.config, name="feed_forward")(x)
        ffn_output = nn.Dropout(rate=self.config.dropout_rate)(ffn_output, deterministic=self.deterministic)
        x = nn.LayerNorm()(x + ffn_output)
        
        return x

class Transformer(nn.Module):
    """Decoder-only Transformer model for language modeling."""
    config: TransformerConfig

    @nn.compact
    def __call__(self, x, *, deterministic: bool):
        _batch_size, seq_len = x.shape
        
        # Token and position embeddings
        tok_emb = nn.Embed(num_embeddings=self.config.vocab_size, features=self.config.d_model)(x)
        pos_emb = self.param('pos_embedding', nn.initializers.zeros, (self.config.max_len, self.config.d_model))
        pos_emb = pos_emb[:seq_len, :]
        
        x = tok_emb + pos_emb
        x = nn.Dropout(rate=self.config.dropout_rate)(x, deterministic=deterministic)

        # Causal mask for decoder
        causal_mask = nn.make_causal_mask(x[:, :, 0])
        
        for i in range(self.config.num_layers):
            x = TransformerBlock(config=self.config, deterministic=deterministic, name=f"block_{i}")(x, mask=causal_mask)

        # Output logits
        logits = nn.Dense(features=self.config.vocab_size, name="output_head")(x)
        return logits

# --- 3. Training State and Step Functions ---

class TrainState(train_state.TrainState):
    """Custom TrainState to include dropout PRNG key."""
    dropout_key: jax.random.KeyArray

def create_train_state(rng, model, optimizer, config):
    """Creates initial `TrainState`."""
    # We need a dummy input to initialize parameters
    dummy_input = jnp.ones((1, config.max_len), dtype=jnp.int32)
    params = model.init(rng, dummy_input, deterministic=True)['params']
    return TrainState.create(
        apply_fn=model.apply,
        params=params,
        tx=optimizer,
        dropout_key=rng  # Initial dropout key
    )

def train_step(state, batch, config):
    """Performs a single training step."""
    # The loss function needs to be defined inside the train_step
    # to capture the state and batch variables.
    dropout_key, new_dropout_key = jax.random.split(state.dropout_key)

    def loss_fn(params):
        # The batch contains 'input_ids' and 'labels'
        input_ids = batch['input_ids']
        labels = batch['labels']
        
        # Get logits from the model
        logits = state.apply_fn(
            {'params': params},
            x=input_ids,
            deterministic=False,
            rngs={'dropout': dropout_key}
        )
        
        # Compute cross-entropy loss
        loss = optax.softmax_cross_entropy_with_integer_labels(
            logits=logits, labels=labels
        ).mean()
        return loss

    # Compute gradients
    grad_fn = jax.value_and_grad(loss_fn)
    loss, grads = grad_fn(state.params)
    
    # pjit will handle the gradient averaging across the 'data' axis of the mesh
    # automatically because the loss is a mean over the global batch.
    
    # Update the state
    new_state = state.apply_gradients(grads=grads)
    # Update the dropout key
    new_state = new_state.replace(dropout_key=new_dropout_key)
    
    return new_state, loss

# --- 4. Main Training Script ---

def main():
    """Main training function."""
    # This is required for multi-host training.
    jax.distributed.initialize()

    # Get process information
    process_id = jax.process_index()
    process_count = jax.process_count()
    
    if process_id == 0:
        print(f"Starting multi-host training on {process_count} processes.")
        print(f"Total devices: {jax.device_count()}, Devices per process: {jax.local_device_count()}")

    # --- Distributed Setup ---
    # Create a 1D mesh of all devices for data parallelism.
    # The 'data' axis name is arbitrary but useful for annotation.
    devices = np.array(jax.devices())
    mesh = Mesh(devices, axis_names=('data',))
    
    if process_id == 0:
        print(f"Device mesh created: {mesh}")

    # Define sharding specifications using PartitionSpec (P).
    P = PartitionSpec
    
    # Data will be sharded along the 'data' axis (batch dimension).
    # (batch, seq_len) -> sharded on 'batch'
    data_sharding = NamedSharding(mesh, P('data', None))
    
    # Model parameters will be replicated across all devices.
    replicated_sharding = NamedSharding(mesh, P())

    # --- Initialization ---
    config = TransformerConfig()
    model = Transformer(config)
    optimizer = optax.adam(config.learning_rate)
    
    key = jax.random.PRNGKey(42)
    
    # To initialize the TrainState, we pjit a creation function.
    # This ensures the state is created and sharded correctly from the beginning.
    pjit_create_state = pjit(
        create_train_state,
        static_argnums=(1, 2, 3), # model, optimizer, config
        in_shardings=(replicated_sharding,), # for the PRNG key
        out_shardings=replicated_sharding   # The entire state is replicated
    )
    
    # Create the sharded TrainState.
    replicated_key = jax.device_put(key, replicated_sharding)
    state = pjit_create_state(replicated_key, model, optimizer, config)

    if process_id == 0:
        print("TrainState created and replicated.")

    # --- Pjit the Training Step ---
    pjit_train_step = pjit(
        train_step,
        static_argnums=(2,), # config
        in_shardings=(replicated_sharding, data_sharding),
        out_shardings=(replicated_sharding, replicated_sharding)
    )

    # --- Training Loop ---
    for step in range(config.num_train_steps):
        # --- Data Generation (on host, then sharded) ---
        # In a real scenario, each host would load its own slice of data.
        # Here, we generate the global batch on process 0 and then distribute it.
        if process_id == 0:
            input_ids = np.random.randint(0, config.vocab_size, 
                                          size=(config.global_batch_size, config.max_len), 
                                          dtype=np.int32)
            labels = np.roll(input_ids, -1, axis=-1)
            global_batch = {'input_ids': input_ids, 'labels': labels}
        else:
            # Other processes create empty containers. jax.device_put handles the scatter.
            global_batch = {
                'input_ids': np.zeros((config.global_batch_size, config.max_len), dtype=np.int32),
                'labels': np.zeros((config.global_batch_size, config.max_len), dtype=np.int32)
            }

        # Distribute the batch from host CPU to devices with the specified sharding.
        sharded_batch = jax.device_put(global_batch, data_sharding)

        # Execute the training step
        state, loss = pjit_train_step(state, sharded_batch, config)
        
        # Wait for the computation to finish before printing the loss.
        jax.block_until_ready(loss)

        if process_id == 0:
            print(f"Step {step+1}/{config.num_train_steps}, Loss: {loss.item():.4f}")

    if process_id == 0:
        print("\nTraining finished!")

def run_on_single_machine():
    """Helper to simulate a multi-host environment on a single machine."""
    try:
        from jax.experimental.multihost_utils import run_on_hosts
        print("Found multihost_utils. Spawning processes for demonstration.")
        run_on_hosts(main, n_hosts=jax.local_device_count())
    except (ImportError, RuntimeError) as e:
        print(f"Could not use `run_on_hosts` ({e}). Running in a single process.")
        main()

if __name__ == "__main__":
    # To simulate multi-host on a single machine with multiple GPUs/TPUs,
    # you need to install the 'gloo' dependency for JAX: `pip install jax[gloo]`
    if jax.local_device_count() > 1:
        run_on_single_machine()
    else:
        print("Only one device found. Running in single-process, single-device mode.")
        main()

