Skip to content

Commit

Permalink
Increase minimum jaxlib version to 0.3.0.
Browse files Browse the repository at this point in the history
  • Loading branch information
hawkinsp committed Mar 4, 2022
1 parent f6a5f0d commit c978df5
Show file tree
Hide file tree
Showing 13 changed files with 131 additions and 371 deletions.
310 changes: 89 additions & 221 deletions jax/_src/lax/control_flow.py
Expand Up @@ -644,167 +644,75 @@ def _pred_bcast_select_mhlo(
return mhlo.SelectOp(bcast_pred, x, y).results


if jax._src.lib._xla_extension_version < 48:
def _while_lowering(ctx, *args, cond_jaxpr, body_jaxpr, cond_nconsts,
body_nconsts):
pred_aval = cond_jaxpr.out_avals[0]
batched = bool(pred_aval.shape)

def _while_lowering(ctx, *args, cond_jaxpr, body_jaxpr, cond_nconsts,
body_nconsts):
pred_aval = cond_jaxpr.out_avals[0]
batched = bool(pred_aval.shape)

# Since jaxprs don't have tuples and have multiple return values, but we need
# the HLO While loop to take a single tuple input and output a single boolean
# (for the cond computation) or a single tuple output (for the body
# computation), we build XLA computations that handle the tuple munging before
# generating a Call into the computations formed from the jaxprs.

loop_carry_types = _map(mlir.aval_to_ir_types, ctx.avals_in)
flat_loop_carry_types = util.flatten(loop_carry_types)
loop_carry_tuple_type = ir.TupleType.get_tuple(flat_loop_carry_types)

flat_args = mlir.flatten_lowering_ir_args(args)
init_carry = mhlo.TupleOp(loop_carry_tuple_type, flat_args)
while_op = mhlo.WhileOp([loop_carry_tuple_type], [init_carry.result])

# Loop condition
cond_block = while_op.regions[0].blocks.append(loop_carry_tuple_type)
with ir.InsertionPoint(cond_block):
flat_cond_args = [
mhlo.GetTupleElementOp(input_type, cond_block.arguments[0],
mlir.i32_attr(i)).result
for i, input_type in enumerate(flat_loop_carry_types)
]
cond_args = util.unflatten(flat_cond_args, _map(len, loop_carry_types))
x, _, z = util.split_list(cond_args, [cond_nconsts, body_nconsts])
cond_ctx = ctx.module_context.replace(
name_stack=xla.extend_name_stack(ctx.module_context.name_stack,
'cond'))
(pred,), = mlir.jaxpr_subcomp(cond_ctx, cond_jaxpr.jaxpr,
_map(mlir.ir_constants, cond_jaxpr.consts),
*(x + z))
if batched:
pred_ctx = mlir.LoweringRuleContext(
module_context=ctx.module_context,
primitive=None,
avals_in=[pred_aval],
avals_out=[pred_aval.update(shape=())])
pred, = lax._unary_reduce_lower(
mhlo.OrOp,
lambda dtype: np.array(False, dtype),
pred_ctx,
pred,
axes=tuple(range(len(pred_aval.shape))))
mhlo.ReturnOp([pred])

# Loop body
body_block = while_op.regions[1].blocks.append(loop_carry_tuple_type)
with ir.InsertionPoint(body_block):
flat_body_args = [
mhlo.GetTupleElementOp(input_type, body_block.arguments[0],
mlir.i32_attr(i)).result
for i, input_type in enumerate(flat_loop_carry_types)
]
body_args = util.unflatten(flat_body_args, _map(len, loop_carry_types))
x, y, z = util.split_list(body_args, [cond_nconsts, body_nconsts])
body_ctx = ctx.module_context.replace(
name_stack=xla.extend_name_stack(ctx.module_context.name_stack,
'body'))
new_z = mlir.jaxpr_subcomp(body_ctx, body_jaxpr.jaxpr,
_map(mlir.ir_constants, body_jaxpr.consts),
*(y + z))
if batched:
body_pred_ctx = ctx.module_context.replace(
name_stack=xla.extend_name_stack(ctx.module_context.name_stack,
'body_pred'))
(body_pred,), = mlir.jaxpr_subcomp(
body_pred_ctx, cond_jaxpr.jaxpr,
_map(mlir.ir_constants, cond_jaxpr.consts), *(x + z))
new_z = _map(
partial(_pred_bcast_select_mhlo, pred_aval, body_pred), new_z, z,
body_jaxpr.out_avals)

new_carry = mhlo.TupleOp(
loop_carry_tuple_type,
[*util.flatten(x), *util.flatten(y), *util.flatten(new_z)])
mhlo.ReturnOp([new_carry.result])

outputs = util.unflatten([
mhlo.GetTupleElementOp(output_type, while_op.result,
mlir.i32_attr(i)).result
for i, output_type in enumerate(flat_loop_carry_types)
], _map(len, loop_carry_types))
_, _, z = util.split_list(outputs, [cond_nconsts, body_nconsts])
return z
else:

def _while_lowering(ctx, *args, cond_jaxpr, body_jaxpr, cond_nconsts,
body_nconsts):
pred_aval = cond_jaxpr.out_avals[0]
batched = bool(pred_aval.shape)

loop_carry_types = _map(mlir.aval_to_ir_types, ctx.avals_in)
flat_loop_carry_types = util.flatten(loop_carry_types)

flat_args = mlir.flatten_lowering_ir_args(args)
while_op = mhlo.WhileOp(flat_loop_carry_types, flat_args)

# Loop condition
cond_block = while_op.regions[0].blocks.append(*flat_loop_carry_types)
with ir.InsertionPoint(cond_block):
flat_cond_args = [
cond_block.arguments[i] for i in range(len(flat_loop_carry_types))
]
cond_args = util.unflatten(flat_cond_args, _map(len, loop_carry_types))
x, _, z = util.split_list(cond_args, [cond_nconsts, body_nconsts])
cond_ctx = ctx.module_context.replace(
name_stack=xla.extend_name_stack(ctx.module_context.name_stack,
'cond'))
(pred,), = mlir.jaxpr_subcomp(cond_ctx, cond_jaxpr.jaxpr,
_map(mlir.ir_constants, cond_jaxpr.consts),
*(x + z))
if batched:
pred_ctx = mlir.LoweringRuleContext(
module_context=ctx.module_context,
primitive=None,
avals_in=[pred_aval],
avals_out=[pred_aval.update(shape=())])
pred, = lax._unary_reduce_lower(
mhlo.OrOp,
lambda dtype: np.array(False, dtype),
pred_ctx,
pred,
axes=tuple(range(len(pred_aval.shape))))
mhlo.ReturnOp([pred])

# Loop body
body_block = while_op.regions[1].blocks.append(*flat_loop_carry_types)
with ir.InsertionPoint(body_block):
flat_body_args = [
body_block.arguments[i] for i in range(len(flat_loop_carry_types))
]
body_args = util.unflatten(flat_body_args, _map(len, loop_carry_types))
x, y, z = util.split_list(body_args, [cond_nconsts, body_nconsts])
body_ctx = ctx.module_context.replace(
loop_carry_types = _map(mlir.aval_to_ir_types, ctx.avals_in)
flat_loop_carry_types = util.flatten(loop_carry_types)

flat_args = mlir.flatten_lowering_ir_args(args)
while_op = mhlo.WhileOp(flat_loop_carry_types, flat_args)

# Loop condition
cond_block = while_op.regions[0].blocks.append(*flat_loop_carry_types)
with ir.InsertionPoint(cond_block):
flat_cond_args = [
cond_block.arguments[i] for i in range(len(flat_loop_carry_types))
]
cond_args = util.unflatten(flat_cond_args, _map(len, loop_carry_types))
x, _, z = util.split_list(cond_args, [cond_nconsts, body_nconsts])
cond_ctx = ctx.module_context.replace(
name_stack=xla.extend_name_stack(ctx.module_context.name_stack,
'cond'))
(pred,), = mlir.jaxpr_subcomp(cond_ctx, cond_jaxpr.jaxpr,
_map(mlir.ir_constants, cond_jaxpr.consts),
*(x + z))
if batched:
pred_ctx = mlir.LoweringRuleContext(
module_context=ctx.module_context,
primitive=None,
avals_in=[pred_aval],
avals_out=[pred_aval.update(shape=())])
pred, = lax._unary_reduce_lower(
mhlo.OrOp,
lambda dtype: np.array(False, dtype),
pred_ctx,
pred,
axes=tuple(range(len(pred_aval.shape))))
mhlo.ReturnOp([pred])

# Loop body
body_block = while_op.regions[1].blocks.append(*flat_loop_carry_types)
with ir.InsertionPoint(body_block):
flat_body_args = [
body_block.arguments[i] for i in range(len(flat_loop_carry_types))
]
body_args = util.unflatten(flat_body_args, _map(len, loop_carry_types))
x, y, z = util.split_list(body_args, [cond_nconsts, body_nconsts])
body_ctx = ctx.module_context.replace(
name_stack=xla.extend_name_stack(ctx.module_context.name_stack,
'body'))
new_z = mlir.jaxpr_subcomp(body_ctx, body_jaxpr.jaxpr,
_map(mlir.ir_constants, body_jaxpr.consts),
*(y + z))
if batched:
body_pred_ctx = ctx.module_context.replace(
name_stack=xla.extend_name_stack(ctx.module_context.name_stack,
'body'))
new_z = mlir.jaxpr_subcomp(body_ctx, body_jaxpr.jaxpr,
_map(mlir.ir_constants, body_jaxpr.consts),
*(y + z))
if batched:
body_pred_ctx = ctx.module_context.replace(
name_stack=xla.extend_name_stack(ctx.module_context.name_stack,
'body_pred'))
(body_pred,), = mlir.jaxpr_subcomp(
body_pred_ctx, cond_jaxpr.jaxpr,
_map(mlir.ir_constants, cond_jaxpr.consts), *(x + z))
new_z = _map(
partial(_pred_bcast_select_mhlo, pred_aval, body_pred), new_z, z,
body_jaxpr.out_avals)

mhlo.ReturnOp([*util.flatten(x), *util.flatten(y), *util.flatten(new_z)])

outputs = util.unflatten(while_op.results, _map(len, loop_carry_types))
_, _, z = util.split_list(outputs, [cond_nconsts, body_nconsts])
return z
'body_pred'))
(body_pred,), = mlir.jaxpr_subcomp(
body_pred_ctx, cond_jaxpr.jaxpr,
_map(mlir.ir_constants, cond_jaxpr.consts), *(x + z))
new_z = _map(
partial(_pred_bcast_select_mhlo, pred_aval, body_pred), new_z, z,
body_jaxpr.out_avals)

mhlo.ReturnOp([*util.flatten(x), *util.flatten(y), *util.flatten(new_z)])

outputs = util.unflatten(while_op.results, _map(len, loop_carry_types))
_, _, z = util.split_list(outputs, [cond_nconsts, body_nconsts])
return z

mlir.register_lowering(while_p, _while_lowering)

Expand Down Expand Up @@ -1408,68 +1316,28 @@ def cond_bind(*args, branches, linear):
pe.partial_eval_jaxpr_custom_rules[cond_p] = \
partial(pe.partial_eval_jaxpr_custom_rule_not_implemented, 'cond')

if jax._src.lib._xla_extension_version < 51:

def _cond_lowering(ctx, index, *args, branches, linear):
del linear # Unused.
arg_avals = ctx.avals_in[1:]
input_types = _map(mlir.aval_to_ir_types, arg_avals)
output_types = _map(mlir.aval_to_ir_types, ctx.avals_out)
flat_input_types = util.flatten(input_types)
flat_output_types = util.flatten(output_types)
input_tuple_type = ir.TupleType.get_tuple(flat_input_types)
output_tuple_type = ir.TupleType.get_tuple(flat_output_types)
op = mhlo.TupleOp(input_tuple_type,
mlir.flatten_lowering_ir_args(args)).result
# TODO(phawkins): avoid build_generic when CaseOp is fixed.
case_op = mhlo.CaseOp.build_generic([output_tuple_type],
[index] + [op] * len(branches),
regions=len(branches))
for i, jaxpr in enumerate(branches):
branch = case_op.regions[i].blocks.append(input_tuple_type)
with ir.InsertionPoint(branch):
args = [
mhlo.GetTupleElementOp(input_type, branch.arguments[0],
mlir.i32_attr(i)).result
for i, input_type in enumerate(flat_input_types)
]
unflattened_args = util.unflatten(args, _map(len, input_types))
out_vals = mlir.jaxpr_subcomp(ctx.module_context, jaxpr.jaxpr,
jaxpr.consts, *unflattened_args)
out = mhlo.TupleOp(output_tuple_type, util.flatten(out_vals)).results
mhlo.ReturnOp(out)

results = [
mhlo.GetTupleElementOp(output_type, case_op.result,
mlir.i32_attr(i)).result
for i, output_type in enumerate(flat_output_types)
]
return util.unflatten(results, _map(len, output_types))

else:

def _cond_lowering(ctx, index, *args, branches, linear):
del linear # Unused.
output_types = _map(mlir.aval_to_ir_types, ctx.avals_out)
flat_output_types = util.flatten(output_types)

# mhlo.CaseOp takes a single argument 'index' and the corresponding blocks
# have no arguments; the computation within the block uses implicit
# captures.

# TODO(phawkins): avoid build_generic when CaseOp is fixed.
case_op = mhlo.CaseOp.build_generic(
flat_output_types, [index], regions=len(branches))
for i, jaxpr in enumerate(branches):
branch = case_op.regions[i].blocks.append()
with ir.InsertionPoint(branch):
out_vals = mlir.jaxpr_subcomp(
ctx.module_context, jaxpr.jaxpr,
_map(mlir.ir_constants, jaxpr.consts),
*_map(mlir.wrap_singleton_ir_values, args))
mhlo.ReturnOp(util.flatten(out_vals))

return util.unflatten(case_op.results, _map(len, output_types))
def _cond_lowering(ctx, index, *args, branches, linear):
del linear # Unused.
output_types = _map(mlir.aval_to_ir_types, ctx.avals_out)
flat_output_types = util.flatten(output_types)

# mhlo.CaseOp takes a single argument 'index' and the corresponding blocks
# have no arguments; the computation within the block uses implicit
# captures.

# TODO(phawkins): avoid build_generic when CaseOp is fixed.
case_op = mhlo.CaseOp.build_generic(
flat_output_types, [index], regions=len(branches))
for i, jaxpr in enumerate(branches):
branch = case_op.regions[i].blocks.append()
with ir.InsertionPoint(branch):
out_vals = mlir.jaxpr_subcomp(
ctx.module_context, jaxpr.jaxpr,
_map(mlir.ir_constants, jaxpr.consts),
*_map(mlir.wrap_singleton_ir_values, args))
mhlo.ReturnOp(util.flatten(out_vals))

return util.unflatten(case_op.results, _map(len, output_types))

mlir.register_lowering(cond_p, _cond_lowering)

Expand Down

0 comments on commit c978df5

Please sign in to comment.