Skip to content

Commit

Permalink
[export] Fix the serialization of effects
Browse files Browse the repository at this point in the history
We currently support only the serialization of effects with
nullary constructors. We must also ensure that upon deserialization
we produce an event that tests equal to the original one.
Here we add explicit error checks and tests.

We also make the CallTfEffect to have this property.
  • Loading branch information
gnecula committed Dec 15, 2023
1 parent a7b6023 commit 552010a
Show file tree
Hide file tree
Showing 3 changed files with 76 additions and 39 deletions.
23 changes: 15 additions & 8 deletions jax/experimental/export/serialization.py
Expand Up @@ -386,17 +386,24 @@ def _deserialize_sharding(s: ser_flatbuf.Sharding) -> export.Sharding:


def _serialize_effect(builder: flatbuffers.Builder, eff: core.Effect) -> int:
# TODO(necula): for now serialize just the name of the class
try:
_ = eff.__class__()
except:
eff_replica = eff.__class__()
except Exception:
raise NotImplementedError(
f"serializing effect {eff} that does not have a nullary class"
" constructor"
f"Effect {eff} must have a nullary constructor to be serializable"
)
try:
hash_eff = hash(eff)
hash_eff_replica = hash(eff_replica)
except Exception:
raise NotImplementedError(
f"Effect {eff} must be hashable to be serializable"
)
if eff != eff_replica or hash_eff != hash_eff_replica:
raise NotImplementedError(
f"Effect {eff} must have a nullary class constructor that produces an "
"equal effect object."
)
# TODO: fix the effects serialization and deserialization, to ensure that
# upon deserialization we reconstruct an effect that compares equal to the
# one that was serialized.
effect_type_name = str(eff.__class__)
effect_type_name_offset = builder.CreateString(effect_type_name)
ser_flatbuf.EffectStart(builder)
Expand Down
6 changes: 2 additions & 4 deletions jax/experimental/jax2tf/call_tf.py
Expand Up @@ -26,6 +26,7 @@
from __future__ import annotations

from collections.abc import Sequence
import dataclasses
import functools
from typing import Any, Callable, Optional

Expand All @@ -35,20 +36,16 @@
from jax import dtypes
from jax import numpy as jnp
from jax import tree_util
from jax._src import ad_checkpoint
from jax._src import ad_util
from jax._src import core
from jax._src import custom_derivatives
from jax._src import effects
from jax._src import util
from jax._src.lax import control_flow as lax_control_flow
from jax._src.lib import xla_client
from jax._src.lib.mlir import ir
from jax._src.lib.mlir.dialects import func as func_dialect
from jax._src.lib.mlir.dialects import hlo
from jax.experimental.jax2tf import jax2tf as jax2tf_internal
from jax.interpreters import mlir
from jax.interpreters import xla
import numpy as np
import tensorflow as tf

Expand Down Expand Up @@ -376,6 +373,7 @@ def _get_concrete_function_tf(function_flat_tf, args_flat_sig_tf): # -> tf.Conc


# Mark the effectful instances of call_tf
@dataclasses.dataclass(frozen=True)
class CallTfEffect(effects.Effect):
__str__ = lambda _: "CallTfEffect"

Expand Down
86 changes: 59 additions & 27 deletions tests/export_test.py
Expand Up @@ -14,6 +14,7 @@
from __future__ import annotations

import contextlib
import dataclasses
import functools
import logging
import math
Expand Down Expand Up @@ -57,19 +58,35 @@ def tearDownModule():
prev_xla_flags()

### Setup for testing lowering with effects
class TestingOrderedEffect1(effects.Effect):
__str__ = lambda _: "TestingOrderedEffect1"
@dataclasses.dataclass(frozen=True)
class ForTestingOrderedEffect1(effects.Effect):
pass

class TestingOrderedEffect2(effects.Effect):
__str__ = lambda _: "TestingOrderedEffect2"
@dataclasses.dataclass(frozen=True)
class ForTestingOrderedEffect2(effects.Effect):
pass

@dataclasses.dataclass(frozen=True)
class ForTestingUnorderedEffect1(effects.Effect):
pass


class ForTestingOrderedEffect4NoNullary(effects.Effect):
def __init__(self, _):
pass

@dataclasses.dataclass(eq=False)
class ForTestingOrderedEffect5NoEq(effects.Effect):
pass

class TestingUnorderedEffect1(effects.Effect):
__str__ = lambda _: "TestingUnorderedEffect1"

_testing_effects = dict(
TestingOrderedEffect1=TestingOrderedEffect1(),
TestingOrderedEffect2=TestingOrderedEffect2(),
TestingUnorderedEffect1=TestingUnorderedEffect1())
ForTestingOrderedEffect1=ForTestingOrderedEffect1(),
ForTestingOrderedEffect2=ForTestingOrderedEffect2(),
ForTestingUnorderedEffect1=ForTestingUnorderedEffect1(),
ForTestingOrderedEffect4NoNullary=ForTestingOrderedEffect4NoNullary(42),
ForTestingOrderedEffect5NoEq=ForTestingOrderedEffect5NoEq(),
)
# Register the effects
for effect in _testing_effects.values():
effect_class = effect.__class__
Expand Down Expand Up @@ -1015,23 +1032,20 @@ def f_jax(x): # x: f32[3]
# Test also the calling convention for inner functions
def f_jax_inner(x):
return (
testing_primitive_with_effect_p.bind(x, effect_class_name="TestingOrderedEffect2") +
testing_primitive_with_effect_p.bind(x, effect_class_name="TestingUnorderedEffect1"))
testing_primitive_with_effect_p.bind(x, effect_class_name="ForTestingOrderedEffect2") +
testing_primitive_with_effect_p.bind(x, effect_class_name="ForTestingUnorderedEffect1"))
return (
10. +
jax.jit(f_jax_inner)(x) +
testing_primitive_with_effect_p.bind(x, effect_class_name="TestingOrderedEffect1") +
testing_primitive_with_effect_p.bind(x, effect_class_name="TestingOrderedEffect2")
testing_primitive_with_effect_p.bind(x, effect_class_name="ForTestingOrderedEffect1") +
testing_primitive_with_effect_p.bind(x, effect_class_name="ForTestingOrderedEffect2")
)

# TODO(necula): at the moment serializing and deserializing effects breaks
# the effect equality, and this results in this test failing. So, for now
# we disable the serization round-trip
exp = export.export(f_jax)(x) # get_exported(f_jax)(x)
exp = get_exported(f_jax)(x)
if exp.mlir_module_serialization_version >= export._VERSION_START_SUPPORT_EFFECTS_WITH_REAL_TOKENS:
self.assertEqual(["TestingOrderedEffect1", "TestingOrderedEffect2"],
self.assertEqual(["ForTestingOrderedEffect1()", "ForTestingOrderedEffect2()"],
sorted(str(e) for e in exp.ordered_effects))
self.assertEqual(["TestingUnorderedEffect1"],
self.assertEqual(["ForTestingUnorderedEffect1()"],
[str(e) for e in exp.unordered_effects])
else:
self.assertEqual([], [str(e) for e in exp.ordered_effects])
Expand Down Expand Up @@ -1074,19 +1088,19 @@ def f_jax_inner(x):
def f_outer(x):
return (
testing_primitive_with_effect_p.bind(
x, effect_class_name="TestingOrderedEffect2") +
x, effect_class_name="ForTestingOrderedEffect2") +
testing_primitive_with_effect_p.bind(
x, effect_class_name="TestingUnorderedEffect1") +
x, effect_class_name="ForTestingUnorderedEffect1") +
export.call_exported(exp)(x))

lowered_outer = jax.jit(f_outer).lower(x)
if exp.mlir_module_serialization_version < export._VERSION_START_SUPPORT_EFFECTS_WITH_REAL_TOKENS:
self.assertEqual(["TestingOrderedEffect2"],
self.assertEqual(["ForTestingOrderedEffect2()"],
[str(e) for e in lowered_outer._lowering.compile_args["ordered_effects"]])
else:
self.assertEqual(["TestingOrderedEffect1", "TestingOrderedEffect2"],
self.assertEqual(["ForTestingOrderedEffect1()", "ForTestingOrderedEffect2()"],
sorted(str(e) for e in lowered_outer._lowering.compile_args["ordered_effects"]))
self.assertEqual(["TestingUnorderedEffect1"],
self.assertEqual(["ForTestingUnorderedEffect1()"],
sorted([str(e) for e in lowered_outer._lowering.compile_args["unordered_effects"]]))

mlir_outer_module_str = str(lowered_outer.compiler_ir())
Expand All @@ -1106,7 +1120,7 @@ def test_ordered_effects_poly(self, *, v: int):
self.override_serialization_version(v)
x = np.arange(12, dtype=np.float32).reshape((3, 4))
def f_jax(x): # x: f32[b1, b2]
return 10. + testing_primitive_with_effect_p.bind(x, effect_class_name="TestingOrderedEffect1")
return 10. + testing_primitive_with_effect_p.bind(x, effect_class_name="ForTestingOrderedEffect1")
exp = get_exported(f_jax)(jax.ShapeDtypeStruct(
export.symbolic_shape("b2, b1"), x.dtype))
mlir_module_str = str(exp.mlir_module())
Expand Down Expand Up @@ -1150,7 +1164,7 @@ def test_ordered_effects_multi_platform_and_poly(self, *, v: int):
x = np.ones((3, 4), dtype=np.float32)
def f_jax(x): # x: f32[b1, b2]
return 10. + _testing_multi_platform_func(x,
effect_class_name="TestingOrderedEffect1")
effect_class_name="ForTestingOrderedEffect1")
exp = get_exported(
f_jax,
lowering_platforms=("cpu", "tpu")
Expand Down Expand Up @@ -1196,7 +1210,7 @@ def test_ordered_effects_with_donation(self, *, v: int):

def f_jax(x):
return testing_primitive_with_effect_p.bind(
x, effect_class_name="TestingOrderedEffect1"
x, effect_class_name="ForTestingOrderedEffect1"
)

f_jax = jax.jit(f_jax, donate_argnums=(0,))
Expand All @@ -1209,6 +1223,24 @@ def f_jax(x):
self.assertRegex(mlir_module_str, r"@main.*tf.aliasing_output = 1")
self.assertRegex(mlir_module_str, r"@_wrapped_jax_export_main.*tf.aliasing_output = 1")

@jtu.parameterized_filterable(
kwargs=[
dict(name=name, expect_error=expect_error)
# name is the suffix for event name: ForTestingOrderedEffectxxx
for name, expect_error in (
("4NoNullary", "must have a nullary constructor"),
("5NoEq", "must have a nullary class constructor that produces an "
"equal effect object"),
)
])
def test_ordered_effects_error(self, *, name: str, expect_error: str):
x = np.ones((3, 4), dtype=np.float32)
def f_jax(x):
return 10. + _testing_multi_platform_func(
x,
effect_class_name="ForTestingOrderedEffect" + name)
with self.assertRaisesRegex(Exception, expect_error):
_ = get_exported(f_jax)(jax.ShapeDtypeStruct((3, 4), x.dtype))

if __name__ == "__main__":
absltest.main(testLoader=jtu.JaxTestLoader())

0 comments on commit 552010a

Please sign in to comment.