Skip to content

Commit

Permalink
Instead of tf.Graph protobuf, we switch to tf.Saved_model for back_…
Browse files Browse the repository at this point in the history
…compat_tf_test.

PiperOrigin-RevId: 555500398
  • Loading branch information
maxwillzq authored and jax authors committed Aug 10, 2023
1 parent 1ddc340 commit cf026ce
Show file tree
Hide file tree
Showing 3 changed files with 98 additions and 53 deletions.
2 changes: 1 addition & 1 deletion jax/experimental/jax2tf/tests/back_compat_test.py
Expand Up @@ -109,7 +109,7 @@ def test_custom_call_coverage(self):
cpu_schur_lapack_gees.data_2023_07_16,
cpu_svd_lapack_gesdd.data_2023_06_19,
cpu_triangular_solve_blas_trsm.data_2023_07_16,
tf_call_tf_function.data_2023_06_02, # This is tested in back_compat_tf_test.py
tf_call_tf_function.data_2023_07_29, # This is tested in back_compat_tf_test.py
tpu_Eigh.data, tpu_Lu.data_2023_03_21, tpu_Qr.data_2023_03_17,
tpu_Sharding.data_2023_03_16, tpu_ApproxTopK.data_2023_04_17,
tpu_ApproxTopK.data_2023_05_16,
Expand Down

Large diffs are not rendered by default.

107 changes: 72 additions & 35 deletions jax/experimental/jax2tf/tests/back_compat_tf_test.py
Expand Up @@ -17,28 +17,48 @@
these tests.
"""

import base64
from collections.abc import Sequence
import io
import os
import tarfile
from typing import Callable, Optional

from absl.testing import absltest

from jax import config
from jax._src import test_util as jtu
from jax._src.lib import xla_extension
from jax.experimental import jax2tf
from jax.experimental.jax2tf.tests import back_compat_test_util as bctu

from jax.experimental.jax2tf.tests.back_compat_testdata import tf_call_tf_function

import jax.numpy as jnp

from jax._src.lib import xla_extension
from jax._src import test_util as jtu

import tensorflow as tf
from tensorflow.core.framework import graph_pb2 # type: ignore[import]


config.parse_flags_with_absl()


def serialize_directory(directory_path):
"""Seriliaze the directory as a string."""
tar_buffer = io.BytesIO()
with tarfile.open(fileobj=tar_buffer, mode="w") as tar:
tar.add(directory_path, arcname=os.path.basename(directory_path))

# Convert the binary data to a base64-encoded string
serialized_string = base64.b64encode(tar_buffer.getvalue())
return serialized_string


def deserialize_directory(serialized_string, output_directory):
"""Deserialize the string to the diretory."""
# Convert the base64-encoded string back to binary data
tar_data = base64.b64decode(serialized_string)

# Extract the tar archive to the output directory
with tarfile.open(fileobj=io.BytesIO(tar_data), mode="r") as tar:
tar.extractall(output_directory)


class CompatTensoflowTest(bctu.CompatTestBase):
"""Compatibility tests that use TF.
Expand All @@ -48,9 +68,8 @@ class CompatTensoflowTest(bctu.CompatTestBase):
"""

def run_current(self, func: Callable, data: bctu.CompatTestData):
# Is there a better way to serialize/deserialize TF functions? I thought
# about using tf.saved_model, but then we have to zip/unzip a whole
# directory.
# Here we use tf.saved_model and provide string serialize/deserialize methods
# for the whole directory.
@tf.function(autograph=False, jit_compile=True)
def tf_func(the_input): # Use recognizeable names for input and result
res = jax2tf.convert(func, native_serialization=True)(the_input)
Expand All @@ -59,54 +78,72 @@ def tf_func(the_input): # Use recognizeable names for input and result
self.tf_func = tf_func
return tf_func(*data.inputs) # type: ignore

def serialize(self, func: Callable, data: bctu.CompatTestData,
polymorphic_shapes: Optional[Sequence[str]] = None,
allow_additional_custom_call_targets: Sequence[str] = ()):
def serialize(
self,
func: Callable,
data: bctu.CompatTestData,
polymorphic_shapes: Optional[Sequence[str]] = None,
allow_additional_custom_call_targets: Sequence[str] = (),
):
# We serialize as a tf.Graph
assert len(data.inputs) == 1 # We only support a single input now
tf_graph = self.tf_func.get_concrete_function(*data.inputs).graph
for op in tf_graph.get_operations():
if op.type == "XlaCallModule":
serialized_module = op.get_attr("module")
module_str = xla_extension.mlir.deserialize_portable_artifact(
serialized_module)
serialized_module
)
module_version = op.get_attr("version")
break
else:
raise ValueError("Cannot find an XlaCallModule")
tf_graph_def = tf_graph.as_graph_def()
# module_str is just for human readability, add both the MLIR module
# and the tf.Graph
module_str = ("# First the MLIR module:\n" + module_str +
"\n# Then the tf.Graph:\n" + str(tf_graph_def))
serialized = tf_graph_def.SerializeToString()
module_str = (
"# First the MLIR module:\n"
+ module_str
+ "\n# Then the tf.Graph:\n"
+ str(tf_graph_def)
)
# serialized = tf_graph_def.SerializeToString()
module = tf.Module()
module.call = self.tf_func.get_concrete_function(*data.inputs)
root_dir = self.create_tempdir()
saved_model_dir = os.path.join(root_dir, "saved_model")
os.mkdir(saved_model_dir)
tf.saved_model.save(
module,
saved_model_dir,
options=tf.saved_model.SaveOptions(experimental_custom_gradients=False),
)
serialized = serialize_directory(saved_model_dir)
return serialized, module_str, module_version

def run_serialized(self, data: bctu.CompatTestData,
polymorphic_shapes: Optional[Sequence[str]] = None):
loaded_f_tf_graph = graph_pb2.GraphDef()
loaded_f_tf_graph.ParseFromString(data.mlir_module_serialized)

@tf.function(autograph=False)
def loaded_fun(x):
result = tf.import_graph_def(loaded_f_tf_graph,
input_map={"the_input": x},
return_elements=["the_result:0"])
return result[0]

return (loaded_fun(*data.inputs).numpy(),)
def run_serialized(
self,
data: bctu.CompatTestData,
polymorphic_shapes: Optional[Sequence[str]] = None,
):
root_dir = self.create_tempdir()
deserialize_directory(data.mlir_module_serialized, root_dir)
saved_model_dir = os.path.join(root_dir, "saved_model")
loaded_model = tf.saved_model.load(saved_model_dir)
return (loaded_model.call(*data.inputs).numpy(),)

def test_tf_call_tf_function(self):
self.skipTest("b/286409830: brittle on function naming.")
# A custom call tf.call_tf_function is generated when we lower call_tf
# with the call_tf_graph=True option.
def func(x):
def func_tf(x):
return tf.math.sin(x)
return jnp.cos(jax2tf.call_tf(func_tf, output_shape_dtype=x,
call_tf_graph=True)(x))

data = self.load_testdata(tf_call_tf_function.data_2023_06_02)
return jnp.cos(
jax2tf.call_tf(func_tf, output_shape_dtype=x, call_tf_graph=True)(x)
)

data = self.load_testdata(tf_call_tf_function.data_2023_07_29)
self.run_one_test(func, data)


Expand Down

0 comments on commit cf026ce

Please sign in to comment.