Skip to content

Commit

Permalink
[JAX:CPU] Enable buffer donation on CPU.
Browse files Browse the repository at this point in the history
Fix a bug in PJRT where if a buffer was not owned (e.g., it aliased a NumPy buffer) it could still be donated and that would lead to a use after free.

PiperOrigin-RevId: 484001545
  • Loading branch information
hawkinsp authored and jax authors committed Oct 26, 2022
1 parent b4fdc12 commit ce9e009
Show file tree
Hide file tree
Showing 3 changed files with 31 additions and 15 deletions.
3 changes: 3 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,9 @@ Remember to align the itemized text with the first line of an item within a list
and forbids passing lists or tuples in place of arrays ({jax-issue}`#12958`)

## jaxlib 0.3.24
* Changes
* Buffer donation now works on CPU. This may break code that marked buffers
for donation on CPU but relied on donation not being implemented.

## jax 0.3.23 (Oct 12, 2022)
* Changes
Expand Down
13 changes: 9 additions & 4 deletions jax/interpreters/mlir.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@
from jax._src import ad_util
from jax._src import device_array
from jax._src import dtypes
from jax._src.lib import mlir_api_version
from jax._src.lib import mlir_api_version, xla_extension_version
from jax._src.lib import version as jaxlib_version
from jax._src.lib.mlir import ir
from jax._src.lib.mlir.dialects import chlo
Expand Down Expand Up @@ -570,6 +570,12 @@ class LoweringResult(NamedTuple):
host_callbacks: List[Any]


if xla_extension_version >= 102:
_platforms_with_donation = ["cpu", "cuda", "rocm", "tpu"]
else:
_platforms_with_donation = ["cuda", "rocm", "tpu"]


def lower_jaxpr_to_module(
module_name: str,
jaxpr: core.ClosedJaxpr,
Expand Down Expand Up @@ -608,8 +614,7 @@ def lower_jaxpr_to_module(
out_aval, = out_aval.dtype._rules.physical_avals(out_aval)
out_avals.append(sharded_aval(out_aval, out_sharding))

platforms_with_donation = ("cuda", "rocm", "tpu")
if platform in platforms_with_donation:
if platform in _platforms_with_donation:
input_output_aliases, donated_args = _set_up_aliases(
in_avals, out_avals, donated_args)
if any(eff not in lowerable_effects for eff in jaxpr.effects):
Expand All @@ -619,7 +624,7 @@ def lower_jaxpr_to_module(
unused_donations = [str(a) for a, d in zip(in_avals, donated_args)
if d]
msg = "See an explanation at https://jax.readthedocs.io/en/latest/faq.html#buffer-donation."
if platform not in platforms_with_donation:
if platform not in _platforms_with_donation:
msg = f"Donation is not implemented for {platform}.\n{msg}"
warnings.warn(f"Some donated buffers were not usable: {', '.join(unused_donations)}.\n{msg}")

Expand Down
30 changes: 19 additions & 11 deletions tests/api_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,7 @@
from jax._src import device_array
from jax._src import prng
from jax._src.lib import xla_client
from jax._src.lib import xla_extension_version
from jax._src import test_util as jtu
from jax import tree_util
from jax import linear_util as lu
Expand Down Expand Up @@ -411,8 +412,9 @@ def test_jit_donate_argnums_warning_raised(self):
"Some donated buffers were not usable:",
str(w[-1].message))

@jtu.skip_on_devices("cpu") # In/out aliasing not supported on CPU.
def test_jit_donate_argnums_invalidates_input(self):
if jtu.device_under_test() == "cpu" and xla_extension_version < 102:
raise unittest.SkipTest("CPU buffer donation requires jaxlib > 0.3.22")
# We can't just use `lambda x: x` because JAX simplifies this away to an
# empty XLA computation.
move = self.jit(lambda x: x + x - x, donate_argnums=0)
Expand All @@ -421,31 +423,34 @@ def test_jit_donate_argnums_invalidates_input(self):
self.assertDeleted(x)
self.assertEqual(y, 1.)

@jtu.skip_on_devices("cpu") # In/out aliasing not supported on CPU.
def test_jit_donate_argnums_static_argnums(self):
if jtu.device_under_test() == "cpu" and xla_extension_version < 102:
raise unittest.SkipTest("CPU buffer donation requires jaxlib > 0.3.22")
jit_fun = self.jit(
lambda a, b, c, d: ((a + b + c), (a + b + d)),
static_argnums=(0, 1),
donate_argnums=(2, 3))

c = jax.device_put(jnp.array([1., 1.]))
d = jax.device_put(jnp.array([1., 1., 1.]))
c = jax.device_put(jnp.array([2., 2.]))
d = jax.device_put(jnp.array([1., 1., 1., 1.]))
e, f = jit_fun(1, 2, c, d)
np.testing.assert_allclose(e, jnp.array([4., 4.]))
np.testing.assert_allclose(f, jnp.array([4., 4., 4.]))
np.testing.assert_allclose(e, jnp.array([5., 5.]))
np.testing.assert_allclose(f, jnp.array([4., 4., 4., 4.]))
self.assertDeleted(c)
self.assertDeleted(d)

@jtu.skip_on_devices("cpu") # In/out aliasing not supported on CPU.
def test_jit_donate_argnums_weak_type(self):
if jtu.device_under_test() == "cpu" and xla_extension_version < 102:
raise unittest.SkipTest("CPU buffer donation requires jaxlib > 0.3.22")
# input has weak-type, output does not have weak-type
move = self.jit(lambda x: x.astype(int), donate_argnums=0)
x = jnp.broadcast_to(2, (3,))
move(x)
self.assertDeleted(x)

@jtu.skip_on_devices("cpu") # In/out aliasing not supported on CPU.
def test_jnp_array_copy(self):
if jtu.device_under_test() == "cpu" and xla_extension_version < 102:
raise unittest.SkipTest("CPU buffer donation requires jaxlib > 0.3.22")
# https://github.com/google/jax/issues/3412

@partial(self.jit, donate_argnums=(0,))
Expand Down Expand Up @@ -1013,8 +1018,9 @@ def f(*args):
f_exe = self.jit(f).lower(1., 1.).compile()
self.assertAllClose(f_exe(1., 1.), 1.)

@jtu.skip_on_devices("cpu") # no donation on cpu, so this would warn
def test_jit_lower_donate_argnums_available(self):
if jtu.device_under_test() == "cpu" and xla_extension_version < 102:
raise unittest.SkipTest("CPU buffer donation requires jaxlib > 0.3.22")
def f(*args):
x, *_ = args
return x + 4.
Expand Down Expand Up @@ -2582,9 +2588,10 @@ def test_xla_computation_psum_constant(self):
f = lambda: jax.lax.psum(1, "i")
api.xla_computation(f, axis_env=[("i", 2)])() # doesn't crash

@jtu.skip_on_devices("cpu")
@jtu.ignore_warning(message="Some donated buffers were not usable")
def test_xla_computation_donate_argnums(self):
if jtu.device_under_test() == "cpu" and xla_extension_version < 102:
raise unittest.SkipTest("CPU buffer donation requires jaxlib > 0.3.22")
api.xla_computation(lambda x: None, donate_argnums=(0,))(3) # doesn't crash

def test_xla_computation_lower_fun_axis_env(self):
Expand Down Expand Up @@ -9089,8 +9096,9 @@ def test_def_method_forwarding_all_permutations(self):

class BufferDonationTest(jtu.BufferDonationTestCase):

@jtu.skip_on_devices("cpu") # In/out aliasing not supported on CPU.
def test_pmap_donate_argnums_invalidates_input(self):
if jtu.device_under_test() == "cpu" and xla_extension_version < 102:
raise unittest.SkipTest("CPU buffer donation requires jaxlib > 0.3.22")
move = api.pmap(lambda x: x + x - x, donate_argnums=0)
n = jax.local_device_count()
x = api.pmap(lambda x: x)(jnp.ones([n]))
Expand Down

0 comments on commit ce9e009

Please sign in to comment.