Skip to content

Commit

Permalink
[shape_poly] Add partial support for call_exported with polymorphic s…
Browse files Browse the repository at this point in the history
…hapes

Until now the jax_export.call_exported did not allow calling functions
that were exported with polymorphic shapes. We now add that support,
including resolving the dimension variables of the called function
in terms of the shapes at the call site (which themselves may include
dimension variables), and then computing the output shape of the
called function.

The support is partial in that we can export a JAX function that
calls an exported polymorphic function, but we cannot invoke it.
This is because we do not yet have access to the shape refinement
machinery that XlaCallModule uses. For now, we use XlaCallModule
for invoking exported that includes shape polymorphism.
  • Loading branch information
gnecula committed May 26, 2023
1 parent 7833528 commit 46a258b
Show file tree
Hide file tree
Showing 2 changed files with 76 additions and 14 deletions.
32 changes: 26 additions & 6 deletions jax/experimental/jax2tf/jax_export.py
Expand Up @@ -755,15 +755,19 @@ def f_imported(*args, **kwargs):
def _call_exported_abstract_eval(*in_avals: core.AbstractValue,
exported: Exported) -> Tuple[core.AbstractValue, ...]:
exported_dim_vars = shape_poly.all_dim_vars(exported.in_avals)
if exported_dim_vars:
raise NotImplementedError("call_exported for exported with polymorphic shapes")
assert len(in_avals) == len(exported.in_avals) # since the pytrees have the same structure
# Must express the exported_dim_vars in terms of the shapes in in_avals.
_ = shape_poly.unify_avals_with_args(
exported_dim_values = shape_poly.unify_avals_with_args(
exported.in_avals, exported_dim_vars, *in_avals, # type: ignore
use_static_dimension_size=True,
args_kwargs_tree=exported.in_tree)
return exported.out_avals

return tuple(
core.ShapedArray(core.evaluate_shape(out_aval.shape, exported_dim_vars,
*exported_dim_values),
dtype=out_aval.dtype, weak_type=out_aval.weak_type,
named_shape=out_aval.named_shape)
for out_aval in exported.out_avals)


call_exported_p.def_abstract_eval(_call_exported_abstract_eval)
Expand All @@ -783,16 +787,32 @@ def _call_exported_lowering(ctx: mlir.LoweringRuleContext, *args,
f"on '{platform}'.")
submodule = ir.Module.parse(exported.mlir_module)
symtab = ir.SymbolTable(submodule.operation)
# The called function may have been exported with polymorphic shapes and called
# now with more refined shapes. We insert hlo.ConvertOp to ensure the module
# is valid.
def convert_shape(x: ir.Value, x_aval: core.AbstractValue, new_aval: core.AbstractValue) -> ir.Value:
new_ir_type = mlir.aval_to_ir_type(new_aval)
if x.type != new_ir_type:
return mlir.convert_hlo(ctx, x, x_aval, new_aval)
else:
return x

callee_result_types = symtab["main"].type.results
# TODO: maybe cache multiple calls
fn = mlir.merge_mlir_modules(ctx.module_context.module,
f"call_exported_{exported.fun_name}",
submodule)
kept_args = [a for i, a in enumerate(args) if i in exported.module_kept_var_idx]
kept_args = [
convert_shape(a, a_aval, exported_in_aval)
for i, (a, a_aval, exported_in_aval) in enumerate(zip(args, ctx.avals_in, exported.in_avals))
if i in exported.module_kept_var_idx]
call = func_dialect.CallOp(callee_result_types,
ir.FlatSymbolRefAttr.get(fn),
kept_args)
return call.results
# The ctx.avals_out already contain the abstract values refined by
# _call_exported_abstract_eval.
return tuple(convert_shape(out, out_aval, refined_out_aval)
for out, out_aval, refined_out_aval in zip(call.results, exported.out_avals, ctx.avals_out))


for _p in ("cpu", "tpu", "cuda", "rocm"):
Expand Down
58 changes: 50 additions & 8 deletions jax/experimental/jax2tf/tests/jax_export_test.py
Expand Up @@ -23,6 +23,11 @@
from jax import numpy as jnp
from jax.config import config
from jax.experimental.jax2tf import jax_export
try:
from jax.experimental.jax2tf import jax2tf # TODO: temporary
except ImportError:
jax2tf = None

from jax._src import core
from jax._src import test_util as jtu
from jax._src import xla_bridge as xb
Expand Down Expand Up @@ -208,14 +213,51 @@ def f2(x):
self.assertAllClose(jnp.cos(jnp.sin(jnp.sin(a))),
jax_export.call_exported(exp_f2)(a))

def test_call_poly_error(self):
a = np.arange(4, dtype=np.float32)
exp_f1 = jax_export.export(jnp.sin)(
jax_export.poly_spec(a.shape, a.dtype, "b, ...")
)
with self.assertRaisesRegex(NotImplementedError,
"call_exported for exported with polymorphic shapes"):
jax_export.call_exported(exp_f1)(a)
# An inner function is exported with polymorphic shapes inner_poly_spec, and
# is called from an outer function, that is exported with outer_poly_spec.
@parameterized.named_parameters(
dict(testcase_name=f"inner={inner_poly_spec}_outer={outer_poly_spec}",
inner_poly_spec=inner_poly_spec, outer_poly_spec=outer_poly_spec,
expect_error=expect_error)
for inner_poly_spec, outer_poly_spec, expect_error in (
("3,a,a+b", "3,4,12", None),
("3,a,a+b", "3,4,c", None),
("3,a,a+b", "3,c,c", r"Dimension variable.*b.*must have.* >= 1. Found value 0"),
("3,a,a+b", "c,4,12", r"Shape mismatch for args\[0\] in dimension 0"),
("3,a,a+b", "3,c+4,12", None), # TODO: This should be an error, c = 0
("3,4,3*a", "3,4,12", None),
("3,4,5*a", "3,4,12", r"Dimension variable 'a' must have integer value >= 1. Found value 2.4"),
# ("3,a,a", "3,a,a", None), # TODO: wrong error. It should be shape mismatch
# ("3,4,5*a", "3,4,c", None), # TODO: wrong error. It should be "not divisible by 5"
))
def test_poly(self, inner_poly_spec="3,a,a+b",
outer_poly_spec="3,4,12", expect_error=None):
# Polymorphic export called with static or polymorphic shapes
def inner(x): # x: export_poly_spec
return jnp.reshape(x, (x.shape[0] * x.shape[1], x.shape[2]))

x1 = np.arange(3 * 4 * 6, dtype=np.float32).reshape((3, 4, 6)) # x1 : f32[3,4,6]
exp1 = jax_export.export(inner)(jax_export.poly_spec(x1.shape, x1.dtype, inner_poly_spec))

x2 = np.concatenate([x1, x1], axis=2) # x2: f32[3,4,12]
def outer(x): # x: call_poly_spec
# Use an addition to test that the shapes are refined properly for the
# result of the call_exported.
return jax_export.call_exported(exp1)(x) + inner(x)

with contextlib.ExitStack() as stack:
if expect_error is not None:
stack.push(self.assertRaisesRegex(ValueError, expect_error))

# Call it after exporting again, with polymorphic shapes
exp2 = jax_export.export(outer)(
jax_export.poly_spec(x2.shape, x2.dtype, outer_poly_spec))
# TODO: for now, we use XlaCallModule to run modules with polymorphic shapes
# until we create the python bindings to invoke shape refinement.
if jax2tf is not None:
res2 = jax2tf._run_exported_as_tf([x2], exp2)[0].numpy()
# res2 = jax_export.call_exported(exp2)(x2)
self.assertAllClose(2. * inner(x2), res2)


if __name__ == "__main__":
Expand Down

0 comments on commit 46a258b

Please sign in to comment.