Skip to content

Commit

Permalink
[jax2tf] A few fixed for handling of float0 in jax2tf and call_tf
Browse files Browse the repository at this point in the history
TF returns None or 0 for the gradients of functions with integer
arguments. JAX expects float0. We must convert to and from float0
at the JAX-TF boundary.
  • Loading branch information
gnecula committed Jun 29, 2021
1 parent 238e8d0 commit ffd8fb8
Show file tree
Hide file tree
Showing 5 changed files with 259 additions and 16 deletions.
41 changes: 41 additions & 0 deletions jax/experimental/jax2tf/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -589,6 +589,47 @@ jax2tf.convert(jax.grad(f_jax, allow_int=True))(2))
# returns a `bfloat16` zero: tf.Tensor(0, shape=(), dtype=bfloat16)
```

### Different behavior for gradients for unused arguments

When differentiating functions with unused arguments, TF by default
returns the value `None` for the corresponding gradients. The
`tape.gradient` function takes the option `tf.UnconnectedGradients.ZERO`
to ask that gradients for unused arguments be zero.

Functions converted with `jax2tf.convert` behave the same way under
`tf.UnconnectedGradients.ZERO`, but by default, they will return
`None` only for gradients corresponding to integer arguments.

```
# x1 and x3 are not used. x3 has integer type.
def fn(x0, x1, x2, x3):
return x0 * 0. + x2 * 2.
xs = [tf.Variable(x) for x in [10., 11., 12., 13]]
with tf.GradientTape(persistent=True) as tape:
res = fn(*xs)
g_tf_native = tape.gradient(res, xs)
# Returns: 0., None, 2., None
g_tf_native_0 = tape.gradient(res, xs,
unconnected_gradients=tf.UnconnectedGradients.ZERO)
# Returns: 0., 0., 2., 0
# Now with jax2tf.convert
with tf.GradientTape() as tape:
res = jax2tf.convert(fn, with_gradient=True)(*xs0
g_jax2tf = tape.gradient(res, xs)
# Returns: 0., 0., 2., None
# Note that the gradient for x1 is 0.
g_jaxx2tf_0 = tape.gradient(res, xs,
unconnected_gradients=tf.UnconnectedGradients.ZERO)
# Returns: 0., 0., 2., 0
# In this case we get the same result as for TF native.
```

### Different 64-bit precision in JAX and TensorFlow

JAX behaves somewhat differently than TensorFlow in the handling
Expand Down
14 changes: 13 additions & 1 deletion jax/experimental/jax2tf/call_tf.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@
from jax import numpy as jnp
from jax import tree_util
from jax._src import util
from jax._src import ad_util
from jax.interpreters import xla
from jax.lib import xla_client
from . import jax2tf as jax2tf_internal
Expand Down Expand Up @@ -162,7 +163,18 @@ def replace_non_float(arg):
return dres_darg

# Use call_tf to call the VJP function
return call_tf(tf_vjp_fun)(args_jax, ct_res_jax)
ct_args_jax = call_tf(tf_vjp_fun)(args_jax, ct_res_jax)
# We must make the float0s that JAX expects
def fix_float0(arg_jax, ct_arg_jax):
arg_dtype = dtypes.result_type(arg_jax) # May be scalar
ct_arg_dtype = core.primal_dtype_to_tangent_dtype(arg_dtype)
if ct_arg_dtype != ct_arg_jax.dtype:
return ad_util.zeros_like_aval(core.ShapedArray(np.shape(arg_jax),
ct_arg_dtype))
return ct_arg_jax

ct_args_jax_fixed = tree_util.tree_map(fix_float0, args_jax, ct_args_jax)
return ct_args_jax_fixed

make_call.defvjp(make_call_vjp_fwd, make_call_vjp_bwd)
return util.wraps(callable_tf)(make_call)
Expand Down
14 changes: 13 additions & 1 deletion jax/experimental/jax2tf/jax2tf.py
Original file line number Diff line number Diff line change
Expand Up @@ -349,7 +349,19 @@ def fun_vjp_jax(args_jax, out_cts_jax):
fun_vjp_jax,
with_gradient=False,
polymorphic_shapes=vjp_polymorphic_shapes)(args, out_cts)
return in_cts
# We fix here the float0
# TODO: it would be better to fix these somewhere inside the converter
# because here we are already back in TF.
def fix_float0(arg: TfVal, in_ct: TfVal) -> TfVal:
_, arg_jax_dtype = _tfval_to_tensor_jax_dtype(arg) # Maybe it is a scalar
if np.issubdtype(arg_jax_dtype, np.inexact):
return in_ct
else:
assert in_ct.dtype.as_numpy_dtype == tf.bfloat16
return tf.zeros(arg.shape, arg.dtype)

in_cts_fixed = tf.nest.map_structure(fix_float0, args, in_cts)
return in_cts_fixed

try:
assert not _thread_local_state.shape_env, f"Unexpected shape environment {_thread_local_state.shape_env}"
Expand Down
31 changes: 31 additions & 0 deletions jax/experimental/jax2tf/tests/call_tf_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -363,6 +363,37 @@ def f(param, state, x):
g = jax.grad(lambda *args: jnp.sum(f(*args)[0]))(param, state, x)
self.assertAllClose(g_call_tf, g)

def test_grad_int_argument_unused(self):
batch_size = 5
inputs = np.ones((batch_size, 3), dtype=np.float32)
rng = np.array([1, 2], dtype=np.uint32)
params = np.float32(.5)

# rng is integer, unused
def jax_model(params, rng, inputs):
return jnp.ones([batch_size, 2], dtype=jnp.float32)

tf_model = jax2tf.convert(jax_model, with_gradient=True)

def _loss_fn(inference_fn, params, rng, inputs):
prediction = inference_fn(params, rng, inputs)
return jnp.mean(prediction)

jax_loss_fn = partial(_loss_fn, jax_model)
jax_grad = jax.grad(jax_loss_fn)(params, rng, inputs)

paramsv = tf.Variable(params)
with tf.GradientTape() as tape:
tf_prediction = tf_model(paramsv, rng, inputs)
tf_loss = tf.reduce_mean(tf_prediction)

tf_grad = tape.gradient(tf_loss, paramsv)
self.assertAllClose(jax_grad, tf_grad.numpy())

call_tf_loss_fn = partial(_loss_fn, jax2tf.call_tf(tf_model))
call_tf_grad = jax.grad(call_tf_loss_fn)(params, rng, inputs)
self.assertAllClose(jax_grad, call_tf_grad)

def test_grad_with_float0_result(self):
# Gradient over integer-argument functions, with float0 result
def f_jax(x, y): # x is an int, y is a float; res is a (int, float)
Expand Down
175 changes: 161 additions & 14 deletions jax/experimental/jax2tf/tests/jax2tf_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -352,29 +352,176 @@ def g(x): # x: i32
self.assertAllClose(jnp.zeros(np.shape(d_dx_jax), dtypes.bfloat16),
d_dx_tf.numpy())

@parameterized.named_parameters(jtu.cases_from_list(
dict(testcase_name=f"function={with_function}",
with_function=with_function)
for with_function in [False, True]))
def test_gradients_unused_argument_readme(self, with_function=True):
# x2 and x3 are not used. x3 has integer type.
def fn(x0, x1, x2, x3):
return x0 * 0. + x2 * 2.

def test_tf_gradients_int_argument(self):
# https://github.com/google/jax/issues/6975
# state is a pytree that contains an integer and a boolean.
# The function returns an integer and a boolean.
def f_jax(param, state, x):
return param * x, state
xs = [tf.Variable(x) for x in [10., 11., 12., 13]]
with tf.GradientTape(persistent=True) as tape:
res = fn(*xs)

g_tf_native = tape.gradient(res, xs)
self.assertAllClose(g_tf_native[0].numpy(), np.float32(0.))
self.assertIsNone(g_tf_native[1])
self.assertAllClose(g_tf_native[2].numpy(), np.float32(2.))
self.assertIsNone(g_tf_native[3])

g_tf_native_0 = tape.gradient(res, xs,
unconnected_gradients=tf.UnconnectedGradients.ZERO)
self.assertAllClose(g_tf_native_0[0].numpy(), np.float32(0.))
self.assertAllClose(g_tf_native_0[1].numpy(), np.float32(0.))
self.assertAllClose(g_tf_native_0[2].numpy(), np.float32(2.))
self.assertAllClose(g_tf_native_0[3].numpy(), np.int32(0))

# Now with jax2tf.convert
with tf.GradientTape(persistent=True) as tape:
conv_fn = jax2tf.convert(fn, with_gradient=True)
if with_function:
conv_fn = tf.function(conv_fn, autograph=False)
res = conv_fn(*xs)

g_jax2tf = tape.gradient(res, xs)
# Returns: 0., 0., 2., None
# Note that the gradient for x1 is 0.
self.assertAllClose(g_jax2tf[0].numpy(), np.float32(0.))
self.assertAllClose(g_jax2tf[1].numpy(), np.float32(0.))
self.assertAllClose(g_jax2tf[2].numpy(), np.float32(2.))
self.assertIsNone(g_jax2tf[3])

g_jax2tf = tape.gradient(res, xs,
unconnected_gradients=tf.UnconnectedGradients.ZERO)
self.assertAllClose(g_jax2tf[0].numpy(), np.float32(0.))
self.assertAllClose(g_jax2tf[1].numpy(), np.float32(0.))
self.assertAllClose(g_jax2tf[2].numpy(), np.float32(2.))
self.assertAllClose(g_jax2tf[3].numpy(), np.int32(0))

@parameterized.named_parameters(jtu.cases_from_list(
dict(testcase_name=f"function={with_function}",
with_function=with_function)
for with_function in [False, True]))
def test_gradients_int_argument(self, with_function=False):
# https://github.com/google/jax/issues/6975
# An expanded version of test_gradients_unused_argument
# param: f32
# state: dict(array:f32, counter:i32, truth: bool)
# xf: f32
# xf_unused: f32 unused
# xi_unused: i32 unused
# return f32, state
def f_jax(param, state, xf, xf_unused, xi_unused):
return param * xf, state

# everything has different shapes
param = np.array([0.7, 0.9], dtype=np.float32)
state = dict(array=1., counter=7, truth=True)
x = 3.
xf = np.array([11.], dtype=np.float32)
xf_unused = np.array([21., 22., 23., 24.], dtype=np.float32)
xi_unused = np.array([31, 32, 33, 34, 35], dtype=np.int32)

# Native JAX AD
g_jax = jax.grad(lambda *args: jnp.sum(f_jax(*args)[0]),
argnums=(0, 1, 2, 3, 4),
allow_int=True)(param, state, xf, xf_unused, xi_unused)
g_jax_param = np.array([11., 11.], dtype=np.float32)
g_jax_x = np.array([1.6], dtype=np.float32)

self.assertAllClose(g_jax[0], g_jax_param)
self.assertAllClose(g_jax[1]["array"], np.zeros_like(state["array"]))
self.assertEqual(g_jax[1]["counter"].dtype, jax.float0)
self.assertEqual(g_jax[1]["counter"].shape, ())
self.assertEqual(g_jax[1]["truth"].dtype, jax.float0)
self.assertEqual(g_jax[1]["truth"].shape, ())
self.assertAllClose(g_jax[2], g_jax_x)
self.assertAllClose(g_jax[3], np.zeros_like(xf_unused))
self.assertEqual(g_jax[4].dtype, jax.float0)
self.assertEqual(g_jax[4].shape, xi_unused.shape)

# Now native TF gradients, only to test how TF AD works
paramv = tf.Variable(param)
statev = tf.nest.map_structure(tf.Variable, state)
xfv = tf.Variable(xf)
xf_unusedv = tf.Variable(xf_unused)
xi_unusedv = tf.Variable(xi_unused)
with tf.GradientTape(persistent=True) as tape:
r, _ = f_jax(paramv, statev, xfv, xf_unusedv, xi_unusedv)
loss = tf.reduce_sum(r)

g_tf_native_0 = tape.gradient(
loss, (paramv, statev, xfv, xf_unusedv, xi_unusedv),
unconnected_gradients=tf.UnconnectedGradients.ZERO)
self.assertAllClose(g_tf_native_0[0].numpy(), g_jax_param)
self.assertAllClose(g_tf_native_0[1]["array"].numpy(), np.zeros_like(state["array"]).astype(np.float32))
self.assertAllClose(g_tf_native_0[1]["counter"].numpy(), np.zeros_like(state["counter"]).astype(np.int32))
self.assertAllClose(g_tf_native_0[1]["truth"].numpy(), np.zeros_like(state["truth"]))
self.assertAllClose(g_tf_native_0[2].numpy(), g_jax_x)
self.assertAllClose(g_tf_native_0[3].numpy(), np.zeros_like(xf_unused).astype(np.float32))
self.assertAllClose(g_tf_native_0[4].numpy(), np.zeros_like(xi_unused).astype(np.int32))

g_tf_native_None = tape.gradient(
loss, (paramv, statev, xfv, xf_unusedv, xi_unusedv),
unconnected_gradients=tf.UnconnectedGradients.NONE)
self.assertAllClose(g_tf_native_None[0].numpy(), g_jax_param)
self.assertIsNone(g_tf_native_None[1]["array"])
self.assertIsNone(g_tf_native_None[1]["counter"])
self.assertIsNone(g_tf_native_None[1]["truth"])
self.assertAllClose(g_tf_native_None[2].numpy(), g_jax_x)
self.assertIsNone(g_tf_native_None[3])
self.assertIsNone(g_tf_native_None[4])

# tf.function is important, without it the bug does not appear
f_tf = tf.function(jax2tf.convert(f_jax, with_gradient=True), autograph=False)
f_tf = jax2tf.convert(f_jax, with_gradient=True)
if with_function:
f_tf = tf.function(f_tf, autograph=False)

paramv = tf.Variable(param)
with tf.GradientTape() as tape:
r, _ = f_tf(paramv, state, x)
with tf.GradientTape(persistent=True) as tape:
r, _ = f_tf(paramv, statev, xfv, xf_unusedv, xi_unusedv)
loss = tf.reduce_sum(r)

g_tf = tape.gradient(loss, paramv)
self.assertAllClose(g_tf.numpy(),
jax.grad(lambda *args: jnp.sum(f_jax(*args)[0]))(param, state, x))
g_tf_0 = tape.gradient(loss, (paramv, statev, xfv, xf_unusedv, xi_unusedv),
unconnected_gradients=tf.UnconnectedGradients.ZERO)
# Same results as TF native AD with tf.UnconnectedGradients.ZERO
self.assertAllClose(g_tf_0[0].numpy(), g_jax_param)
self.assertAllClose(g_tf_0[1]["array"].numpy(), np.zeros_like(state["array"]).astype(np.float32))
self.assertAllClose(g_tf_0[1]["counter"].numpy(), np.zeros_like(state["counter"]).astype(np.int32))
self.assertAllClose(g_tf_0[1]["truth"].numpy(), np.zeros_like(state["truth"]))
self.assertAllClose(g_tf_0[2].numpy(), g_jax_x)
self.assertAllClose(g_tf_0[3].numpy(), np.zeros_like(xf_unused))
self.assertAllClose(g_tf_0[4].numpy(), np.zeros_like(xi_unused))

g_tf_None = tape.gradient(loss, (paramv, statev, xfv, xf_unusedv, xi_unusedv),
unconnected_gradients=tf.UnconnectedGradients.NONE)

# Almost the same results as TF native AD with tf.UnconnectedGradients.ZERO,
# except that unused inputs of inexact type get 0. gradients.
self.assertAllClose(g_tf_None[0].numpy(), g_jax_param)
# The next one is different
self.assertAllClose(g_tf_0[1]["array"].numpy(), np.zeros_like(state["array"]).astype(np.float32))
self.assertIsNone(g_tf_None[1]["counter"])
self.assertIsNone(g_tf_None[1]["truth"])
self.assertAllClose(g_tf_None[2].numpy(), g_jax_x)
# The next one is different
self.assertAllClose(g_tf_0[3].numpy(), np.zeros_like(xf_unused))
self.assertIsNone(g_tf_None[4])

@parameterized.named_parameters(jtu.cases_from_list(
dict(testcase_name=f"function={with_function}",
with_function=with_function)
for with_function in [False, True]))
def test_tf_gradients_int_argument(self, with_function=False):
# https://github.com/google/jax/issues/6975
# param: f32
# state: dict(array:f32, counter:i32, truth: bool)
# xf: f32
# xf_unused: f32 unused
# xi_unused: i32 unused
# return f32, state
def f_jax(param, state, xf, xf_unused, xi_unused):
return param * xf, state

def test_convert_argument_non_callable_error(self):
with self.assertRaisesRegex(TypeError, "Expected a callable value"):
Expand Down

0 comments on commit ffd8fb8

Please sign in to comment.