Skip to content

Commit

Permalink
[jax2tf] Add more documentation about saving models with custom gradi…
Browse files Browse the repository at this point in the history
…ents
  • Loading branch information
gnecula committed Jun 15, 2021
1 parent c169ee3 commit 2888e7c
Show file tree
Hide file tree
Showing 5 changed files with 167 additions and 25 deletions.
32 changes: 30 additions & 2 deletions jax/experimental/jax2tf/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -95,7 +95,8 @@ is trivial:
my_model = tf.Module()
# Save a function that can take scalar inputs.
my_model.f = tf.function(jax2tf.convert(f_jax), input_signature=[tf.TensorSpec([], tf.float32)])
tf.saved_model.save(my_model, '/some/directory')
tf.saved_model.save(my_model, '/some/directory',
options=tf.saved_model.SaveOptions(experimental_custom_gradients=True))

# Restoring (note: the restored model does *not* require JAX to run, just XLA).
restored_model = tf.saved_model.load('/some/directory')
Expand All @@ -113,7 +114,8 @@ SavedModel multiple versions of a function for different input shapes, by
my_model.f = tf.function(jax2tf.convert(f_jax), autograph=False)
my_model.f(tf.ones([1, 28, 28])) # a batch size of 1
my_model.f(tf.ones([16, 28, 28])) # a batch size of 16
tf.saved_model.save(my_model, '/some/directory')
tf.saved_model.save(my_model, '/some/directory',
options=tf.saved_model.SaveOptions(experimental_custom_gradients=True))
```

For examples of how to save a Flax model as a SavedModel see the
Expand Down Expand Up @@ -144,6 +146,32 @@ options = tf.saved_model.SaveOptions(experimental_custom_gradients=True)
tf.saved_model.save(model, path, options=options)
```

If you use `with_gradient=True` and forget to use the `experimental_custom_gradients=True` parameter
to `tf.saved_model.save` when you later load the saved model you will see a warning:

```
WARNING:absl:Importing a function (__inference_converted_fun_25) with ops with unsaved custom gradients. Will likely fail if a gradient is requested.
```

and if you do attempt to take a gradient of the loaded model you may get an error:

```
TypeError: An op outside of the function building code is being passed
a "Graph" tensor. It is possible to have Graph tensors
leak out of the function building context by including a
tf.init_scope in your function building code.
For example, the following function will fail:
@tf.function
def has_init_scope():
my_constant = tf.constant(1.)
with tf.init_scope():
added = my_constant * 2
The graph tensor has name: args_0:0
```

(We are working with the TF team to give a more explicit error in this case.)


## Shape-polymorphic conversion

**The shape polymorphism support is work in progress. It is meant to be sound,
Expand Down
11 changes: 8 additions & 3 deletions jax/experimental/jax2tf/examples/saved_model_lib.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@ def convert_and_save_model(
with_gradient: bool = False,
enable_xla: bool = True,
compile_model: bool = True,
save_model_options: Optional[tf.saved_model.SaveOptions] = None):
saved_model_options: Optional[tf.saved_model.SaveOptions] = None):
"""Convert a JAX function and saves a SavedModel.
This is an example, for serious uses you will likely want to copy and
Expand Down Expand Up @@ -89,7 +89,7 @@ def convert_and_save_model(
`polymorphic_shapes` argument to jax2tf.convert for the second parameter of
`jax_fn`. In this case, a single `input_signatures` is supported, and
should have `None` in the polymorphic dimensions.
save_model_options: options to pass to savedmodel.save.
saved_model_options: options to pass to savedmodel.save.
"""
if not input_signatures:
raise ValueError("At least one input_signature must be given")
Expand Down Expand Up @@ -124,8 +124,13 @@ def convert_and_save_model(
# If there are more signatures, trace and cache a TF function for each one
tf_graph.get_concrete_function(input_signature)
wrapper = _ReusableSavedModelWrapper(tf_graph, param_vars)
if with_gradient:
if not saved_model_options:
saved_model_options = tf.saved_model.SaveOptions(experimental_custom_gradients=True)
else:
saved_model_options.experimental_custom_gradients = True
tf.saved_model.save(wrapper, model_dir, signatures=signatures,
options=save_model_options)
options=saved_model_options)


class _ReusableSavedModelWrapper(tf.train.Checkpoint):
Expand Down
42 changes: 42 additions & 0 deletions jax/experimental/jax2tf/tests/call_tf_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -415,6 +415,48 @@ def g(x):
self.assertAllClose(g(x), g_rt(x))
self.assertAllClose(jax.grad(g)(x), jax.grad(g_rt)(x))

def test_round_trip_without_gradient_saved_model(self):
# Explicitly with_gradient=False
f_jax = jnp.sum

x = np.array([0.7, 0.8], dtype=np.float32)
f_tf = tf_test_util.SaveAndLoadFunction(
jax2tf.convert(f_jax, with_gradient=False),
[tf.TensorSpec(x.shape, dtype=x.dtype)])
f_rt = jax2tf.call_tf(f_tf)

self.assertAllClose(f_jax(x), f_rt(x))
with self.assertRaisesRegex(Exception,
"Gradient explicitly disabled.*jax2tf-converted function does not support gradients. Use `with_gradient` parameter to enable gradients"):
jax.grad(f_rt)(x)

def test_round_trip_saved_model_no_gradients(self):
# Save without gradients
f_jax = jnp.sum

x = np.array([0.7, 0.8], dtype=np.float32)
f_tf = tf_test_util.SaveAndLoadFunction(
jax2tf.convert(f_jax, with_gradient=True),
[tf.TensorSpec(x.shape, dtype=x.dtype)],
save_gradients=False)
f_rt = jax2tf.call_tf(f_tf)

self.assertAllClose(f_jax(x), f_rt(x))
# TODO: clean this up b/191117111: it should fail with a clear error
# The following results in a confusing error:
# TypeError: An op outside of the function building code is being passed
# a "Graph" tensor. It is possible to have Graph tensors
# leak out of the function building context by including a
# tf.init_scope in your function building code.
# For example, the following function will fail:
# @tf.function
# def has_init_scope():
# my_constant = tf.constant(1.)
# with tf.init_scope():
# added = my_constant * 2
# The graph tensor has name: args_0:0
# g = jax.grad(f_rt)(x)

def test_module_documentation(self):
def cos_tf(x):
return tf.math.cos(x)
Expand Down
97 changes: 81 additions & 16 deletions jax/experimental/jax2tf/tests/savedmodel_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,26 +42,37 @@ def test_eval(self):
restored_model = tf_test_util.SaveAndLoadModel(model)
self.assertAllClose(restored_model.f(x), f_jax(x))

def test_gradient_disabled(self):
f_jax = lambda x: x * x
def test_gradient(self):
"""Save and restore the custom gradient."""
@jax.custom_jvp
def f_jax(x):
return x * x

@f_jax.defjvp
def f_jax_jvp(primals, tangents):
# 3 * x * x_t
x, = primals
x_dot, = tangents
primal_out = f_jax(x)
tangent_out = x * x_dot * 3.
return primal_out, tangent_out

model = tf.Module()
model.f = tf.function(jax2tf.convert(f_jax, with_gradient=False),
model.f = tf.function(jax2tf.convert(f_jax, with_gradient=True),
autograph=False,
input_signature=[tf.TensorSpec([], tf.float32)])
x = np.array(0.7, dtype=jnp.float32)
self.assertAllClose(model.f(x), f_jax(x))
restored_model = tf_test_util.SaveAndLoadModel(model)
xv = tf.Variable(0.7, dtype=jnp.float32)
xv = tf.Variable(x)
self.assertAllClose(restored_model.f(x), f_jax(x))
with tf.GradientTape() as tape:
y = restored_model.f(xv)
self.assertAllClose(tape.gradient(y, xv).numpy(),
jax.grad(f_jax)(x))

with self.assertRaisesRegex(LookupError,
"Gradient explicitly disabled.*The jax2tf-converted function does not support gradients"):
with tf.GradientTape():
_ = restored_model.f(xv)

def test_gradient(self):
"""Save and restore the custom gradient."""
def test_gradient_nested(self):
"""Save and restore the custom gradient, when combined with other TF code."""
@jax.custom_jvp
def f_jax(x):
return x * x
Expand All @@ -76,18 +87,72 @@ def f_jax_jvp(primals, tangents):
return primal_out, tangent_out

model = tf.Module()
model.f = tf.function(jax2tf.convert(f_jax, with_gradient=True),
# After conversion, we wrap with some pure TF code
model.f = tf.function(lambda x: tf.math.sin(jax2tf.convert(f_jax, with_gradient=True)(x)),
autograph=False,
input_signature=[tf.TensorSpec([], tf.float32)])
f_jax_equiv = lambda x: jnp.sin(f_jax(x))
x = np.array(0.7, dtype=jnp.float32)
self.assertAllClose(model.f(x), f_jax(x))
self.assertAllClose(model.f(x), f_jax_equiv(x))
restored_model = tf_test_util.SaveAndLoadModel(model)
xv = tf.Variable(0.7, dtype=jnp.float32)
self.assertAllClose(restored_model.f(x), f_jax(x))
xv = tf.Variable(x)
self.assertAllClose(restored_model.f(x), f_jax_equiv(x))
with tf.GradientTape() as tape:
y = restored_model.f(xv)
self.assertAllClose(tape.gradient(y, xv).numpy(),
jax.grad(f_jax)(x).astype(np.float32))
jax.grad(f_jax_equiv)(x))

def test_gradient_disabled(self):
f_jax = lambda x: x * x

model = tf.Module()
model.f = tf.function(jax2tf.convert(f_jax, with_gradient=False),
autograph=False,
input_signature=[tf.TensorSpec([], tf.float32)])
x = np.array(0.7, dtype=jnp.float32)
self.assertAllClose(model.f(x), f_jax(x))
restored_model = tf_test_util.SaveAndLoadModel(model)
xv = tf.Variable(0.7, dtype=jnp.float32)
self.assertAllClose(restored_model.f(x), f_jax(x))

with self.assertRaisesRegex(LookupError,
"Gradient explicitly disabled.*The jax2tf-converted function does not support gradients"):
with tf.GradientTape():
_ = restored_model.f(xv)

def test_save_without_gradients(self):
f_jax = lambda x: x * x

x = np.array(0.7, dtype=jnp.float32)
model = tf.Module()
model.f = tf.function(jax2tf.convert(f_jax, with_gradient=True),
autograph=False,
input_signature=[tf.TensorSpec(x.shape, x.dtype)])

self.assertAllClose(model.f(x), f_jax(x))
restored_model = tf_test_util.SaveAndLoadModel(model,
save_gradients=False)
self.assertAllClose(restored_model.f(x), f_jax(x))

xv = tf.Variable(x)
with tf.GradientTape():
_ = restored_model.f(xv)
# TODO: clean this up b/191117111: it should fail with a clear error
# The following results in a confusing error:
# TypeError: An op outside of the function building code is being passed
# a "Graph" tensor. It is possible to have Graph tensors
# leak out of the function building context by including a
# tf.init_scope in your function building code.
# For example, the following function will fail:
# @tf.function
# def has_init_scope():
# my_constant = tf.constant(1.)
# with tf.init_scope():
# added = my_constant * 2
# The graph tensor has name: args_0:0
# g = tape.gradient(res, xv)
#self.assertAllClose(g.numpy(), jax.grad(f_jax)(x))


def _compare_with_saved_model(self, f_jax, *args):
# Certain ops are converted to ensure an XLA context, e.g.,
Expand Down
10 changes: 6 additions & 4 deletions jax/experimental/jax2tf/tests/tf_test_util.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,23 +75,25 @@ class OpMetadataGraph:
source_line: str


def SaveAndLoadModel(model: tf.Module) -> tf.Module:
def SaveAndLoadModel(model: tf.Module,
save_gradients=True) -> tf.Module:
# Roundtrip through saved model on disk.
model_dir = os.path.join(absltest.get_default_test_tmpdir(), str(id(model)))
tf.saved_model.save(
model, model_dir,
options=tf.saved_model.SaveOptions(experimental_custom_gradients=True))
options=tf.saved_model.SaveOptions(experimental_custom_gradients=save_gradients))
restored_model = tf.saved_model.load(model_dir)
return restored_model

def SaveAndLoadFunction(f_tf: Callable,
input_signature: Sequence[tf.TensorSpec]) -> Callable:
input_signature: Sequence[tf.TensorSpec],
save_gradients=True) -> Callable:
# Roundtrip through saved model on disk
model = tf.Module()
model.f = tf.function(f_tf,
autograph=False,
input_signature=input_signature)
restored = SaveAndLoadModel(model)
restored = SaveAndLoadModel(model, save_gradients=save_gradients)
return restored.f


Expand Down

0 comments on commit 2888e7c

Please sign in to comment.