# Model Offloading Demo

This notebook demonstrates using model offloading to train a model that is so large that its forward pass would otherwise not fit into HBM.

In [1]:
from etils import ecolab
import flax.nnx as nnx
import jax
import jax.numpy as jnp
import optax

with ecolab.adhoc(reload='parallax'):
  import parallax
  from parallax.examples import models
  from parallax.examples import utils

In [2]:
print('All devices:', jax.devices())
gpu_device = jax.devices('gpu')[0]
print(gpu_device.memory_stats())

INFO:2025-09-23 09:49:21,408:jax._src.xla_bridge:822: Unable to initialize backend 'tpu': UNKNOWN: TPU initialization failed: No jellyfish device found.
INFO:2025-09-23 09:49:21,410:jax._src.xla_bridge:822: Unable to initialize backend 'pathways': Could not initialize backend 'pathways'
INFO:2025-09-23 09:49:21,411:jax._src.xla_bridge:822: Unable to initialize backend 'proxy': INVALID_ARGUMENT: IFRT proxy server address must be '<transport-type>://<backend-address>' (e.g., 'grpc://localhost'), but got 


All devices: [CudaDevice(id=0)]
{'num_allocs': 0, 'bytes_in_use': 0, 'peak_bytes_in_use': 0, 'largest_alloc_size': 0, 'bytes_limit': 31724126208, 'bytes_reserved': 0, 'peak_bytes_reserved': 0, 'largest_free_block_bytes': 0, 'pool_bytes': 0, 'peak_pool_bytes': 0}


In [15]:
# Define models.
VOCAB_SIZE = 10_000
MAX_LEN = 1024

model = parallax.create_offloaded_model(
    lambda: models.MiniGPT7B(VOCAB_SIZE, nnx.Rngs(42), MAX_LEN),
)
print(model.layers[1].linear1.kernel.device)

TFRT_CPU_0


In [16]:
mesh = jax.make_mesh((1,), ('x',))
s_host = jax.sharding.NamedSharding(
    mesh,
    jax.sharding.PartitionSpec(None),
    memory_kind='pinned_host',
)
jitted_model_call = parallax.jit_offloaded_model(model, s_host)

# The state of the model must be on the host to begin with.
state = nnx.state(model)
host_state = jax.tree.map(lambda x: jax.device_put(x, s_host), state)

# 3. Run the forward pass with offloading.
inputs, _ = utils.make_gpt_inputs(batch_size=8, max_len=MAX_LEN)
outputs = jitted_model_call(host_state, inputs)
outputs

Array([[[-1.80469, 0.0622559, -0.503906, ..., 0.59375, -0.617188,
         0.0981445],
        [-1.80469, 0.0625, -0.503906, ..., 0.589844, -0.613281,
         0.0917969],
        [-1.79688, 0.0629883, -0.5, ..., 0.59375, -0.617188, 0.0883789],
        ...,
        [-1.83594, 0.109863, -0.542969, ..., 0.582031, -0.621094,
         0.0141602],
        [-1.83594, 0.109375, -0.542969, ..., 0.578125, -0.617188,
         0.0185547],
        [-1.83594, 0.11377, -0.546875, ..., 0.574219, -0.617188,
         0.0113525]],

       [[-1.80469, 0.0622559, -0.503906, ..., 0.59375, -0.617188,
         0.0981445],
        [-1.80469, 0.0625, -0.503906, ..., 0.589844, -0.613281,
         0.0917969],
        [-1.79688, 0.0629883, -0.5, ..., 0.59375, -0.617188, 0.0883789],
        ...,
        [-1.83594, 0.109863, -0.542969, ..., 0.582031, -0.621094,
         0.0141602],
        [-1.83594, 0.109375, -0.542969, ..., 0.578125, -0.617188,
         0.0185547],
        [-1.83594, 0.11377, -0.546875, ..., 0.57