Skip to content

Commit

Permalink
[jax2tf] Refactored the handling of float0.
Browse files Browse the repository at this point in the history
JAX and TF have different ways of dealing with tangents and co-tangents for
exact types. JAX uses float0 values. TF sometimes uses None, sometines
integer (or boolean) zeros. In the JAX VJP function we convert the
None's to zeros. On exit from the VJP function we convert the float0
to zeros.
  • Loading branch information
gnecula committed Jul 28, 2021
1 parent e7f0307 commit c73c0ad
Show file tree
Hide file tree
Showing 7 changed files with 343 additions and 247 deletions.
35 changes: 27 additions & 8 deletions jax/experimental/jax2tf/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -572,24 +572,43 @@ disable the generation of this metadata with the parameter
The `jax2tf`-converted function supports higher-order gradients, but when the
function is saved in a SavedModel, only the first-order gradient is saved.

### Converting gradients for integer-argument functions
### Converting gradients for functions with integer arguments or unused arguments

When JAX differentiates over functions with integer arguments, the gradients will
When JAX differentiates functions with integer or boolean arguments, the gradients will
be zero-vectors with a special `float0` type (see PR 4039](https://github.com/google/jax/pull/4039)).
This type is translated to `bfloat16` when converting to TF. For example,
This type is translated to `int32` when converting to TF.
For example,

```python
def f_jax(x): # x: int32
x = np.int16(2)
def f_jax(x): # x: int16
return x * 2.

jax.grad(f_jax, allow_int=True)(2)
jax.grad(f_jax, allow_int=True)(x)
# returns a special `float0`: array((b'',), dtype=[('float0', 'V')])

jax2tf.convert(jax.grad(f_jax, allow_int=True))(2))
# returns a `bfloat16` zero: tf.Tensor(0, shape=(), dtype=bfloat16)
jax2tf.convert(jax.grad(f_jax, allow_int=True))(x))
# returns a tf.Tensor(0, shape=(), dtype=int32)
```

### Different behavior for gradients for unused arguments
Note that this is different from how TensorFlow handles gradients
for integer or boolean arguments: sometimes the gradient is `None`,
sometimes it is a zero with the same dtype as the argument, and
sometimes it is a one with the same dtype as the argument (e.g.,
for the identity function).

```python
def f_tf(x): # x: int16
return tf.cast(x, tf.float32) * 2.

xv = tf.Variable(x)
with tf.GradientTape(persistent=True) as tape:
print(tape.gradient(f_tf(xv), xv))
# returns None
print(tape.gradient(f_tf(xv), xv,
unconnected_gradients=tf.UnconnectedGradients.ZERO))
# returns 0 with the same shape and dtype as x
```

When differentiating functions with unused arguments, TF by default
returns the value `None` for the corresponding gradients. The
Expand Down
2 changes: 1 addition & 1 deletion jax/experimental/jax2tf/call_tf.py
Original file line number Diff line number Diff line change
Expand Up @@ -374,7 +374,7 @@ def canonical_res_aval(res_shape: xla.XlaShape) -> core.ShapedArray:
# call_tf is a multiple_results primitive.
result_shapes = (result_shape,)
else:
result_shapes = result_shape.tuple_shapes()
result_shapes = result_shape.tuple_shapes() # type: ignore

result_avals = tuple(map(canonical_res_aval, result_shapes)) # type: ignore

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
import grpc # type: ignore[import]
import json
import logging
import requests
import requests # type: ignore[import]

from absl import app
from absl import flags
Expand Down

0 comments on commit c73c0ad

Please sign in to comment.