Skip to content

Commit

Permalink
Support activation offloading to host in JAX!
Browse files Browse the repository at this point in the history
Currently on TPU support works. GPU support is being added.

PiperOrigin-RevId: 604482085
  • Loading branch information
yashk2810 authored and jax authors committed Feb 6, 2024
1 parent fb6fa04 commit d30c3ad
Show file tree
Hide file tree
Showing 4 changed files with 69 additions and 54 deletions.
16 changes: 15 additions & 1 deletion jax/_src/ad_checkpoint.py
Expand Up @@ -107,6 +107,19 @@ def policy(prim, *_, **params):
return False # not saveable unless it's in the allow-list
return policy

def save_and_offload_only_these_names(
*, names_which_can_be_saved, names_which_can_be_offloaded,
offload_src, offload_dst):
names_which_can_be_saved = set(names_which_can_be_saved)
names_which_can_be_offloaded = set(names_which_can_be_offloaded)
def policy(prim, *_, **params):
if prim is name_p and params['name'] in names_which_can_be_saved:
return pe.Saveable
if prim is name_p and params['name'] in names_which_can_be_offloaded:
return pe.Offloadable(src=offload_src, dst=offload_dst)
return pe.Recompute # not saveable unless it's in the allow-list
return policy


def save_from_both_policies(policy_1, policy_2):

Expand All @@ -126,7 +139,8 @@ def policy(prim, *args, **params):
save_anything_except_these_names=save_anything_except_these_names,
save_any_names_but_these=save_any_names_but_these,
save_only_these_names=save_only_these_names,
save_from_both_policies=save_from_both_policies)
save_from_both_policies=save_from_both_policies,
save_and_offload_only_these_names=save_and_offload_only_these_names)


### Main API
Expand Down
20 changes: 4 additions & 16 deletions jax/_src/interpreters/mlir.py
Expand Up @@ -1395,7 +1395,7 @@ def aval_to_types(aval):

if ir_arg_memory_kinds is not None:
flat_args = [
a if mk is None else wrap_with_memory_kind(a, mk, a_aval, is_input=True)
a if mk is None else wrap_with_memory_kind(a, mk, a_aval)
for a, mk, a_aval in zip(flat_args, ir_arg_memory_kinds, input_avals)]

_, token_args, unflattened_args = util.split_list(
Expand Down Expand Up @@ -1446,27 +1446,15 @@ def aval_to_types(aval):
return func_op


def get_compute_type(memory_kind: str) -> str:
if memory_kind == 'tpu_hbm':
return 'dense'
elif memory_kind == 'unpinned_host':
return 'host'
raise ValueError(f'Unknown memory_kind: {memory_kind}')


def wrap_with_memory_kind(
x: ir.Value, memory_kind: str, aval_out: core.AbstractValue,
is_input: bool = False) -> ir.Value:
x: ir.Value, memory_kind: str, aval_out: core.AbstractValue) -> ir.Value:
if aval_out is None:
result_type = x.type
else:
result_type = aval_to_ir_type(aval_out)
op = custom_call("annotate_device_placement", result_types=[result_type],
operands=[x], api_version=1)
mka = get_compute_type(memory_kind)
dict_attr = {"_xla_compute_type": ir.StringAttr.get(mka)}
if is_input and mka == 'host':
dict_attr.update({"_xla_buffer_placement": ir.StringAttr.get("arg")})
operands=[x], has_side_effect=True, api_version=1)
dict_attr = {"_xla_buffer_placement": ir.StringAttr.get(memory_kind)}
op.attributes["mhlo.frontend_attributes"] = ir.DictAttr.get(dict_attr)
return op.result

Expand Down
9 changes: 4 additions & 5 deletions jax/_src/interpreters/pxla.py
Expand Up @@ -1993,15 +1993,15 @@ def lower_sharding_computation(
for js, source_info in util.stable_unique(jaxpr_sharding))),
devices_from_context)

transfer_mem_kind_in_jaxpr = list(jaxpr_transfer_mem_kinds(jaxpr))
# TODO(yashkatariya): Enable this when offload APIs are stable.
# transfer_mem_kind_in_jaxpr = list(jaxpr_transfer_mem_kinds(jaxpr))

committed = bool(
devices_from_context or
len(device_assignment) > 1 or
any(not is_unspecified(i) for i in in_shardings) or
any(not is_unspecified(js) for js, _ in jaxpr_sharding) or
any(not is_unspecified(o) for o in out_shardings) or
transfer_mem_kind_in_jaxpr)
any(not is_unspecified(o) for o in out_shardings))

gs = sharding_impls.GSPMDSharding.get_replicated(device_assignment)
in_shardings = tuple(gs if is_unspecified(i) else i for i in in_shardings)
Expand All @@ -2010,8 +2010,7 @@ def lower_sharding_computation(

all_default_mem_kind = are_all_shardings_default_mem_kind(
da_object,
it.chain(in_shardings, out_shardings, [js for js, _ in jaxpr_sharding], # type: ignore
transfer_mem_kind_in_jaxpr))
it.chain(in_shardings, out_shardings, [js for js, _ in jaxpr_sharding])) # type: ignore

if not da_object.is_fully_addressable: # type: ignore
if inline and config.spmd_mode.value != 'allow_all':
Expand Down
78 changes: 46 additions & 32 deletions tests/memories_test.py
Expand Up @@ -22,6 +22,7 @@
from jax._src import test_util as jtu
from jax._src import xla_bridge as xb
from jax._src import config
from jax.ad_checkpoint import checkpoint_name
import jax.numpy as jnp
from jax.sharding import PartitionSpec as P
from jax.ad_checkpoint import Offloadable, remat
Expand Down Expand Up @@ -199,20 +200,8 @@ def test_default_memory_kind(self):
class MemoriesComputationTest(jtu.BufferDonationTestCase):

def setUp(self):
if not jtu.test_device_matches(["tpu"]):
self.skipTest("Memories do not work on CPU and GPU backends yet.")
# TODO(b/311021572)
if jtu.is_cloud_tpu():
self.skipTest("Experimental feature not yet implemented on Cloud TPU")
self.skipTest("Compute via memories does not work yet.")
super().setUp()
self.orig_memories_flag = config.enable_memories.value
jax.config.update('jax_enable_memories', True)
FLAGS.xla_tpu_enable_host_aware_passes = True

def tearDown(self):
jax.config.update('jax_enable_memories', self.orig_memories_flag)
FLAGS.xla_tpu_enable_host_aware_passes = False
super().tearDown()

def _check_mem_kind(self, executable_kind, out_sharding, expected_kind):
out_kind = out_sharding.memory_kind
Expand Down Expand Up @@ -1115,7 +1104,7 @@ def test_remat_jaxpr_offloadable(self):
inp = jax.device_put(np.arange(16.), NamedSharding(mesh, P("x")))

def policy(prim, *avals, **params):
return Offloadable(src="tpu_hbm", dst="unpinned_host")
return Offloadable(src="tpu_hbm", dst="pinned_host")

@functools.partial(remat, policy=policy)
def f(x):
Expand All @@ -1126,43 +1115,68 @@ def f(x):

fwd_jaxpr, bwd_jaxpr = jtu.fwd_bwd_jaxprs(f, inp)

self.assertLen(fwd_jaxpr.out_avals, 4) # 1 output, 3 offloaded residuals
self.assertLen(fwd_jaxpr.out_avals, 4) # 1 output, 3 offloaded residuals
fwd_mem_kind_count = str(fwd_jaxpr).count(
"TransferToMemoryKind(memory_kind='unpinned_host')")
"TransferToMemoryKind(memory_kind='pinned_host')")
self.assertEqual(fwd_mem_kind_count, 3)

self.assertLen(bwd_jaxpr.in_avals, 4) # 3 offloaded residuals, 1 input
self.assertLen(bwd_jaxpr.in_avals, 4) # 3 offloaded residuals, 1 input
bwd_mem_kind_count = str(bwd_jaxpr).count(
"TransferToMemoryKind(memory_kind='tpu_hbm')")
self.assertEqual(bwd_mem_kind_count, 3)

# Execution test.
f = jax.jit(jax.grad(f))
f(inp) # doesn't crash

compiled_text = f.lower(inp).compile().as_text()
if compiled_text is not None:
self.assertIn('S(5)', compiled_text)
self.assertRegex(compiled_text, r"copy-start.*S\(5\)")
self.assertRegex(compiled_text, r"copy-done.*S\(5\)")

def test_remat_scan_jaxpr_offloadable(self):
mesh = jtu.create_global_mesh((2,), ("x",))
inp = jax.device_put(np.arange(16.), NamedSharding(mesh, P("x")))
shape = (256, 128)
np_inp = np.arange(math.prod(shape), dtype=np.float32).reshape(shape)
inp = jax.device_put(np_inp, NamedSharding(mesh, P("x")))

def policy(prim, *avals, **params):
return Offloadable(src="tpu_hbm", dst="unpinned_host")
policy = jax.checkpoint_policies.save_and_offload_only_these_names(
names_which_can_be_saved=["y"], names_which_can_be_offloaded=["z", "w"],
offload_src='tpu_hbm', offload_dst='pinned_host')

@functools.partial(remat, policy=policy)
def f(x):
@functools.partial(remat, policy=policy)
def g(y, _):
y = jnp.sin(y)
y = jnp.sin(y)
y = jnp.sin(y)
return y, None
return jax.lax.scan(g, x, None, length=1)[0]
def g(ys, _):
y, _ = ys
y = checkpoint_name(jnp.sin(y), "y")
z = checkpoint_name(jnp.sin(y), "z")
w = checkpoint_name(jnp.sin(z), "w")
return (w, jnp.sum(w)), None
_, scan_out = jax.lax.scan(g, (x, np.array(1, dtype=np.float32)), [np_inp])[0]
return scan_out

fwd_jaxpr, bwd_jaxpr = jtu.fwd_bwd_jaxprs(f, inp)

self.assertLen(fwd_jaxpr.out_avals, 4) # 1 output, 3 offloaded residuals
self.assertLen(fwd_jaxpr.out_avals, 5) # 2 output, 3 offloaded residuals
fwd_mem_kind_count = str(fwd_jaxpr).count(
"TransferToMemoryKind(memory_kind='unpinned_host')")
self.assertEqual(fwd_mem_kind_count, 3)
"TransferToMemoryKind(memory_kind='pinned_host')")
self.assertEqual(fwd_mem_kind_count, 2)

self.assertLen(bwd_jaxpr.in_avals, 4) # 3 offloaded residuals, 1 input
self.assertLen(bwd_jaxpr.in_avals, 5) # 3 offloaded residuals, 2 input
bwd_mem_kind_count = str(bwd_jaxpr).count(
"TransferToMemoryKind(memory_kind='tpu_hbm')")
self.assertEqual(bwd_mem_kind_count, 3)
self.assertEqual(bwd_mem_kind_count, 2)

f = jax.jit(jax.grad(f))
f(inp) # doesn't crash

compiled_text = f.lower(inp).compile().as_text()
if compiled_text is not None:
self.assertIn('S(5)', compiled_text)
self.assertNotRegex(compiled_text, r"copy-start.*S\(5\)")
self.assertNotRegex(compiled_text, r"copy-done.*S\(5\)")


if __name__ == "__main__":
absltest.main(testLoader=jtu.JaxTestLoader())

0 comments on commit d30c3ad

Please sign in to comment.