diff --git a/jax/experimental/jax2tf/README.md b/jax/experimental/jax2tf/README.md index 352d51fe46e1..010e6a780dd8 100644 --- a/jax/experimental/jax2tf/README.md +++ b/jax/experimental/jax2tf/README.md @@ -381,6 +381,10 @@ lowered with the batch dimension polymorphic and the remaining dimensions concre It is reasonable to expect that there will be JAX programs for which there is a shape-polymorphic TensorFlow graph, but which will give an error when lowering with jax2tf. +In general, you should expect that shape polymorphism can handle those programs for which +all the intermediate shapes can be expressed as polynomials in the dimension variables +appearing in the input shapes. In particular, this does not include programs whose +intermediate shapes depend on the data. ### Details @@ -613,6 +617,38 @@ jax2tf.convert(lambda x: jnp.reshape(x, (2, -1)), polymorphic_shapes=["(2*b, ...)"])(np.ones((4, 5, 7))) ``` +### Dimension variables must be solvable from the input shapes + +`jax2tf` will generate code to derive the values of the dimension variables +from the input shapes. This works only if dimension polynomials in the input shapes are linear. +For example, the following `polymorphic_shapes` will result in errors: + +```python +polymorphic_shapes = ["a * a"] # Not a linear polynomial +polymorphic_shapes = ["a + b"] # Too few equations to derive both `a` and `b` +``` + +If you are using native lowering, the restrictions are stronger: every dimension +variable must occur as the value of some dimension of some input, e.g., +the following will work: + +```python +polymorphic_shapes = ["a, 2*a, b"] +polymorphic_shapes = ["a * a, a"] +``` + +Furthermore, when using the native lowering the inputs that are not needed in the computation +are ignored, so the dimension variables must be derivable only from used inputs. +In the following example, the `x_unused` is not part of the computation so its +input shapes cannot be used for deriving the dimension variables, and you will +get an error that `a` cannot be derived: + +```python +jax2tf.convert(lambda x_unused, y: y * 2., + polymorphic_shapes=["b, a", "b, 2 * a"])(x, y) +``` + + ## Known issues `jax2tf` has been in use since 2020 and the vast majority of users encounter diff --git a/jax/experimental/jax2tf/jax2tf.py b/jax/experimental/jax2tf/jax2tf.py index b5441da301dc..9cc48382e1d3 100644 --- a/jax/experimental/jax2tf/jax2tf.py +++ b/jax/experimental/jax2tf/jax2tf.py @@ -595,46 +595,28 @@ def _lower_native_and_run(fun_jax: Callable, Special care must be taken in presence of shape polymorphism. """ # Look for shape polymorphism - - # For each dimension variable, encode how to compute its value from the - # shape of the explicit arguments. E.g., "2.1" denotes args_tf[2].shape[1]. - # The order of the dimension variables must match the order of the first N - # arguments of the lowered function. - # We now have two implementations for the native lowering. If --jax_dynamic_shapes - # then we use JAX's in-progress support for native dynamic shapes. In that - # case we assume that the dimension variables are listed in the order in which - # they are encountered by scanning the arguments and their shapes in order. - # If we don't use --jax_dynamic_shapes then the dimension variables are passed - # in the alphabetical order of their names. - abstracted_axes: Sequence[Dict[int, str]] = [] - dim_args_spec_dict: Dict[str, str] = {} # map dim var name to dim_args_spec - dim_vars_seen: List[str] = [] # the dim var names in order - for arg_idx, aval in enumerate(args_avals): - one_abstract_axes = {} - for axis_idx, d in enumerate(aval.shape): - if not core.is_constant_dim(d): - d_var = d.to_var() - if d_var is None: - raise ValueError(f"Only simple dimension variables supported: {aval.shape}") - if not d_var in dim_vars_seen: - dim_args_spec_dict[d_var] = f"{arg_idx}.{axis_idx}" - dim_vars_seen.append(d_var) - one_abstract_axes[axis_idx] = d_var - abstracted_axes.append(one_abstract_axes) - - if any(abstracted_axes): - if config.jax_dynamic_shapes: + # then we use JAX's in-progress support for native dynamic shapes, and we pass + # abstracted_axes to lowering functions. Otherwise, we just lower using + # abstract values whose shapes may include polynomials (already in args_avals). + if config.jax_dynamic_shapes: + abstracted_axes: Sequence[Dict[int, str]] = [] + for arg_idx, aval in enumerate(args_avals): + one_abstract_axes = {} + for axis_idx, d in enumerate(aval.shape): + if not core.is_constant_dim(d): + d_var = d.to_var() + if d_var is None: + raise ValueError(f"Only trivial dimension polynomials on input: {aval.shape}") + one_abstract_axes[axis_idx] = d_var + abstracted_axes.append(one_abstract_axes) + + if any(abstracted_axes): abstracted_axes = tuple(abstracted_axes) - # In the order we have seen them - dim_args_spec = [dim_args_spec_dict[d_var] for d_var in dim_vars_seen] else: abstracted_axes = None # type: ignore - # In sorted order by name - dim_args_spec = [dim_args_spec_dict[d_var] for d_var in sorted(dim_vars_seen)] else: abstracted_axes = None # type: ignore - dim_args_spec = [] arg_specs_jax = [ jax.ShapeDtypeStruct(aval.shape, aval.dtype, named_shape=aval.named_shape) @@ -647,7 +629,6 @@ def _lower_native_and_run(fun_jax: Callable, # convert(f_jax), in which case a "jit" is implied. We also add a jit when # we need to pass the abstracted axes. fun_jax_lower = jax.jit(fun_jax, backend=backend, - keep_unused=True, # TODO: allow dropping unused abstracted_axes=abstracted_axes).lower else: fun_jax_lower = fun_jax.lower @@ -658,10 +639,6 @@ def _lower_native_and_run(fun_jax: Callable, else: mhlo_module = lowered.mhlo() xla_call_module_version = 1 - if logging.vlog_is_on(3): - mhlo_module_text = mlir.module_to_string(mhlo_module) - logging.vlog(3, "XlaCallModule (version=%d)\n%s", xla_call_module_version, - mhlo_module_text) mhlo_serialized_module = mlir.module_to_bytecode(mhlo_module) # Figure out the result types and shapes @@ -685,6 +662,62 @@ def _out_type(jax_type): return jax_type out_types = tuple(_out_type(out_aval.dtype) for out_aval in out_avals) + module_kept_var_idx = lowered.compile_args["kept_var_idx"] + # We must compute the dim_args_spec: for each dimension variable, encode how + # to compute its value from the shape of the explicit arguments. E.g., "2.1" + # denotes args_tf[2].shape[1]. The order of the dimension variables must match + # the order of the first N arguments of the lowered function. + # If we use --jax_dynamic_shapes, the dimension variables are listed in the + # order in which they are encountered by scanning the arguments and their + # shapes in order. Otherwise, the dimension variables are passed in the + # alphabetical order of their names. + dim_args_spec_dict: Dict[str, str] = {} # map dim var name to dim_args_spec + dim_vars_order: List[str] = [] + all_dim_vars: Set[str] = set() + current_kept_arg_idx = -1 # The index among the kept arguments + for arg_idx, aval in enumerate(args_avals): + is_kept = arg_idx in module_kept_var_idx + if is_kept: + current_kept_arg_idx += 1 + + for axis_idx, d in enumerate(aval.shape): + if not core.is_constant_dim(d): + # We collect dimension variables even from dropped args + all_dim_vars = all_dim_vars.union(d.get_vars()) + if not is_kept: continue + d_var = d.to_var() + # We can compute dim vars only from trivial polynomials + if d_var is None: continue + if not d_var in dim_args_spec_dict: + dim_vars_order.append(d_var) + dim_args_spec_dict[d_var] = f"{current_kept_arg_idx}.{axis_idx}" + + if all_dim_vars: + dim_args_spec_set = set(dim_vars_order) + if dim_args_spec_set != all_dim_vars: + missing = all_dim_vars.difference(dim_args_spec_set) + args_list = [f" Arg[{arg_idx}] - {'KEPT ' if arg_idx in module_kept_var_idx else 'DROPPED'}: {aval}" + for arg_idx, aval in enumerate(args_avals)] + raise ValueError( + "The following dimension variables cannot be computed from the static " + f"shapes of the kept lowered arguments: {missing}. These are the " + "argument shapes:\n" + + "\n".join(args_list) + + "\n" + "Please see https://github.com/google/jax/blob/main/jax/experimental/jax2tf/README.md#dimension-variables-must-be-solvable-from-the-input-shapes for more details.") + + if config.jax_dynamic_shapes: + # In the order we have seen them + dim_args_spec = [dim_args_spec_dict[d_var] for d_var in dim_vars_order] + else: + # In sorted order by name + dim_args_spec = [dim_args_spec_dict[d_var] for d_var in sorted(dim_vars_order)] + else: + dim_args_spec = [] + + args_avals = [aval for i, aval in enumerate(args_avals) if i in module_kept_var_idx] + args_tf = [atf for i, atf in enumerate(args_tf) if i in module_kept_var_idx] + # Apply the shardings on arguments and results for pjit. This is redundant # because the mhlo_module_text will already contain the shardings, but it # makes it easier for tools like the TPU inference converter to see the @@ -694,6 +727,11 @@ def _out_type(jax_type): args_tf = tuple( map(_shard_value, args_tf, args_avals, lowered.compile_args["in_shardings"])) + if logging.vlog_is_on(3): + mhlo_module_text = mlir.module_to_string(mhlo_module) + logging.vlog(3, "XlaCallModule (version=%d, dim_args_spec=%s)\n%s", + xla_call_module_version, ", ".join(dim_args_spec), + mhlo_module_text) res = tfxla.call_module( args_tf, version=xla_call_module_version, diff --git a/jax/experimental/jax2tf/shape_poly.py b/jax/experimental/jax2tf/shape_poly.py index ca33b1690ebb..032b1fec735e 100644 --- a/jax/experimental/jax2tf/shape_poly.py +++ b/jax/experimental/jax2tf/shape_poly.py @@ -886,5 +886,7 @@ def process_one_eqn(eqn: DimEquation) -> bool: err_msg = ( f"Cannot solve for values of dimension variables {unsolved_vars} from " f"the remaining dimension polynomials\n {eqns_str}.{_shapeenv_to_str()} " - "Dimension variables can be solved only from linear polynomials.") + "Dimension variables can be solved only from linear polynomials.\n" + "\n" + "Please see https://github.com/google/jax/blob/main/jax/experimental/jax2tf/README.md#dimension-variables-must-be-solvable-from-the-input-shapes for more details.") raise ValueError(err_msg) diff --git a/jax/experimental/jax2tf/tests/jax2tf_test.py b/jax/experimental/jax2tf/tests/jax2tf_test.py index 39d6eb8f4a36..b5f1525b5d37 100644 --- a/jax/experimental/jax2tf/tests/jax2tf_test.py +++ b/jax/experimental/jax2tf/tests/jax2tf_test.py @@ -818,6 +818,28 @@ def inner1(y): jax2tf.convert(func)(2.) # No error + def test_jit_unused(self): + def f_jax(x, y_unused): + return x * np.float32(2.) + x, y_unused = np.float32(5.), np.arange(7, dtype=np.int32) + res_tf = jax2tf.convert(jax.jit(f_jax, keep_unused=False))(x, y_unused) + self.assertAllClose(f_jax(x, None), res_tf) + + def test_jit_unused_grad(self): + def f_jax(x, y_unused): + return x * np.float32(2.) + + x, y_unused = np.float32(5.), np.arange(7, dtype=np.int32) + f_tf = jax2tf.convert(jax.jit(f_jax, keep_unused=False)) + xv, y_unused_v = tf.Variable(x), tf.Variable(y_unused) + with tf.GradientTape() as tape: + res_tf = f_tf(xv, y_unused_v) + grad_tf_x, grad_tf_y = tape.gradient(res_tf, (xv, y_unused_v)) + + self.assertAllClose(f_jax(x, None), res_tf) + self.assertAllClose(np.float32(2.), grad_tf_x) + self.assertIsNone(grad_tf_y) + def test_nested_convert_error(self): def outer(y): return jax2tf.convert(jnp.sin)(y) # Inner convert takes tracer args diff --git a/jax/experimental/jax2tf/tests/shape_poly_test.py b/jax/experimental/jax2tf/tests/shape_poly_test.py index e4e1f8a19193..e34a47db389a 100644 --- a/jax/experimental/jax2tf/tests/shape_poly_test.py +++ b/jax/experimental/jax2tf/tests/shape_poly_test.py @@ -20,7 +20,6 @@ import collections import functools from functools import partial -import logging import operator import re @@ -722,6 +721,70 @@ def f_jax(x): # x: f32[w, h] input_signature=[tf.TensorSpec([None, None])], polymorphic_shapes=["w, h"]) + def test_non_trivial_polynomials(self): + if config.jax_dynamic_shapes: + raise unittest.SkipTest("--jax_dynamic_shapes supports only trivial polynomials") + # We can handle non-trivial polynomials in the input shape, + # as long as all variables also occur in trivial polynoamials + self.CheckShapePolymorphism( + lambda x, y: x + y.reshape((-1,)), + input_signature=[tf.TensorSpec([None]), tf.TensorSpec([None, None])], + polymorphic_shapes=["b * b", "b, b"]) + + def test_unused_args(self): + # Tests with functions that do not use their inputs. + + # First arg unused, not polymorphic + self.CheckShapePolymorphism( + lambda x_unused, y: y * 2.0, + input_signature=[tf.TensorSpec([]), tf.TensorSpec([None])], + polymorphic_shapes=[None, "b"]) + + # Some args unused, not polymorphic + self.CheckShapePolymorphism( + lambda x_unused, y, z_unused, w: jnp.concatenate([y, w]), + input_signature=[tf.TensorSpec([]), tf.TensorSpec([None]), + tf.TensorSpec([]), tf.TensorSpec([None])], + polymorphic_shapes=[None, "b1", None, "b2"]) + + # A polymorphic arg is not used, but the dimension var appears + # in a used arg also + self.CheckShapePolymorphism( + lambda x_unused, y: y * 2.0, + input_signature=[tf.TensorSpec([None]), tf.TensorSpec([None])], + polymorphic_shapes=["b", "b"]) + + # A polymorphic arg is not used, and the dimension var does not appear + # elsewhere. + if config.jax2tf_default_experimental_native_lowering: + with self.assertRaisesRegex(ValueError, + "The following dimension variables cannot be computed"): + self.CheckShapePolymorphism( + lambda x_unused, y: y * 2.0, + input_signature=[tf.TensorSpec([None]), tf.TensorSpec([None])], + polymorphic_shapes=["b1", "b2"]) + else: + self.CheckShapePolymorphism( + lambda x_unused, y: y * 2.0, + input_signature=[tf.TensorSpec([None]), tf.TensorSpec([None])], + polymorphic_shapes=["b1", "b2"]) + + # A polymorphic arg is not used, and the dimension var does appear + # elsewhere but not as a trivial monomial. + if config.jax2tf_default_experimental_native_lowering: + with self.assertRaisesRegex(ValueError, + "The following dimension variables cannot be computed"): + self.CheckShapePolymorphism( + lambda x_unused, y: y * 2.0, + input_signature=[tf.TensorSpec([None]), tf.TensorSpec([None])], + polymorphic_shapes=["b1", "b1 * b1"]) + else: + self.CheckShapePolymorphism( + lambda x_unused, y: y * 2.0, + input_signature=[tf.TensorSpec([None]), tf.TensorSpec([None])], + polymorphic_shapes=["b1", "b1 * b1"]) + + def test_with_custom_vjp(self): """Shape-polymorphic custom VJP.""" @@ -1065,6 +1128,11 @@ def f_jax(x): jax2tf.convert(f_jax, polymorphic_shapes=["b, b"])).get_concrete_function(tf.TensorSpec([None, None], dtype=np.float32)) self.assertEqual(1, f_tf(x45)) + x = np.ones((5,), dtype=np.float32) + with self.assertRaisesRegex(ValueError, + "Cannot solve for values of dimension variables"): + jax2tf.convert(lambda x: x, polymorphic_shapes=["a + b"])(x) + class DimAsValueTest(tf_test_util.JaxToTfTestCase): """Dimension polynomials used as values in the JAX computation."""