diff --git a/compiler/src/iree/compiler/InputConversion/StableHLO/LegalizeToLinalgUtils.h b/compiler/src/iree/compiler/InputConversion/StableHLO/LegalizeToLinalgUtils.h index 0050085665d8..f33b18098bc3 100644 --- a/compiler/src/iree/compiler/InputConversion/StableHLO/LegalizeToLinalgUtils.h +++ b/compiler/src/iree/compiler/InputConversion/StableHLO/LegalizeToLinalgUtils.h @@ -100,92 +100,6 @@ bool allOperandsAreScalarTensors(Operation* op); /// Returns true if parent op is linalg. bool isInBodyOfLinalgOps(Operation* op); -/// Converts a HLO operation to a linalg.generic op that contains the -/// corresponding scalar operations. -template -class PointwiseToLinalgConverter : public OpConversionPattern { - public: - using OpConversionPattern::OpConversionPattern; - - LogicalResult matchAndRewrite( - OpTy op, typename OpTy::Adaptor adaptor, - ConversionPatternRewriter& rewriter) const final { - auto loc = op.getLoc(); - // Find maximum rank / number of loops. - auto getRank = [](Value v) { - return v.getType().cast().getRank(); - }; - auto isScalar = [&](Value v) { return getRank(v) == 0; }; - auto it = llvm::find_if_not(adaptor.getOperands(), isScalar); - Value maxRankArg = - it != adaptor.getOperands().end() ? *it : adaptor.getOperands().front(); - int64_t nloops = getRank(maxRankArg); - - // Apply only if all operands are scalar or have the same rank. Some ops, - // like `mhlo.select`, support implicit broadcasting of scalars. - if (!llvm::all_of(adaptor.getOperands(), [&](Value v) { - int64_t r = getRank(v); - return r == 0 || r == nloops; - })) { - return rewriter.notifyMatchFailure( - op, "Operands must be os same rank or scalar."); - } - - // Find result type, if on tensors. - std::optional resultTy; - resultTy = this->typeConverter->convertType(op->getResultTypes().front()) - .template dyn_cast(); - - // Check result type compatibility. - if (!resultTy || !resultTy->hasRank() || resultTy->getRank() != nloops || - !(resultTy->getElementType().isSignlessIntOrFloat() || - resultTy->getElementType().isa())) { - return rewriter.notifyMatchFailure( - op, "mismatched operand/result types or iterator count"); - } - - if (allOperandsAreScalarTensors(op) && isInBodyOfLinalgOps(op)) - return failure(); - - // Find input/output values and types. - ValueRange inputs = adaptor.getOperands(); - Value output = - getEmptyTensorFor(rewriter, loc, *resultTy, op, adaptor.getOperands()); - - // Create indexing maps. - AffineMap scalarMap = AffineMap::get(nloops, 0, rewriter.getContext()); - AffineMap idMap = rewriter.getMultiDimIdentityMap(nloops); - SmallVector maps; - for (Value v : inputs) maps.push_back(isScalar(v) ? scalarMap : idMap); - maps.push_back(idMap); - - // Build `linalg.generic` op. - bool failed = false; - auto linalgOp = rewriter.create( - loc, resultTy ? *resultTy : TypeRange{}, inputs, output, maps, - getNParallelLoopsAttrs(nloops), - [&](OpBuilder& nestedBuilder, Location /*nested_loc*/, - ValueRange args) { - Type innerResultTy = getElementTypeOrSelf(output); - auto argvec = llvm::to_vector<2>(args.take_front(inputs.size())); - auto semiring = preSparsify(op, argvec, innerResultTy, &rewriter); - Value innerResult = mlir::stablehlo::StableHloOpToStdScalarOp::mapOp( - op, innerResultTy, argvec, &rewriter); - if (innerResult == nullptr) { - failed = true; - } else { - innerResult = postSparsify(op, semiring, innerResult, &rewriter); - nestedBuilder.create(loc, innerResult); - } - }, - linalg::getPrunedAttributeList(op)); - if (failed) return failure(); - - rewriter.replaceOp(op, linalgOp->getResults()); - return success(); - } -}; - } // namespace mlir::iree_compiler::stablehlo #endif // IREE_COMPILER_INPUTCONVERSION_STABLEHLO_LEGALIZE_TO_LINALG_UTILS_H_ diff --git a/compiler/src/iree/compiler/InputConversion/StableHLO/StableHLOToLinalgPointwise.cpp b/compiler/src/iree/compiler/InputConversion/StableHLO/StableHLOToLinalgPointwise.cpp index 2c57ef64b47b..378d987b5b7b 100644 --- a/compiler/src/iree/compiler/InputConversion/StableHLO/StableHLOToLinalgPointwise.cpp +++ b/compiler/src/iree/compiler/InputConversion/StableHLO/StableHLOToLinalgPointwise.cpp @@ -12,6 +12,7 @@ #include "iree/compiler/InputConversion/StableHLO/MapStableHLOToScalarOp.h" #include "iree/compiler/InputConversion/StableHLO/Rewriters.h" #include "iree/compiler/InputConversion/StableHLO/TypeConversion.h" +#include "mlir/IR/ValueRange.h" #include "mlir/Support/LogicalResult.h" #include "mlir/Transforms/DialectConversion.h" #include "stablehlo/dialect/StablehloOps.h" @@ -22,6 +23,16 @@ namespace stablehlo = mlir::stablehlo; int64_t getRank(Value v) { return cast(v.getType()).getRank(); } +int64_t getMaxRank(ValueRange operands) { + int64_t maxRank = 0; + for (Value operand : operands) { + maxRank = std::max(maxRank, getRank(operand)); + } + return maxRank; +} + +bool isScalar(Value v) { return getRank(v) == 0; } + /// Inserts block arguments in places where scalar inputs have a nullptr. SmallVector interleaveScalarAndBlockArgs(ValueRange scalarInputs, ValueRange blockArgs) { @@ -38,6 +49,47 @@ SmallVector interleaveScalarAndBlockArgs(ValueRange scalarInputs, return result; } +struct PointwiseConversionInfo { + int64_t maxOperandRank = 0; + ShapedType resultType; +}; + +/// Checks the preconditions for conversion of pointwise HLO ops to linalg. +/// Returns the max operand rank and the result type on success. +FailureOr checkOperandsAndResults( + Operation* op, ValueRange operands, TypeConverter& typeConverter, + ConversionPatternRewriter& rewriter) { + int64_t maxRank = getMaxRank(operands); + + // Apply only if all operands are scalar or have the same rank. Some ops, + // like `mhlo.select`, support implicit broadcasting of scalars. + if (!llvm::all_of(operands, [&](Value v) { + int64_t r = getRank(v); + return r == 0 || r == maxRank; + })) { + return rewriter.notifyMatchFailure( + op, "Operands must be of same rank or scalar."); + } + + // Find result type, if on tensors. + auto resultTy = dyn_cast_or_null( + typeConverter.convertType(op->getResultTypes().front())); + + // Check result type compatibility. + if (!resultTy || !resultTy.hasRank() || resultTy.getRank() != maxRank || + !(resultTy.getElementType().isSignlessIntOrFloat() || + isa(resultTy.getElementType()))) { + return rewriter.notifyMatchFailure( + op, "mismatched operand/result types or iterator count"); + } + + // All-scalar pointwise ops inside of linalg ops are processes by + // ScalarHloToArithmeticPattern. + if (maxRank == 0 && isInBodyOfLinalgOps(op)) return failure(); + + return PointwiseConversionInfo{maxRank, resultTy}; +} + /// Converts a HLO operation to a linalg.map op that contains the corresponding /// scalar operations. template @@ -48,34 +100,15 @@ struct PointwiseToLinalgMapConverter final : OpConversionPattern { LogicalResult matchAndRewrite( OpTy op, OpAdaptor adaptor, ConversionPatternRewriter& rewriter) const override { - auto loc = op.getLoc(); - int64_t maxRank = getMaxRank(adaptor); - - // Apply only if all operands are scalar or have the same rank. Some ops, - // like `mhlo.select`, support implicit broadcasting of scalars. - if (!llvm::all_of(adaptor.getOperands(), [&](Value v) { - int64_t r = getRank(v); - return r == 0 || r == maxRank; - })) { - return rewriter.notifyMatchFailure( - op, "Operands must be of same rank or scalar."); - } - - // Find result type, if on tensors. - auto resultTy = dyn_cast_or_null( - this->typeConverter->convertType(op->getResultTypes().front())); - - // Check result type compatibility. - if (!resultTy || !resultTy.hasRank() || resultTy.getRank() != maxRank || - !(resultTy.getElementType().isSignlessIntOrFloat() || - isa(resultTy.getElementType()))) { - return rewriter.notifyMatchFailure( - op, "mismatched operand/result types or iterator count"); + auto conversionInfo = checkOperandsAndResults( + op, adaptor.getOperands(), *this->typeConverter, rewriter); + if (failed(conversionInfo)) { + return failure(); } - // All-scalar pointwise ops inside of linalg ops are processes by - // ScalarHloToArithmeticPattern. - if (maxRank == 0 && isInBodyOfLinalgOps(op)) return failure(); + int64_t maxRank = conversionInfo->maxOperandRank; + ShapedType resultTy = conversionInfo->resultType; + Location loc = op.getLoc(); // Find input/output values and types. Value emptyTensor = @@ -110,13 +143,64 @@ struct PointwiseToLinalgMapConverter final : OpConversionPattern { rewriter.replaceOp(op, mapOp->getResults()); return success(); } +}; + +/// Converts a HLO operation to a linalg.generic op that contains the +/// corresponding scalar operations. +template +struct PointwiseToLinalgConverter final : OpConversionPattern { + using OpConversionPattern::OpConversionPattern; + using OpAdaptor = typename OpTy::Adaptor; - static int64_t getMaxRank(OpAdaptor adaptor) { - int64_t maxRank = 0; - for (Value operand : adaptor.getOperands()) { - maxRank = std::max(maxRank, getRank(operand)); + LogicalResult matchAndRewrite( + OpTy op, OpAdaptor adaptor, + ConversionPatternRewriter& rewriter) const override { + auto conversionInfo = checkOperandsAndResults( + op, adaptor.getOperands(), *this->typeConverter, rewriter); + if (failed(conversionInfo)) { + return failure(); } - return maxRank; + + int64_t maxRank = conversionInfo->maxOperandRank; + ShapedType resultTy = conversionInfo->resultType; + Location loc = op.getLoc(); + + // Find input/output values and types. + ValueRange inputs = adaptor.getOperands(); + Value output = + getEmptyTensorFor(rewriter, loc, resultTy, op, adaptor.getOperands()); + + // Create indexing maps. + AffineMap scalarMap = AffineMap::get(maxRank, 0, rewriter.getContext()); + AffineMap idMap = rewriter.getMultiDimIdentityMap(maxRank); + SmallVector maps; + for (Value v : inputs) maps.push_back(isScalar(v) ? scalarMap : idMap); + maps.push_back(idMap); + + // Build `linalg.generic` op. + bool failed = false; + auto linalgOp = rewriter.create( + loc, resultTy ? resultTy : TypeRange{}, inputs, output, maps, + getNParallelLoopsAttrs(maxRank), + [&](OpBuilder& nestedBuilder, Location /*nested_loc*/, + ValueRange args) { + Type innerResultTy = getElementTypeOrSelf(output); + auto argvec = llvm::to_vector<2>(args.take_front(inputs.size())); + Value semiring = preSparsify(op, argvec, innerResultTy, &rewriter); + Value innerResult = mlir::stablehlo::StableHloOpToStdScalarOp::mapOp( + op, innerResultTy, argvec, &rewriter); + if (!innerResult) { + failed = true; + } else { + innerResult = postSparsify(op, semiring, innerResult, &rewriter); + nestedBuilder.create(loc, innerResult); + } + }, + linalg::getPrunedAttributeList(op)); + if (failed) return failure(); + + rewriter.replaceOp(op, linalgOp->getResults()); + return success(); } }; } // namespace