Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

jax and xarray integration for automatic differentiation? #17107

Open
2 tasks done
tylerflex opened this issue Aug 14, 2023 · 12 comments
Open
2 tasks done

jax and xarray integration for automatic differentiation? #17107

tylerflex opened this issue Aug 14, 2023 · 12 comments
Labels
enhancement New feature or request

Comments

@tylerflex
Copy link

I've been wondering if there has been any recent progress in integrating jax and xarray, specifically for automatic differentiation. For context, we have a simulation project that relies on xarray for our simulation output data but recently added jax support so users can automatically differentiate through these simulations. To make this work, we added code to emulate xr.DataArray functionality but with jax internals. However, this approach has been a headache to maintain and extend. It would be amazing if xarray had native support for gradient tracking in jax.

As an example, the code snippet below multiplies a Jax-traced value by an xarray.DataArray, does an interpolation, and then a jax-traved operation. It would be great if we could differentiate through this. The forward pass works, but the backwards pass gives a TracerArrayConversionError.

I've tried many other workarounds based on issues, such as this and some other discussions eg but without any luck. Are any updates on the status of this, whether it would be possible eventually, or suggestions for possible workarounds? Any discussion or pointers towards a good approach to this are really appreciated.

@shoyer

import numpy as np
import jax
import jax.numpy as jnp
import xarray as xr

shape = (3, 4, 5)
values = np.random.random(shape)
coords = {dim: np.arange(length) for dim, length in zip('xyz', shape)}
xarr = xr.DataArray(values, coords=coords)

def f(x):
    xarr_multiplied = x * xarr
    val = xarr_multiplied.interp(x=1, y=1, z=1)
    return jnp.sqrt(val.values)

f(1.0)
jax.grad(f)(1.0)
---------------------------------------------------------------------------
TracerArrayConversionError                Traceback (most recent call last)
Cell In[30], line 17
     14     return jnp.sqrt(val.values)
     16 f(1.0)
---> 17 jax.grad(f)(1.0)

    [... skipping hidden 10 frame]

Cell In[30], line 12, in f(x)
     11 def f(x):
---> 12     xarr_multiplied = x * xarr
     13     val = xarr_multiplied.interp(x=1, y=1, z=1)
     14     return jnp.sqrt(val.values)

File ~/.pyenv/versions/3.10.9/lib/python3.10/site-packages/xarray/core/_typed_ops.py:282, in DataArrayOpsMixin.__rmul__(self, other)
    281 def __rmul__(self, other):
--> 282     return self._binary_op(other, operator.mul, reflexive=True)

File ~/.pyenv/versions/3.10.9/lib/python3.10/site-packages/xarray/core/dataarray.py:4622, in DataArray._binary_op(self, other, f, reflexive)
   4616 other_variable = getattr(other, "variable", other)
   4617 other_coords = getattr(other, "coords", None)
   4619 variable = (
   4620     f(self.variable, other_variable)
   4621     if not reflexive
-> 4622     else f(other_variable, self.variable)
   4623 )
   4624 coords, indexes = self.coords._merge_raw(other_coords, reflexive)
   4625 name = self._result_name(other)

File ~/.pyenv/versions/3.10.9/lib/python3.10/site-packages/xarray/core/_typed_ops.py:488, in VariableOpsMixin.__rmul__(self, other)
    487 def __rmul__(self, other):
--> 488     return self._binary_op(other, operator.mul, reflexive=True)

File ~/.pyenv/versions/3.10.9/lib/python3.10/site-packages/xarray/core/variable.py:2707, in Variable._binary_op(self, other, f, reflexive)
   2703 with np.errstate(all="ignore"):
   2704     new_data = (
   2705         f(self_data, other_data) if not reflexive else f(other_data, self_data)
   2706     )
-> 2707 result = Variable(dims, new_data, attrs=attrs)
   2708 return result

File ~/.pyenv/versions/3.10.9/lib/python3.10/site-packages/xarray/core/variable.py:366, in Variable.__init__(self, dims, data, attrs, encoding, fastpath)
    346 def __init__(self, dims, data, attrs=None, encoding=None, fastpath=False):
    347     """
    348     Parameters
    349     ----------
   (...)
    364         unrecognized encoding items.
    365     """
--> 366     self._data = as_compatible_data(data, fastpath=fastpath)
    367     self._dims = self._parse_dimensions(dims)
    368     self._attrs = None

File ~/.pyenv/versions/3.10.9/lib/python3.10/site-packages/xarray/core/variable.py:293, in as_compatible_data(data, fastpath)
    290     return data
    292 # validate whether the data is valid data types.
--> 293 data = np.asarray(data)
    295 if isinstance(data, np.ndarray) and data.dtype.kind in "OMm":
    296     data = _possibly_convert_objects(data)

File ~/.pyenv/versions/3.10.9/lib/python3.10/site-packages/jax/_src/core.py:605, in Tracer.__array__(self, *args, **kw)
    604 def __array__(self, *args, **kw):
--> 605   raise TracerArrayConversionError(self)

TracerArrayConversionError: The numpy.ndarray conversion method __array__() was called on traced array with shape float32[3,4,5].
See https://jax.readthedocs.io/en/latest/errors.html#jax.errors.TracerArrayConversionError
  • Check for duplicate requests.
  • Describe your goal, and if possible provide a code snippet with a motivating example.
@tylerflex tylerflex added the enhancement New feature or request label Aug 14, 2023
@jakevdp
Copy link
Collaborator

jakevdp commented Aug 14, 2023

Hi - I think I recall this kind of thing coming up before – I don't know of any effort to do the full integration of JAX and xarray that you have in mind. The problem is that xarray is fundamentally built on the assumption that its arrays are numpy arrays, and so for example np.asarray is used frequently in its implementations. For traced operations in JAX, np.asarray will return a TracerArrayConversionError, because traced values cannot be converted to NumPy arrays.

To move forward, either one of two things would have to happen:

  1. xarray would have to loosen its assumptions about the internal array representation. One way this could happen is if xarray adopted the Python Array API standard. It looks like there is some thought about this, but it would be a very big project.
  2. Somebody could write an entirely new xarray-like wrapper for JAX. I don't know of any projects like this (though I wouldn't be surprised if folks have experimented with it), but it would also be a very big project.

Short of a team of people undertaking one of those very big projects, I don't think there's any good way to do what you have in mind.

@shoyer
Copy link
Member

shoyer commented Aug 14, 2023

I think would be quite exciting!

I think the Python Array API standard would probably be the way to go. Xarray's support for the API standard is pretty close to complete, and most missing features would not be hard to add. Xarray in fact already supports wrapping many types of non-NumPy arrays so this supporting JAX arrays as well would not be a big lift.

To get Xarray objects working with JAX transforms like jax.grad, they need to be registered with tree_util. But I think that is also straightforward.

Deepmind's GraphCast project contains a bundled Xarray-JAX wrapper, which I think already does some verison of both of these (maybe in a non-ideal way):
https://github.com/deepmind/graphcast/blob/main/graphcast/xarray_jax.py

@jakevdp
Copy link
Collaborator

jakevdp commented Aug 14, 2023

(Side note: for the Array API approach, we'd also have to land some version of #16099 to make JAX compliant)

@shoyer
Copy link
Member

shoyer commented Aug 14, 2023

CC @mjwillson who wrote the Xarray-JAX wrapper in GraphCast.

@tylerflex
Copy link
Author

Thanks @shoyer ! I'll have to study that graph cast code, I tried something similar but never could get it working properly.

@tylerflex
Copy link
Author

tylerflex commented Aug 15, 2023

I played around a bit with this GraphCast wrapper. It worked for the intended use case of applying @jax.jit to functions mapping from DataArray -> DataArray.

Unfortunately for jax.grad() still seems to give a TracerArrayConversionError. It seems like it might be occurring on the VJP function for multiplying the jax-traced scalar by the DataArray.

It's pretty likely I'm doing something wrong here so if @mjwilson / @shoyer spots something wrong here let me know!

import numpy as np
import jax
import jax.numpy as jnp
import xarray as xr

shape = (3, 4, 5)
values = np.random.random(shape)
coords = {dim: np.arange(length).tolist() for dim, length in zip('xyz', shape)}
xarr_jax = DataArray(values, dims=('x', 'y', 'z'), coords=coords) # note: GraphCast wrapper class

def f(x):
    val = x * xarr_jax
    val = val.interp(x=1, y=1, z=1)
    val = jnp.array(val.data)
    return jnp.sum(val)

f(1.0) # works
jax.grad(f)(1.0) # TracerArrayConversionError
---------------------------------------------------------------------------
TracerArrayConversionError                Traceback (most recent call last)
Cell In[16], line 18
     15     return jnp.sum(val)
     17 f(1.0) # works
---> 18 jax.grad(f)(1.0) # TracerArrayConversionError

    [... skipping hidden 10 frame]

Cell In[16], line 12, in f(x)
     11 def f(x):
---> 12     val = x * xarr_jax
     13     val = val.interp(x=1, y=1, z=1)
     14     val = jnp.array(val.data)

File ~/.pyenv/versions/3.10.9/lib/python3.10/site-packages/xarray/core/_typed_ops.py:282, in DataArrayOpsMixin.__rmul__(self, other)
    281 def __rmul__(self, other):
--> 282     return self._binary_op(other, operator.mul, reflexive=True)

File ~/.pyenv/versions/3.10.9/lib/python3.10/site-packages/xarray/core/dataarray.py:4622, in DataArray._binary_op(self, other, f, reflexive)
   4616 other_variable = getattr(other, "variable", other)
   4617 other_coords = getattr(other, "coords", None)
   4619 variable = (
   4620     f(self.variable, other_variable)
   4621     if not reflexive
-> 4622     else f(other_variable, self.variable)
   4623 )
   4624 coords, indexes = self.coords._merge_raw(other_coords, reflexive)
   4625 name = self._result_name(other)

File ~/.pyenv/versions/3.10.9/lib/python3.10/site-packages/xarray/core/_typed_ops.py:488, in VariableOpsMixin.__rmul__(self, other)
    487 def __rmul__(self, other):
--> 488     return self._binary_op(other, operator.mul, reflexive=True)

File ~/.pyenv/versions/3.10.9/lib/python3.10/site-packages/xarray/core/variable.py:2707, in Variable._binary_op(self, other, f, reflexive)
   2703 with np.errstate(all="ignore"):
   2704     new_data = (
   2705         f(self_data, other_data) if not reflexive else f(other_data, self_data)
   2706     )
-> 2707 result = Variable(dims, new_data, attrs=attrs)
   2708 return result

File ~/.pyenv/versions/3.10.9/lib/python3.10/site-packages/xarray/core/variable.py:366, in Variable.__init__(self, dims, data, attrs, encoding, fastpath)
    346 def __init__(self, dims, data, attrs=None, encoding=None, fastpath=False):
    347     """
    348     Parameters
    349     ----------
   (...)
    364         unrecognized encoding items.
    365     """
--> 366     self._data = as_compatible_data(data, fastpath=fastpath)
    367     self._dims = self._parse_dimensions(dims)
    368     self._attrs = None

File ~/.pyenv/versions/3.10.9/lib/python3.10/site-packages/xarray/core/variable.py:293, in as_compatible_data(data, fastpath)
    290     return data
    292 # validate whether the data is valid data types.
--> 293 data = np.asarray(data)
    295 if isinstance(data, np.ndarray) and data.dtype.kind in "OMm":
    296     data = _possibly_convert_objects(data)

File ~/.pyenv/versions/3.10.9/lib/python3.10/site-packages/jax/_src/core.py:605, in Tracer.__array__(self, *args, **kw)
    604 def __array__(self, *args, **kw):
--> 605   raise TracerArrayConversionError(self)

TracerArrayConversionError: The numpy.ndarray conversion method __array__() was called on traced array with shape float32[3,4,5].
See https://jax.readthedocs.io/en/latest/errors.html#jax.errors.TracerArrayConversionError

@jakevdp
Copy link
Collaborator

jakevdp commented Aug 15, 2023

The error is happening because the gradient computation results in calling xarray's __rmul__, which attempts to cast the inputs to numpy arrays, because xarray is built on the assumption that all its buffers are numpy arrays. Casting a traced JAX array to a numpy array results in a TracerConversionError.

There's no way to fix this without changing how xarray is implemented.

@shoyer
Copy link
Member

shoyer commented Aug 15, 2023

xarray is built on the assumption that all its buffers are numpy arrays

This isn't true -- xarray supports a number of duck arrays. As soon as JAX implements __array_namespace__ from the array API, you'll be able to wrap JAX arrays directly into xarray objects.

If you use the GraphCast Xarray-JAX wrapper, you need to use its special constructors for DataArray/Dataset.

@jakevdp
Copy link
Collaborator

jakevdp commented Aug 15, 2023

Oh, good to know! Progress on __array_namespace__ is in #16099, though it's been hampered by the fact that JAX arrays are immutable, and some corners of the Python array API and its primary testing framework assume mutability (hopefully xarray doesn't depend on any of these mutation APIs).

@tylerflex
Copy link
Author

If you use the GraphCast Xarray-JAX wrapper, you need to use its special constructors for DataArray/Dataset.

Could you explain a bit more?

  • Is this something I could fix by building on top of the GraphCast wrapper? For example, defining a custom jax VJP for DataArray.__rmul__?
  • Or maybe something more subtle that I'm not doing in my example?
  • Is it something more internal to jax / xarray that can't be fixed right now?

@shoyer
Copy link
Member

shoyer commented Aug 15, 2023

@jakevdp Indeed, Xarray doesn't rely on the mutation APIs (unless a user tries to mutate an array)

@tylerflex I see, it looks like you were already using the GraphCast wrapper. I don't know exactly what's going on, then.

@mjwillson
Copy link
Contributor

mjwillson commented Aug 21, 2023

Hiya,

Firstly just to note that xarray_jax isn't something we're officially supporting outside the GraphCast project for now, as it does have some rough edges and is in part a bit of a stop-gap measure until JAX supports the new array protocol which will allow it to integrate better with xarray.

That said, about your example, you'll find the following very similar code works:

import numpy as np
import jax
import jax.numpy as jnp
import xarray as xr
from graphcast import xarray_jax

shape = (3, 4, 5)
values = jnp.asarray(np.random.random(shape))
coords = {dim: np.arange(length).tolist() for dim, length in zip('xyz', shape)}
xarr_jax = xarray_jax.DataArray(values, dims=('x', 'y', 'z'), coords=coords)

def f(x):
    val = x * xarr_jax
    val = xarray_jax.unwrap_data(val)
    return jnp.sum(val)

f(1.0)
jax.jit(f)(1.0)
jax.grad(f)(1.0)

Some issues in your code were:

  • It looks like xarray's interp method uses scipy to do the interpolation, so there's no way this is going to work with JAX (it will try to convert the JAX tracer to a numpy array and fail)
  • To access the underlying jax array use xarray_jax.unwrap_data -- jnp.array(val.data) isn't going to work as it'll try to go via a numpy array
  • It appears there is a corner case bug where multiplying (or any other binary op-ing) a numpy-array-backed DataArray with a JAX tracer causes an error. To work around it you should create the xarray_jax.DataArray with a jax array for the data, or you could pass the DataArray in as an explicit argument to the function rather than just closing over it, that way the JIT process will ensure it gets converted to a JAX array under the hood.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
enhancement New feature or request
Projects
None yet
Development

No branches or pull requests

4 participants