From 5b0b8ac58aa43b55de32eedd65c47e20e43db52b Mon Sep 17 00:00:00 2001 From: George Necula Date: Tue, 26 Jul 2022 12:41:40 +0300 Subject: [PATCH] [jax2tf] Improved documentation and tests for pjit --- jax/experimental/jax2tf/README.md | 13 ++- .../jax2tf/tests/sharding_test.py | 103 +++++++++++++----- 2 files changed, 89 insertions(+), 27 deletions(-) diff --git a/jax/experimental/jax2tf/README.md b/jax/experimental/jax2tf/README.md index 921fabaa50be..d8b78523955a 100644 --- a/jax/experimental/jax2tf/README.md +++ b/jax/experimental/jax2tf/README.md @@ -254,6 +254,16 @@ You have two options, either pass `enable_gradients=False` to `jax2tf.convert`, set `tf.saved_model.SaveOption(experimental_custom_gradients=False)`. In either case, you will not be able to compute the gradients of the function loaded from the SavedModel. +## Support for partitioning + +jax2tf supports JAX functions that use `jax.pjit`, for single-host meshes. +The conversion is actually similar as for a `jax.jit`, except that the +arguments and results will be wrapped with +`tensorflow.compiler.xla.experimental.xla_sharding.XlaSharding` TensorFlow ops. + +Note that when saving a model, the parameters to the model are wrapped with +`tf.Variable` before calling the converted function (see [above](#saved_model_with_parameters)), +therefore outside of the `XlaSharding` wrapper. ## Shape-polymorphic conversion @@ -318,7 +328,6 @@ specification for the argument `x` of a function, JAX will know that a condition `x.shape[-2] == x.shape[-1]` is `True`, and will also know that `x` and `jnp.sin(x)` have the same shape of a batch of square matrices that can be passed to `jnp.matmul`. - ### Correctness of shape-polymorphic tracing We want to trust that the converted program produces the same results as the @@ -655,7 +664,7 @@ in [savedmodel_test.py](https://github.com/google/jax/blob/main/jax/experimental ### Missing converter features There is currently no support for `pmap` or`xmap`, nor for the collective -operations. There is support for `sharded_jit` and `pjit`. +operations. There is support for `pjit`. ### SavedModel may be large diff --git a/jax/experimental/jax2tf/tests/sharding_test.py b/jax/experimental/jax2tf/tests/sharding_test.py index 9729f4f7ee0c..03e04ffc4f49 100644 --- a/jax/experimental/jax2tf/tests/sharding_test.py +++ b/jax/experimental/jax2tf/tests/sharding_test.py @@ -15,6 +15,7 @@ import functools import logging +import os import re from typing import Any, Sequence import unittest @@ -30,7 +31,7 @@ from jax.experimental.jax2tf.tests import tf_test_util from jax.interpreters.pxla import PartitionSpec as P import jax.numpy as jnp -import jax._src.lib.xla_bridge +from jax._src.lib import xla_bridge import numpy as np @@ -38,15 +39,29 @@ config.parse_flags_with_absl() +prev_xla_flags = None def setUpModule(): + global prev_xla_flags + prev_xla_flags = os.getenv("XLA_FLAGS") + flags_str = prev_xla_flags or "" + # Don't override user-specified device count, or other XLA flags. + if "xla_force_host_platform_device_count" not in flags_str: + os.environ["XLA_FLAGS"] = (flags_str + + " --xla_force_host_platform_device_count=8") + # Clear any cached backends so new CPU backend will pick up the env var. + xla_bridge.get_backend.cache_clear() jtu.set_spmd_lowering_flag(True) + def tearDownModule(): + if prev_xla_flags is None: + del os.environ["XLA_FLAGS"] + else: + os.environ["XLA_FLAGS"] = prev_xla_flags + xla_bridge.get_backend.cache_clear() jtu.restore_spmd_lowering_flag() -LOG_HLO = True - class ShardedJitHloTest(tf_test_util.JaxToTfTestCase): """Tests that inspect the HLO for the sharding annotations. @@ -59,7 +74,8 @@ def _check_sharding_annotations(self, *, expected: Sequence[str], expected_opt: Sequence[str], - num_partitions=2): + num_partitions=2, + num_variables=0): """Check expected patterns in the HLO generated from f_jax and its conversion. We run this check on CPU also, which is useful for debugging locally. @@ -69,18 +85,20 @@ def _check_sharding_annotations(self, See `self.AssertShardingAnnotations` for documentation of `expected` and `expected_opt`. + + num_variables: the number of `args` to be wrapped with tf.Variable. """ if jtu.device_under_test() == "gpu": raise unittest.SkipTest("Sharding HLO tests not useful for GPU") jax_comp = f_jax.lower(*args).compiler_ir(dialect="hlo") jax_hlo = jax_comp.as_hlo_text() - if LOG_HLO: - logging.info("[%s] got JAX HLO %s", self._testMethodName, jax_hlo) - self.AssertShardingAnnotations("JAX before optimizations", jax_hlo, expected) + logging.info("[%s] got JAX HLO %s", self._testMethodName, jax_hlo) + self._assert_sharding_annotations("JAX before optimizations", jax_hlo, expected) + # We only dump JAX optimized code on the TPU if jtu.device_under_test() == "tpu": - backend = jax._src.lib.xla_bridge.get_backend() + backend = xla_bridge.get_backend() num_replicas = 1 device_assignment = np.arange(num_partitions * num_replicas) device_assignment = np.reshape(device_assignment, (-1, num_partitions)) @@ -93,30 +111,37 @@ def _check_sharding_annotations(self, ) jax_optimized_hlo = backend.compile( jax_comp, compile_options).hlo_modules()[0].to_string() - if LOG_HLO: - logging.info("[%s] got JAX optimized HLO for platform %s %s", + logging.info("[%s] got JAX optimized HLO for platform %s %s", self._testMethodName, backend.platform, jax_optimized_hlo) - self.AssertShardingAnnotations("JAX after optimizations", - jax_optimized_hlo, expected_opt) - - f_tf = jax2tf.convert(f_jax) + self._assert_sharding_annotations("JAX after optimizations", + jax_optimized_hlo, expected_opt) + + f_tf_base = jax2tf.convert(f_jax, with_gradient=False) + if num_variables > 0: + args_vars = [tf.Variable(a) for a in args[:num_variables]] + args = args[:num_variables] + f_tf = lambda *inputs: f_tf_base(*args_vars, *inputs) + else: + f_tf = f_tf_base + f_tf_fun = tf.function(f_tf, jit_compile=True, autograph=False) + logging.info("[%s] Got TF graph %s", + self._testMethodName, + f_tf_fun.get_concrete_function(*args).graph.as_graph_def()) device_name = f"/device:{jtu.device_under_test().upper()}:0" - tf_hlo = (tf.function(f_tf, jit_compile=True, autograph=False) + tf_hlo = (f_tf_fun .experimental_get_compiler_ir(*args)(stage="hlo", device_name=device_name)) - if LOG_HLO: - logging.info("[%s] got TF HLO %s", self._testMethodName, tf_hlo) - self.AssertShardingAnnotations("TF before optimizations", tf_hlo, expected) + logging.info("[%s] got TF HLO %s", self._testMethodName, tf_hlo) + self._assert_sharding_annotations("TF before optimizations", tf_hlo, expected) tf_optimized_hlo = ( tf.function(f_tf, jit_compile=True) .experimental_get_compiler_ir(*args)(stage="optimized_hlo", device_name=device_name)) - if LOG_HLO: - logging.info("[%s] got TF optimized HLO for %s: %s", self._testMethodName, - device_name, tf_optimized_hlo) + logging.info("[%s] got TF optimized HLO for %s: %s", self._testMethodName, + device_name, tf_optimized_hlo) - def AssertShardingAnnotations(self, what: str, hlo: str, - expected: Sequence[str]): + def _assert_sharding_annotations(self, what: str, hlo: str, + expected: Sequence[str]): """Args: what: either 'JAX' or 'TF', used for messages only. @@ -153,8 +178,8 @@ def jax_func(x, y): shape = (8, 10) x = np.arange(np.prod(shape), dtype=np.float32).reshape(shape) hlo = jax_func.lower(x, x).compiler_ir(dialect="hlo").as_hlo_text() - print(f"HLO is {hlo}") - print(f"JAXPR is {jax.make_jaxpr(jax_func)(x, x)}") + logging.info("HLO is %s", hlo) + logging.info("JAXPR is %s", jax.make_jaxpr(jax_func)(x, x)) self._check_sharding_annotations( jax_func, [x, x], expected=[ @@ -168,6 +193,34 @@ def jax_func(x, y): ], num_partitions=2) + @jtu.with_mesh([("x", 2)]) + def test_pjit_basic1D_variable(self): + # The first argument is a tf.Variable + @functools.partial(pjit.pjit, + in_axis_resources=(P("x"), P("x")), + out_axis_resources=None) + def jax_func(x, y): + return x + y + + shape = (8, 10) + x = np.arange(np.prod(shape), dtype=np.float32).reshape(shape) + hlo = jax_func.lower(x, x).compiler_ir(dialect="hlo").as_hlo_text() + logging.info("HLO is %s", hlo) + logging.info("JAXPR is %s", jax.make_jaxpr(jax_func)(x, x)) + self._check_sharding_annotations( + jax_func, [x, x], + expected=[ + r"f32\[8,10\].*sharding={devices=\[2,1\]", # x and y + r"f32\[8,10\].*sharding={replicated", # output + ], + expected_opt=[ + r"f32\[4,10\].*sharding={devices=\[2,1\]", # x and y + # TODO: why don't we see "sharding={replicated" + r"f32\[8,10\]", # output + ], + num_partitions=2, + num_variables=1) + @jtu.with_mesh([("x", 2), ("y", 2)]) def test_pjit_basic2D(self): @functools.partial(pjit.pjit,