diff --git a/mlir/include/mlir/Dialect/Arith/Utils/Utils.h b/mlir/include/mlir/Dialect/Arith/Utils/Utils.h index 2111a7c5810294..5e7945d9b04928 100644 --- a/mlir/include/mlir/Dialect/Arith/Utils/Utils.h +++ b/mlir/include/mlir/Dialect/Arith/Utils/Utils.h @@ -24,6 +24,29 @@ namespace mlir { +using ReassociationIndices = SmallVector; + +/// Infer the output shape for a {memref|tensor}.expand_shape when it is +/// possible to do so. +/// +/// Note: This should *only* be used to implement +/// `ExpandShapeOp::inferOutputShape` in both the memref and tensor namespaces. +/// If you need to infer the output shape you should use the static method of +/// `ExpandShapeOp` instead of calling this. +/// +/// `inputShape` is the shape of the tensor or memref being expanded as a +/// sequence of SSA values or constants. `expandedType` is the output shape of +/// the expand_shape operation. `reassociation` is the reassociation denoting +/// the output dims each input dim is mapped to. +/// +/// Returns the output shape in `outputShape` and `staticOutputShape`, following +/// the conventions for the output_shape and static_output_shape inputs to the +/// expand_shape ops. +std::optional> +inferExpandShapeOutputShape(OpBuilder &b, Location loc, ShapedType expandedType, + ArrayRef reassociation, + ArrayRef inputShape); + /// Matches a ConstantIndexOp. detail::op_matcher matchConstantIndex(); diff --git a/mlir/include/mlir/Dialect/Utils/ReshapeOpsUtils.h b/mlir/include/mlir/Dialect/Utils/ReshapeOpsUtils.h index 8a41a0a18b0ab3..e8f6edc3f133e1 100644 --- a/mlir/include/mlir/Dialect/Utils/ReshapeOpsUtils.h +++ b/mlir/include/mlir/Dialect/Utils/ReshapeOpsUtils.h @@ -30,27 +30,6 @@ using ReassociationExprs = SmallVector; /// Attribute name for the ArrayAttr which encodes reassociation indices. constexpr StringRef getReassociationAttrName() { return "reassociation"; } -// Infer the output shape for a {memref|tensor}.expand_shape when it is possible -// to do so. -// -// Note: This should *only* be used to implement -// `ExpandShapeOp::inferOutputShape` in both the memref and tensor namespaces. -// If you need to infer the output shape you should use the static method of -// `ExpandShapeOp` instead of calling this. -// -// `inputShape` is the shape of the tensor or memref being expanded as a -// sequence of SSA values or constants. `expandedType` is the output shape of -// the expand_shape operation. `reassociation` is the reassociation denoting -// the output dims each input dim is mapped to. -// -// Returns the output shape in `outputShape` and `staticOutputShape`, following -// the conventions for the output_shape and static_output_shape inputs to the -// expand_shape ops. -std::optional> -inferExpandShapeOutputShape(OpBuilder &b, Location loc, ShapedType expandedType, - ArrayRef reassociation, - ArrayRef inputShape); - /// Compose reassociation maps that are used in pair of reshape ops where one /// is a producer and other is the consumer. Only valid to use this method when /// both the producer and consumer are collapsing dimensions or both are diff --git a/mlir/lib/Dialect/Arith/Utils/CMakeLists.txt b/mlir/lib/Dialect/Arith/Utils/CMakeLists.txt index 2be2724d4a9172..07fa58b209b5e3 100644 --- a/mlir/lib/Dialect/Arith/Utils/CMakeLists.txt +++ b/mlir/lib/Dialect/Arith/Utils/CMakeLists.txt @@ -8,5 +8,6 @@ add_mlir_dialect_library(MLIRArithUtils MLIRArithDialect MLIRComplexDialect MLIRDialect + MLIRDialectUtils MLIRIR ) diff --git a/mlir/lib/Dialect/Arith/Utils/Utils.cpp b/mlir/lib/Dialect/Arith/Utils/Utils.cpp index aa239f5e053969..4ce55a23820cf7 100644 --- a/mlir/lib/Dialect/Arith/Utils/Utils.cpp +++ b/mlir/lib/Dialect/Arith/Utils/Utils.cpp @@ -13,12 +13,74 @@ #include "mlir/Dialect/Arith/Utils/Utils.h" #include "mlir/Dialect/Arith/IR/Arith.h" #include "mlir/Dialect/Complex/IR/Complex.h" +#include "mlir/Dialect/Utils/StaticValueUtils.h" #include "mlir/IR/ImplicitLocOpBuilder.h" #include "llvm/ADT/SmallBitVector.h" #include using namespace mlir; +std::optional> +mlir::inferExpandShapeOutputShape(OpBuilder &b, Location loc, + ShapedType expandedType, + ArrayRef reassociation, + ArrayRef inputShape) { + + SmallVector outputShapeValues; + SmallVector outputShapeInts; + // For zero-rank inputs, all dims in result shape are unit extent. + if (inputShape.empty()) { + outputShapeInts.resize(expandedType.getRank(), 1); + return getMixedValues(outputShapeInts, outputShapeValues, b); + } + + // Check for all static shapes. + if (expandedType.hasStaticShape()) { + ArrayRef staticShape = expandedType.getShape(); + outputShapeInts.assign(staticShape.begin(), staticShape.end()); + return getMixedValues(outputShapeInts, outputShapeValues, b); + } + + outputShapeInts.resize(expandedType.getRank(), ShapedType::kDynamic); + for (const auto &it : llvm::enumerate(reassociation)) { + ReassociationIndices indexGroup = it.value(); + + int64_t indexGroupStaticSizesProductInt = 1; + bool foundDynamicShape = false; + for (int64_t index : indexGroup) { + int64_t outputDimSize = expandedType.getDimSize(index); + // Cannot infer expanded shape with multiple dynamic dims in the + // same reassociation group! + if (ShapedType::isDynamic(outputDimSize)) { + if (foundDynamicShape) + return std::nullopt; + foundDynamicShape = true; + } else { + outputShapeInts[index] = outputDimSize; + indexGroupStaticSizesProductInt *= outputDimSize; + } + } + if (!foundDynamicShape) + continue; + + int64_t inputIndex = it.index(); + // Call get() under the assumption that we're not casting + // dynamism. + Value indexGroupSize = inputShape[inputIndex].get(); + Value indexGroupStaticSizesProduct = + b.create(loc, indexGroupStaticSizesProductInt); + Value dynamicDimSize = b.createOrFold( + loc, indexGroupSize, indexGroupStaticSizesProduct); + outputShapeValues.push_back(dynamicDimSize); + } + + if ((int64_t)outputShapeValues.size() != + llvm::count(outputShapeInts, ShapedType::kDynamic)) + return std::nullopt; + + return getMixedValues(outputShapeInts, outputShapeValues, b); +} + /// Matches a ConstantIndexOp. /// TODO: This should probably just be a general matcher that uses matchConstant /// and checks the operation for an index type. diff --git a/mlir/lib/Dialect/Utils/CMakeLists.txt b/mlir/lib/Dialect/Utils/CMakeLists.txt index a4cd96db1ff434..a0096e5f299d59 100644 --- a/mlir/lib/Dialect/Utils/CMakeLists.txt +++ b/mlir/lib/Dialect/Utils/CMakeLists.txt @@ -8,6 +8,5 @@ add_mlir_library(MLIRDialectUtils MLIRDialectUtilsIncGen LINK_LIBS PUBLIC - MLIRArithDialect MLIRIR ) diff --git a/mlir/lib/Dialect/Utils/ReshapeOpsUtils.cpp b/mlir/lib/Dialect/Utils/ReshapeOpsUtils.cpp index 6161faf7e30e11..e4f387d40ced2e 100644 --- a/mlir/lib/Dialect/Utils/ReshapeOpsUtils.cpp +++ b/mlir/lib/Dialect/Utils/ReshapeOpsUtils.cpp @@ -8,7 +8,6 @@ #include "mlir/Dialect/Utils/ReshapeOpsUtils.h" -#include "mlir/Dialect/Arith/IR/Arith.h" #include "mlir/IR/AffineMap.h" #include "mlir/IR/Builders.h" @@ -17,67 +16,6 @@ using namespace mlir; -std::optional> -mlir::inferExpandShapeOutputShape(OpBuilder &b, Location loc, - ShapedType expandedType, - ArrayRef reassociation, - ArrayRef inputShape) { - - SmallVector outputShapeValues; - SmallVector outputShapeInts; - // For zero-rank inputs, all dims in result shape are unit extent. - if (inputShape.empty()) { - outputShapeInts.resize(expandedType.getRank(), 1); - return getMixedValues(outputShapeInts, outputShapeValues, b); - } - - // Check for all static shapes. - if (expandedType.hasStaticShape()) { - ArrayRef staticShape = expandedType.getShape(); - outputShapeInts.assign(staticShape.begin(), staticShape.end()); - return getMixedValues(outputShapeInts, outputShapeValues, b); - } - - outputShapeInts.resize(expandedType.getRank(), ShapedType::kDynamic); - for (const auto &it : llvm::enumerate(reassociation)) { - ReassociationIndices indexGroup = it.value(); - - int64_t indexGroupStaticSizesProductInt = 1; - bool foundDynamicShape = false; - for (int64_t index : indexGroup) { - int64_t outputDimSize = expandedType.getDimSize(index); - // Cannot infer expanded shape with multiple dynamic dims in the - // same reassociation group! - if (ShapedType::isDynamic(outputDimSize)) { - if (foundDynamicShape) - return std::nullopt; - foundDynamicShape = true; - } else { - outputShapeInts[index] = outputDimSize; - indexGroupStaticSizesProductInt *= outputDimSize; - } - } - if (!foundDynamicShape) - continue; - - int64_t inputIndex = it.index(); - // Call get() under the assumption that we're not casting - // dynamism. - Value indexGroupSize = inputShape[inputIndex].get(); - Value indexGroupStaticSizesProduct = - b.create(loc, indexGroupStaticSizesProductInt); - Value dynamicDimSize = b.createOrFold( - loc, indexGroupSize, indexGroupStaticSizesProduct); - outputShapeValues.push_back(dynamicDimSize); - } - - if ((int64_t)outputShapeValues.size() != - llvm::count(outputShapeInts, ShapedType::kDynamic)) - return std::nullopt; - - return getMixedValues(outputShapeInts, outputShapeValues, b); -} - std::optional> mlir::getReassociationIndicesForReshape(ShapedType sourceType, ShapedType targetType) { diff --git a/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel b/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel index 1202fed564c46e..1e39bf252577bf 100644 --- a/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel +++ b/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel @@ -3875,7 +3875,6 @@ cc_library( includes = ["include"], deps = [ ":DialectUtilsIncGen", - ":ArithDialect", ":IR", ":Support", "//llvm:Support", @@ -12635,6 +12634,7 @@ cc_library( deps = [ ":ArithDialect", ":ComplexDialect", + ":DialectUtils", ":IR", "//llvm:Support", ],