Skip to content

Commit

Permalink
Make jax.grad and compute_on work correctly. If the forward pass …
Browse files Browse the repository at this point in the history
…has annotation to execute on CPU, then it's backward pass also executes on CPU.

PiperOrigin-RevId: 634917402
  • Loading branch information
yashk2810 authored and jax authors committed May 17, 2024
1 parent 1043e24 commit 02c19e9
Show file tree
Hide file tree
Showing 6 changed files with 69 additions and 33 deletions.
1 change: 1 addition & 0 deletions jax/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -463,6 +463,7 @@ pytype_strict_library(
"_src/linear_util.py",
],
deps = [
":compute_on",
":config",
":dtypes",
":effects",
Expand Down
25 changes: 20 additions & 5 deletions jax/_src/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,9 +18,8 @@
from collections.abc import (Collection, Generator, Hashable, Iterable,
Iterator, Set, Sequence, MutableSet,
MutableMapping)
from contextlib import contextmanager
from contextlib import contextmanager, ContextDecorator, ExitStack
from dataclasses import dataclass
import dataclasses
import functools
from functools import partial, partialmethod, total_ordering
import gc
Expand All @@ -40,6 +39,7 @@
from jax._src import dtypes
from jax._src import config
from jax._src import effects
from jax._src import compute_on
from jax._src.errors import (
ConcretizationTypeError, TracerArrayConversionError, TracerBoolConversionError,
TracerIntegerConversionError, UnexpectedTracerError)
Expand Down Expand Up @@ -260,9 +260,24 @@ def jaxpr_as_fun(closed_jaxpr: ClosedJaxpr, *args):
return eval_jaxpr(closed_jaxpr.jaxpr, closed_jaxpr.consts, *args)


@dataclasses.dataclass(frozen=True)
class JaxprEqnContext:
compute_type: str | None
class JaxprEqnContext(ContextDecorator):

def __init__(self, compute_type: str | None):
self.compute_type = compute_type
self._exit_stack = ExitStack()
self._managers = [(compute_on.extend_compute_type, self.compute_type)]

def __enter__(self):
for manager, val in self._managers:
self._exit_stack.enter_context(manager(val))
return self

def __exit__(self, exc_type, exc_value, traceback):
self._exit_stack.close()
return False

def __repr__(self):
return f'JaxprEqnContext(compute_type={self.compute_type})'


class JaxprEqn(NamedTuple):
Expand Down
3 changes: 2 additions & 1 deletion jax/_src/interpreters/ad.py
Original file line number Diff line number Diff line change
Expand Up @@ -238,7 +238,8 @@ def write_primal(v, val):
else:
cts_in, = map(read_cotangent, eqn.outvars)
name_stack = source_info_util.current_name_stack() + eqn.source_info.name_stack
with source_info_util.user_context(eqn.source_info.traceback, name_stack=name_stack):
with source_info_util.user_context(
eqn.source_info.traceback, name_stack=name_stack), eqn.ctx:
if eqn.primitive.call_primitive or eqn.primitive.map_primitive:
cts_in_avals = [v.aval for v in eqn.outvars]
params = dict(eqn.params)
Expand Down
42 changes: 28 additions & 14 deletions jax/_src/interpreters/partial_eval.py
Original file line number Diff line number Diff line change
Expand Up @@ -226,16 +226,18 @@ def default_process_primitive(self, primitive, tracers, params):
out_aval, effects = primitive.abstract_eval(*avals, **params)
name_stack = self._current_truncated_name_stack()
source = source_info_util.current().replace(name_stack=name_stack)
eqn_ctx = core.JaxprEqnContext(compute_on.current_compute_type())
if primitive.multiple_results:
out_tracers = [JaxprTracer(self, PartialVal.unknown(aval), None)
for aval in out_aval]
eqn = new_eqn_recipe(tracers, out_tracers, primitive, params, effects, source)
eqn = new_eqn_recipe(tracers, out_tracers, primitive, params, effects,
source, eqn_ctx)
for t in out_tracers: t.recipe = eqn
return out_tracers
else:
out_tracer = JaxprTracer(self, PartialVal.unknown(out_aval), None)
out_tracer.recipe = new_eqn_recipe(tracers, [out_tracer], primitive,
params, effects, source)
params, effects, source, eqn_ctx)
return out_tracer

def process_call(self, primitive, f, tracers, params):
Expand Down Expand Up @@ -315,9 +317,10 @@ def process_call(self, primitive, f, tracers, params):
for a in out_type]
name_stack = self._current_truncated_name_stack()
source = source_info_util.current().replace(name_stack=name_stack)
eqn_ctx = core.JaxprEqnContext(compute_on.current_compute_type())
eqn = new_eqn_recipe((*res_tracers, *env_tracers, *unknown_arg_tracers),
out_tracers, primitive, staged_params, jaxpr.effects,
source)
source, eqn_ctx)
for t in out_tracers: t.recipe = eqn
return merge_lists(out_knowns, out_tracers, out_consts)

Expand Down Expand Up @@ -383,8 +386,10 @@ def const_out_axes_thunk():
for a in out_avals]
effs = core.filter_named_axis_effects(jaxpr.effects, {params['axis_name']})
src_info = source_info_util.current()
eqn_ctx = core.JaxprEqnContext(compute_on.current_compute_type())
eqn = new_eqn_recipe((*const_tracers, *env_tracers, *unknown_arg_tracers), # type: ignore[arg-type]
out_tracers, primitive, staged_params, effs, src_info)
out_tracers, primitive, staged_params, effs, src_info,
eqn_ctx)
for t in out_tracers: t.recipe = eqn

return merge_lists(out_knowns, out_tracers, out_consts)
Expand All @@ -409,8 +414,9 @@ def todo(out):
new_params = dict(new_params, call_jaxpr=convert_constvars_jaxpr(jaxpr))
name_stack = self._current_truncated_name_stack()
source = source_info_util.current().replace(name_stack=name_stack)
eqn_ctx = core.JaxprEqnContext(compute_on.current_compute_type())
eqn = new_eqn_recipe(in_tracers, out_tracers, primitive, new_params,
jaxpr.effects, source)
jaxpr.effects, source, eqn_ctx)
for t in out_tracers: t.recipe = eqn
return merge_lists(out_knowns, out_tracers, out_consts)

Expand Down Expand Up @@ -448,8 +454,10 @@ def todo(out):
for a in out_avals]
name_stack = self._current_truncated_name_stack()
source = source_info_util.current().replace(name_stack=name_stack)
eqn_ctx = core.JaxprEqnContext(compute_on.current_compute_type())
eqn = new_eqn_recipe((*const_tracers, *env_tracers), out_tracers,
primitive, staged_params, jaxpr.effects, source)
primitive, staged_params, jaxpr.effects, source,
eqn_ctx)
for t in out_tracers: t.recipe = eqn
return merge_lists(out_knowns, out_tracers, out_consts)

Expand Down Expand Up @@ -490,8 +498,9 @@ def process_custom_transpose(self, prim, call, tracers, **params):
for aval in params['out_types']]
in_tracers = map(self.instantiate_const, tracers)
new_params = dict(params, call=call)
eqn_ctx = core.JaxprEqnContext(compute_on.current_compute_type())
eqn = new_eqn_recipe(in_tracers, out_tracers, prim, new_params,
core.no_effects, source_info_util.current())
core.no_effects, source_info_util.current(), eqn_ctx)
for t in out_tracers: t.recipe = eqn
return out_tracers

Expand Down Expand Up @@ -528,14 +537,15 @@ def fwd_jaxpr_thunk(*zeros):

name_stack = self._current_truncated_name_stack()
source = source_info_util.current().replace(name_stack=name_stack)
eqn_ctx = core.JaxprEqnContext(compute_on.current_compute_type())
eqn = new_eqn_recipe((*res_tracers, *env_tracers, *tracers),
out_tracers, prim.initial_style,
dict(fun_jaxpr=closed_jaxpr,
fwd_jaxpr_thunk=fwd_jaxpr_thunk,
num_consts=len(res) + len(env),
bwd=bwd, out_trees=out_trees,
symbolic_zeros=symbolic_zeros),
jaxpr.effects, source)
jaxpr.effects, source, eqn_ctx)
for t in out_tracers: t.recipe = eqn
return merge_lists(out_knowns, out_tracers, out_consts)

Expand Down Expand Up @@ -876,14 +886,15 @@ class JaxprEqnRecipe(NamedTuple):
params: dict[str, Any]
effects: core.Effects
source_info: source_info_util.SourceInfo
ctx: JaxprEqnContext

def new_eqn_recipe(in_tracers: Sequence[JaxprTracer],
out_tracers: Sequence[JaxprTracer],
primitive: Primitive,
params: dict[str, Any],
effects: core.Effects,
source_info: source_info_util.SourceInfo
) -> JaxprEqnRecipe:
source_info: source_info_util.SourceInfo,
ctx: JaxprEqnContext | None = None) -> JaxprEqnRecipe:
# TODO(necula): move these checks to core.check_jaxpr, and call in more places
if primitive.call_primitive or primitive.map_primitive:
assert "call_jaxpr" in params
Expand All @@ -895,18 +906,21 @@ def new_eqn_recipe(in_tracers: Sequence[JaxprTracer],
assert ("donated_invars" in params and
len(params["donated_invars"]) == len(params["call_jaxpr"].invars))
out_avals = [core.raise_to_shaped(t.aval) for t in out_tracers]
ctx = ctx or JaxprEqnContext(None)
return JaxprEqnRecipe(object(), tuple(in_tracers), map(ref, out_tracers),
out_avals, primitive, params, effects, source_info)
out_avals, primitive, params, effects, source_info,
ctx)


def recipe_to_eqn(getvar: Callable[[JaxprTracer], Atom],
recipe: JaxprEqnRecipe) -> core.JaxprEqn:
(_, in_tracers, out_tracer_refs, out_avals, prim, params, eff, src) = recipe
(_, in_tracers, out_tracer_refs, out_avals, prim, params, eff, src,
ctx) = recipe
invars = [getvar(t) for t in in_tracers]
out_tracers = [t_ref() for t_ref in out_tracer_refs]
outvars = [DropVar(a) if t is None else getvar(t) # type: ignore
for a, t in zip(out_avals, out_tracers)]
return new_jaxpr_eqn(invars, outvars, prim, params, eff, src)
return new_jaxpr_eqn(invars, outvars, prim, params, eff, src, ctx)

def tracers_to_jaxpr(
in_tracers: Sequence[JaxprTracer],
Expand Down Expand Up @@ -959,7 +973,7 @@ def type_substitute(aval: AbstractValue) -> AbstractValue:
outvars = [DropVar(type_substitute(a)) if rf() is None else newvar(rf())
for a, rf in zip(r.out_avals, r.out_tracer_refs)]
eqns.append(new_jaxpr_eqn(in_atoms, outvars, r.primitive, r.params,
r.effects, r.source_info))
r.effects, r.source_info, r.ctx))
processed_eqn_ids.add(r.eqn_id)
elif isinstance(r, LambdaBinding):
if not any(t is in_tracer for in_tracer in in_tracers):
Expand Down
4 changes: 3 additions & 1 deletion jax/_src/pjit.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,7 @@
from jax._src import tree_util
from jax._src import util
from jax._src import xla_bridge as xb
from jax._src import compute_on
from jax._src.api_util import (
argnums_partial_except, flatten_axes, flatten_fun, flatten_fun_nokwargs,
donation_vector, shaped_abstractify, check_callable, resolve_argnums,
Expand Down Expand Up @@ -1962,12 +1963,13 @@ def keep_where(l, should_keep):
pe.JaxprTracer(trace, pe.PartialVal.unknown(aval), None)
for aval in unknown_out_avals
]
eqn_ctx = core.JaxprEqnContext(compute_on.current_compute_type())
eqn = pe.new_eqn_recipe((*unknown_tracers_in, *residual_tracers),
unknown_tracers_out,
pjit_p,
unknown_params,
unknown_jaxpr.effects,
source_info_util.current())
source_info_util.current(), eqn_ctx)
for t in unknown_tracers_out: t.recipe = eqn
return merge_lists(unknown_outs, known_out_vals, unknown_tracers_out)

Expand Down
27 changes: 15 additions & 12 deletions tests/memories_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -1317,20 +1317,23 @@ def f2(x):
" yet."):
f2(jnp.arange(8))

# def test_compute_on_grad(self):
# @compute_on('device_host')
# @jax.jit
# def g(x):
# return x * 2
def test_compute_on_grad(self):
@compute_on('device_host')
@jax.jit
def g(x):
return jnp.sin(x)

# def f(x):
# y = g(x)
# return jnp.sum(y * 3)
def f(x):
y = g(x)
return jnp.sum(y)

# inp = jnp.arange(8)
# jf = jax.jit(jax.grad(f))
# out = jf(inp)
# print(jax.jit(jax.grad(f)).lower(inp).as_text())
inp = jnp.arange(8.)
jf = jax.jit(jax.grad(f))

jtu.check_grads(jf, (inp,), order=2)

lowered_text = jf.lower(inp).as_text()
self.assertEqual(lowered_text.count('_xla_compute_type = "host"'), 2)

# def test_sharded_compute_on_host(self):
# mesh = jtu.create_global_mesh((2, 2), ('x', 'y'))
Expand Down

0 comments on commit 02c19e9

Please sign in to comment.