Skip to content

Commit

Permalink
[jax2tf] Improved documentation and tests for pjit
Browse files Browse the repository at this point in the history
  • Loading branch information
gnecula committed Jul 26, 2022
1 parent 48edd21 commit 5b0b8ac
Show file tree
Hide file tree
Showing 2 changed files with 89 additions and 27 deletions.
13 changes: 11 additions & 2 deletions jax/experimental/jax2tf/README.md
Expand Up @@ -254,6 +254,16 @@ You have two options, either pass `enable_gradients=False` to `jax2tf.convert`,
set `tf.saved_model.SaveOption(experimental_custom_gradients=False)`. In either case,
you will not be able to compute the gradients of the function loaded from the SavedModel.

## Support for partitioning

jax2tf supports JAX functions that use `jax.pjit`, for single-host meshes.
The conversion is actually similar as for a `jax.jit`, except that the
arguments and results will be wrapped with
`tensorflow.compiler.xla.experimental.xla_sharding.XlaSharding` TensorFlow ops.

Note that when saving a model, the parameters to the model are wrapped with
`tf.Variable` before calling the converted function (see [above](#saved_model_with_parameters)),
therefore outside of the `XlaSharding` wrapper.

## Shape-polymorphic conversion

Expand Down Expand Up @@ -318,7 +328,6 @@ specification for the argument `x` of a function, JAX will know that a condition
`x.shape[-2] == x.shape[-1]` is `True`, and will also know that `x` and `jnp.sin(x)` have the
same shape of a batch of square matrices that can be passed to `jnp.matmul`.


### Correctness of shape-polymorphic tracing

We want to trust that the converted program produces the same results as the
Expand Down Expand Up @@ -655,7 +664,7 @@ in [savedmodel_test.py](https://github.com/google/jax/blob/main/jax/experimental
### Missing converter features

There is currently no support for `pmap` or`xmap`, nor for the collective
operations. There is support for `sharded_jit` and `pjit`.
operations. There is support for `pjit`.

### SavedModel may be large

Expand Down
103 changes: 78 additions & 25 deletions jax/experimental/jax2tf/tests/sharding_test.py
Expand Up @@ -15,6 +15,7 @@

import functools
import logging
import os
import re
from typing import Any, Sequence
import unittest
Expand All @@ -30,23 +31,37 @@
from jax.experimental.jax2tf.tests import tf_test_util
from jax.interpreters.pxla import PartitionSpec as P
import jax.numpy as jnp
import jax._src.lib.xla_bridge
from jax._src.lib import xla_bridge

import numpy as np

import tensorflow as tf # type: ignore[import]

config.parse_flags_with_absl()

prev_xla_flags = None
def setUpModule():
global prev_xla_flags
prev_xla_flags = os.getenv("XLA_FLAGS")
flags_str = prev_xla_flags or ""
# Don't override user-specified device count, or other XLA flags.
if "xla_force_host_platform_device_count" not in flags_str:
os.environ["XLA_FLAGS"] = (flags_str +
" --xla_force_host_platform_device_count=8")
# Clear any cached backends so new CPU backend will pick up the env var.
xla_bridge.get_backend.cache_clear()
jtu.set_spmd_lowering_flag(True)


def tearDownModule():
if prev_xla_flags is None:
del os.environ["XLA_FLAGS"]
else:
os.environ["XLA_FLAGS"] = prev_xla_flags
xla_bridge.get_backend.cache_clear()
jtu.restore_spmd_lowering_flag()


LOG_HLO = True

class ShardedJitHloTest(tf_test_util.JaxToTfTestCase):
"""Tests that inspect the HLO for the sharding annotations.
Expand All @@ -59,7 +74,8 @@ def _check_sharding_annotations(self,
*,
expected: Sequence[str],
expected_opt: Sequence[str],
num_partitions=2):
num_partitions=2,
num_variables=0):
"""Check expected patterns in the HLO generated from f_jax and its conversion.
We run this check on CPU also, which is useful for debugging locally.
Expand All @@ -69,18 +85,20 @@ def _check_sharding_annotations(self,
See `self.AssertShardingAnnotations` for documentation of `expected`
and `expected_opt`.
num_variables: the number of `args` to be wrapped with tf.Variable.
"""
if jtu.device_under_test() == "gpu":
raise unittest.SkipTest("Sharding HLO tests not useful for GPU")

jax_comp = f_jax.lower(*args).compiler_ir(dialect="hlo")
jax_hlo = jax_comp.as_hlo_text()
if LOG_HLO:
logging.info("[%s] got JAX HLO %s", self._testMethodName, jax_hlo)
self.AssertShardingAnnotations("JAX before optimizations", jax_hlo, expected)
logging.info("[%s] got JAX HLO %s", self._testMethodName, jax_hlo)
self._assert_sharding_annotations("JAX before optimizations", jax_hlo, expected)

# We only dump JAX optimized code on the TPU
if jtu.device_under_test() == "tpu":
backend = jax._src.lib.xla_bridge.get_backend()
backend = xla_bridge.get_backend()
num_replicas = 1
device_assignment = np.arange(num_partitions * num_replicas)
device_assignment = np.reshape(device_assignment, (-1, num_partitions))
Expand All @@ -93,30 +111,37 @@ def _check_sharding_annotations(self,
)
jax_optimized_hlo = backend.compile(
jax_comp, compile_options).hlo_modules()[0].to_string()
if LOG_HLO:
logging.info("[%s] got JAX optimized HLO for platform %s %s",
logging.info("[%s] got JAX optimized HLO for platform %s %s",
self._testMethodName, backend.platform, jax_optimized_hlo)
self.AssertShardingAnnotations("JAX after optimizations",
jax_optimized_hlo, expected_opt)

f_tf = jax2tf.convert(f_jax)
self._assert_sharding_annotations("JAX after optimizations",
jax_optimized_hlo, expected_opt)

f_tf_base = jax2tf.convert(f_jax, with_gradient=False)
if num_variables > 0:
args_vars = [tf.Variable(a) for a in args[:num_variables]]
args = args[:num_variables]
f_tf = lambda *inputs: f_tf_base(*args_vars, *inputs)
else:
f_tf = f_tf_base
f_tf_fun = tf.function(f_tf, jit_compile=True, autograph=False)
logging.info("[%s] Got TF graph %s",
self._testMethodName,
f_tf_fun.get_concrete_function(*args).graph.as_graph_def())
device_name = f"/device:{jtu.device_under_test().upper()}:0"
tf_hlo = (tf.function(f_tf, jit_compile=True, autograph=False)
tf_hlo = (f_tf_fun
.experimental_get_compiler_ir(*args)(stage="hlo",
device_name=device_name))
if LOG_HLO:
logging.info("[%s] got TF HLO %s", self._testMethodName, tf_hlo)
self.AssertShardingAnnotations("TF before optimizations", tf_hlo, expected)
logging.info("[%s] got TF HLO %s", self._testMethodName, tf_hlo)
self._assert_sharding_annotations("TF before optimizations", tf_hlo, expected)
tf_optimized_hlo = (
tf.function(f_tf, jit_compile=True)
.experimental_get_compiler_ir(*args)(stage="optimized_hlo",
device_name=device_name))
if LOG_HLO:
logging.info("[%s] got TF optimized HLO for %s: %s", self._testMethodName,
device_name, tf_optimized_hlo)
logging.info("[%s] got TF optimized HLO for %s: %s", self._testMethodName,
device_name, tf_optimized_hlo)

def AssertShardingAnnotations(self, what: str, hlo: str,
expected: Sequence[str]):
def _assert_sharding_annotations(self, what: str, hlo: str,
expected: Sequence[str]):
"""Args:
what: either 'JAX' or 'TF', used for messages only.
Expand Down Expand Up @@ -153,8 +178,8 @@ def jax_func(x, y):
shape = (8, 10)
x = np.arange(np.prod(shape), dtype=np.float32).reshape(shape)
hlo = jax_func.lower(x, x).compiler_ir(dialect="hlo").as_hlo_text()
print(f"HLO is {hlo}")
print(f"JAXPR is {jax.make_jaxpr(jax_func)(x, x)}")
logging.info("HLO is %s", hlo)
logging.info("JAXPR is %s", jax.make_jaxpr(jax_func)(x, x))
self._check_sharding_annotations(
jax_func, [x, x],
expected=[
Expand All @@ -168,6 +193,34 @@ def jax_func(x, y):
],
num_partitions=2)

@jtu.with_mesh([("x", 2)])
def test_pjit_basic1D_variable(self):
# The first argument is a tf.Variable
@functools.partial(pjit.pjit,
in_axis_resources=(P("x"), P("x")),
out_axis_resources=None)
def jax_func(x, y):
return x + y

shape = (8, 10)
x = np.arange(np.prod(shape), dtype=np.float32).reshape(shape)
hlo = jax_func.lower(x, x).compiler_ir(dialect="hlo").as_hlo_text()
logging.info("HLO is %s", hlo)
logging.info("JAXPR is %s", jax.make_jaxpr(jax_func)(x, x))
self._check_sharding_annotations(
jax_func, [x, x],
expected=[
r"f32\[8,10\].*sharding={devices=\[2,1\]", # x and y
r"f32\[8,10\].*sharding={replicated", # output
],
expected_opt=[
r"f32\[4,10\].*sharding={devices=\[2,1\]", # x and y
# TODO: why don't we see "sharding={replicated"
r"f32\[8,10\]", # output
],
num_partitions=2,
num_variables=1)

@jtu.with_mesh([("x", 2), ("y", 2)])
def test_pjit_basic2D(self):
@functools.partial(pjit.pjit,
Expand Down

0 comments on commit 5b0b8ac

Please sign in to comment.