Skip to content

Commit

Permalink
If an input to jnp.asarray is a numpy array, then convert it to a j…
Browse files Browse the repository at this point in the history
…ax.Array via device_put to avoid a copy.

Do a similar thing for jax.Array too if dtypes match.

Fixes #17702

PiperOrigin-RevId: 567644997
  • Loading branch information
yashk2810 authored and jax authors committed Sep 22, 2023
1 parent 51589bb commit 4269705
Show file tree
Hide file tree
Showing 2 changed files with 17 additions and 4 deletions.
15 changes: 11 additions & 4 deletions jax/_src/numpy/lax_numpy.py
Expand Up @@ -289,7 +289,7 @@ def load(*args: Any, **kwargs: Any) -> Array:
out = out.view(bfloat16)
try:
out = asarray(out)
except TypeError: # Unsupported dtype
except (TypeError, AssertionError): # Unsupported dtype
pass
return out

Expand Down Expand Up @@ -2017,6 +2017,12 @@ def array(object: Any, dtype: DTypeLike | None = None, copy: bool = True,
# array([1, 2, 3])
weak_type = dtype is None and dtypes.is_weakly_typed(object)

# Use device_put to avoid a copy for ndarray inputs.
if (not copy and isinstance(object, np.ndarray) and
(dtype is None or dtype == object.dtype) and (ndmin <= object.ndim)):
# Keep the output uncommitted.
return jax.device_put(object)

# For Python scalar literals, call coerce_to_array to catch any overflow
# errors. We don't use dtypes.is_python_scalar because we don't want this
# triggering for traced values. We do this here because it matters whether or
Expand All @@ -2027,8 +2033,8 @@ def array(object: Any, dtype: DTypeLike | None = None, copy: bool = True,

if hasattr(object, '__jax_array__'):
object = object.__jax_array__()
object = tree_map(lambda leaf: leaf.__jax_array__() if hasattr(leaf, "__jax_array__") else leaf,
object)
object = tree_map(lambda leaf: leaf.__jax_array__()
if hasattr(leaf, "__jax_array__") else leaf, object)
leaves = tree_leaves(object)
if dtype is None:
# Use lattice_result_type rather than result_type to avoid canonicalization.
Expand Down Expand Up @@ -2070,7 +2076,8 @@ def array(object: Any, dtype: DTypeLike | None = None, copy: bool = True,

raise TypeError(f"Unexpected input type for array: {type(object)}")

out_array: Array = lax_internal._convert_element_type(out, dtype, weak_type=weak_type)
out_array: Array = lax_internal._convert_element_type(
out, dtype, weak_type=weak_type)
if ndmin > ndim(out_array):
out_array = lax.expand_dims(out_array, range(ndmin - ndim(out_array)))
return out_array
Expand Down
6 changes: 6 additions & 0 deletions tests/api_test.py
Expand Up @@ -4451,6 +4451,12 @@ def test_scalar_conversion_errors(self):
self.assertRaises(TracerBoolConversionError, jax.jit(bool), scalar_int)
_ = bool(scalar_int) # no error

@jtu.run_on_devices('cpu')
def test_asarray_no_copy_np(self):
x = np.random.uniform(0, 1, (1000, 2000)).astype("float32")
out = jnp.asarray(x)
self.assertTrue(np.shares_memory(out, x))


class RematTest(jtu.JaxTestCase):

Expand Down

0 comments on commit 4269705

Please sign in to comment.