Skip to content

Commit

Permalink
[jax2tf] Improve support for converting functions with kwargs
Browse files Browse the repository at this point in the history
The previous conversion for kwargs did not work for AD.

Bug: #6791
  • Loading branch information
gnecula committed Aug 5, 2021
1 parent efc5e25 commit b4e4acd
Show file tree
Hide file tree
Showing 3 changed files with 64 additions and 15 deletions.
37 changes: 23 additions & 14 deletions jax/experimental/jax2tf/jax2tf.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,6 @@
from jax._src.lax import lax
from jax._src.lax import linalg as lax_linalg
import jax._src.random
from jax.api_util import flatten_fun
from jax.experimental import maps
from jax.experimental import pjit
from jax.interpreters import ad
Expand Down Expand Up @@ -254,16 +253,28 @@ def converted_fun(*args: TfVal, **kwargs: TfVal) -> TfVal:
raise ValueError("convert must be used outside all JAX transformations." +
f"Trace state: {core.thread_local_state.trace_state.trace_stack}")

# We support kwargs by wrapping the function to take only positional arguments.
# This is in part because jax.vjp does not support kwargs.
nr_positional_args = len(args)
kw_names = kwargs.keys()
args = tuple(args) + tuple(kwargs[kw] for kw in kw_names)

def fun_no_kwargs(*args_and_kwargs):
assert len(args_and_kwargs) == nr_positional_args + len(kw_names)
args = args_and_kwargs[:nr_positional_args]
kwargs = {kw: args_and_kwargs[nr_positional_args + i]
for i, kw in enumerate(kw_names)}
return fun(*args, **kwargs)

def check_arg(a):
if not _is_tfval(a):
msg = (f"Argument {a} of type {type(a)} of jax2tf.convert(f) should "
"be NumPy array, scalar, tf.Variable, or tf.Tensor")
raise TypeError(msg)

tree_util.tree_map(check_arg, args)
tree_util.tree_map(check_arg, list(kwargs.values()))

args_flat, in_tree = tree_util.tree_flatten((args, kwargs))
args_flat, in_tree = tree_util.tree_flatten((args, {}))
# May need to cast the arguments to have the type assumed by JAX
args_and_dtypes_flat = tuple(map(_tfval_to_tensor_jax_dtype, args_flat))
args_flat, arg_dtypes_flat = util.unzip2(args_and_dtypes_flat)
Expand All @@ -277,19 +288,16 @@ def _apply_name(a: TfVal, suffix) -> TfVal:
elif isinstance(polymorphic_shapes, (PolyShape, str)):
polymorphic_shapes_ = (polymorphic_shapes,) * len(args) # type: ignore
else:
if not isinstance(polymorphic_shapes, Sequence) or len(args) != len(polymorphic_shapes):
if not isinstance(polymorphic_shapes, Sequence) or len(polymorphic_shapes) != len(args) - len(kw_names):
msg = ("polymorphic_shapes must be a sequence with the same length as the positional argument list "
f"({len(args)}). Got polymorphic_shapes={repr(polymorphic_shapes)}.")
raise TypeError(msg)
polymorphic_shapes_ = tuple(polymorphic_shapes)
polymorphic_shapes_ = tuple(polymorphic_shapes) + (None,) * len(kw_names)

# Expand the polymorphic_shapes to match the argument pytree
polymorphic_shapes_flat = tuple(api_util.flatten_axes("jax2tf.convert polymorphic_shapes",
in_tree.children()[0],
polymorphic_shapes_))
# Add kwargs shapes.
polymorphic_shapes_flat = polymorphic_shapes_flat + tuple(
(None,) * (len(args_flat) - len(polymorphic_shapes_flat)))

# Construct the abstract values for the flat arguments, possibly based on
# the input shapes and the polymorphic_shapes if given. May create new shape
Expand All @@ -300,9 +308,9 @@ def _apply_name(a: TfVal, suffix) -> TfVal:

# This function may take pytrees of TfVals. We can only set
# tf.custom_gradient on functions that take a flat argument list.
f = lu.wrap_init(fun)
f = lu.wrap_init(fun_no_kwargs)
# out_tree_thunk() will be the output tree, after running _interpret_fun.
flat_fun, out_tree_thunk = flatten_fun(f, in_tree)
flat_fun, out_tree_thunk = api_util.flatten_fun(f, in_tree)
# out_tree_thunk will be ready after _interpret_fun below.

# Prepare the grad_fn for tf.custom_gradient.
Expand Down Expand Up @@ -332,8 +340,9 @@ def fun_vjp_jax(args_flat_jax, out_cts_flat_jax):
# pullback may contain captured tracers from the conversion of the
# main function. Those tracers will confuse the conversion of the
# pullback. So, we construct the vjp anew and we convert it separately.
args_jax, _ = tree_util.tree_unflatten(in_tree, args_flat_jax)
_, pullback_jax = jax.vjp(fun, *args_jax)
args_jax, kwargs_jax = tree_util.tree_unflatten(in_tree, args_flat_jax)
assert not kwargs_jax
_, pullback_jax = jax.vjp(fun_no_kwargs, *args_jax)

def fix_out_ct(out_ct_jax, out_ct_aval: core.ShapedArray):
# If the primal function has outputs of integer or bool types, and if we are
Expand Down Expand Up @@ -365,13 +374,13 @@ def fix_in_ct(in_ct, arg_aval: core.ShapedArray):
return in_cts_fixed_flat_jax

# TODO: enable higher-order gradients
# TODO: I think that this does not work with kwargs
with tf.name_scope("jax2tf_vjp"):
in_cts_flat = convert(
fun_vjp_jax,
with_gradient=False,
polymorphic_shapes=vjp_polymorphic_shapes)(args_flat, out_cts_flat)
in_cts, _ = tree_util.tree_unflatten(in_tree, in_cts_flat)
in_cts, kwin_cts = tree_util.tree_unflatten(in_tree, in_cts_flat)
assert not kwin_cts
return in_cts

try:
Expand Down
30 changes: 29 additions & 1 deletion jax/experimental/jax2tf/tests/jax2tf_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -745,15 +745,43 @@ def jax_fn_array(x):
tf_fn_array(np.array([3, 4, 5])), np.array([4.5, 10, 17.5],
jnp.bfloat16))

def test_kwargs(self):
@parameterized.named_parameters(jtu.cases_from_list(
dict(testcase_name=f"function={with_function}",
with_function=with_function)
for with_function in [False, True]))
def test_kwargs(self, with_function=True):
# Re: https://github.com/google/jax/issues/6791
def f_jax(*, x):
return jnp.sum(x)
f_tf = jax2tf.convert(f_jax)
if with_function:
f_tf = tf.function(f_tf)
self.assertAllClose(
f_tf(x=np.zeros(3, dtype=np.float32)), # Call with kwargs.
np.zeros((), dtype=np.float32))

@parameterized.named_parameters(jtu.cases_from_list(
dict(testcase_name=f"function={with_function}",
with_function=with_function)
for with_function in [False, True]))
def test_grad_kwargs(self, with_function=False):
# Re: https://github.com/google/jax/issues/6791
x = (np.zeros(3, dtype=np.float32),
np.zeros(4, dtype=np.float32))
def f_jax(*, x=(1., 2.)):
return jnp.sum(x[0]) + 2. * jnp.sum(x[1])
f_tf = jax2tf.convert(f_jax)
if with_function:
f_tf = tf.function(f_tf)
xv = tf.nest.map_structure(tf.Variable, x)
with tf.GradientTape() as tape:
res = f_tf(x=xv)
grad_tf = tape.gradient(res, xv)
self.assertAllClose((np.full_like(x[0], fill_value=1.),
np.full_like(x[1], fill_value=2.)),
(grad_tf[0].numpy(), grad_tf[1].numpy()))


def test_enable_xla(self):
# Tests that enable_xla flag is properly scoped to a conversion.
def fun(x):
Expand Down
12 changes: 12 additions & 0 deletions jax/experimental/jax2tf/tests/shape_poly_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -375,6 +375,18 @@ def test_forgot_polymorphic_shapes_error(self):
input_signature=[tf.TensorSpec([1, None])],
polymorphic_shapes=None)

def test_kwargs(self):
"""Test shape polymorphism for a function with kwargs."""

x = np.ones(3, dtype=np.float32)
y = np.ones(1, dtype=np.float32)
def f_jax(x, *, y):
return x + jnp.sin(y)

f_tf = jax2tf.convert(f_jax, polymorphic_shapes=["b, ..."])
f_tf(x, y=y)


def test_arg_avals(self):
"""Test conversion of actual arguments to abstract values."""

Expand Down

0 comments on commit b4e4acd

Please sign in to comment.