Skip to content

Commit

Permalink
make core_test.py pass with core.call
Browse files Browse the repository at this point in the history
  • Loading branch information
mattjj committed May 11, 2022
1 parent 5bce808 commit 0b841cf
Show file tree
Hide file tree
Showing 2 changed files with 13 additions and 1 deletion.
2 changes: 1 addition & 1 deletion jax/interpreters/ad.py
Original file line number Diff line number Diff line change
Expand Up @@ -593,7 +593,7 @@ def call_transpose(primitive, params, call_jaxpr, args, ct, _, reduce_axes):
fun = lu.hashable_partial(lu.wrap_init(backward_pass), call_jaxpr,
reduce_axes, False)
fun, out_tree = flatten_fun_nokwargs(fun, in_tree_def)
if not config.jax_experimental_name_stack:
if 'name' in params and not config.jax_experimental_name_stack:
params = dict(params, name=wrap_name(params['name'], 'transpose'))
update_params = call_transpose_param_updaters.get(primitive)
if update_params:
Expand Down
12 changes: 12 additions & 0 deletions tests/core_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,8 +32,10 @@
from jax.core import UnshapedArray, ShapedArray
from jax.tree_util import (tree_flatten, tree_unflatten, tree_map, tree_reduce,
tree_leaves)
from jax.api_util import flatten_fun_nokwargs
from jax.interpreters import partial_eval as pe

from jax._src import util
from jax._src import test_util as jtu
from jax._src.abstract_arrays import make_shaped_array
from jax._src.lax import lax as lax_internal
Expand All @@ -47,6 +49,13 @@
def call(f, *args):
return jit(f)(*args)

@util.curry
def core_call(f, *args):
args, in_tree = tree_flatten(args)
f, out_tree = flatten_fun_nokwargs(lu.wrap_init(f), in_tree)
out = core.call_p.bind(f, *args)
return tree_unflatten(out_tree(), out)

def simple_fun(x, y):
return jnp.sin(x * y)

Expand Down Expand Up @@ -135,6 +144,9 @@ def jvp_unlinearized(f, primals, tangents):
test_specs.append(CallSpec(partial(jvp, ts.fun), (ts.args, ts.args)))
test_specs.append(CallSpec(jit(ts.fun), ts.args))
test_specs.append(CallSpec(jit(jit(ts.fun)), ts.args))
test_specs.append(CallSpec(core_call(ts.fun), ts.args))
test_specs.append(CallSpec(core_call(jit(ts.fun)), ts.args))
test_specs.append(CallSpec(core_call(core_call(ts.fun)), ts.args))
test_specs.append(CallSpec(partial(jvp_unlinearized, ts.fun),
(ts.args, ts.args)))

Expand Down

0 comments on commit 0b841cf

Please sign in to comment.