Skip to content

Commit

Permalink
Reverts f0382a5
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 587231484
  • Loading branch information
yashk2810 authored and jax authors committed Dec 2, 2023
1 parent 86661c8 commit f0bc7e0
Show file tree
Hide file tree
Showing 8 changed files with 106 additions and 113 deletions.
5 changes: 0 additions & 5 deletions CHANGELOG.md
Expand Up @@ -29,11 +29,6 @@ Remember to align the itemized text with the first line of an item within a list
that cannot be converted to a JAX array is deprecated and now raises a
{obj}`DeprecationWaning`. Currently the functions return False, in the future this
will raise an exception.
* The `device_buffer` and `device_buffers` properties of JAX arrays are deprecated.
Explicit buffers have been replaced by the more flexible array sharding interface,
but the previous outputs can be recovered this way:
* `arr.device_buffer` becomes `arr.addressable_data(0)`
* `arr.device_buffers` becomes `[x.data for x in arr.addressable_shards]`
* The `device()` method of JAX arrays deprecated. Depending on the context, it may
be replaced with one of the following:
- {meth}`jax.Array.devices` returns the set of all devices used by the array.
Expand Down
4 changes: 2 additions & 2 deletions jax/_src/api.py
Expand Up @@ -102,8 +102,8 @@ def _nan_check_posthook(fun, args, kwargs, output):
"""Hook function called by the C++ jit/pmap to perform NaN checking."""
buffers = []
for leaf in tree_leaves(output):
if hasattr(leaf, "addressable_shards"):
buffers.extend([shard.data for shard in leaf.addressable_shards])
if hasattr(leaf, "device_buffers"):
buffers.extend(leaf.device_buffers)

try:
dispatch.check_special(pjit.pjit_p.name, buffers)
Expand Down
8 changes: 0 additions & 8 deletions jax/_src/array.py
Expand Up @@ -474,10 +474,6 @@ def devices(self) -> set[Device]:
# deleted.
@property
def device_buffer(self) -> ArrayImpl:
# Added 2023 Nov 29
warnings.warn(
"arr.device_buffer is deprecated. Use arr.addressable_data(0)",
DeprecationWarning, stacklevel=2)
self._check_if_deleted()
if len(self._arrays) == 1:
return self._arrays[0]
Expand All @@ -488,10 +484,6 @@ def device_buffer(self) -> ArrayImpl:
# deleted.
@property
def device_buffers(self) -> Sequence[ArrayImpl]:
# Added 2023 Nov 29
warnings.warn(
"arr.device_buffers is deprecated. Use [x.data for x in arr.addressable_shards]",
DeprecationWarning, stacklevel=2)
self._check_if_deleted()
return cast(Sequence[ArrayImpl], self._arrays)

Expand Down
8 changes: 5 additions & 3 deletions jax/_src/test_util.py
Expand Up @@ -1116,11 +1116,13 @@ class BufferDonationTestCase(JaxTestCase):
assertNotDeleted = lambda self, x: self._assertDeleted(x, False)

def _assertDeleted(self, x, deleted):
if hasattr(x, "_arrays") or hasattr(x, "is_deleted"):
if hasattr(x, "_arrays"):
self.assertEqual(x.is_deleted(), deleted)
elif hasattr(x, "device_buffer"):
self.assertEqual(x.device_buffer.is_deleted(), deleted)
else:
for shard in x.addressable_shards:
self.assertEqual(shard.data.is_deleted(), deleted)
for buffer in x.device_buffers:
self.assertEqual(buffer.is_deleted(), deleted)


@contextmanager
Expand Down
2 changes: 0 additions & 2 deletions tests/array_test.py
Expand Up @@ -593,8 +593,6 @@ def test_array_jnp_array_copy_multi_device(self):
self.assertNotEqual(a.data.unsafe_buffer_pointer(),
c.data.unsafe_buffer_pointer())

@jtu.ignore_warning(category=DeprecationWarning,
message="arr.device_buffers? is deprecated")
def test_array_device_buffer(self):
global_mesh = jtu.create_global_mesh((4, 2), ('x', 'y'))
input_shape = (8, 2)
Expand Down
186 changes: 97 additions & 89 deletions tests/notebooks/colab_cpu.ipynb
@@ -1,10 +1,23 @@
{
"nbformat": 4,
"nbformat_minor": 0,
"metadata": {
"colab": {
"name": "JAX Colab CPU Test",
"provenance": [],
"collapsed_sections": []
},
"kernelspec": {
"name": "python3",
"display_name": "Python 3"
}
},
"cells": [
{
"cell_type": "markdown",
"metadata": {
"colab_type": "text",
"id": "view-in-github"
"id": "view-in-github",
"colab_type": "text"
},
"source": [
"<a href=\"https://colab.research.google.com/github/google/jax/blob/main/tests/notebooks/colab_cpu.ipynb\" target=\"_parent\"><img src=\"https://colab.research.google.com/assets/colab-badge.svg\" alt=\"Open In Colab\"/></a>"
Expand All @@ -13,8 +26,8 @@
{
"cell_type": "markdown",
"metadata": {
"colab_type": "text",
"id": "WkadOyTDCAWD"
"id": "WkadOyTDCAWD",
"colab_type": "text"
},
"source": [
"# JAX Colab CPU Test\n",
Expand All @@ -24,69 +37,77 @@
},
{
"cell_type": "code",
"execution_count": 6,
"metadata": {
"id": "_tKNrbqqBHwu",
"colab_type": "code",
"colab": {
"base_uri": "https://localhost:8080/",
"height": 68
},
"colab_type": "code",
"id": "_tKNrbqqBHwu",
"outputId": "071fb360-ddf5-41ae-d772-acc08ec71d9b"
},
"source": [
"import jax\n",
"import jaxlib\n",
"\n",
"!cat /var/colab/hostname\n",
"print(jax.__version__)\n",
"print(jaxlib.__version__)"
],
"execution_count": 6,
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"m-s-1p12yf76kgzz\n",
"0.1.64\n",
"0.1.45\n"
]
],
"name": "stdout"
}
],
"source": [
"import jax\n",
"import jaxlib\n",
"\n",
"!cat /var/colab/hostname\n",
"print(jax.__version__)\n",
"print(jaxlib.__version__)"
]
},
{
"cell_type": "markdown",
"metadata": {
"colab_type": "text",
"id": "oqEG21rADO1F"
"id": "oqEG21rADO1F",
"colab_type": "text"
},
"source": [
"## Confirm Device"
]
},
{
"cell_type": "code",
"execution_count": 2,
"metadata": {
"colab_type": "code",
"id": "8BwzMYhKGQj6",
"outputId": "f79a44e3-4303-494c-9288-a4e582bb34cb",
"colab": {
"base_uri": "https://localhost:8080/",
"height": 68
},
"colab_type": "code",
"id": "8BwzMYhKGQj6",
"outputId": "f79a44e3-4303-494c-9288-a4e582bb34cb"
}
},
"source": [
"from jaxlib import xla_extension\n",
"import jax\n",
"key = jax.random.PRNGKey(1701)\n",
"arr = jax.random.normal(key, (1000,))\n",
"device = arr.device_buffer.device()\n",
"print(f\"JAX device type: {device}\")\n",
"assert device.platform == \"cpu\", f\"unexpected JAX device type: {device.platform}\""
],
"execution_count": 2,
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"/usr/local/lib/python3.6/dist-packages/jax/lib/xla_bridge.py:123: UserWarning: No GPU/TPU found, falling back to CPU.\n",
" warnings.warn('No GPU/TPU found, falling back to CPU.')\n"
]
],
"name": "stderr"
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"JAX device type: cpu:0\n"
Expand All @@ -106,34 +127,24 @@
{
"cell_type": "markdown",
"metadata": {
"colab_type": "text",
"id": "z0FUY9yUC4k1"
"id": "z0FUY9yUC4k1",
"colab_type": "text"
},
"source": [
"## Matrix Multiplication"
]
},
{
"cell_type": "code",
"execution_count": 3,
"metadata": {
"colab_type": "code",
"id": "eXn8GUl6CG5N",
"outputId": "307aa669-76f1-4117-b62a-7acb2aee2c16",
"colab": {
"base_uri": "https://localhost:8080/",
"height": 34
},
"colab_type": "code",
"id": "eXn8GUl6CG5N",
"outputId": "307aa669-76f1-4117-b62a-7acb2aee2c16"
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"1.0216691\n"
]
}
],
},
"source": [
"import jax\n",
"import numpy as np\n",
Expand All @@ -143,40 +154,39 @@
"x = jax.random.normal(key, (3000, 3000))\n",
"result = jax.numpy.dot(x, x.T).mean()\n",
"print(result)"
],
"execution_count": 3,
"outputs": [
{
"output_type": "stream",
"text": [
"1.0216691\n"
],
"name": "stdout"
}
]
},
{
"cell_type": "markdown",
"metadata": {
"colab_type": "text",
"id": "0zTA2Q19DW4G"
"id": "0zTA2Q19DW4G",
"colab_type": "text"
},
"source": [
"## Linear Algebra"
]
},
{
"cell_type": "code",
"execution_count": 4,
"metadata": {
"id": "uW9j84_UDYof",
"colab_type": "code",
"colab": {
"base_uri": "https://localhost:8080/",
"height": 51
},
"colab_type": "code",
"id": "uW9j84_UDYof",
"outputId": "3dd5d7c0-9d47-4be1-c6f7-068b432b69f7"
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"[6.9178133 5.9580317 5.581113 4.506963 4.111582 3.973543 3.3307292\n",
" 2.8664916 1.8229378 1.5478933]\n"
]
}
],
"source": [
"import jax.numpy as jnp\n",
"import jax.random as rand\n",
Expand All @@ -190,61 +200,59 @@
"assert u.shape == (N, N)\n",
"assert vt.shape == (M, M)\n",
"print(s)"
],
"execution_count": 4,
"outputs": [
{
"output_type": "stream",
"text": [
"[6.9178133 5.9580317 5.581113 4.506963 4.111582 3.973543 3.3307292\n",
" 2.8664916 1.8229378 1.5478933]\n"
],
"name": "stdout"
}
]
},
{
"cell_type": "markdown",
"metadata": {
"colab_type": "text",
"id": "jCyKUn4-DCXn"
"id": "jCyKUn4-DCXn",
"colab_type": "text"
},
"source": [
"## XLA Compilation"
]
},
{
"cell_type": "code",
"execution_count": 5,
"metadata": {
"colab_type": "code",
"id": "2GOn_HhDPuEn",
"outputId": "41a40dd9-3680-458d-cedd-81ebcc2ab06f",
"colab": {
"base_uri": "https://localhost:8080/",
"height": 51
},
"colab_type": "code",
"id": "2GOn_HhDPuEn",
"outputId": "41a40dd9-3680-458d-cedd-81ebcc2ab06f"
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"[ 0.34676832 -0.7532232 1.7060695 ... 2.1208048 -0.42621925\n",
" 0.13093236]\n"
]
}
],
},
"source": [
"@jax.jit\n",
"def selu(x, alpha=1.67, lmbda=1.05):\n",
" return lmbda * jax.numpy.where(x > 0, x, alpha * jax.numpy.exp(x) - alpha)\n",
"x = jax.random.normal(key, (5000,))\n",
"result = selu(x).block_until_ready()\n",
"print(result)"
],
"execution_count": 5,
"outputs": [
{
"output_type": "stream",
"text": [
"[ 0.34676832 -0.7532232 1.7060695 ... 2.1208048 -0.42621925\n",
" 0.13093236]\n"
],
"name": "stdout"
}
]
}
],
"metadata": {
"colab": {
"collapsed_sections": [],
"name": "JAX Colab CPU Test",
"provenance": []
},
"kernelspec": {
"display_name": "Python 3",
"name": "python3"
}
},
"nbformat": 4,
"nbformat_minor": 0
]
}

0 comments on commit f0bc7e0

Please sign in to comment.