Hi, would really appreciate some help with the following issue of getting a TypeError while running the GenCast autoregressive rollout. It fails with a TypeError related to xarray_jax.JaxArrayWrapper interacting with a DynamicJaxprTracer.
The failure occurs during the autoregressive rollout step (chunked_prediction_generator_multiple_runs) in the GenCast demo.
I can reproduce the same error on a local multi-GPU machine, and on Google Colab TPU (v5e-1 TPU with 2025.07 runtime version), and I think the error is related to the GenCast rollout path interacting with xarray_jax.
The simpler GraphCast forward pass works correctly.
Execution fails with:
TypeError: operand type(s) all returned NotImplemented from __array_ufunc__(
<ufunc 'multiply'>, '__call__',
JitTracer(float32[]),
xarray_jax.JaxArrayWrapper(JitTracer(float32[1,1,181,360]))
): 'DynamicJaxprTracer', 'JaxArrayWrapper'
For reproduction:
Using GenCast demo notebook, the failure occurs in the Autoregressive rollout cell:
chunks = []
for chunk in rollout.chunked_prediction_generator_multiple_runs(
predictor_fn=run_forward_pmap,
rngs=rngs,
inputs=eval_inputs,
targets_template=eval_targets * np.nan,
forcings=eval_forcings,
num_steps_per_chunk=1,
num_samples=num_ensemble_members,
pmap_devices=jax.local_devices(),
):
chunks.append(chunk)
predictions = xarray.combine_by_coords(chunks)
Setup:
Machine:
Linux cluster
2× NVIDIA GPUs
Python environment:
Python 3.11
jax 0.4.38
jaxlib 0.4.38
numpy 2.4.2
xarray 2024.11.0
pandas 2.2.3
Devices detected:
devices: [CudaDevice(id=0), CudaDevice(id=1)]
backend: gpu
Error traceback ends with:
TypeError: operand type(s) all returned NotImplemented from __array_ufunc__(
<ufunc 'multiply'>, '__call__',
JitTracer(float32[]),
xarray_jax.JaxArrayWrapper(JitTracer(float32[1,1,181,360]))
)
Hi, would really appreciate some help with the following issue of getting a TypeError while running the GenCast autoregressive rollout. It fails with a TypeError related to xarray_jax.JaxArrayWrapper interacting with a DynamicJaxprTracer.
The failure occurs during the autoregressive rollout step
(chunked_prediction_generator_multiple_runs)in the GenCast demo.I can reproduce the same error on a local multi-GPU machine, and on Google Colab TPU (v5e-1 TPU with 2025.07 runtime version), and I think the error is related to the GenCast rollout path interacting with xarray_jax.
The simpler GraphCast forward pass works correctly.
Execution fails with:
For reproduction:
Using GenCast demo notebook, the failure occurs in the Autoregressive rollout cell:
Setup:
Machine:
Linux cluster
2× NVIDIA GPUs
Python environment:
Devices detected:
Error traceback ends with: