Skip to content

Commit

Permalink
[jax2tf] Improve the documentation for jax2tf and SavedModel
Browse files Browse the repository at this point in the history
In particular, document better how to avoid embedding large constants
in the SavedModel.
  • Loading branch information
gnecula committed Nov 4, 2021
1 parent 752823e commit 27de09c
Showing 1 changed file with 44 additions and 15 deletions.
59 changes: 44 additions & 15 deletions jax/experimental/jax2tf/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -103,11 +103,12 @@ tf.saved_model.save(my_model, '/some/directory',
restored_model = tf.saved_model.load('/some/directory')
```

An important point is that in the above code snippet **everything is standard
TensorFlow code. In particular, the saving of the model is not directly part
An important point is that in the above code snippet **everything after the
jax2tf conversion is standard TensorFlow code.
In particular, the saving of the model is not directly part
of the jax2tf API, and the user has full control over how to create the SavedModel**.

Just like for regular TensorFlow functions, it is possible to include in the
For example, just like for regular TensorFlow functions, it is possible to include in the
SavedModel multiple versions of a function for different input shapes, by
"warming up" the function on different input shapes:

Expand All @@ -119,29 +120,41 @@ tf.saved_model.save(my_model, '/some/directory',
options=tf.saved_model.SaveOptions(experimental_custom_gradients=True))
```

Note that if the JAX function is not reverse-mode differentiable, e.g., uses `lax.while_loop` then
attempting to save its conversion to a SavedModel will fail with
```
ValueError: Error when tracing gradients for SavedModel
```

You have two options, either pass `enable_gradients=False` to `jax2tf.convert`, or
set `tf.saved_model.SaveOption(experimental_custom_gradients=False)`. In either case,
you will not be able to compute the gradients of the function loaded from the SavedModel.
### Saved model with parameters

Some special care is needed to ensure that the model parameters are not embedded
as constants in the graph and are instead saved separately as variables.
This is useful for two reasons:
the parameters could be very large and exceed the limits of the
the parameters could be very large and exceed the 2GB limits of the
GraphDef part of the SavedModel, or you may want to fine-tune the
model and change the value of the parameters.

For example, consider the following function:
```python
def model_jax(inputs):
return param0 + param1 * inputs
```

If you just convert and save the model directly, the values of
`param0` and `param1` will be embedded in the computation graph. In fact, the
value of `param1` is needed for the gradient computation and
will be embedded twice: once in the computation
graph for the forward computation and once for the backward computation,
unless you turn off the conversion of gradients or their saving as discussed
further below (e.g., `with_gradient=False`). Note also that if one
views the above function as an ML model parameterized by `param0` and `param1`
then the gradient function will be w.r.t. the inputs, while you probably
want gradients w.r.t. the parameters.

A better way to deal with parameters (or any large constants) is to
pass them as parameters to the function to be converted:
```
def model_jax(params, inputs):
return params[0] + params[1] * inputs
# Wrap the parameter constants as tf.Variables; this will signal to the model
# saving code to save those constants as variables.
# saving code to save those constants as variables, separate from the
# computation graph.
params_vars = tf.nest.map_structure(tf.Variable, params)
# Build the prediction function by closing over the `params_vars`. If you
Expand All @@ -156,11 +169,15 @@ my_model.f = tf.function(prediction_tf, jit_compile=True)
tf.saved_model.save(my_model)
```

This strategy will avoid any copies of the large parameters in the computation
graph (they will be saved in a `variables` area of the model, which is not
subject to the 2GB limitation).

For examples of how to save a Flax model as a SavedModel see the
[examples directory](https://github.com/google/jax/blob/main/jax/experimental/jax2tf/examples/README.md).


## Differentiation
### Saved model and differentiation

The converted code supports differentiation from TensorFlow. In order to
ensure that the result of TensorFlow differentiation is identical to the
Expand Down Expand Up @@ -209,6 +226,18 @@ The graph tensor has name: args_0:0

(We are working with the TF team to give a more explicit error in this case.)

### Saved model for non-differentiable JAX functions

Note that if the JAX function is not reverse-mode differentiable, e.g., uses `lax.while_loop` then
attempting to save its conversion to a SavedModel will fail with
```
ValueError: Error when tracing gradients for SavedModel
```

You have two options, either pass `enable_gradients=False` to `jax2tf.convert`, or
set `tf.saved_model.SaveOption(experimental_custom_gradients=False)`. In either case,
you will not be able to compute the gradients of the function loaded from the SavedModel.


## Shape-polymorphic conversion

Expand Down

0 comments on commit 27de09c

Please sign in to comment.