Skip to content

Commit

Permalink
[shape_poly] Add a polymorphic shape refinement MLIR pass accessible …
Browse files Browse the repository at this point in the history
…to JAX Python.

At the moment we can run the StableHLO module lowered by jax2tf
with polymorphic shapes only with jax2tf, because the tf.XlaCallModule op has the
necessary shape refinement logic (which is necessary to legalize
the StableHLO module with dynamic shapes to MHLO). Here we
expose the shape refinement MLIR transformation to JAX Python.

For now this is used only in a test in jax_export_test.py.

PiperOrigin-RevId: 537485288
  • Loading branch information
gnecula authored and jax authors committed Jun 3, 2023
1 parent 8861858 commit ec8b855
Showing 1 changed file with 80 additions and 3 deletions.
83 changes: 80 additions & 3 deletions jax/experimental/jax2tf/tests/jax_export_test.py
Expand Up @@ -12,14 +12,15 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import contextlib
import logging
import math
from typing import List
import unittest

from absl.testing import absltest, parameterized

import jax
from jax import tree_util

from jax import numpy as jnp
from jax import tree_util
from jax.config import config
from jax.experimental.jax2tf import jax_export
try:
Expand Down Expand Up @@ -292,6 +293,82 @@ def outer(x): # x: outer_poly_spec
res = jax2tf._run_exported_as_tf([outer_x], outer_exp)[0].numpy()
self.assertAllClose(2. * inner(outer_x), res)

def test_call_poly(self):
a_shape = (3, 4)
a = np.arange(math.prod(a_shape), dtype=np.float32).reshape(a_shape)

def f_inner(x): # x: f32[w, h]
return jnp.reshape(x, (-1,))

exp_inner = jax_export.export(f_inner)(
jax_export.poly_spec(a.shape, a.dtype, "w, h")
)

# There are dynamic shapes in the exported module
self.assertIn("?x", exp_inner.mlir_module)
self.assertIn("stablehlo.dynamic_reshape", exp_inner.mlir_module)

# Add a wrapper "main" func with static shapes
# TODO(necula): We will add this functionality to jax_export.
from jax._src.interpreters import mlir
from jax._src.lib.mlir import ir
from jax._src.lib.mlir.dialects import hlo
from jax._src.lib.mlir.dialects import func as func_dialect
from jax.lib import xla_client as xc
from jax._src.lib import xla_extension

context = mlir.make_ir_context()
with context, ir.Location.unknown(context):
wrapped_module = ir.Module.parse(exp_inner.mlir_module)
symbol_table = ir.SymbolTable(wrapped_module.operation)
orig_main = symbol_table["main"]
orig_main.attributes["sym_visibility"] = ir.StringAttr.get("private")
symbol_table.set_symbol_name(orig_main, "_wrapped_jax_export_main")
orig_main_name = ir.StringAttr(symbol_table.insert(orig_main)).value
# Use static shapes
new_main_input_types = [
mlir.aval_to_ir_type(core.ShapedArray((3, 4), np.float32))
]
orig_output_types = orig_main.type.results
new_main_ftype = ir.FunctionType.get(
new_main_input_types, orig_output_types
)
new_main_op = func_dialect.FuncOp(
"main",
new_main_ftype,
ip=ir.InsertionPoint.at_block_begin(wrapped_module.body),
)
new_main_op.attributes["sym_visibility"] = ir.StringAttr.get("public")
symbol_table.insert(new_main_op)
entry_block = new_main_op.add_entry_block()
with ir.InsertionPoint(entry_block):
orig_main_args: List[ir.Value] = []
for new_arg, orig_arg_type in zip(
new_main_op.arguments, orig_main.type.inputs
):
orig_main_args.append(hlo.ConvertOp(orig_arg_type, new_arg).result)
call = func_dialect.CallOp(
orig_output_types,
ir.FlatSymbolRefAttr.get(orig_main_name),
orig_main_args,
)
func_dialect.ReturnOp(call.results)
symbol_table.set_symbol_name(new_main_op, "main")

# TODO(necula): need conditionals until jaxlib 0.4.12 is the minimum version
if xc.mlir_api_version >= 50:
refined_module_str = xla_extension.mlir.refine_polymorphic_shapes(
mlir.module_to_bytecode(wrapped_module)
)
context = mlir.make_ir_context()
with context:
refined_module = ir.Module.parse(refined_module_str)

logging.info("Postprocessed module %s", str(refined_module))
self.assertNotIn("?x", str(refined_module))
self.assertNotIn("stablehlo.dynamic_reshape", str(refined_module))
self.assertIn("stablehlo.reshape", str(refined_module))


if __name__ == "__main__":
absltest.main(testLoader=jtu.JaxTestLoader())

0 comments on commit ec8b855

Please sign in to comment.