Skip to content

Commit

Permalink
Enable JAX memory tests for GPUs and CPUs
Browse files Browse the repository at this point in the history
PjRt GPU and CPU has recently gotten memory space support with just one memory space per device, so enabling relevant JAX memory tests. Most tests cannot be enabled yet because they rely on `unpinned_host`, so only enabling `ShardingMemoriesTest` for now.

PiperOrigin-RevId: 633335638
  • Loading branch information
junwhanahn authored and jax authors committed May 13, 2024
1 parent 72a81e5 commit cd6e012
Showing 1 changed file with 25 additions and 18 deletions.
43 changes: 25 additions & 18 deletions tests/memories_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
from jax._src import test_util as jtu
from jax._src import xla_bridge as xb
from jax._src import config
from jax._src.lib import xla_extension_version
from jax.ad_checkpoint import checkpoint_name, checkpoint as new_checkpoint
import jax.numpy as jnp
from jax.sharding import PartitionSpec as P
Expand Down Expand Up @@ -64,14 +65,18 @@ def _create_inputs(shape, pspec, mem_kind=None):
class ShardingMemoriesTest(jtu.JaxTestCase):

def setUp(self):
if not jtu.test_device_matches(["tpu"]):
if xla_extension_version < 265 and not jtu.test_device_matches(["tpu"]):
self.skipTest("Memories do not work on CPU and GPU backends yet.")
# TODO(b/311021572)
if jtu.is_cloud_tpu():
self.skipTest("Experimental feature not yet implemented on Cloud TPU")
super().setUp()
self.orig_memories_flag = config.enable_memories.value
jax.config.update('jax_enable_memories', True)
if jtu.test_device_matches(["cpu"]):
self._default_memory_kind = "unpinned_host"
else:
self._default_memory_kind = "device"

def tearDown(self):
jax.config.update('jax_enable_memories', self.orig_memories_flag)
Expand All @@ -87,17 +92,17 @@ def test_canonicalize_memory_kind(self, name):
if name == "named_sharding":
mesh = jtu.create_global_mesh((1,), "x")
ns = NamedSharding(mesh, P("x"))
self.assertEqual(ns.memory_kind, "device")
self.assertEqual(ns.memory_kind, self._default_memory_kind)
elif name == "positional_sharding":
ps = PositionalSharding(jax.devices())
self.assertEqual(ps.memory_kind, "device")
self.assertEqual(ps.memory_kind, self._default_memory_kind)
elif name == "single_device_sharding":
ss = SingleDeviceSharding(jax.devices()[0])
self.assertEqual(ss.memory_kind, "device")
self.assertEqual(ss.memory_kind, self._default_memory_kind)
else:
assert name == "gspmd_sharding"
gs = GSPMDSharding.get_replicated(jax.devices())
self.assertEqual(gs.memory_kind, "device")
self.assertEqual(gs.memory_kind, self._default_memory_kind)

@parameterized.named_parameters(
("named_sharding", "named_sharding"),
Expand All @@ -108,26 +113,26 @@ def test_canonicalize_memory_kind(self, name):
def test_wrong_memory_kind(self, name):
if name == "named_sharding":
with self.assertRaisesRegex(
ValueError, "Could not find memory addressable by device TPU.*"
ValueError, "Could not find memory addressable by device.*"
):
mesh = jtu.create_global_mesh((8,), ("x",))
NamedSharding(mesh, P("x"), memory_kind="hbm")
elif name == "positional_sharding":
with self.assertRaisesRegex(
ValueError, "Could not find memory addressable by device TPU.*"
ValueError, "Could not find memory addressable by device.*"
):
PositionalSharding(jax.devices(), memory_kind="gpu_hbm")
elif name == "single_device_sharding":
with self.assertRaisesRegex(
ValueError,
"Could not find memory addressable by device TPU.*Device TPU.*"
"Could not find memory addressable by device.*Device.*"
" can address the following memory kinds.*",
):
SingleDeviceSharding(jax.devices()[0], memory_kind="host")
else:
assert name == "gspmd_sharding"
with self.assertRaisesRegex(
ValueError, "Could not find memory addressable by device TPU.*"
ValueError, "Could not find memory addressable by device.*"
):
GSPMDSharding.get_replicated(jax.devices(), memory_kind="my_host")

Expand All @@ -138,11 +143,13 @@ def test_wrong_memory_kind(self, name):
("gspmd_sharding", "gspmd_sharding"),
)
def test_correct_tpu_memory_kind(self, name):
if not jtu.test_device_matches(["tpu"]):
self.skipTest("TPU memory kind test.")
if name == "named_sharding":
mesh = jtu.create_global_mesh((8,), ("x",))
NamedSharding(mesh, P("x"), memory_kind="device")
NamedSharding(mesh, P("x"), memory_kind=self._default_memory_kind)
elif name == "positional_sharding":
PositionalSharding(jax.devices(), memory_kind="device")
PositionalSharding(jax.devices(), memory_kind=self._default_memory_kind)
elif name == "single_device_sharding":
SingleDeviceSharding(jax.devices()[0], memory_kind="unpinned_host")
else:
Expand All @@ -159,19 +166,19 @@ def test_sharding_eq(self, name):
if name == "named_sharding":
mesh = jtu.create_global_mesh((8,), ("x",))
s1 = NamedSharding(mesh, P("x"))
s2 = NamedSharding(mesh, P("x"), memory_kind="device")
s2 = NamedSharding(mesh, P("x"), memory_kind=self._default_memory_kind)
self.assertEqual(s1, s2)
elif name == "positional_sharding":
s1 = PositionalSharding(jax.devices())
s2 = PositionalSharding(jax.devices(), memory_kind="device")
s2 = PositionalSharding(jax.devices(), memory_kind=self._default_memory_kind)
self.assertEqual(s1, s2)
elif name == "single_device_sharding":
s1 = SingleDeviceSharding(jax.devices()[0])
s2 = SingleDeviceSharding(jax.devices()[0], memory_kind="device")
s2 = SingleDeviceSharding(jax.devices()[0], memory_kind=self._default_memory_kind)
self.assertEqual(s1, s2)
elif name == "gspmd_sharding":
s1 = GSPMDSharding.get_replicated(jax.devices())
s2 = GSPMDSharding.get_replicated(jax.devices(), memory_kind="device")
s2 = GSPMDSharding.get_replicated(jax.devices(), memory_kind=self._default_memory_kind)
self.assertEqual(s1, s2)

def test_sharding_equivalent(self):
Expand All @@ -181,19 +188,19 @@ def test_sharding_equivalent(self):
gs1 = GSPMDSharding(
tuple(mesh.devices.flat),
ns1._to_xla_hlo_sharding(ndim),
memory_kind="device",
memory_kind=self._default_memory_kind,
)
self.assertTrue(ns1.is_equivalent_to(gs1, ndim))

ns2 = NamedSharding(mesh, P("x"), memory_kind="device")
ns2 = NamedSharding(mesh, P("x"), memory_kind=self._default_memory_kind)
gs2 = GSPMDSharding(
tuple(mesh.devices.flat), ns2._to_xla_hlo_sharding(ndim)
)
self.assertTrue(ns2.is_equivalent_to(gs2, ndim))

def test_default_memory_kind(self):
dev = jax.devices()[0]
self.assertEqual(dev.default_memory().kind, "device")
self.assertEqual(dev.default_memory().kind, self._default_memory_kind)


class MemoriesComputationTest(jtu.BufferDonationTestCase):
Expand Down

0 comments on commit cd6e012

Please sign in to comment.