From 2c32660a8fb9056c32d41be141474bb8f8f3dacd Mon Sep 17 00:00:00 2001 From: Peter Hawkins Date: Fri, 18 Aug 2023 16:50:36 -0400 Subject: [PATCH] Replace references to DeviceArray with Array. A number of stale references are lurking in our documentation. --- cloud_tpu_colabs/Pmap_Cookbook.ipynb | 4 +-- docs/autodidax.ipynb | 12 ++++---- docs/autodidax.md | 12 ++++---- docs/autodidax.py | 12 ++++---- docs/jax-101/01-jax-basics.ipynb | 24 ++++++++-------- docs/jax-101/01-jax-basics.md | 2 +- docs/jax-101/02-jitting.ipynb | 2 +- docs/jax-101/03-vectorization.ipynb | 14 +++++----- docs/jax-101/04-advanced-autodiff.ipynb | 28 +++++++++---------- docs/notebooks/Common_Gotchas_in_JAX.ipynb | 24 ++++++++-------- docs/notebooks/Common_Gotchas_in_JAX.md | 4 +-- ...tom_derivative_rules_for_Python_code.ipynb | 18 ++++++------ ...Custom_derivative_rules_for_Python_code.md | 4 +-- docs/notebooks/autodiff_cookbook.ipynb | 22 +++++++-------- docs/notebooks/external_callbacks.ipynb | 2 +- docs/notebooks/thinking_in_jax.ipynb | 24 ++++++++-------- docs/notebooks/vmapped_log_probs.ipynb | 14 +++++----- jax/_src/array.py | 8 ++---- jax/_src/core.py | 2 +- jax/_src/dlpack.py | 14 +++++----- jax/_src/interpreters/batching.py | 2 +- jax/_src/interpreters/pxla.py | 6 ++-- jax/_src/numpy/array_methods.py | 10 +++---- jax/_src/profiler.py | 2 +- jax/_src/sharding_specs.py | 16 +++++------ .../array_serialization/serialization.py | 6 ++-- jax/experimental/jax2tf/call_tf.py | 4 +-- jax/experimental/jax2tf/jax2tf.py | 4 +-- jax/experimental/jax2tf/tests/jax2tf_test.py | 2 +- .../jax2tf/tests/shape_poly_test.py | 2 +- jax/experimental/sparse/bcoo.py | 2 +- jax/experimental/sparse/bcsr.py | 2 +- tests/api_test.py | 2 +- tests/array_test.py | 2 +- tests/dtypes_test.py | 2 +- tests/lax_control_flow_test.py | 2 +- tests/lax_numpy_test.py | 6 ++-- tests/multiprocess_gpu_test.py | 10 ------- tests/pickle_test.py | 4 +-- tests/pmap_test.py | 2 +- 40 files changed, 161 insertions(+), 173 deletions(-) diff --git a/cloud_tpu_colabs/Pmap_Cookbook.ipynb b/cloud_tpu_colabs/Pmap_Cookbook.ipynb index e5b4c4d61907..67174a9a01ea 100644 --- a/cloud_tpu_colabs/Pmap_Cookbook.ipynb +++ b/cloud_tpu_colabs/Pmap_Cookbook.ipynb @@ -251,9 +251,9 @@ "colab_type": "text" }, "source": [ - "A `ShardedDeviceArray` is effectively an `ndarray` subclass, but it's stored in pieces spread across the memory of multiple devices. Results from `pmap` functions are left sharded in device memory so that they can be operated on by subsequent `pmap` functions without moving data around, at least in some cases. But these results logically appear just like a single array.\n", + "A sharded `Array` is effectively an `ndarray` subclass, but it's stored in pieces spread across the memory of multiple devices. Results from `pmap` functions are left sharded in device memory so that they can be operated on by subsequent `pmap` functions without moving data around, at least in some cases. But these results logically appear just like a single array.\n", "\n", - "When you call a non-`pmap` function on a `ShardedDeviceArray`, like a standard `jax.numpy` function, communication happens behind the scenes to bring the values to one device (or back to the host in the case of the matplotlib function above):" + "When you call a non-`pmap` function on an `Array`, like a standard `jax.numpy` function, communication happens behind the scenes to bring the values to one device (or back to the host in the case of the matplotlib function above):" ] }, { diff --git a/docs/autodidax.ipynb b/docs/autodidax.ipynb index 7ccaa0bedf57..af92a7e581e0 100644 --- a/docs/autodidax.ipynb +++ b/docs/autodidax.ipynb @@ -2344,8 +2344,8 @@ "One piece missing is device memory persistence for arrays. That is, we've\n", "defined `handle_result` to transfer results back to CPU memory as NumPy\n", "arrays, but it's often preferable to avoid transferring results just to\n", - "transfer them back for the next operation. We can do that by introducing a\n", - "`DeviceArray` class, which can wrap XLA buffers and otherwise duck-type\n", + "transfer them back for the next operation. We can do that by introducing an\n", + "`Array` class, which can wrap XLA buffers and otherwise duck-type\n", "`numpy.ndarray`s:" ] }, @@ -2356,9 +2356,9 @@ "outputs": [], "source": [ "def handle_result(aval: ShapedArray, buf): # noqa: F811\n", - " return DeviceArray(aval, buf)\n", + " return Array(aval, buf)\n", "\n", - "class DeviceArray:\n", + "class Array:\n", " buf: Any\n", " aval: ShapedArray\n", "\n", @@ -2381,9 +2381,9 @@ " _rmul = staticmethod(mul)\n", " _gt = staticmethod(greater)\n", " _lt = staticmethod(less)\n", - "input_handlers[DeviceArray] = lambda x: x.buf\n", + "input_handlers[Array] = lambda x: x.buf\n", "\n", - "jax_types.add(DeviceArray)" + "jax_types.add(Array)" ] }, { diff --git a/docs/autodidax.md b/docs/autodidax.md index 9fdaf7c85e53..a718c0c57156 100644 --- a/docs/autodidax.md +++ b/docs/autodidax.md @@ -1822,15 +1822,15 @@ print(ys) One piece missing is device memory persistence for arrays. That is, we've defined `handle_result` to transfer results back to CPU memory as NumPy arrays, but it's often preferable to avoid transferring results just to -transfer them back for the next operation. We can do that by introducing a -`DeviceArray` class, which can wrap XLA buffers and otherwise duck-type +transfer them back for the next operation. We can do that by introducing an +`Array` class, which can wrap XLA buffers and otherwise duck-type `numpy.ndarray`s: ```{code-cell} def handle_result(aval: ShapedArray, buf): # noqa: F811 - return DeviceArray(aval, buf) + return Array(aval, buf) -class DeviceArray: +class Array: buf: Any aval: ShapedArray @@ -1853,9 +1853,9 @@ class DeviceArray: _rmul = staticmethod(mul) _gt = staticmethod(greater) _lt = staticmethod(less) -input_handlers[DeviceArray] = lambda x: x.buf +input_handlers[Array] = lambda x: x.buf -jax_types.add(DeviceArray) +jax_types.add(Array) ``` ```{code-cell} diff --git a/docs/autodidax.py b/docs/autodidax.py index bae462a05cd9..199cce92b40d 100644 --- a/docs/autodidax.py +++ b/docs/autodidax.py @@ -1813,15 +1813,15 @@ def f(x): # One piece missing is device memory persistence for arrays. That is, we've # defined `handle_result` to transfer results back to CPU memory as NumPy # arrays, but it's often preferable to avoid transferring results just to -# transfer them back for the next operation. We can do that by introducing a -# `DeviceArray` class, which can wrap XLA buffers and otherwise duck-type +# transfer them back for the next operation. We can do that by introducing an +# `Array` class, which can wrap XLA buffers and otherwise duck-type # `numpy.ndarray`s: # + def handle_result(aval: ShapedArray, buf): # noqa: F811 - return DeviceArray(aval, buf) + return Array(aval, buf) -class DeviceArray: +class Array: buf: Any aval: ShapedArray @@ -1844,9 +1844,9 @@ def __str__(self): return str(np.asarray(self.buf)) _rmul = staticmethod(mul) _gt = staticmethod(greater) _lt = staticmethod(less) -input_handlers[DeviceArray] = lambda x: x.buf +input_handlers[Array] = lambda x: x.buf -jax_types.add(DeviceArray) +jax_types.add(Array) # + diff --git a/docs/jax-101/01-jax-basics.ipynb b/docs/jax-101/01-jax-basics.ipynb index c2a784713ef4..20aba4440e05 100644 --- a/docs/jax-101/01-jax-basics.ipynb +++ b/docs/jax-101/01-jax-basics.ipynb @@ -68,7 +68,7 @@ "source": [ "So far, everything is just like NumPy. A big appeal of JAX is that you don't need to learn a new API. Many common NumPy programs would run just as well in JAX if you substitute `np` for `jnp`. However, there are some important differences which we touch on at the end of this section.\n", "\n", - "You can notice the first difference if you check the type of `x`. It is a variable of type `DeviceArray`, which is the way JAX represents arrays." + "You can notice the first difference if you check the type of `x`. It is a variable of type `Array`, which is the way JAX represents arrays." ] }, { @@ -81,7 +81,7 @@ { "data": { "text/plain": [ - "DeviceArray([0, 1, 2, 3, 4, 5, 6, 7, 8, 9], dtype=int32)" + "Array([0, 1, 2, 3, 4, 5, 6, 7, 8, 9], dtype=int32)" ] }, "execution_count": 2, @@ -277,8 +277,8 @@ { "data": { "text/plain": [ - "(DeviceArray([-0.20000005, -0.19999981, -0.19999981, -0.19999981], dtype=float32),\n", - " DeviceArray([0.20000005, 0.19999981, 0.19999981, 0.19999981], dtype=float32))" + "(Array([-0.20000005, -0.19999981, -0.19999981, -0.19999981], dtype=float32),\n", + " Array([0.20000005, 0.19999981, 0.19999981, 0.19999981], dtype=float32))" ] }, "execution_count": 7, @@ -338,8 +338,8 @@ { "data": { "text/plain": [ - "(DeviceArray(0.03999995, dtype=float32),\n", - " DeviceArray([-0.20000005, -0.19999981, -0.19999981, -0.19999981], dtype=float32))" + "(Array(0.03999995, dtype=float32),\n", + " Array([-0.20000005, -0.19999981, -0.19999981, -0.19999981], dtype=float32))" ] }, "execution_count": 8, @@ -395,7 +395,7 @@ "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m", "\u001b[0;31mFilteredStackTrace\u001b[0m Traceback (most recent call last)", "\u001b[0;32m\u001b[0m in \u001b[0;36m\u001b[0;34m()\u001b[0m\n\u001b[1;32m 3\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m----> 4\u001b[0;31m \u001b[0mjax\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mgrad\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0msquared_error_with_aux\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mx\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0my\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m", - "\u001b[0;31mFilteredStackTrace\u001b[0m: TypeError: Gradient only defined for scalar-output functions. Output was (DeviceArray(0.03999995, dtype=float32), DeviceArray([-0.10000002, -0.0999999 , -0.0999999 , -0.0999999 ], dtype=float32)).\n\nThe stack trace above excludes JAX-internal frames." + "\u001b[0;31mFilteredStackTrace\u001b[0m: TypeError: Gradient only defined for scalar-output functions. Output was (Array(0.03999995, dtype=float32), Array([-0.10000002, -0.0999999 , -0.0999999 , -0.0999999 ], dtype=float32)).\n\nThe stack trace above excludes JAX-internal frames." ] } ], @@ -425,8 +425,8 @@ { "data": { "text/plain": [ - "(DeviceArray([-0.20000005, -0.19999981, -0.19999981, -0.19999981], dtype=float32),\n", - " DeviceArray([-0.10000002, -0.0999999 , -0.0999999 , -0.0999999 ], dtype=float32))" + "(Array([-0.20000005, -0.19999981, -0.19999981, -0.19999981], dtype=float32),\n", + " Array([-0.10000002, -0.0999999 , -0.0999999 , -0.0999999 ], dtype=float32))" ] }, "execution_count": 10, @@ -530,7 +530,7 @@ "\u001b[0;32m\u001b[0m in \u001b[0;36m\u001b[0;34m()\u001b[0m\n\u001b[0;32m----> 1\u001b[0;31m \u001b[0min_place_modify\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mjnp\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0marray\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mx\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m)\u001b[0m \u001b[0;31m# Raises error when we cast input to jnp.ndarray\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m", "\u001b[0;32m\u001b[0m in \u001b[0;36min_place_modify\u001b[0;34m(x)\u001b[0m\n\u001b[1;32m 4\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 5\u001b[0m \u001b[0;32mdef\u001b[0m \u001b[0min_place_modify\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mx\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m----> 6\u001b[0;31m \u001b[0mx\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0;36m0\u001b[0m\u001b[0;34m]\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0;36m123\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 7\u001b[0m \u001b[0;32mreturn\u001b[0m \u001b[0;32mNone\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 8\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n", "\u001b[0;32m/usr/local/lib/python3.7/dist-packages/jax/_src/numpy/lax_numpy.py\u001b[0m in \u001b[0;36m_unimplemented_setitem\u001b[0;34m(self, i, x)\u001b[0m\n\u001b[1;32m 6594\u001b[0m \u001b[0;34m\"or another .at[] method: \"\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 6595\u001b[0m \"https://jax.readthedocs.io/en/latest/jax.ops.html\")\n\u001b[0;32m-> 6596\u001b[0;31m \u001b[0;32mraise\u001b[0m \u001b[0mTypeError\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mmsg\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mformat\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mtype\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 6597\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 6598\u001b[0m \u001b[0;32mdef\u001b[0m \u001b[0m_operator_round\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mnumber\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mndigits\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0;32mNone\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", - "\u001b[0;31mTypeError\u001b[0m: '' object does not support item assignment. JAX arrays are immutable. Instead of ``x[idx] = y``, use ``x = x.at[idx].set(y)`` or another .at[] method: https://jax.readthedocs.io/en/latest/_autosummary/jax.numpy.ndarray.at.html" + "\u001b[0;31mTypeError\u001b[0m: '' object does not support item assignment. JAX arrays are immutable. Instead of ``x[idx] = y``, use ``x = x.at[idx].set(y)`` or another .at[] method: https://jax.readthedocs.io/en/latest/_autosummary/jax.numpy.ndarray.at.html" ] } ], @@ -557,7 +557,7 @@ { "data": { "text/plain": [ - "DeviceArray([123, 2, 3], dtype=int32)" + "Array([123, 2, 3], dtype=int32)" ] }, "execution_count": 13, @@ -594,7 +594,7 @@ { "data": { "text/plain": [ - "DeviceArray([1, 2, 3], dtype=int32)" + "Array([1, 2, 3], dtype=int32)" ] }, "execution_count": 14, diff --git a/docs/jax-101/01-jax-basics.md b/docs/jax-101/01-jax-basics.md index e7c1278b8c91..4431c085a3e8 100644 --- a/docs/jax-101/01-jax-basics.md +++ b/docs/jax-101/01-jax-basics.md @@ -45,7 +45,7 @@ print(x) So far, everything is just like NumPy. A big appeal of JAX is that you don't need to learn a new API. Many common NumPy programs would run just as well in JAX if you substitute `np` for `jnp`. However, there are some important differences which we touch on at the end of this section. -You can notice the first difference if you check the type of `x`. It is a variable of type `DeviceArray`, which is the way JAX represents arrays. +You can notice the first difference if you check the type of `x`. It is a variable of type `Array`, which is the way JAX represents arrays. ```{code-cell} ipython3 :id: 3fLtgPUAn7mi diff --git a/docs/jax-101/02-jitting.ipynb b/docs/jax-101/02-jitting.ipynb index de93aabf10f3..d72b310531e6 100644 --- a/docs/jax-101/02-jitting.ipynb +++ b/docs/jax-101/02-jitting.ipynb @@ -401,7 +401,7 @@ { "data": { "text/plain": [ - "DeviceArray(30, dtype=int32, weak_type=True)" + "Array(30, dtype=int32, weak_type=True)" ] }, "execution_count": 8, diff --git a/docs/jax-101/03-vectorization.ipynb b/docs/jax-101/03-vectorization.ipynb index d5dd3bbd4887..cbcf120d4812 100644 --- a/docs/jax-101/03-vectorization.ipynb +++ b/docs/jax-101/03-vectorization.ipynb @@ -37,7 +37,7 @@ { "data": { "text/plain": [ - "DeviceArray([11., 20., 29.], dtype=float32)" + "Array([11., 20., 29.], dtype=float32)" ] }, "execution_count": 1, @@ -104,7 +104,7 @@ { "data": { "text/plain": [ - "DeviceArray([[11., 20., 29.],\n", + "Array([[11., 20., 29.],\n", " [11., 20., 29.]], dtype=float32)" ] }, @@ -149,7 +149,7 @@ { "data": { "text/plain": [ - "DeviceArray([[11., 20., 29.],\n", + "Array([[11., 20., 29.],\n", " [11., 20., 29.]], dtype=float32)" ] }, @@ -201,7 +201,7 @@ { "data": { "text/plain": [ - "DeviceArray([[11., 20., 29.],\n", + "Array([[11., 20., 29.],\n", " [11., 20., 29.]], dtype=float32)" ] }, @@ -240,7 +240,7 @@ { "data": { "text/plain": [ - "DeviceArray([[11., 11.],\n", + "Array([[11., 11.],\n", " [20., 20.],\n", " [29., 29.]], dtype=float32)" ] @@ -281,7 +281,7 @@ { "data": { "text/plain": [ - "DeviceArray([[11., 20., 29.],\n", + "Array([[11., 20., 29.],\n", " [11., 20., 29.]], dtype=float32)" ] }, @@ -320,7 +320,7 @@ { "data": { "text/plain": [ - "DeviceArray([[11., 20., 29.],\n", + "Array([[11., 20., 29.],\n", " [11., 20., 29.]], dtype=float32)" ] }, diff --git a/docs/jax-101/04-advanced-autodiff.ipynb b/docs/jax-101/04-advanced-autodiff.ipynb index c4380a8cae74..7573a5839732 100644 --- a/docs/jax-101/04-advanced-autodiff.ipynb +++ b/docs/jax-101/04-advanced-autodiff.ipynb @@ -175,9 +175,9 @@ { "data": { "text/plain": [ - "DeviceArray([[2., 0., 0.],\n", - " [0., 2., 0.],\n", - " [0., 0., 2.]], dtype=float32)" + "Array([[2., 0., 0.],\n", + " [0., 2., 0.],\n", + " [0., 0., 2.]], dtype=float32)" ] }, "execution_count": 6, @@ -312,7 +312,7 @@ { "data": { "text/plain": [ - "DeviceArray([ 2.4, -2.4, 2.4], dtype=float32)" + "Array([ 2.4, -2.4, 2.4], dtype=float32)" ] }, "execution_count": 9, @@ -356,7 +356,7 @@ { "data": { "text/plain": [ - "DeviceArray([-2.4, -4.8, 2.4], dtype=float32)" + "Array([-2.4, -4.8, 2.4], dtype=float32)" ] }, "execution_count": 10, @@ -459,8 +459,8 @@ { "data": { "text/plain": [ - "DeviceArray([[-2.4, -4.8, 2.4],\n", - " [-2.4, -4.8, 2.4]], dtype=float32)" + "Array([[-2.4, -4.8, 2.4],\n", + " [-2.4, -4.8, 2.4]], dtype=float32)" ] }, "execution_count": 12, @@ -503,7 +503,7 @@ { "data": { "text/plain": [ - "DeviceArray([-2.4, -4.8, 2.4], dtype=float32)" + "Array([-2.4, -4.8, 2.4], dtype=float32)" ] }, "execution_count": 13, @@ -548,8 +548,8 @@ { "data": { "text/plain": [ - "DeviceArray([[-2.4, -4.8, 2.4],\n", - " [-2.4, -4.8, 2.4]], dtype=float32)" + "Array([[-2.4, -4.8, 2.4],\n", + " [-2.4, -4.8, 2.4]], dtype=float32)" ] }, "execution_count": 14, @@ -586,8 +586,8 @@ { "data": { "text/plain": [ - "DeviceArray([[-2.4, -4.8, 2.4],\n", - " [-2.4, -4.8, 2.4]], dtype=float32)" + "Array([[-2.4, -4.8, 2.4],\n", + " [-2.4, -4.8, 2.4]], dtype=float32)" ] }, "execution_count": 15, @@ -623,8 +623,8 @@ { "data": { "text/plain": [ - "DeviceArray([[-2.4, -4.8, 2.4],\n", - " [-2.4, -4.8, 2.4]], dtype=float32)" + "Array([[-2.4, -4.8, 2.4],\n", + " [-2.4, -4.8, 2.4]], dtype=float32)" ] }, "execution_count": 16, diff --git a/docs/notebooks/Common_Gotchas_in_JAX.ipynb b/docs/notebooks/Common_Gotchas_in_JAX.ipynb index 4b892c1c954b..fb5ea832db27 100644 --- a/docs/notebooks/Common_Gotchas_in_JAX.ipynb +++ b/docs/notebooks/Common_Gotchas_in_JAX.ipynb @@ -347,7 +347,7 @@ "evalue": "ignored", "output_type": "error", "traceback": [ - "\u001b[0;31mTypeError\u001b[0m\u001b[0;31m:\u001b[0m '' object does not support item assignment. JAX arrays are immutable. Instead of ``x[idx] = y``, use ``x = x.at[idx].set(y)`` or another .at[] method: https://jax.readthedocs.io/en/latest/_autosummary/jax.numpy.ndarray.at.html\n" + "\u001b[0;31mTypeError\u001b[0m\u001b[0;31m:\u001b[0m '' object does not support item assignment. JAX arrays are immutable. Instead of ``x[idx] = y``, use ``x = x.at[idx].set(y)`` or another .at[] method: https://jax.readthedocs.io/en/latest/_autosummary/jax.numpy.ndarray.at.html\n" ] } ], @@ -587,7 +587,7 @@ { "data": { "text/plain": [ - "DeviceArray(9, dtype=int32)" + "Array(9, dtype=int32)" ] }, "execution_count": 14, @@ -754,7 +754,7 @@ { "data": { "text/plain": [ - "DeviceArray(45, dtype=int32)" + "Array(45, dtype=int32)" ] }, "execution_count": 17, @@ -848,7 +848,7 @@ { "data": { "text/plain": [ - "DeviceArray(45, dtype=int32)" + "Array(45, dtype=int32)" ] }, "execution_count": 19, @@ -1020,7 +1020,7 @@ { "data": { "text/plain": [ - "DeviceArray([0, 0], dtype=uint32)" + "Array([0, 0], dtype=uint32)" ] }, "execution_count": 23, @@ -1408,7 +1408,7 @@ { "data": { "text/plain": [ - "DeviceArray(5., dtype=float32)" + "Array(5., dtype=float32)" ] }, "execution_count": 33, @@ -1553,7 +1553,7 @@ { "data": { "text/plain": [ - "DeviceArray(4, dtype=int32, weak_type=True)" + "Array(4, dtype=int32, weak_type=True)" ] }, "execution_count": 37, @@ -1616,7 +1616,7 @@ { "data": { "text/plain": [ - "DeviceArray([-1.], dtype=float32)" + "Array([-1.], dtype=float32)" ] }, "execution_count": 38, @@ -1689,7 +1689,7 @@ { "data": { "text/plain": [ - "DeviceArray(10, dtype=int32, weak_type=True)" + "Array(10, dtype=int32, weak_type=True)" ] }, "execution_count": 39, @@ -1733,7 +1733,7 @@ { "data": { "text/plain": [ - "DeviceArray(45, dtype=int32, weak_type=True)" + "Array(45, dtype=int32, weak_type=True)" ] }, "execution_count": 40, @@ -2000,7 +2000,7 @@ " 104 if np.any(np.isnan(py_val)):\n", "--> 105 raise FloatingPointError(\"invalid value\")\n", " 106 else:\n", - " 107 return DeviceArray(device_buffer, *result_shape)\n", + " 107 return Array(device_buffer, *result_shape)\n", "\n", "FloatingPointError: invalid value\n", "```" @@ -2222,7 +2222,7 @@ " array([254, 255, 0, 1], dtype=uint8)\n", "\n", " >>> jnp.arange(254.0, 258.0).astype('uint8')\n", - " DeviceArray([254, 255, 255, 255], dtype=uint8)\n", + " Array([254, 255, 255, 255], dtype=uint8)\n", " ```\n", " This sort of mismatch would typically arise when casting extreme values from floating to integer types or vice versa.\n", "\n", diff --git a/docs/notebooks/Common_Gotchas_in_JAX.md b/docs/notebooks/Common_Gotchas_in_JAX.md index 7b5e09494d13..619ebcc630b9 100644 --- a/docs/notebooks/Common_Gotchas_in_JAX.md +++ b/docs/notebooks/Common_Gotchas_in_JAX.md @@ -987,7 +987,7 @@ FloatingPointError Traceback (most recent call last) 104 if np.any(np.isnan(py_val)): --> 105 raise FloatingPointError("invalid value") 106 else: - 107 return DeviceArray(device_buffer, *result_shape) + 107 return Array(device_buffer, *result_shape) FloatingPointError: invalid value ``` @@ -1142,7 +1142,7 @@ Many such cases are discussed in detail in the sections above; here we list seve array([254, 255, 0, 1], dtype=uint8) >>> jnp.arange(254.0, 258.0).astype('uint8') - DeviceArray([254, 255, 255, 255], dtype=uint8) + Array([254, 255, 255, 255], dtype=uint8) ``` This sort of mismatch would typically arise when casting extreme values from floating to integer types or vice versa. diff --git a/docs/notebooks/Custom_derivative_rules_for_Python_code.ipynb b/docs/notebooks/Custom_derivative_rules_for_Python_code.ipynb index 3e2bf601867a..70f499a89a5d 100644 --- a/docs/notebooks/Custom_derivative_rules_for_Python_code.ipynb +++ b/docs/notebooks/Custom_derivative_rules_for_Python_code.ipynb @@ -234,7 +234,7 @@ { "data": { "text/plain": [ - "DeviceArray(3.0485873, dtype=float32)" + "Array(3.0485873, dtype=float32)" ] }, "execution_count": 8, @@ -1578,7 +1578,7 @@ { "data": { "text/plain": [ - "DeviceArray(-0.14112, dtype=float32)" + "Array(-0.14112, dtype=float32)" ] }, "execution_count": 50, @@ -1901,7 +1901,7 @@ "output_type": "stream", "text": [ "called f_bwd!\n", - "(DeviceArray(-0.9899925, dtype=float32),)\n" + "(Array(-0.9899925, dtype=float32),)\n" ] } ], @@ -2013,9 +2013,9 @@ "> (12)debug_bwd()\n", "-> return g\n", "(Pdb) p x\n", - "DeviceArray(9., dtype=float32)\n", + "Array(9., dtype=float32)\n", "(Pdb) p g\n", - "DeviceArray(-0.91113025, dtype=float32)\n", + "Array(-0.91113025, dtype=float32)\n", "(Pdb) q\n", "```" ] @@ -2085,7 +2085,7 @@ "name": "stdout", "output_type": "stream", "text": [ - "{'a': 1.0, 'b': (DeviceArray(0.841471, dtype=float32), DeviceArray(-0.4161468, dtype=float32))}\n" + "{'a': 1.0, 'b': (Array(0.841471, dtype=float32), Array(-0.4161468, dtype=float32))}\n" ] } ], @@ -2107,7 +2107,7 @@ "name": "stdout", "output_type": "stream", "text": [ - "Point(x=DeviceArray(2.5403023, dtype=float32), y=array(0., dtype=float32))\n" + "Point(x=Array(2.5403023, dtype=float32), y=array(0., dtype=float32))\n" ] } ], @@ -2166,7 +2166,7 @@ "name": "stdout", "output_type": "stream", "text": [ - "{'a': 1.0, 'b': (DeviceArray(0.841471, dtype=float32), DeviceArray(-0.4161468, dtype=float32))}\n" + "{'a': 1.0, 'b': (Array(0.841471, dtype=float32), Array(-0.4161468, dtype=float32))}\n" ] } ], @@ -2188,7 +2188,7 @@ "name": "stdout", "output_type": "stream", "text": [ - "Point(x=DeviceArray(2.5403023, dtype=float32), y=DeviceArray(-0., dtype=float32))\n" + "Point(x=Array(2.5403023, dtype=float32), y=Array(-0., dtype=float32))\n" ] } ], diff --git a/docs/notebooks/Custom_derivative_rules_for_Python_code.md b/docs/notebooks/Custom_derivative_rules_for_Python_code.md index aea0f3e276ef..003ea18edc54 100644 --- a/docs/notebooks/Custom_derivative_rules_for_Python_code.md +++ b/docs/notebooks/Custom_derivative_rules_for_Python_code.md @@ -1036,9 +1036,9 @@ jax.grad(foo)(3.) > (12)debug_bwd() -> return g (Pdb) p x -DeviceArray(9., dtype=float32) +Array(9., dtype=float32) (Pdb) p g -DeviceArray(-0.91113025, dtype=float32) +Array(-0.91113025, dtype=float32) (Pdb) q ``` diff --git a/docs/notebooks/autodiff_cookbook.ipynb b/docs/notebooks/autodiff_cookbook.ipynb index 24bf17850a63..39aad749de47 100644 --- a/docs/notebooks/autodiff_cookbook.ipynb +++ b/docs/notebooks/autodiff_cookbook.ipynb @@ -236,7 +236,7 @@ "name": "stdout", "output_type": "stream", "text": [ - "{'W': DeviceArray([-0.16965576, -0.8774645 , -1.4901344 ], dtype=float32), 'b': DeviceArray(-0.29227236, dtype=float32)}\n" + "{'W': Array([-0.16965576, -0.8774645 , -1.4901344 ], dtype=float32), 'b': Array(-0.29227236, dtype=float32)}\n" ] } ], @@ -1204,7 +1204,7 @@ "name": "stdout", "output_type": "stream", "text": [ - "(DeviceArray(3.1415927, dtype=float32),)\n" + "(Array(3.1415927, dtype=float32),)\n" ] } ], @@ -1470,7 +1470,7 @@ { "data": { "text/plain": [ - "DeviceArray(6.-8.j, dtype=complex64)" + "Array(6.-8.j, dtype=complex64)" ] }, "execution_count": 31, @@ -1511,7 +1511,7 @@ { "data": { "text/plain": [ - "DeviceArray(-27.034945-3.8511531j, dtype=complex64)" + "Array(-27.034945-3.8511531j, dtype=complex64)" ] }, "execution_count": 32, @@ -1549,7 +1549,7 @@ { "data": { "text/plain": [ - "DeviceArray(1.-0.j, dtype=complex64)" + "Array(1.-0.j, dtype=complex64)" ] }, "execution_count": 33, @@ -1602,12 +1602,12 @@ { "data": { "text/plain": [ - "DeviceArray([[-0.75342447 +0.j , -3.0509021 -10.940544j ,\n", - " 5.989684 +3.5422976j],\n", - " [-3.0509021 +10.940544j , -8.904487 +0.j ,\n", - " -5.1351547 -6.5593696j],\n", - " [ 5.989684 -3.5422976j, -5.1351547 +6.5593696j,\n", - " 0.01320434 +0.j ]], dtype=complex64)" + "Array([[-0.75342447 +0.j , -3.0509021 -10.940544j ,\n", + " 5.989684 +3.5422976j],\n", + " [-3.0509021 +10.940544j , -8.904487 +0.j ,\n", + " -5.1351547 -6.5593696j],\n", + " [ 5.989684 -3.5422976j, -5.1351547 +6.5593696j,\n", + " 0.01320434 +0.j ]], dtype=complex64)" ] }, "execution_count": 34, diff --git a/docs/notebooks/external_callbacks.ipynb b/docs/notebooks/external_callbacks.ipynb index d0dab1da1db9..e55d35bcfbb0 100644 --- a/docs/notebooks/external_callbacks.ipynb +++ b/docs/notebooks/external_callbacks.ipynb @@ -1071,7 +1071,7 @@ { "data": { "text/plain": [ - "DeviceArray(-0.4003078, dtype=float32, weak_type=True)" + "Array(-0.4003078, dtype=float32, weak_type=True)" ] }, "execution_count": 8, diff --git a/docs/notebooks/thinking_in_jax.ipynb b/docs/notebooks/thinking_in_jax.ipynb index f107bc485a5c..5e28802cfec8 100644 --- a/docs/notebooks/thinking_in_jax.ipynb +++ b/docs/notebooks/thinking_in_jax.ipynb @@ -170,7 +170,7 @@ { "data": { "text/plain": [ - "jax.interpreters.xla._DeviceArray" + "jaxlib.xla_extension.ArrayImpl" ] }, "execution_count": 5, @@ -248,7 +248,7 @@ "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m", "\u001b[0;31mTypeError\u001b[0m Traceback (most recent call last)", "\u001b[0;32m\u001b[0m in \u001b[0;36m\u001b[0;34m()\u001b[0m\n\u001b[1;32m 1\u001b[0m \u001b[0;31m# JAX: immutable arrays\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 2\u001b[0m \u001b[0mx\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mjnp\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0marange\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;36m10\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m----> 3\u001b[0;31m \u001b[0mx\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0;36m0\u001b[0m\u001b[0;34m]\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0;36m10\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m", - "\u001b[0;31mTypeError\u001b[0m: '' object does not support item assignment. JAX arrays are immutable; perhaps you want jax.ops.index_update or jax.ops.index_add instead?" + "\u001b[0;31mTypeError\u001b[0m: '' object does not support item assignment. JAX arrays are immutable; perhaps you want jax.ops.index_update or jax.ops.index_add instead?" ] } ], @@ -327,7 +327,7 @@ { "data": { "text/plain": [ - "DeviceArray(2., dtype=float32)" + "Array(2., dtype=float32)" ] }, "execution_count": 9, @@ -390,7 +390,7 @@ { "data": { "text/plain": [ - "DeviceArray(2., dtype=float32)" + "Array(2., dtype=float32)" ] }, "execution_count": 11, @@ -426,7 +426,7 @@ { "data": { "text/plain": [ - "DeviceArray([1., 3., 4., 4., 4., 4., 4., 4., 4., 4., 3., 1.], dtype=float32)" + "Array([1., 3., 4., 4., 4., 4., 4., 4., 4., 4., 3., 1.], dtype=float32)" ] }, "execution_count": 12, @@ -462,7 +462,7 @@ { "data": { "text/plain": [ - "DeviceArray([1., 3., 4., 4., 4., 4., 4., 4., 4., 4., 3., 1.], dtype=float32)" + "Array([1., 3., 4., 4., 4., 4., 4., 4., 4., 4., 3., 1.], dtype=float32)" ] }, "execution_count": 13, @@ -638,7 +638,7 @@ { "data": { "text/plain": [ - "DeviceArray([-0.10570311, -0.59403396, -0.8680282 , -0.23489487], dtype=float32)" + "Array([-0.10570311, -0.59403396, -0.8680282 , -0.23489487], dtype=float32)" ] }, "execution_count": 18, @@ -739,7 +739,7 @@ { "data": { "text/plain": [ - "DeviceArray([0.25773212, 5.3623195 , 5.4032435 ], dtype=float32)" + "Array([0.25773212, 5.3623195 , 5.4032435 ], dtype=float32)" ] }, "execution_count": 20, @@ -788,7 +788,7 @@ { "data": { "text/plain": [ - "DeviceArray([1.4344584, 4.3004413, 7.9897013], dtype=float32)" + "Array([1.4344584, 4.3004413, 7.9897013], dtype=float32)" ] }, "execution_count": 21, @@ -908,7 +908,7 @@ { "data": { "text/plain": [ - "DeviceArray(-1, dtype=int32)" + "Array(-1, dtype=int32)" ] }, "execution_count": 24, @@ -948,7 +948,7 @@ { "data": { "text/plain": [ - "DeviceArray(1, dtype=int32)" + "Array(1, dtype=int32)" ] }, "execution_count": 25, @@ -1086,7 +1086,7 @@ { "data": { "text/plain": [ - "DeviceArray([1., 1., 1., 1., 1., 1.], dtype=float32)" + "Array([1., 1., 1., 1., 1., 1.], dtype=float32)" ] }, "execution_count": 28, diff --git a/docs/notebooks/vmapped_log_probs.ipynb b/docs/notebooks/vmapped_log_probs.ipynb index 536f14431022..4ee2e4924d53 100644 --- a/docs/notebooks/vmapped_log_probs.ipynb +++ b/docs/notebooks/vmapped_log_probs.ipynb @@ -145,7 +145,7 @@ { "data": { "text/plain": [ - "DeviceArray(-213.23558, dtype=float32)" + "Array(-213.23558, dtype=float32)" ] }, "execution_count": 13, @@ -229,9 +229,9 @@ { "data": { "text/plain": [ - "DeviceArray([-147.84033203, -207.02204895, -109.26074982, -243.80830383,\n", - " -163.02911377, -143.84848022, -160.28771973, -113.77169037,\n", - " -126.60544586, -190.81988525], dtype=float32)" + "Array([-147.84033203, -207.02204895, -109.26074982, -243.80830383,\n", + " -163.02911377, -143.84848022, -160.28771973, -113.77169037,\n", + " -126.60544586, -190.81988525], dtype=float32)" ] }, "execution_count": 16, @@ -270,9 +270,9 @@ { "data": { "text/plain": [ - "DeviceArray([-147.84033203, -207.02204895, -109.26074982, -243.80830383,\n", - " -163.02911377, -143.84848022, -160.28771973, -113.77169037,\n", - " -126.60544586, -190.81988525], dtype=float32)" + "Array([-147.84033203, -207.02204895, -109.26074982, -243.80830383,\n", + " -163.02911377, -143.84848022, -160.28771973, -113.77169037,\n", + " -126.60544586, -190.81988525], dtype=float32)" ] }, "execution_count": 17, diff --git a/jax/_src/array.py b/jax/_src/array.py index 41cbbe723a2a..cb88f9beee4c 100644 --- a/jax/_src/array.py +++ b/jax/_src/array.py @@ -649,9 +649,8 @@ def make_array_from_single_device_arrays( ) -> ArrayImpl: r"""Returns a ``jax.Array`` from a sequence of ``jax.Array``\s on a single device. - ``jax.Array`` on a single device is analogous to a ``DeviceArray``. You can use - this function if you have already ``jax.device_put`` the value on a single - device and want to create a global Array. The smaller ``jax.Array``\s should be + You can use this function if you have already ``jax.device_put`` the value on + a single device and want to create a global Array. The smaller ``jax.Array``\s should be addressable and belong to the current process. Args: @@ -702,8 +701,7 @@ def make_array_from_single_device_arrays( aval = core.ShapedArray(shape, arrays[0].dtype, weak_type=False) if dtypes.issubdtype(aval.dtype, dtypes.extended): return aval.dtype._rules.make_sharded_array(aval, sharding, arrays, committed=True) - # TODO(phawkins): ideally the cast() could be checked. Revisit this after - # removing DeviceArray. + # TODO(phawkins): ideally the cast() could be checked. return ArrayImpl(aval, sharding, cast(Sequence[ArrayImpl], arrays), committed=True) diff --git a/jax/_src/core.py b/jax/_src/core.py index f39f896196fd..1b6f22d27650 100644 --- a/jax/_src/core.py +++ b/jax/_src/core.py @@ -1620,7 +1620,7 @@ def update(self, dtype=None, val=None, weak_type=None): def __eq__(self, other): if (type(self) is type(other) and self.dtype == other.dtype and self.shape == other.shape and self.weak_type == other.weak_type): - with eval_context(): # in case self.val is a DeviceArray + with eval_context(): # in case self.val is an Array return (self.val == other.val).all() else: return False diff --git a/jax/_src/dlpack.py b/jax/_src/dlpack.py index 70e12aae926c..eff01b8fc327 100644 --- a/jax/_src/dlpack.py +++ b/jax/_src/dlpack.py @@ -30,21 +30,21 @@ def to_dlpack(x: Array, take_ownership: bool = False, stream: int | None = None): - """Returns a DLPack tensor that encapsulates a ``DeviceArray`` `x`. + """Returns a DLPack tensor that encapsulates a :class:`~jax.Array` ``x``. - Takes ownership of the contents of ``x``; leaves `x` in an invalid/deleted + Takes ownership of the contents of ``x``; leaves ``x`` in an invalid/deleted state. Args: - x: a ``DeviceArray``, on either CPU or GPU. + x: a :class:`~jax.Array`, on either CPU or GPU. take_ownership: If ``True``, JAX hands ownership of the buffer to DLPack, and the consumer is free to mutate the buffer; the JAX buffer acts as if it were deleted. If ``False``, JAX retains ownership of the buffer; it is undefined behavior if the DLPack consumer writes to a buffer that JAX owns. stream: optional platform-dependent stream to wait on until the buffer is - ready. This corresponds to the `stream` argument to __dlpack__ documented - in https://dmlc.github.io/dlpack/latest/python_spec.html. + ready. This corresponds to the `stream` argument to ``__dlpack__`` + documented in https://dmlc.github.io/dlpack/latest/python_spec.html. """ if not isinstance(x, array.ArrayImpl): raise TypeError("Argument to to_dlpack must be a jax.Array, " @@ -64,9 +64,9 @@ def to_dlpack(x: Array, take_ownership: bool = False, def from_dlpack(dlpack): - """Returns a ``DeviceArray`` representation of a DLPack tensor. + """Returns a :class:`~jax.Array` representation of a DLPack tensor. - The returned ``DeviceArray`` shares memory with ``dlpack``. + The returned :class:`~jax.Array` shares memory with ``dlpack``. Args: dlpack: a DLPack tensor, on either CPU or GPU. diff --git a/jax/_src/interpreters/batching.py b/jax/_src/interpreters/batching.py index 36d098b6bbe2..d470dca93105 100644 --- a/jax/_src/interpreters/batching.py +++ b/jax/_src/interpreters/batching.py @@ -65,7 +65,7 @@ def __repr__(self) -> str: replace = dataclasses.replace # Jumble(aval=a:3 => f32[[3 1 4].a], -# data=DeviceArray([0., 1., 2., 0., 0., 1., 2., 3.], dtype=float32)) +# data=Array([0., 1., 2., 0., 0., 1., 2., 3.], dtype=float32)) @dataclasses.dataclass(frozen=True) class Jumble: aval: JumbleTy diff --git a/jax/_src/interpreters/pxla.py b/jax/_src/interpreters/pxla.py index 1ee5819a4927..4ae94e671c40 100644 --- a/jax/_src/interpreters/pxla.py +++ b/jax/_src/interpreters/pxla.py @@ -217,7 +217,7 @@ def local_aval_to_result_handler( Returns: A function for handling the Buffers that will eventually be produced for this output. The function will return an object suitable for returning - to the user, e.g. a ShardedDeviceArray. + to the user, e.g. an Array. """ try: return local_result_handlers[(type(aval))](aval, sharding, indices) @@ -247,7 +247,7 @@ def global_aval_to_result_handler( Returns: A function for handling the Buffers that will eventually be produced for this output. The function will return an object suitable for returning - to the user, e.g. a ShardedDeviceArray. + to the user, e.g. an Array. """ try: return global_result_handlers[type(aval)]( @@ -1048,7 +1048,7 @@ def __str__(self): class ResultsHandler: - # `out_avals` is the `GlobalDeviceArray` global avals when using pjit or xmap + # `out_avals` is the `Array` global avals when using pjit or xmap # with `config.parallel_functions_output_gda=True`. It is the local one # otherwise, and also when using `pmap`. __slots__ = ("handlers", "out_shardings", "out_avals") diff --git a/jax/_src/numpy/array_methods.py b/jax/_src/numpy/array_methods.py index ae0abce8a324..f68555d2db43 100644 --- a/jax/_src/numpy/array_methods.py +++ b/jax/_src/numpy/array_methods.py @@ -47,7 +47,7 @@ ### add method and operator overloads to arraylike classes -# We add operator overloads to DeviceArray and ShapedArray. These method and +# We add operator overloads to Array and ShapedArray. These method and # operator overloads mainly just forward calls to the corresponding lax_numpy # functions, which can themselves handle instances from any of these classes. @@ -240,7 +240,7 @@ def _view(arr: Array, dtype: DTypeLike = None, type: None = None) -> Array: def _notimplemented_flat(self): - raise NotImplementedError("JAX DeviceArrays do not implement the arr.flat property: " + raise NotImplementedError("JAX Arrays do not implement the arr.flat property: " "consider arr.flatten() instead.") _accepted_binop_types = (int, float, complex, np.generic, np.ndarray, Array) @@ -308,8 +308,8 @@ def _multi_slice(arr: ArrayLike, removed_dims: tuple[tuple[int, ...]]) -> list[Array]: """Extracts multiple slices from `arr`. - This is used to shard DeviceArray arguments to pmap. It's implemented as a - DeviceArray method here to avoid circular imports. + This is used to shard Array arguments to pmap. It's implemented as a + Array method here to avoid circular imports. """ results: list[Array] = [] for starts, limits, removed in zip(start_indices, limit_indices, removed_dims): @@ -746,7 +746,7 @@ def _set_tracer_aval_forwarding(tracer, exclude=()): setattr(tracer, prop_name, _forward_property_to_aval(prop_name)) def _set_array_base_attributes(device_array, include=None, exclude=None): - # Forward operators, methods, and properties on DeviceArray to lax_numpy + # Forward operators, methods, and properties on Array to lax_numpy # functions (with no Tracers involved; this forwarding is direct) def maybe_setattr(attr_name, target): if exclude is not None and attr_name in exclude: diff --git a/jax/_src/profiler.py b/jax/_src/profiler.py index 30fe521c8d16..f1d7d3832eb1 100644 --- a/jax/_src/profiler.py +++ b/jax/_src/profiler.py @@ -321,7 +321,7 @@ def device_memory_profile(backend: Optional[str] = None) -> bytes: """Captures a JAX device memory profile as ``pprof``-format protocol buffer. A device memory profile is a snapshot of the state of memory, that describes the JAX - :class:`jax.DeviceArray` and executable objects present in memory and their + :class:`~jax.Array` and executable objects present in memory and their allocation sites. For more information how to use the device memory profiler, see diff --git a/jax/_src/sharding_specs.py b/jax/_src/sharding_specs.py index 0d6774f5e5d6..e01c67eb65a9 100644 --- a/jax/_src/sharding_specs.py +++ b/jax/_src/sharding_specs.py @@ -13,12 +13,12 @@ # limitations under the License. # A ShardingSpec describes at a high level how a logical array is sharded across -# devices (each ShardedDeviceArray has a ShardingSpec, and ShardingSpecs also -# describe how to shard inputs to a parallel computation). spec_to_indices() -# encodes exactly how a given ShardingSpec is translated to device buffers, i.e. -# how the sharded array is "laid out" across devices. Given a sequence of -# devices, we shard the data across the devices in row-major order, with -# replication treated as an extra inner dimension. +# devices (each array sharded with a `PmapSharding` has a ShardingSpec, and +# ShardingSpecs also describe how to shard inputs to a parallel computation). +# spec_to_indices() encodes exactly how a given ShardingSpec is translated to +# device buffers, i.e. how the sharded array is "laid out" across devices. Given +# a sequence of devices, we shard the data across the devices in row-major +# order, with replication treated as an extra inner dimension. # # For example, given the logical data array [1, 2, 3, 4], if we were to # partition this array 4 ways with a replication factor of 2, for a total of 8 @@ -233,8 +233,8 @@ def spec_to_indices(shape: Sequence[int], """Returns numpy-style indices corresponding to a sharding spec. Each index describes a shard of the array. The order of the indices is the - same as the device_buffers of a ShardedDeviceArray (i.e. the data is laid out - row-major). + same as the device_buffers of a Array sharded using PmapSharding (i.e. the + data is laid out row-major). Args: shape: The shape of the logical array being sharded. diff --git a/jax/experimental/array_serialization/serialization.py b/jax/experimental/array_serialization/serialization.py index 0dd65614ae02..fd23d36f6a0e 100644 --- a/jax/experimental/array_serialization/serialization.py +++ b/jax/experimental/array_serialization/serialization.py @@ -11,7 +11,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -"""GlobalDeviceArray serialization and deserialization.""" +"""Array serialization and deserialization.""" import abc import asyncio @@ -482,7 +482,7 @@ class GlobalAsyncCheckpointManager(AsyncManager, GlobalAsyncCheckpointManagerBas """Responsible for serializing GDAs via TensorStore.""" def serialize(self, arrays, tensorstore_specs, *, on_commit_callback): - """Serializes GlobalDeviceArrays or Arrays via TensorStore asynchronously. + """Serializes Arrays or Arrays via TensorStore asynchronously. TensorStore writes to a storage layer in 2 steps: * Reading/copying from the source after which the source can be modified. @@ -494,7 +494,7 @@ def serialize(self, arrays, tensorstore_specs, *, on_commit_callback): finish in a separate thread allowing other computation to proceed. Args: - arrays: GlobalDeviceArrays or Arrays that should be serialized. + arrays: Arrays or Arrays that should be serialized. tensorstore_specs: TensorStore specs that are used to serialize GDAs or Arrays. on_commit_callback: This callback will be executed after all processes diff --git a/jax/experimental/jax2tf/call_tf.py b/jax/experimental/jax2tf/call_tf.py index ae1299966501..301216131fa7 100644 --- a/jax/experimental/jax2tf/call_tf.py +++ b/jax/experimental/jax2tf/call_tf.py @@ -57,7 +57,7 @@ TfVal = jax2tf_internal.TfVal # The platforms for which to use DLPack to avoid copying (only works on GPU -# and CPU at the moment, and only for DeviceArray). For CPU we don't need +# and CPU at the moment, and only for Array). For CPU we don't need # DLPack, if we are careful. _DLPACK_PLATFORMS = ("gpu",) @@ -335,7 +335,7 @@ def _arg_jax_to_tf(arg_jax): arg_jax.dtype in dlpack.SUPPORTED_DTYPES): arg_dlpack = jax.dlpack.to_dlpack(arg_jax, take_ownership=False) return tf.experimental.dlpack.from_dlpack(arg_dlpack) - # The following avoids copies to the host on CPU, always for DeviceArray + # The following avoids copies to the host on CPU, always for Array # and even for ndarray if they are sufficiently aligned. # TODO(necula): on TPU this copies to the host! return tf.constant(np.asarray(arg_jax)) diff --git a/jax/experimental/jax2tf/jax2tf.py b/jax/experimental/jax2tf/jax2tf.py index 1d47498c5e7d..9a3e36551fa4 100644 --- a/jax/experimental/jax2tf/jax2tf.py +++ b/jax/experimental/jax2tf/jax2tf.py @@ -1814,11 +1814,11 @@ def _not(x): Numpy and JAX support bitwise not for booleans by applying a logical not! This means that applying bitwise_not yields an unexpected result: jnp.bitwise_not(jnp.array([True, False])) - >> DeviceArray([False, True], dtype=bool) + >> Array([False, True], dtype=bool) if you assume that booleans are simply casted to integers. jnp.bitwise_not(jnp.array([True, False]).astype(np.int32)).astype(bool) - >> DeviceArray([True, True], dtype=bool) + >> Array([True, True], dtype=bool) """ if x.dtype == tf.bool: return tf.logical_not(x) diff --git a/jax/experimental/jax2tf/tests/jax2tf_test.py b/jax/experimental/jax2tf/tests/jax2tf_test.py index c2a01a54ccff..0d6eca9f26dd 100644 --- a/jax/experimental/jax2tf/tests/jax2tf_test.py +++ b/jax/experimental/jax2tf/tests/jax2tf_test.py @@ -189,7 +189,7 @@ def test_converts_jax_arrays(self): f_tf = tf.function(lambda x: x + x) self.assertEqual(f_tf(jnp.ones([])).numpy(), 2.) - # Test with ShardedDeviceArray. + # Test with a PmapSharding-sharded Array. n = jax.local_device_count() mk_sharded = lambda f: jax.pmap(lambda x: x)(f([n])) f_tf = tf.function(lambda x: x) diff --git a/jax/experimental/jax2tf/tests/shape_poly_test.py b/jax/experimental/jax2tf/tests/shape_poly_test.py index 43d19b994001..b266724911ca 100644 --- a/jax/experimental/jax2tf/tests/shape_poly_test.py +++ b/jax/experimental/jax2tf/tests/shape_poly_test.py @@ -226,7 +226,7 @@ def test_dim_vars_symbolic_equal(self): self.assertFalse(core.definitely_equal_one_of_dim(1, [2, b])) self.assertFalse(core.definitely_equal_one_of_dim(3, [])) - self.assertTrue(core.definitely_equal(1, jnp.add(0, 1))) # A DeviceArray + self.assertTrue(core.definitely_equal(1, jnp.add(0, 1))) # An Array self.assertFalse(core.definitely_equal(1, "a")) def test_poly_bounds(self): diff --git a/jax/experimental/sparse/bcoo.py b/jax/experimental/sparse/bcoo.py index 177162d450d6..a243720c21b8 100644 --- a/jax/experimental/sparse/bcoo.py +++ b/jax/experimental/sparse/bcoo.py @@ -2505,7 +2505,7 @@ def sum(self) -> BCOO: @classmethod def fromdense(cls, mat: Array, *, nse: int | None = None, index_dtype: DTypeLike = np.int32, n_dense: int = 0, n_batch: int = 0) -> BCOO: - """Create a BCOO array from a (dense) :class:`DeviceArray`.""" + """Create a BCOO array from a (dense) :class:`~jax.Array`.""" return bcoo_fromdense( mat, nse=nse, index_dtype=index_dtype, n_dense=n_dense, n_batch=n_batch) diff --git a/jax/experimental/sparse/bcsr.py b/jax/experimental/sparse/bcsr.py index 72ba9bce424e..dba7a25595bf 100644 --- a/jax/experimental/sparse/bcsr.py +++ b/jax/experimental/sparse/bcsr.py @@ -848,7 +848,7 @@ def sum_duplicates(self, nse: int | None = None, remove_zeros: bool = True) -> B @classmethod def fromdense(cls, mat, *, nse=None, index_dtype=np.int32, n_dense=0, n_batch=0): - """Create a BCSR array from a (dense) :class:`DeviceArray`.""" + """Create a BCSR array from a (dense) :class:`Array`.""" return bcsr_fromdense(mat, nse=nse, index_dtype=index_dtype, n_dense=n_dense, n_batch=n_batch) diff --git a/tests/api_test.py b/tests/api_test.py index 9d300b738b04..1d38f3f58296 100644 --- a/tests/api_test.py +++ b/tests/api_test.py @@ -2669,7 +2669,7 @@ def test_float0_error(self): error_text = "float0s do not support any operations by design" with self.assertRaisesRegex(TypeError, error_text): - # dispatch via DeviceArray + # dispatch via Array _ = float0_array + jnp.zeros(()) with self.assertRaisesRegex(TypeError, error_text): diff --git a/tests/array_test.py b/tests/array_test.py index 0bd905e2e19b..bc0eaeed9f32 100644 --- a/tests/array_test.py +++ b/tests/array_test.py @@ -11,7 +11,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -"""Tests for GlobalDeviceArray.""" +"""Tests for Array.""" import contextlib import math diff --git a/tests/dtypes_test.py b/tests/dtypes_test.py index b02a231a59c5..d6482ea482f4 100644 --- a/tests/dtypes_test.py +++ b/tests/dtypes_test.py @@ -661,7 +661,7 @@ def testBinaryNonPromotion(self, dtype, weak_type, promotion): for dtype in all_dtypes for weak_type in [True, False] ) - def testDeviceArrayRepr(self, dtype, weak_type): + def testArrayRepr(self, dtype, weak_type): val = lax_internal._convert_element_type(0, dtype, weak_type=weak_type) rep = repr(val) self.assertStartsWith(rep, 'Array(') diff --git a/tests/lax_control_flow_test.py b/tests/lax_control_flow_test.py index adcd97dbd05d..ca52948871f7 100644 --- a/tests/lax_control_flow_test.py +++ b/tests/lax_control_flow_test.py @@ -2567,7 +2567,7 @@ def cumprod(x): # TODO(mattjj): make the numpy.ndarray test pass w/ remat raise unittest.SkipTest("new-remat-of-scan doesn't convert numpy.ndarray") - x = rng.randn(32, 2, 32).astype('float32') # numpy.ndarray, not DeviceArray + x = rng.randn(32, 2, 32).astype('float32') # numpy.ndarray, not Array _, vjp_fun = jax.vjp(cumprod, x) *_, ext_res = vjp_fun.args[0].args[0] self.assertIsInstance(ext_res, jax.Array) diff --git a/tests/lax_numpy_test.py b/tests/lax_numpy_test.py index b34cfb43940c..8d1d5002d2a8 100644 --- a/tests/lax_numpy_test.py +++ b/tests/lax_numpy_test.py @@ -3196,7 +3196,7 @@ def testArrayFromList(self): def testIssue121(self): assert not np.isscalar(jnp.array(3)) - def testArrayOutputsDeviceArrays(self): + def testArrayOutputsArrays(self): assert type(jnp.array([])) is array.ArrayImpl assert type(jnp.array(np.array([]))) is array.ArrayImpl @@ -3206,10 +3206,10 @@ def __array__(self, dtype=None): assert type(jnp.array(NDArrayLike())) is array.ArrayImpl # NOTE(mattjj): disabled b/c __array__ must produce ndarrays - # class DeviceArrayLike: + # class ArrayLike: # def __array__(self, dtype=None): # return jnp.array([], dtype=dtype) - # assert xla.type_is_device_array(jnp.array(DeviceArrayLike())) + # assert xla.type_is_device_array(jnp.array(ArrayLike())) def testArrayMethod(self): class arraylike: diff --git a/tests/multiprocess_gpu_test.py b/tests/multiprocess_gpu_test.py index c2531119d5f5..dafe0dfe74d7 100644 --- a/tests/multiprocess_gpu_test.py +++ b/tests/multiprocess_gpu_test.py @@ -320,8 +320,6 @@ def test_gpu_multi_node_transparent_initialize_and_psum(self): self.assertEqual(y[0], jax.device_count()) print(y) - # TODO(sudhakarsingh27): To change/omit test in favor of using `Array` - # since `GlobalDeviceArray` is going to be deprecated in the future def test_pjit_gda_multi_input_multi_output(self): jax.distributed.initialize() global_mesh = jtu.create_global_mesh((8, 2), ("x", "y")) @@ -370,8 +368,6 @@ def f(x, y, z): np.testing.assert_array_equal(np.asarray(s.data), global_input_data[s.index]) - # TODO(sudhakarsingh27): To change/omit test in favor of using `Array` - # since `GlobalDeviceArray` is going to be deprecated in the future def test_pjit_gda_non_contiguous_mesh(self): jax.distributed.initialize() devices = self.sorted_devices() @@ -428,8 +424,6 @@ def cb(index): np.testing.assert_array_equal(np.asarray(s.data), global_input_data[expected_index]) - # TODO(sudhakarsingh27): To change/omit test in favor of using `Array` - # since `GlobalDeviceArray` is going to be deprecated in the future def test_pjit_gda_non_contiguous_mesh_2d(self): jax.distributed.initialize() global_mesh = self.create_2d_non_contiguous_mesh() @@ -504,8 +498,6 @@ def cb(index): # Fully replicated values + GDA allows a non-contiguous mesh. out1, out2 = f(global_input_data, gda2) - # TODO(sudhakarsingh27): To change/omit test in favor of using `Array` - # since `GlobalDeviceArray` is going to be deprecated in the future def test_pjit_gda_non_contiguous_mesh_2d_aot(self): jax.distributed.initialize() global_mesh = self.create_2d_non_contiguous_mesh() @@ -531,8 +523,6 @@ def test_pjit_gda_non_contiguous_mesh_2d_aot(self): self.assertEqual(out1.shape, (8, 2)) self.assertEqual(out2.shape, (8, 2)) - # TODO(sudhakarsingh27): To change/omit test in favor of using `Array` - # since `GlobalDeviceArray` is going to be deprecated in the future def test_pjit_gda_eval_shape(self): jax.distributed.initialize() diff --git a/tests/pickle_test.py b/tests/pickle_test.py index 019476ed721c..1ef76d9b2844 100644 --- a/tests/pickle_test.py +++ b/tests/pickle_test.py @@ -98,7 +98,7 @@ def g(z): class PickleTest(jtu.JaxTestCase): - def testPickleOfDeviceArray(self): + def testPickleOfArray(self): x = jnp.arange(10.0) s = pickle.dumps(x) y = pickle.loads(s) @@ -106,7 +106,7 @@ def testPickleOfDeviceArray(self): self.assertIsInstance(y, type(x)) self.assertEqual(x.aval, y.aval) - def testPickleOfDeviceArrayWeakType(self): + def testPickleOfArrayWeakType(self): x = jnp.array(4.0) self.assertEqual(x.aval.weak_type, True) s = pickle.dumps(x) diff --git a/tests/pmap_test.py b/tests/pmap_test.py index 7bc72ca0e4e6..7667330749eb 100644 --- a/tests/pmap_test.py +++ b/tests/pmap_test.py @@ -842,7 +842,7 @@ def testArrays(self): self.assertNotIsInstance(z, np.ndarray) self.assertAllClose(z, 2 * 2 * x, check_dtypes=False) - # test that we can pass in a regular DeviceArray + # test that we can pass in a regular Array y = f(device_put(x)) self.assertIsInstance(y, array.ArrayImpl) self.assertAllClose(y, 2 * x, check_dtypes=False)