Skip to content

Commit

Permalink
[MLIR] Remove ArithDialect dependency from Dialect/Utils
Browse files Browse the repository at this point in the history
This commit moves the inferExpandShapeOutputShape utility from the
Dialect/Utils/ReshapeOpsUtils.cpp to Arith/Utils/Utils.cpp in order to
remove specific dialect dependencies from the DialectUtils.

Signed-Off-by: Gaurav Shukla <gaurav.shukla@amd.com>
  • Loading branch information
Shukla-Gaurav committed Apr 26, 2024
1 parent 0000fe8 commit c4d9869
Show file tree
Hide file tree
Showing 7 changed files with 87 additions and 85 deletions.
23 changes: 23 additions & 0 deletions mlir/include/mlir/Dialect/Arith/Utils/Utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,29 @@

namespace mlir {

using ReassociationIndices = SmallVector<int64_t, 2>;

/// Infer the output shape for a {memref|tensor}.expand_shape when it is
/// possible to do so.
///
/// Note: This should *only* be used to implement
/// `ExpandShapeOp::inferOutputShape` in both the memref and tensor namespaces.
/// If you need to infer the output shape you should use the static method of
/// `ExpandShapeOp` instead of calling this.
///
/// `inputShape` is the shape of the tensor or memref being expanded as a
/// sequence of SSA values or constants. `expandedType` is the output shape of
/// the expand_shape operation. `reassociation` is the reassociation denoting
/// the output dims each input dim is mapped to.
///
/// Returns the output shape in `outputShape` and `staticOutputShape`, following
/// the conventions for the output_shape and static_output_shape inputs to the
/// expand_shape ops.
std::optional<SmallVector<OpFoldResult>>
inferExpandShapeOutputShape(OpBuilder &b, Location loc, ShapedType expandedType,
ArrayRef<ReassociationIndices> reassociation,
ArrayRef<OpFoldResult> inputShape);

/// Matches a ConstantIndexOp.
detail::op_matcher<arith::ConstantIndexOp> matchConstantIndex();

Expand Down
21 changes: 0 additions & 21 deletions mlir/include/mlir/Dialect/Utils/ReshapeOpsUtils.h
Original file line number Diff line number Diff line change
Expand Up @@ -30,27 +30,6 @@ using ReassociationExprs = SmallVector<AffineExpr, 2>;
/// Attribute name for the ArrayAttr which encodes reassociation indices.
constexpr StringRef getReassociationAttrName() { return "reassociation"; }

// Infer the output shape for a {memref|tensor}.expand_shape when it is possible
// to do so.
//
// Note: This should *only* be used to implement
// `ExpandShapeOp::inferOutputShape` in both the memref and tensor namespaces.
// If you need to infer the output shape you should use the static method of
// `ExpandShapeOp` instead of calling this.
//
// `inputShape` is the shape of the tensor or memref being expanded as a
// sequence of SSA values or constants. `expandedType` is the output shape of
// the expand_shape operation. `reassociation` is the reassociation denoting
// the output dims each input dim is mapped to.
//
// Returns the output shape in `outputShape` and `staticOutputShape`, following
// the conventions for the output_shape and static_output_shape inputs to the
// expand_shape ops.
std::optional<SmallVector<OpFoldResult>>
inferExpandShapeOutputShape(OpBuilder &b, Location loc, ShapedType expandedType,
ArrayRef<ReassociationIndices> reassociation,
ArrayRef<OpFoldResult> inputShape);

/// Compose reassociation maps that are used in pair of reshape ops where one
/// is a producer and other is the consumer. Only valid to use this method when
/// both the producer and consumer are collapsing dimensions or both are
Expand Down
1 change: 1 addition & 0 deletions mlir/lib/Dialect/Arith/Utils/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -8,5 +8,6 @@ add_mlir_dialect_library(MLIRArithUtils
MLIRArithDialect
MLIRComplexDialect
MLIRDialect
MLIRDialectUtils
MLIRIR
)
62 changes: 62 additions & 0 deletions mlir/lib/Dialect/Arith/Utils/Utils.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -13,12 +13,74 @@
#include "mlir/Dialect/Arith/Utils/Utils.h"
#include "mlir/Dialect/Arith/IR/Arith.h"
#include "mlir/Dialect/Complex/IR/Complex.h"
#include "mlir/Dialect/Utils/StaticValueUtils.h"
#include "mlir/IR/ImplicitLocOpBuilder.h"
#include "llvm/ADT/SmallBitVector.h"
#include <numeric>

using namespace mlir;

std::optional<SmallVector<OpFoldResult>>
mlir::inferExpandShapeOutputShape(OpBuilder &b, Location loc,
ShapedType expandedType,
ArrayRef<ReassociationIndices> reassociation,
ArrayRef<OpFoldResult> inputShape) {

SmallVector<Value> outputShapeValues;
SmallVector<int64_t> outputShapeInts;
// For zero-rank inputs, all dims in result shape are unit extent.
if (inputShape.empty()) {
outputShapeInts.resize(expandedType.getRank(), 1);
return getMixedValues(outputShapeInts, outputShapeValues, b);
}

// Check for all static shapes.
if (expandedType.hasStaticShape()) {
ArrayRef<int64_t> staticShape = expandedType.getShape();
outputShapeInts.assign(staticShape.begin(), staticShape.end());
return getMixedValues(outputShapeInts, outputShapeValues, b);
}

outputShapeInts.resize(expandedType.getRank(), ShapedType::kDynamic);
for (const auto &it : llvm::enumerate(reassociation)) {
ReassociationIndices indexGroup = it.value();

int64_t indexGroupStaticSizesProductInt = 1;
bool foundDynamicShape = false;
for (int64_t index : indexGroup) {
int64_t outputDimSize = expandedType.getDimSize(index);
// Cannot infer expanded shape with multiple dynamic dims in the
// same reassociation group!
if (ShapedType::isDynamic(outputDimSize)) {
if (foundDynamicShape)
return std::nullopt;
foundDynamicShape = true;
} else {
outputShapeInts[index] = outputDimSize;
indexGroupStaticSizesProductInt *= outputDimSize;
}
}
if (!foundDynamicShape)
continue;

int64_t inputIndex = it.index();
// Call get<Value>() under the assumption that we're not casting
// dynamism.
Value indexGroupSize = inputShape[inputIndex].get<Value>();
Value indexGroupStaticSizesProduct =
b.create<arith::ConstantIndexOp>(loc, indexGroupStaticSizesProductInt);
Value dynamicDimSize = b.createOrFold<arith::DivUIOp>(
loc, indexGroupSize, indexGroupStaticSizesProduct);
outputShapeValues.push_back(dynamicDimSize);
}

if ((int64_t)outputShapeValues.size() !=
llvm::count(outputShapeInts, ShapedType::kDynamic))
return std::nullopt;

return getMixedValues(outputShapeInts, outputShapeValues, b);
}

/// Matches a ConstantIndexOp.
/// TODO: This should probably just be a general matcher that uses matchConstant
/// and checks the operation for an index type.
Expand Down
1 change: 0 additions & 1 deletion mlir/lib/Dialect/Utils/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,5 @@ add_mlir_library(MLIRDialectUtils
MLIRDialectUtilsIncGen

LINK_LIBS PUBLIC
MLIRArithDialect
MLIRIR
)
62 changes: 0 additions & 62 deletions mlir/lib/Dialect/Utils/ReshapeOpsUtils.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,6 @@

#include "mlir/Dialect/Utils/ReshapeOpsUtils.h"

#include "mlir/Dialect/Arith/IR/Arith.h"
#include "mlir/IR/AffineMap.h"
#include "mlir/IR/Builders.h"

Expand All @@ -17,67 +16,6 @@

using namespace mlir;

std::optional<SmallVector<OpFoldResult>>
mlir::inferExpandShapeOutputShape(OpBuilder &b, Location loc,
ShapedType expandedType,
ArrayRef<ReassociationIndices> reassociation,
ArrayRef<OpFoldResult> inputShape) {

SmallVector<Value> outputShapeValues;
SmallVector<int64_t> outputShapeInts;
// For zero-rank inputs, all dims in result shape are unit extent.
if (inputShape.empty()) {
outputShapeInts.resize(expandedType.getRank(), 1);
return getMixedValues(outputShapeInts, outputShapeValues, b);
}

// Check for all static shapes.
if (expandedType.hasStaticShape()) {
ArrayRef<int64_t> staticShape = expandedType.getShape();
outputShapeInts.assign(staticShape.begin(), staticShape.end());
return getMixedValues(outputShapeInts, outputShapeValues, b);
}

outputShapeInts.resize(expandedType.getRank(), ShapedType::kDynamic);
for (const auto &it : llvm::enumerate(reassociation)) {
ReassociationIndices indexGroup = it.value();

int64_t indexGroupStaticSizesProductInt = 1;
bool foundDynamicShape = false;
for (int64_t index : indexGroup) {
int64_t outputDimSize = expandedType.getDimSize(index);
// Cannot infer expanded shape with multiple dynamic dims in the
// same reassociation group!
if (ShapedType::isDynamic(outputDimSize)) {
if (foundDynamicShape)
return std::nullopt;
foundDynamicShape = true;
} else {
outputShapeInts[index] = outputDimSize;
indexGroupStaticSizesProductInt *= outputDimSize;
}
}
if (!foundDynamicShape)
continue;

int64_t inputIndex = it.index();
// Call get<Value>() under the assumption that we're not casting
// dynamism.
Value indexGroupSize = inputShape[inputIndex].get<Value>();
Value indexGroupStaticSizesProduct =
b.create<arith::ConstantIndexOp>(loc, indexGroupStaticSizesProductInt);
Value dynamicDimSize = b.createOrFold<arith::DivUIOp>(
loc, indexGroupSize, indexGroupStaticSizesProduct);
outputShapeValues.push_back(dynamicDimSize);
}

if ((int64_t)outputShapeValues.size() !=
llvm::count(outputShapeInts, ShapedType::kDynamic))
return std::nullopt;

return getMixedValues(outputShapeInts, outputShapeValues, b);
}

std::optional<SmallVector<ReassociationIndices>>
mlir::getReassociationIndicesForReshape(ShapedType sourceType,
ShapedType targetType) {
Expand Down
2 changes: 1 addition & 1 deletion utils/bazel/llvm-project-overlay/mlir/BUILD.bazel
Original file line number Diff line number Diff line change
Expand Up @@ -3875,7 +3875,6 @@ cc_library(
includes = ["include"],
deps = [
":DialectUtilsIncGen",
":ArithDialect",
":IR",
":Support",
"//llvm:Support",
Expand Down Expand Up @@ -12635,6 +12634,7 @@ cc_library(
deps = [
":ArithDialect",
":ComplexDialect",
":DialectUtils",
":IR",
"//llvm:Support",
],
Expand Down

0 comments on commit c4d9869

Please sign in to comment.