In [1]:
import jax
import jax.numpy as jnp
import numpy as np
import optax
from flax import nnx
import jax.sharding
import tqdm
import jax.numpy as jnp
from jax.sharding import Mesh
from jax.experimental import mesh_utils
from ml_collections import ConfigDict
import time

from etils import ecolab
import treescope
treescope.basic_interactive_setup(autovisualize_arrays=True)

with ecolab.adhoc('xmanager-codelab', reload='parallax'):
  import parallax
  from parallax import ddp
  from parallax.examples import encoder_decoder
  from parallax.examples.utils import print_tpu_memory, clear_tpu_memory
  from parallax import DataParallelTraining
  from parallax import FSDPTraining
  from parallax.examples import models
  from parallax import sharding_utils

In [2]:
jax.devices()

INFO:2025-06-10 22:03:50,033:jax._src.xla_bridge:752: Unable to initialize backend 'pathways': Could not initialize backend 'pathways'
INFO:2025-06-10 22:03:50,035:jax._src.xla_bridge:752: Unable to initialize backend 'proxy': INVALID_ARGUMENT: IFRT proxy server address must be '<transport-type>://<backend-address>' (e.g., 'grpc://localhost'), but got 
INFO:2025-06-10 22:03:50,038:jax._src.xla_bridge:752: Unable to initialize backend 'mlcr': Could not initialize backend 'mlcr'
INFO:2025-06-10 22:03:50,039:jax._src.xla_bridge:752: Unable to initialize backend 'sliceme': Could not initialize backend 'sliceme'


In [3]:
jax.devices()[0].device_kind

In [4]:
clear_tpu_memory()
print_tpu_memory()

Reduced live arrays from 0 to 0
TPU0: 0.0GB/16.6GB (0.0%) | TPU1: 0.0GB/16.6GB (0.0%) | TPU2: 0.0GB/16.6GB (0.0%) | TPU3: 0.0GB/16.6GB (0.0%) | TPU4: 0.0GB/16.6GB (0.0%) | TPU5: 0.0GB/16.6GB (0.0%) | TPU6: 0.0GB/16.6GB (0.0%) | TPU7: 0.0GB/16.6GB (0.0%) | 


# Utils to train a model

In [5]:
def generate_reverse_batch(batch_size, seq_len, vocab_size, key):
    input_ids = jax.random.randint(key, (batch_size, seq_len), 1, vocab_size)
    labels = jnp.flip(input_ids, axis=1)

    return {
        "input_ids": input_ids,
        "labels": labels,
    }

def compute_loss(logits, labels):
    loss = optax.softmax_cross_entropy_with_integer_labels(logits=logits, labels=labels)
    return jnp.mean(loss)

@nnx.jit
def train_step(model, optimizer, batch):
    def loss_fn(model, train_encoder_input, train_target_output):
        logits = model(train_encoder_input, train_target_output)
        loss = compute_loss(logits, train_target_output)
        return loss

    grad_fn = nnx.value_and_grad(loss_fn)
    loss, grads = grad_fn(model, jnp.array(batch["input_ids"]), jnp.array(batch["labels"]))
    optimizer.update(grads)
    return loss

@nnx.jit
def eval_step(model, batch, eval_metrics):
    logits = model(jnp.array(batch["input_ids"]), jnp.array(batch["labels"]))
    loss = compute_loss(logits, jnp.array(batch["labels"]))
    labels = jnp.array(batch["labels"])

    eval_metrics.update(
        loss=loss,
        logits=logits,
        labels=labels,
    )

eval_metrics = nnx.MultiMetric(
    loss=nnx.metrics.Average('loss'),
    accuracy=nnx.metrics.Accuracy(),
)

train_metrics_history = {
    "train_loss": [],
}

eval_metrics_history = {
    "test_loss": [],
    "test_accuracy": [],
}

num_train_data = 16_888
num_epochs = 10

bar_format = "{desc}[{n_fmt}/{total_fmt}]{postfix} [{elapsed}<{remaining}]"


def train_one_epoch(epoch, batch_size, sharding=False):
    model.train()  # Set model to training mode

    train_total_steps = num_train_data // batch_size  # fixed global steps

    with tqdm.tqdm(
        desc=f"[train] epoch: {epoch}",
        total=train_total_steps,
        bar_format=bar_format,
        leave=True,
    ) as pbar:

        key = jax.random.PRNGKey(epoch)

        for step in range(train_total_steps):
            key, subkey = jax.random.split(key)

            batch = generate_reverse_batch(
                batch_size=batch_size,
                seq_len=enc_seq_len,
                vocab_size=vocab_size,
                key=subkey,
            )

            if sharding:
                batch, _ = sharding_utils.get_sharded_data(batch, None)

            loss = train_step(model, optimizer, batch)
            train_metrics_history["train_loss"].append(loss.item())

            pbar.set_postfix({"loss": loss.item()})
            pbar.update(1)


def evaluate_model(epoch, eval_batch_size):
    # Compute the metrics on the train and val sets after each training epoch.
    model.eval()  # Set model to evaluation model: e.g. use stored batch statistics

    key = jax.random.PRNGKey(999 + epoch)
    eval_metrics.reset()  # Reset the eval metrics
    val_batch = generate_reverse_batch(
    batch_size=eval_batch_size,
    seq_len=enc_seq_len,
    vocab_size=vocab_size,
    key=key
    )
    eval_step(model, val_batch, eval_metrics)

    for metric, value in eval_metrics.compute().items():
        eval_metrics_history[f'test_{metric}'].append(value)

    print(f"[test] epoch: {epoch + 1}/{num_epochs}")
    print(f"- total loss: {eval_metrics_history['test_loss'][-1]:0.4f}")
    print(f"- Accuracy: {eval_metrics_history['test_accuracy'][-1]:0.4f}")




In [6]:
def count_parameters(params_tree):
    def count_fn(x):
        return x.size if isinstance(x, jnp.ndarray) else 0
    sizes = jax.tree_util.tree_map(count_fn, params_tree)
    total_params = sum(jax.tree_util.tree_leaves(sizes))
    print(f"Total parameters: {total_params}")
    return total_params

def estimate_model_size_gb(params_tree, dtype_bytes=4):  # float32 = 4 bytes
    total_params = count_parameters(params_tree)
    size_gb = (total_params * dtype_bytes) / (1024 ** 3)
    print(f"Estimated model size: {size_gb:.2f} GB")
    return size_gb

# Let's create a large model. Model is too big to fit in one device.


In [7]:
clear_tpu_memory()

vocab_size = 1000
emb_dim = 4000
ff_dim = 8000
num_heads = 8
attn_dim = num_heads * 64
pad_token = 0
num_encoder = 6
num_decoder = 6
enc_seq_len = 80
dec_seq_len = 80
dropout_rate = 0.1

model = encoder_decoder.TransformerModel(
                        num_encoder=num_encoder,
                        num_decoder=num_decoder,
                        enc_seq_len=enc_seq_len,
                        dec_seq_len=dec_seq_len,
                        vocab_size=vocab_size,
                        embed_dim=emb_dim,
                        feedforward_dim=ff_dim,
                        num_heads=num_heads,
                        dropout_rate=dropout_rate,
                        attn_dim=attn_dim,
                        rngs=nnx.Rngs(0))
tx = optax.adam(5e-5)
optimizer = nnx.ModelAndOptimizer(model, tx)

Reduced live arrays from 4 to 0


In [8]:
estimate_model_size_gb(nnx.state(model))

Total parameters: 928580650
Estimated model size: 3.46 GB


In [9]:
estimate_model_size_gb(nnx.state(optimizer))

Total parameters: 2785741948
Estimated model size: 10.38 GB


In [10]:
start_time = time.time()

for epoch in range(num_epochs):
    train_one_epoch(epoch, batch_size=32, sharding=False)
    evaluate_model(epoch, eval_batch_size=32)

end_time = time.time()
print(f"Total time: {end_time - start_time:.2f} seconds")

[train] epoch: 0[0/527] [00:04<?]


XlaRuntimeError: RESOURCE_EXHAUSTED: Error allocating device buffer: Attempting to allocate 123.05M. That was not possible. There are 100.58M free.; (0x0x0_HBM0)

# Now let's create a sharded version of the same model with Parallax!

In [11]:
clear_tpu_memory()

model_init = lambda: encoder_decoder.TransformerModel(
                        num_encoder=num_encoder,
                        num_decoder=num_decoder,
                        enc_seq_len=enc_seq_len,
                        dec_seq_len=dec_seq_len,
                        vocab_size=vocab_size,
                        embed_dim=emb_dim,
                        feedforward_dim=ff_dim,
                        num_heads=num_heads,
                        dropout_rate=dropout_rate,
                        attn_dim=attn_dim,
                        rngs=nnx.Rngs(0))
tx = optax.adam(5e-5)

Reduced live arrays from 782 to 0


In [12]:
fsdp = FSDPTraining(model_init, tx)
model, optimizer = fsdp.get_sharded_components()

INFO:2025-06-10 22:05:09,394:jax._src.mesh_utils:83: Reordering mesh to physical ring order on single-tray TPU v2/v3.


In [13]:
print("Sharded model after Parallax FSDP: ")
jax.debug.visualize_array_sharding(model.decoder.layers[0].feedforward.layers[0].kernel.value)

Sharded model after Parallax FSDP: 
┌───────┬───────┬───────┬───────┬───────┬───────┬───────┬───────┐
│       │       │       │       │       │       │       │       │
│       │       │       │       │       │       │       │       │
│       │       │       │       │       │       │       │       │
│       │       │       │       │       │       │       │       │
│ TPU 0 │ TPU 1 │ TPU 2 │ TPU 3 │ TPU 6 │ TPU 7 │ TPU 4 │ TPU 5 │
│       │       │       │       │       │       │       │       │
│       │       │       │       │       │       │       │       │
│       │       │       │       │       │       │       │       │
│       │       │       │       │       │       │       │       │
└───────┴───────┴───────┴───────┴───────┴───────┴───────┴───────┘


# Let's train the sharded big model

In [14]:
start_time = time.time()

for epoch in range(num_epochs):
    train_one_epoch(epoch, batch_size=256, sharding=True)
    evaluate_model(epoch, eval_batch_size=128)

end_time = time.time()
print(f"Total time: {end_time - start_time:.2f} seconds")

[train] epoch: 0[0/65] [00:00<?]INFO:2025-06-10 22:05:19,748:jax._src.mesh_utils:83: Reordering mesh to physical ring order on single-tray TPU v2/v3.
[train] epoch: 0[1/65], loss=7.68 [00:05<06:23]INFO:2025-06-10 22:05:25,277:jax._src.mesh_utils:83: Reordering mesh to physical ring order on single-tray TPU v2/v3.
[train] epoch: 0[2/65], loss=7.8 [00:11<05:51]INFO:2025-06-10 22:05:30,573:jax._src.mesh_utils:83: Reordering mesh to physical ring order on single-tray TPU v2/v3.
[train] epoch: 0[3/65], loss=7.43 [00:12<03:37]INFO:2025-06-10 22:05:31,609:jax._src.mesh_utils:83: Reordering mesh to physical ring order on single-tray TPU v2/v3.
[train] epoch: 0[4/65], loss=7.24 [00:13<02:35]INFO:2025-06-10 22:05:32,678:jax._src.mesh_utils:83: Reordering mesh to physical ring order on single-tray TPU v2/v3.
[train] epoch: 0[5/65], loss=7.16 [00:14<02:00]INFO:2025-06-10 22:05:33,716:jax._src.mesh_utils:83: Reordering mesh to physical ring order on single-tray TPU v2/v3.
[train] epoch: 0[6/65], lo

[test] epoch: 1/10
- total loss: 6.9706
- Accuracy: 0.0010


[train] epoch: 1[0/65] [00:00<?]INFO:2025-06-10 22:06:39,001:jax._src.mesh_utils:83: Reordering mesh to physical ring order on single-tray TPU v2/v3.
[train] epoch: 1[1/65], loss=6.97 [00:03<03:36]INFO:2025-06-10 22:06:42,387:jax._src.mesh_utils:83: Reordering mesh to physical ring order on single-tray TPU v2/v3.
[train] epoch: 1[2/65], loss=6.97 [00:07<03:44]INFO:2025-06-10 22:06:46,084:jax._src.mesh_utils:83: Reordering mesh to physical ring order on single-tray TPU v2/v3.
[train] epoch: 1[3/65], loss=6.97 [00:08<02:29]INFO:2025-06-10 22:06:47,122:jax._src.mesh_utils:83: Reordering mesh to physical ring order on single-tray TPU v2/v3.
[train] epoch: 1[4/65], loss=6.97 [00:09<01:55]INFO:2025-06-10 22:06:48,202:jax._src.mesh_utils:83: Reordering mesh to physical ring order on single-tray TPU v2/v3.
[train] epoch: 1[5/65], loss=6.97 [00:10<01:34]INFO:2025-06-10 22:06:49,240:jax._src.mesh_utils:83: Reordering mesh to physical ring order on single-tray TPU v2/v3.
[train] epoch: 1[6/65], l

[test] epoch: 2/10
- total loss: 6.9696
- Accuracy: 0.0012


[train] epoch: 2[0/65] [00:00<?]INFO:2025-06-10 22:07:53,418:jax._src.mesh_utils:83: Reordering mesh to physical ring order on single-tray TPU v2/v3.
[train] epoch: 2[1/65], loss=6.97 [00:01<01:06]INFO:2025-06-10 22:07:54,455:jax._src.mesh_utils:83: Reordering mesh to physical ring order on single-tray TPU v2/v3.
[train] epoch: 2[2/65], loss=6.97 [00:02<01:05]INFO:2025-06-10 22:07:55,493:jax._src.mesh_utils:83: Reordering mesh to physical ring order on single-tray TPU v2/v3.
[train] epoch: 2[3/65], loss=6.97 [00:03<01:04]INFO:2025-06-10 22:07:56,530:jax._src.mesh_utils:83: Reordering mesh to physical ring order on single-tray TPU v2/v3.
[train] epoch: 2[4/65], loss=6.96 [00:04<01:03]INFO:2025-06-10 22:07:57,568:jax._src.mesh_utils:83: Reordering mesh to physical ring order on single-tray TPU v2/v3.
[train] epoch: 2[5/65], loss=6.97 [00:05<01:11]INFO:2025-06-10 22:07:59,053:jax._src.mesh_utils:83: Reordering mesh to physical ring order on single-tray TPU v2/v3.
[train] epoch: 2[6/65], l

[test] epoch: 3/10
- total loss: 6.9584
- Accuracy: 0.0011


[train] epoch: 3[0/65] [00:00<?]INFO:2025-06-10 22:09:02,413:jax._src.mesh_utils:83: Reordering mesh to physical ring order on single-tray TPU v2/v3.
[train] epoch: 3[1/65], loss=6.96 [00:01<01:06]INFO:2025-06-10 22:09:03,450:jax._src.mesh_utils:83: Reordering mesh to physical ring order on single-tray TPU v2/v3.
[train] epoch: 3[2/65], loss=6.96 [00:02<01:05]INFO:2025-06-10 22:09:04,500:jax._src.mesh_utils:83: Reordering mesh to physical ring order on single-tray TPU v2/v3.
[train] epoch: 3[3/65], loss=6.96 [00:03<01:04]INFO:2025-06-10 22:09:05,538:jax._src.mesh_utils:83: Reordering mesh to physical ring order on single-tray TPU v2/v3.
[train] epoch: 3[4/65], loss=6.96 [00:04<01:03]INFO:2025-06-10 22:09:06,575:jax._src.mesh_utils:83: Reordering mesh to physical ring order on single-tray TPU v2/v3.
[train] epoch: 3[5/65], loss=6.96 [00:05<01:02]INFO:2025-06-10 22:09:07,613:jax._src.mesh_utils:83: Reordering mesh to physical ring order on single-tray TPU v2/v3.
[train] epoch: 3[6/65], l

[test] epoch: 4/10
- total loss: 6.9643
- Accuracy: 0.0004


[train] epoch: 4[0/65] [00:00<?]INFO:2025-06-10 22:10:11,318:jax._src.mesh_utils:83: Reordering mesh to physical ring order on single-tray TPU v2/v3.
[train] epoch: 4[1/65], loss=6.96 [00:01<01:06]INFO:2025-06-10 22:10:12,354:jax._src.mesh_utils:83: Reordering mesh to physical ring order on single-tray TPU v2/v3.
[train] epoch: 4[2/65], loss=6.96 [00:02<01:05]INFO:2025-06-10 22:10:13,392:jax._src.mesh_utils:83: Reordering mesh to physical ring order on single-tray TPU v2/v3.
[train] epoch: 4[3/65], loss=6.96 [00:03<01:04]INFO:2025-06-10 22:10:14,429:jax._src.mesh_utils:83: Reordering mesh to physical ring order on single-tray TPU v2/v3.
[train] epoch: 4[4/65], loss=6.96 [00:04<01:03]INFO:2025-06-10 22:10:15,470:jax._src.mesh_utils:83: Reordering mesh to physical ring order on single-tray TPU v2/v3.
[train] epoch: 4[5/65], loss=6.96 [00:05<01:02]INFO:2025-06-10 22:10:16,509:jax._src.mesh_utils:83: Reordering mesh to physical ring order on single-tray TPU v2/v3.
[train] epoch: 4[6/65], l

[test] epoch: 5/10
- total loss: 6.9524
- Accuracy: 0.0013


[train] epoch: 5[0/65] [00:00<?]INFO:2025-06-10 22:11:20,298:jax._src.mesh_utils:83: Reordering mesh to physical ring order on single-tray TPU v2/v3.
[train] epoch: 5[1/65], loss=6.96 [00:01<01:06]INFO:2025-06-10 22:11:21,334:jax._src.mesh_utils:83: Reordering mesh to physical ring order on single-tray TPU v2/v3.
[train] epoch: 5[2/65], loss=6.96 [00:02<01:05]INFO:2025-06-10 22:11:22,369:jax._src.mesh_utils:83: Reordering mesh to physical ring order on single-tray TPU v2/v3.
[train] epoch: 5[3/65], loss=6.96 [00:03<01:04]INFO:2025-06-10 22:11:23,406:jax._src.mesh_utils:83: Reordering mesh to physical ring order on single-tray TPU v2/v3.
[train] epoch: 5[4/65], loss=6.95 [00:04<01:03]INFO:2025-06-10 22:11:24,446:jax._src.mesh_utils:83: Reordering mesh to physical ring order on single-tray TPU v2/v3.
[train] epoch: 5[5/65], loss=6.95 [00:05<01:02]INFO:2025-06-10 22:11:25,484:jax._src.mesh_utils:83: Reordering mesh to physical ring order on single-tray TPU v2/v3.
[train] epoch: 5[6/65], l

[test] epoch: 6/10
- total loss: 0.3603
- Accuracy: 0.9500


[train] epoch: 6[0/65] [00:00<?]INFO:2025-06-10 22:12:29,251:jax._src.mesh_utils:83: Reordering mesh to physical ring order on single-tray TPU v2/v3.
[train] epoch: 6[1/65], loss=0.358 [00:01<01:06]INFO:2025-06-10 22:12:30,287:jax._src.mesh_utils:83: Reordering mesh to physical ring order on single-tray TPU v2/v3.
[train] epoch: 6[2/65], loss=0.354 [00:02<01:05]INFO:2025-06-10 22:12:31,324:jax._src.mesh_utils:83: Reordering mesh to physical ring order on single-tray TPU v2/v3.
[train] epoch: 6[3/65], loss=0.356 [00:03<01:04]INFO:2025-06-10 22:12:32,360:jax._src.mesh_utils:83: Reordering mesh to physical ring order on single-tray TPU v2/v3.
[train] epoch: 6[4/65], loss=0.354 [00:04<01:03]INFO:2025-06-10 22:12:33,397:jax._src.mesh_utils:83: Reordering mesh to physical ring order on single-tray TPU v2/v3.
[train] epoch: 6[5/65], loss=0.354 [00:05<01:02]INFO:2025-06-10 22:12:34,434:jax._src.mesh_utils:83: Reordering mesh to physical ring order on single-tray TPU v2/v3.
[train] epoch: 6[6/6

[test] epoch: 7/10
- total loss: 0.0852
- Accuracy: 0.9878


[train] epoch: 7[0/65] [00:00<?]INFO:2025-06-10 22:13:37,773:jax._src.mesh_utils:83: Reordering mesh to physical ring order on single-tray TPU v2/v3.
[train] epoch: 7[1/65], loss=0.0851 [00:01<01:34]INFO:2025-06-10 22:13:39,248:jax._src.mesh_utils:83: Reordering mesh to physical ring order on single-tray TPU v2/v3.
[train] epoch: 7[2/65], loss=0.0836 [00:02<01:16]INFO:2025-06-10 22:13:40,285:jax._src.mesh_utils:83: Reordering mesh to physical ring order on single-tray TPU v2/v3.
[train] epoch: 7[3/65], loss=0.084 [00:03<01:10]INFO:2025-06-10 22:13:41,323:jax._src.mesh_utils:83: Reordering mesh to physical ring order on single-tray TPU v2/v3.
[train] epoch: 7[4/65], loss=0.0828 [00:04<01:06]INFO:2025-06-10 22:13:42,359:jax._src.mesh_utils:83: Reordering mesh to physical ring order on single-tray TPU v2/v3.
[train] epoch: 7[5/65], loss=0.0834 [00:05<01:04]INFO:2025-06-10 22:13:43,395:jax._src.mesh_utils:83: Reordering mesh to physical ring order on single-tray TPU v2/v3.
[train] epoch: 7

[test] epoch: 8/10
- total loss: 0.0063
- Accuracy: 1.0000


[train] epoch: 8[0/65] [00:00<?]INFO:2025-06-10 22:14:46,719:jax._src.mesh_utils:83: Reordering mesh to physical ring order on single-tray TPU v2/v3.
[train] epoch: 8[1/65], loss=0.00544 [00:01<01:06]INFO:2025-06-10 22:14:47,754:jax._src.mesh_utils:83: Reordering mesh to physical ring order on single-tray TPU v2/v3.
[train] epoch: 8[2/65], loss=0.00515 [00:02<01:20]INFO:2025-06-10 22:14:49,214:jax._src.mesh_utils:83: Reordering mesh to physical ring order on single-tray TPU v2/v3.
[train] epoch: 8[3/65], loss=0.00509 [00:03<01:12]INFO:2025-06-10 22:14:50,250:jax._src.mesh_utils:83: Reordering mesh to physical ring order on single-tray TPU v2/v3.
[train] epoch: 8[4/65], loss=0.00491 [00:04<01:08]INFO:2025-06-10 22:14:51,286:jax._src.mesh_utils:83: Reordering mesh to physical ring order on single-tray TPU v2/v3.
[train] epoch: 8[5/65], loss=0.00441 [00:05<01:05]INFO:2025-06-10 22:14:52,323:jax._src.mesh_utils:83: Reordering mesh to physical ring order on single-tray TPU v2/v3.
[train] ep

[test] epoch: 9/10
- total loss: 0.0011
- Accuracy: 1.0000


[train] epoch: 9[0/65] [00:00<?]INFO:2025-06-10 22:15:55,520:jax._src.mesh_utils:83: Reordering mesh to physical ring order on single-tray TPU v2/v3.
[train] epoch: 9[1/65], loss=0.00123 [00:01<01:06]INFO:2025-06-10 22:15:56,555:jax._src.mesh_utils:83: Reordering mesh to physical ring order on single-tray TPU v2/v3.
[train] epoch: 9[2/65], loss=0.00114 [00:02<01:05]INFO:2025-06-10 22:15:57,591:jax._src.mesh_utils:83: Reordering mesh to physical ring order on single-tray TPU v2/v3.
[train] epoch: 9[3/65], loss=0.00114 [00:03<01:04]INFO:2025-06-10 22:15:58,626:jax._src.mesh_utils:83: Reordering mesh to physical ring order on single-tray TPU v2/v3.
[train] epoch: 9[4/65], loss=0.00107 [00:04<01:03]INFO:2025-06-10 22:15:59,665:jax._src.mesh_utils:83: Reordering mesh to physical ring order on single-tray TPU v2/v3.
[train] epoch: 9[5/65], loss=0.00106 [00:05<01:10]INFO:2025-06-10 22:16:01,099:jax._src.mesh_utils:83: Reordering mesh to physical ring order on single-tray TPU v2/v3.
[train] ep

[test] epoch: 10/10
- total loss: 0.0007
- Accuracy: 1.0000
Total time: 705.00 seconds



