From 11dbbecec5e432155c1f8ef40c2bd7b33ab6ea54 Mon Sep 17 00:00:00 2001 From: Jakub Kuderski Date: Wed, 12 Apr 2023 11:10:48 -0400 Subject: [PATCH] [StableHLO][NFC] Move pointwise to linalg conversion to source file This moves the pointwise StableHLO op to `linalg.generic` conversion out of the common header and to the matching source file. In addition, unify the common precondition checks code used by both pointwise patterns and pull it out to a non-template function. This is to reduce the amount of duplicated code. This reduces the compilation time of this source file from 28s to 22s on my machine. Also clean up the moved pattern. Issue: https://github.com/openxla/iree/issues/12678 --- .../StableHLO/LegalizeToLinalgUtils.h | 86 ----------- .../StableHLO/StableHLOToLinalgPointwise.cpp | 146 ++++++++++++++---- 2 files changed, 115 insertions(+), 117 deletions(-) diff --git a/compiler/src/iree/compiler/InputConversion/StableHLO/LegalizeToLinalgUtils.h b/compiler/src/iree/compiler/InputConversion/StableHLO/LegalizeToLinalgUtils.h index 0050085665d8..f33b18098bc3 100644 --- a/compiler/src/iree/compiler/InputConversion/StableHLO/LegalizeToLinalgUtils.h +++ b/compiler/src/iree/compiler/InputConversion/StableHLO/LegalizeToLinalgUtils.h @@ -100,92 +100,6 @@ bool allOperandsAreScalarTensors(Operation* op); /// Returns true if parent op is linalg. bool isInBodyOfLinalgOps(Operation* op); -/// Converts a HLO operation to a linalg.generic op that contains the -/// corresponding scalar operations. -template -class PointwiseToLinalgConverter : public OpConversionPattern { - public: - using OpConversionPattern::OpConversionPattern; - - LogicalResult matchAndRewrite( - OpTy op, typename OpTy::Adaptor adaptor, - ConversionPatternRewriter& rewriter) const final { - auto loc = op.getLoc(); - // Find maximum rank / number of loops. - auto getRank = [](Value v) { - return v.getType().cast().getRank(); - }; - auto isScalar = [&](Value v) { return getRank(v) == 0; }; - auto it = llvm::find_if_not(adaptor.getOperands(), isScalar); - Value maxRankArg = - it != adaptor.getOperands().end() ? *it : adaptor.getOperands().front(); - int64_t nloops = getRank(maxRankArg); - - // Apply only if all operands are scalar or have the same rank. Some ops, - // like `mhlo.select`, support implicit broadcasting of scalars. - if (!llvm::all_of(adaptor.getOperands(), [&](Value v) { - int64_t r = getRank(v); - return r == 0 || r == nloops; - })) { - return rewriter.notifyMatchFailure( - op, "Operands must be os same rank or scalar."); - } - - // Find result type, if on tensors. - std::optional resultTy; - resultTy = this->typeConverter->convertType(op->getResultTypes().front()) - .template dyn_cast(); - - // Check result type compatibility. - if (!resultTy || !resultTy->hasRank() || resultTy->getRank() != nloops || - !(resultTy->getElementType().isSignlessIntOrFloat() || - resultTy->getElementType().isa())) { - return rewriter.notifyMatchFailure( - op, "mismatched operand/result types or iterator count"); - } - - if (allOperandsAreScalarTensors(op) && isInBodyOfLinalgOps(op)) - return failure(); - - // Find input/output values and types. - ValueRange inputs = adaptor.getOperands(); - Value output = - getEmptyTensorFor(rewriter, loc, *resultTy, op, adaptor.getOperands()); - - // Create indexing maps. - AffineMap scalarMap = AffineMap::get(nloops, 0, rewriter.getContext()); - AffineMap idMap = rewriter.getMultiDimIdentityMap(nloops); - SmallVector maps; - for (Value v : inputs) maps.push_back(isScalar(v) ? scalarMap : idMap); - maps.push_back(idMap); - - // Build `linalg.generic` op. - bool failed = false; - auto linalgOp = rewriter.create( - loc, resultTy ? *resultTy : TypeRange{}, inputs, output, maps, - getNParallelLoopsAttrs(nloops), - [&](OpBuilder& nestedBuilder, Location /*nested_loc*/, - ValueRange args) { - Type innerResultTy = getElementTypeOrSelf(output); - auto argvec = llvm::to_vector<2>(args.take_front(inputs.size())); - auto semiring = preSparsify(op, argvec, innerResultTy, &rewriter); - Value innerResult = mlir::stablehlo::StableHloOpToStdScalarOp::mapOp( - op, innerResultTy, argvec, &rewriter); - if (innerResult == nullptr) { - failed = true; - } else { - innerResult = postSparsify(op, semiring, innerResult, &rewriter); - nestedBuilder.create(loc, innerResult); - } - }, - linalg::getPrunedAttributeList(op)); - if (failed) return failure(); - - rewriter.replaceOp(op, linalgOp->getResults()); - return success(); - } -}; - } // namespace mlir::iree_compiler::stablehlo #endif // IREE_COMPILER_INPUTCONVERSION_STABLEHLO_LEGALIZE_TO_LINALG_UTILS_H_ diff --git a/compiler/src/iree/compiler/InputConversion/StableHLO/StableHLOToLinalgPointwise.cpp b/compiler/src/iree/compiler/InputConversion/StableHLO/StableHLOToLinalgPointwise.cpp index 2c57ef64b47b..378d987b5b7b 100644 --- a/compiler/src/iree/compiler/InputConversion/StableHLO/StableHLOToLinalgPointwise.cpp +++ b/compiler/src/iree/compiler/InputConversion/StableHLO/StableHLOToLinalgPointwise.cpp @@ -12,6 +12,7 @@ #include "iree/compiler/InputConversion/StableHLO/MapStableHLOToScalarOp.h" #include "iree/compiler/InputConversion/StableHLO/Rewriters.h" #include "iree/compiler/InputConversion/StableHLO/TypeConversion.h" +#include "mlir/IR/ValueRange.h" #include "mlir/Support/LogicalResult.h" #include "mlir/Transforms/DialectConversion.h" #include "stablehlo/dialect/StablehloOps.h" @@ -22,6 +23,16 @@ namespace stablehlo = mlir::stablehlo; int64_t getRank(Value v) { return cast(v.getType()).getRank(); } +int64_t getMaxRank(ValueRange operands) { + int64_t maxRank = 0; + for (Value operand : operands) { + maxRank = std::max(maxRank, getRank(operand)); + } + return maxRank; +} + +bool isScalar(Value v) { return getRank(v) == 0; } + /// Inserts block arguments in places where scalar inputs have a nullptr. SmallVector interleaveScalarAndBlockArgs(ValueRange scalarInputs, ValueRange blockArgs) { @@ -38,6 +49,47 @@ SmallVector interleaveScalarAndBlockArgs(ValueRange scalarInputs, return result; } +struct PointwiseConversionInfo { + int64_t maxOperandRank = 0; + ShapedType resultType; +}; + +/// Checks the preconditions for conversion of pointwise HLO ops to linalg. +/// Returns the max operand rank and the result type on success. +FailureOr checkOperandsAndResults( + Operation* op, ValueRange operands, TypeConverter& typeConverter, + ConversionPatternRewriter& rewriter) { + int64_t maxRank = getMaxRank(operands); + + // Apply only if all operands are scalar or have the same rank. Some ops, + // like `mhlo.select`, support implicit broadcasting of scalars. + if (!llvm::all_of(operands, [&](Value v) { + int64_t r = getRank(v); + return r == 0 || r == maxRank; + })) { + return rewriter.notifyMatchFailure( + op, "Operands must be of same rank or scalar."); + } + + // Find result type, if on tensors. + auto resultTy = dyn_cast_or_null( + typeConverter.convertType(op->getResultTypes().front())); + + // Check result type compatibility. + if (!resultTy || !resultTy.hasRank() || resultTy.getRank() != maxRank || + !(resultTy.getElementType().isSignlessIntOrFloat() || + isa(resultTy.getElementType()))) { + return rewriter.notifyMatchFailure( + op, "mismatched operand/result types or iterator count"); + } + + // All-scalar pointwise ops inside of linalg ops are processes by + // ScalarHloToArithmeticPattern. + if (maxRank == 0 && isInBodyOfLinalgOps(op)) return failure(); + + return PointwiseConversionInfo{maxRank, resultTy}; +} + /// Converts a HLO operation to a linalg.map op that contains the corresponding /// scalar operations. template @@ -48,34 +100,15 @@ struct PointwiseToLinalgMapConverter final : OpConversionPattern { LogicalResult matchAndRewrite( OpTy op, OpAdaptor adaptor, ConversionPatternRewriter& rewriter) const override { - auto loc = op.getLoc(); - int64_t maxRank = getMaxRank(adaptor); - - // Apply only if all operands are scalar or have the same rank. Some ops, - // like `mhlo.select`, support implicit broadcasting of scalars. - if (!llvm::all_of(adaptor.getOperands(), [&](Value v) { - int64_t r = getRank(v); - return r == 0 || r == maxRank; - })) { - return rewriter.notifyMatchFailure( - op, "Operands must be of same rank or scalar."); - } - - // Find result type, if on tensors. - auto resultTy = dyn_cast_or_null( - this->typeConverter->convertType(op->getResultTypes().front())); - - // Check result type compatibility. - if (!resultTy || !resultTy.hasRank() || resultTy.getRank() != maxRank || - !(resultTy.getElementType().isSignlessIntOrFloat() || - isa(resultTy.getElementType()))) { - return rewriter.notifyMatchFailure( - op, "mismatched operand/result types or iterator count"); + auto conversionInfo = checkOperandsAndResults( + op, adaptor.getOperands(), *this->typeConverter, rewriter); + if (failed(conversionInfo)) { + return failure(); } - // All-scalar pointwise ops inside of linalg ops are processes by - // ScalarHloToArithmeticPattern. - if (maxRank == 0 && isInBodyOfLinalgOps(op)) return failure(); + int64_t maxRank = conversionInfo->maxOperandRank; + ShapedType resultTy = conversionInfo->resultType; + Location loc = op.getLoc(); // Find input/output values and types. Value emptyTensor = @@ -110,13 +143,64 @@ struct PointwiseToLinalgMapConverter final : OpConversionPattern { rewriter.replaceOp(op, mapOp->getResults()); return success(); } +}; + +/// Converts a HLO operation to a linalg.generic op that contains the +/// corresponding scalar operations. +template +struct PointwiseToLinalgConverter final : OpConversionPattern { + using OpConversionPattern::OpConversionPattern; + using OpAdaptor = typename OpTy::Adaptor; - static int64_t getMaxRank(OpAdaptor adaptor) { - int64_t maxRank = 0; - for (Value operand : adaptor.getOperands()) { - maxRank = std::max(maxRank, getRank(operand)); + LogicalResult matchAndRewrite( + OpTy op, OpAdaptor adaptor, + ConversionPatternRewriter& rewriter) const override { + auto conversionInfo = checkOperandsAndResults( + op, adaptor.getOperands(), *this->typeConverter, rewriter); + if (failed(conversionInfo)) { + return failure(); } - return maxRank; + + int64_t maxRank = conversionInfo->maxOperandRank; + ShapedType resultTy = conversionInfo->resultType; + Location loc = op.getLoc(); + + // Find input/output values and types. + ValueRange inputs = adaptor.getOperands(); + Value output = + getEmptyTensorFor(rewriter, loc, resultTy, op, adaptor.getOperands()); + + // Create indexing maps. + AffineMap scalarMap = AffineMap::get(maxRank, 0, rewriter.getContext()); + AffineMap idMap = rewriter.getMultiDimIdentityMap(maxRank); + SmallVector maps; + for (Value v : inputs) maps.push_back(isScalar(v) ? scalarMap : idMap); + maps.push_back(idMap); + + // Build `linalg.generic` op. + bool failed = false; + auto linalgOp = rewriter.create( + loc, resultTy ? resultTy : TypeRange{}, inputs, output, maps, + getNParallelLoopsAttrs(maxRank), + [&](OpBuilder& nestedBuilder, Location /*nested_loc*/, + ValueRange args) { + Type innerResultTy = getElementTypeOrSelf(output); + auto argvec = llvm::to_vector<2>(args.take_front(inputs.size())); + Value semiring = preSparsify(op, argvec, innerResultTy, &rewriter); + Value innerResult = mlir::stablehlo::StableHloOpToStdScalarOp::mapOp( + op, innerResultTy, argvec, &rewriter); + if (!innerResult) { + failed = true; + } else { + innerResult = postSparsify(op, semiring, innerResult, &rewriter); + nestedBuilder.create(loc, innerResult); + } + }, + linalg::getPrunedAttributeList(op)); + if (failed) return failure(); + + rewriter.replaceOp(op, linalgOp->getResults()); + return success(); } }; } // namespace