Skip to content

Commit

Permalink
[export] Add jax.global_constant MLIR attributes for dimension variab…
Browse files Browse the repository at this point in the history
…le arguments

In presence of shape polymorphism and multi-platorm lowering
we pass the global values for the dimension variables and the
platform index to all inner functions. At the moment, prior to
compilation we run a shape refinement pass to infer which of
the arguments of a function carry such global values.
This inference can yield false positives, e.g., when a
user-defined function is called with a constant int32 as the first argument.

With this change we do not need to infer anymore the arguments
that carry global constants. This is in preparation for a more
reliable implementation of shape refinement.
  • Loading branch information
gnecula committed Oct 20, 2023
1 parent 741b71f commit 8d5a858
Show file tree
Hide file tree
Showing 5 changed files with 101 additions and 50 deletions.
25 changes: 14 additions & 11 deletions jax/_src/interpreters/mlir.py
Expand Up @@ -413,7 +413,7 @@ def __init__(self,
dim_vars: tuple[str, ...],
lowering_platforms: tuple[str, ...] | None):
if lowering_platforms is not None and len(lowering_platforms) > 1:
dim_vars = ("platform_index_",) + tuple(dim_vars)
dim_vars = ("_platform_index",) + tuple(dim_vars)
self.has_platform_index_argument = True
else:
self.has_platform_index_argument = False
Expand All @@ -435,6 +435,13 @@ class LoweringParameters:
# This is used only in export and jax2tf.
platforms: tuple[str, ...] | None = None

# Signals that the entire computation being lowered operates on global
# constants. This will result in adding jax.global_constant attributes
# to the arguments of all functions that are created, e.g., floor_divide.
# This is used only in export and jax2tf in presence of shape polymorphism
# or multi-platform lowering.
global_constant_computation: bool = False

@property
def override_platform(self) -> str | None:
"""Overrides the lowering platform for cross-platform lowering.
Expand Down Expand Up @@ -1127,16 +1134,12 @@ def aval_to_types(aval):
attrs["tf.aliasing_output"] = i32_attr(alias)

if num_dim_vars > 0:
if ctx.shape_poly_state.has_platform_index_argument:
num_platform_index_vars = 1
else:
num_platform_index_vars = 0
platform_arg_attrs = arg_attrs[0:num_platform_index_vars]
for attrs in platform_arg_attrs:
attrs["jax.platform_index"] = ir.BoolAttr.get(True)
dim_var_arg_attrs = arg_attrs[num_platform_index_vars:num_dim_vars]
for attrs in dim_var_arg_attrs:
attrs["jax.dimension_variable"] = ir.BoolAttr.get(True)
for var_name, attrs in zip(ctx.shape_poly_state.dim_vars,
arg_attrs[:num_dim_vars]):
attrs["jax.global_constant"] = ir.StringAttr.get(var_name)
elif ctx.lowering_parameters.global_constant_computation:
for attrs in arg_attrs:
attrs["jax.global_constant"] = ir.StringAttr.get("")

if num_tokens > 0:
token_arg_attrs = arg_attrs[num_dim_vars:num_dim_vars + num_tokens]
Expand Down
5 changes: 4 additions & 1 deletion jax/_src/test_util.py
Expand Up @@ -1283,7 +1283,10 @@ def parameterized_filterable(*,
if one_containing is not None:
filtered = tuple(kw for kw in kwargs_with_testcase_name
if one_containing in kw["testcase_name"])
assert filtered, f"No testcase_name contains '{one_containing}'"
assert filtered, (
f"No testcase_name contains '{one_containing}'. "
"The testcase_name values are\n " +
"\n ".join(kw["testcase_name"] for kw in kwargs_with_testcase_name))
kw = filtered[0]
kw["testcase_name"] = ""
return parameterized.named_parameters([kw])
Expand Down
34 changes: 31 additions & 3 deletions jax/experimental/export/export.py
Expand Up @@ -182,20 +182,45 @@ class Exported:
Assume that we use multi-platform lowering, and we have
ordered effects. The `main` function will be as follows:
func public main(platform_index: i32, arg: f32[?, ?]) {
func public main(
platform_index: i32 {jax.global_constant="_platform_index"},
arg: f32[?, ?]) {
arg_w = hlo.get_dimension_size(arg, 0)
dim1 = hlo.get_dimension_size(arg, 1)
arg_h = hlo.floordiv(dim1, 2)
call _check_shape_assertions(arg) # See below
token = new_token()
token_out, res = call _wrapped_jax_export_main(platform_index, arg_h, arg_w, token_in, arg)
token_out, res = call _wrapped_jax_export_main(platform_index,
arg_h,
arg_w,
token_in,
arg)
return res
}
The actual computation is in `_wrapped_jax_export_main`, taking also
the values of `h` and `w` and the token. Proper exporting of
functions with side-effects and tokens is still work-in-progress.
The signature of the `_wrapped_jax_export_main` is:
func private _wrapped_jax_export_main(
platform_index: i32 {jax.global_constant="_platform_index"},
arg_h: i32 {jax.global_constant="h"},
arg_w: i32 {jax.global_constant="w"},
arg_token: stablehlo.token {jax.token=True},
arg: f32[?, ?])
Starting with serialization version 9, function arguments that contain
the platform index or the dimension variable values have a
`jax.global_constant` string attribute whose value is the name of the
global constant, either `_platform_index` or a dimension variable name.
The global constant name may be empty if it is not known.
Some global constant computations use inner functions, e.g., for
`floor_divide`. The arguments of such functions have a `jax.global_constant`
attribute for all attributes, meaning that the result of the function is
also a global constant.
Note that `main` contains a call to `_check_shape_assertions.
JAX tracing assumes that `arg.shape[1]` is even, and that both `w` and `h`
have values >= 1. We must check these constraints when we invoke the
Expand Down Expand Up @@ -602,13 +627,16 @@ def is_token(attrs):
symbol_table.insert(new_main_op)
entry_block = new_main_op.add_entry_block()
with ir.InsertionPoint(entry_block):
# Make a context just for lowering the dimension value computations
module_context = mlir.ModuleContext(
backend_or_name="cpu", platform="cpu",
axis_context=sharding_impls.ShardingContext([]),
name_stack=source_info_util.new_name_stack(),
keepalives=[], channel_iterator=itertools.count(1),
host_callbacks=[], module=wrapped_module, context=context,
lowering_parameters=mlir.LoweringParameters())
lowering_parameters=mlir.LoweringParameters(
global_constant_computation=True
))
ctx = mlir.LoweringRuleContext(
module_context=module_context, primitive=None,
avals_in=args_avals_flat, avals_out=None,
Expand Down
2 changes: 2 additions & 0 deletions jax/experimental/jax2tf/README.md
Expand Up @@ -828,6 +828,8 @@ We list here a history of the serialization version numbers:
attribute and enables the shape refinement pass only when the
attribute is present. Supported by XlaCallModule since July 21st, 2023
(cl/549973693) and available in JAX since July 26th, 2023 (JAX 0.4.14).
* Version 9 (not yet in use) adds support for the `jax.global_constant`
attribute.


## Known issues
Expand Down
85 changes: 50 additions & 35 deletions tests/export_test.py
Expand Up @@ -150,28 +150,6 @@ def f(a_b_pair, *, a, b):
self.assertEqual(exp.out_tree, tree_util.tree_flatten(f(*args, **kwargs))[1])
self.assertEqual(exp.out_avals, (a_aval, b_aval, a_aval, b_aval, a_aval, b_aval))

def test_poly_export_only(self):
a = np.arange(12, dtype=np.float32).reshape((3, 4))
def f(a, b): # a: f32[2w,h] b: f32[w,h]
return jnp.concatenate([a, b], axis=0)

exp = export.export(f)(
export.poly_spec(a.shape, a.dtype, "(2*w, h)"),
export.poly_spec(a.shape, a.dtype, "(w, h)"))
self.assertEqual("(2*w, h)", str(exp.in_avals[0].shape))
self.assertEqual("(w, h)", str(exp.in_avals[1].shape))
self.assertEqual("(3*w, h)", str(exp.out_avals[0].shape))

def test_poly_pytree_export_only(self):
a = np.arange(12, dtype=np.float32).reshape((3, 4))
def f(a0, a1, *, ak):
return jnp.concatenate([a0, a1, ak], axis=0)

a_poly_spec = export.poly_spec(a.shape, a.dtype, "(w, h)")
exp = export.export(f)(a_poly_spec, a_poly_spec, ak=a_poly_spec)
self.assertEqual("(w, h)", str(exp.in_avals[0].shape))
self.assertEqual("(3*w, h)", str(exp.out_avals[0].shape))

def test_basic(self):
f = jnp.sin
x = np.arange(4, dtype=np.float32)
Expand Down Expand Up @@ -361,12 +339,60 @@ def f2(x):
self.assertAllClose(jnp.cos(jnp.sin(jnp.sin(a))),
export.call_exported(exp_f2)(a))

def test_poly_export_only(self):
a = np.arange(12, dtype=np.float32).reshape((3, 4))
def f(a, b): # a: f32[2w,h] b: f32[w,h]
return jnp.concatenate([a, b], axis=0)

exp = export.export(f)(
export.poly_spec(a.shape, a.dtype, "(2*w, h)"),
export.poly_spec(a.shape, a.dtype, "(w, h)"))
self.assertEqual("(2*w, h)", str(exp.in_avals[0].shape))
self.assertEqual("(w, h)", str(exp.in_avals[1].shape))
self.assertEqual("(3*w, h)", str(exp.out_avals[0].shape))

# Peek at the module
module_str = exp.mlir_module()
self.assertEqual(config.jax_serialization_version.value >= 7,
"shape_assertion" in module_str)
self.assertIn("jax.uses_shape_polymorphism = true", module_str)
wrapped_main_expected_re = (
r"@_wrapped_jax_export_main\("
r"%arg0: tensor<i..> {jax.global_constant = \"h\"}.*"
r"%arg1: tensor<i..> {jax.global_constant = \"w\"}.*"
r"%arg2: tensor<\?x\?xf32>"
)
self.assertRegex(module_str, wrapped_main_expected_re)

# Look for private inner functions that are generated to compute the
# dimension variables and shape assertions. All those functions must
# have jax.global_constant attributes on all the arguments.
for func_name, func_args in re.findall(
r"func.func private @([\w]+)\((.+)\) ->",
module_str):
if func_name == "_wrapped_jax_export_main":
continue
func_args_count = len(re.findall(r"%arg\d+", func_args))
func_args_constant_attrs = len(re.findall(r"jax.global_constant = ",
func_args))
self.assertEqual(func_args_count, func_args_constant_attrs)

def test_poly_pytree_export_only(self):
a = np.arange(12, dtype=np.float32).reshape((3, 4))
def f(a0, a1, *, ak):
return jnp.concatenate([a0, a1, ak], axis=0)

a_poly_spec = export.poly_spec(a.shape, a.dtype, "(w, h)")
exp = export.export(f)(a_poly_spec, a_poly_spec, ak=a_poly_spec)
self.assertEqual("(w, h)", str(exp.in_avals[0].shape))
self.assertEqual("(3*w, h)", str(exp.out_avals[0].shape))

@jtu.parameterized_filterable(
kwargs=[
dict(v=v)
for v in range(export.minimum_supported_serialization_version - 1,
export.maximum_supported_serialization_version + 2)])
def test_shape_poly_basic_versions(self, v: int):
def test_poly_basic_versions(self, v: int):
self.override_serialization_version(v)
with contextlib.ExitStack() as e:
if not (export.minimum_supported_serialization_version <= v
Expand All @@ -377,17 +403,6 @@ def test_shape_poly_basic_versions(self, v: int):

exp = export.export(jnp.sin)(
export.poly_spec((3, 4), np.float32, "w, h"))
# Peek at the module
module_str = exp.mlir_module()
self.assertEqual(config.jax_serialization_version.value >= 7,
"shape_assertion" in module_str)
self.assertIn("jax.uses_shape_polymorphism = true",
module_str)
dim_vars = re.findall(
r"(%arg\d):\s*tensor<i..>\s*{jax.dimension_variable = true}",
module_str)
self.assertEqual(["%arg0", "%arg1"], dim_vars,
f"Found {dim_vars} in {module_str}")
x = np.arange(30, dtype=np.float32).reshape((5, 6))
res = export.call_exported(exp)(x)
self.assertAllClose(res, np.sin(x))
Expand Down Expand Up @@ -752,7 +767,7 @@ def test_multi_platform(self):
module_str = str(exp.mlir_module())
expected_main_re = (
r"@main\("
r"%arg0: tensor<i..> {jax.platform_index = true}.*, "
r"%arg0: tensor<i..> {jax.global_constant = \"_platform_index\"}.*, "
r"%arg1: tensor<8xf32>.* ->")
self.assertRegex(module_str, expected_main_re)

Expand Down

0 comments on commit 8d5a858

Please sign in to comment.