Skip to content

Commit

Permalink
[Mosaic] Support scf.while and scf.condition.
Browse files Browse the repository at this point in the history
This allows lowering while loops of a more general form than "for i" loops.
Improving generality here allows us to implement more interesting dynamic looping behaviors, such as progressive scans in VMEM.

PiperOrigin-RevId: 625411151
  • Loading branch information
jax authors committed Apr 16, 2024
1 parent 1a650cd commit 5bd6013
Show file tree
Hide file tree
Showing 2 changed files with 301 additions and 0 deletions.
181 changes: 181 additions & 0 deletions jaxlib/mosaic/dialect/tpu/transforms/apply_vector_layout.cc
Original file line number Diff line number Diff line change
Expand Up @@ -1000,6 +1000,185 @@ LogicalResult scf_for_rule(RewriteContext &ctx, Operation &op,
return success();
}

LogicalResult scf_while_rule(RewriteContext &ctx, Operation &op,
const ArrayRef<Layout> layouts_in,
const ArrayRef<Layout> layouts_out) {
scf::WhileOp while_op = cast<scf::WhileOp>(op);
TPU_ASSERT_EQ_OP(layouts_in.size(), while_op->getNumOperands());
TPU_ASSERT_EQ_OP(layouts_out.size(), while_op->getNumResults());
TPU_ASSERT_EQ_OP(layouts_in.size(), layouts_out.size());

// The terminator for the before region is the condition op.
// It takes multiple arguments -- the first being the decision to execute the
// after region or branch to the exit.
FAILUREOR_ASSIGN_OR_RETURN(
const SmallVector<Layout> condition_in_layouts,
getInLayouts(*while_op.getBeforeBody()->getTerminator(),
ctx.target_shape));
if (!llvm::equal(ArrayRef<Layout>(condition_in_layouts).drop_front(1),
layouts_out)) {
return op.emitOpError(
"Mismatched layouts between scf.while result and its before region "
"condition.");
}

if (failed(applyLayoutBlock(ctx, *while_op.getBeforeBody()))) {
return failure();
}

FAILUREOR_ASSIGN_OR_RETURN(
const SmallVector<Layout> after_yield_in_layouts,
getInLayouts(*while_op.getYieldOp(), ctx.target_shape));
if (!layouts_out.empty() &&
ArrayRef<Layout>(after_yield_in_layouts) != layouts_out) {
return op.emitOpError(
"Not implemented: different layouts while's yield's operands and "
"results");
}

if (failed(applyLayoutBlock(ctx, *while_op.getAfterBody()))) {
return failure();
}

if (op.getNumResults() == 0) {
return success();
}

OpBuilder builder(&op);
SmallVector<Value> unrolled_args;
for (int i = 0; i < layouts_in.size(); ++i) {
auto layout = layouts_in[i];
auto operand = while_op.getOperand(i);
if (auto vector_operand = dyn_cast<TypedValue<VectorType>>(operand)) {
if (!layout.has_value()) {
return op.emitOpError("Expected layout for vector operand");
}
FAILUREOR_ASSIGN_OR_RETURN(
const xla::Array<Value> tiles,
disassemble(builder, *layout, vector_operand, ctx.target_shape));
unrolled_args.append(tiles.begin(), tiles.end());
} else {
if (layout.has_value()) {
return op.emitOpError("Expected no layout for scalar operand");
}
unrolled_args.push_back(operand);
}
}

// Create a new scf::WhileOp with unrolled args.
auto new_op = builder.create<scf::WhileOp>(
while_op->getLoc(),
TypeRange(while_op.getConditionOp().getOperands().drop_front(1)),
unrolled_args, nullptr, nullptr);

const auto tile_body_args = [&](::mlir::Block *old_body,
::mlir::Block *new_body,
const ArrayRef<Layout> layouts) {
TPU_ASSERT_OP(old_body != nullptr);
TPU_ASSERT_OP(new_body != nullptr);
int num_old_args = old_body->getNumArguments();
SmallVector<Location> locs(new_body->getNumArguments(), while_op.getLoc());
old_body->addArguments(TypeRange(new_body->getArguments()), locs);
builder.setInsertionPointToStart(old_body);
auto arg_idx = num_old_args;
for (auto [old_arg, layout] : llvm::zip_equal(
old_body->getArguments().take_front(num_old_args), layouts)) {
if (const auto vty = dyn_cast<VectorType>(old_arg.getType())) {
TPU_ASSERT_OP(layout.has_value());
const SmallVector<int64_t> tiles_shape =
layout->tileArrayShape(vty.getShape(), ctx.target_shape);
const int64_t num_vectors = ShapedType::getNumElements(tiles_shape);
xla::Array<Value> tiles(tiles_shape);
TPU_ASSERT_LE_OP(arg_idx + num_vectors, old_body->getNumArguments());
tiles.SetValues(llvm::make_range(
old_body->getArguments().begin() + arg_idx,
old_body->getArguments().begin() + arg_idx + num_vectors));
arg_idx += num_vectors;
RollVectorsOp rolled_op =
assemble(builder, vty, *layout, tiles, ctx.target_shape);
old_arg.replaceUsesWithIf(rolled_op, [&](OpOperand &operand) {
return operand.getOwner() != rolled_op;
});
} else {
TPU_ASSERT_OP(!layout.has_value());
old_arg.replaceAllUsesWith(old_body->getArgument(arg_idx));
++arg_idx;
}
}
old_body->eraseArguments(0, num_old_args);
return success();
};

const auto before_status = tile_body_args(while_op.getBeforeBody(),
new_op.getBeforeBody(), layouts_in);
if (before_status.failed()) return before_status;
new_op.getBefore().takeBody(while_op.getBefore());

const auto after_status = tile_body_args(while_op.getAfterBody(),
new_op.getAfterBody(), layouts_out);
if (after_status.failed()) return after_status;
new_op.getAfter().takeBody(while_op.getAfter());

builder.setInsertionPointAfter(new_op);
int64_t res_idx = 0;
SmallVector<Value> rolled_results;
for (auto [result, layout] :
llvm::zip_equal(while_op.getResults(), layouts_out)) {
if (const auto vty = dyn_cast<VectorType>(result.getType())) {
TPU_ASSERT_OP(layout.has_value());
const SmallVector<int64_t> tiles_shape =
layout->tileArrayShape(vty.getShape(), ctx.target_shape);
const int64_t num_vectors = ShapedType::getNumElements(tiles_shape);
xla::Array<Value> tiles(tiles_shape);
TPU_ASSERT_LE_OP(res_idx + num_vectors, new_op.getResults().size());
tiles.SetValues(llvm::make_range(
new_op.getResults().begin() + res_idx,
new_op.getResults().begin() + res_idx + num_vectors));
res_idx += num_vectors;
RollVectorsOp rolled_op =
assemble(builder, vty, *layout, tiles, ctx.target_shape);
rolled_results.push_back(rolled_op);
} else {
TPU_ASSERT_OP(!layout.has_value());
rolled_results.push_back(new_op.getResult(res_idx));
++res_idx;
}
}

while_op.replaceAllUsesWith(rolled_results);
while_op.erase();
return success();
}

LogicalResult scf_condition_rule(RewriteContext &ctx, Operation &op,
const ArrayRef<Layout> layouts_in,
const ArrayRef<Layout> layouts_out) {
OpBuilder builder(&op);
auto condition_op = cast<scf::ConditionOp>(op);
TPU_ASSERT_EQ_OP(layouts_in.size(), condition_op.getNumOperands());
TPU_ASSERT_EQ_OP(layouts_out.size(), 0);
SmallVector<Value> unrolled;

for (auto [operand, layout] :
llvm::zip_equal(condition_op.getOperands(), layouts_in)) {
if (auto vector_operand = dyn_cast<TypedValue<VectorType>>(operand)) {
// When the operand has vector type, disassemble the operand.
TPU_ASSERT_OP(layout.has_value());
FAILUREOR_ASSIGN_OR_RETURN(
const xla::Array<Value> tiles,
disassemble(builder, *layout, vector_operand, ctx.target_shape));
unrolled.append(tiles.begin(), tiles.end());
} else {
TPU_ASSERT_OP(!layout.has_value());
unrolled.push_back(operand);
}
}

// Replace the old operands with unrolled operands.
condition_op->setOperands(unrolled);
return success();
}

LogicalResult scf_if_rule(RewriteContext &ctx, Operation &op,
const ArrayRef<Layout> layouts_in,
const ArrayRef<Layout> layouts_out) {
Expand Down Expand Up @@ -3634,6 +3813,8 @@ const llvm::StringMap<rule_type> &rules() {
{arith::TruncIOp::getOperationName(), arith_trunci_rule},
{func::ReturnOp::getOperationName(), func_return_rule},
{scf::ForOp::getOperationName(), scf_for_rule},
{scf::WhileOp::getOperationName(), scf_while_rule},
{scf::ConditionOp::getOperationName(), scf_condition_rule},
{scf::IfOp::getOperationName(), scf_if_rule},
{scf::YieldOp::getOperationName(), scf_yield_rule},
{tpu::RotateOp::getOperationName(), tpu_rotate_rule},
Expand Down
120 changes: 120 additions & 0 deletions jaxlib/mosaic/dialect/tpu/transforms/infer_vector_layout.cc
Original file line number Diff line number Diff line change
Expand Up @@ -220,6 +220,14 @@ class VectorLayoutInferer {
if (infer(op).failed()) {
return failure();
}
} else if (auto op = dyn_cast<scf::WhileOp>(any_op)) {
if (infer(op).failed()) {
return failure();
}
} else if (auto op = dyn_cast<scf::ConditionOp>(any_op)) {
if (infer(op).failed()) {
return failure();
}
} else if (auto op = dyn_cast<tpu::RotateOp>(any_op)) {
if (infer(op).failed()) {
return failure();
Expand Down Expand Up @@ -536,6 +544,118 @@ class VectorLayoutInferer {
return success();
}

LogicalResult infer(scf::WhileOp op) {
static LogicalResult (*match_condition)(Operation *) = [](Operation *op) {
TPU_CHECK_OP(isa<scf::ConditionOp>(op), "expected condition terminator");
return success();
};
static LogicalResult (*match_yield)(Operation *) = [](Operation *op) {
TPU_CHECK_OP(isa<scf::YieldOp>(op), "expected yield terminator");
return success();
};
TPU_CHECK_OP(op.getNumRegions() == 2, "expected two blocks for scf.while");

const auto layout_for_type = [&op, this](const ::mlir::Value &arg,
SmallVector<Layout> *layouts) {
if (arg.getType().isSignlessIntOrIndexOrFloat()) {
layouts->push_back(kNoLayout);
} else if (isa<VectorType>(arg.getType())) {
auto layout = getLayout(arg);
layouts->push_back(layout);
} else {
op.emitOpError() << "unsupported arg type " << arg.getType()
<< " in scf.while";
return failure();
}
return success();
};

SmallVector<Layout> in_layouts;
in_layouts.reserve(op->getNumOperands());
for (const auto &arg : op.getInits()) {
const auto status = layout_for_type(arg, &in_layouts);
if (status.failed()) return status;
}

// Formally, the types and layouts of the results should follow the layout
// of the condition op in the Before region, rather than mimicking the input
// layouts. In practice these are constrained to be the same for our current
// pipelines, but doesn't represent the full expressiveness of scf.while.
// TODO(hmckenzie): Base output layout on ConditionOp, not inputs.
SmallVector<Layout> out_layouts = in_layouts;

// Use tpu.assume_layout to annotate every block argument with the layout of
// the corresponding operand in WhileOp and replace all uses of the block
// argument with the result of tpu.assume_layout.
ImplicitLocOpBuilder builder =
ImplicitLocOpBuilder::atBlockBegin(op.getLoc(), op.getBeforeBody());
for (auto [iter_arg, layout] :
llvm::zip_equal(op.getBeforeBody()->getArguments(), in_layouts)) {
if (!dyn_cast<VectorType>(iter_arg.getType())) {
continue;
}
auto assume_layout_op =
builder.create<AssumeLayoutOp>(iter_arg.getType(), iter_arg);
setLayout(assume_layout_op, layout, layout);
iter_arg.replaceUsesWithIf(assume_layout_op, [&](OpOperand &operand) {
return operand.getOwner() != assume_layout_op;
});
}
if (inferBlock(*op.getBeforeBody(), match_condition).failed()) {
return failure();
}

builder =
ImplicitLocOpBuilder::atBlockBegin(op.getLoc(), op.getAfterBody());
for (auto [iter_arg, layout] :
llvm::zip_equal(op.getAfterBody()->getArguments(), out_layouts)) {
if (!dyn_cast<VectorType>(iter_arg.getType())) {
continue;
}
auto assume_layout_op =
builder.create<AssumeLayoutOp>(iter_arg.getType(), iter_arg);
setLayout(assume_layout_op, layout, layout);
iter_arg.replaceUsesWithIf(assume_layout_op, [&](OpOperand &operand) {
return operand.getOwner() != assume_layout_op;
});
}

if (inferBlock(*op.getAfterBody(), match_yield).failed()) {
return failure();
}

auto *condition_op = op.getBeforeBody()->getTerminator();
SmallVector<Layout> cond_layout;
cond_layout.reserve(out_layouts.size() + 1);
cond_layout.push_back(kNoLayout);
cond_layout.append(out_layouts);
setInLayout(condition_op, cond_layout);

auto *yield_op = op.getAfterBody()->getTerminator();
setInLayout(yield_op, in_layouts);

setLayout(op, in_layouts, out_layouts);
return success();
}
LogicalResult infer(scf::ConditionOp op) {
SmallVector<Layout> in_layouts;
in_layouts.reserve(op->getNumOperands());
for (const auto &arg : op.getOperands()) {
if (arg.getType().isSignlessIntOrIndexOrFloat()) {
in_layouts.push_back(kNoLayout);
} else if (isa<VectorType>(arg.getType())) {
auto layout = getLayout(arg);
in_layouts.push_back(layout);
} else {
op.emitOpError() << "unsupported arg type " << arg.getType()
<< " in scf::condition";
return failure();
}
}
setLayout(op, in_layouts, ArrayRef<Layout>(in_layouts).drop_front(1));
return success();
}

LogicalResult infer(tpu::RotateOp op) {
auto bitwidth = op.getType().getElementTypeBitWidth();
if (bitwidth != 32) {
Expand Down

0 comments on commit 5bd6013

Please sign in to comment.