From d30c3ad0d0e3825021421187c20239c2c86457a4 Mon Sep 17 00:00:00 2001 From: Yash Katariya Date: Mon, 5 Feb 2024 17:27:32 -0800 Subject: [PATCH] Support activation offloading to host in JAX! Currently on TPU support works. GPU support is being added. PiperOrigin-RevId: 604482085 --- jax/_src/ad_checkpoint.py | 16 ++++++- jax/_src/interpreters/mlir.py | 20 ++------- jax/_src/interpreters/pxla.py | 9 ++-- tests/memories_test.py | 78 +++++++++++++++++++++-------------- 4 files changed, 69 insertions(+), 54 deletions(-) diff --git a/jax/_src/ad_checkpoint.py b/jax/_src/ad_checkpoint.py index 6baf89d72ad5..9696c6b0b3e4 100644 --- a/jax/_src/ad_checkpoint.py +++ b/jax/_src/ad_checkpoint.py @@ -107,6 +107,19 @@ def policy(prim, *_, **params): return False # not saveable unless it's in the allow-list return policy +def save_and_offload_only_these_names( + *, names_which_can_be_saved, names_which_can_be_offloaded, + offload_src, offload_dst): + names_which_can_be_saved = set(names_which_can_be_saved) + names_which_can_be_offloaded = set(names_which_can_be_offloaded) + def policy(prim, *_, **params): + if prim is name_p and params['name'] in names_which_can_be_saved: + return pe.Saveable + if prim is name_p and params['name'] in names_which_can_be_offloaded: + return pe.Offloadable(src=offload_src, dst=offload_dst) + return pe.Recompute # not saveable unless it's in the allow-list + return policy + def save_from_both_policies(policy_1, policy_2): @@ -126,7 +139,8 @@ def policy(prim, *args, **params): save_anything_except_these_names=save_anything_except_these_names, save_any_names_but_these=save_any_names_but_these, save_only_these_names=save_only_these_names, - save_from_both_policies=save_from_both_policies) + save_from_both_policies=save_from_both_policies, + save_and_offload_only_these_names=save_and_offload_only_these_names) ### Main API diff --git a/jax/_src/interpreters/mlir.py b/jax/_src/interpreters/mlir.py index c9745e258b31..60032e721f6d 100644 --- a/jax/_src/interpreters/mlir.py +++ b/jax/_src/interpreters/mlir.py @@ -1395,7 +1395,7 @@ def aval_to_types(aval): if ir_arg_memory_kinds is not None: flat_args = [ - a if mk is None else wrap_with_memory_kind(a, mk, a_aval, is_input=True) + a if mk is None else wrap_with_memory_kind(a, mk, a_aval) for a, mk, a_aval in zip(flat_args, ir_arg_memory_kinds, input_avals)] _, token_args, unflattened_args = util.split_list( @@ -1446,27 +1446,15 @@ def aval_to_types(aval): return func_op -def get_compute_type(memory_kind: str) -> str: - if memory_kind == 'tpu_hbm': - return 'dense' - elif memory_kind == 'unpinned_host': - return 'host' - raise ValueError(f'Unknown memory_kind: {memory_kind}') - - def wrap_with_memory_kind( - x: ir.Value, memory_kind: str, aval_out: core.AbstractValue, - is_input: bool = False) -> ir.Value: + x: ir.Value, memory_kind: str, aval_out: core.AbstractValue) -> ir.Value: if aval_out is None: result_type = x.type else: result_type = aval_to_ir_type(aval_out) op = custom_call("annotate_device_placement", result_types=[result_type], - operands=[x], api_version=1) - mka = get_compute_type(memory_kind) - dict_attr = {"_xla_compute_type": ir.StringAttr.get(mka)} - if is_input and mka == 'host': - dict_attr.update({"_xla_buffer_placement": ir.StringAttr.get("arg")}) + operands=[x], has_side_effect=True, api_version=1) + dict_attr = {"_xla_buffer_placement": ir.StringAttr.get(memory_kind)} op.attributes["mhlo.frontend_attributes"] = ir.DictAttr.get(dict_attr) return op.result diff --git a/jax/_src/interpreters/pxla.py b/jax/_src/interpreters/pxla.py index ec33950498de..7ab1ca605f1e 100644 --- a/jax/_src/interpreters/pxla.py +++ b/jax/_src/interpreters/pxla.py @@ -1993,15 +1993,15 @@ def lower_sharding_computation( for js, source_info in util.stable_unique(jaxpr_sharding))), devices_from_context) - transfer_mem_kind_in_jaxpr = list(jaxpr_transfer_mem_kinds(jaxpr)) + # TODO(yashkatariya): Enable this when offload APIs are stable. + # transfer_mem_kind_in_jaxpr = list(jaxpr_transfer_mem_kinds(jaxpr)) committed = bool( devices_from_context or len(device_assignment) > 1 or any(not is_unspecified(i) for i in in_shardings) or any(not is_unspecified(js) for js, _ in jaxpr_sharding) or - any(not is_unspecified(o) for o in out_shardings) or - transfer_mem_kind_in_jaxpr) + any(not is_unspecified(o) for o in out_shardings)) gs = sharding_impls.GSPMDSharding.get_replicated(device_assignment) in_shardings = tuple(gs if is_unspecified(i) else i for i in in_shardings) @@ -2010,8 +2010,7 @@ def lower_sharding_computation( all_default_mem_kind = are_all_shardings_default_mem_kind( da_object, - it.chain(in_shardings, out_shardings, [js for js, _ in jaxpr_sharding], # type: ignore - transfer_mem_kind_in_jaxpr)) + it.chain(in_shardings, out_shardings, [js for js, _ in jaxpr_sharding])) # type: ignore if not da_object.is_fully_addressable: # type: ignore if inline and config.spmd_mode.value != 'allow_all': diff --git a/tests/memories_test.py b/tests/memories_test.py index 47a31e62b547..9353f21ead0f 100644 --- a/tests/memories_test.py +++ b/tests/memories_test.py @@ -22,6 +22,7 @@ from jax._src import test_util as jtu from jax._src import xla_bridge as xb from jax._src import config +from jax.ad_checkpoint import checkpoint_name import jax.numpy as jnp from jax.sharding import PartitionSpec as P from jax.ad_checkpoint import Offloadable, remat @@ -199,20 +200,8 @@ def test_default_memory_kind(self): class MemoriesComputationTest(jtu.BufferDonationTestCase): def setUp(self): - if not jtu.test_device_matches(["tpu"]): - self.skipTest("Memories do not work on CPU and GPU backends yet.") - # TODO(b/311021572) - if jtu.is_cloud_tpu(): - self.skipTest("Experimental feature not yet implemented on Cloud TPU") + self.skipTest("Compute via memories does not work yet.") super().setUp() - self.orig_memories_flag = config.enable_memories.value - jax.config.update('jax_enable_memories', True) - FLAGS.xla_tpu_enable_host_aware_passes = True - - def tearDown(self): - jax.config.update('jax_enable_memories', self.orig_memories_flag) - FLAGS.xla_tpu_enable_host_aware_passes = False - super().tearDown() def _check_mem_kind(self, executable_kind, out_sharding, expected_kind): out_kind = out_sharding.memory_kind @@ -1115,7 +1104,7 @@ def test_remat_jaxpr_offloadable(self): inp = jax.device_put(np.arange(16.), NamedSharding(mesh, P("x"))) def policy(prim, *avals, **params): - return Offloadable(src="tpu_hbm", dst="unpinned_host") + return Offloadable(src="tpu_hbm", dst="pinned_host") @functools.partial(remat, policy=policy) def f(x): @@ -1126,43 +1115,68 @@ def f(x): fwd_jaxpr, bwd_jaxpr = jtu.fwd_bwd_jaxprs(f, inp) - self.assertLen(fwd_jaxpr.out_avals, 4) # 1 output, 3 offloaded residuals + self.assertLen(fwd_jaxpr.out_avals, 4) # 1 output, 3 offloaded residuals fwd_mem_kind_count = str(fwd_jaxpr).count( - "TransferToMemoryKind(memory_kind='unpinned_host')") + "TransferToMemoryKind(memory_kind='pinned_host')") self.assertEqual(fwd_mem_kind_count, 3) - self.assertLen(bwd_jaxpr.in_avals, 4) # 3 offloaded residuals, 1 input + self.assertLen(bwd_jaxpr.in_avals, 4) # 3 offloaded residuals, 1 input bwd_mem_kind_count = str(bwd_jaxpr).count( "TransferToMemoryKind(memory_kind='tpu_hbm')") self.assertEqual(bwd_mem_kind_count, 3) + # Execution test. + f = jax.jit(jax.grad(f)) + f(inp) # doesn't crash + + compiled_text = f.lower(inp).compile().as_text() + if compiled_text is not None: + self.assertIn('S(5)', compiled_text) + self.assertRegex(compiled_text, r"copy-start.*S\(5\)") + self.assertRegex(compiled_text, r"copy-done.*S\(5\)") + def test_remat_scan_jaxpr_offloadable(self): mesh = jtu.create_global_mesh((2,), ("x",)) - inp = jax.device_put(np.arange(16.), NamedSharding(mesh, P("x"))) + shape = (256, 128) + np_inp = np.arange(math.prod(shape), dtype=np.float32).reshape(shape) + inp = jax.device_put(np_inp, NamedSharding(mesh, P("x"))) - def policy(prim, *avals, **params): - return Offloadable(src="tpu_hbm", dst="unpinned_host") + policy = jax.checkpoint_policies.save_and_offload_only_these_names( + names_which_can_be_saved=["y"], names_which_can_be_offloaded=["z", "w"], + offload_src='tpu_hbm', offload_dst='pinned_host') + @functools.partial(remat, policy=policy) def f(x): - @functools.partial(remat, policy=policy) - def g(y, _): - y = jnp.sin(y) - y = jnp.sin(y) - y = jnp.sin(y) - return y, None - return jax.lax.scan(g, x, None, length=1)[0] + def g(ys, _): + y, _ = ys + y = checkpoint_name(jnp.sin(y), "y") + z = checkpoint_name(jnp.sin(y), "z") + w = checkpoint_name(jnp.sin(z), "w") + return (w, jnp.sum(w)), None + _, scan_out = jax.lax.scan(g, (x, np.array(1, dtype=np.float32)), [np_inp])[0] + return scan_out fwd_jaxpr, bwd_jaxpr = jtu.fwd_bwd_jaxprs(f, inp) - self.assertLen(fwd_jaxpr.out_avals, 4) # 1 output, 3 offloaded residuals + self.assertLen(fwd_jaxpr.out_avals, 5) # 2 output, 3 offloaded residuals fwd_mem_kind_count = str(fwd_jaxpr).count( - "TransferToMemoryKind(memory_kind='unpinned_host')") - self.assertEqual(fwd_mem_kind_count, 3) + "TransferToMemoryKind(memory_kind='pinned_host')") + self.assertEqual(fwd_mem_kind_count, 2) - self.assertLen(bwd_jaxpr.in_avals, 4) # 3 offloaded residuals, 1 input + self.assertLen(bwd_jaxpr.in_avals, 5) # 3 offloaded residuals, 2 input bwd_mem_kind_count = str(bwd_jaxpr).count( "TransferToMemoryKind(memory_kind='tpu_hbm')") - self.assertEqual(bwd_mem_kind_count, 3) + self.assertEqual(bwd_mem_kind_count, 2) + + f = jax.jit(jax.grad(f)) + f(inp) # doesn't crash + + compiled_text = f.lower(inp).compile().as_text() + if compiled_text is not None: + self.assertIn('S(5)', compiled_text) + self.assertNotRegex(compiled_text, r"copy-start.*S\(5\)") + self.assertNotRegex(compiled_text, r"copy-done.*S\(5\)") + if __name__ == "__main__": absltest.main(testLoader=jtu.JaxTestLoader())