Skip to content

Commit

Permalink
[Mosaic] apply_vector_layout C++ rewrite: Handle elementwise ops by c…
Browse files Browse the repository at this point in the history
…hecking for the Elementwise trait and using the generic Operation interface, without templates

PiperOrigin-RevId: 572065184
  • Loading branch information
tlongeri authored and jax authors committed Oct 9, 2023
1 parent a86d4dd commit 622472f
Showing 1 changed file with 56 additions and 175 deletions.
231 changes: 56 additions & 175 deletions jaxlib/mosaic/dialect/tpu/transforms/apply_vector_layout.cc
Expand Up @@ -52,6 +52,7 @@
#include "mlir/include/mlir/Dialect/Vector/IR/VectorOps.h"
#include "mlir/include/mlir/IR/Builders.h"
#include "mlir/include/mlir/IR/ImplicitLocOpBuilder.h"
#include "mlir/include/mlir/IR/OperationSupport.h"
#include "jaxlib/mosaic/dialect/tpu/layout.h"
#include "jaxlib/mosaic/dialect/tpu/tpu_dialect.h"
#include "jaxlib/mosaic/dialect/tpu/transforms/infer_memref_layout.h"
Expand Down Expand Up @@ -450,171 +451,93 @@ FailureOr<SmallVector<Layout>> getInLayout(Operation &op) {
return in_layout;
}

LogicalResult elementwise_op_rule(
RewriteContext &ctx, Operation &op, const ArrayRef<Layout> layouts_in,
const ArrayRef<Layout> layouts_out,
std::function<FailureOr<Operation *>(RewriteContext &, OpBuilder &,
ArrayRef<Value>)>
factory) {
LogicalResult elementwise_op_rule(RewriteContext &ctx, Operation &op,
const ArrayRef<Layout> layouts_in,
const ArrayRef<Layout> layouts_out) {
CHECK(OpTrait::hasElementwiseMappableTraits(&op));
if (op.getNumResults() != 1) {
return op.emitError("Not implemented: Only ops with one result supported");
}
CHECK_EQ(layouts_in.size(), op.getNumOperands());
CHECK_GT(layouts_in.size(), 0);
CHECK_EQ(layouts_out.size(), 1);
OpBuilder builder(&op);
if (!(layouts_out.front().has_value() &&
llvm::all_of(layouts_in,
[&](const Layout &l) { return l.has_value(); }))) {
return op.emitOpError("null layout in elementwise operation");
return op.emitOpError(
"Not implemented: Null layout / non-vector operand in elementwise "
"operation");
}
const auto vty = cast<VectorType>(op.getResult(0).getType());
const auto out_ty = cast<VectorType>(op.getResult(0).getType());
const VectorLayout &layout_out = *layouts_out.front();
if (!llvm::all_of(layouts_in, [&](const Layout &l) {
return l->generalizes(layout_out, vty.getShape(), ctx.target_shape);
return l->generalizes(layout_out, out_ty.getShape(), ctx.target_shape);
})) {
return op.emitOpError("incompatible layouts in elementwise operation");
return op.emitOpError("Incompatible layouts in elementwise operation");
}
const unsigned num_operands = op.getNumOperands();
SmallVector<xla::Array<Value>> in_tile_arrays;
in_tile_arrays.reserve(num_operands);
SmallVector<xla::Array<Value>> in_vreg_arrays;
in_vreg_arrays.reserve(num_operands);
for (unsigned i = 0; i < num_operands; ++i) {
FAILUREOR_ASSIGN_OR_RETURN(
xla::Array<Value> tile_array,
disassemble(ctx, builder, *layouts_in[i], op.getOperand(i)));
in_tile_arrays.emplace_back(std::move(tile_array));
in_vreg_arrays.emplace_back(std::move(tile_array));
}

FAILUREOR_ASSIGN_OR_RETURN(
const VectorType out_vreg_ty,
getNativeVregType(out_ty.getElementType(), ctx.target_shape));

NamedAttrList attributes(op.getAttrDictionary());
attributes.erase("in_layout");
attributes.erase("out_layout");

// Note that we have to broadcast to handle replicate dimensions.
SmallVector<int64_t> broadcasted_shape(
toArrayRef(in_tile_arrays[0].dimensions()));
toArrayRef(in_vreg_arrays[0].dimensions()));
for (size_t i = 1; i < num_operands; ++i) {
SmallVector<int64_t> new_broadcasted_shape;
CHECK(OpTrait::util::getBroadcastedShape(
broadcasted_shape, toArrayRef(in_tile_arrays[i].dimensions()),
broadcasted_shape, toArrayRef(in_vreg_arrays[i].dimensions()),
new_broadcasted_shape));
broadcasted_shape = std::move(new_broadcasted_shape);
}
CHECK(broadcasted_shape ==
layout_out.tileArrayShape(out_ty.getShape(), ctx.target_shape));

// TODO(tlongeri): Can we avoid initializing the array before filling values?
xla::Array<Value> out_tile_array(broadcasted_shape);
absl::Status status =
out_tile_array.EachStatus([&](absl::Span<const int64_t> idx, Value *v) {
SmallVector<Value> operands(num_operands);
for (unsigned i = 0; i < num_operands; ++i) {
// Handle indices for broadcasted dimensions
SmallVector<int64_t> operand_idx(toArrayRef(idx));
for (unsigned j = 0; j < idx.size(); ++j) {
if (in_tile_arrays[i].dim(j) == 1) {
operand_idx[j] = 0;
}
}
operands[i] = in_tile_arrays[i](operand_idx);
}
FailureOr<Operation *> failure_or_tile_op =
factory(ctx, builder, operands);
if (failed(failure_or_tile_op)) {
return absl::InvalidArgumentError("");
xla::Array<Value> out_vreg_array(broadcasted_shape);
out_vreg_array.Each([&](absl::Span<const int64_t> idx, Value *out_vreg) {
SmallVector<Value> operands(num_operands);

for (unsigned i = 0; i < num_operands; ++i) {
// Handle indices for broadcasted dimensions
SmallVector<int64_t> operand_idx(toArrayRef(idx));
for (unsigned j = 0; j < idx.size(); ++j) {
if (in_vreg_arrays[i].dim(j) == 1) {
operand_idx[j] = 0;
}
Operation *tile_op = *failure_or_tile_op;
CHECK(tile_op);
CHECK_EQ(tile_op->getNumResults(), 1);
*v = tile_op->getResult(0);
return absl::OkStatus();
});
if (!status.ok()) {
return failure();
}
}
operands[i] = in_vreg_arrays[i](operand_idx);
}
Operation *vreg_op =
builder.create(op.getLoc(), op.getName().getIdentifier(), operands,
out_vreg_ty, attributes.getAttrs());
CHECK(vreg_op);
CHECK_EQ(vreg_op->getNumResults(), 1);
*out_vreg = vreg_op->getResult(0);
});
op.replaceAllUsesWith(
assemble(ctx, builder, vty, layout_out, std::move(out_tile_array)));
assemble(ctx, builder, out_ty, layout_out, std::move(out_vreg_array)));
op.erase();
return success();
}

// Helper for index_sequence expansion
template <typename T, std::size_t>
using Wrapper = T;

template <std::size_t... I>
LogicalResult elementwise_op_rule_unpacked_impl(
RewriteContext &ctx, Operation &op, const ArrayRef<Layout> layout_in,
const ArrayRef<Layout> layout_out,
std::function<FailureOr<Operation *>(
RewriteContext &ctx, OpBuilder &builder, Wrapper<Value, I>...)>
factory,
std::index_sequence<I...>) {
return elementwise_op_rule(
ctx, op, layout_in, layout_out,
[&](RewriteContext &ctx, OpBuilder &builder,
ArrayRef<Value> operands) -> FailureOr<Operation *> {
if (operands.size() != sizeof...(I)) {
return failure();
}
return factory(ctx, builder, operands[I]...);
});
}

// Like elementwise_op_rule, but operands are "unpacked" into individual
// arguments for the factory.
// Returns failure if the number of operands is not the one expected (i.e. it
// doesn't match NumOperands).
template <std::size_t NumOperands, typename Func>
LogicalResult elementwise_op_rule_unpacked(RewriteContext &ctx, Operation &op,
const ArrayRef<Layout> layouts_in,
const ArrayRef<Layout> layouts_out,
Func factory) {
return elementwise_op_rule_unpacked_impl(
ctx, op, layouts_in, layouts_out, std::move(factory),
std::make_index_sequence<NumOperands>());
}

using rule_type = std::function<LogicalResult(
RewriteContext &, Operation &, ArrayRef<Layout>, ArrayRef<Layout>)>;

LogicalResult arith_cmpf_rule(RewriteContext &ctx, Operation &op,
ArrayRef<Layout> layouts_in,
ArrayRef<Layout> layouts_out) {
auto cmpf_op = cast<arith::CmpFOp>(op);
return elementwise_op_rule_unpacked<2>(
ctx, op, layouts_in, layouts_out,
[&](RewriteContext &ctx, OpBuilder &builder, const Value lhs,
const Value rhs) -> FailureOr<Operation *> {
return builder
.create<arith::CmpFOp>(cmpf_op.getLoc(), cmpf_op.getPredicateAttr(),
lhs, rhs)
.getOperation();
});
}

LogicalResult arith_cmpi_rule(RewriteContext &ctx, Operation &op,
const ArrayRef<Layout> layouts_in,
const ArrayRef<Layout> layouts_out) {
auto cmpi_op = cast<arith::CmpIOp>(op);
return elementwise_op_rule_unpacked<2>(
ctx, op, layouts_in, layouts_out,
[&](RewriteContext &ctx, OpBuilder &builder, const Value lhs,
const Value rhs) -> FailureOr<Operation *> {
return builder
.create<arith::CmpIOp>(cmpi_op.getLoc(), cmpi_op.getPredicateAttr(),
lhs, rhs)
.getOperation();
});
}

LogicalResult arith_extui_rule(RewriteContext &ctx, Operation &op,
const ArrayRef<Layout> layouts_in,
const ArrayRef<Layout> layouts_out) {
auto extui_op = cast<arith::ExtUIOp>(op);
const Type elem_ty =
cast<VectorType>(extui_op.getResult().getType()).getElementType();
return elementwise_op_rule_unpacked<1>(
ctx, op, layouts_in, layouts_out,
[&](RewriteContext &ctx, OpBuilder &builder,
const Value x) -> FailureOr<Operation *> {
const VectorType x_ty = cast<VectorType>(x.getType());
const VectorType out_ty = VectorType::get(x_ty.getShape(), elem_ty);
return builder.create<arith::ExtUIOp>(extui_op.getLoc(), out_ty, x)
.getOperation();
});
}

template <typename OpTy>
LogicalResult ext_op_rule_impl(RewriteContext &ctx, OpTy op,
const VectorLayout &layout_in,
Expand Down Expand Up @@ -2698,58 +2621,13 @@ LogicalResult vector_transpose_rule(RewriteContext &ctx, Operation &op,
transpose_op->erase();
return success();
}

template <typename Op, std::size_t NumOperands>
std::pair<StringRef, rule_type> rules_elementwise_op_entry() {
return {
Op::getOperationName(),
[](RewriteContext &ctx, Operation &op, const ArrayRef<Layout> layouts_in,
const ArrayRef<Layout> layouts_out) -> LogicalResult {
return elementwise_op_rule_unpacked<NumOperands>(
ctx, op, layouts_in, layouts_out,
[&](RewriteContext &ctx, OpBuilder &builder,
auto... operands) -> FailureOr<Operation *> {
return builder.create<Op>(op.getLoc(), operands...)
.getOperation();
});
}};
}

const llvm::StringMap<rule_type> &rules() {
static auto rules = new llvm::StringMap<rule_type>{
{arith::ConstantOp::getOperationName(), arith_constant_rule},
rules_elementwise_op_entry<arith::AddFOp, 2>(),
rules_elementwise_op_entry<arith::AddIOp, 2>(),
{arith::CmpFOp::getOperationName(), arith_cmpf_rule},
{arith::CmpIOp::getOperationName(), arith_cmpi_rule},
{arith::ExtFOp::getOperationName(), arith_extf_rule},
{arith::ExtSIOp::getOperationName(), arith_extsi_rule},
{arith::ExtUIOp::getOperationName(), arith_extui_rule},
{arith::TruncFOp::getOperationName(), arith_truncf_rule},
{arith::TruncIOp::getOperationName(), arith_trunci_rule},
rules_elementwise_op_entry<arith::SubFOp, 2>(),
rules_elementwise_op_entry<arith::SubIOp, 2>(),
rules_elementwise_op_entry<arith::MulFOp, 2>(),
rules_elementwise_op_entry<arith::MulIOp, 2>(),
rules_elementwise_op_entry<arith::DivFOp, 2>(),
rules_elementwise_op_entry<arith::DivSIOp, 2>(),
rules_elementwise_op_entry<arith::RemSIOp, 2>(),
rules_elementwise_op_entry<arith::MaximumFOp, 2>(),
rules_elementwise_op_entry<arith::MinimumFOp, 2>(),
rules_elementwise_op_entry<arith::SelectOp, 3>(),
// TODO(tlongeri) arith::IndexCastOp
rules_elementwise_op_entry<arith::AndIOp, 2>(),
rules_elementwise_op_entry<arith::OrIOp, 2>(),
rules_elementwise_op_entry<arith::NegFOp, 1>(),
rules_elementwise_op_entry<arith::XOrIOp, 2>(),
rules_elementwise_op_entry<arith::ShLIOp, 2>(),
rules_elementwise_op_entry<arith::ShRUIOp, 2>(),
rules_elementwise_op_entry<math::ExpOp, 1>(),
rules_elementwise_op_entry<math::CosOp, 1>(),
rules_elementwise_op_entry<math::SinOp, 1>(),
rules_elementwise_op_entry<math::PowFOp, 2>(),
rules_elementwise_op_entry<math::RsqrtOp, 1>(),
rules_elementwise_op_entry<math::TanhOp, 1>(),
{func::ReturnOp::getOperationName(), func_return_rule},
{scf::ForOp::getOperationName(), scf_for_rule},
{scf::IfOp::getOperationName(), scf_if_rule},
Expand Down Expand Up @@ -3479,7 +3357,8 @@ FailureOr<Value> relayout(RewriteContext &ctx, OpBuilder &builder, Value v,
// replicated outputs satisfy that requirement.
LogicalResult applyLayoutOp(RewriteContext &ctx, Operation &op) {
// TODO(tlongeri): Once we support all ops, return failure instead.
if (!rules().contains(op.getName().getStringRef())) {
if (!rules().contains(op.getName().getStringRef()) &&
!OpTrait::hasElementwiseMappableTraits(&op)) {
return success();
}

Expand Down Expand Up @@ -3551,8 +3430,10 @@ LogicalResult applyLayoutOp(RewriteContext &ctx, Operation &op) {
if (auto rule_it = rules().find(op.getName().getStringRef());
rule_it != rules().end()) {
const rule_type &rule = rule_it->getValue();
LogicalResult res = rule(ctx, op, layout_in, layout_out);
return res;
return rule(ctx, op, layout_in, layout_out);
}
if (OpTrait::hasElementwiseMappableTraits(&op)) {
return elementwise_op_rule(ctx, op, layout_in, layout_out);
}
return op.emitError("Unsupported operation: ") << op.getName();
}
Expand Down

0 comments on commit 622472f

Please sign in to comment.