Skip to content

GenCast autoregressive rollout fails with JaxArrayWrapper / DynamicJaxprTracer __array_ufunc__ TypeError on both GPU and Colab TPU #203

@DarinaAndr

Description

@DarinaAndr

Hi, would really appreciate some help with the following issue of getting a TypeError while running the GenCast autoregressive rollout. It fails with a TypeError related to xarray_jax.JaxArrayWrapper interacting with a DynamicJaxprTracer.

The failure occurs during the autoregressive rollout step (chunked_prediction_generator_multiple_runs) in the GenCast demo.

I can reproduce the same error on a local multi-GPU machine, and on Google Colab TPU (v5e-1 TPU with 2025.07 runtime version), and I think the error is related to the GenCast rollout path interacting with xarray_jax.

The simpler GraphCast forward pass works correctly.

Execution fails with:

TypeError: operand type(s) all returned NotImplemented from __array_ufunc__(
<ufunc 'multiply'>, '__call__',
JitTracer(float32[]),
xarray_jax.JaxArrayWrapper(JitTracer(float32[1,1,181,360]))
): 'DynamicJaxprTracer', 'JaxArrayWrapper'

For reproduction:
Using GenCast demo notebook, the failure occurs in the Autoregressive rollout cell:

chunks = []
for chunk in rollout.chunked_prediction_generator_multiple_runs(
    predictor_fn=run_forward_pmap,
    rngs=rngs,
    inputs=eval_inputs,
    targets_template=eval_targets * np.nan,
    forcings=eval_forcings,
    num_steps_per_chunk=1,
    num_samples=num_ensemble_members,
    pmap_devices=jax.local_devices(),
):
    chunks.append(chunk)

predictions = xarray.combine_by_coords(chunks)

Setup:

Machine:

Linux cluster

2× NVIDIA GPUs
Python environment:

Python 3.11
jax 0.4.38
jaxlib 0.4.38
numpy 2.4.2
xarray 2024.11.0
pandas 2.2.3

Devices detected:

devices: [CudaDevice(id=0), CudaDevice(id=1)]
backend: gpu

Error traceback ends with:

TypeError: operand type(s) all returned NotImplemented from __array_ufunc__(
<ufunc 'multiply'>, '__call__',
JitTracer(float32[]),
xarray_jax.JaxArrayWrapper(JitTracer(float32[1,1,181,360]))
)

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions