From a3ffdeed54dcd7889afde3a794405d4efb3db3d6 Mon Sep 17 00:00:00 2001 From: George Necula Date: Thu, 8 Sep 2022 11:16:39 +0300 Subject: [PATCH] [jax2tf] Updates to the documentation One change I made throughout is to use the verb lower instead of convert when refering to what jax2tf does. This is more accurate, because what is happening is that the lowering to HLO is replaced with a lowering to TF ops. Some users were confused by the use of the verb convert as suggesting some sort of source-to-source translation of JAX programs to TF, which sounds very complicated and risky. --- jax/experimental/jax2tf/README.md | 559 +++++++++--------- .../jax2tf/g3doc/convert_models_results.md | 2 +- 2 files changed, 284 insertions(+), 277 deletions(-) diff --git a/jax/experimental/jax2tf/README.md b/jax/experimental/jax2tf/README.md index 4704ee2ba7d5..3a58644451b3 100644 --- a/jax/experimental/jax2tf/README.md +++ b/jax/experimental/jax2tf/README.md @@ -7,13 +7,16 @@ This package provides experimental support for interoperation between JAX and Te There are two interoperation directions: - `jax2tf.convert`: for using JAX functions in a TensorFlow context, e.g., -for eager or graph execution, or for saving as a TensorFlow SavedModel; and +for eager or graph TensorFlow execution, +or for saving as a TensorFlow SavedModel; and - `jax2tf.call_tf`: for using TensorFlow functions in a JAX context, e.g., to call a TensorFlow library or a SavedModel inside a JAX function. -The `jax2tf.convert` mechanism can wrap a function -written in JAX, possibly including JAX transformations, and turn it into -a function that uses only TensorFlow operations. The converted function +`jax2tf.convert` directs JAX to use an alternative code +generator (lowering) and emit TensorFlow operations instead of the regular HLO operations +emitted in native JAX lowering. In all other respects the JAX function is +processed as in native JAX execution, e.g., for the JAX transformations. +The resulting function can be called or traced from TensorFlow and will behave as if it was written in TensorFlow. In practice this means that you can take some code written in JAX and execute it using TensorFlow eager mode, or stage it out as a TensorFlow graph, even use it @@ -26,8 +29,8 @@ or TensorFlow Hub. This package also contains the `jax2tf.call_tf` mechanism to call TensorFlow functions from JAX. These functions can be called in JAX's op-by-op execution mode, -in which case the callee is executed in eager mode, or in JAX's jit (staged) context, -in which case the callee is compiled to XLA and embedded in JAX's staged XLA. +in which case the callee is executed in TensorFlow eager mode, or in JAX's jit (staged) context, +in which case the callee is compiled to XLA and embedded in JAX's lowered HLO. Both interoperation directions rely on the ability of TensorFlow to use the XLA compiler (`tf.function(jit_compile=True)`). For the @@ -35,9 +38,10 @@ TensorFlow to use the XLA compiler (`tf.function(jit_compile=True)`). For the that the performance characteristics of the code match those of the JAX source. For the `call_tf` direction, JIT compilation is an essential part of the implementation mechanism. Only TensorFlow functions that can be JIT-compiled can be called from -JAX. Since the TensorFlow functions that are produced by `jax2tf.convert` can -be JIT-compiled by design, we can round-trip from JAX to TensorFlow -(e.g., a SavedModel) and back. +JAX in a jit context. +Since the TensorFlow functions that are produced by `jax2tf.convert` can +be JIT-compiled by design, we can call them using `jax2tf.call_tf` thus achieving +a round-trip from JAX to TensorFlow (e.g., a SavedModel) and back. We describe below some general concepts and capabilities, first for `jax2tf.convert` and [later](#calling-tensorflow-functions-from-jax) @@ -51,13 +55,12 @@ For details on saving a batch-polymorphic SavedModel see [below](#shape-polymorp See also some internal ongoing design discussions at `go/jax2tf-doc`. -## Usage: converting basic functions. +## Usage: basic functions. As a rule of thumb, if you can `jax.jit` your function then you should be able to use `jax2tf.convert`: ```python -import jax from jax.experimental import jax2tf from jax import numpy as jnp @@ -67,7 +70,7 @@ import tensorflow as tf def f_jax(x): return jnp.sin(jnp.cos(x)) -# jax2tf.convert is a higher order function that returns a wrapped function with +# jax2tf.convert is a higher-order function that returns a wrapped function with # the same signature as your input function but accepting TensorFlow tensors (or # variables) as input. f_tf = jax2tf.convert(f_jax) @@ -81,10 +84,10 @@ f_tf_graph = tf.function(f_tf, autograph=False) ``` The Autograph feature of `tf.function` cannot be expected to work on -functions converted from JAX as above, so it is recommended to +functions lowered from JAX as above, so it is recommended to set `autograph=False` in order to avoid warnings or outright errors. -It is a good idea to use XLA to compile the converted function; that is +It is a good idea to use XLA to compile the lowered function; that is the scenario for which we are optimizing for numerical and performance accuracy w.r.t. the JAX execution: @@ -118,7 +121,7 @@ restored_model = tf.saved_model.load('/some/directory') ``` An important point is that in the above code snippet **everything after the -jax2tf conversion is standard TensorFlow code. +jax2tf invocation is standard TensorFlow code. In particular, the saving of the model is not directly part of the jax2tf API, and the user has full control over how to create the SavedModel**. @@ -149,19 +152,19 @@ def model_jax(inputs): return param0 + param1 * inputs ``` -If you just convert and save the model directly, the values of +If you just lower and save the model directly, the values of `param0` and `param1` will be embedded in the computation graph. In fact, the value of `param1` is needed for the gradient computation and will be embedded twice: once in the computation graph for the forward computation and once for the backward computation, -unless you turn off the conversion of gradients or their saving as discussed +unless you turn off the staging of gradients or their saving as discussed further below (e.g., `with_gradient=False`). Note also that if one views the above function as an ML model parameterized by `param0` and `param1` then the gradient function will be w.r.t. the inputs, while you probably want gradients w.r.t. the parameters. A better way to deal with parameters (or any large constants) is to -pass them as parameters to the function to be converted: +pass them as parameters to the function to be lowered: ```python def model_jax(params, inputs): @@ -194,19 +197,20 @@ For examples of how to save a Flax model as a SavedModel see the ### Saved model and differentiation -The converted code supports differentiation from TensorFlow. In order to +The code lowered from JAX supports differentiation from TensorFlow. In order to ensure that the result of TensorFlow differentiation is identical to the -one that JAX differentiation would produce, the jax2tf converter will -annotate the converter function with a ``tf.custom_gradient`` that, +one that JAX differentiation would produce, we will +annotate the lowered primal function with a ``tf.custom_gradient`` that, upon TensorFlow differentiation, will lazily -call into JAX to compute the ``jax.vjp`` of the converted function, followed by -jax2tf conversion. This ensures that ultimately it is JAX that performs the +call into JAX to compute the ``jax.vjp`` of the lowered primal function, followed by +jax2tf lowering of the gradient function. +This ensures that ultimately it is JAX that performs the differentiation, thus respecting any custom gradients that may be present in the original function. -The jax2tf converter has an option ``with_gradient=False`` to skip the -custom gradients and wrap instead the converted function with -``tf.raw_ops.PreventGradient`` to generated an error in case a gradient +The `jax2tf.convert` function has an option ``with_gradient=False`` to skip the +custom gradients and wrap instead the lowered function with +``tf.raw_ops.PreventGradient`` to generate an error in case a gradient computation is attempted. SavedModels enables saving custom derivative rules by using the `experimental_custom_gradients` option: @@ -257,21 +261,21 @@ you will not be able to compute the gradients of the function loaded from the Sa ## Support for partitioning jax2tf supports JAX functions that use `jax.pjit`, for single-host meshes. -The conversion is actually similar as for a `jax.jit`, except that the +The lowering is actually similar as for a `jax.jit`, except that the arguments and results will be wrapped with `tensorflow.compiler.xla.experimental.xla_sharding.XlaSharding` TensorFlow ops. Note that when saving a model, the parameters to the model are wrapped with -`tf.Variable` before calling the converted function (see [above](#saved_model_with_parameters)), +`tf.Variable` before calling the lowered function (see [above](#saved_model_with_parameters)), therefore outside of the `XlaSharding` wrapper. ## Shape-polymorphic conversion **The shape polymorphism support is work in progress. It is meant to be sound, -but it may fail to convert some programs. Please report any bugs you encounter.** +but it may fail to lower some programs. Please report any bugs you encounter.** We described above how to include in the SavedModel several specializations -of a converted function for a few specific input shapes. The converter can +of a lowered function for a few specific input shapes. `jax2tf` can also produce a shape-polymorphic TensorFlow graph that is usable with inputs of any shape matching certain constraints. This is useful, e.g., to allow a single SavedModel @@ -312,7 +316,7 @@ error messages. The real need for named shape variables arises when there are multiple unknown dimensions and there is a relationship between them. For example, -if the function to be converted is also polymorphic on the size of each +if the function to be lowered is also polymorphic on the size of each image while requiring the images to be square, we would add a dimension variable `d` to stand for the unknown image size: @@ -330,7 +334,7 @@ same shape of a batch of square matrices that can be passed to `jnp.matmul`. ### Correctness of shape-polymorphic tracing -We want to trust that the converted program produces the same results as the +We want to trust that the lowered program produces the same results as the original JAX program. More precisely: For any function `f_jax` and any input signature `abs_sig` containing partially @@ -354,22 +358,22 @@ by reusing the same JAX tracing and shape checking mechanism as when the shapes ### Coverage of shape-polymorphic tracing -Besides correctness, a secondary goal is to be able to convert many shape-polymorphic programs, +Besides correctness, a secondary goal is to be able to lower many shape-polymorphic programs, but at the very least batch-size-polymorphic programs, so that one SavedModel can be used for any batch sizes. For example, we want to ensure that any function written using `jax.vmap` at the top level can be -converted with the batch dimension polymorphic and the remaining dimensions concrete. +lowered with the batch dimension polymorphic and the remaining dimensions concrete. It is reasonable to expect that there will be JAX programs for which there is a -shape-polymorphic TensorFlow graph, but which will give an error when converting with jax2tf. +shape-polymorphic TensorFlow graph, but which will give an error when lowering with jax2tf. ### Details In order to be able to use shape polymorphism effectively with jax2tf, it -is worth considering what happens under the hood. When the converted function -is invoked with a `TensorSpec`, the jax2tf converter will combine the +is worth considering what happens under the hood. When the lowered function +is invoked with a `TensorSpec`, `jax2tf` will combine the `TensorSpec` from the actual argument with the `polymorphic_shapes` parameter to -obtain a shape abstraction to be used to specialize the converted function. +obtain a shape abstraction to be used to specialize the lowered function. Normally, the shape abstraction contains the dimension sizes, but in the presence of shape polymorphism, some dimensions may be dimension variables. @@ -406,7 +410,7 @@ A few examples of shape specifications and uses: * `polymorphic_shapes=["(b, _, _)", None]` can be used for a function with two arguments, the first having a batch leading dimension that should be polymorphic. The other dimensions for the first argument and the shape of the second argument are specialized based on the actual - `TensorSpec`, which must be known. The converted function can be used, e.g., + `TensorSpec`, which must be known. The lowered function can be used, e.g., with `TensorSpec`s `[None, 28, 28]` and `[28, 16]` for the first and second argument respectively. An alternative `TensorSpec` pair can be `[1, 28, 28]` and `[28, 16]`, in which case the JAX tracing is done for the same polymorphic shape given by @@ -481,13 +485,13 @@ jax2tf.convert(lambda x: 0 if x.shape[0] + 1 == x.shape[1] else 1, ``` Note that it would be unsound for JAX to compute `x.shape[0] + 1 == x.shape[1]` -as `False` and produce a converted function that returns `1` just because the dimension polynomials +as `False` and produce a lowered function that returns `1` just because the dimension polynomials are not identical: there are some concrete input shapes for which the function should return `0`. ### Dimension variables appearing in the numeric computation -There are some situations when dimension variables arise in the staged computation itself. +There are some situations when dimension variables arise in the lowered computation itself. You can see in the following example how elements from the input shapes `(1024, 28, 28)` and `(28, 28)` appear in the computation and specifically in the `shape` parameter of the `broadcast_in_dim` JAX primitive. @@ -508,12 +512,12 @@ print(jax.make_jaxpr(image_mask_jax)(np.ones((1024, 28, 28)), np.ones((28, 28))) jax2tf.convert(image_mask_jax, polymorphic_shapes=["(b, w, w)", "(w, w)"]) ``` -When tracing and converting with abstract shapes some primitive parameters will be dimension variables +When tracing and lowering with abstract shapes some primitive parameters will be dimension variables instead of just constants, e.g., the `shape` parameter of `broadcast_in_dim` will be `(1, w, w)`. Note that JAX primitives distinguish the inputs, which are array values, e.g., `b` for `broadcast_in_dim` above, and the parameters, e.g., `broadcast_dimensions` and `shape`. -The conversion of `image_mask_jax` would use `tf.shape` to compute the +The lowering of `image_mask_jax` would use `tf.shape` to compute the values of the dimension variables `b` and `w`: ```python @@ -524,7 +528,7 @@ def image_mask_tf(images, mask): [b, w, w])) ``` -To achieve this, when we start converting a function we construct a shape environment, +To achieve this, when we start lowering a function we construct a shape environment, mapping the dimension variables in the `polymorphic_shapes` specification to TensorFlow expressions using `tf.shape` on the input parameters. @@ -559,7 +563,7 @@ will want to ensure the size of the two axes is the same (`v == 4`). Note that `v` can stand for any integer greater than 0, so the value of the equality expression can be true or false. Since it is not always true that `v == 4`, the shape checking rules fail with the above error. -Since the converted function works only for square matrices, the correct +Since the lowered function works only for square matrices, the correct `polymorphic_shapes` is `["(v, v)"]`. @@ -618,27 +622,97 @@ jax2tf.convert(lambda x: jnp.reshape(x, (2, -1)), ## Known issues +`jax2tf` has been in use since 2020 and the vast majority of users encounter +no problems. However, there are a few rare corner cases +in which the different conventions of JAX and TensorFlow result in a breakage. +We try to give an exhaustive list below. + +### Different 64-bit precision in JAX and TensorFlow + +JAX behaves somewhat differently than TensorFlow in the handling +of 32-bit vs. 64-bit values. However, the `jax2tf` lowered function +always behaves like the JAX function. + +JAX interprets the type of Python scalars differently based on +`JAX_ENABLE_X64` flag. (See +[JAX - The Sharp Bits: Double (64bit) precision](https://jax.readthedocs.io/en/latest/notebooks/Common_Gotchas_in_JAX.html#double-64bit-precision).) +In the default configuration, the +flag is unset, and JAX interprets Python constants as 32-bit, +e.g., the type of `3.14` is `float32`. This is also what +TensorFlow always does. JAX goes further, it forces +all explicitly-specified 64-bit values to be interpreted as +32-bit: + +```python +# with JAX_ENABLE_X64=0 +jnp.sin(3.14) # Has type float32 +tf.math.sin(3.14) # Has type float32 + +jnp.sin(np.float64(3.14)) # Also has type float32 +tf.math.sin(np.float64(3.14)) # Has type float64 + +# The jax2tf.convert function behaves like the JAX function. +jax2tf.convert(jnp.sin)(3.14) # Has type float32 +jax2tf.convert(jnp.sin)(np.float64(3.14)) # Has type float32 + +# The following will still compute `sin` in float32 (with a tf.cast on the argument). +tf.function(jax2tf.convert(jnp.sin))(tf.Variable(3.14, tf.float64)) +``` + +When the `JAX_ENABLE_X64` flas is set, JAX uses 64-bit types +for Python scalars and respects the explicit 64-bit types: + +```python +# with JAX_ENABLE_X64=1 +jnp.sin(3.14) # Has type float64 +tf.math.sin(3.14) # Has type float32 + +# The jax2tf.convert function behaves like the JAX function. +jax2tf.convert(jnp.sin)(3.14) # Has type float64 + +# The following will compute `sin` in float64. +tf.function(jax2tf.convert(jnp.sin))(tf.Variable(3.14, tf.float64)) + +# The following will compute `sin` in float32. +tf.function(jax2tf.convert(jnp.sin))(tf.Variable(3.14)) +``` + +This is achieved by inserting `tf.cast` operations +on the input arguments inside the lowered function, +if necessary. + +If you want to create a `tf.Variable` or `tf.TensorSpec` with the +same dtype, you should use `jax2tf.dtype_of_val`: + +```python +# The following two calls will lower jax_fun at the same dtypes +# independently of the value of JAX_ENABLE_X64. +jax2tf.convert(jax_fun)(3.14) +jax2tf.convert(jax_fun)(tf.Variable(3.14, dtype=jax2tf.dtype_of_val(3.14)) +``` + ### Incomplete TensorFlow data type coverage There are a number of cases when the TensorFlow ops that are used by the -jax2tf converter are not supported by TensorFlow for the same data types as in JAX. +`jax2tf` are not supported by TensorFlow for the same data types as in JAX. There is an [up-to-date list of unimplemented cases](https://github.com/google/jax/blob/main/jax/experimental/jax2tf/g3doc/primitives_with_limited_support.md). -If you try to convert and run in TensorFlow a program with partially supported primitives, you may see TensorFlow errors that -a TensorFlow op is used with an supported data type, or that +If you try to lower and run in TensorFlow a program with partially supported primitives, +you may see TensorFlow errors that +a TensorFlow op is used with an unsupported data type, or that there is no supported TensorFlow kernel for the op for the given data type. The former case can happen even if you `jit_compile` the TensorFlow program, and it is a priority to fit. The latter -case only appears in TensorFlow non-compiled mode and you can +case only appears in TensorFlow non-compiled mode; you can avoid the problem if you use XLA to `jit_compile` (always recommended). Our priority is to ensure numerical and performance accuracy for -the converted program **when using XLA to compile the converted program**. -It is always a good idea to use XLA on the JAX-converted function. +the lowered program **when using XLA to compile the lowered program**. +It is always a good idea to use XLA on the lowered function. Sometimes you cannot compile the entire TensorFlow function for your -model, because in addition to the function that is converted from JAX, +model, because in addition to the function that is lowered from JAX, it may include some pre-processing TensorFlow code that is not compileable with XLA, e.g., string parsing. Even in those situations you can instruct TensorFlow to compile only the portion that originates @@ -647,40 +721,101 @@ from JAX: ```python def entire_tf_fun(x): y = preprocess_tf_fun_not_compileable(x) - # Compile the code that is converted from JAX + # Compile the code that is lowered from JAX z = tf.function(jax2tf.convert(compute_jax_fn), autograph=False, jit_compile=True)(y) return postprocess_tf_fun_not_compileable(z) ``` You won't be able to compile the `entire_tf_fun`, but you can still execute -it knowing that the JAX-converted code is compiled. You can even save +it knowing that the jax2tf-lowered code is compiled. You can even save the function to a SavedModel, knowing that upon restore the -JAX-converted code will be compiled. +jax2tf-lowered code will be compiled. For a more elaborate example, see the test `test_tf_mix_jax_with_uncompileable` in [savedmodel_test.py](https://github.com/google/jax/blob/main/jax/experimental/jax2tf/tests/savedmodel_test.py). -### Missing converter features +### Functions whose arguments and results are nested Python data structures -There is currently no support for `pmap` or`xmap`, nor for the collective -operations. There is support for `pjit`. +`jax2tf` can lower functions with arguments and results that are nested +collections (tuples, lists, dictionaries) of numeric values or JAX arrays +([pytrees](https://jax.readthedocs.io/en/latest/pytrees.html)). The +resulting TensorFlow function will take the same kind of arguments except the +leaves can be numeric values or TensorFlow tensors (`tf.Tensor`, `tf.TensorSpec`, `tf.Variable`). -### SavedModel may be large +As long as the arguments use only standard Python containers (tuple, list, dictionaries), +both JAX and TensorFlow can flatten and unflatten them and you can use the lowered +function in TensorFlow without limitations. -If you suspect that the SavedModel is larger than it should be, check first -that you are not including the parameters as constants in the graph (see [above](#usage-saved-model)). +However, if your JAX function takes a custom container, you can register it with +the JAX `tree_util` module so that JAX will know how to operate with it, and you +can still lower the function to use it in TensorFlow +eager and with `tf.function`, but you won't be able to save it to a SavedModel, nor +will you be able to compute gradients with TensorFlow +(code from `jax2tf_test.test_custom_pytree_readme`): -### SavedModel supports only first-order gradients +```python +class CustomPair: + def __init__(self, a, b): + self.a = a + self.b = b -The `jax2tf`-converted function supports higher-order gradients, but when the -function is saved in a SavedModel, only the first-order gradient is saved. +# Register it with the JAX tree_util module +jax.tree_util.register_pytree_node(CustomPair, + lambda x: ((x.a, x.b), None), + lambda _, ab: CustomPair(*ab)) +def f_jax(pair: CustomPair): + return 2. * pair.a + 3. * pair.b + +x = CustomPair(4., 5.) +res_jax = f_jax(x) +# TF execution works as long as JAX can flatten the arguments +res_tf = jax2tf.convert(f_jax)(x) +self.assertAllClose(res_jax, res_tf.numpy()) +res_tf_2 = tf.function(jax2tf.convert(f_jax), autograph=False, jit_compile=True)(x) +``` -### Converting gradients for functions with integer arguments or unused arguments +If you want to save the function in a SavedModel or compute gradients, +you should construct a wrapper: + +```python + # wrapped TF function to use only standard containers +def f_tf_wrapped(a, b): + return f_tf(CustomPair(a, b)) + +# Try to put into SavedModel +my_model = tf.Module() +# Save a function that can take scalar inputs. +my_model.f = tf.function(f_tf_wrapped, autograph=False, + input_signature=[tf.TensorSpec([], tf.float32), + tf.TensorSpec([], tf.float32)]) +model_dir = os.path.join(absltest.get_default_test_tmpdir(), str(id(my_model))) +tf.saved_model.save(my_model, model_dir, + options=tf.saved_model.SaveOptions(experimental_custom_gradients=True)) + +# Restoring (note: the restored model does *not* require JAX to run, just XLA). +restored_model = tf.saved_model.load(model_dir) +def restored_f(pair: CustomPair): + return restored_model.f(pair.a, pair.b) + +res_tf_3 = restored_f(x) +self.assertAllClose(res_jax, res_tf_3) +grad_jax = jax.grad(f_jax)(x) + +x_v = [tf.Variable(x.a), tf.Variable(x.b)] +with tf.GradientTape() as tape: + res = f_tf_wrapped(*x_v) + grad_tf = tape.gradient(res, x_v) + +self.assertAllClose(grad_jax.a, grad_tf[0]) +self.assertAllClose(grad_jax.b, grad_tf[1]) +``` + +### Lowering gradients for functions with integer arguments or unused arguments When JAX differentiates functions with integer or boolean arguments, the gradients will be zero-vectors with a special `float0` type (see PR 4039](https://github.com/google/jax/pull/4039)). -This type is translated to `int32` when converting to TF. +This type is translated to `int32` when lowering to TF. For example, ```python @@ -719,7 +854,7 @@ returns the value `None` for the corresponding gradients. The `tape.gradient` function takes the option `tf.UnconnectedGradients.ZERO` to ask that gradients for unused arguments be zero. -Functions converted with `jax2tf.convert` behave the same way under +Functions lowered with `jax2tf.convert` behave the same way under `tf.UnconnectedGradients.ZERO`, but by default, they will return `None` only for gradients corresponding to integer arguments. @@ -747,159 +882,58 @@ g_jax2tf = tape.gradient(res, xs) # Returns: 0., 0., 2., None # Note that the gradient for x1 is 0. -g_jaxx2tf_0 = tape.gradient(res, xs, +g_jax2tf_0 = tape.gradient(res, xs, unconnected_gradients=tf.UnconnectedGradients.ZERO) # Returns: 0., 0., 2., 0 # In this case we get the same result as for TF native. ``` -### Functions whose arguments and results are Python nested data structures - -jax2tf can convert functions with arguments and results that are nested -collections (tuples, lists, dictionaries) of numeric values or JAX arrays -([pytrees](https://jax.readthedocs.io/en/latest/pytrees.html)). The -resulting TensorFlow function will take the same kind of arguments except the -leaves can be numeric values or TensorFlow tensors (`tf.Tensor`, `tf.TensorSpec`, `tf.Variable`). - -As long as the arguments use only standard Python containers (tuple, list, dictionaries), -both JAX and TensorFlow can flatten and unflatten them and you can use the converted -function in TensorFlow without limitations. - -However, if your JAX function takes a custom container, you can register it with -the JAX `tree_util` module so that JAX will know how to operate with it, and you -can still convert the function to use it in TensorFlow -eager and with `tf.function`, but you won't be able to save it to a SavedModel, nor -will you be able to compute gradients with TensorFlow -(code from `jax2tf_test.test_custom_pytree_readme`): -```python -class CustomPair: - def __init__(self, a, b): - self.a = a - self.b = b - -# Register it with the JAX tree_util module -jax.tree_util.register_pytree_node(CustomPair, - lambda x: ((x.a, x.b), None), - lambda _, ab: CustomPair(*ab)) -def f_jax(pair: CustomPair): - return 2. * pair.a + 3. * pair.b +### Errors due to tf.Module magic conversion during attribute assignment -x = CustomPair(4., 5.) -res_jax = f_jax(x) -# TF execution works as long as JAX can flatten the arguments -res_tf = jax2tf.convert(f_jax)(x) -self.assertAllClose(res_jax, res_tf.numpy()) -res_tf_2 = tf.function(jax2tf.convert(f_jax), autograph=False, jit_compile=True)(x) -``` +`tf.Module` will automatically wrap the standard Python container data types into +trackable classes during attribute assignment. +Python Dict/List/Tuple are changed to _DictWrapper/_ListWrapper/_TupleWrapper +classes. +In most situation, these Wrapper classes work exactly as the standard +Python data types. However, the low-level pytree data structures are different +and this can lead to errors. -If you want to save the function in a SavedModel or compute gradients, -you should construct a wrapper: +In such cases, the user can use this workaround: ```python - # wrapped TF function to use only standard containers -def f_tf_wrapped(a, b): - return f_tf(CustomPair(a, b)) - -# Try to put into SavedModel -my_model = tf.Module() -# Save a function that can take scalar inputs. -my_model.f = tf.function(f_tf_wrapped, autograph=False, - input_signature=[tf.TensorSpec([], tf.float32), - tf.TensorSpec([], tf.float32)]) -model_dir = os.path.join(absltest.get_default_test_tmpdir(), str(id(my_model))) -tf.saved_model.save(my_model, model_dir, - options=tf.saved_model.SaveOptions(experimental_custom_gradients=True)) - -# Restoring (note: the restored model does *not* require JAX to run, just XLA). -restored_model = tf.saved_model.load(model_dir) -def restored_f(pair: CustomPair): - return restored_model.f(pair.a, pair.b) - -res_tf_3 = restored_f(x) -self.assertAllClose(res_jax, res_tf_3) -grad_jax = jax.grad(f_jax)(x) - -x_v = [tf.Variable(x.a), tf.Variable(x.b)] -with tf.GradientTape() as tape: - res = f_tf_wrapped(*x_v) - grad_tf = tape.gradient(res, x_v) +import tensorflow as tf +input_data = #Any data object -self.assertAllClose(grad_jax.a, grad_tf[0]) -self.assertAllClose(grad_jax.b, grad_tf[1]) +m = tf.Module() +flat, tree_def = jax.tree_util.tree_flatten(input_data) +m.input_data = {"flat": flat, "tree_def": tree_def} ``` -### Different 64-bit precision in JAX and TensorFlow - -JAX behaves somewhat differently than TensorFlow in the handling -of 32-bit vs. 64-bit values. However, the `jax2tf.convert` function -always behaves like the JAX function. - -JAX interprets the type of Python scalars differently based on -`JAX_ENABLE_X64` flag. (See -[JAX - The Sharp Bits: Double (64bit) precision](https://jax.readthedocs.io/en/latest/notebooks/Common_Gotchas_in_JAX.html#double-64bit-precision).) -In the default configuration, the -flag is unset, and JAX interprets Python constants as 32-bit, -e.g., the type of `3.14` is `float32`. This is also what -TensorFlow always does. JAX goes further, it forces -all explicitly-specified 64-bit values to be interpreted as -32-bit: +Later the user can use `tree_unflatten` for the reverse process: ```python -# with JAX_ENABLE_X64=0 -jnp.sin(3.14) # Has type float32 -tf.math.sin(3.14) # Has type float32 - -jnp.sin(np.float64(3.14)) # Also has type float32 -tf.math.sin(np.float64(3.14)) # Has type float64 - -# The jax2tf.convert function behaves like the JAX function. -jax2tf.convert(jnp.sin)(3.14) # Has type float32 -jax2tf.convert(jnp.sin)(np.float64(3.14)) # Has type float32 - -# The following will still compute `sin` in float32 (with a tf.cast on the argument). -tf.function(jax2tf.convert(jnp.sin))(tf.Variable(3.14, tf.float64)) +input_data = jax.tree_util.tree_unflatten(m.input_data['tree_def'], m.input_data['flat']) ``` -When the `JAX_ENABLE_X64` flas is set, JAX uses 64-bit types -for Python scalars and respects the explicit 64-bit types: +### Unimplemented jax2tf features -```python -# with JAX_ENABLE_X64=1 -jnp.sin(3.14) # Has type float64 -tf.math.sin(3.14) # Has type float32 - -# The jax2tf.convert function behaves like the JAX function. -jax2tf.convert(jnp.sin)(3.14) # Has type float64 - -# The following will compute `sin` in float64. -tf.function(jax2tf.convert(jnp.sin))(tf.Variable(3.14, tf.float64)) - -# The following will compute `sin` in float32. -tf.function(jax2tf.convert(jnp.sin))(tf.Variable(3.14)) -``` - -This is achieved by inserting `tf.cast` operations -on the input arguments inside the converted function, -if necessary. +There is currently no support for `pmap` or`xmap`, nor for the collective +operations. There is support for `pjit`. -If you want to create a `tf.Variable` or `tf.TensorSpec` with the -same dtype, you should use `jax2tf.dtype_of_val`: +### SavedModel supports only first-order gradients -```python -# The following two calls will convert jax_fun at the same dtypes -# independently of the value of JAX_ENABLE_X64. -jax2tf.convert(jax_fun)(3.14) -jax2tf.convert(jax_fun)(tf.Variable(3.14, dtype=jax2tf.dtype_of_val(3.14)) -``` +The `jax2tf`-lowered function supports higher-order gradients, but when the +function is saved in a SavedModel, only the first-order gradient is saved. +This is primarily a limitation of the SavedModel support for custom gradients. ### Slow implementation of associative reductions for CPU -Operations like ``jax.numpy.cumsum`` are compiled by JAX differently based -on the platform. For TPU, the compilation uses the [HLO ReduceWindow](https://www.tensorflow.org/xla/operation_semantics#reducewindow) +Operations like ``jax.numpy.cumsum`` are lowered by JAX differently based +on the platform. For TPU, the lowering uses the [HLO ReduceWindow](https://www.tensorflow.org/xla/operation_semantics#reducewindow) operation, which has an efficient implementation for the cases when the reduction function is associative. For CPU and GPU, JAX uses an alternative -implementation using [associative scans](https://github.com/google/jax/blob/f08bb50bfa9f6cf2de1f3f78f76e1aee4a78735d/jax/_src/lax/control_flow.py#L2801). +lowering using [associative scans](https://github.com/google/jax/blob/f08bb50bfa9f6cf2de1f3f78f76e1aee4a78735d/jax/_src/lax/control_flow.py#L2801). jax2tf uses the TPU lowering (because it does not support backend-specific lowering) and hence it can be slow in some cases on CPU and GPU. @@ -914,100 +948,51 @@ Use this only if it improves the performance for your application. Note that this lowering may not work as well as the default one in presence of shape polymorphism. -### Unchecked assumption that the dimension variables take strictly positive values - -The shape polymorphic conversion is sound with the assumption that the dimension -variables take non-zero values. In the following example, the function to be converted -has different behavior for empty shapes. The broken assumption is caught by jax2tf if -the converted function is executed eagerly, but not if it is first traced to a -TensorFlow graph: - -```python -def f_jax(x): - return 0 if x.shape[0] == 0 else 1 - -x0 = np.array([], np.float32) -self.assertEqual(0, f_jax(x0)) # JAX sees that the x.shape[0] == 0 - -# jax2tf catches the broken assumption b >= 1 if the converted function is executed -# eagerly. -# Raises: ValueError: Dimension variable b must have integer value >= 1. Found value 0 when solving b == 0 -jax2tf.convert(f_jax, polymorphic_shapes=["b"])(x0)) - -# However, if we first trace to a TensorFlow graph, we may miss the broken assumption: -f_tf = tf.function( - jax2tf.convert(f_jax, polymorphic_shapes=["b"])).get_concrete_function(tf.TensorSpec([None], dtype=np.float32)) -self.assertEqual(1, f_tf(x0)) -``` - -Another possible source of unsoundness is that JAX assumes that all unknown -dimensions represented by the same dimension variable have equal size. As before, -this assumption is checked if the converted function is executed eagerly, but -it may be missed if it is first traced to a TensorFlow graph: - -```python -def f_jax(x): - return 0 if x.shape[0] != x.shape[1] else 1 - -x45 = np.ones((4, 5), dtype=np.float32) -self.assertEqual(0, f_jax(x45)) # JAX seems that x.shape[0] != x.shape[1] - -# jax2tf catches the broken assumption x.shape[0] == x.shape[1] if the converted -# function is executed eagerly. -# Raises: ValueError: polymorphic shape ('b, b',) has dimension variable 'b' corresponding to multiple values {4, 5}, for argument shapes (TensorShape([4, 5]),) -jax2tf.convert(f_jax, polymorphic_shapes=["b, b"])(x45) - -# However, if we first trace to a TensorFlow graph, we may miss the broken assumption. -f_tf = tf.function( - jax2tf.convert(f_jax, polymorphic_shapes=["b, b"])).get_concrete_function(tf.TensorSpec([None, None], dtype=np.float32)) -self.assertEqual(1, f_tf(x45)) -``` - ### TensorFlow XLA ops -For most JAX primitives there is a natural TF op that fits the needed semantics. +For most JAX primitives there is a natural TensorFlow op that fits the needed semantics. There are a few (listed below) JAX primitives for which there is no -single TF op with matching semantics. +single TensorFlow op with matching semantics. This is not so surprising, because JAX primitives have been designed to be compiled to [HLO ops](https://www.tensorflow.org/xla/operation_semantics), -while the corresponding TF ops are sometimes higher-level. -For the cases when there is no matching canonical TF op, -we use a set of special TF ops that are thin wrappers over HLO ops +while the corresponding TensorFlow ops are sometimes higher-level. +For the cases when there is no matching canonical TensorFlow op, +we use a set of special TensorFlow ops that are thin wrappers over HLO ops (a subset of those registered in [tf2xla/ops/xla_ops.cc](https://github.com/tensorflow/tensorflow/blob/master/tensorflow/compiler/tf2xla/ops/xla_ops.cc) and implemented in, e.g., [tf2xla/kernels/xla_pad_op.cc](https://github.com/tensorflow/tensorflow/blob/master/tensorflow/compiler/tf2xla/kernels/xla_pad_op.cc).) -We refer to these ops here as the XLA TF ops. Note that these are +We refer to these ops here as the XLA TensorFlow ops. Note that these are still regular TF ops, e.g., they can be saved in a SavedModel. -There are several drawbacks of using XLA TF ops: +There are several drawbacks of using XLA TensorFlow ops: * These ops will only be executable by a consumer that has XLA linked in. This should not be a problem for TPU execution, since that requires XLA anyway. * These ops are not yet recognized by tools that process tf.Graph, e.g., TensorFlow.js converter or the TensorFlow Lite converter. -As an experimental feature we implemented alternative conversions to avoid the XLA TF ops. +As an experimental feature we implemented alternative conversions to avoid the XLA TensorFlow ops. You can enable this with the `enable_xla=False` parameter to `jax2tf.convert`. For more details see [no_xla_limitations.md](g3doc/no_xla_limitations.md). ### Different performance characteristics -The converted code may have slightly different performance characteristics than +The lowered code may have slightly different performance characteristics than the original JAX code. -We do expect that the performance characteristics of converted code -should approximate those of JAX when used with the XLA compiler (`tf.function(jit_compile=True)`). +We do expect that the performance characteristics of lowered code +should be the same as those of JAX when used with the XLA compiler (`tf.function(jit_compile=True)`). This is because -during conversion we try to generate one TensorFlow op for one JAX primitive. +during lowering we try to generate one TensorFlow op for one JAX primitive. We expect that the lowering that XLA does is similar to that done by JAX before conversion. (This is a hypothesis, we have not yet verified it extensively.) -There is one know case when the performance of the converted code will be different. +There is one know case when the performance of the lowered code will be different. JAX programs use a [stateless deterministic PRNG](https://github.com/google/jax/blob/main/docs/design_notes/prng.md) and it has an internal JAX primitive for it. -This primitive is at the moment converted to a soup of tf.bitwise operations, +This primitive is at the moment lowered to a soup of tf.bitwise operations, which has a clear performance penalty. We plan to look into using the HLO [RNGBitGenerator](https://www.tensorflow.org/xla/operation_semantics#rngbitgenerator) (exposed as a TFXLA op), which does implement @@ -1025,38 +1010,60 @@ a custom C++ “high-level” kernel implementing batch normalization is execute In JAX, there is no primitive for batch normalization, and instead the operation is decomposed into low-level primitives (e.g., [flax.linen.BatchNorm](https://flax.readthedocs.io/en/latest/_autosummary/flax.linen.BatchNorm.html), or haiku.BatchNorm). -Once those primitives are converted to TensorFlow, and the resulting code is +Once those primitives are lowered to TensorFlow, and the resulting code is run without XLA, the ensemble of the kernels executed will quite possibly behave differently, performance-wise or even numerically, than either the TensorFlow native or JAX native batch normalization. A similar example is that of an LSTM cell. -### Errors due to tf.Module magic conversion during attribute assignment - -tf.Module will automatically wrap the standard Python container data types into -trackable classes during attribute assignment. -Python Dict/List/Tuple are changed to _DictWrapper/_ListWrapper/_TupleWrapper -classes. -In most situation, these Wrapper classes work exactly as the standard -Python data types. However, the low-level pytree data structures are different -and this can lead to errors. +### Unchecked assumption that the dimension variables take strictly positive values -In such cases, the user can use this walkaround: +The shape polymorphic conversion is sound with the assumption that the dimension +variables take non-zero values. In the following example, the function to be lowered +has different behavior for empty shapes. The broken assumption is caught by jax2tf if +the lowered function is executed eagerly, but not if it is first traced to a +TensorFlow graph: ```python -import tensorflow as tf -input_data = #Any data object +def f_jax(x): + return 0 if x.shape[0] == 0 else 1 -m = tf.Module() -flat, tree_def = jax.tree_util.tree_flatten(input_data) -m.input_data = {"flat": flat, "tree_def": tree_def} +x0 = np.array([], np.float32) +self.assertEqual(0, f_jax(x0)) # JAX sees that the x.shape[0] == 0 + +# jax2tf catches the broken assumption b >= 1 if the lowered function is executed +# eagerly. +# Raises: ValueError: Dimension variable b must have integer value >= 1. Found value 0 when solving b == 0 +jax2tf.convert(f_jax, polymorphic_shapes=["b"])(x0)) + +# However, if we first trace to a TensorFlow graph, we may miss the broken assumption: +f_tf = tf.function( + jax2tf.convert(f_jax, polymorphic_shapes=["b"])).get_concrete_function(tf.TensorSpec([None], dtype=np.float32)) +self.assertEqual(1, f_tf(x0)) ``` -Later the user can use `tree_unflatten` for the reverse process: +Another possible source of unsoundness is that JAX assumes that all unknown +dimensions represented by the same dimension variable have equal size. As before, +this assumption is checked if the lowered function is executed eagerly, but +it may be missed if it is first traced to a TensorFlow graph: ```python -input_data = jax.tree_util.tree_unflatten(m.input_data['tree_def'], m.input_data['flat']) +def f_jax(x): + return 0 if x.shape[0] != x.shape[1] else 1 + +x45 = np.ones((4, 5), dtype=np.float32) +self.assertEqual(0, f_jax(x45)) # JAX seems that x.shape[0] != x.shape[1] + +# jax2tf catches the broken assumption x.shape[0] == x.shape[1] if the lowered +# function is executed eagerly. +# Raises: ValueError: polymorphic shape ('b, b',) has dimension variable 'b' corresponding to multiple values {4, 5}, for argument shapes (TensorShape([4, 5]),) +jax2tf.convert(f_jax, polymorphic_shapes=["b, b"])(x45) + +# However, if we first trace to a TensorFlow graph, we may miss the broken assumption. +f_tf = tf.function( + jax2tf.convert(f_jax, polymorphic_shapes=["b, b"])).get_concrete_function(tf.TensorSpec([None, None], dtype=np.float32)) +self.assertEqual(1, f_tf(x45)) ``` # Calling TensorFlow functions from JAX diff --git a/jax/experimental/jax2tf/g3doc/convert_models_results.md b/jax/experimental/jax2tf/g3doc/convert_models_results.md index 659abd6fde22..bab2eeba6411 100644 --- a/jax/experimental/jax2tf/g3doc/convert_models_results.md +++ b/jax/experimental/jax2tf/g3doc/convert_models_results.md @@ -120,4 +120,4 @@ support. After that, it converts the SavedModel to TFLite using the ### `jax2tflite+flex` This is similar to the `jax2tflite` path, but then links in the Select ops. See -[here](https://www.tensorflow.org/lite/guide/ops_select) for more details. \ No newline at end of file +[here](https://www.tensorflow.org/lite/guide/ops_select) for more details.