Skip to content

_globalize_single_replica_arrays "incompatible devices" error persists after #3164 fix — with sd_mesh: does not override active training mesh context #3183

@evelyn22chen

Description

@evelyn22chen

Description

Thank you for addressing #3164! I've tested the applied fix (with sd_mesh: as mesh context manager) and confirmed that the "incompatible devices" error still occurs. It appears that with sd_mesh: does not override the
active training mesh context, so JAX still resolves JIT compilation against the full multi-device mesh.

The fix should use with jax.set_mesh(sd_mesh): instead of with sd_mesh: to properly override the active mesh context for single-device operations.

Environment

Error

[rank0]: ValueError: Received incompatible devices for jitted computation. Got argument args[0] of broadcast_in_dim with shape float32[1] and device ids [0] on platform GPU and jit's context mesh with device ids [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31] on platform GPU

Root Cause

https://github.com/google/orbax/blob/main/checkpoint/orbax/checkpoint/_src/multihost/multislice.py#L253,#L268

The current code uses with sd_mesh: to scope single-device operations:

sd_mesh = jax.sharding.Mesh(np.array([s.device]), ('_single',))
with sd_mesh:
    source_device_map[s.device] = jnp.expand_dims(s.data, axis=0)

with sd_mesh: enters the mesh as a context manager but does not take precedence over the active training mesh. When jnp.expand_dims triggers JIT compilation (e.g., broadcast_in_dim), JAX resolves the mesh context to the training mesh (32 devices) rather than the intended single-device mesh, causing the incompatible devices error.

Fix

Replace with sd_mesh: with with jax.set_mesh(sd_mesh):

if is_source:
    for s in inp.addressable_shards:
        sd_mesh = jax.sharding.Mesh(np.array([s.device]), ('_single',))
        with jax.set_mesh(sd_mesh):
            source_device_map[s.device] = jnp.expand_dims(s.data, axis=0)

...

else:
    slice_shape = _get_slice_shape(index, global_shape)
    sd_mesh = jax.sharding.Mesh(np.array([d]), ('_single',))
    with jax.set_mesh(sd_mesh):
        zero_data = jnp.zeros(slice_shape, dtype=inp.dtype, device=d)
    device_buffers.append(zero_data)

jax.set_mesh properly overrides the active mesh context for JIT compilation, while with sd_mesh: does not. We have verified that with jax.set_mesh(sd_mesh): resolves this issue in our multi-host training environment.

Questions

What is the difference between with sd_mesh and with jax.set_mesh(sd_mesh) and why is with sd_mesh unable to override the active training mesh context?

Related

Metadata

Metadata

Assignees

No one assigned

    Labels

    type:bugSomething isn't working

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions