Skip to content

Commit

Permalink
Merge pull request #18786 from gnecula:test_export_effects
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 587333913
  • Loading branch information
jax authors committed Dec 2, 2023
2 parents b51b80e + bd7c1aa commit 61e79cd
Showing 1 changed file with 35 additions and 23 deletions.
58 changes: 35 additions & 23 deletions tests/export_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -946,22 +946,22 @@ def f_jax_inner(x):

exp = export.export(f_jax)(x)
if exp.serialization_version >= export._VERSION_START_SUPPORT_EFFECTS_WITH_REAL_TOKENS:
self.assertSetEqual({"TestingOrderedEffect1", "TestingOrderedEffect2"},
{str(e) for e in exp.ordered_effects})
self.assertEqual({"TestingUnorderedEffect1"},
{str(e) for e in exp.unordered_effects})
self.assertEqual(["TestingOrderedEffect1", "TestingOrderedEffect2"],
sorted(str(e) for e in exp.ordered_effects))
self.assertEqual(["TestingUnorderedEffect1"],
[str(e) for e in exp.unordered_effects])
else:
self.assertSetEqual(set(), {str(e) for e in exp.ordered_effects})
self.assertSetEqual(set(), {str(e) for e in exp.unordered_effects})
self.assertEqual([], [str(e) for e in exp.ordered_effects])
self.assertEqual([], [str(e) for e in exp.unordered_effects])
mlir_module_str = str(exp.mlir_module())

# Inner functions use stablehlo.token for all versions
inner_fun_expected_re = (
r"func.func private @f_jax_inner\("
r"%arg0: !stablehlo.token {jax.token = true}.*"
r"%arg0: !stablehlo.token .*jax.token = true.*"
r"%arg1: tensor<3xf32>.*->.*"
# Results
r"!stablehlo.token {jax.token = true}.*"
r"!stablehlo.token .*jax.token = true.*"
r"tensor<3xf32>"
)
self.assertRegex(mlir_module_str, inner_fun_expected_re)
Expand All @@ -970,11 +970,11 @@ def f_jax_inner(x):
# i1[0] before version 9.
wrapped_main_expected_re = (
r"@_wrapped_jax_export_main\("
r"%arg0: !stablehlo.token {jax.token = true}.*"
r"%arg1: !stablehlo.token {jax.token = true}.*->.*"
r"%arg0: !stablehlo.token .*jax.token = true.*"
r"%arg1: !stablehlo.token .*jax.token = true.*->.*"
# Results
r"!stablehlo.token {jax.token = true}.*"
r"!stablehlo.token {jax.token = true}.*")
r"!stablehlo.token .*jax.token = true.*"
r"!stablehlo.token .*jax.token = true.*")
if exp.serialization_version < export._VERSION_START_SUPPORT_EFFECTS_WITH_REAL_TOKENS:
wrapped_main_expected_re = wrapped_main_expected_re.replace("!stablehlo.token", "tensor<0xi1>")
self.assertRegex(mlir_module_str, wrapped_main_expected_re)
Expand All @@ -987,20 +987,32 @@ def f_jax_inner(x):
main_expected_re = wrapped_main_expected_re.replace("@_wrapped_jax_export_main", "@main")
self.assertRegex(mlir_module_str, main_expected_re)

lowered = jax.jit(export.call_exported(exp)).lower(x)
# Now call the exported from a function that uses its own effects
def f_outer(x):
return (
testing_primitive_with_effect_p.bind(
x, effect_class_name="TestingOrderedEffect2") +
testing_primitive_with_effect_p.bind(
x, effect_class_name="TestingUnorderedEffect1") +
export.call_exported(exp)(x))

lowered_outer = jax.jit(f_outer).lower(x)
if exp.serialization_version < export._VERSION_START_SUPPORT_EFFECTS_WITH_REAL_TOKENS:
self.assertSetEqual(set(),
{str(e) for e in lowered._lowering.compile_args["ordered_effects"]})
self.assertSetEqual(set(),
{str(e) for e in lowered._lowering.compile_args["unordered_effects"]})
self.assertEqual(["TestingOrderedEffect2"],
[str(e) for e in lowered_outer._lowering.compile_args["ordered_effects"]])
else:
self.assertSetEqual({"TestingOrderedEffect1", "TestingOrderedEffect2"},
{str(e) for e in lowered._lowering.compile_args["ordered_effects"]})
self.assertSetEqual({"TestingUnorderedEffect1"},
{str(e) for e in lowered._lowering.compile_args["unordered_effects"]})
self.assertEqual(["TestingOrderedEffect1", "TestingOrderedEffect2"],
sorted(str(e) for e in lowered_outer._lowering.compile_args["ordered_effects"]))
self.assertEqual(["TestingUnorderedEffect1"],
sorted([str(e) for e in lowered_outer._lowering.compile_args["unordered_effects"]]))

res = export.call_exported(exp)(x)
self.assertAllClose(10. + 4. * 2. * x, res)
mlir_outer_module_str = str(lowered_outer.compiler_ir())
if exp.serialization_version >= export._VERSION_START_SUPPORT_EFFECTS_WITH_REAL_TOKENS:
main_expected_re = main_expected_re.replace("!stablehlo.token", "tensor<0xi1>")
self.assertRegex(mlir_outer_module_str, main_expected_re)

res = jax.jit(f_outer)(x)
self.assertAllClose(2. * 2. * x + 10. + 4. * 2. * x, res)

@jtu.parameterized_filterable(
kwargs=[
Expand Down

0 comments on commit 61e79cd

Please sign in to comment.