diff --git a/compiler/src/iree/compiler/InputConversion/StableHLO/BUILD.bazel b/compiler/src/iree/compiler/InputConversion/StableHLO/BUILD.bazel index 7c6a8bf6e3d8..7804b2f7cd17 100644 --- a/compiler/src/iree/compiler/InputConversion/StableHLO/BUILD.bazel +++ b/compiler/src/iree/compiler/InputConversion/StableHLO/BUILD.bazel @@ -49,6 +49,7 @@ iree_compiler_cc_library( "LegalizeToLinalgUtils.cpp", "Passes.cpp", "StableHLOToLinalg.cpp", + "StableHLOToLinalgPointwise.cpp", "TypeConversion.cpp", ], hdrs = [ diff --git a/compiler/src/iree/compiler/InputConversion/StableHLO/CMakeLists.txt b/compiler/src/iree/compiler/InputConversion/StableHLO/CMakeLists.txt index 78a3216e1487..a2b91145e630 100644 --- a/compiler/src/iree/compiler/InputConversion/StableHLO/CMakeLists.txt +++ b/compiler/src/iree/compiler/InputConversion/StableHLO/CMakeLists.txt @@ -46,6 +46,7 @@ iree_cc_library( "LegalizeToLinalgUtils.cpp" "Passes.cpp" "StableHLOToLinalg.cpp" + "StableHLOToLinalgPointwise.cpp" "TypeConversion.cpp" DEPS ::PassHeaders diff --git a/compiler/src/iree/compiler/InputConversion/StableHLO/LegalizeToLinalgUtils.cpp b/compiler/src/iree/compiler/InputConversion/StableHLO/LegalizeToLinalgUtils.cpp index 87f2f47de6c1..a95dbdbf46c1 100644 --- a/compiler/src/iree/compiler/InputConversion/StableHLO/LegalizeToLinalgUtils.cpp +++ b/compiler/src/iree/compiler/InputConversion/StableHLO/LegalizeToLinalgUtils.cpp @@ -81,6 +81,13 @@ Value getEmptyTensorFor(OpBuilder& b, Location loc, ShapedType resultType, : getEmptyTensor(b, loc, resultType, sizes); } +Value coerceTensorShape(OpBuilder& builder, Location loc, + TypedValue value, ShapedType targetType) { + return builder.createOrFold( + loc, targetType.cloneWith(std::nullopt, value.getType().getElementType()), + value); +} + Value preSparsify(Operation* op, llvm::SmallVector& values, Type rtp, OpBuilder* b) { // Apply for semi-ring operations that lower to elaborate code diff --git a/compiler/src/iree/compiler/InputConversion/StableHLO/LegalizeToLinalgUtils.h b/compiler/src/iree/compiler/InputConversion/StableHLO/LegalizeToLinalgUtils.h index 5f36df166b9c..ed1499be9e59 100644 --- a/compiler/src/iree/compiler/InputConversion/StableHLO/LegalizeToLinalgUtils.h +++ b/compiler/src/iree/compiler/InputConversion/StableHLO/LegalizeToLinalgUtils.h @@ -62,6 +62,11 @@ Value getEmptyTensor(OpBuilder& b, Location loc, ShapedType type, Value getEmptyTensorFor(OpBuilder& b, Location loc, ShapedType resultType, Operation* op, ValueRange operands); +/// Ensures a tensor has the same shape (not including the element type) as +/// another. +Value coerceTensorShape(OpBuilder& builder, Location loc, + TypedValue value, ShapedType targetType); + /// Sparsifies a (block of) operation(s) that cannot be handled directly /// by the sparse compiler but has well-known semi-ring semantics. /// diff --git a/compiler/src/iree/compiler/InputConversion/StableHLO/Rewriters.h b/compiler/src/iree/compiler/InputConversion/StableHLO/Rewriters.h index 36cea233fb3f..d2dd40480dd7 100644 --- a/compiler/src/iree/compiler/InputConversion/StableHLO/Rewriters.h +++ b/compiler/src/iree/compiler/InputConversion/StableHLO/Rewriters.h @@ -17,6 +17,17 @@ void populateStableHloToLinalgConversionPatterns(MLIRContext* context, RewritePatternSet* patterns, bool enablePrimitiveOps); +//===----------------------------------------------------------------------===// +// Fine-grained patterns used by the implementation. +//===----------------------------------------------------------------------===// +namespace detail { +/// Populates the patterns that convert from StableHLO to Linalg on tensors. +void populatePointwiseStableHloToLinalgConversionPatterns( + MLIRContext* context, TypeConverter& typeConverter, + RewritePatternSet* patterns, bool enablePrimitiveOps); + +} // namespace detail + } // namespace mlir::iree_compiler::stablehlo #endif // IREE_COMPILER_INPUTCONVERSION_STABLEHLO_REWRITERS_H_ diff --git a/compiler/src/iree/compiler/InputConversion/StableHLO/StableHLOToLinalg.cpp b/compiler/src/iree/compiler/InputConversion/StableHLO/StableHLOToLinalg.cpp index 2e8cec3ea0c0..de74e26a99df 100644 --- a/compiler/src/iree/compiler/InputConversion/StableHLO/StableHLOToLinalg.cpp +++ b/compiler/src/iree/compiler/InputConversion/StableHLO/StableHLOToLinalg.cpp @@ -135,15 +135,6 @@ Value extractIndexFromTensor(OpBuilder& builder, Location loc, Value tensor, loc, builder.getIndexType(), extracted); } -/// Ensures a tensor has the same shape (not including the element type) as -/// another. -Value coerceTensorShape(OpBuilder& builder, Location loc, - TypedValue value, ShapedType targetType) { - return builder.createOrFold( - loc, targetType.cloneWith(std::nullopt, value.getType().getElementType()), - value); -} - //===----------------------------------------------------------------------===// // stablehlo.RngOp conversion patterns. //===----------------------------------------------------------------------===// @@ -2909,110 +2900,6 @@ class DotGeneralOpConversion } }; -/// Converts a HLO operation to a linalg.map op that contains the corresponding -/// scalar operations. -template -class PointwiseToLinalgMapConverter : public OpConversionPattern { - public: - using OpConversionPattern::OpConversionPattern; - - LogicalResult matchAndRewrite( - OpTy op, typename OpTy::Adaptor adaptor, - ConversionPatternRewriter& rewriter) const final { - auto loc = op.getLoc(); - int64_t maxRank = getMaxRank(adaptor); - - // Apply only if all operands are scalar or have the same rank. Some ops, - // like `mhlo.select`, support implicit broadcasting of scalars. - if (!llvm::all_of(adaptor.getOperands(), [&](Value v) { - int64_t r = getRank(v); - return r == 0 || r == maxRank; - })) { - return rewriter.notifyMatchFailure( - op, "Operands must be of same rank or scalar."); - } - - // Find result type, if on tensors. - std::optional resultTy; - resultTy = this->typeConverter->convertType(op->getResultTypes().front()) - .template dyn_cast(); - - // Check result type compatibility. - if (!resultTy || !resultTy->hasRank() || resultTy->getRank() != maxRank || - !(resultTy->getElementType().isSignlessIntOrFloat() || - resultTy->getElementType().isa())) { - return rewriter.notifyMatchFailure( - op, "mismatched operand/result types or iterator count"); - } - - // All-scalar pointwise ops inside of linalg ops are processes by - // ScalarHloToArithmeticPattern. - if (maxRank == 0 && isInBodyOfLinalgOps(op)) return failure(); - - // Find input/output values and types. - Value emptyTensor = - getEmptyTensorFor(rewriter, loc, *resultTy, op, adaptor.getOperands()); - - // Mapped inputs are cast to the same shape as the init tensor. - // Values from scalar inputs are extracted and used directly in the block. - SmallVector mappedInputs; - SmallVector scalarInputs; - for (Value input : adaptor.getOperands()) { - if (getRank(input) == maxRank) { - mappedInputs.push_back(coerceTensorShape( - rewriter, loc, cast>(input), - emptyTensor.getType())); - scalarInputs.push_back(nullptr); - } else { - scalarInputs.push_back(rewriter.create(loc, input)); - } - } - - auto mapOp = rewriter.create( - loc, mappedInputs, emptyTensor, - [&](OpBuilder& b, Location loc, ValueRange args) { - Value innerResult = stablehlo::StableHloOpToStdScalarOp::mapOp( - op, getElementTypeOrSelf(emptyTensor), - interleaveScalarAndBlockArgs(scalarInputs, args), &b); - - b.create(loc, innerResult); - }, - linalg::getPrunedAttributeList(op)); - - rewriter.replaceOp(op, mapOp->getResults()); - return success(); - } - - protected: - int64_t getRank(Value v) const { - return v.getType().cast().getRank(); - } - - int64_t getMaxRank(typename OpTy::Adaptor adaptor) const { - int64_t maxRank = 0; - for (auto operand : adaptor.getOperands()) { - maxRank = std::max(maxRank, getRank(operand)); - } - return maxRank; - } - - // Inserts block arguments in places where scalar inputs have a nullptr. - SmallVector interleaveScalarAndBlockArgs(ValueRange scalarInputs, - ValueRange blockArgs) const { - SmallVector result; - auto argsIter = blockArgs.begin(); - for (Value scalarInput : scalarInputs) { - if (scalarInput) { - result.push_back(scalarInput); - } else { - result.push_back(*argsIter); - ++argsIter; - } - } - return result; - } -}; - class SetDimensionSizeConverter : public OpConversionPattern { public: @@ -3108,6 +2995,9 @@ void populateStableHloToLinalgConversionPatterns(MLIRContext* context, SelectAndScatterNoOverlapConverter, ReduceRegionReturnOpConversion>(typeConverter, context); + detail::populatePointwiseStableHloToLinalgConversionPatterns( + context, typeConverter, patterns, enablePrimitiveOps); + if (enablePrimitiveOps) { patterns->add< BroadcastInDimOpToBroadcastConverter, @@ -3116,52 +3006,6 @@ void populateStableHloToLinalgConversionPatterns(MLIRContext* context, IotaToMapConverter, IotaToMapConverter, MapOpToMapConverter, - PointwiseToLinalgMapConverter, - PointwiseToLinalgMapConverter, - PointwiseToLinalgMapConverter, - PointwiseToLinalgMapConverter, - PointwiseToLinalgMapConverter, - PointwiseToLinalgMapConverter, - PointwiseToLinalgMapConverter, - PointwiseToLinalgMapConverter, - PointwiseToLinalgMapConverter, - PointwiseToLinalgMapConverter, - PointwiseToLinalgMapConverter, - PointwiseToLinalgMapConverter, - PointwiseToLinalgMapConverter, - PointwiseToLinalgMapConverter, - PointwiseToLinalgMapConverter, - PointwiseToLinalgMapConverter, - PointwiseToLinalgMapConverter, - PointwiseToLinalgMapConverter, - PointwiseToLinalgMapConverter, - PointwiseToLinalgMapConverter, - PointwiseToLinalgMapConverter, - PointwiseToLinalgMapConverter, - PointwiseToLinalgMapConverter, - PointwiseToLinalgMapConverter, - PointwiseToLinalgMapConverter, - PointwiseToLinalgMapConverter, - PointwiseToLinalgMapConverter, - PointwiseToLinalgMapConverter, - PointwiseToLinalgMapConverter, - PointwiseToLinalgMapConverter, - PointwiseToLinalgMapConverter, - PointwiseToLinalgMapConverter, - PointwiseToLinalgMapConverter, - PointwiseToLinalgMapConverter, - PointwiseToLinalgMapConverter, - PointwiseToLinalgMapConverter, - PointwiseToLinalgMapConverter, - PointwiseToLinalgMapConverter, - PointwiseToLinalgMapConverter, - PointwiseToLinalgMapConverter, - PointwiseToLinalgMapConverter, - PointwiseToLinalgMapConverter, - PointwiseToLinalgMapConverter, - PointwiseToLinalgMapConverter, - PointwiseToLinalgMapConverter, - PointwiseToLinalgMapConverter, ReduceOpToReduceConverter, TransposeOpToTransposeConverter >(typeConverter, context); @@ -3173,52 +3017,6 @@ void populateStableHloToLinalgConversionPatterns(MLIRContext* context, HloBroadcastInDimConverter, HloDynamicBroadcastInDimConverter, MapOpToGenericConverter, - PointwiseToLinalgConverter, - PointwiseToLinalgConverter, - PointwiseToLinalgConverter, - PointwiseToLinalgConverter, - PointwiseToLinalgConverter, - PointwiseToLinalgConverter, - PointwiseToLinalgConverter, - PointwiseToLinalgConverter, - PointwiseToLinalgConverter, - PointwiseToLinalgConverter, - PointwiseToLinalgConverter, - PointwiseToLinalgConverter, - PointwiseToLinalgConverter, - PointwiseToLinalgConverter, - PointwiseToLinalgConverter, - PointwiseToLinalgConverter, - PointwiseToLinalgConverter, - PointwiseToLinalgConverter, - PointwiseToLinalgConverter, - PointwiseToLinalgConverter, - PointwiseToLinalgConverter, - PointwiseToLinalgConverter, - PointwiseToLinalgConverter, - PointwiseToLinalgConverter, - PointwiseToLinalgConverter, - PointwiseToLinalgConverter, - PointwiseToLinalgConverter, - PointwiseToLinalgConverter, - PointwiseToLinalgConverter, - PointwiseToLinalgConverter, - PointwiseToLinalgConverter, - PointwiseToLinalgConverter, - PointwiseToLinalgConverter, - PointwiseToLinalgConverter, - PointwiseToLinalgConverter, - PointwiseToLinalgConverter, - PointwiseToLinalgConverter, - PointwiseToLinalgConverter, - PointwiseToLinalgConverter, - PointwiseToLinalgConverter, - PointwiseToLinalgConverter, - PointwiseToLinalgConverter, - PointwiseToLinalgConverter, - PointwiseToLinalgConverter, - PointwiseToLinalgConverter, - PointwiseToLinalgConverter, ReduceOpToGenericConverter, TransposeConverter >(typeConverter, context); diff --git a/compiler/src/iree/compiler/InputConversion/StableHLO/StableHLOToLinalgPointwise.cpp b/compiler/src/iree/compiler/InputConversion/StableHLO/StableHLOToLinalgPointwise.cpp new file mode 100644 index 000000000000..2c57ef64b47b --- /dev/null +++ b/compiler/src/iree/compiler/InputConversion/StableHLO/StableHLOToLinalgPointwise.cpp @@ -0,0 +1,229 @@ +// Copyright 2019 The IREE Authors +// +// Licensed under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception + +// Implements logic for lowering StableHLO/CHLO pointwise ops to Linalg dialect. +// These patterns are separated out to their own file to save on the compilation +// times, given that we instantiate a large number of class templates here. + +#include "iree/compiler/InputConversion/StableHLO/LegalizeToLinalgUtils.h" +#include "iree/compiler/InputConversion/StableHLO/MapStableHLOToScalarOp.h" +#include "iree/compiler/InputConversion/StableHLO/Rewriters.h" +#include "iree/compiler/InputConversion/StableHLO/TypeConversion.h" +#include "mlir/Support/LogicalResult.h" +#include "mlir/Transforms/DialectConversion.h" +#include "stablehlo/dialect/StablehloOps.h" + +namespace mlir::iree_compiler::stablehlo { +namespace { +namespace stablehlo = mlir::stablehlo; + +int64_t getRank(Value v) { return cast(v.getType()).getRank(); } + +/// Inserts block arguments in places where scalar inputs have a nullptr. +SmallVector interleaveScalarAndBlockArgs(ValueRange scalarInputs, + ValueRange blockArgs) { + SmallVector result; + auto argsIter = blockArgs.begin(); + for (Value scalarInput : scalarInputs) { + if (scalarInput) { + result.push_back(scalarInput); + } else { + result.push_back(*argsIter); + ++argsIter; + } + } + return result; +} + +/// Converts a HLO operation to a linalg.map op that contains the corresponding +/// scalar operations. +template +struct PointwiseToLinalgMapConverter final : OpConversionPattern { + using OpConversionPattern::OpConversionPattern; + using OpAdaptor = typename OpTy::Adaptor; + + LogicalResult matchAndRewrite( + OpTy op, OpAdaptor adaptor, + ConversionPatternRewriter& rewriter) const override { + auto loc = op.getLoc(); + int64_t maxRank = getMaxRank(adaptor); + + // Apply only if all operands are scalar or have the same rank. Some ops, + // like `mhlo.select`, support implicit broadcasting of scalars. + if (!llvm::all_of(adaptor.getOperands(), [&](Value v) { + int64_t r = getRank(v); + return r == 0 || r == maxRank; + })) { + return rewriter.notifyMatchFailure( + op, "Operands must be of same rank or scalar."); + } + + // Find result type, if on tensors. + auto resultTy = dyn_cast_or_null( + this->typeConverter->convertType(op->getResultTypes().front())); + + // Check result type compatibility. + if (!resultTy || !resultTy.hasRank() || resultTy.getRank() != maxRank || + !(resultTy.getElementType().isSignlessIntOrFloat() || + isa(resultTy.getElementType()))) { + return rewriter.notifyMatchFailure( + op, "mismatched operand/result types or iterator count"); + } + + // All-scalar pointwise ops inside of linalg ops are processes by + // ScalarHloToArithmeticPattern. + if (maxRank == 0 && isInBodyOfLinalgOps(op)) return failure(); + + // Find input/output values and types. + Value emptyTensor = + getEmptyTensorFor(rewriter, loc, resultTy, op, adaptor.getOperands()); + + // Mapped inputs are cast to the same shape as the init tensor. + // Values from scalar inputs are extracted and used directly in the block. + SmallVector mappedInputs; + SmallVector scalarInputs; + for (Value input : adaptor.getOperands()) { + if (getRank(input) == maxRank) { + mappedInputs.push_back(coerceTensorShape( + rewriter, loc, cast>(input), + emptyTensor.getType())); + scalarInputs.push_back(nullptr); + } else { + scalarInputs.push_back(rewriter.create(loc, input)); + } + } + + auto mapOp = rewriter.create( + loc, mappedInputs, emptyTensor, + [&](OpBuilder& b, Location loc, ValueRange args) { + Value innerResult = stablehlo::StableHloOpToStdScalarOp::mapOp( + op, getElementTypeOrSelf(emptyTensor), + interleaveScalarAndBlockArgs(scalarInputs, args), &b); + + b.create(loc, innerResult); + }, + linalg::getPrunedAttributeList(op)); + + rewriter.replaceOp(op, mapOp->getResults()); + return success(); + } + + static int64_t getMaxRank(OpAdaptor adaptor) { + int64_t maxRank = 0; + for (Value operand : adaptor.getOperands()) { + maxRank = std::max(maxRank, getRank(operand)); + } + return maxRank; + } +}; +} // namespace + +namespace detail { +void populatePointwiseStableHloToLinalgConversionPatterns( + MLIRContext* context, TypeConverter& typeConverter, + RewritePatternSet* patterns, bool enablePrimitiveOps) { + if (enablePrimitiveOps) { + patterns + ->add, + PointwiseToLinalgMapConverter, + PointwiseToLinalgMapConverter, + PointwiseToLinalgMapConverter, + PointwiseToLinalgMapConverter, + PointwiseToLinalgMapConverter, + PointwiseToLinalgMapConverter, + PointwiseToLinalgMapConverter, + PointwiseToLinalgMapConverter, + PointwiseToLinalgMapConverter, + PointwiseToLinalgMapConverter, + PointwiseToLinalgMapConverter, + PointwiseToLinalgMapConverter, + PointwiseToLinalgMapConverter, + PointwiseToLinalgMapConverter, + PointwiseToLinalgMapConverter, + PointwiseToLinalgMapConverter, + PointwiseToLinalgMapConverter, + PointwiseToLinalgMapConverter, + PointwiseToLinalgMapConverter, + PointwiseToLinalgMapConverter, + PointwiseToLinalgMapConverter, + PointwiseToLinalgMapConverter, + PointwiseToLinalgMapConverter, + PointwiseToLinalgMapConverter, + PointwiseToLinalgMapConverter, + PointwiseToLinalgMapConverter, + PointwiseToLinalgMapConverter, + PointwiseToLinalgMapConverter, + PointwiseToLinalgMapConverter, + PointwiseToLinalgMapConverter, + PointwiseToLinalgMapConverter, + PointwiseToLinalgMapConverter, + PointwiseToLinalgMapConverter, + PointwiseToLinalgMapConverter, + PointwiseToLinalgMapConverter, + PointwiseToLinalgMapConverter, + PointwiseToLinalgMapConverter, + PointwiseToLinalgMapConverter, + PointwiseToLinalgMapConverter, + PointwiseToLinalgMapConverter, + PointwiseToLinalgMapConverter, + PointwiseToLinalgMapConverter, + PointwiseToLinalgMapConverter, + PointwiseToLinalgMapConverter, + PointwiseToLinalgMapConverter>(typeConverter, + context); + return; + } + + patterns->add, + PointwiseToLinalgConverter, + PointwiseToLinalgConverter, + PointwiseToLinalgConverter, + PointwiseToLinalgConverter, + PointwiseToLinalgConverter, + PointwiseToLinalgConverter, + PointwiseToLinalgConverter, + PointwiseToLinalgConverter, + PointwiseToLinalgConverter, + PointwiseToLinalgConverter, + PointwiseToLinalgConverter, + PointwiseToLinalgConverter, + PointwiseToLinalgConverter, + PointwiseToLinalgConverter, + PointwiseToLinalgConverter, + PointwiseToLinalgConverter, + PointwiseToLinalgConverter, + PointwiseToLinalgConverter, + PointwiseToLinalgConverter, + PointwiseToLinalgConverter, + PointwiseToLinalgConverter, + PointwiseToLinalgConverter, + PointwiseToLinalgConverter, + PointwiseToLinalgConverter, + PointwiseToLinalgConverter, + PointwiseToLinalgConverter, + PointwiseToLinalgConverter, + PointwiseToLinalgConverter, + PointwiseToLinalgConverter, + PointwiseToLinalgConverter, + PointwiseToLinalgConverter, + PointwiseToLinalgConverter, + PointwiseToLinalgConverter, + PointwiseToLinalgConverter, + PointwiseToLinalgConverter, + PointwiseToLinalgConverter, + PointwiseToLinalgConverter, + PointwiseToLinalgConverter, + PointwiseToLinalgConverter, + PointwiseToLinalgConverter, + PointwiseToLinalgConverter, + PointwiseToLinalgConverter, + PointwiseToLinalgConverter, + PointwiseToLinalgConverter, + PointwiseToLinalgConverter>(typeConverter, + context); +} +} // namespace detail +} // namespace mlir::iree_compiler::stablehlo