Skip to content

Commit

Permalink
[jax2tf] Fix stale comments
Browse files Browse the repository at this point in the history
  • Loading branch information
gnecula committed Sep 12, 2021
1 parent ab544cb commit b59db5b
Show file tree
Hide file tree
Showing 2 changed files with 15 additions and 24 deletions.
21 changes: 7 additions & 14 deletions jax/experimental/jax2tf/examples/saved_model_lib.py
Expand Up @@ -43,8 +43,9 @@ def convert_and_save_model(
saved_model_options: Optional[tf.saved_model.SaveOptions] = None):
"""Convert a JAX function and saves a SavedModel.
This is an example, for serious uses you will likely want to copy and
expand it as needed (see note at the top of the model).
This is an example, we do not promise backwards compatibility for this code.
For serious uses, please copy and and expand it as needed (see note at the top
of the module).
Use this function if you have a trained ML model that has both a prediction
function and trained parameters, which you want to save separately from the
Expand Down Expand Up @@ -74,14 +75,10 @@ def convert_and_save_model(
the default serving signature. The additional signatures will be used
only to ensure that the `jax_fn` is traced and converted to TF for the
corresponding input shapes.
with_gradient: whether the SavedModel should support gradients. If True,
then a custom gradient is saved. If False, then a
tf.raw_ops.PreventGradient is saved to error if a gradient is attempted.
(At the moment due to a bug in SavedModel, custom gradients are not
supported.)
enable_xla: whether the jax2tf converter is allowed to use TFXLA ops. If
False, the conversion tries harder to use purely TF ops and raises an
exception if it is not possible. (default: True)
with_gradient: the value to use for the `with_gradient` parameter for
`jax2tf.convert`.
enable_xla: the value to use for the `enable_xla` parameter for
`jax2tf.convert`.
compile_model: use TensorFlow jit_compiler on the SavedModel. This
is needed if the SavedModel will be used for TensorFlow serving.
polymorphic_shapes: if given then it will be used as the
Expand All @@ -105,10 +102,6 @@ def convert_and_save_model(
# Create tf.Variables for the parameters. If you want more useful variable
# names, you can use `tree.map_structure_with_path` from the `dm-tree` package
param_vars = tf.nest.map_structure(
# Due to a bug in SavedModel it is not possible to use tf.GradientTape on
# a function converted with jax2tf and loaded from SavedModel. Thus, we
# mark the variables as non-trainable to ensure that users of the
# SavedModel will not try to fine tune them.
lambda param: tf.Variable(param, trainable=with_gradient),
params)
tf_graph = tf.function(lambda inputs: tf_fn(param_vars, inputs),
Expand Down
18 changes: 8 additions & 10 deletions jax/experimental/jax2tf/jax2tf.py
Expand Up @@ -226,22 +226,20 @@ def convert(fun: Callable,
for more details.
in_shapes: DEPRECATED in favor of `polymorphic_shapes`.
with_gradient: if set, will add a tf.custom_gradient to the converted
function, by converting the ``jax.vjp(fun)``. Only first-order
differentiation is supported for now. If the converted function is saved
in a SavedModel, the custom gradients are currently lost and an error will
be raised if a gradient computation is attempted. This is due to a current
bug in TensorFlow.
with_gradient: if set (default), add a tf.custom_gradient to the converted
function, by converting the ``jax.vjp(fun)``. This means that reverse-mode
TensorFlow AD is supported for the output TensorFlow function, and the
value of the gradient will be JAX-accurate.
enable_xla: if set (default), the converter will use the simplest conversion
and use XLA TF ops when necessary. These ops are known to create issues
for the TFLite and TFjs converters. For those cases, unset this parameter
so the converter tries harder to use non-XLA TF ops to convert the function,
and raises an error if it can not be converted
without resorting to XLA ops.
so the converter tries harder to use non-XLA TF ops to convert the
function and aborts if this is not possible.
Returns:
A version of `fun` that expects TfVals as arguments (or
tuple/lists/dicts) thereof, and returns TfVals as outputs.
tuple/lists/dicts) thereof, and returns TfVals as outputs, and uses
only TensorFlow ops.
"""
api._check_callable(fun)
fun_name = getattr(fun, "__name__", "unknown")
Expand Down

0 comments on commit b59db5b

Please sign in to comment.