From d6c172d53e06483d24dba767edf667a93d89757e Mon Sep 17 00:00:00 2001 From: Sharad Vikram Date: Thu, 21 Jul 2022 11:22:54 -0700 Subject: [PATCH] Fix PE not allowing double JIT-ted effectful functions --- jax/_src/ad_checkpoint.py | 5 +++-- jax/interpreters/mlir.py | 3 ++- jax/interpreters/partial_eval.py | 6 +++--- tests/debugging_primitives_test.py | 11 +++++++++++ tests/jaxpr_effects_test.py | 24 ++++++++++++++---------- 5 files changed, 33 insertions(+), 16 deletions(-) diff --git a/jax/_src/ad_checkpoint.py b/jax/_src/ad_checkpoint.py index f4e8dc2df0fb..0fbdf8d7236c 100644 --- a/jax/_src/ad_checkpoint.py +++ b/jax/_src/ad_checkpoint.py @@ -280,8 +280,6 @@ def remat_impl(*args, jaxpr, prevent_cse, differentiated, policy): @remat_p.def_effectful_abstract_eval def remat_abstract_eval(*args, jaxpr, prevent_cse, differentiated, policy): del args, prevent_cse, differentiated, policy # Unused. - if jaxpr.effects: - raise NotImplementedError('Effects not supported in `remat`.') return [v.aval for v in jaxpr.outvars], jaxpr.effects def remat_jvp(primals, tangents, jaxpr, prevent_cse, differentiated, policy): @@ -303,6 +301,9 @@ def remat_jvp(primals, tangents, jaxpr, prevent_cse, differentiated, policy): def remat_partial_eval(trace, *tracers, jaxpr, **params): assert not jaxpr.constvars + if jaxpr.effects: + raise NotImplementedError( + 'Effects not supported in partial-eval of `checkpoint`/`remat`.') policy = params['policy'] or nothing_saveable in_unknowns = [not t.is_known() for t in tracers] jaxpr_known, jaxpr_staged, out_unknowns, out_inst, num_res = \ diff --git a/jax/interpreters/mlir.py b/jax/interpreters/mlir.py index c38e8c01dddf..2e451380cf4b 100644 --- a/jax/interpreters/mlir.py +++ b/jax/interpreters/mlir.py @@ -1071,9 +1071,10 @@ def _call_lowering(fn_name, stack_name, call_jaxpr, backend, ctx, avals_in, if isinstance(call_jaxpr, core.Jaxpr): call_jaxpr = core.ClosedJaxpr(call_jaxpr, ()) xla.check_backend_matches(backend, ctx.platform) + effects = tokens_in.effects() output_types = map(aval_to_ir_types, avals_out) + output_types = [token_type()] * len(effects) + output_types flat_output_types = util.flatten(output_types) - effects = tokens_in.effects() symbol_name = lower_jaxpr_to_fun(ctx, fn_name, call_jaxpr, effects).name.value args = [*tokens_in.tokens(), *args] call = func_dialect.CallOp(flat_output_types, diff --git a/jax/interpreters/partial_eval.py b/jax/interpreters/partial_eval.py index 942593570f65..bc2a387d2a91 100644 --- a/jax/interpreters/partial_eval.py +++ b/jax/interpreters/partial_eval.py @@ -1149,9 +1149,11 @@ def _remat_partial_eval(trace, _, f, tracers, params): f, aux = partial_eval_wrapper_nounits(f, tuple(in_knowns), tuple(in_avals)) consts = remat_call_p.bind(f, **params) # no known inputs _, out_avals, jaxpr, env = aux() + if jaxpr.effects: + raise NotImplementedError( + 'Effects not supported in partial-eval of `checkpoint`/`remat`.') env_tracers = map(trace.full_raise, env) jaxpr = convert_constvars_jaxpr(jaxpr) - if jaxpr.effects: raise NotImplementedError del in_pvals, in_knowns, in_avals, out_avals, f, aux, env # When concrete=True, we could avoid some redundant computation by extracting # values from any ConcreteArrays in `out_avals`, but we eschew that @@ -1823,8 +1825,6 @@ def process_call(self, call_primitive, f, explicit_tracers, params): else: jaxpr, out_type, consts = trace_to_subjaxpr_dynamic2_memoized( f, self.main).val - if jaxpr.effects: - raise NotImplementedError('Effects not supported for call primitives.') if params.get('inline', False): return core.eval_jaxpr(jaxpr, consts, *in_tracers) source_info = source_info_util.current() diff --git a/tests/debugging_primitives_test.py b/tests/debugging_primitives_test.py index 7d75faacb3a7..84e2b9250abf 100644 --- a/tests/debugging_primitives_test.py +++ b/tests/debugging_primitives_test.py @@ -114,6 +114,17 @@ def f(x): jax.effects_barrier() self.assertEqual(output(), "x: 2\n") + @jtu.skip_on_devices(*disabled_backends) + def test_can_double_stage_out_ordered_print(self): + @jax.jit + @jax.jit + def f(x): + debug_print('x: {x}', x=x, ordered=True) + with capture_stdout() as output: + f(2) + jax.effects_barrier() + self.assertEqual(output(), "x: 2\n") + @jtu.skip_on_devices(*disabled_backends) def test_can_stage_out_ordered_print_with_pytree(self): @jax.jit diff --git a/tests/jaxpr_effects_test.py b/tests/jaxpr_effects_test.py index 8a9f70bf8d29..424cb16ba3da 100644 --- a/tests/jaxpr_effects_test.py +++ b/tests/jaxpr_effects_test.py @@ -25,6 +25,7 @@ from jax import lax from jax import linear_util as lu from jax.config import config +from jax.interpreters import ad from jax.experimental import maps from jax.experimental import pjit from jax.interpreters import mlir @@ -41,8 +42,12 @@ effect_p.multiple_results = True @effect_p.def_effectful_abstract_eval -def _(*, effect): - return [], {effect} +def _(*avals, effect): + return avals, {effect} + +def effect_jvp_rule(primals, tangents, effect): + return effect_p.bind(*primals, effect=effect), tangents +ad.primitive_jvps[effect_p] = effect_jvp_rule mlir.lowerable_effects.add('foo') mlir.lowerable_effects.add('foo2') @@ -189,8 +194,7 @@ def f_(x): effect_p.bind(effect='bar') return [x] return core.call(f_, x)[0] - with self.assertRaisesRegex(NotImplementedError, 'Effects not supported'): - jax.make_jaxpr(f)(2.) + jax.make_jaxpr(f)(2.) def test_xla_call_primitive_inherits_effects(self): @@ -199,8 +203,7 @@ def f(x): effect_p.bind(effect='foo') effect_p.bind(effect='bar') return x - with self.assertRaisesRegex(NotImplementedError, 'Effects not supported'): - jax.make_jaxpr(f)(2.) + jax.make_jaxpr(f)(2.) @parameterized.named_parameters(jtu.cases_from_list( dict(testcase_name=f"_{flavor}", flavor=flavor) @@ -210,11 +213,12 @@ def test_remat_call_primitive_inherits_effects(self, flavor): @remat def f(x): - effect_p.bind(effect='foo') - effect_p.bind(effect='bar') + x, = effect_p.bind(x, effect='foo') + x, = effect_p.bind(x, effect='bar') return x - with self.assertRaisesRegex(NotImplementedError, 'Effects not supported'): - jax.make_jaxpr(f)(2.) + jax.make_jaxpr(f)(2.) + with self.assertRaisesRegex(NotImplementedError, "Effects not supported"): + jax.make_jaxpr(lambda x: jax.linearize(f, x)[1](x))(2.) def test_custom_jvp_primitive_inherits_effects(self):