Skip to content

Commit

Permalink
Fix PE not allowing double JIT-ted effectful functions
Browse files Browse the repository at this point in the history
  • Loading branch information
sharadmv committed Jul 21, 2022
1 parent a4e7548 commit d6c172d
Show file tree
Hide file tree
Showing 5 changed files with 33 additions and 16 deletions.
5 changes: 3 additions & 2 deletions jax/_src/ad_checkpoint.py
Expand Up @@ -280,8 +280,6 @@ def remat_impl(*args, jaxpr, prevent_cse, differentiated, policy):
@remat_p.def_effectful_abstract_eval
def remat_abstract_eval(*args, jaxpr, prevent_cse, differentiated, policy):
del args, prevent_cse, differentiated, policy # Unused.
if jaxpr.effects:
raise NotImplementedError('Effects not supported in `remat`.')
return [v.aval for v in jaxpr.outvars], jaxpr.effects

def remat_jvp(primals, tangents, jaxpr, prevent_cse, differentiated, policy):
Expand All @@ -303,6 +301,9 @@ def remat_jvp(primals, tangents, jaxpr, prevent_cse, differentiated, policy):

def remat_partial_eval(trace, *tracers, jaxpr, **params):
assert not jaxpr.constvars
if jaxpr.effects:
raise NotImplementedError(
'Effects not supported in partial-eval of `checkpoint`/`remat`.')
policy = params['policy'] or nothing_saveable
in_unknowns = [not t.is_known() for t in tracers]
jaxpr_known, jaxpr_staged, out_unknowns, out_inst, num_res = \
Expand Down
3 changes: 2 additions & 1 deletion jax/interpreters/mlir.py
Expand Up @@ -1071,9 +1071,10 @@ def _call_lowering(fn_name, stack_name, call_jaxpr, backend, ctx, avals_in,
if isinstance(call_jaxpr, core.Jaxpr):
call_jaxpr = core.ClosedJaxpr(call_jaxpr, ())
xla.check_backend_matches(backend, ctx.platform)
effects = tokens_in.effects()
output_types = map(aval_to_ir_types, avals_out)
output_types = [token_type()] * len(effects) + output_types
flat_output_types = util.flatten(output_types)
effects = tokens_in.effects()
symbol_name = lower_jaxpr_to_fun(ctx, fn_name, call_jaxpr, effects).name.value
args = [*tokens_in.tokens(), *args]
call = func_dialect.CallOp(flat_output_types,
Expand Down
6 changes: 3 additions & 3 deletions jax/interpreters/partial_eval.py
Expand Up @@ -1149,9 +1149,11 @@ def _remat_partial_eval(trace, _, f, tracers, params):
f, aux = partial_eval_wrapper_nounits(f, tuple(in_knowns), tuple(in_avals))
consts = remat_call_p.bind(f, **params) # no known inputs
_, out_avals, jaxpr, env = aux()
if jaxpr.effects:
raise NotImplementedError(
'Effects not supported in partial-eval of `checkpoint`/`remat`.')
env_tracers = map(trace.full_raise, env)
jaxpr = convert_constvars_jaxpr(jaxpr)
if jaxpr.effects: raise NotImplementedError
del in_pvals, in_knowns, in_avals, out_avals, f, aux, env
# When concrete=True, we could avoid some redundant computation by extracting
# values from any ConcreteArrays in `out_avals`, but we eschew that
Expand Down Expand Up @@ -1823,8 +1825,6 @@ def process_call(self, call_primitive, f, explicit_tracers, params):
else:
jaxpr, out_type, consts = trace_to_subjaxpr_dynamic2_memoized(
f, self.main).val
if jaxpr.effects:
raise NotImplementedError('Effects not supported for call primitives.')
if params.get('inline', False):
return core.eval_jaxpr(jaxpr, consts, *in_tracers)
source_info = source_info_util.current()
Expand Down
11 changes: 11 additions & 0 deletions tests/debugging_primitives_test.py
Expand Up @@ -114,6 +114,17 @@ def f(x):
jax.effects_barrier()
self.assertEqual(output(), "x: 2\n")

@jtu.skip_on_devices(*disabled_backends)
def test_can_double_stage_out_ordered_print(self):
@jax.jit
@jax.jit
def f(x):
debug_print('x: {x}', x=x, ordered=True)
with capture_stdout() as output:
f(2)
jax.effects_barrier()
self.assertEqual(output(), "x: 2\n")

@jtu.skip_on_devices(*disabled_backends)
def test_can_stage_out_ordered_print_with_pytree(self):
@jax.jit
Expand Down
24 changes: 14 additions & 10 deletions tests/jaxpr_effects_test.py
Expand Up @@ -25,6 +25,7 @@
from jax import lax
from jax import linear_util as lu
from jax.config import config
from jax.interpreters import ad
from jax.experimental import maps
from jax.experimental import pjit
from jax.interpreters import mlir
Expand All @@ -41,8 +42,12 @@
effect_p.multiple_results = True

@effect_p.def_effectful_abstract_eval
def _(*, effect):
return [], {effect}
def _(*avals, effect):
return avals, {effect}

def effect_jvp_rule(primals, tangents, effect):
return effect_p.bind(*primals, effect=effect), tangents
ad.primitive_jvps[effect_p] = effect_jvp_rule

mlir.lowerable_effects.add('foo')
mlir.lowerable_effects.add('foo2')
Expand Down Expand Up @@ -189,8 +194,7 @@ def f_(x):
effect_p.bind(effect='bar')
return [x]
return core.call(f_, x)[0]
with self.assertRaisesRegex(NotImplementedError, 'Effects not supported'):
jax.make_jaxpr(f)(2.)
jax.make_jaxpr(f)(2.)

def test_xla_call_primitive_inherits_effects(self):

Expand All @@ -199,8 +203,7 @@ def f(x):
effect_p.bind(effect='foo')
effect_p.bind(effect='bar')
return x
with self.assertRaisesRegex(NotImplementedError, 'Effects not supported'):
jax.make_jaxpr(f)(2.)
jax.make_jaxpr(f)(2.)

@parameterized.named_parameters(jtu.cases_from_list(
dict(testcase_name=f"_{flavor}", flavor=flavor)
Expand All @@ -210,11 +213,12 @@ def test_remat_call_primitive_inherits_effects(self, flavor):

@remat
def f(x):
effect_p.bind(effect='foo')
effect_p.bind(effect='bar')
x, = effect_p.bind(x, effect='foo')
x, = effect_p.bind(x, effect='bar')
return x
with self.assertRaisesRegex(NotImplementedError, 'Effects not supported'):
jax.make_jaxpr(f)(2.)
jax.make_jaxpr(f)(2.)
with self.assertRaisesRegex(NotImplementedError, "Effects not supported"):
jax.make_jaxpr(lambda x: jax.linearize(f, x)[1](x))(2.)

def test_custom_jvp_primitive_inherits_effects(self):

Expand Down

0 comments on commit d6c172d

Please sign in to comment.