Skip to content

Commit

Permalink
Merge pull request #13603 from gnecula:native_unused
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 494769977
  • Loading branch information
jax authors committed Dec 12, 2022
2 parents 5e8c0ec + 2f9dd04 commit 23001ae
Show file tree
Hide file tree
Showing 5 changed files with 207 additions and 41 deletions.
36 changes: 36 additions & 0 deletions jax/experimental/jax2tf/README.md
Expand Up @@ -381,6 +381,10 @@ lowered with the batch dimension polymorphic and the remaining dimensions concre

It is reasonable to expect that there will be JAX programs for which there is a
shape-polymorphic TensorFlow graph, but which will give an error when lowering with jax2tf.
In general, you should expect that shape polymorphism can handle those programs for which
all the intermediate shapes can be expressed as polynomials in the dimension variables
appearing in the input shapes. In particular, this does not include programs whose
intermediate shapes depend on the data.

### Details

Expand Down Expand Up @@ -613,6 +617,38 @@ jax2tf.convert(lambda x: jnp.reshape(x, (2, -1)),
polymorphic_shapes=["(2*b, ...)"])(np.ones((4, 5, 7)))
```

### Dimension variables must be solvable from the input shapes

`jax2tf` will generate code to derive the values of the dimension variables
from the input shapes. This works only if dimension polynomials in the input shapes are linear.
For example, the following `polymorphic_shapes` will result in errors:

```python
polymorphic_shapes = ["a * a"] # Not a linear polynomial
polymorphic_shapes = ["a + b"] # Too few equations to derive both `a` and `b`
```

If you are using native lowering, the restrictions are stronger: every dimension
variable must occur as the value of some dimension of some input, e.g.,
the following will work:

```python
polymorphic_shapes = ["a, 2*a, b"]
polymorphic_shapes = ["a * a, a"]
```

Furthermore, when using the native lowering the inputs that are not needed in the computation
are ignored, so the dimension variables must be derivable only from used inputs.
In the following example, the `x_unused` is not part of the computation so its
input shapes cannot be used for deriving the dimension variables, and you will
get an error that `a` cannot be derived:

```python
jax2tf.convert(lambda x_unused, y: y * 2.,
polymorphic_shapes=["b, a", "b, 2 * a"])(x, y)
```


## Known issues

`jax2tf` has been in use since 2020 and the vast majority of users encounter
Expand Down
116 changes: 77 additions & 39 deletions jax/experimental/jax2tf/jax2tf.py
Expand Up @@ -595,46 +595,28 @@ def _lower_native_and_run(fun_jax: Callable,
Special care must be taken in presence of shape polymorphism.
"""
# Look for shape polymorphism

# For each dimension variable, encode how to compute its value from the
# shape of the explicit arguments. E.g., "2.1" denotes args_tf[2].shape[1].
# The order of the dimension variables must match the order of the first N
# arguments of the lowered function.

# We now have two implementations for the native lowering. If --jax_dynamic_shapes
# then we use JAX's in-progress support for native dynamic shapes. In that
# case we assume that the dimension variables are listed in the order in which
# they are encountered by scanning the arguments and their shapes in order.
# If we don't use --jax_dynamic_shapes then the dimension variables are passed
# in the alphabetical order of their names.
abstracted_axes: Sequence[Dict[int, str]] = []
dim_args_spec_dict: Dict[str, str] = {} # map dim var name to dim_args_spec
dim_vars_seen: List[str] = [] # the dim var names in order
for arg_idx, aval in enumerate(args_avals):
one_abstract_axes = {}
for axis_idx, d in enumerate(aval.shape):
if not core.is_constant_dim(d):
d_var = d.to_var()
if d_var is None:
raise ValueError(f"Only simple dimension variables supported: {aval.shape}")
if not d_var in dim_vars_seen:
dim_args_spec_dict[d_var] = f"{arg_idx}.{axis_idx}"
dim_vars_seen.append(d_var)
one_abstract_axes[axis_idx] = d_var
abstracted_axes.append(one_abstract_axes)

if any(abstracted_axes):
if config.jax_dynamic_shapes:
# then we use JAX's in-progress support for native dynamic shapes, and we pass
# abstracted_axes to lowering functions. Otherwise, we just lower using
# abstract values whose shapes may include polynomials (already in args_avals).
if config.jax_dynamic_shapes:
abstracted_axes: Sequence[Dict[int, str]] = []
for arg_idx, aval in enumerate(args_avals):
one_abstract_axes = {}
for axis_idx, d in enumerate(aval.shape):
if not core.is_constant_dim(d):
d_var = d.to_var()
if d_var is None:
raise ValueError(f"Only trivial dimension polynomials on input: {aval.shape}")
one_abstract_axes[axis_idx] = d_var
abstracted_axes.append(one_abstract_axes)

if any(abstracted_axes):
abstracted_axes = tuple(abstracted_axes)
# In the order we have seen them
dim_args_spec = [dim_args_spec_dict[d_var] for d_var in dim_vars_seen]
else:
abstracted_axes = None # type: ignore
# In sorted order by name
dim_args_spec = [dim_args_spec_dict[d_var] for d_var in sorted(dim_vars_seen)]
else:
abstracted_axes = None # type: ignore
dim_args_spec = []

arg_specs_jax = [
jax.ShapeDtypeStruct(aval.shape, aval.dtype, named_shape=aval.named_shape)
Expand All @@ -647,7 +629,6 @@ def _lower_native_and_run(fun_jax: Callable,
# convert(f_jax), in which case a "jit" is implied. We also add a jit when
# we need to pass the abstracted axes.
fun_jax_lower = jax.jit(fun_jax, backend=backend,
keep_unused=True, # TODO: allow dropping unused
abstracted_axes=abstracted_axes).lower
else:
fun_jax_lower = fun_jax.lower
Expand All @@ -658,10 +639,6 @@ def _lower_native_and_run(fun_jax: Callable,
else:
mhlo_module = lowered.mhlo()
xla_call_module_version = 1
if logging.vlog_is_on(3):
mhlo_module_text = mlir.module_to_string(mhlo_module)
logging.vlog(3, "XlaCallModule (version=%d)\n%s", xla_call_module_version,
mhlo_module_text)

mhlo_serialized_module = mlir.module_to_bytecode(mhlo_module)
# Figure out the result types and shapes
Expand All @@ -685,6 +662,62 @@ def _out_type(jax_type):
return jax_type
out_types = tuple(_out_type(out_aval.dtype) for out_aval in out_avals)

module_kept_var_idx = lowered.compile_args["kept_var_idx"]
# We must compute the dim_args_spec: for each dimension variable, encode how
# to compute its value from the shape of the explicit arguments. E.g., "2.1"
# denotes args_tf[2].shape[1]. The order of the dimension variables must match
# the order of the first N arguments of the lowered function.
# If we use --jax_dynamic_shapes, the dimension variables are listed in the
# order in which they are encountered by scanning the arguments and their
# shapes in order. Otherwise, the dimension variables are passed in the
# alphabetical order of their names.
dim_args_spec_dict: Dict[str, str] = {} # map dim var name to dim_args_spec
dim_vars_order: List[str] = []
all_dim_vars: Set[str] = set()
current_kept_arg_idx = -1 # The index among the kept arguments
for arg_idx, aval in enumerate(args_avals):
is_kept = arg_idx in module_kept_var_idx
if is_kept:
current_kept_arg_idx += 1

for axis_idx, d in enumerate(aval.shape):
if not core.is_constant_dim(d):
# We collect dimension variables even from dropped args
all_dim_vars = all_dim_vars.union(d.get_vars())
if not is_kept: continue
d_var = d.to_var()
# We can compute dim vars only from trivial polynomials
if d_var is None: continue
if not d_var in dim_args_spec_dict:
dim_vars_order.append(d_var)
dim_args_spec_dict[d_var] = f"{current_kept_arg_idx}.{axis_idx}"

if all_dim_vars:
dim_args_spec_set = set(dim_vars_order)
if dim_args_spec_set != all_dim_vars:
missing = all_dim_vars.difference(dim_args_spec_set)
args_list = [f" Arg[{arg_idx}] - {'KEPT ' if arg_idx in module_kept_var_idx else 'DROPPED'}: {aval}"
for arg_idx, aval in enumerate(args_avals)]
raise ValueError(
"The following dimension variables cannot be computed from the static "
f"shapes of the kept lowered arguments: {missing}. These are the "
"argument shapes:\n" +
"\n".join(args_list) +
"\n"
"Please see https://github.com/google/jax/blob/main/jax/experimental/jax2tf/README.md#dimension-variables-must-be-solvable-from-the-input-shapes for more details.")

if config.jax_dynamic_shapes:
# In the order we have seen them
dim_args_spec = [dim_args_spec_dict[d_var] for d_var in dim_vars_order]
else:
# In sorted order by name
dim_args_spec = [dim_args_spec_dict[d_var] for d_var in sorted(dim_vars_order)]
else:
dim_args_spec = []

args_avals = [aval for i, aval in enumerate(args_avals) if i in module_kept_var_idx]
args_tf = [atf for i, atf in enumerate(args_tf) if i in module_kept_var_idx]

# Apply the shardings on arguments and results for pjit. This is redundant
# because the mhlo_module_text will already contain the shardings, but it
# makes it easier for tools like the TPU inference converter to see the
Expand All @@ -694,6 +727,11 @@ def _out_type(jax_type):
args_tf = tuple(
map(_shard_value, args_tf, args_avals, lowered.compile_args["in_shardings"]))

if logging.vlog_is_on(3):
mhlo_module_text = mlir.module_to_string(mhlo_module)
logging.vlog(3, "XlaCallModule (version=%d, dim_args_spec=%s)\n%s",
xla_call_module_version, ", ".join(dim_args_spec),
mhlo_module_text)
res = tfxla.call_module(
args_tf,
version=xla_call_module_version,
Expand Down
4 changes: 3 additions & 1 deletion jax/experimental/jax2tf/shape_poly.py
Expand Up @@ -886,5 +886,7 @@ def process_one_eqn(eqn: DimEquation) -> bool:
err_msg = (
f"Cannot solve for values of dimension variables {unsolved_vars} from "
f"the remaining dimension polynomials\n {eqns_str}.{_shapeenv_to_str()} "
"Dimension variables can be solved only from linear polynomials.")
"Dimension variables can be solved only from linear polynomials.\n"
"\n"
"Please see https://github.com/google/jax/blob/main/jax/experimental/jax2tf/README.md#dimension-variables-must-be-solvable-from-the-input-shapes for more details.")
raise ValueError(err_msg)
22 changes: 22 additions & 0 deletions jax/experimental/jax2tf/tests/jax2tf_test.py
Expand Up @@ -818,6 +818,28 @@ def inner1(y):

jax2tf.convert(func)(2.) # No error

def test_jit_unused(self):
def f_jax(x, y_unused):
return x * np.float32(2.)
x, y_unused = np.float32(5.), np.arange(7, dtype=np.int32)
res_tf = jax2tf.convert(jax.jit(f_jax, keep_unused=False))(x, y_unused)
self.assertAllClose(f_jax(x, None), res_tf)

def test_jit_unused_grad(self):
def f_jax(x, y_unused):
return x * np.float32(2.)

x, y_unused = np.float32(5.), np.arange(7, dtype=np.int32)
f_tf = jax2tf.convert(jax.jit(f_jax, keep_unused=False))
xv, y_unused_v = tf.Variable(x), tf.Variable(y_unused)
with tf.GradientTape() as tape:
res_tf = f_tf(xv, y_unused_v)
grad_tf_x, grad_tf_y = tape.gradient(res_tf, (xv, y_unused_v))

self.assertAllClose(f_jax(x, None), res_tf)
self.assertAllClose(np.float32(2.), grad_tf_x)
self.assertIsNone(grad_tf_y)

def test_nested_convert_error(self):
def outer(y):
return jax2tf.convert(jnp.sin)(y) # Inner convert takes tracer args
Expand Down
70 changes: 69 additions & 1 deletion jax/experimental/jax2tf/tests/shape_poly_test.py
Expand Up @@ -20,7 +20,6 @@
import collections
import functools
from functools import partial
import logging
import operator
import re

Expand Down Expand Up @@ -722,6 +721,70 @@ def f_jax(x): # x: f32[w, h]
input_signature=[tf.TensorSpec([None, None])],
polymorphic_shapes=["w, h"])

def test_non_trivial_polynomials(self):
if config.jax_dynamic_shapes:
raise unittest.SkipTest("--jax_dynamic_shapes supports only trivial polynomials")
# We can handle non-trivial polynomials in the input shape,
# as long as all variables also occur in trivial polynoamials
self.CheckShapePolymorphism(
lambda x, y: x + y.reshape((-1,)),
input_signature=[tf.TensorSpec([None]), tf.TensorSpec([None, None])],
polymorphic_shapes=["b * b", "b, b"])

def test_unused_args(self):
# Tests with functions that do not use their inputs.

# First arg unused, not polymorphic
self.CheckShapePolymorphism(
lambda x_unused, y: y * 2.0,
input_signature=[tf.TensorSpec([]), tf.TensorSpec([None])],
polymorphic_shapes=[None, "b"])

# Some args unused, not polymorphic
self.CheckShapePolymorphism(
lambda x_unused, y, z_unused, w: jnp.concatenate([y, w]),
input_signature=[tf.TensorSpec([]), tf.TensorSpec([None]),
tf.TensorSpec([]), tf.TensorSpec([None])],
polymorphic_shapes=[None, "b1", None, "b2"])

# A polymorphic arg is not used, but the dimension var appears
# in a used arg also
self.CheckShapePolymorphism(
lambda x_unused, y: y * 2.0,
input_signature=[tf.TensorSpec([None]), tf.TensorSpec([None])],
polymorphic_shapes=["b", "b"])

# A polymorphic arg is not used, and the dimension var does not appear
# elsewhere.
if config.jax2tf_default_experimental_native_lowering:
with self.assertRaisesRegex(ValueError,
"The following dimension variables cannot be computed"):
self.CheckShapePolymorphism(
lambda x_unused, y: y * 2.0,
input_signature=[tf.TensorSpec([None]), tf.TensorSpec([None])],
polymorphic_shapes=["b1", "b2"])
else:
self.CheckShapePolymorphism(
lambda x_unused, y: y * 2.0,
input_signature=[tf.TensorSpec([None]), tf.TensorSpec([None])],
polymorphic_shapes=["b1", "b2"])

# A polymorphic arg is not used, and the dimension var does appear
# elsewhere but not as a trivial monomial.
if config.jax2tf_default_experimental_native_lowering:
with self.assertRaisesRegex(ValueError,
"The following dimension variables cannot be computed"):
self.CheckShapePolymorphism(
lambda x_unused, y: y * 2.0,
input_signature=[tf.TensorSpec([None]), tf.TensorSpec([None])],
polymorphic_shapes=["b1", "b1 * b1"])
else:
self.CheckShapePolymorphism(
lambda x_unused, y: y * 2.0,
input_signature=[tf.TensorSpec([None]), tf.TensorSpec([None])],
polymorphic_shapes=["b1", "b1 * b1"])


def test_with_custom_vjp(self):
"""Shape-polymorphic custom VJP."""

Expand Down Expand Up @@ -1065,6 +1128,11 @@ def f_jax(x):
jax2tf.convert(f_jax, polymorphic_shapes=["b, b"])).get_concrete_function(tf.TensorSpec([None, None], dtype=np.float32))
self.assertEqual(1, f_tf(x45))

x = np.ones((5,), dtype=np.float32)
with self.assertRaisesRegex(ValueError,
"Cannot solve for values of dimension variables"):
jax2tf.convert(lambda x: x, polymorphic_shapes=["a + b"])(x)


class DimAsValueTest(tf_test_util.JaxToTfTestCase):
"""Dimension polynomials used as values in the JAX computation."""
Expand Down

0 comments on commit 23001ae

Please sign in to comment.