Skip to content

Commit

Permalink
[Mosaic] apply_vector_layout C++ rewrite (13): scf.if, scf.yield
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 570845376
  • Loading branch information
tlongeri authored and jax authors committed Oct 4, 2023
1 parent 9e3d64a commit 991e6ef
Showing 1 changed file with 205 additions and 73 deletions.
278 changes: 205 additions & 73 deletions jaxlib/mosaic/dialect/tpu/transforms/apply_vector_layout.cc
Expand Up @@ -14,11 +14,13 @@
#include "llvm/ADT/SmallVector.h"
#include "llvm/ADT/SmallVectorExtras.h"
#include "llvm/ADT/StringMap.h"
#include "llvm/ADT/iterator_range.h"
#include "llvm/Support/ErrorHandling.h"
#include "llvm/Support/MathExtras.h"
#include "mlir/Dialect/Arith/IR/Arith.h"
#include "mlir/Dialect/Func/IR/FuncOps.h"
#include "mlir/Dialect/Math/IR/Math.h"
#include "mlir/Dialect/SCF/IR/SCF.h"
#include "mlir/Dialect/Traits.h"
#include "mlir/Dialect/Vector/IR/VectorOps.h"
#include "mlir/IR/AffineExpr.h"
Expand Down Expand Up @@ -84,6 +86,13 @@ FailureOr<xla::Array<Value>> disassemble(RewriteContext &ctx,
const VectorLayout &layout, Value val);
namespace {

void moveAllRegions(Operation &src, Operation &dst) {
for (auto [src_region, dst_region] :
llvm::zip_equal(src.getRegions(), dst.getRegions())) {
dst_region.takeBody(src_region);
}
}

// Models Numpy's np.repeat, repeating each element `repeats` times along the
// specified axis. For example, if `src` is [1, 2], `axis` is 0 and `repeats` is
// 3, this will return [1, 1, 1, 2, 2, 2].
Expand Down Expand Up @@ -281,6 +290,76 @@ FailureOr<VectorType> getNativeVregType(
elem_ty);
}

// Get the layout from a VectorLayoutAttr or StringAttr.
mlir::FailureOr<Layout> getLayoutFromAttr(Attribute attr) {
if (attr == nullptr) {
return failure();
}

if (auto layout_attr = dyn_cast<VectorLayoutAttr>(attr)) {
return layout_attr.getLayout();
}

// TODO(tlongeri): StringAttr support was only added temporarily to avoid
// having Python bindings for VectorLayoutAttr. Remove this once we get rid
// of the Python implementation
if (auto string_attr = dyn_cast<StringAttr>(attr)) {
StringRef str = string_attr.getValue();
if (!str.consume_front("#tpu.vpad<\"")) {
return failure();
}
if (str.consume_front("none")) {
return kNoLayout;
}
if (auto layout = VectorLayout::parse(&str)) {
return layout;
}
return failure();
}

return failure();
}

// Returns empty vector on null attribute
FailureOr<SmallVector<Layout>> getLayoutArrayFromAttr(const Attribute attr) {
if (const auto array_attr = dyn_cast_if_present<ArrayAttr>(attr)) {
SmallVector<Layout> out_layouts;
out_layouts.reserve(array_attr.size());
for (const Attribute a : array_attr) {
FAILUREOR_ASSIGN_OR_RETURN(const Layout layout, getLayoutFromAttr(a));
out_layouts.push_back(layout);
}
return out_layouts;
}
return SmallVector<Layout>{};
}

// TODO(tlongeri): Unify with infer_vector_layout.cc's getOutLayout.
FailureOr<SmallVector<Layout>> getOutLayout(Operation &op) {
// TODO(tlongeri): non-array attribute path should be removed after tests are
// updated
FailureOr<Layout> failure_or_layout =
getLayoutFromAttr(op.getAttr("out_layout"));
if (succeeded(failure_or_layout)) {
return SmallVector<Layout>{failure_or_layout.value()};
}
FAILUREOR_ASSIGN_OR_RETURN(const SmallVector<Layout> out_layout,
getLayoutArrayFromAttr(op.getAttr("out_layout")));
if (out_layout.size() != op.getNumResults()) {
return failure();
}
return out_layout;
}

FailureOr<SmallVector<Layout>> getInLayout(Operation &op) {
FAILUREOR_ASSIGN_OR_RETURN(const SmallVector<Layout> in_layout,
getLayoutArrayFromAttr(op.getAttr("in_layout")));
if (in_layout.size() != op.getNumOperands()) {
return failure();
}
return in_layout;
}

LogicalResult elementwise_op_rule(
RewriteContext &ctx, Operation &op, const ArrayRef<Layout> layouts_in,
const ArrayRef<Layout> layouts_out,
Expand Down Expand Up @@ -638,6 +717,121 @@ LogicalResult func_return_rule(RewriteContext &ctx, Operation &op,
return success();
}

LogicalResult scf_if_rule(RewriteContext &ctx, Operation &op,
const ArrayRef<Layout> layouts_in,
const ArrayRef<Layout> layouts_out) {
CHECK_EQ(layouts_in.size(), 1);
CHECK(!layouts_in.front().has_value());
scf::IfOp if_op = cast<scf::IfOp>(op);
FAILUREOR_ASSIGN_OR_RETURN(const SmallVector<Layout> then_yield_in_layouts,
getInLayout(*if_op.thenYield()));
// TODO(tlongeri): ArrayRef<Layout> conversion should not be necessary, fix
// after LLVM adds const qualifiers to ==/!= operators. Also
// applies to else_yield_in_layouts comparison below.
if (!layouts_out.empty() &&
ArrayRef<Layout>(then_yield_in_layouts) != layouts_out) {
return op.emitOpError(
"Not implemented: different layouts in then yield's operands and if's "
"results");
}
if (failed(applyLayoutBlock(ctx, *if_op.thenBlock()))) {
return failure();
}
if (if_op.getElseRegion().empty()) {
CHECK_EQ(if_op->getNumResults(), 0)
<< "Expected no results if op does not have an else block";
CHECK_EQ(layouts_out.size(), 0);
return success();
}
FAILUREOR_ASSIGN_OR_RETURN(const SmallVector<Layout> else_yield_in_layouts,
getInLayout(*if_op.elseYield()));
if (!layouts_out.empty() &&
ArrayRef<Layout>(else_yield_in_layouts) != layouts_out) {
return op.emitOpError(
"Not implemented: different layouts in else yield's operands and if's "
"results");
}
if (failed(applyLayoutBlock(ctx, *if_op.elseBlock()))) {
return failure();
}

// Apply layout to results after applying layout in the true and false
// regions.
if (if_op.getNumResults() == 0) {
CHECK_EQ(layouts_out.size(), 0);
return success();
}
CHECK_EQ(if_op.getNumResults(), layouts_out.size());
// If scf.if has results, it should have both non-empty true and false
// regions.
CHECK(!if_op.getThenRegion().empty() && !if_op.getElseRegion().empty());

// Move true and false regions to the new if op whose result has same type and
// layout as yield operand's.
auto new_op = ctx.builder.create<scf::IfOp>(
if_op.getLoc(), TypeRange(if_op.thenYield().getResults()),
if_op.getCondition(),
/*withElseRegion =*/true);
moveAllRegions(*if_op, *new_op);

int64_t index = 0;
SmallVector<Value> rolled_results;
for (auto [result, layout] :
llvm::zip_equal(if_op.getResults(), layouts_out)) {
if (const auto vty = dyn_cast<VectorType>(result.getType())) {
// When the result has a vector type, assemble the result.
CHECK(layout.has_value());
const SmallVector<int64_t> tiles_shape =
layout->tileArrayShape(vty.getShape(), ctx.target_shape);
const int64_t num_vectors = ShapedType::getNumElements(tiles_shape);
xla::Array<Value> tiles(tiles_shape);
CHECK_LE(index + num_vectors, new_op.getResults().size());
tiles.SetValues(
llvm::make_range(new_op.getResults().begin() + index,
new_op.getResults().begin() + index + num_vectors));
index += num_vectors;
RollVectorsOp rolled_op = assemble(ctx, vty, *layout, tiles);
rolled_results.push_back(rolled_op);
} else {
CHECK(!layout.has_value());
rolled_results.push_back(new_op.getResult(index));
++index;
}
}
if_op.replaceAllUsesWith(rolled_results);
if_op.erase();
return success();
}

LogicalResult scf_yield_rule(RewriteContext &ctx, Operation &op,
const ArrayRef<Layout> layouts_in,
const ArrayRef<Layout> layouts_out) {
auto yield_op = cast<scf::YieldOp>(op);
CHECK_EQ(layouts_in.size(), yield_op.getNumOperands());
CHECK_EQ(layouts_out.size(), 0);
if (yield_op.getNumOperands() == 0) {
return success();
}
SmallVector<Value> unrolled;
for (auto [operand, layout] :
llvm::zip_equal(yield_op.getOperands(), layouts_in)) {
if (auto vty = dyn_cast<VectorType>(operand.getType())) {
// When the operand has vector type, disassemble the operand.
CHECK(layout.has_value());
FAILUREOR_ASSIGN_OR_RETURN(const xla::Array<Value> tiles,
disassemble(ctx, *layout, operand));
unrolled.append(tiles.begin(), tiles.end());
} else {
CHECK(!layout.has_value());
unrolled.push_back(operand);
}
}

// Replace the old operands with unrolled operands.
yield_op->setOperands(unrolled);
return success();
}

LogicalResult tpu_load_rule(RewriteContext &ctx, Operation &op,
const ArrayRef<Layout> layouts_in,
const ArrayRef<Layout> layouts_out) {
Expand Down Expand Up @@ -1595,6 +1789,8 @@ const llvm::StringMap<rule_type> &rules() {
rules_elementwise_op_entry<math::RsqrtOp, 1>(),
rules_elementwise_op_entry<math::TanhOp, 1>(),
{func::ReturnOp::getOperationName(), func_return_rule},
{scf::IfOp::getOperationName(), scf_if_rule},
{scf::YieldOp::getOperationName(), scf_yield_rule},
{tpu::IotaOp::getOperationName(), tpu_iota_rule},
{tpu::GatherOp::getOperationName(), tpu_gather_rule},
{tpu::LoadOp::getOperationName(), tpu_load_rule},
Expand All @@ -1610,76 +1806,6 @@ const llvm::StringMap<rule_type> &rules() {
}
} // namespace

// Get the layout from a VectorLayoutAttr or StringAttr.
mlir::FailureOr<Layout> getLayoutFromAttr(Attribute attr) {
if (attr == nullptr) {
return failure();
}

if (auto layout_attr = dyn_cast<VectorLayoutAttr>(attr)) {
return layout_attr.getLayout();
}

// TODO(tlongeri): StringAttr support was only added temporarily to avoid
// having Python bindings for VectorLayoutAttr. Remove this once we get rid
// of the Python implementation
if (auto string_attr = dyn_cast<StringAttr>(attr)) {
StringRef str = string_attr.getValue();
if (!str.consume_front("#tpu.vpad<\"")) {
return failure();
}
if (str.consume_front("none")) {
return kNoLayout;
}
if (auto layout = VectorLayout::parse(&str)) {
return layout;
}
return failure();
}

return failure();
}

// Returns empty vector on null attribute
FailureOr<SmallVector<Layout>> getLayoutArrayFromAttr(const Attribute attr) {
if (const auto array_attr = dyn_cast_if_present<ArrayAttr>(attr)) {
SmallVector<Layout> out_layouts;
out_layouts.reserve(array_attr.size());
for (const Attribute a : array_attr) {
FAILUREOR_ASSIGN_OR_RETURN(const Layout layout, getLayoutFromAttr(a));
out_layouts.push_back(layout);
}
return out_layouts;
}
return SmallVector<Layout>{};
}

// TODO(tlongeri): Unify with infer_vector_layout.cc's getOutLayout.
FailureOr<SmallVector<Layout>> getOutLayout(Operation &op) {
// TODO(tlongeri): non-array attribute path should be removed after tests are
// updated
FailureOr<Layout> failure_or_layout =
getLayoutFromAttr(op.getAttr("out_layout"));
if (succeeded(failure_or_layout)) {
return SmallVector<Layout>{failure_or_layout.value()};
}
FAILUREOR_ASSIGN_OR_RETURN(const SmallVector<Layout> out_layout,
getLayoutArrayFromAttr(op.getAttr("out_layout")));
if (out_layout.size() != op.getNumResults()) {
return failure();
}
return out_layout;
}

FailureOr<SmallVector<Layout>> getInLayout(Operation &op) {
FAILUREOR_ASSIGN_OR_RETURN(const SmallVector<Layout> in_layout,
getLayoutArrayFromAttr(op.getAttr("in_layout")));
if (in_layout.size() != op.getNumOperands()) {
return failure();
}
return in_layout;
}

template <typename T>
ArrayRef<T> XlaArrayToFlatArrayRef(xla::Array<T> xla_array) {
return ArrayRef<T>(xla_array.data(), xla_array.num_elements());
Expand Down Expand Up @@ -2357,7 +2483,6 @@ FailureOr<Value> relayout(RewriteContext &ctx, Value v, VectorLayout src,
// For example, we should verify that ops that were supposed to generate
// replicated outputs satisfy that requirement.
LogicalResult applyLayoutOp(RewriteContext &ctx, Operation &op) {
ctx.builder.setInsertionPointAfter(&op);
// TODO(tlongeri): Once we support all ops, return failure instead.
if (!rules().contains(op.getName().getStringRef())) {
return success();
Expand Down Expand Up @@ -2405,9 +2530,12 @@ LogicalResult applyLayoutOp(RewriteContext &ctx, Operation &op) {
if (lo->generalizes(*li, vty.getShape(), ctx.target_shape)) {
continue;
}
const OpBuilder::InsertPoint ip = ctx.builder.saveInsertionPoint();
ctx.builder.setInsertionPoint(&op);
FAILUREOR_ASSIGN_OR_RETURN(Value new_v,
relayout(ctx, operand, /*src=*/*lo,
/*dst=*/*li));
ctx.builder.restoreInsertionPoint(ip);
op.setOperand(idx, new_v);
}
}
Expand All @@ -2429,8 +2557,12 @@ LogicalResult applyLayoutOp(RewriteContext &ctx, Operation &op) {
}
if (auto rule_it = rules().find(op.getName().getStringRef());
rule_it != rules().end()) {
const OpBuilder::InsertPoint ip = ctx.builder.saveInsertionPoint();
ctx.builder.setInsertionPointAfter(&op);
const rule_type &rule = rule_it->getValue();
return rule(ctx, op, layout_in, layout_out);
LogicalResult res = rule(ctx, op, layout_in, layout_out);
ctx.builder.restoreInsertionPoint(ip);
return res;
}
return op.emitError("Unsupported operation: ") << op.getName();
}
Expand Down Expand Up @@ -2475,7 +2607,7 @@ struct ApplyVectorLayoutPass
return;
}
func::FuncOp func = getOperation();
OpBuilder builder(func.getBody());
OpBuilder builder(func->getContext());
RewriteContext ctx{
func, builder, hardware_generation, {sublane_count, lane_count}};
if (failed(applyLayoutFunc(ctx, func))) {
Expand Down

0 comments on commit 991e6ef

Please sign in to comment.