-
Notifications
You must be signed in to change notification settings - Fork 15.4k
[mlir] Fold memref.cast static-to-dynamic to memref.expand_shape #170037
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
|
Thank you for submitting a Pull Request (PR) to the LLVM Project! This PR will be automatically labeled and the relevant teams will be notified. If you wish to, you can add reviewers by using the "Reviewers" section on this page. If this is not working for you, it is probably because you do not have write permissions for the repository. In which case you can instead tag reviewers by name in a comment by using If you have received no comments on your PR for a week, you can request a review by "ping"ing the PR by adding a comment “Ping”. The common courtesy "ping" rate is once a week. Please remember that you are asking for valuable time from other developers. If you have further questions, they may be answered by the LLVM GitHub User Guide. You can also ask questions in a comment on this PR, on the LLVM Discord or on the forums. |
|
@llvm/pr-subscribers-mlir-memref @llvm/pr-subscribers-mlir Author: None (kdmitry1) Changesmemref.expand_shape didn't have memref.cast op folder. Added canonicalization pattern to allow folding of memref.cast from static to dynamic. Example: %0 = memref.cast %arg0 : memref<8x4xf32> to memref<?x4xf32>
%c0 = arith.constant 0 : index
%dim0 = memref.dim %0, %c0 : memref<?x4xf32>
%1 = memref.expand_shape %0 [[0, 1], [2]] output_shape [%dim0, 1, 4] : memref<?x4xf32> into memref<?x1x4xf32>is converted to: %expand_shape = memref.expand_shape %arg0 [[0, 1], [2]] output_shape [8, 1, 4] : memref<8x4xf32> into memref<8x1x4xf32>
%cast = memref.cast %expand_shape : memref<8x1x4xf32> to memref<?x1x4xf32>
Full diff: https://github.com/llvm/llvm-project/pull/170037.diff 2 Files Affected:
diff --git a/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp b/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp
index 1035d7cb46e6e..49dc23b702875 100644
--- a/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp
+++ b/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp
@@ -2504,11 +2504,89 @@ LogicalResult ExpandShapeOp::verify() {
return success();
}
+struct ExpandShapeOpMemRefCastFolder : public OpRewritePattern<ExpandShapeOp> {
+public:
+ using OpRewritePattern<ExpandShapeOp>::OpRewritePattern;
+
+ LogicalResult matchAndRewrite(ExpandShapeOp op,
+ PatternRewriter &rewriter) const override {
+ auto cast = op.getSrc().getDefiningOp<CastOp>();
+ if (!cast)
+ return failure();
+
+ if (!CastOp::canFoldIntoConsumerOp(cast))
+ return failure();
+
+ auto originalOutputShape = op.getMixedOutputShape();
+ auto newOutputShape = originalOutputShape;
+ SmallVector<int64_t> newOutputShapeSizes;
+ SmallVector<Value> newOperands;
+
+ // Convert output shape dims from dynamic to static where possible.
+ for (auto [dimIdx, dimSize] : enumerate(originalOutputShape)) {
+ auto dimVal = dimSize.dyn_cast<Value>();
+ if (!dimVal) {
+ newOutputShapeSizes.push_back(getConstantIntValue(dimSize).value());
+ continue;
+ }
+
+ auto constOp = dimVal.getDefiningOp<arith::ConstantIndexOp>();
+ if (!constOp) {
+ newOperands.push_back(dimVal);
+ newOutputShapeSizes.push_back(ShapedType::kDynamic);
+ continue;
+ }
+
+ newOutputShape[dimIdx] = constOp.getValue();
+ newOutputShapeSizes.push_back(
+ getConstantIntValue(constOp.getValue()).value());
+ }
+
+ if (newOperands.size() == op->getNumOperands())
+ return rewriter.notifyMatchFailure(
+ op, "no static-to-dynamic conversions found");
+
+ auto castSource = cast.getSource();
+ auto castSourceType = llvm::cast<MemRefType>(castSource.getType());
+ auto reassociationIndices = op.getReassociationIndices();
+ for (auto [idx, group] : llvm::enumerate(reassociationIndices)) {
+ int64_t castSourceDynCount = castSourceType.isDynamicDim(idx) ? 1 : 0;
+ auto newOutputShapeSizesSlice =
+ ArrayRef(newOutputShapeSizes).slice(group.front(), group.size());
+ int64_t newOutputDynCount =
+ llvm::count_if(newOutputShapeSizesSlice, ShapedType::isDynamic);
+ if (castSourceDynCount != newOutputDynCount)
+ return rewriter.notifyMatchFailure(
+ op, "folding cast will result in changing dynamicity in "
+ "reassociation group");
+ }
+
+ auto newResultTypeOrFailure = ExpandShapeOp::computeExpandedType(
+ castSourceType, newOutputShapeSizes, reassociationIndices);
+
+ if (failed(newResultTypeOrFailure))
+ return rewriter.notifyMatchFailure(
+ op, "could not compute new expanded type after folding cast");
+
+ if (*newResultTypeOrFailure == op.getResultType()) {
+ rewriter.modifyOpInPlace(
+ op, [&]() { op.getSrcMutable().assign(castSource); });
+ } else {
+ Value newOp = ExpandShapeOp::create(rewriter, op->getLoc(),
+ *newResultTypeOrFailure, castSource,
+ reassociationIndices, newOutputShape);
+ rewriter.replaceOpWithNewOp<CastOp>(op, op.getType(), newOp);
+ }
+ return success();
+ }
+};
+
void ExpandShapeOp::getCanonicalizationPatterns(RewritePatternSet &results,
MLIRContext *context) {
results.add<
ComposeReassociativeReshapeOps<ExpandShapeOp, ReshapeOpKind::kExpand>,
- ComposeExpandOfCollapseOp<ExpandShapeOp, CollapseShapeOp>>(context);
+ ComposeExpandOfCollapseOp<ExpandShapeOp, CollapseShapeOp>,
+ ExpandShapeOpMemRefCastFolder>(context);
}
FailureOr<std::optional<SmallVector<Value>>>
diff --git a/mlir/test/Dialect/MemRef/canonicalize.mlir b/mlir/test/Dialect/MemRef/canonicalize.mlir
index e02717a2f5689..c2d0376fc9723 100644
--- a/mlir/test/Dialect/MemRef/canonicalize.mlir
+++ b/mlir/test/Dialect/MemRef/canonicalize.mlir
@@ -551,6 +551,90 @@ func.func @fold_memref_expand_cast(%arg0 : memref<?x?xf32>) -> memref<2x4x4xf32>
// -----
+// CHECK-LABEL: @fold_memref_expand_with_static_to_dynamic_cast
+// CHECK-NOT: memref.cast
+// CHECK: return
+func.func @fold_memref_expand_with_static_to_dynamic_cast(%arg0 : memref<8x4xf32>) -> memref<2x1x4x4xf32> {
+ %0 = memref.cast %arg0 : memref<8x4xf32> to memref<?x4xf32>
+ %c0 = arith.constant 0 : index
+ %dim0 = memref.dim %0, %c0 : memref<?x4xf32>
+ %c4 = arith.constant 4 : index
+ %dim_ext = arith.divui %dim0 , %c4: index
+ %1 = memref.expand_shape %0 [[0, 1, 2], [3]] output_shape [%dim_ext, 1, 4, 4]
+ : memref<?x4xf32> into memref<?x1x4x4xf32>
+ %2 = memref.cast %1 : memref<?x1x4x4xf32> to memref<2x1x4x4xf32>
+ return %2 : memref<2x1x4x4xf32>
+}
+
+// -----
+
+// CHECK-LABEL: func.func @fold_memref_expand_static_to_dynamic_partial(
+// CHECK-NOT: memref.cast
+// CHECK: return
+func.func @fold_memref_expand_static_to_dynamic_partial(%arg0 : memref<8x?xf32>) -> memref<1x8x1x?xf32> {
+ %0 = memref.cast %arg0 : memref<8x?xf32> to memref<?x?xf32>
+ %c0 = arith.constant 0 : index
+ %c1 = arith.constant 1 : index
+ %dim0 = memref.dim %0, %c0 : memref<?x?xf32>
+ %dim1 = memref.dim %0, %c1 : memref<?x?xf32>
+ %1 = memref.expand_shape %0 [[0, 1], [2, 3]] output_shape [1, %dim0, 1, %dim1]
+ : memref<?x?xf32> into memref<1x?x1x?xf32>
+ %2 = memref.cast %1 : memref<1x?x1x?xf32> to memref<1x8x1x?xf32>
+ return %2 : memref<1x8x1x?xf32>
+}
+
+// -----
+
+// CHECK-LABEL: func.func @fold_memref_expand_static_to_dynamic_partial1(
+// CHECK-NOT: memref.cast
+// CHECK: return
+func.func @fold_memref_expand_static_to_dynamic_partial1(%arg0 : memref<8x?xf32>) -> memref<1x8x1x?xf32> {
+ %0 = memref.cast %arg0 : memref<8x?xf32> to memref<?x?xf32>
+ %c0 = arith.constant 0 : index
+ %c1 = arith.constant 1 : index
+ %dim0 = memref.dim %0, %c0 : memref<?x?xf32>
+ %dim1 = memref.dim %0, %c1 : memref<?x?xf32>
+ %1 = memref.expand_shape %0 [[0, 1], [2, 3]] output_shape [%c1, %dim0, %c1, %dim1]
+ : memref<?x?xf32> into memref<?x?x?x?xf32>
+ %2 = memref.cast %1 : memref<?x?x?x?xf32> to memref<1x8x1x?xf32>
+ return %2 : memref<1x8x1x?xf32>
+}
+
+// -----
+
+// CHECK-LABEL: func.func @not_fold_memref_expand_with_dynamic_to_static_cast(
+// CHECK: memref.cast
+// CHECK: memref.expand_shape
+// CHECK: return
+// CHECK: }
+func.func @not_fold_memref_expand_with_dynamic_to_static_cast(%arg0 : memref<?x4xf32>) -> memref<2x1x4x4xf32> {
+ %0 = memref.cast %arg0 : memref<?x4xf32> to memref<8x4xf32>
+ %1 = memref.expand_shape %0 [[0, 1, 2], [3]] output_shape [2, 1, 4, 4]
+ : memref<8x4xf32> into memref<2x1x4x4xf32>
+ return %1 : memref<2x1x4x4xf32>
+}
+
+// -----
+
+// CHECK-LABEL: func.func @not_fold_memref_expand_static_to_dynamic_cast_if_really_dynamic(
+// CHECK: memref.cast
+// CHECK: memref.expand_shape
+// CHECK: memref.cast
+// CHECK: return
+// CHECK: }
+func.func @not_fold_memref_expand_static_to_dynamic_cast_if_really_dynamic(%arg0 : memref<8x4xf32>, %arg1 : index) -> memref<2x1x4x4xf32> {
+ %0 = memref.cast %arg0 : memref<8x4xf32> to memref<?x4xf32>
+ %c0 = arith.constant 0 : index
+ %dim0 = memref.dim %0, %c0 : memref<?x4xf32>
+ %dim_ext = arith.divui %dim0 , %arg1: index
+ %1 = memref.expand_shape %0 [[0, 1, 2], [3]] output_shape [%dim_ext, 1, 4, 4]
+ : memref<?x4xf32> into memref<?x1x4x4xf32>
+ %2 = memref.cast %1 : memref<?x1x4x4xf32> to memref<2x1x4x4xf32>
+ return %2 : memref<2x1x4x4xf32>
+}
+
+// -----
+
// CHECK-LABEL: func @collapse_after_memref_cast_type_change(
// CHECK-SAME: %[[INPUT:.*]]: memref<?x512x1x1xf32>) -> memref<?x?xf32> {
// CHECK: %[[COLLAPSED:.*]] = memref.collapse_shape %[[INPUT]]
|
| if (!CastOp::canFoldIntoConsumerOp(cast)) | ||
| return failure(); | ||
|
|
||
| auto originalOutputShape = op.getMixedOutputShape(); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can you spell out the type here and for newOutputShape?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
done
|
|
||
| // Convert output shape dims from dynamic to static where possible. | ||
| for (auto [dimIdx, dimSize] : enumerate(originalOutputShape)) { | ||
| auto dimVal = dimSize.dyn_cast<Value>(); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This notation is deprecated. Use dyn_cast<Value>(dimSize).
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Fixed
| for (auto [dimIdx, dimSize] : enumerate(originalOutputShape)) { | ||
| auto dimVal = dimSize.dyn_cast<Value>(); | ||
| if (!dimVal) { | ||
| newOutputShapeSizes.push_back(getConstantIntValue(dimSize).value()); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
getConstantIntValue works on OpFoldResult. There's no need to check whether it's a Value or Attribute.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Modified the whole loop to follow your proposal
| continue; | ||
| } | ||
|
|
||
| auto constOp = dimVal.getDefiningOp<arith::ConstantIndexOp>(); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
No need to hard-code ConstantIndexOp here. getConstantIntValue will extract the int64_t. If it fails, you can assume it's a dynamic value.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thank you for your review
| // ----- | ||
|
|
||
| // CHECK-LABEL: @fold_memref_expand_with_static_to_dynamic_cast | ||
| // CHECK-NOT: memref.cast |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can you update all the new test cases to include the types of the memref.expand_shape? E.g., // CHECK: memref.expand_shape {{.*}} : memref<...> to memref<...>
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done
qedawkins
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This LGTM % a few nits, make sure you get an approval from Matthias as well before landing.
| // CHECK-NOT: memref.cast | ||
| // CHECK: memref.expand_shape {{.*}} {{\[\[}}0, 1], [2, 3]] output_shape [1, 8, 1, %{{.*}}] : memref<8x?xf32> into memref<1x8x1x?xf32> | ||
| // CHECK-NOT: memref.cast | ||
| // CHECK: return |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit: Here and below, instead of using CHECK-NOT lines to make sure the cast propagated/folded the way you expected, you can match the labels like
// CHECK: %[[EXPAND:.+]] = memref.expand_shape ...
// CHECK: return %[[EXPAND]]
which is both a strong check and doesn't need CHECK-NOT's.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done
| // CHECK-LABEL: func.func @not_fold_memref_expand_static_to_dynamic_cast_if_really_dynamic( | ||
| // CHECK-SAME: %[[ARG0:.*]]: memref<8x4xf32>, | ||
| // CHECK-SAME: %[[ARG1:.*]]: index) -> memref<2x1x4x4xf32> { | ||
| // CHECK: %[[CONSTANT_0:.*]] = arith.constant 8 : index |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit: This is a more standard label name.
| // CHECK: %[[CONSTANT_0:.*]] = arith.constant 8 : index | |
| // CHECK: %[[C8:.*]] = arith.constant 8 : index |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done
| // CHECK-SAME: %[[ARG1:.*]]: index, %[[ARG2:.*]]: index) -> memref<8x1x?x?xf32> { | ||
| // CHECK-NOT: memref.cast | ||
| // CHECK: %[[EXPAND_SHAPE_0:.*]] = memref.expand_shape %[[ARG0]] {{\[\[}}0, 1], [2, 3]] output_shape [8, 1, %[[ARG1]], %[[ARG2]]] : memref<8x?xf32> into memref<8x1x?x?xf32> | ||
| // CHECK-NOT: memref.cast |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
the CHECK-NOT's for memref.cast aren't really needed here. We already know the live IR was transformed the way we wanted so we don't really care if there is still a cast sticking around.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done. Got the idea. Thank you.
|
@qedawkins, @matthias-springer, please note I've added some minor changes |
memref.expand_shape didn't have memref.cast op folder. Added canonicalization pattern to allow folding of memref.cast from static to dynamic.
Example:
is converted to: