From 622472f78b61c34809cbe6a99ccf18a3c93a6ff6 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Tom=C3=A1s=20Longeri?= Date: Mon, 9 Oct 2023 16:02:00 -0700 Subject: [PATCH] [Mosaic] apply_vector_layout C++ rewrite: Handle elementwise ops by checking for the Elementwise trait and using the generic Operation interface, without templates PiperOrigin-RevId: 572065184 --- .../tpu/transforms/apply_vector_layout.cc | 231 +++++------------- 1 file changed, 56 insertions(+), 175 deletions(-) diff --git a/jaxlib/mosaic/dialect/tpu/transforms/apply_vector_layout.cc b/jaxlib/mosaic/dialect/tpu/transforms/apply_vector_layout.cc index a47a837b949d..e45b3a1cfb32 100644 --- a/jaxlib/mosaic/dialect/tpu/transforms/apply_vector_layout.cc +++ b/jaxlib/mosaic/dialect/tpu/transforms/apply_vector_layout.cc @@ -52,6 +52,7 @@ #include "mlir/include/mlir/Dialect/Vector/IR/VectorOps.h" #include "mlir/include/mlir/IR/Builders.h" #include "mlir/include/mlir/IR/ImplicitLocOpBuilder.h" +#include "mlir/include/mlir/IR/OperationSupport.h" #include "jaxlib/mosaic/dialect/tpu/layout.h" #include "jaxlib/mosaic/dialect/tpu/tpu_dialect.h" #include "jaxlib/mosaic/dialect/tpu/transforms/infer_memref_layout.h" @@ -450,12 +451,13 @@ FailureOr> getInLayout(Operation &op) { return in_layout; } -LogicalResult elementwise_op_rule( - RewriteContext &ctx, Operation &op, const ArrayRef layouts_in, - const ArrayRef layouts_out, - std::function(RewriteContext &, OpBuilder &, - ArrayRef)> - factory) { +LogicalResult elementwise_op_rule(RewriteContext &ctx, Operation &op, + const ArrayRef layouts_in, + const ArrayRef layouts_out) { + CHECK(OpTrait::hasElementwiseMappableTraits(&op)); + if (op.getNumResults() != 1) { + return op.emitError("Not implemented: Only ops with one result supported"); + } CHECK_EQ(layouts_in.size(), op.getNumOperands()); CHECK_GT(layouts_in.size(), 0); CHECK_EQ(layouts_out.size(), 1); @@ -463,158 +465,79 @@ LogicalResult elementwise_op_rule( if (!(layouts_out.front().has_value() && llvm::all_of(layouts_in, [&](const Layout &l) { return l.has_value(); }))) { - return op.emitOpError("null layout in elementwise operation"); + return op.emitOpError( + "Not implemented: Null layout / non-vector operand in elementwise " + "operation"); } - const auto vty = cast(op.getResult(0).getType()); + const auto out_ty = cast(op.getResult(0).getType()); const VectorLayout &layout_out = *layouts_out.front(); if (!llvm::all_of(layouts_in, [&](const Layout &l) { - return l->generalizes(layout_out, vty.getShape(), ctx.target_shape); + return l->generalizes(layout_out, out_ty.getShape(), ctx.target_shape); })) { - return op.emitOpError("incompatible layouts in elementwise operation"); + return op.emitOpError("Incompatible layouts in elementwise operation"); } const unsigned num_operands = op.getNumOperands(); - SmallVector> in_tile_arrays; - in_tile_arrays.reserve(num_operands); + SmallVector> in_vreg_arrays; + in_vreg_arrays.reserve(num_operands); for (unsigned i = 0; i < num_operands; ++i) { FAILUREOR_ASSIGN_OR_RETURN( xla::Array tile_array, disassemble(ctx, builder, *layouts_in[i], op.getOperand(i))); - in_tile_arrays.emplace_back(std::move(tile_array)); + in_vreg_arrays.emplace_back(std::move(tile_array)); } + FAILUREOR_ASSIGN_OR_RETURN( + const VectorType out_vreg_ty, + getNativeVregType(out_ty.getElementType(), ctx.target_shape)); + + NamedAttrList attributes(op.getAttrDictionary()); + attributes.erase("in_layout"); + attributes.erase("out_layout"); + // Note that we have to broadcast to handle replicate dimensions. SmallVector broadcasted_shape( - toArrayRef(in_tile_arrays[0].dimensions())); + toArrayRef(in_vreg_arrays[0].dimensions())); for (size_t i = 1; i < num_operands; ++i) { SmallVector new_broadcasted_shape; CHECK(OpTrait::util::getBroadcastedShape( - broadcasted_shape, toArrayRef(in_tile_arrays[i].dimensions()), + broadcasted_shape, toArrayRef(in_vreg_arrays[i].dimensions()), new_broadcasted_shape)); broadcasted_shape = std::move(new_broadcasted_shape); } + CHECK(broadcasted_shape == + layout_out.tileArrayShape(out_ty.getShape(), ctx.target_shape)); // TODO(tlongeri): Can we avoid initializing the array before filling values? - xla::Array out_tile_array(broadcasted_shape); - absl::Status status = - out_tile_array.EachStatus([&](absl::Span idx, Value *v) { - SmallVector operands(num_operands); - for (unsigned i = 0; i < num_operands; ++i) { - // Handle indices for broadcasted dimensions - SmallVector operand_idx(toArrayRef(idx)); - for (unsigned j = 0; j < idx.size(); ++j) { - if (in_tile_arrays[i].dim(j) == 1) { - operand_idx[j] = 0; - } - } - operands[i] = in_tile_arrays[i](operand_idx); - } - FailureOr failure_or_tile_op = - factory(ctx, builder, operands); - if (failed(failure_or_tile_op)) { - return absl::InvalidArgumentError(""); + xla::Array out_vreg_array(broadcasted_shape); + out_vreg_array.Each([&](absl::Span idx, Value *out_vreg) { + SmallVector operands(num_operands); + + for (unsigned i = 0; i < num_operands; ++i) { + // Handle indices for broadcasted dimensions + SmallVector operand_idx(toArrayRef(idx)); + for (unsigned j = 0; j < idx.size(); ++j) { + if (in_vreg_arrays[i].dim(j) == 1) { + operand_idx[j] = 0; } - Operation *tile_op = *failure_or_tile_op; - CHECK(tile_op); - CHECK_EQ(tile_op->getNumResults(), 1); - *v = tile_op->getResult(0); - return absl::OkStatus(); - }); - if (!status.ok()) { - return failure(); - } + } + operands[i] = in_vreg_arrays[i](operand_idx); + } + Operation *vreg_op = + builder.create(op.getLoc(), op.getName().getIdentifier(), operands, + out_vreg_ty, attributes.getAttrs()); + CHECK(vreg_op); + CHECK_EQ(vreg_op->getNumResults(), 1); + *out_vreg = vreg_op->getResult(0); + }); op.replaceAllUsesWith( - assemble(ctx, builder, vty, layout_out, std::move(out_tile_array))); + assemble(ctx, builder, out_ty, layout_out, std::move(out_vreg_array))); op.erase(); return success(); } -// Helper for index_sequence expansion -template -using Wrapper = T; - -template -LogicalResult elementwise_op_rule_unpacked_impl( - RewriteContext &ctx, Operation &op, const ArrayRef layout_in, - const ArrayRef layout_out, - std::function( - RewriteContext &ctx, OpBuilder &builder, Wrapper...)> - factory, - std::index_sequence) { - return elementwise_op_rule( - ctx, op, layout_in, layout_out, - [&](RewriteContext &ctx, OpBuilder &builder, - ArrayRef operands) -> FailureOr { - if (operands.size() != sizeof...(I)) { - return failure(); - } - return factory(ctx, builder, operands[I]...); - }); -} - -// Like elementwise_op_rule, but operands are "unpacked" into individual -// arguments for the factory. -// Returns failure if the number of operands is not the one expected (i.e. it -// doesn't match NumOperands). -template -LogicalResult elementwise_op_rule_unpacked(RewriteContext &ctx, Operation &op, - const ArrayRef layouts_in, - const ArrayRef layouts_out, - Func factory) { - return elementwise_op_rule_unpacked_impl( - ctx, op, layouts_in, layouts_out, std::move(factory), - std::make_index_sequence()); -} - using rule_type = std::function, ArrayRef)>; -LogicalResult arith_cmpf_rule(RewriteContext &ctx, Operation &op, - ArrayRef layouts_in, - ArrayRef layouts_out) { - auto cmpf_op = cast(op); - return elementwise_op_rule_unpacked<2>( - ctx, op, layouts_in, layouts_out, - [&](RewriteContext &ctx, OpBuilder &builder, const Value lhs, - const Value rhs) -> FailureOr { - return builder - .create(cmpf_op.getLoc(), cmpf_op.getPredicateAttr(), - lhs, rhs) - .getOperation(); - }); -} - -LogicalResult arith_cmpi_rule(RewriteContext &ctx, Operation &op, - const ArrayRef layouts_in, - const ArrayRef layouts_out) { - auto cmpi_op = cast(op); - return elementwise_op_rule_unpacked<2>( - ctx, op, layouts_in, layouts_out, - [&](RewriteContext &ctx, OpBuilder &builder, const Value lhs, - const Value rhs) -> FailureOr { - return builder - .create(cmpi_op.getLoc(), cmpi_op.getPredicateAttr(), - lhs, rhs) - .getOperation(); - }); -} - -LogicalResult arith_extui_rule(RewriteContext &ctx, Operation &op, - const ArrayRef layouts_in, - const ArrayRef layouts_out) { - auto extui_op = cast(op); - const Type elem_ty = - cast(extui_op.getResult().getType()).getElementType(); - return elementwise_op_rule_unpacked<1>( - ctx, op, layouts_in, layouts_out, - [&](RewriteContext &ctx, OpBuilder &builder, - const Value x) -> FailureOr { - const VectorType x_ty = cast(x.getType()); - const VectorType out_ty = VectorType::get(x_ty.getShape(), elem_ty); - return builder.create(extui_op.getLoc(), out_ty, x) - .getOperation(); - }); -} - template LogicalResult ext_op_rule_impl(RewriteContext &ctx, OpTy op, const VectorLayout &layout_in, @@ -2698,58 +2621,13 @@ LogicalResult vector_transpose_rule(RewriteContext &ctx, Operation &op, transpose_op->erase(); return success(); } - -template -std::pair rules_elementwise_op_entry() { - return { - Op::getOperationName(), - [](RewriteContext &ctx, Operation &op, const ArrayRef layouts_in, - const ArrayRef layouts_out) -> LogicalResult { - return elementwise_op_rule_unpacked( - ctx, op, layouts_in, layouts_out, - [&](RewriteContext &ctx, OpBuilder &builder, - auto... operands) -> FailureOr { - return builder.create(op.getLoc(), operands...) - .getOperation(); - }); - }}; -} - const llvm::StringMap &rules() { static auto rules = new llvm::StringMap{ {arith::ConstantOp::getOperationName(), arith_constant_rule}, - rules_elementwise_op_entry(), - rules_elementwise_op_entry(), - {arith::CmpFOp::getOperationName(), arith_cmpf_rule}, - {arith::CmpIOp::getOperationName(), arith_cmpi_rule}, {arith::ExtFOp::getOperationName(), arith_extf_rule}, {arith::ExtSIOp::getOperationName(), arith_extsi_rule}, - {arith::ExtUIOp::getOperationName(), arith_extui_rule}, {arith::TruncFOp::getOperationName(), arith_truncf_rule}, {arith::TruncIOp::getOperationName(), arith_trunci_rule}, - rules_elementwise_op_entry(), - rules_elementwise_op_entry(), - rules_elementwise_op_entry(), - rules_elementwise_op_entry(), - rules_elementwise_op_entry(), - rules_elementwise_op_entry(), - rules_elementwise_op_entry(), - rules_elementwise_op_entry(), - rules_elementwise_op_entry(), - rules_elementwise_op_entry(), - // TODO(tlongeri) arith::IndexCastOp - rules_elementwise_op_entry(), - rules_elementwise_op_entry(), - rules_elementwise_op_entry(), - rules_elementwise_op_entry(), - rules_elementwise_op_entry(), - rules_elementwise_op_entry(), - rules_elementwise_op_entry(), - rules_elementwise_op_entry(), - rules_elementwise_op_entry(), - rules_elementwise_op_entry(), - rules_elementwise_op_entry(), - rules_elementwise_op_entry(), {func::ReturnOp::getOperationName(), func_return_rule}, {scf::ForOp::getOperationName(), scf_for_rule}, {scf::IfOp::getOperationName(), scf_if_rule}, @@ -3479,7 +3357,8 @@ FailureOr relayout(RewriteContext &ctx, OpBuilder &builder, Value v, // replicated outputs satisfy that requirement. LogicalResult applyLayoutOp(RewriteContext &ctx, Operation &op) { // TODO(tlongeri): Once we support all ops, return failure instead. - if (!rules().contains(op.getName().getStringRef())) { + if (!rules().contains(op.getName().getStringRef()) && + !OpTrait::hasElementwiseMappableTraits(&op)) { return success(); } @@ -3551,8 +3430,10 @@ LogicalResult applyLayoutOp(RewriteContext &ctx, Operation &op) { if (auto rule_it = rules().find(op.getName().getStringRef()); rule_it != rules().end()) { const rule_type &rule = rule_it->getValue(); - LogicalResult res = rule(ctx, op, layout_in, layout_out); - return res; + return rule(ctx, op, layout_in, layout_out); + } + if (OpTrait::hasElementwiseMappableTraits(&op)) { + return elementwise_op_rule(ctx, op, layout_in, layout_out); } return op.emitError("Unsupported operation: ") << op.getName(); }