Skip to content

Commit

Permalink
[host_callback] Fix the JVP rule for id_tap(result=...)
Browse files Browse the repository at this point in the history
The previous rule was leaking ad.Zero.
  • Loading branch information
gnecula committed Jul 26, 2021
1 parent 36d06db commit 0b697ce
Show file tree
Hide file tree
Showing 2 changed files with 110 additions and 25 deletions.
58 changes: 33 additions & 25 deletions jax/experimental/host_callback.py
Original file line number Diff line number Diff line change
Expand Up @@ -798,17 +798,24 @@ def _values_to_avals(vals) -> Sequence[core.ShapedArray]:
id_tap_dep_p.def_impl(lambda r, _: r)
xla.translations[id_tap_dep_p] = lambda comp, a_res, a_tap: a_res
id_tap_dep_p.def_abstract_eval(lambda r_a, _: r_a)
ad.primitive_jvps[id_tap_dep_p] = (
lambda primals, tangents: (
id_tap_dep_p.bind(primals[0], primals[1]),
id_tap_dep_p.bind(tangents[0], tangents[1])))

def _id_tap_dep_jvp_rule(primals, tangents):
tangents_instantiated = tuple(map(_instantiate_zeros, tangents, primals))
return (id_tap_dep_p.bind(primals[0], primals[1]),
id_tap_dep_p.bind(tangents_instantiated[0], tangents_instantiated[1]))

def _id_tap_dep_transpose_rule(cts, arg_res, arg_tap):
assert ad.is_undefined_primal(arg_res)
assert ad.is_undefined_primal(arg_tap)
return (_instantiate_zeros(arg_res, cts), ad.Zero(arg_tap.aval))
ad.primitive_jvps[id_tap_dep_p] = _id_tap_dep_jvp_rule

def _id_tap_dep_transpose_rule(cts, arg_res, arg_tap):
if ad.is_undefined_primal(arg_res):
ct_res = _instantiate_zeros(cts, arg_res)
else:
ct_res = None
if ad.is_undefined_primal(arg_tap):
ct_tap = ad.Zero(arg_tap.aval)
else:
ct_tap = None
return (ct_res, ct_tap)

ad.primitive_transposes[id_tap_dep_p] = _id_tap_dep_transpose_rule

Expand Down Expand Up @@ -1170,36 +1177,37 @@ def _add_transform(params: Dict, name: str, *transform_params) -> Dict:
def _aval_is_empty(aval) -> bool:
return np.prod(aval.shape) == 0

# TODO(necula): there must be a better way to do this.
# The AttributeError is for regular values, the KeyError is for ConcreteArray
def _instantiate_zeros(arg, tan):
"""Turn special ad.zero tangents into arrays of 0s for sending to host."""
# return ad.instantiate_zeros(tan)
def _instantiate_zeros(tan, arg):
"""Turn special ad.zero tangents into arrays of 0s for sending to host.
Args:
tan: the tangent.
arg: the argument for which we need to instantiate the tangent
Returns: tan if is is not ad.Zero, otherwise a 0 array of appropriate type
and shape
"""
if type(tan) is not ad.Zero:
return tan
if tan.aval is not core.abstract_unit:
return ad.instantiate_zeros_aval(tan.aval, tan)

if tan.aval is core.abstract_unit:
if ad.is_undefined_primal(arg):
aval = arg.aval
else:
aval = core.raise_to_shaped(core.get_aval(arg))
if ad.is_undefined_primal(arg):
aval = arg.aval
else:
aval = tan.aval
res = ad.instantiate_zeros_aval(aval, tan)
return res

aval = core.raise_to_shaped(core.get_aval(arg))
return ad.instantiate_zeros_aval(aval, tan)

def _outside_call_jvp_rule(primals, tangents, **params):
assert "has_token" not in params
if not params["identity"]:
raise NotImplementedError("JVP rule is implemented only for id_tap, not for call.")
tangent_instantiated = tuple(map(_instantiate_zeros, primals, tangents))
tangents_instantiated = tuple(map(_instantiate_zeros, tangents, primals))

arg_treedef = params["arg_treedef"]
# The argument to the jvp tap is a pair of the tapped primals and tangents
jvp_flat_args, jvp_arg_treedef = api.tree_flatten(
(arg_treedef.unflatten(primals),
arg_treedef.unflatten(tangent_instantiated)))
arg_treedef.unflatten(tangents_instantiated)))
out_all = outside_call_p.bind(
*jvp_flat_args,
**dict(_add_transform(params, "jvp"),
Expand Down Expand Up @@ -1264,7 +1272,7 @@ def _outside_call_transpose_rule(cts, *args, **params):
raise NotImplementedError("differentiation rules are implemented only for id_tap, not for call.")
assert "has_token" not in params
assert len(cts) == len(args)
cts_instantiated = tuple(map(_instantiate_zeros, args, cts))
cts_instantiated = tuple(map(_instantiate_zeros, cts, args))

# The args have been prepared by the id_tap_jvp_rule: tapped_primals, tapped_tangents, rest_primals, rest_tangents
transforms = params.get("transforms", ())
Expand Down
77 changes: 77 additions & 0 deletions tests/host_callback_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@
from absl.testing import parameterized

import jax
from jax import core
from jax._src import api
from jax.config import config
from jax import dtypes
Expand Down Expand Up @@ -1046,6 +1047,82 @@ def func(x, yint):
( 2.00
False )""", testing_stream.output)

def test_tap_grad_float0_result(self):
# https://github.com/google/jax/issues/7340
# x is a Tuple[f32[2], s32[3]]
x = (np.array([.7, .8], dtype=np.float32),
np.array([11, 12, 13], dtype=np.int32))
def f_jax(x):
x = hcb.id_print(x, result=x, output_stream=testing_stream) # result= is important
return (3. * x[0], x[1])

def f_jax_vjp(x):
res, pullback = jax.vjp(f_jax, x)
g, = pullback((np.ones(x[0].shape, dtype=x[0].dtype),
np.zeros(x[1].shape, dtype=dtypes.float0)))
return g

g = f_jax_vjp(x)
self.assertAllClose(np.array([3., 3.], dtype=np.float32), g[0])
self.assertEqual(dtypes.float0, g[1].dtype)
hcb.barrier_wait()
assertMultiLineStrippedEqual(self, """
( [0.70 0.80]
[11 12 13] )
transforms: ['jvp', 'transpose']
( [0.00 0.00]
[False False False] )""", testing_stream.output)

def test_tap_higher_order_grad_float0_result(self):
# https://github.com/google/jax/issues/7340
# x is a Tuple[f32[2], s32[3]]
x = (np.array([.7, .8], dtype=np.float32),
np.array([11, 12, 13], dtype=np.int32))
def f_jax(x):
x = hcb.id_print(x, result=x, output_stream=testing_stream) # result= is important
return (jnp.sin(x[0]), x[1])

def wrap_vjp(f, args, res_f_of_args):
# Given a function "f" and "args" return the f_vjp and args_vjp
def make_ct(res):
res_dtype = np.result_type(res)
if res_dtype == dtypes.float0:
return res
ct_dtype = core.primal_dtype_to_tangent_dtype(res_dtype)
return np.ones(np.shape(res), dtype=ct_dtype)
cts = tree_util.tree_map(make_ct, res_f_of_args)
def f_vjp(args, cts):
res, pullback = jax.vjp(f, *args)
return pullback(cts)
return (f_vjp, (args, cts))

res = f_jax(x)
hcb.barrier_wait()
assertMultiLineStrippedEqual(self, """
( [0.70 0.80]
[11 12 13] )""", testing_stream.output)
testing_stream.reset()

# 1st order
f_jax_vjp1, args_vjp1 = wrap_vjp(f_jax, (x,), res)
res_vjp1 = f_jax_vjp1(*args_vjp1)
hcb.barrier_wait()
assertMultiLineStrippedEqual(self, """
( [0.70 0.80]
[11 12 13] )
transforms: ['jvp', 'transpose']
( [0.00 0.00]
[False False False] )""", testing_stream.output)
testing_stream.reset()

# 2nd order
f_jax_vjp2, args_vjp2 = wrap_vjp(f_jax_vjp1, args_vjp1, res_vjp1)
res_vjp2 = f_jax_vjp2(*args_vjp2)

# 3rd order
f_jax_vjp3, args_vjp3 = wrap_vjp(f_jax_vjp2, args_vjp2, res_vjp2)
_ = f_jax_vjp3(*args_vjp3)

def test_tap_vmap(self):
vmap_fun1 = api.vmap(fun1)
vargs = jnp.array([jnp.float32(4.), jnp.float32(5.)])
Expand Down

0 comments on commit 0b697ce

Please sign in to comment.