Skip to content

Commit

Permalink
Return gradients as flattened list from custom-defined gradient op in…
Browse files Browse the repository at this point in the history
… jax2tf.

Previous code returned gradients as tree. These are then flattened by tf's custom_gradient, but the flattening is performed by tf.nest as opposed to tree_util, and this can lead to inconsistencies in the order of gradients compared to order of the inputs.

PiperOrigin-RevId: 435363026
  • Loading branch information
ofirnachum authored and jax authors committed Mar 17, 2022
1 parent 1f0a5b3 commit 4d966bb
Show file tree
Hide file tree
Showing 2 changed files with 30 additions and 3 deletions.
4 changes: 1 addition & 3 deletions jax/experimental/jax2tf/jax2tf.py
Original file line number Diff line number Diff line change
Expand Up @@ -407,9 +407,7 @@ def fix_in_ct(in_ct, arg_aval: core.ShapedArray):
fun_vjp_jax,
with_gradient=False,
polymorphic_shapes=vjp_polymorphic_shapes)(args_flat, out_cts_flat)
in_cts, kwin_cts = tree_util.tree_unflatten(in_tree, in_cts_flat)
assert not kwin_cts
return in_cts
return in_cts_flat

try:
assert not _thread_local_state.shape_env, f"Unexpected shape environment {_thread_local_state.shape_env}"
Expand Down
29 changes: 29 additions & 0 deletions jax/experimental/jax2tf/tests/jax2tf_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,8 @@
from absl.testing import absltest
from absl.testing import parameterized

from collections import OrderedDict

import jax
from jax import ad_checkpoint
from jax import dtypes
Expand Down Expand Up @@ -286,6 +288,33 @@ def f(xy: Tuple[float, float]) -> Dict[str, float]:
self.assertAllClose(5., tape.gradient(uv["two"], x))
self.assertAllClose(4., tape.gradient(uv["two"], y))

@parameterized.named_parameters(jtu.cases_from_list(
dict(testcase_name=f"function={with_function}",
with_function=with_function)
for with_function in [False, True]))
def test_gradients_with_ordered_dict_input(self, with_function=True):
def f(inputs):
out = 0.0
for v in inputs.values():
out += jnp.sum(v)
return out

f_tf = jax2tf.convert(f, with_gradient=True)
if with_function:
f_tf = tf.function(f_tf, autograph=False)
default_float_type = jax2tf.dtype_of_val(4.)
inputs = OrderedDict()
x = tf.Variable([4.], dtype=default_float_type)
y = tf.Variable([4., 5.], dtype=default_float_type)
inputs = OrderedDict()
inputs['r'] = x
inputs['d'] = y
with tf.GradientTape(persistent=True) as tape:
u = f_tf(inputs)

self.assertAllClose(np.array([1.]), tape.gradient(u, x).numpy())
self.assertAllClose(np.array([1., 1.]), tape.gradient(u, y).numpy())

@parameterized.named_parameters(jtu.cases_from_list(
dict(testcase_name=f"function={with_function}",
with_function=with_function)
Expand Down

0 comments on commit 4d966bb

Please sign in to comment.