Skip to content

Commit

Permalink
[MHLO] Change jax.xla_computation() to use MHLO lowering internally.
Browse files Browse the repository at this point in the history
Change in preparation for removing the non-MHLO lowering path.

PiperOrigin-RevId: 441460875
  • Loading branch information
hawkinsp authored and jax authors committed Apr 13, 2022
1 parent af89426 commit ad8e6ad
Show file tree
Hide file tree
Showing 11 changed files with 139 additions and 136 deletions.
35 changes: 15 additions & 20 deletions jax/_src/api.py
Expand Up @@ -82,6 +82,7 @@
custom_vjp, linear_call)
from jax.custom_transpose import custom_transpose
from jax.interpreters import partial_eval as pe
from jax.interpreters import mlir
from jax.interpreters import xla
from jax.interpreters import pxla
from jax.interpreters import ad
Expand Down Expand Up @@ -811,27 +812,21 @@ def computation_maker(*args, **kwargs):
else:
out_parts_flat = tuple(flatten_axes(
"xla_computation out_parts", out_tree(), out_parts))
c = xc.XlaBuilder(f"xla_computation_{fun_name}")
xla_consts = map(partial(xla.pyval_to_ir_constant, c), consts)
m = mlir.lower_jaxpr_to_module(
f"xla_computation_{fun_name}",
core.ClosedJaxpr(jaxpr, consts),
platform=backend,
axis_context=mlir.ReplicaAxisContext(axis_env_),
name_stack=new_name_stack(wrap_name(fun_name, "xla_computation")),
donated_args=donated_invars,
arg_shardings=(None if in_parts_flat is None else
map(xla.sharding_to_proto, in_parts_flat)),
result_shardings=(None if out_parts_flat is None else
map(xla.sharding_to_proto, out_parts_flat)))
should_tuple = tuple_args if tuple_args is not None else (len(avals) > 100)
xla_args, donated_invars = xla._xla_callable_args(
c, avals, should_tuple, partitions=in_parts_flat, donated_invars=donated_invars)
name_stack = new_name_stack(wrap_name(fun_name, "xla_computation"))
ctx = xla.TranslationContext(c, backend, axis_env_, name_stack)
out_nodes = xla.jaxpr_subcomp(ctx, jaxpr, xla_consts, *xla_args)
build_out_tuple = partial(xc.ops.Tuple, c, out_nodes)
if out_parts is not None:
out_tuple = xla.with_sharding(c, out_parts_flat, build_out_tuple)
else:
out_tuple = build_out_tuple()

if any(donated_invars):
donated_invars = xla.set_up_aliases(c, xla_args, c.GetShape(out_tuple),
donated_invars, tuple_args)
if any(donated_invars):
shapes = [str(c.GetShape(a)) for a, d in zip(xla_args, donated_invars) if d]
warn(f"Some donated buffers were not usable: {', '.join(shapes)}")
built = c.build(out_tuple)
built = xc._xla.mlir.mlir_module_to_xla_computation(
mlir.module_to_string(m), use_tuple_args=should_tuple,
return_tuple=True)
out_shapes_flat = [
ShapeDtypeStruct(a.shape, a.dtype, a.named_shape) for a in out_avals]
out_shape = tree_unflatten(out_tree(), out_shapes_flat)
Expand Down
14 changes: 8 additions & 6 deletions jax/_src/lax/control_flow.py
Expand Up @@ -663,15 +663,15 @@ def _while_lowering(ctx, *args, cond_jaxpr, body_jaxpr, cond_nconsts,

# Loop condition
cond_block = while_op.regions[0].blocks.append(*flat_loop_carry_types)
name_stack = extend_name_stack(ctx.module_context.name_stack, 'while')
with ir.InsertionPoint(cond_block):
flat_cond_args = [
cond_block.arguments[i] for i in range(len(flat_loop_carry_types))
]
cond_args = util.unflatten(flat_cond_args, _map(len, loop_carry_types))
x, _, z = util.split_list(cond_args, [cond_nconsts, body_nconsts])
cond_ctx = ctx.module_context.replace(
name_stack=xla.extend_name_stack(ctx.module_context.name_stack,
'cond'))
name_stack=xla.extend_name_stack(name_stack, 'cond'))
(pred,), = mlir.jaxpr_subcomp(cond_ctx, cond_jaxpr.jaxpr,
_map(mlir.ir_constants, cond_jaxpr.consts),
*(x + z))
Expand All @@ -698,14 +698,13 @@ def _while_lowering(ctx, *args, cond_jaxpr, body_jaxpr, cond_nconsts,
body_args = util.unflatten(flat_body_args, _map(len, loop_carry_types))
x, y, z = util.split_list(body_args, [cond_nconsts, body_nconsts])
body_ctx = ctx.module_context.replace(
name_stack=xla.extend_name_stack(ctx.module_context.name_stack,
'body'))
name_stack=xla.extend_name_stack(name_stack, 'body'))
new_z = mlir.jaxpr_subcomp(body_ctx, body_jaxpr.jaxpr,
_map(mlir.ir_constants, body_jaxpr.consts),
*(y + z))
if batched:
body_pred_ctx = ctx.module_context.replace(
name_stack=xla.extend_name_stack(ctx.module_context.name_stack,
name_stack=xla.extend_name_stack(name_stack,
'body_pred'))
(body_pred,), = mlir.jaxpr_subcomp(
body_pred_ctx, cond_jaxpr.jaxpr,
Expand Down Expand Up @@ -1360,11 +1359,14 @@ def _cond_lowering(ctx, index, *args, branches, linear):
# TODO(phawkins): avoid build_generic when CaseOp is fixed.
case_op = mhlo.CaseOp.build_generic(
flat_output_types, [index], regions=len(branches))
name_stack = extend_name_stack(ctx.module_context.name_stack, 'cond')
for i, jaxpr in enumerate(branches):
branch = case_op.regions[i].blocks.append()
with ir.InsertionPoint(branch):
sub_ctx = ctx.module_context.replace(
name_stack=xla.extend_name_stack(name_stack, f'branch_{i}_fun'))
out_vals = mlir.jaxpr_subcomp(
ctx.module_context, jaxpr.jaxpr,
sub_ctx, jaxpr.jaxpr,
_map(mlir.ir_constants, jaxpr.consts),
*_map(mlir.wrap_singleton_ir_values, args))
mhlo.ReturnOp(util.flatten(out_vals))
Expand Down
8 changes: 4 additions & 4 deletions jax/experimental/jax2tf/tests/sharding_test.py
Expand Up @@ -73,7 +73,7 @@ def _check_sharding_annotations(self,
if jtu.device_under_test() == "gpu":
raise unittest.SkipTest("Sharding HLO tests not useful for GPU")

jax_comp = jax.xla_computation(f_jax)(*args)
jax_comp = f_jax.lower(*args).compiler_ir(dialect="hlo")
jax_hlo = jax_comp.as_hlo_text()
if LOG_HLO:
logging.info("[%s] got JAX HLO %s", self._testMethodName, jax_hlo)
Expand Down Expand Up @@ -152,7 +152,7 @@ def jax_func(x, y):

shape = (8, 10)
x = np.arange(np.prod(shape), dtype=np.float32).reshape(shape)
hlo = jax.xla_computation(jax_func)(x, x).as_hlo_text()
hlo = jax_func.lower(x, x).compiler_ir(dialect="hlo").as_hlo_text()
print(f"HLO is {hlo}")
print(f"JAXPR is {jax.make_jaxpr(jax_func)(x, x)}")
self._check_sharding_annotations(
Expand Down Expand Up @@ -222,7 +222,7 @@ def jax_func(x, y):
r"f32\[2,2\].*sharding={devices=\[4,1\]0,1,2,3|f32\[6,8\].*sharding={devices=\[4,1\]0,1,2,3",
r"f32\[2,2\].*sharding={devices=\[4,1\]0,1,2,3|f32\[6,8\].*sharding={devices=\[4,1\]0,1,2,3",
# TODO: why we cannot see .*sharding={devices=\[4,1\]0,1,2,3
r"f32\[1,6,2\]", # output
r"f32\[6,2\]", # output
],
num_partitions=4)

Expand All @@ -247,7 +247,7 @@ def jax_func(x): # x: f32[12, 8]
expected_opt=[
r"f32\[12,8\].*sharding={replicated}", # x
# TODO: why can't we see "sharding={devices=\[2,1\]0,1"
r"f32\[1,12,8\]", # y
r"f32\[12,8\]", # y
# TODO: why can't we see "sharding={replicated}" ?
r"f32\[6,8\]", # output
],
Expand Down
26 changes: 19 additions & 7 deletions jax/interpreters/mlir.py
Expand Up @@ -41,6 +41,7 @@
from jax._src.lib import xla_client as xc
from jax._src import source_info_util
import jax._src.util as util
from jax.config import config
import jax.interpreters.ad as ad
import jax.interpreters.partial_eval as pe
import jax.interpreters.xla as xla
Expand Down Expand Up @@ -288,7 +289,12 @@ def _source_info_to_location(
primitive: core.Primitive, params: Dict,
source_info: source_info_util.SourceInfo,
name_stack: Union[str, source_info_util.NameStack] = "") -> ir.Location:
eqn_str = str(name_stack) + core.str_eqn_compact(primitive.name, params)
if config.jax_experimental_name_stack:
eqn_str = (f'{str(source_info.name_stack)}/'
f'{core.str_eqn_compact(primitive.name, params)}')
else:
assert isinstance(name_stack, str)
eqn_str = name_stack + core.str_eqn_compact(primitive.name, params)
frame = source_info_util.user_frame(source_info)
if frame is None:
loc = ir.Location.unknown()
Expand Down Expand Up @@ -746,7 +752,13 @@ def write(v: core.Var, node: Sequence[ir.Value]):
map(write, jaxpr.invars, args)
for eqn in jaxpr.eqns:
in_nodes = map(read, eqn.invars)
loc = _source_info_to_location(eqn.primitive, eqn.params, eqn.source_info,
if config.jax_experimental_name_stack:
assert isinstance(ctx.name_stack, source_info_util.NameStack), type(ctx.name_stack)
source_info = eqn.source_info.replace(
name_stack=ctx.name_stack + eqn.source_info.name_stack)
else:
source_info = eqn.source_info
loc = _source_info_to_location(eqn.primitive, eqn.params, source_info,
name_stack=ctx.name_stack)
with source_info_util.user_context(eqn.source_info.traceback), loc:
if eqn.primitive in _platform_specific_lowerings[ctx.platform]:
Expand All @@ -762,8 +774,10 @@ def write(v: core.Var, node: Sequence[ir.Value]):
f"MLIR translation rule for primitive '{eqn.primitive.name}' not "
f"found for platform {ctx.platform}")

eqn_ctx = (ctx.replace(name_stack=source_info.name_stack) if
config.jax_experimental_name_stack else ctx)
rule_ctx = LoweringRuleContext(
module_context=ctx, primitive=eqn.primitive,
module_context=eqn_ctx, primitive=eqn.primitive,
avals_in=map(aval, eqn.invars), avals_out=map(aval, eqn.outvars))
ans = rule(rule_ctx, *map(_unwrap_singleton_ir_values, in_nodes),
**eqn.params)
Expand Down Expand Up @@ -812,9 +826,7 @@ def _call_lowering(fn_name, stack_name, call_jaxpr, backend, ctx, avals_in,
xla.check_backend_matches(backend, ctx.platform)
output_types = map(aval_to_ir_types, avals_out)
flat_output_types = util.flatten(output_types)
sub_ctx = ctx.replace(
name_stack=xla.extend_name_stack(ctx.name_stack, stack_name))
symbol_name = lower_jaxpr_to_fun(sub_ctx, fn_name,
symbol_name = lower_jaxpr_to_fun(ctx, fn_name,
core.ClosedJaxpr(call_jaxpr, ())).name.value
call = func_dialect.CallOp(flat_output_types,
ir.FlatSymbolRefAttr.get(symbol_name),
Expand All @@ -825,7 +837,7 @@ def _xla_call_lower(ctx, *args,
backend=None, name, call_jaxpr, donated_invars, inline=None,
device=None):
del device, donated_invars, inline # Ignored.
return _call_lowering(f"jit_{name}", xla.wrap_name(name, "jit"), call_jaxpr,
return _call_lowering(name, xla.wrap_name(name, "jit"), call_jaxpr,
backend, ctx.module_context, ctx.avals_in, ctx.avals_out,
*args)

Expand Down
18 changes: 6 additions & 12 deletions jax/interpreters/sharded_jit.py
Expand Up @@ -230,12 +230,9 @@ def _sharded_jit_lowering(ctx, *in_nodes,
args = []
for ns, sharding in safe_zip(
safe_map(mlir.wrap_singleton_ir_values, in_nodes), in_parts):
if sharding is not None:
args.append(
[mlir.wrap_with_sharding_op(n, xla.sharding_to_proto(sharding))
for n in ns])
else:
args.append(ns)
args.append(
[mlir.wrap_with_sharding_op(n, xla.sharding_to_proto(sharding))
for n in ns])

sub_ctx = ctx.module_context.replace(
name_stack=new_name_stack(wrap_name(name, "sharded_jit")))
Expand All @@ -252,12 +249,9 @@ def _sharded_jit_lowering(ctx, *in_nodes,
out_parts = out_parts_thunk()
outputs = []
for ns, sharding in safe_zip(out_nodes, out_parts):
if sharding is not None:
outputs.append(
[mlir.wrap_with_sharding_op(n, xla.sharding_to_proto(sharding))
for n in ns])
else:
outputs.append(ns)
outputs.append(
[mlir.wrap_with_sharding_op(n, xla.sharding_to_proto(sharding))
for n in ns])
return outputs


Expand Down
4 changes: 2 additions & 2 deletions tests/api_test.py
Expand Up @@ -2158,8 +2158,8 @@ def f(x, y):
self.assertIn("constant(3)", hlo_text)
# The static arguments should be removed from the function being compiled,
# thus the function should have only a single argument.
self.assertIn("parameter.1", hlo_text)
self.assertNotIn("parameter.2", hlo_text)
self.assertIn("parameter(0)", hlo_text)
self.assertNotIn("parameter(1)", hlo_text)

def test_xla_computation_return_shape(self):
_, shape_tree = api.xla_computation(lambda x: (x + 1, jnp.zeros(2, jnp.float32)),
Expand Down
2 changes: 1 addition & 1 deletion tests/host_callback_test.py
Expand Up @@ -179,7 +179,7 @@ def helper_log_ir(name,
num_partitions=None,
strip_metadata=False):
print(f"Jaxpr[{name}]: {jax.make_jaxpr(f_jax)(*args)}")
jax_comp = jax.xla_computation(f_jax, backend=jtu.device_under_test())(*args)
jax_comp = f_jax.lower(*args).compiler_ir(dialect="hlo")
print(f"HLO[{name}]: {jax_comp.as_hlo_text()}")

backend = xla_bridge.get_backend()
Expand Down
28 changes: 10 additions & 18 deletions tests/metadata_test.py
Expand Up @@ -27,13 +27,13 @@ class MetadataTest(jtu.JaxTestCase):

def test_jit_metadata(self):
hlo = jax.xla_computation(jnp.sin)(1.).get_hlo_module().to_string()
self.assertRegex(hlo, 'op_type="sin"')
self.assertRegex(hlo, 'op_name="xla_computation\\(sin\\)/sin"')
self.assertRegex(hlo,
'op_name="xla_computation\\(sin\\)/jit\\(main\\)/sin"')
def foo(x):
return jnp.sin(x)
hlo = jax.xla_computation(foo)(1.).get_hlo_module().to_string()
self.assertRegex(hlo, 'op_type="sin"')
self.assertRegex(hlo, 'op_name="xla_computation\\(foo\\)/sin"')
self.assertRegex(hlo,
'op_name="xla_computation\\(foo\\)/jit\\(main\\)/sin"')

@unittest.skip("TODO") # TODO(jekbradbury)
def test_nested_jit_metadata(self):
Expand Down Expand Up @@ -62,28 +62,20 @@ def test_grad_jit_metadata(self):
def foo(x):
return jnp.sin(x)
hlo = jax.xla_computation(jax.grad(foo))(1.).get_hlo_module().to_string()
self.assertRegex(hlo, 'op_type="sin"')
self.assertRegex(hlo, 'op_type="cos"')
self.assertRegex(hlo, 'op_type="mul"')
# TODO(mattjj,jekbradbury): update these tests post-omnistaging
# self.assertRegex(hlo, 'op_name=".*jit\\(jvp\\(foo\\)\\)/sin"')
# self.assertRegex(hlo, 'op_name=".*jit\\(jvp\\(foo\\)\\)/cos"')
# self.assertRegex(hlo, 'op_name=".*jit\\(transpose\\('
# 'jvp\\(foo\\)\\)\\)/mul"')
self.assertRegex(hlo, 'op_name=".*jit\\(jvp\\(foo\\)\\)/sin"')
self.assertRegex(hlo, 'op_name=".*jit\\(jvp\\(foo\\)\\)/cos"')
self.assertRegex(hlo, 'op_name=".*jit\\(transpose\\(jvp\\(foo\\)\\)\\)/mul"')

def test_cond_metadata(self):
def true_fun(x):
return jnp.sin(x)
def false_fun(x):
return jnp.cos(x)
def f(x):
return jax.lax.cond(True, x, true_fun, x, false_fun)
hlo = jax.xla_computation(f)(1.).get_hlo_module().to_string()
self.assertRegex(hlo, 'op_type="cond"')
def f(which, x):
return jax.lax.cond(which, x, true_fun, x, false_fun)
hlo = jax.xla_computation(f)(True, 1.).get_hlo_module().to_string()
self.assertRegex(hlo, 'op_name=".*cond\\[linear=\\(False, False\\)\\]"')
self.assertRegex(hlo, 'op_type="cos"')
self.assertRegex(hlo, 'op_name=".*cond/branch_0_fun/cos"')
self.assertRegex(hlo, 'op_type="sin"')
self.assertRegex(hlo, 'op_name=".*cond/branch_1_fun/sin"')

def test_source_file_prefix_removal(self):
Expand Down

0 comments on commit ad8e6ad

Please sign in to comment.