Skip to content

Commit

Permalink
[jax2tf] Fixed stale documentation about XLA metadata.
Browse files Browse the repository at this point in the history
jax2tf does not yet support passing source location information
through to TF. The mechanism is partially implemented but disabled.
Here we remove misleading documentation that suggests the mechanism
is enabled.
  • Loading branch information
gnecula committed Feb 16, 2022
1 parent c49fb9c commit 461b37b
Show file tree
Hide file tree
Showing 2 changed files with 3 additions and 8 deletions.
7 changes: 0 additions & 7 deletions jax/experimental/jax2tf/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -606,13 +606,6 @@ operations. There is support for `sharded_jit` and `pjit`.
If you suspect that the SavedModel is larger than it should be, check first
that you are not including the parameters as constants in the graph (see [above](#usage-saved-model)).

Additionally, the SavedModel obtained from a `jax2tf.convert`-ed function may include source
location information. This ensures that the debugging experience is similar
for JAX with XLA vs. `jax2tf.convert` with XLA. However, this debugging information
increases the size of the SavedModel, even possibly doubling it. You can
disable the generation of this metadata with the parameter
`include_xla_op_metadata`.

### SavedModel supports only first-order gradients

The `jax2tf`-converted function supports higher-order gradients, but when the
Expand Down
4 changes: 3 additions & 1 deletion jax/experimental/jax2tf/jax2tf.py
Original file line number Diff line number Diff line change
Expand Up @@ -163,7 +163,8 @@ def __init__(self):
self.shape_env: Sequence[Tuple[str, TfVal]] = ()

# Whether to actually include XLA op metadata in the generated TF ops
self.include_xla_op_metadata = True
# TODO(b/189306134): implement support for XLA metadata
self.include_xla_op_metadata = False

# A cache for the tf.convert_to_tensor for constants. We try to preserve
# sharing for constants, to enable tf.Graph to take advantage of it.
Expand Down Expand Up @@ -415,6 +416,7 @@ def fix_in_ct(in_ct, arg_aval: core.ShapedArray):
_thread_local_state.enable_xla = enable_xla

prev_include_xla_op_metadata = _thread_local_state.include_xla_op_metadata
# TODO(b/189306134): implement support for XLA metadata
_thread_local_state.include_xla_op_metadata = False

_thread_local_state.shape_env = shape_env
Expand Down

0 comments on commit 461b37b

Please sign in to comment.