Skip to content

Commit

Permalink
Replace references to DeviceArray with Array.
Browse files Browse the repository at this point in the history
A number of stale references are lurking in our documentation.
  • Loading branch information
hawkinsp committed Aug 18, 2023
1 parent 97af33c commit 2c32660
Show file tree
Hide file tree
Showing 40 changed files with 161 additions and 173 deletions.
4 changes: 2 additions & 2 deletions cloud_tpu_colabs/Pmap_Cookbook.ipynb
Expand Up @@ -251,9 +251,9 @@
"colab_type": "text"
},
"source": [
"A `ShardedDeviceArray` is effectively an `ndarray` subclass, but it's stored in pieces spread across the memory of multiple devices. Results from `pmap` functions are left sharded in device memory so that they can be operated on by subsequent `pmap` functions without moving data around, at least in some cases. But these results logically appear just like a single array.\n",
"A sharded `Array` is effectively an `ndarray` subclass, but it's stored in pieces spread across the memory of multiple devices. Results from `pmap` functions are left sharded in device memory so that they can be operated on by subsequent `pmap` functions without moving data around, at least in some cases. But these results logically appear just like a single array.\n",
"\n",
"When you call a non-`pmap` function on a `ShardedDeviceArray`, like a standard `jax.numpy` function, communication happens behind the scenes to bring the values to one device (or back to the host in the case of the matplotlib function above):"
"When you call a non-`pmap` function on an `Array`, like a standard `jax.numpy` function, communication happens behind the scenes to bring the values to one device (or back to the host in the case of the matplotlib function above):"
]
},
{
Expand Down
12 changes: 6 additions & 6 deletions docs/autodidax.ipynb
Expand Up @@ -2344,8 +2344,8 @@
"One piece missing is device memory persistence for arrays. That is, we've\n",
"defined `handle_result` to transfer results back to CPU memory as NumPy\n",
"arrays, but it's often preferable to avoid transferring results just to\n",
"transfer them back for the next operation. We can do that by introducing a\n",
"`DeviceArray` class, which can wrap XLA buffers and otherwise duck-type\n",
"transfer them back for the next operation. We can do that by introducing an\n",
"`Array` class, which can wrap XLA buffers and otherwise duck-type\n",
"`numpy.ndarray`s:"
]
},
Expand All @@ -2356,9 +2356,9 @@
"outputs": [],
"source": [
"def handle_result(aval: ShapedArray, buf): # noqa: F811\n",
" return DeviceArray(aval, buf)\n",
" return Array(aval, buf)\n",
"\n",
"class DeviceArray:\n",
"class Array:\n",
" buf: Any\n",
" aval: ShapedArray\n",
"\n",
Expand All @@ -2381,9 +2381,9 @@
" _rmul = staticmethod(mul)\n",
" _gt = staticmethod(greater)\n",
" _lt = staticmethod(less)\n",
"input_handlers[DeviceArray] = lambda x: x.buf\n",
"input_handlers[Array] = lambda x: x.buf\n",
"\n",
"jax_types.add(DeviceArray)"
"jax_types.add(Array)"
]
},
{
Expand Down
12 changes: 6 additions & 6 deletions docs/autodidax.md
Expand Up @@ -1822,15 +1822,15 @@ print(ys)
One piece missing is device memory persistence for arrays. That is, we've
defined `handle_result` to transfer results back to CPU memory as NumPy
arrays, but it's often preferable to avoid transferring results just to
transfer them back for the next operation. We can do that by introducing a
`DeviceArray` class, which can wrap XLA buffers and otherwise duck-type
transfer them back for the next operation. We can do that by introducing an
`Array` class, which can wrap XLA buffers and otherwise duck-type
`numpy.ndarray`s:

```{code-cell}
def handle_result(aval: ShapedArray, buf): # noqa: F811
return DeviceArray(aval, buf)
return Array(aval, buf)
class DeviceArray:
class Array:
buf: Any
aval: ShapedArray
Expand All @@ -1853,9 +1853,9 @@ class DeviceArray:
_rmul = staticmethod(mul)
_gt = staticmethod(greater)
_lt = staticmethod(less)
input_handlers[DeviceArray] = lambda x: x.buf
input_handlers[Array] = lambda x: x.buf
jax_types.add(DeviceArray)
jax_types.add(Array)
```

```{code-cell}
Expand Down
12 changes: 6 additions & 6 deletions docs/autodidax.py
Expand Up @@ -1813,15 +1813,15 @@ def f(x):
# One piece missing is device memory persistence for arrays. That is, we've
# defined `handle_result` to transfer results back to CPU memory as NumPy
# arrays, but it's often preferable to avoid transferring results just to
# transfer them back for the next operation. We can do that by introducing a
# `DeviceArray` class, which can wrap XLA buffers and otherwise duck-type
# transfer them back for the next operation. We can do that by introducing an
# `Array` class, which can wrap XLA buffers and otherwise duck-type
# `numpy.ndarray`s:

# +
def handle_result(aval: ShapedArray, buf): # noqa: F811
return DeviceArray(aval, buf)
return Array(aval, buf)

class DeviceArray:
class Array:
buf: Any
aval: ShapedArray

Expand All @@ -1844,9 +1844,9 @@ def __str__(self): return str(np.asarray(self.buf))
_rmul = staticmethod(mul)
_gt = staticmethod(greater)
_lt = staticmethod(less)
input_handlers[DeviceArray] = lambda x: x.buf
input_handlers[Array] = lambda x: x.buf

jax_types.add(DeviceArray)
jax_types.add(Array)


# +
Expand Down
24 changes: 12 additions & 12 deletions docs/jax-101/01-jax-basics.ipynb
Expand Up @@ -68,7 +68,7 @@
"source": [
"So far, everything is just like NumPy. A big appeal of JAX is that you don't need to learn a new API. Many common NumPy programs would run just as well in JAX if you substitute `np` for `jnp`. However, there are some important differences which we touch on at the end of this section.\n",
"\n",
"You can notice the first difference if you check the type of `x`. It is a variable of type `DeviceArray`, which is the way JAX represents arrays."
"You can notice the first difference if you check the type of `x`. It is a variable of type `Array`, which is the way JAX represents arrays."
]
},
{
Expand All @@ -81,7 +81,7 @@
{
"data": {
"text/plain": [
"DeviceArray([0, 1, 2, 3, 4, 5, 6, 7, 8, 9], dtype=int32)"
"Array([0, 1, 2, 3, 4, 5, 6, 7, 8, 9], dtype=int32)"
]
},
"execution_count": 2,
Expand Down Expand Up @@ -277,8 +277,8 @@
{
"data": {
"text/plain": [
"(DeviceArray([-0.20000005, -0.19999981, -0.19999981, -0.19999981], dtype=float32),\n",
" DeviceArray([0.20000005, 0.19999981, 0.19999981, 0.19999981], dtype=float32))"
"(Array([-0.20000005, -0.19999981, -0.19999981, -0.19999981], dtype=float32),\n",
" Array([0.20000005, 0.19999981, 0.19999981, 0.19999981], dtype=float32))"
]
},
"execution_count": 7,
Expand Down Expand Up @@ -338,8 +338,8 @@
{
"data": {
"text/plain": [
"(DeviceArray(0.03999995, dtype=float32),\n",
" DeviceArray([-0.20000005, -0.19999981, -0.19999981, -0.19999981], dtype=float32))"
"(Array(0.03999995, dtype=float32),\n",
" Array([-0.20000005, -0.19999981, -0.19999981, -0.19999981], dtype=float32))"
]
},
"execution_count": 8,
Expand Down Expand Up @@ -395,7 +395,7 @@
"\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
"\u001b[0;31mFilteredStackTrace\u001b[0m Traceback (most recent call last)",
"\u001b[0;32m<ipython-input-9-7433a86e7375>\u001b[0m in \u001b[0;36m<module>\u001b[0;34m()\u001b[0m\n\u001b[1;32m 3\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m----> 4\u001b[0;31m \u001b[0mjax\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mgrad\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0msquared_error_with_aux\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mx\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0my\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m",
"\u001b[0;31mFilteredStackTrace\u001b[0m: TypeError: Gradient only defined for scalar-output functions. Output was (DeviceArray(0.03999995, dtype=float32), DeviceArray([-0.10000002, -0.0999999 , -0.0999999 , -0.0999999 ], dtype=float32)).\n\nThe stack trace above excludes JAX-internal frames."
"\u001b[0;31mFilteredStackTrace\u001b[0m: TypeError: Gradient only defined for scalar-output functions. Output was (Array(0.03999995, dtype=float32), Array([-0.10000002, -0.0999999 , -0.0999999 , -0.0999999 ], dtype=float32)).\n\nThe stack trace above excludes JAX-internal frames."
]
}
],
Expand Down Expand Up @@ -425,8 +425,8 @@
{
"data": {
"text/plain": [
"(DeviceArray([-0.20000005, -0.19999981, -0.19999981, -0.19999981], dtype=float32),\n",
" DeviceArray([-0.10000002, -0.0999999 , -0.0999999 , -0.0999999 ], dtype=float32))"
"(Array([-0.20000005, -0.19999981, -0.19999981, -0.19999981], dtype=float32),\n",
" Array([-0.10000002, -0.0999999 , -0.0999999 , -0.0999999 ], dtype=float32))"
]
},
"execution_count": 10,
Expand Down Expand Up @@ -530,7 +530,7 @@
"\u001b[0;32m<ipython-input-12-709e2d7ddd3f>\u001b[0m in \u001b[0;36m<module>\u001b[0;34m()\u001b[0m\n\u001b[0;32m----> 1\u001b[0;31m \u001b[0min_place_modify\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mjnp\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0marray\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mx\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m)\u001b[0m \u001b[0;31m# Raises error when we cast input to jnp.ndarray\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m",
"\u001b[0;32m<ipython-input-11-fce65eb843c7>\u001b[0m in \u001b[0;36min_place_modify\u001b[0;34m(x)\u001b[0m\n\u001b[1;32m 4\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 5\u001b[0m \u001b[0;32mdef\u001b[0m \u001b[0min_place_modify\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mx\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m----> 6\u001b[0;31m \u001b[0mx\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0;36m0\u001b[0m\u001b[0;34m]\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0;36m123\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 7\u001b[0m \u001b[0;32mreturn\u001b[0m \u001b[0;32mNone\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 8\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n",
"\u001b[0;32m/usr/local/lib/python3.7/dist-packages/jax/_src/numpy/lax_numpy.py\u001b[0m in \u001b[0;36m_unimplemented_setitem\u001b[0;34m(self, i, x)\u001b[0m\n\u001b[1;32m 6594\u001b[0m \u001b[0;34m\"or another .at[] method: \"\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 6595\u001b[0m \"https://jax.readthedocs.io/en/latest/jax.ops.html\")\n\u001b[0;32m-> 6596\u001b[0;31m \u001b[0;32mraise\u001b[0m \u001b[0mTypeError\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mmsg\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mformat\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mtype\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 6597\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 6598\u001b[0m \u001b[0;32mdef\u001b[0m \u001b[0m_operator_round\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mnumber\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mndigits\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0;32mNone\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
"\u001b[0;31mTypeError\u001b[0m: '<class 'jaxlib.xla_extension.DeviceArray'>' object does not support item assignment. JAX arrays are immutable. Instead of ``x[idx] = y``, use ``x = x.at[idx].set(y)`` or another .at[] method: https://jax.readthedocs.io/en/latest/_autosummary/jax.numpy.ndarray.at.html"
"\u001b[0;31mTypeError\u001b[0m: '<class 'jaxlib.xla_extension.Array'>' object does not support item assignment. JAX arrays are immutable. Instead of ``x[idx] = y``, use ``x = x.at[idx].set(y)`` or another .at[] method: https://jax.readthedocs.io/en/latest/_autosummary/jax.numpy.ndarray.at.html"
]
}
],
Expand All @@ -557,7 +557,7 @@
{
"data": {
"text/plain": [
"DeviceArray([123, 2, 3], dtype=int32)"
"Array([123, 2, 3], dtype=int32)"
]
},
"execution_count": 13,
Expand Down Expand Up @@ -594,7 +594,7 @@
{
"data": {
"text/plain": [
"DeviceArray([1, 2, 3], dtype=int32)"
"Array([1, 2, 3], dtype=int32)"
]
},
"execution_count": 14,
Expand Down
2 changes: 1 addition & 1 deletion docs/jax-101/01-jax-basics.md
Expand Up @@ -45,7 +45,7 @@ print(x)

So far, everything is just like NumPy. A big appeal of JAX is that you don't need to learn a new API. Many common NumPy programs would run just as well in JAX if you substitute `np` for `jnp`. However, there are some important differences which we touch on at the end of this section.

You can notice the first difference if you check the type of `x`. It is a variable of type `DeviceArray`, which is the way JAX represents arrays.
You can notice the first difference if you check the type of `x`. It is a variable of type `Array`, which is the way JAX represents arrays.

```{code-cell} ipython3
:id: 3fLtgPUAn7mi
Expand Down
2 changes: 1 addition & 1 deletion docs/jax-101/02-jitting.ipynb
Expand Up @@ -401,7 +401,7 @@
{
"data": {
"text/plain": [
"DeviceArray(30, dtype=int32, weak_type=True)"
"Array(30, dtype=int32, weak_type=True)"
]
},
"execution_count": 8,
Expand Down
14 changes: 7 additions & 7 deletions docs/jax-101/03-vectorization.ipynb
Expand Up @@ -37,7 +37,7 @@
{
"data": {
"text/plain": [
"DeviceArray([11., 20., 29.], dtype=float32)"
"Array([11., 20., 29.], dtype=float32)"
]
},
"execution_count": 1,
Expand Down Expand Up @@ -104,7 +104,7 @@
{
"data": {
"text/plain": [
"DeviceArray([[11., 20., 29.],\n",
"Array([[11., 20., 29.],\n",
" [11., 20., 29.]], dtype=float32)"
]
},
Expand Down Expand Up @@ -149,7 +149,7 @@
{
"data": {
"text/plain": [
"DeviceArray([[11., 20., 29.],\n",
"Array([[11., 20., 29.],\n",
" [11., 20., 29.]], dtype=float32)"
]
},
Expand Down Expand Up @@ -201,7 +201,7 @@
{
"data": {
"text/plain": [
"DeviceArray([[11., 20., 29.],\n",
"Array([[11., 20., 29.],\n",
" [11., 20., 29.]], dtype=float32)"
]
},
Expand Down Expand Up @@ -240,7 +240,7 @@
{
"data": {
"text/plain": [
"DeviceArray([[11., 11.],\n",
"Array([[11., 11.],\n",
" [20., 20.],\n",
" [29., 29.]], dtype=float32)"
]
Expand Down Expand Up @@ -281,7 +281,7 @@
{
"data": {
"text/plain": [
"DeviceArray([[11., 20., 29.],\n",
"Array([[11., 20., 29.],\n",
" [11., 20., 29.]], dtype=float32)"
]
},
Expand Down Expand Up @@ -320,7 +320,7 @@
{
"data": {
"text/plain": [
"DeviceArray([[11., 20., 29.],\n",
"Array([[11., 20., 29.],\n",
" [11., 20., 29.]], dtype=float32)"
]
},
Expand Down
28 changes: 14 additions & 14 deletions docs/jax-101/04-advanced-autodiff.ipynb
Expand Up @@ -175,9 +175,9 @@
{
"data": {
"text/plain": [
"DeviceArray([[2., 0., 0.],\n",
" [0., 2., 0.],\n",
" [0., 0., 2.]], dtype=float32)"
"Array([[2., 0., 0.],\n",
" [0., 2., 0.],\n",
" [0., 0., 2.]], dtype=float32)"
]
},
"execution_count": 6,
Expand Down Expand Up @@ -312,7 +312,7 @@
{
"data": {
"text/plain": [
"DeviceArray([ 2.4, -2.4, 2.4], dtype=float32)"
"Array([ 2.4, -2.4, 2.4], dtype=float32)"
]
},
"execution_count": 9,
Expand Down Expand Up @@ -356,7 +356,7 @@
{
"data": {
"text/plain": [
"DeviceArray([-2.4, -4.8, 2.4], dtype=float32)"
"Array([-2.4, -4.8, 2.4], dtype=float32)"
]
},
"execution_count": 10,
Expand Down Expand Up @@ -459,8 +459,8 @@
{
"data": {
"text/plain": [
"DeviceArray([[-2.4, -4.8, 2.4],\n",
" [-2.4, -4.8, 2.4]], dtype=float32)"
"Array([[-2.4, -4.8, 2.4],\n",
" [-2.4, -4.8, 2.4]], dtype=float32)"
]
},
"execution_count": 12,
Expand Down Expand Up @@ -503,7 +503,7 @@
{
"data": {
"text/plain": [
"DeviceArray([-2.4, -4.8, 2.4], dtype=float32)"
"Array([-2.4, -4.8, 2.4], dtype=float32)"
]
},
"execution_count": 13,
Expand Down Expand Up @@ -548,8 +548,8 @@
{
"data": {
"text/plain": [
"DeviceArray([[-2.4, -4.8, 2.4],\n",
" [-2.4, -4.8, 2.4]], dtype=float32)"
"Array([[-2.4, -4.8, 2.4],\n",
" [-2.4, -4.8, 2.4]], dtype=float32)"
]
},
"execution_count": 14,
Expand Down Expand Up @@ -586,8 +586,8 @@
{
"data": {
"text/plain": [
"DeviceArray([[-2.4, -4.8, 2.4],\n",
" [-2.4, -4.8, 2.4]], dtype=float32)"
"Array([[-2.4, -4.8, 2.4],\n",
" [-2.4, -4.8, 2.4]], dtype=float32)"
]
},
"execution_count": 15,
Expand Down Expand Up @@ -623,8 +623,8 @@
{
"data": {
"text/plain": [
"DeviceArray([[-2.4, -4.8, 2.4],\n",
" [-2.4, -4.8, 2.4]], dtype=float32)"
"Array([[-2.4, -4.8, 2.4],\n",
" [-2.4, -4.8, 2.4]], dtype=float32)"
]
},
"execution_count": 16,
Expand Down

0 comments on commit 2c32660

Please sign in to comment.