From 552010a38147222c13677c2a8d34794052d05bcd Mon Sep 17 00:00:00 2001 From: George Necula Date: Wed, 13 Dec 2023 15:43:12 +0100 Subject: [PATCH] [export] Fix the serialization of effects We currently support only the serialization of effects with nullary constructors. We must also ensure that upon deserialization we produce an event that tests equal to the original one. Here we add explicit error checks and tests. We also make the CallTfEffect to have this property. --- jax/experimental/export/serialization.py | 23 ++++--- jax/experimental/jax2tf/call_tf.py | 6 +- tests/export_test.py | 86 ++++++++++++++++-------- 3 files changed, 76 insertions(+), 39 deletions(-) diff --git a/jax/experimental/export/serialization.py b/jax/experimental/export/serialization.py index f8dfe399df5d..694648b23f98 100644 --- a/jax/experimental/export/serialization.py +++ b/jax/experimental/export/serialization.py @@ -386,17 +386,24 @@ def _deserialize_sharding(s: ser_flatbuf.Sharding) -> export.Sharding: def _serialize_effect(builder: flatbuffers.Builder, eff: core.Effect) -> int: - # TODO(necula): for now serialize just the name of the class try: - _ = eff.__class__() - except: + eff_replica = eff.__class__() + except Exception: raise NotImplementedError( - f"serializing effect {eff} that does not have a nullary class" - " constructor" + f"Effect {eff} must have a nullary constructor to be serializable" + ) + try: + hash_eff = hash(eff) + hash_eff_replica = hash(eff_replica) + except Exception: + raise NotImplementedError( + f"Effect {eff} must be hashable to be serializable" + ) + if eff != eff_replica or hash_eff != hash_eff_replica: + raise NotImplementedError( + f"Effect {eff} must have a nullary class constructor that produces an " + "equal effect object." ) - # TODO: fix the effects serialization and deserialization, to ensure that - # upon deserialization we reconstruct an effect that compares equal to the - # one that was serialized. effect_type_name = str(eff.__class__) effect_type_name_offset = builder.CreateString(effect_type_name) ser_flatbuf.EffectStart(builder) diff --git a/jax/experimental/jax2tf/call_tf.py b/jax/experimental/jax2tf/call_tf.py index f1f58b3a7f17..6810f7bb5d31 100644 --- a/jax/experimental/jax2tf/call_tf.py +++ b/jax/experimental/jax2tf/call_tf.py @@ -26,6 +26,7 @@ from __future__ import annotations from collections.abc import Sequence +import dataclasses import functools from typing import Any, Callable, Optional @@ -35,20 +36,16 @@ from jax import dtypes from jax import numpy as jnp from jax import tree_util -from jax._src import ad_checkpoint from jax._src import ad_util from jax._src import core -from jax._src import custom_derivatives from jax._src import effects from jax._src import util -from jax._src.lax import control_flow as lax_control_flow from jax._src.lib import xla_client from jax._src.lib.mlir import ir from jax._src.lib.mlir.dialects import func as func_dialect from jax._src.lib.mlir.dialects import hlo from jax.experimental.jax2tf import jax2tf as jax2tf_internal from jax.interpreters import mlir -from jax.interpreters import xla import numpy as np import tensorflow as tf @@ -376,6 +373,7 @@ def _get_concrete_function_tf(function_flat_tf, args_flat_sig_tf): # -> tf.Conc # Mark the effectful instances of call_tf +@dataclasses.dataclass(frozen=True) class CallTfEffect(effects.Effect): __str__ = lambda _: "CallTfEffect" diff --git a/tests/export_test.py b/tests/export_test.py index 5c6af12193e8..f78c5d4eb0a1 100644 --- a/tests/export_test.py +++ b/tests/export_test.py @@ -14,6 +14,7 @@ from __future__ import annotations import contextlib +import dataclasses import functools import logging import math @@ -57,19 +58,35 @@ def tearDownModule(): prev_xla_flags() ### Setup for testing lowering with effects -class TestingOrderedEffect1(effects.Effect): - __str__ = lambda _: "TestingOrderedEffect1" +@dataclasses.dataclass(frozen=True) +class ForTestingOrderedEffect1(effects.Effect): + pass -class TestingOrderedEffect2(effects.Effect): - __str__ = lambda _: "TestingOrderedEffect2" +@dataclasses.dataclass(frozen=True) +class ForTestingOrderedEffect2(effects.Effect): + pass + +@dataclasses.dataclass(frozen=True) +class ForTestingUnorderedEffect1(effects.Effect): + pass + + +class ForTestingOrderedEffect4NoNullary(effects.Effect): + def __init__(self, _): + pass + +@dataclasses.dataclass(eq=False) +class ForTestingOrderedEffect5NoEq(effects.Effect): + pass -class TestingUnorderedEffect1(effects.Effect): - __str__ = lambda _: "TestingUnorderedEffect1" _testing_effects = dict( - TestingOrderedEffect1=TestingOrderedEffect1(), - TestingOrderedEffect2=TestingOrderedEffect2(), - TestingUnorderedEffect1=TestingUnorderedEffect1()) + ForTestingOrderedEffect1=ForTestingOrderedEffect1(), + ForTestingOrderedEffect2=ForTestingOrderedEffect2(), + ForTestingUnorderedEffect1=ForTestingUnorderedEffect1(), + ForTestingOrderedEffect4NoNullary=ForTestingOrderedEffect4NoNullary(42), + ForTestingOrderedEffect5NoEq=ForTestingOrderedEffect5NoEq(), +) # Register the effects for effect in _testing_effects.values(): effect_class = effect.__class__ @@ -1015,23 +1032,20 @@ def f_jax(x): # x: f32[3] # Test also the calling convention for inner functions def f_jax_inner(x): return ( - testing_primitive_with_effect_p.bind(x, effect_class_name="TestingOrderedEffect2") + - testing_primitive_with_effect_p.bind(x, effect_class_name="TestingUnorderedEffect1")) + testing_primitive_with_effect_p.bind(x, effect_class_name="ForTestingOrderedEffect2") + + testing_primitive_with_effect_p.bind(x, effect_class_name="ForTestingUnorderedEffect1")) return ( 10. + jax.jit(f_jax_inner)(x) + - testing_primitive_with_effect_p.bind(x, effect_class_name="TestingOrderedEffect1") + - testing_primitive_with_effect_p.bind(x, effect_class_name="TestingOrderedEffect2") + testing_primitive_with_effect_p.bind(x, effect_class_name="ForTestingOrderedEffect1") + + testing_primitive_with_effect_p.bind(x, effect_class_name="ForTestingOrderedEffect2") ) - # TODO(necula): at the moment serializing and deserializing effects breaks - # the effect equality, and this results in this test failing. So, for now - # we disable the serization round-trip - exp = export.export(f_jax)(x) # get_exported(f_jax)(x) + exp = get_exported(f_jax)(x) if exp.mlir_module_serialization_version >= export._VERSION_START_SUPPORT_EFFECTS_WITH_REAL_TOKENS: - self.assertEqual(["TestingOrderedEffect1", "TestingOrderedEffect2"], + self.assertEqual(["ForTestingOrderedEffect1()", "ForTestingOrderedEffect2()"], sorted(str(e) for e in exp.ordered_effects)) - self.assertEqual(["TestingUnorderedEffect1"], + self.assertEqual(["ForTestingUnorderedEffect1()"], [str(e) for e in exp.unordered_effects]) else: self.assertEqual([], [str(e) for e in exp.ordered_effects]) @@ -1074,19 +1088,19 @@ def f_jax_inner(x): def f_outer(x): return ( testing_primitive_with_effect_p.bind( - x, effect_class_name="TestingOrderedEffect2") + + x, effect_class_name="ForTestingOrderedEffect2") + testing_primitive_with_effect_p.bind( - x, effect_class_name="TestingUnorderedEffect1") + + x, effect_class_name="ForTestingUnorderedEffect1") + export.call_exported(exp)(x)) lowered_outer = jax.jit(f_outer).lower(x) if exp.mlir_module_serialization_version < export._VERSION_START_SUPPORT_EFFECTS_WITH_REAL_TOKENS: - self.assertEqual(["TestingOrderedEffect2"], + self.assertEqual(["ForTestingOrderedEffect2()"], [str(e) for e in lowered_outer._lowering.compile_args["ordered_effects"]]) else: - self.assertEqual(["TestingOrderedEffect1", "TestingOrderedEffect2"], + self.assertEqual(["ForTestingOrderedEffect1()", "ForTestingOrderedEffect2()"], sorted(str(e) for e in lowered_outer._lowering.compile_args["ordered_effects"])) - self.assertEqual(["TestingUnorderedEffect1"], + self.assertEqual(["ForTestingUnorderedEffect1()"], sorted([str(e) for e in lowered_outer._lowering.compile_args["unordered_effects"]])) mlir_outer_module_str = str(lowered_outer.compiler_ir()) @@ -1106,7 +1120,7 @@ def test_ordered_effects_poly(self, *, v: int): self.override_serialization_version(v) x = np.arange(12, dtype=np.float32).reshape((3, 4)) def f_jax(x): # x: f32[b1, b2] - return 10. + testing_primitive_with_effect_p.bind(x, effect_class_name="TestingOrderedEffect1") + return 10. + testing_primitive_with_effect_p.bind(x, effect_class_name="ForTestingOrderedEffect1") exp = get_exported(f_jax)(jax.ShapeDtypeStruct( export.symbolic_shape("b2, b1"), x.dtype)) mlir_module_str = str(exp.mlir_module()) @@ -1150,7 +1164,7 @@ def test_ordered_effects_multi_platform_and_poly(self, *, v: int): x = np.ones((3, 4), dtype=np.float32) def f_jax(x): # x: f32[b1, b2] return 10. + _testing_multi_platform_func(x, - effect_class_name="TestingOrderedEffect1") + effect_class_name="ForTestingOrderedEffect1") exp = get_exported( f_jax, lowering_platforms=("cpu", "tpu") @@ -1196,7 +1210,7 @@ def test_ordered_effects_with_donation(self, *, v: int): def f_jax(x): return testing_primitive_with_effect_p.bind( - x, effect_class_name="TestingOrderedEffect1" + x, effect_class_name="ForTestingOrderedEffect1" ) f_jax = jax.jit(f_jax, donate_argnums=(0,)) @@ -1209,6 +1223,24 @@ def f_jax(x): self.assertRegex(mlir_module_str, r"@main.*tf.aliasing_output = 1") self.assertRegex(mlir_module_str, r"@_wrapped_jax_export_main.*tf.aliasing_output = 1") + @jtu.parameterized_filterable( + kwargs=[ + dict(name=name, expect_error=expect_error) + # name is the suffix for event name: ForTestingOrderedEffectxxx + for name, expect_error in ( + ("4NoNullary", "must have a nullary constructor"), + ("5NoEq", "must have a nullary class constructor that produces an " + "equal effect object"), + ) + ]) + def test_ordered_effects_error(self, *, name: str, expect_error: str): + x = np.ones((3, 4), dtype=np.float32) + def f_jax(x): + return 10. + _testing_multi_platform_func( + x, + effect_class_name="ForTestingOrderedEffect" + name) + with self.assertRaisesRegex(Exception, expect_error): + _ = get_exported(f_jax)(jax.ShapeDtypeStruct((3, 4), x.dtype)) if __name__ == "__main__": absltest.main(testLoader=jtu.JaxTestLoader())