Skip to content

Commit

Permalink
Handle jaxpr constants correctly in MLIR lowering of conditional bran…
Browse files Browse the repository at this point in the history
…ches.

Add some dynamic type checks and type annotations to catch this kind of problem sooner.

There's no test case, because I'm not entirely sure how to make a test case for this. In fact, I'm not even sure it's legal for a conditional branch to have non-empty constants. We'll dig into that separately.

PiperOrigin-RevId: 431697808
  • Loading branch information
hawkinsp authored and jax authors committed Mar 1, 2022
1 parent c7508d1 commit cffe997
Show file tree
Hide file tree
Showing 2 changed files with 12 additions and 7 deletions.
3 changes: 2 additions & 1 deletion jax/_src/lax/control_flow.py
Original file line number Diff line number Diff line change
Expand Up @@ -1464,7 +1464,8 @@ def _cond_lowering(ctx, index, *args, branches, linear):
branch = case_op.regions[i].blocks.append()
with ir.InsertionPoint(branch):
out_vals = mlir.jaxpr_subcomp(
ctx.module_context, jaxpr.jaxpr, jaxpr.consts,
ctx.module_context, jaxpr.jaxpr,
_map(mlir.ir_constants, jaxpr.consts),
*_map(mlir.wrap_singleton_ir_values, args))
mhlo.ReturnOp(util.flatten(out_vals))

Expand Down
16 changes: 10 additions & 6 deletions jax/interpreters/mlir.py
Original file line number Diff line number Diff line change
Expand Up @@ -164,7 +164,10 @@ def ir_constants(val: Any,
"""
for t in type(val).__mro__:
handler = _constant_handlers.get(t)
if handler: return handler(val, canonicalize_types)
if handler:
out = handler(val, canonicalize_types)
assert all(isinstance(v, ir.Value) for v in out), (type(val), out)
return out
if hasattr(val, '__jax_array__'):
return ir_constants(val.__jax_array__(), canonicalize_types)
raise TypeError("No constant handler for type: {}".format(type(val)))
Expand Down Expand Up @@ -259,7 +262,7 @@ def _device_array_constant_handler(val, canonicalize_types):
register_constant_handler(t, _device_array_constant_handler)

register_constant_handler(
core.Token, lambda _, __: [mhlo.CreateTokenOp(mhlo.TokenType.get())])
core.Token, lambda _, __: [mhlo.CreateTokenOp(mhlo.TokenType.get()).result])

# Source locations

Expand Down Expand Up @@ -633,27 +636,28 @@ def jaxpr_subcomp(ctx: ModuleContext, jaxpr: core.Jaxpr,
Assumes that an MLIR context, location, and insertion point are set.
"""
def read(v):
def read(v: core.Var) -> Sequence[ir.Value]:
if type(v) is core.Literal:
return ir_constants(v.val, canonicalize_types=True)
else:
return env[v]

def aval(v):
def aval(v: core.Var) -> core.AbstractValue:
if type(v) is core.Literal:
return xla.abstractify(v.val)
else:
return v.aval

def write(v, node):
def write(v: core.Var, node: Sequence[ir.Value]):
assert node is not None
env[v] = tuple(node)


env: Dict[core.Var, Tuple[ir.Value]] = {}
env: Dict[core.Var, Tuple[ir.Value, ...]] = {}

assert len(args) == len(jaxpr.invars), (jaxpr, args)
assert len(consts) == len(jaxpr.constvars), (jaxpr, consts)
assert all(isinstance(v, ir.Value) for vs in consts for v in vs), consts
write(core.unitvar, ())
map(write, jaxpr.constvars, consts)
map(write, jaxpr.invars, args)
Expand Down

0 comments on commit cffe997

Please sign in to comment.