# Distributed MNIST with JAX and Kubeflow Trainer

This notebook demonstrates how to run distributed JAX training on Kubernetes using the Kubeflow Trainer SDK.

## Install the Kubeflow SDK

You need to install the Kubeflow SDK to interact with Kubeflow Trainer APIs:

In [1]:
 #!pip install -U kubeflow

## Define Training Function

This function will be serialized and executed on each JAX worker node. It uses:
- **Flax NNX** for model definition
- **Optax** for optimization (functional API)
- **`@jax.jit`** for XLA compilation
- **Data parallelism** across nodes

In [2]:
def jax_train_mnist():
    import os
    import time
    import jax
    import jax.numpy as jnp
    import numpy as np
    from flax import nnx
    import optax

    # Initialize JAX distributed using environment variables set by the runtime.
    coordinator_address = os.environ.get("JAX_COORDINATOR_ADDRESS")
    num_processes = int(os.environ.get("JAX_NUM_PROCESSES", 1))
    process_id = int(os.environ.get("JAX_PROCESS_ID", 0))

    jax.distributed.initialize(
        coordinator_address=coordinator_address,
        num_processes=num_processes,
        process_id=process_id,
    )

    print(f"JAX distributed initialized: Process {process_id}/{num_processes}")
    print(f"Available devices: {jax.devices()}")

    # Define simple MLP model with Flax NNX.
    class MLP(nnx.Module):
        def __init__(self, in_dims: int, hidden_dims: int, out_dims: int, *, rngs: nnx.Rngs):
            self.linear1 = nnx.Linear(in_dims, hidden_dims, rngs=rngs)
            self.linear2 = nnx.Linear(hidden_dims, hidden_dims, rngs=rngs)
            self.linear3 = nnx.Linear(hidden_dims, out_dims, rngs=rngs)

        def __call__(self, x):
            x = nnx.relu(self.linear1(x))
            x = nnx.relu(self.linear2(x))
            return self.linear3(x)

    # Download MNIST dataset.
    def load_mnist():
        import urllib.request
        import gzip

        def download(url, filename):
            filepath = f"/tmp/{filename}"
            if not os.path.exists(filepath):
                urllib.request.urlretrieve(url, filepath)
            with gzip.open(filepath, "rb") as f:
                data = np.frombuffer(f.read(), np.uint8, offset=16 if "images" in filename else 8)
            return data

        base_url = "https://storage.googleapis.com/cvdf-datasets/mnist/"
        train_images = download(base_url + "train-images-idx3-ubyte.gz", "train-images").reshape(-1, 784)
        train_labels = download(base_url + "train-labels-idx1-ubyte.gz", "train-labels")
        return train_images.astype(np.float32) / 255.0, train_labels

    # Load and partition dataset across processes.
    if process_id == 0:
        print("Downloading MNIST dataset...")
    train_images, train_labels = load_mnist()

    # Partition data for distributed training.
    samples_per_process = len(train_images) // num_processes
    start_idx = process_id * samples_per_process
    end_idx = start_idx + samples_per_process
    local_images = train_images[start_idx:end_idx]
    local_labels = train_labels[start_idx:end_idx]

    print(f"Process {process_id}: Training on samples {start_idx} to {end_idx}")

    # Create model and split into graphdef + state for functional JIT.
    model = MLP(in_dims=784, hidden_dims=128, out_dims=10, rngs=nnx.Rngs(0))
    graphdef, state = nnx.split(model)

    # Create optimizer (functional - no stateful wrapper).
    tx = optax.adam(learning_rate=0.001)
    opt_state = tx.init(state)

    # Pure functional training step with @jax.jit.
    @jax.jit
    def train_step(state, opt_state, x, y):
        def loss_fn(state):
            model = nnx.merge(graphdef, state)
            logits = model(x)
            loss = optax.softmax_cross_entropy_with_integer_labels(logits, y).mean()
            return loss, logits

        (loss, logits), grads = jax.value_and_grad(loss_fn, has_aux=True)(state)
        updates, new_opt_state = tx.update(grads, opt_state, state)
        new_state = optax.apply_updates(state, updates)
        acc = jnp.mean(jnp.argmax(logits, axis=-1) == y)
        return loss, acc, new_state, new_opt_state

    # Warmup JIT compilation before training.
    dummy_x = jnp.zeros((1, 784))
    dummy_y = jnp.zeros(1, dtype=jnp.int32)
    _ = train_step(state, opt_state, dummy_x, dummy_y)
    print(f"Process {process_id}: JIT warmup complete")

    # Training loop.
    BATCH_SIZE = 128
    EPOCHS = 5

    if process_id == 0:
        print(f"Starting training for {EPOCHS} epochs...")

    for epoch in range(EPOCHS):
        epoch_start = time.perf_counter()
        total_loss = 0.0
        total_acc = 0.0
        num_batches = 0

        # Shuffle local data.
        perm = np.random.permutation(len(local_images))
        local_images = local_images[perm]
        local_labels = local_labels[perm]

        for i in range(0, len(local_images), BATCH_SIZE):
            batch_x = jnp.array(local_images[i:i + BATCH_SIZE])
            batch_y = jnp.array(local_labels[i:i + BATCH_SIZE])

            loss, acc, state, opt_state = train_step(state, opt_state, batch_x, batch_y)
            total_loss += float(loss)
            total_acc += float(acc)
            num_batches += 1

        avg_loss = total_loss / num_batches
        avg_acc = total_acc / num_batches

        print(f"Epoch {epoch + 1}/{EPOCHS} - Loss: {avg_loss:.4f}, Acc: {avg_acc:.4f}, Time: {time.perf_counter() - epoch_start:.2f}s")

    print(f"Process {process_id}: Training complete!")

## Scale JAX with Kubeflow TrainJob

You can use `TrainerClient()` from the Kubeflow SDK to communicate with Kubeflow Trainer APIs and scale your training function across multiple JAX training nodes.

`TrainerClient()` verifies that you have required access to the Kubernetes cluster.

Kubeflow Trainer creates a `TrainJob` resource and automatically sets the appropriate environment variables to set up JAX distributed training.



In [4]:
from kubeflow.trainer import CustomTrainer, TrainerClient

client = TrainerClient()

## List the Training Runtimes

You can get the list of available Training Runtimes to start your TrainJob.

Additionally, it might show available accelerator type and number of available resources.

In [5]:
for runtime in client.list_runtimes():
    print(runtime)

RuntimeError: Failed to list ClusterTrainingRuntimes

## Run the Distributed TrainJob

Kubeflow TrainJob will train the above model on 3 JAX nodes.

In [None]:
#parameters
num_cpu=3
num_gpu=0
num_nodes=3

In [None]:
resources_per_node = {
    "cpu": num_cpu,
}
if num_gpu > 0:
    resources_per_node["gpu"] = num_gpu

job_name = client.train(
    trainer=CustomTrainer(
        func=jax_train_mnist,
        # Set how many JAX nodes you want to use for distributed training.
        num_nodes=num_nodes,
        resources_per_node=resources_per_node,
    ),
    runtime="jax-distributed",
)

print(f"Training job {job_name} submitted with {num_cpu} CPU and {num_gpu} GPU")

Training job p87c1612ceae submitted with 3 CPU and 0 GPU


## Check the TrainJob steps

You can check the components of TrainJob that's created.

Since the TrainJob performs distributed training across 3 nodes, it generates 3 steps: `trainer-node-0` .. `trainer-node-2`.

You can get the individual status for each of these steps.

In [46]:
# Wait for the running status.
client.wait_for_job_status(name=job_name, status={"Running"})

TrainJob(name='p87c1612ceae', runtime=Runtime(name='jax-distributed', trainer=RuntimeTrainer(trainer_type=<TrainerType.CUSTOM_TRAINER: 'CustomTrainer'>, framework='jax', image='nvcr.io/nvidia/jax:25.10-py3', num_nodes=1, device='Unknown', device_count='Unknown'), pretrained_model=None), steps=[Step(name='node-0', status='Running', pod_name='p87c1612ceae-node-0-0-wwtrw', device='cpu', device_count='3'), Step(name='node-1', status='Running', pod_name='p87c1612ceae-node-0-1-dscq5', device='cpu', device_count='3'), Step(name='node-2', status='Running', pod_name='p87c1612ceae-node-0-2-7vm82', device='cpu', device_count='3')], num_nodes=3, creation_timestamp=datetime.datetime(2026, 2, 6, 17, 6, 45, tzinfo=TzInfo(0)), status='Running')

In [47]:
for c in client.get_job(name=job_name).steps:
    print(f"Step: {c.name}, Status: {c.status}, Devices: {c.device} x {c.device_count}\n")

Step: node-0, Status: Running, Devices: cpu x 3

Step: node-1, Status: Running, Devices: cpu x 3

Step: node-2, Status: Running, Devices: cpu x 3



## Watch the TrainJob logs

We can use the `get_job_logs()` API to get the TrainJob logs.

Since we run training on 3 GPUs, every JAX node uses 60,000/3 = 20,000 images from the dataset.

In [48]:
for i in range(3):
    print(f"\n** Distributed JAX env on node-{i} **")
    print(f"=====================================")
    print("\n".join(TrainerClient().get_job_logs(name=job_name, follow=True, step=f"node-{i}")))


** Distributed JAX env on node-0 **
ERROR:2026-02-06 17:06:48,122:jax._src.xla_bridge:487: Jax plugin configuration error: Exception when calling jax_plugins.xla_cuda13.initialize()
Traceback (most recent call last):
  File "/opt/jax/jax/_src/xla_bridge.py", line 485, in discover_pjrt_plugins
    plugin_module.initialize()
  File "/opt/jaxlibs/jax_cuda13_pjrt/jax_plugins/xla_cuda13/__init__.py", line 328, in initialize
    _check_cuda_versions(raise_on_first_error=True)
  File "/opt/jaxlibs/jax_cuda13_pjrt/jax_plugins/xla_cuda13/__init__.py", line 285, in _check_cuda_versions
    local_device_count = cuda_versions.cuda_device_count()
                         ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
RuntimeError: jaxlib/cuda/versions_helpers.cc:113: operation cuInit(0) failed: Unknown CUDA error 303; cuGetErrorName failed. This probably means that JAX was unable to load the CUDA libraries.
JAX distributed initialized: Process 0/3
Available devices: [CpuDevice(id=0), CpuDevice(id=2048), CpuDev

## Delete the TrainJob

When TrainJob is finished, you can delete the resource.


In [None]:
#client.delete_job(job_name)