Skip to content

Commit

Permalink
Support remat + compute_on. If the rematted computation is annotated …
Browse files Browse the repository at this point in the history
…to run on host, the backward pass will also execute on host. Also enable no-op nested compute tests.

PiperOrigin-RevId: 634943450
  • Loading branch information
yashk2810 authored and jax authors committed May 18, 2024
1 parent 641d5c8 commit 25aa13c
Show file tree
Hide file tree
Showing 4 changed files with 67 additions and 34 deletions.
6 changes: 4 additions & 2 deletions jax/_src/ad_checkpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@
from jax._src import source_info_util
from jax._src import traceback_util
from jax._src import util
from jax._src import compute_on
from jax._src.api_util import flatten_fun, shaped_abstractify
from jax._src.interpreters import ad
from jax._src.interpreters import batching
Expand Down Expand Up @@ -558,9 +559,10 @@ def remat_partial_eval(trace, *tracers, jaxpr, **params):
out_jaxpr_tracers = [pe.JaxprTracer(trace, pe.PartialVal.unknown(x.aval), None)
for x in jaxpr_unknown.outvars]
new_params = dict(params, jaxpr=jaxpr_unknown, differentiated=True)
ctx = core.JaxprEqnContext(compute_on.current_compute_type())
recipe = pe.new_eqn_recipe(in_jaxpr_tracers, out_jaxpr_tracers, remat_p,
new_params, jaxpr_unknown.effects,
source_info_util.current())
source_info_util.current(), ctx)

# log info about saved residuals
log_level = logging.WARNING if config.log_checkpoint_residuals.value else logging.DEBUG
Expand Down Expand Up @@ -687,7 +689,7 @@ def remat_dce(used_outputs: list[bool], eqn: core.JaxprEqn
new_eqn = pe.new_jaxpr_eqn(
[v for v, used in zip(eqn.invars, used_inputs) if used],
[v for v, used in zip(eqn.outvars, used_outputs) if used],
eqn.primitive, new_params, new_jaxpr.effects, eqn.source_info)
eqn.primitive, new_params, new_jaxpr.effects, eqn.source_info, eqn.ctx)
return used_inputs, new_eqn
pe.dce_rules[remat_p] = remat_dce

Expand Down
4 changes: 2 additions & 2 deletions jax/_src/compute_on.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,10 +28,10 @@ def __init__(self):
@contextmanager
def extend_compute_type(c_type: str):
compute_on_context.stack.append(c_type)
if len(set(compute_on_context.stack)) > 1:
if len(set(filter(lambda x: x is not None, set(compute_on_context.stack)))) > 1:
raise NotImplementedError(
'Nesting `compute_on` with different compute types is not supported'
' yet.')
f' yet. Current stack: {compute_on_context.stack}')
try:
yield compute_on_context.stack[-1]
finally:
Expand Down
5 changes: 4 additions & 1 deletion jax/_src/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -261,6 +261,9 @@ def jaxpr_as_fun(closed_jaxpr: ClosedJaxpr, *args):


class JaxprEqnContext(ContextDecorator):
compute_type: str | None
_exit_stack: ExitStack
_managers: list[tuple[Any, Any]]

def __init__(self, compute_type: str | None):
self.compute_type = compute_type
Expand Down Expand Up @@ -478,7 +481,7 @@ def write(v: Var, val: Any) -> None:
subfuns, bind_params = eqn.primitive.get_bind_params(eqn.params)
name_stack = source_info_util.current_name_stack() + eqn.source_info.name_stack
traceback = eqn.source_info.traceback if propagate_source_info else None
with source_info_util.user_context(traceback, name_stack=name_stack):
with source_info_util.user_context(traceback, name_stack=name_stack), eqn.ctx:
ans = eqn.primitive.bind(*subfuns, *map(read, eqn.invars), **bind_params)
if eqn.primitive.multiple_results:
map(write, eqn.outvars, ans)
Expand Down
86 changes: 57 additions & 29 deletions tests/memories_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@
from jax.ad_checkpoint import checkpoint_name, checkpoint as new_checkpoint
import jax.numpy as jnp
from jax.sharding import PartitionSpec as P
from jax.ad_checkpoint import Offloadable, remat
from jax.ad_checkpoint import Offloadable, remat, Recompute
from jax._src.sharding_impls import (NamedSharding, PositionalSharding,
SingleDeviceSharding, GSPMDSharding,
TransferToMemoryKind,
Expand All @@ -46,7 +46,7 @@ def get_memory_kinds_from_executable(f, args):


def _create_inputs(shape, pspec, mem_kind=None):
mesh = jtu.create_global_mesh((4, 2), ("x", "y"))
mesh = jtu.create_global_mesh((2, 2), ("x", "y"))
np_inp = np.arange(math.prod(shape)).reshape(shape)
s = NamedSharding(mesh, pspec, memory_kind=mem_kind)
inp = jax.device_put(np_inp, s)
Expand Down Expand Up @@ -669,7 +669,7 @@ def mul(x):
self.assertArraysEqual(out2, np_inp2 @ np_inp2.T)

def test_sharding_devices_indices_map_cache_hit(self):
mesh = jtu.create_global_mesh((4, 2), ("x", "y"))
mesh = jtu.create_global_mesh((2, 2), ("x", "y"))
shape = (8, 2)
s1 = NamedSharding(mesh, P("x", "y"))
s2 = NamedSharding(mesh, P("x", "y"), memory_kind="device")
Expand All @@ -682,7 +682,7 @@ def test_sharding_devices_indices_map_cache_hit(self):
self.assertEqual(cache_info2.misses, cache_info1.misses)

def test_jit_host_inputs_via_device_put_outside(self):
mesh = jtu.create_global_mesh((4, 2), ("x", "y"))
mesh = jtu.create_global_mesh((2, 2), ("x", "y"))
s_host = NamedSharding(mesh, P("x", "y"), memory_kind="unpinned_host")
s_hbm = s_host.with_memory_kind("device")
inp = jnp.arange(16).reshape(8, 2)
Expand Down Expand Up @@ -747,7 +747,7 @@ def f(x):
("host_to_hbm", "unpinned_host", "device")
)
def test_device_put_memory_kind_no_sharding(self, inp_mem_kind, out_mem_kind):
mesh = jtu.create_global_mesh((4, 2), ("x", "y"))
mesh = jtu.create_global_mesh((2, 2), ("x", "y"))
np_inp = np.arange(16).reshape(8, 2)
s = NamedSharding(mesh, P("x", "y"), memory_kind=inp_mem_kind)
inp = jax.device_put(np_inp, s)
Expand All @@ -773,7 +773,7 @@ def f(x):
)
def test_device_put_memory_kind_no_sharding_output(
self, inp_mem_kind, out_mem_kind):
mesh = jtu.create_global_mesh((4, 2), ("x", "y"))
mesh = jtu.create_global_mesh((2, 2), ("x", "y"))
np_inp = np.arange(16).reshape(8, 2)
s = NamedSharding(mesh, P("x", "y"), memory_kind=inp_mem_kind)
inp = jax.device_put(np_inp, s)
Expand All @@ -798,7 +798,7 @@ def f(x):
)
def test_device_put_memory_kind_no_sharding_input(
self, inp_mem_kind, out_mem_kind):
mesh = jtu.create_global_mesh((4, 2), ("x", "y"))
mesh = jtu.create_global_mesh((2, 2), ("x", "y"))
np_inp = np.arange(16).reshape(8, 2)
s = NamedSharding(mesh, P("x", "y"), memory_kind=inp_mem_kind)
inp = jax.device_put(np_inp, s)
Expand Down Expand Up @@ -946,7 +946,7 @@ def _check_device_put_addressable_shards(

@parameterized.parameters("unpinned_host", "pinned_host")
def test_device_put_host_to_hbm(self, host_memory_kind: str):
mesh = jtu.create_global_mesh((4, 2), ("x", "y"))
mesh = jtu.create_global_mesh((2, 2), ("x", "y"))
s_host = NamedSharding(mesh, P("y"), memory_kind=host_memory_kind)
np_inp = np.arange(16).reshape(8, 2)

Expand All @@ -960,7 +960,7 @@ def test_device_put_host_to_hbm(self, host_memory_kind: str):

@parameterized.parameters("unpinned_host", "pinned_host")
def test_device_put_hbm_to_host(self, host_memory_kind: str):
mesh = jtu.create_global_mesh((4, 2), ("x", "y"))
mesh = jtu.create_global_mesh((2, 2), ("x", "y"))
s_host = NamedSharding(mesh, P("y"), memory_kind=host_memory_kind)
inp = jnp.arange(16).reshape(8, 2)

Expand Down Expand Up @@ -1063,7 +1063,7 @@ def test_device_put_on_different_device_with_the_same_memory_kind(

@parameterized.parameters("unpinned_host", "pinned_host")
def test_device_put_numpy_array(self, host_memory_kind: str):
mesh = jtu.create_global_mesh((4, 2), ("x", "y"))
mesh = jtu.create_global_mesh((2, 2), ("x", "y"))
np_inp = np.arange(16).reshape(8, 2)
s_hbm = NamedSharding(mesh, P(("x", "y")), memory_kind="device")
s_host = s_hbm.with_memory_kind(host_memory_kind)
Expand Down Expand Up @@ -1140,7 +1140,7 @@ def f(a, b):
out2, np_inp * np_inp * 2, s_host, 'pinned_host')

def test_parameter_streaming_with_scalar_and_constant(self):
mesh = jtu.create_global_mesh((2, 4), ("x", "y"))
mesh = jtu.create_global_mesh((2, 2), ("x", "y"))
scalar_inp = 1
s_host = NamedSharding(mesh, P(), memory_kind="pinned_host")

Expand Down Expand Up @@ -1335,6 +1335,52 @@ def f(x):
lowered_text = jf.lower(inp).as_text()
self.assertEqual(lowered_text.count('_xla_compute_type = "host"'), 2)

def test_compute_on_remat(self):
inp = jnp.arange(16.)

def policy(prim, *avals, **params):
return Recompute

@compute_on('device_host')
@jax.jit
def g(x):
x = jnp.sin(x)
x = jnp.sin(x)
x = jnp.sin(x)
return x

@functools.partial(remat, policy=policy)
def f(x):
x = g(x)
return jnp.sum(x)

# Execution test.
jf = jax.jit(jax.grad(f))
jf(inp) # doesn't crash

lowered_text = jf.lower(inp).as_text()
self.assertEqual(lowered_text.count('_xla_compute_type = "host"'), 2)

def test_nested_no_op_compute(self):
@compute_on('device_host')
@jax.jit
def f0(x):
return x * 2

@compute_on('device_host')
@jax.jit
def f1(x):
x = x * 3
return f0(x)

@jax.jit
def f2(x):
return f1(x)

inp = jnp.arange(8)
out = f2(inp)
self.assertArraysEqual(out, inp * 6)

# def test_sharded_compute_on_host(self):
# mesh = jtu.create_global_mesh((2, 2), ('x', 'y'))
# s = NamedSharding(mesh, P('x', 'y'))
Expand All @@ -1355,24 +1401,6 @@ def f(x):
# self.assertEqual(out.sharding, s)
# self.assertArraysEqual(out, np_inp * 6)

# def test_nested_no_op_compute(self):
# @compute_on('device_host')
# @jax.jit
# def f0(x):
# return x * 2

# @compute_on('device_host')
# @jax.jit
# def f1(x):
# return f0(x)

# @jax.jit
# def f2(x):
# return f1(x)

# print(f2.lower(jnp.arange(8)).as_text('hlo'))
# out = f2(jnp.arange(8))

# def test_eager_compute(self):
# inp = jnp.arange(8)
# with compute_on('device_host'):
Expand Down

0 comments on commit 25aa13c

Please sign in to comment.