From 991e6ef7195d1a6a8286e9dd0cc6b7dfe2cb4494 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Tom=C3=A1s=20Longeri?= Date: Wed, 4 Oct 2023 16:28:57 -0700 Subject: [PATCH] [Mosaic] apply_vector_layout C++ rewrite (13): scf.if, scf.yield PiperOrigin-RevId: 570845376 --- .../tpu/transforms/apply_vector_layout.cc | 278 +++++++++++++----- 1 file changed, 205 insertions(+), 73 deletions(-) diff --git a/jaxlib/mosaic/dialect/tpu/transforms/apply_vector_layout.cc b/jaxlib/mosaic/dialect/tpu/transforms/apply_vector_layout.cc index 028d3598d384..5cc91138608c 100644 --- a/jaxlib/mosaic/dialect/tpu/transforms/apply_vector_layout.cc +++ b/jaxlib/mosaic/dialect/tpu/transforms/apply_vector_layout.cc @@ -14,11 +14,13 @@ #include "llvm/ADT/SmallVector.h" #include "llvm/ADT/SmallVectorExtras.h" #include "llvm/ADT/StringMap.h" +#include "llvm/ADT/iterator_range.h" #include "llvm/Support/ErrorHandling.h" #include "llvm/Support/MathExtras.h" #include "mlir/Dialect/Arith/IR/Arith.h" #include "mlir/Dialect/Func/IR/FuncOps.h" #include "mlir/Dialect/Math/IR/Math.h" +#include "mlir/Dialect/SCF/IR/SCF.h" #include "mlir/Dialect/Traits.h" #include "mlir/Dialect/Vector/IR/VectorOps.h" #include "mlir/IR/AffineExpr.h" @@ -84,6 +86,13 @@ FailureOr> disassemble(RewriteContext &ctx, const VectorLayout &layout, Value val); namespace { +void moveAllRegions(Operation &src, Operation &dst) { + for (auto [src_region, dst_region] : + llvm::zip_equal(src.getRegions(), dst.getRegions())) { + dst_region.takeBody(src_region); + } +} + // Models Numpy's np.repeat, repeating each element `repeats` times along the // specified axis. For example, if `src` is [1, 2], `axis` is 0 and `repeats` is // 3, this will return [1, 1, 1, 2, 2, 2]. @@ -281,6 +290,76 @@ FailureOr getNativeVregType( elem_ty); } +// Get the layout from a VectorLayoutAttr or StringAttr. +mlir::FailureOr getLayoutFromAttr(Attribute attr) { + if (attr == nullptr) { + return failure(); + } + + if (auto layout_attr = dyn_cast(attr)) { + return layout_attr.getLayout(); + } + + // TODO(tlongeri): StringAttr support was only added temporarily to avoid + // having Python bindings for VectorLayoutAttr. Remove this once we get rid + // of the Python implementation + if (auto string_attr = dyn_cast(attr)) { + StringRef str = string_attr.getValue(); + if (!str.consume_front("#tpu.vpad<\"")) { + return failure(); + } + if (str.consume_front("none")) { + return kNoLayout; + } + if (auto layout = VectorLayout::parse(&str)) { + return layout; + } + return failure(); + } + + return failure(); +} + +// Returns empty vector on null attribute +FailureOr> getLayoutArrayFromAttr(const Attribute attr) { + if (const auto array_attr = dyn_cast_if_present(attr)) { + SmallVector out_layouts; + out_layouts.reserve(array_attr.size()); + for (const Attribute a : array_attr) { + FAILUREOR_ASSIGN_OR_RETURN(const Layout layout, getLayoutFromAttr(a)); + out_layouts.push_back(layout); + } + return out_layouts; + } + return SmallVector{}; +} + +// TODO(tlongeri): Unify with infer_vector_layout.cc's getOutLayout. +FailureOr> getOutLayout(Operation &op) { + // TODO(tlongeri): non-array attribute path should be removed after tests are + // updated + FailureOr failure_or_layout = + getLayoutFromAttr(op.getAttr("out_layout")); + if (succeeded(failure_or_layout)) { + return SmallVector{failure_or_layout.value()}; + } + FAILUREOR_ASSIGN_OR_RETURN(const SmallVector out_layout, + getLayoutArrayFromAttr(op.getAttr("out_layout"))); + if (out_layout.size() != op.getNumResults()) { + return failure(); + } + return out_layout; +} + +FailureOr> getInLayout(Operation &op) { + FAILUREOR_ASSIGN_OR_RETURN(const SmallVector in_layout, + getLayoutArrayFromAttr(op.getAttr("in_layout"))); + if (in_layout.size() != op.getNumOperands()) { + return failure(); + } + return in_layout; +} + LogicalResult elementwise_op_rule( RewriteContext &ctx, Operation &op, const ArrayRef layouts_in, const ArrayRef layouts_out, @@ -638,6 +717,121 @@ LogicalResult func_return_rule(RewriteContext &ctx, Operation &op, return success(); } +LogicalResult scf_if_rule(RewriteContext &ctx, Operation &op, + const ArrayRef layouts_in, + const ArrayRef layouts_out) { + CHECK_EQ(layouts_in.size(), 1); + CHECK(!layouts_in.front().has_value()); + scf::IfOp if_op = cast(op); + FAILUREOR_ASSIGN_OR_RETURN(const SmallVector then_yield_in_layouts, + getInLayout(*if_op.thenYield())); + // TODO(tlongeri): ArrayRef conversion should not be necessary, fix + // after LLVM adds const qualifiers to ==/!= operators. Also + // applies to else_yield_in_layouts comparison below. + if (!layouts_out.empty() && + ArrayRef(then_yield_in_layouts) != layouts_out) { + return op.emitOpError( + "Not implemented: different layouts in then yield's operands and if's " + "results"); + } + if (failed(applyLayoutBlock(ctx, *if_op.thenBlock()))) { + return failure(); + } + if (if_op.getElseRegion().empty()) { + CHECK_EQ(if_op->getNumResults(), 0) + << "Expected no results if op does not have an else block"; + CHECK_EQ(layouts_out.size(), 0); + return success(); + } + FAILUREOR_ASSIGN_OR_RETURN(const SmallVector else_yield_in_layouts, + getInLayout(*if_op.elseYield())); + if (!layouts_out.empty() && + ArrayRef(else_yield_in_layouts) != layouts_out) { + return op.emitOpError( + "Not implemented: different layouts in else yield's operands and if's " + "results"); + } + if (failed(applyLayoutBlock(ctx, *if_op.elseBlock()))) { + return failure(); + } + + // Apply layout to results after applying layout in the true and false + // regions. + if (if_op.getNumResults() == 0) { + CHECK_EQ(layouts_out.size(), 0); + return success(); + } + CHECK_EQ(if_op.getNumResults(), layouts_out.size()); + // If scf.if has results, it should have both non-empty true and false + // regions. + CHECK(!if_op.getThenRegion().empty() && !if_op.getElseRegion().empty()); + + // Move true and false regions to the new if op whose result has same type and + // layout as yield operand's. + auto new_op = ctx.builder.create( + if_op.getLoc(), TypeRange(if_op.thenYield().getResults()), + if_op.getCondition(), + /*withElseRegion =*/true); + moveAllRegions(*if_op, *new_op); + + int64_t index = 0; + SmallVector rolled_results; + for (auto [result, layout] : + llvm::zip_equal(if_op.getResults(), layouts_out)) { + if (const auto vty = dyn_cast(result.getType())) { + // When the result has a vector type, assemble the result. + CHECK(layout.has_value()); + const SmallVector tiles_shape = + layout->tileArrayShape(vty.getShape(), ctx.target_shape); + const int64_t num_vectors = ShapedType::getNumElements(tiles_shape); + xla::Array tiles(tiles_shape); + CHECK_LE(index + num_vectors, new_op.getResults().size()); + tiles.SetValues( + llvm::make_range(new_op.getResults().begin() + index, + new_op.getResults().begin() + index + num_vectors)); + index += num_vectors; + RollVectorsOp rolled_op = assemble(ctx, vty, *layout, tiles); + rolled_results.push_back(rolled_op); + } else { + CHECK(!layout.has_value()); + rolled_results.push_back(new_op.getResult(index)); + ++index; + } + } + if_op.replaceAllUsesWith(rolled_results); + if_op.erase(); + return success(); +} + +LogicalResult scf_yield_rule(RewriteContext &ctx, Operation &op, + const ArrayRef layouts_in, + const ArrayRef layouts_out) { + auto yield_op = cast(op); + CHECK_EQ(layouts_in.size(), yield_op.getNumOperands()); + CHECK_EQ(layouts_out.size(), 0); + if (yield_op.getNumOperands() == 0) { + return success(); + } + SmallVector unrolled; + for (auto [operand, layout] : + llvm::zip_equal(yield_op.getOperands(), layouts_in)) { + if (auto vty = dyn_cast(operand.getType())) { + // When the operand has vector type, disassemble the operand. + CHECK(layout.has_value()); + FAILUREOR_ASSIGN_OR_RETURN(const xla::Array tiles, + disassemble(ctx, *layout, operand)); + unrolled.append(tiles.begin(), tiles.end()); + } else { + CHECK(!layout.has_value()); + unrolled.push_back(operand); + } + } + + // Replace the old operands with unrolled operands. + yield_op->setOperands(unrolled); + return success(); +} + LogicalResult tpu_load_rule(RewriteContext &ctx, Operation &op, const ArrayRef layouts_in, const ArrayRef layouts_out) { @@ -1595,6 +1789,8 @@ const llvm::StringMap &rules() { rules_elementwise_op_entry(), rules_elementwise_op_entry(), {func::ReturnOp::getOperationName(), func_return_rule}, + {scf::IfOp::getOperationName(), scf_if_rule}, + {scf::YieldOp::getOperationName(), scf_yield_rule}, {tpu::IotaOp::getOperationName(), tpu_iota_rule}, {tpu::GatherOp::getOperationName(), tpu_gather_rule}, {tpu::LoadOp::getOperationName(), tpu_load_rule}, @@ -1610,76 +1806,6 @@ const llvm::StringMap &rules() { } } // namespace -// Get the layout from a VectorLayoutAttr or StringAttr. -mlir::FailureOr getLayoutFromAttr(Attribute attr) { - if (attr == nullptr) { - return failure(); - } - - if (auto layout_attr = dyn_cast(attr)) { - return layout_attr.getLayout(); - } - - // TODO(tlongeri): StringAttr support was only added temporarily to avoid - // having Python bindings for VectorLayoutAttr. Remove this once we get rid - // of the Python implementation - if (auto string_attr = dyn_cast(attr)) { - StringRef str = string_attr.getValue(); - if (!str.consume_front("#tpu.vpad<\"")) { - return failure(); - } - if (str.consume_front("none")) { - return kNoLayout; - } - if (auto layout = VectorLayout::parse(&str)) { - return layout; - } - return failure(); - } - - return failure(); -} - -// Returns empty vector on null attribute -FailureOr> getLayoutArrayFromAttr(const Attribute attr) { - if (const auto array_attr = dyn_cast_if_present(attr)) { - SmallVector out_layouts; - out_layouts.reserve(array_attr.size()); - for (const Attribute a : array_attr) { - FAILUREOR_ASSIGN_OR_RETURN(const Layout layout, getLayoutFromAttr(a)); - out_layouts.push_back(layout); - } - return out_layouts; - } - return SmallVector{}; -} - -// TODO(tlongeri): Unify with infer_vector_layout.cc's getOutLayout. -FailureOr> getOutLayout(Operation &op) { - // TODO(tlongeri): non-array attribute path should be removed after tests are - // updated - FailureOr failure_or_layout = - getLayoutFromAttr(op.getAttr("out_layout")); - if (succeeded(failure_or_layout)) { - return SmallVector{failure_or_layout.value()}; - } - FAILUREOR_ASSIGN_OR_RETURN(const SmallVector out_layout, - getLayoutArrayFromAttr(op.getAttr("out_layout"))); - if (out_layout.size() != op.getNumResults()) { - return failure(); - } - return out_layout; -} - -FailureOr> getInLayout(Operation &op) { - FAILUREOR_ASSIGN_OR_RETURN(const SmallVector in_layout, - getLayoutArrayFromAttr(op.getAttr("in_layout"))); - if (in_layout.size() != op.getNumOperands()) { - return failure(); - } - return in_layout; -} - template ArrayRef XlaArrayToFlatArrayRef(xla::Array xla_array) { return ArrayRef(xla_array.data(), xla_array.num_elements()); @@ -2357,7 +2483,6 @@ FailureOr relayout(RewriteContext &ctx, Value v, VectorLayout src, // For example, we should verify that ops that were supposed to generate // replicated outputs satisfy that requirement. LogicalResult applyLayoutOp(RewriteContext &ctx, Operation &op) { - ctx.builder.setInsertionPointAfter(&op); // TODO(tlongeri): Once we support all ops, return failure instead. if (!rules().contains(op.getName().getStringRef())) { return success(); @@ -2405,9 +2530,12 @@ LogicalResult applyLayoutOp(RewriteContext &ctx, Operation &op) { if (lo->generalizes(*li, vty.getShape(), ctx.target_shape)) { continue; } + const OpBuilder::InsertPoint ip = ctx.builder.saveInsertionPoint(); + ctx.builder.setInsertionPoint(&op); FAILUREOR_ASSIGN_OR_RETURN(Value new_v, relayout(ctx, operand, /*src=*/*lo, /*dst=*/*li)); + ctx.builder.restoreInsertionPoint(ip); op.setOperand(idx, new_v); } } @@ -2429,8 +2557,12 @@ LogicalResult applyLayoutOp(RewriteContext &ctx, Operation &op) { } if (auto rule_it = rules().find(op.getName().getStringRef()); rule_it != rules().end()) { + const OpBuilder::InsertPoint ip = ctx.builder.saveInsertionPoint(); + ctx.builder.setInsertionPointAfter(&op); const rule_type &rule = rule_it->getValue(); - return rule(ctx, op, layout_in, layout_out); + LogicalResult res = rule(ctx, op, layout_in, layout_out); + ctx.builder.restoreInsertionPoint(ip); + return res; } return op.emitError("Unsupported operation: ") << op.getName(); } @@ -2475,7 +2607,7 @@ struct ApplyVectorLayoutPass return; } func::FuncOp func = getOperation(); - OpBuilder builder(func.getBody()); + OpBuilder builder(func->getContext()); RewriteContext ctx{ func, builder, hardware_generation, {sublane_count, lane_count}}; if (failed(applyLayoutFunc(ctx, func))) {