Skip to content

Fix JaxArrayWrapper TypeError with JAX tracers during JIT#211

Open
Chessing234 wants to merge 1 commit intogoogle-deepmind:mainfrom
Chessing234:fix/issue-203-jax-tracer-ufunc
Open

Fix JaxArrayWrapper TypeError with JAX tracers during JIT#211
Chessing234 wants to merge 1 commit intogoogle-deepmind:mainfrom
Chessing234:fix/issue-203-jax-tracer-ufunc

Conversation

@Chessing234
Copy link
Copy Markdown

Summary

  • Fixes the TypeError: operand type(s) all returned NotImplemented from __array_ufunc__ that occurs during GenCast autoregressive rollout under jax.jit tracing
  • Adds jax.core.Tracer to three locations in xarray_jax.py where JAX array types are checked at runtime:
    • _WRAPPED_TYPES tuple, so tracers are properly wrapped when constructing xarray structures during tracing
    • JaxArrayWrapper.__array_ufunc__ isinstance check, so numpy ufuncs (e.g. multiply) correctly dispatch when one operand is a raw tracer and the other is a JaxArrayWrapper-wrapped tracer
    • unwrap() function, so tracers are passed through (like jax.Array) instead of falling to the error/passthrough branch

Root cause

During JIT-traced execution, JAX replaces concrete jax.Array objects with DynamicJaxprTracer instances. These inherit from jax.core.Tracer, not jax.Array. The existing type checks in xarray_jax.py did not account for tracer types, causing __array_ufunc__ to return NotImplemented when a raw tracer scalar was combined with a JaxArrayWrapper-wrapped tracer array.

Test plan

  • Verify existing xarray_jax_test.py tests pass
  • Run GenCast autoregressive rollout (as in the demo notebook) to confirm the TypeError no longer occurs

Fixes #203

🤖 Generated with Claude Code

Add jax.core.Tracer to the isinstance checks in JaxArrayWrapper's
__array_ufunc__ method, the unwrap function, and _WRAPPED_TYPES.

During JIT-traced execution (e.g. GenCast autoregressive rollout),
JAX uses tracer objects (DynamicJaxprTracer) instead of concrete
jax.Array instances. These tracers were not recognized by the type
checks in xarray_jax.py, causing __array_ufunc__ to return
NotImplemented and raising a TypeError when numpy ufuncs like
multiply were applied to a mix of raw tracers and JaxArrayWrapper-
wrapped tracers.

Fixes google-deepmind#203

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
@google-cla
Copy link
Copy Markdown

google-cla bot commented Apr 6, 2026

Thanks for your pull request! It looks like this may be your first contribution to a Google open source project. Before we can look at your pull request, you'll need to sign a Contributor License Agreement (CLA).

View this failed invocation of the CLA check for more information.

For the most up to date status, view the checks section at the bottom of the pull request.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

GenCast autoregressive rollout fails with JaxArrayWrapper / DynamicJaxprTracer __array_ufunc__ TypeError on both GPU and Colab TPU

1 participant