Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

an experiment in handling instances with __jax_array__ #4725

Merged
merged 3 commits into from Dec 16, 2020
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
3 changes: 3 additions & 0 deletions jax/_src/lax/lax.py
Expand Up @@ -437,6 +437,9 @@ def convert_element_type(operand: Array, new_dtype: DType = None,
msg = "Casting complex values to real discards the imaginary part"
warnings.warn(msg, np.ComplexWarning, stacklevel=2)

if hasattr(operand, '__jax_array__'):
operand = operand.__jax_array__()

if not isinstance(operand, (core.Tracer, xla.DeviceArray)):
return _device_put_raw(np.asarray(operand, dtype=new_dtype),
weak_type=new_weak_type)
Expand Down
6 changes: 4 additions & 2 deletions jax/_src/numpy/lax_numpy.py
Expand Up @@ -300,9 +300,11 @@ def _result_dtype(op, *args):
return _dtype(op(*args))


def _arraylike(x): return isinstance(x, ndarray) or isscalar(x)
def _arraylike(x):
return isinstance(x, ndarray) or isscalar(x) or hasattr(x, '__jax_array__')
jakevdp marked this conversation as resolved.
Show resolved Hide resolved

def _check_arraylike(fun_name, *args):
"""Check if all args fit JAX's definition of arraylike (ndarray or scalar)."""
"""Check if all args fit JAX's definition of arraylike."""
assert isinstance(fun_name, str), f"fun_name must be a string. Got {fun_name}"
if _any(not _arraylike(arg) for arg in args):
pos, arg = next((i, arg) for i, arg in enumerate(args)
Expand Down
2 changes: 2 additions & 0 deletions jax/core.py
Expand Up @@ -859,6 +859,8 @@ def concrete_aval(x):
for typ in type(x).mro():
handler = pytype_aval_mappings.get(typ)
if handler: return handler(x)
if hasattr(x, '__jax_array__'):
return concrete_aval(x.__jax_array__())
raise TypeError(f"{type(x)} is not a valid JAX type")


Expand Down
17 changes: 12 additions & 5 deletions jax/interpreters/xla.py
Expand Up @@ -122,11 +122,14 @@ def array_result_handler(device: Optional[Device], aval: core.ShapedArray):
}

def device_put(x, device: Optional[Device] = None) -> Tuple[Any]:
x = canonicalize_dtype(x)
try:
return device_put_handlers[type(x)](x, device)
except KeyError as err:
raise TypeError(f"No device_put handler for type: {type(x)}") from err
handler = device_put_handlers.get(type(x))
if handler:
x = canonicalize_dtype(x)
Copy link
Member

@shoyer shoyer Feb 5, 2021

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It looks like this PR broke internal tests due to this line. In particular, some type like IntEnum no longer are accepted by device_put, e.g., consider this currently valid behavior:

In [1]: import jax

In [2]: import enum

In [3]: class X(enum.IntEnum):
   ...:     Y = 1
   ...:     Z = 2
   ...:

In [4]: jax.device_put(X.Y)
Out[4]: DeviceArray(1, dtype=int32)

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Oh man, I think you nailed it! Is it just because I dropped the x = canonicalize_dtype(x) beforehand?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That was the only thing I saw that could have broken the existing code path.

return handler(x, device)
elif hasattr(x, '__jax_array__'):
return device_put(x.__jax_array__(), device)
jakevdp marked this conversation as resolved.
Show resolved Hide resolved
else:
raise TypeError(f"No device_put handler for type: {type(x)}")

def _device_put_array(x, device: Optional[Device]):
backend = xb.get_device_backend(device)
Expand All @@ -151,6 +154,8 @@ def canonicalize_dtype(x):
for typ in typ.mro():
handler = canonicalize_dtype_handlers.get(typ)
if handler: return handler(x)
if hasattr(x, '__jax_array__'):
return canonicalize_dtype(x.__jax_array__())
raise TypeError(f"No canonicalize_dtype handler for type: {type(x)}")

def _canonicalize_ndarray_dtype(x):
Expand All @@ -173,6 +178,8 @@ def abstractify(x) -> core.AbstractValue:
for typ in typ.mro():
aval_fn = pytype_aval_mappings.get(typ)
if aval_fn: return aval_fn(x)
if hasattr(x, '__jax_array__'):
return abstractify(x.__jax_array__())
raise TypeError(f"Argument '{x}' of type '{type(x)}' is not a valid JAX type")

def _make_abstract_python_scalar(typ, _):
Expand Down
21 changes: 21 additions & 0 deletions tests/api_test.py
Expand Up @@ -2068,6 +2068,27 @@ def test_linearize_aval_error(self):
with self.assertRaisesRegex(ValueError, "tangent values inconsistent"):
f_jvp(np.ones(2, np.int32))

def test_dunder_jax_array(self):
# https://github.com/google/jax/pull/4725

class AlexArray:
def __init__(self, jax_val):
self.jax_val = jax_val
def __jax_array__(self):
return self.jax_val
dtype = property(lambda self: self.jax_val.dtype)
shape = property(lambda self: self.jax_val.shape)

x = AlexArray(jnp.array([1., 2., 3.]))
y = jnp.sin(x)
jakevdp marked this conversation as resolved.
Show resolved Hide resolved
self.assertAllClose(y, jnp.sin(jnp.array([1., 2., 3.])))
y = api.grad(api.jit(lambda x: jnp.sin(x).sum()))(x)
self.assertAllClose(y, jnp.cos(jnp.array([1., 2., 3.])))

x = AlexArray(jnp.array([[1., 2., 3.]]))
y = api.pmap(jnp.sin)(x)
self.assertAllClose(y, jnp.sin(jnp.array([[1., 2., 3.]])))


class RematTest(jtu.JaxTestCase):

Expand Down