Skip to content

Commit

Permalink
[jax2tf] Simplify back_compat_test.py to use jax_export mechanisms to…
Browse files Browse the repository at this point in the history
… run

the serialized module, instead of relying on tf.XlaCallModule.

PiperOrigin-RevId: 528061968
  • Loading branch information
gnecula authored and chrisflesher committed Jun 3, 2023
1 parent 2a561b6 commit 7b82d5e
Showing 1 changed file with 45 additions and 71 deletions.
116 changes: 45 additions & 71 deletions jax/experimental/jax2tf/tests/back_compat_test.py
Expand Up @@ -78,8 +78,10 @@ def func(...): ...

import jax
from jax import config
from jax import core
from jax import lax
from jax.experimental import jax2tf
from jax import tree_util
from jax.experimental.jax2tf import jax_export
from jax.experimental.jax2tf.tests.back_compat_testdata import cpu_ducc_fft
from jax.experimental.jax2tf.tests.back_compat_testdata import cpu_lapack_geqrf
from jax.experimental.jax2tf.tests.back_compat_testdata import cpu_lapack_syev
Expand All @@ -98,17 +100,10 @@ def func(...): ...
from jax.sharding import Mesh
from jax.sharding import PartitionSpec as P

from jax._src import core
from jax._src import test_util as jtu
from jax._src.interpreters import pxla
from jax._src import xla_bridge as xb

import tensorflow as tf # type: ignore[import]

# pylint: disable=g-direct-tensorflow-import
from tensorflow.compiler.tf2xla.python import xla as tfxla # type: ignore[import]
# pylint: enable=g-direct-tensorflow-import


config.parse_flags_with_absl()


Expand Down Expand Up @@ -183,16 +178,13 @@ class CompatTest(jtu.JaxTestCase):

def run_one_test(self, func: Callable[..., jax.Array],
data: CompatTestData,
run_tf = None,
rtol = None,
check_results: Optional[Callable[..., None]] = None):
"""Run one compatibility test.
Args:
func: the JAX function to serialize and run
data: the test data
run_tf: (optional) a function to invoke the XlaCallModule TF op. Takes
a TensorFlow callable and the arguments.
rtol: relative tolerance for numerical comparisons
check_results: invoked with the results obtained from running the
serialized code, and those stored in the test data, and the kwarg rtol.
Expand All @@ -207,7 +199,7 @@ def run_one_test(self, func: Callable[..., jax.Array],
res_from_jax_run_now = tuple(np.array(a) for a in res_from_jax_run_now)

# Use the native exporter, to make sure we get the proper serialized module.
exported = jax2tf.jax_export.export(
exported = jax_export.export(
jax.jit(func),
lowering_platform=default_jax_backend(),
# Must turn off strict checks because the custom calls may be unallowed.
Expand All @@ -230,13 +222,15 @@ def run_one_test(self, func: Callable[..., jax.Array],
serialized_date={repr(datetime.date.today())},
inputs={repr(data.inputs)},
expected_outputs={repr(res_from_jax_run_now)},
mlir_module_text=\"\"\"\n{module_str}\"\"\",
mlir_module_text=r\"\"\"\n{module_str}\"\"\",
mlir_module_serialized={repr(exported.mlir_module_serialized)},
xla_call_module_version={exported.xla_call_module_version},
) # End paste
"""
output_dir = os.getenv("TEST_UNDECLARED_OUTPUTS_DIR",
"/tmp/back_compat_testdata")
if not os.path.exists(output_dir):
os.makedirs(output_dir)
output_file = os.path.join(output_dir, f"{self._testMethodName}.py")
logging.info("Writing the up-to-date testdata at %s", output_file)
with open(output_file, "w") as f:
Expand All @@ -249,49 +243,41 @@ def run_one_test(self, func: Callable[..., jax.Array],
else:
self.assertAllClose(res_from_jax_run_now, data.expected_outputs, rtol=rtol)

res_from_serialized_run_now = self.run_serialized(data, run_tf=run_tf)
res_from_serialized_run_now = self.run_serialized(data)
logging.info("Result of serialized run is %s", res_from_serialized_run_now)
if check_results is not None:
check_results(res_from_serialized_run_now, data.expected_outputs, rtol=rtol)
else:
self.assertAllClose(res_from_serialized_run_now, data.expected_outputs, rtol=rtol)
self.assertListEqual(custom_call_targets, data.custom_call_targets)

def run_serialized(self, data: CompatTestData, run_tf=None):
# Run the serialized module. For now, use XlaCallModule. This has the
# disadvantage that it brings TF and jax2tf in the picture, but has the
# advantage that it is simple (e.g., XlaCallModule already has the
# machinery to deserialize and run), and also it is the way users actually
# run serialized modules today.
# TODO(necula): come up with a JAX-native way of running serialized modules.
tf_preferred_devices = (
tf.config.list_logical_devices("TPU")
+ tf.config.list_logical_devices("GPU")
+ tf.config.list_logical_devices()
)
# We need --config=cuda build flag for TF to see the GPUs
self.assertEqual(
jtu.device_under_test().upper(), tf_preferred_devices[0].device_type
)

def f_tf(*args_tf):
return tfxla.call_module(
args_tf,
version=data.xla_call_module_version,
Tout=[r.dtype for r in res_tf],
Sout=[r.shape for r in res_tf],
module=data.mlir_module_serialized,
platforms=[data.platform.upper()])

# We need this to run the TPU code on the TPU
with tf.device(tf_preferred_devices[0]):
args_tf = [tf.constant(a) for a in data.inputs]
res_tf = [tf.constant(r) for r in data.expected_outputs]
if run_tf is not None:
res = run_tf(f_tf, *args_tf)
else:
res = f_tf(*args_tf)
return tuple(r.numpy() for r in res)
def run_serialized(self, data: CompatTestData):
def ndarray_to_aval(a: np.ndarray) -> core.ShapedArray:
return core.ShapedArray(a.shape, a.dtype)
in_avals_tree = tree_util.tree_map(ndarray_to_aval, data.inputs)
out_avals_tree = tree_util.tree_map(ndarray_to_aval, data.expected_outputs)
# in_tree must be for (args, kwargs)
in_avals, in_tree = tree_util.tree_flatten((in_avals_tree, {}))
out_avals, out_tree = tree_util.tree_flatten(out_avals_tree)
def _get_vjp(_):
assert False # We do not have and do not need VJP
exported = jax_export.Exported(
fun_name="run_serialized",
in_tree=in_tree,
in_avals=tuple(in_avals),
out_tree=out_tree,
out_avals=tuple(out_avals),
in_shardings=(pxla.UNSPECIFIED,) * len(in_avals),
out_shardings=(pxla.UNSPECIFIED,) * len(out_avals),
lowering_platform=data.platform,
strict_checks=True,
mlir_module_serialized=data.mlir_module_serialized,
xla_call_module_version=data.xla_call_module_version,
module_kept_var_idx=tuple(range(len(in_avals))),
_get_vjp=_get_vjp)

# We use pjit in case there are shardings in the exported module.
return pjit.pjit(jax_export.call_exported(exported))(*data.inputs)

def test_dummy(self):
# Tests the test mechanism. Let this test run on all platforms
Expand All @@ -318,7 +304,7 @@ def test_detect_different_custom_calls(self):
self.run_one_test(jnp.sin, platform_dummy_data)

def test_custom_call_coverage(self):
targets_to_cover = set(jax2tf.jax_export._CUSTOM_CALL_TARGETS_GUARANTEED_STABLE)
targets_to_cover = set(jax_export._CUSTOM_CALL_TARGETS_GUARANTEED_STABLE)
# Add here all the testdatas that should cover the targets guaranteed
# stable
covering_testdatas = [
Expand Down Expand Up @@ -362,11 +348,16 @@ def check_eigh_results(self, operand, res_now, res_expected, *,
_, w_expected = res_expected
n, m = operand.shape
assert n == m
assert v_now.shape == operand.shape
assert w_now.shape == (n,)
self.assertLessEqual(
np.linalg.norm(np.eye(n) - np.matmul(np.conj(np.swapaxes(v_now, -1, -2)), v_now)),
rtol)
self.assertLessEqual(np.linalg.norm(np.matmul(operand, v_now) - w_now * v_now),
rtol * np.linalg.norm(operand))
# w_now : f64[n] while v_now: c128[n, n]
w_now_like_v = w_now[np.newaxis, :].astype(v_now.dtype)
self.assertLessEqual(
np.linalg.norm(np.matmul(operand, v_now) - w_now_like_v * v_now),
rtol * np.linalg.norm(operand))
self.assertAllClose(w_expected, w_now, rtol=rtol)

@parameterized.named_parameters(
Expand Down Expand Up @@ -492,26 +483,9 @@ def func(x): # b: f32[2, 4]
perm = [(j, (j + 1) % axis_size) for j in range(axis_size)]
return lax.ppermute(x, 'a', perm=perm)

# We need these only because for now we run the serialized module with TF
tf_tpus = tf.config.list_logical_devices("TPU")
self.assertNotEmpty(tf_tpus)

resolver = tf.distribute.cluster_resolver.TPUClusterResolver(tpu='')
tf.config.experimental_connect_to_cluster(resolver)
topology = tf.tpu.experimental.initialize_tpu_system(resolver)
device_assignment = tf.tpu.experimental.DeviceAssignment.build(
topology, computation_shape=[1, 1, 1, 2],
num_replicas=1)
def run_tf(f_tf, x):
def wrapped_f_tf(x):
return tf.compat.v1.tpu.rewrite(
f_tf, [tf.convert_to_tensor(x)],
device_assignment=device_assignment)
return tf.function(wrapped_f_tf, autograph=False, jit_compile=True)(x)

data = load_testdata(tpu_Sharding.data_2023_03_16)
with mesh:
self.run_one_test(func, data, run_tf=run_tf)
self.run_one_test(func, data)


if __name__ == "__main__":
Expand Down

0 comments on commit 7b82d5e

Please sign in to comment.