diff --git a/jax/experimental/export/serialization.py b/jax/experimental/export/serialization.py index f8dfe399df5d..694648b23f98 100644 --- a/jax/experimental/export/serialization.py +++ b/jax/experimental/export/serialization.py @@ -386,17 +386,24 @@ def _deserialize_sharding(s: ser_flatbuf.Sharding) -> export.Sharding: def _serialize_effect(builder: flatbuffers.Builder, eff: core.Effect) -> int: - # TODO(necula): for now serialize just the name of the class try: - _ = eff.__class__() - except: + eff_replica = eff.__class__() + except Exception: raise NotImplementedError( - f"serializing effect {eff} that does not have a nullary class" - " constructor" + f"Effect {eff} must have a nullary constructor to be serializable" + ) + try: + hash_eff = hash(eff) + hash_eff_replica = hash(eff_replica) + except Exception: + raise NotImplementedError( + f"Effect {eff} must be hashable to be serializable" + ) + if eff != eff_replica or hash_eff != hash_eff_replica: + raise NotImplementedError( + f"Effect {eff} must have a nullary class constructor that produces an " + "equal effect object." ) - # TODO: fix the effects serialization and deserialization, to ensure that - # upon deserialization we reconstruct an effect that compares equal to the - # one that was serialized. effect_type_name = str(eff.__class__) effect_type_name_offset = builder.CreateString(effect_type_name) ser_flatbuf.EffectStart(builder) diff --git a/jax/experimental/jax2tf/call_tf.py b/jax/experimental/jax2tf/call_tf.py index f1f58b3a7f17..6810f7bb5d31 100644 --- a/jax/experimental/jax2tf/call_tf.py +++ b/jax/experimental/jax2tf/call_tf.py @@ -26,6 +26,7 @@ from __future__ import annotations from collections.abc import Sequence +import dataclasses import functools from typing import Any, Callable, Optional @@ -35,20 +36,16 @@ from jax import dtypes from jax import numpy as jnp from jax import tree_util -from jax._src import ad_checkpoint from jax._src import ad_util from jax._src import core -from jax._src import custom_derivatives from jax._src import effects from jax._src import util -from jax._src.lax import control_flow as lax_control_flow from jax._src.lib import xla_client from jax._src.lib.mlir import ir from jax._src.lib.mlir.dialects import func as func_dialect from jax._src.lib.mlir.dialects import hlo from jax.experimental.jax2tf import jax2tf as jax2tf_internal from jax.interpreters import mlir -from jax.interpreters import xla import numpy as np import tensorflow as tf @@ -376,6 +373,7 @@ def _get_concrete_function_tf(function_flat_tf, args_flat_sig_tf): # -> tf.Conc # Mark the effectful instances of call_tf +@dataclasses.dataclass(frozen=True) class CallTfEffect(effects.Effect): __str__ = lambda _: "CallTfEffect" diff --git a/tests/export_test.py b/tests/export_test.py index 5c6af12193e8..f78c5d4eb0a1 100644 --- a/tests/export_test.py +++ b/tests/export_test.py @@ -14,6 +14,7 @@ from __future__ import annotations import contextlib +import dataclasses import functools import logging import math @@ -57,19 +58,35 @@ def tearDownModule(): prev_xla_flags() ### Setup for testing lowering with effects -class TestingOrderedEffect1(effects.Effect): - __str__ = lambda _: "TestingOrderedEffect1" +@dataclasses.dataclass(frozen=True) +class ForTestingOrderedEffect1(effects.Effect): + pass -class TestingOrderedEffect2(effects.Effect): - __str__ = lambda _: "TestingOrderedEffect2" +@dataclasses.dataclass(frozen=True) +class ForTestingOrderedEffect2(effects.Effect): + pass + +@dataclasses.dataclass(frozen=True) +class ForTestingUnorderedEffect1(effects.Effect): + pass + + +class ForTestingOrderedEffect4NoNullary(effects.Effect): + def __init__(self, _): + pass + +@dataclasses.dataclass(eq=False) +class ForTestingOrderedEffect5NoEq(effects.Effect): + pass -class TestingUnorderedEffect1(effects.Effect): - __str__ = lambda _: "TestingUnorderedEffect1" _testing_effects = dict( - TestingOrderedEffect1=TestingOrderedEffect1(), - TestingOrderedEffect2=TestingOrderedEffect2(), - TestingUnorderedEffect1=TestingUnorderedEffect1()) + ForTestingOrderedEffect1=ForTestingOrderedEffect1(), + ForTestingOrderedEffect2=ForTestingOrderedEffect2(), + ForTestingUnorderedEffect1=ForTestingUnorderedEffect1(), + ForTestingOrderedEffect4NoNullary=ForTestingOrderedEffect4NoNullary(42), + ForTestingOrderedEffect5NoEq=ForTestingOrderedEffect5NoEq(), +) # Register the effects for effect in _testing_effects.values(): effect_class = effect.__class__ @@ -1015,23 +1032,20 @@ def f_jax(x): # x: f32[3] # Test also the calling convention for inner functions def f_jax_inner(x): return ( - testing_primitive_with_effect_p.bind(x, effect_class_name="TestingOrderedEffect2") + - testing_primitive_with_effect_p.bind(x, effect_class_name="TestingUnorderedEffect1")) + testing_primitive_with_effect_p.bind(x, effect_class_name="ForTestingOrderedEffect2") + + testing_primitive_with_effect_p.bind(x, effect_class_name="ForTestingUnorderedEffect1")) return ( 10. + jax.jit(f_jax_inner)(x) + - testing_primitive_with_effect_p.bind(x, effect_class_name="TestingOrderedEffect1") + - testing_primitive_with_effect_p.bind(x, effect_class_name="TestingOrderedEffect2") + testing_primitive_with_effect_p.bind(x, effect_class_name="ForTestingOrderedEffect1") + + testing_primitive_with_effect_p.bind(x, effect_class_name="ForTestingOrderedEffect2") ) - # TODO(necula): at the moment serializing and deserializing effects breaks - # the effect equality, and this results in this test failing. So, for now - # we disable the serization round-trip - exp = export.export(f_jax)(x) # get_exported(f_jax)(x) + exp = get_exported(f_jax)(x) if exp.mlir_module_serialization_version >= export._VERSION_START_SUPPORT_EFFECTS_WITH_REAL_TOKENS: - self.assertEqual(["TestingOrderedEffect1", "TestingOrderedEffect2"], + self.assertEqual(["ForTestingOrderedEffect1()", "ForTestingOrderedEffect2()"], sorted(str(e) for e in exp.ordered_effects)) - self.assertEqual(["TestingUnorderedEffect1"], + self.assertEqual(["ForTestingUnorderedEffect1()"], [str(e) for e in exp.unordered_effects]) else: self.assertEqual([], [str(e) for e in exp.ordered_effects]) @@ -1074,19 +1088,19 @@ def f_jax_inner(x): def f_outer(x): return ( testing_primitive_with_effect_p.bind( - x, effect_class_name="TestingOrderedEffect2") + + x, effect_class_name="ForTestingOrderedEffect2") + testing_primitive_with_effect_p.bind( - x, effect_class_name="TestingUnorderedEffect1") + + x, effect_class_name="ForTestingUnorderedEffect1") + export.call_exported(exp)(x)) lowered_outer = jax.jit(f_outer).lower(x) if exp.mlir_module_serialization_version < export._VERSION_START_SUPPORT_EFFECTS_WITH_REAL_TOKENS: - self.assertEqual(["TestingOrderedEffect2"], + self.assertEqual(["ForTestingOrderedEffect2()"], [str(e) for e in lowered_outer._lowering.compile_args["ordered_effects"]]) else: - self.assertEqual(["TestingOrderedEffect1", "TestingOrderedEffect2"], + self.assertEqual(["ForTestingOrderedEffect1()", "ForTestingOrderedEffect2()"], sorted(str(e) for e in lowered_outer._lowering.compile_args["ordered_effects"])) - self.assertEqual(["TestingUnorderedEffect1"], + self.assertEqual(["ForTestingUnorderedEffect1()"], sorted([str(e) for e in lowered_outer._lowering.compile_args["unordered_effects"]])) mlir_outer_module_str = str(lowered_outer.compiler_ir()) @@ -1106,7 +1120,7 @@ def test_ordered_effects_poly(self, *, v: int): self.override_serialization_version(v) x = np.arange(12, dtype=np.float32).reshape((3, 4)) def f_jax(x): # x: f32[b1, b2] - return 10. + testing_primitive_with_effect_p.bind(x, effect_class_name="TestingOrderedEffect1") + return 10. + testing_primitive_with_effect_p.bind(x, effect_class_name="ForTestingOrderedEffect1") exp = get_exported(f_jax)(jax.ShapeDtypeStruct( export.symbolic_shape("b2, b1"), x.dtype)) mlir_module_str = str(exp.mlir_module()) @@ -1150,7 +1164,7 @@ def test_ordered_effects_multi_platform_and_poly(self, *, v: int): x = np.ones((3, 4), dtype=np.float32) def f_jax(x): # x: f32[b1, b2] return 10. + _testing_multi_platform_func(x, - effect_class_name="TestingOrderedEffect1") + effect_class_name="ForTestingOrderedEffect1") exp = get_exported( f_jax, lowering_platforms=("cpu", "tpu") @@ -1196,7 +1210,7 @@ def test_ordered_effects_with_donation(self, *, v: int): def f_jax(x): return testing_primitive_with_effect_p.bind( - x, effect_class_name="TestingOrderedEffect1" + x, effect_class_name="ForTestingOrderedEffect1" ) f_jax = jax.jit(f_jax, donate_argnums=(0,)) @@ -1209,6 +1223,24 @@ def f_jax(x): self.assertRegex(mlir_module_str, r"@main.*tf.aliasing_output = 1") self.assertRegex(mlir_module_str, r"@_wrapped_jax_export_main.*tf.aliasing_output = 1") + @jtu.parameterized_filterable( + kwargs=[ + dict(name=name, expect_error=expect_error) + # name is the suffix for event name: ForTestingOrderedEffectxxx + for name, expect_error in ( + ("4NoNullary", "must have a nullary constructor"), + ("5NoEq", "must have a nullary class constructor that produces an " + "equal effect object"), + ) + ]) + def test_ordered_effects_error(self, *, name: str, expect_error: str): + x = np.ones((3, 4), dtype=np.float32) + def f_jax(x): + return 10. + _testing_multi_platform_func( + x, + effect_class_name="ForTestingOrderedEffect" + name) + with self.assertRaisesRegex(Exception, expect_error): + _ = get_exported(f_jax)(jax.ShapeDtypeStruct((3, 4), x.dtype)) if __name__ == "__main__": absltest.main(testLoader=jtu.JaxTestLoader())