-
Notifications
You must be signed in to change notification settings - Fork 15.1k
[mlir][linalg-transform] dyn_cast DestinationStyleOpInterface and early return #166299
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
|
@llvm/pr-subscribers-mlir Author: Hsiang-Chieh Tsou (hsjts0u) ChangesUse module {
func.func @<!-- -->fuse(%arg0: tensor<4x4x4xf32>, %arg1: tensor<4x4x4xf32>) -> tensor<4x4x4xf32> {
%mul = "myop.a"(%arg0, %arg1) : (tensor<4x4x4xf32>, tensor<4x4x4xf32>) -> tensor<4x4x4xf32>
%add = "myop.b"(%mul, %mul) : (tensor<4x4x4xf32>, tensor<4x4x4xf32>) -> tensor<4x4x4xf32>
return %add : tensor<4x4x4xf32>
}
transform.sequence failures(propagate) {
^bb0(%func: !transform.any_op):
%mul = transform.structured.match ops{["myop.a"]} in %func : (!transform.any_op) -> !transform.any_op
%add = transform.structured.match ops{["myop.b"]} in %func : (!transform.any_op) -> !transform.any_op
%loop, %tiled = transform.structured.tile_using_forall %add tile_sizes [1, 2, 4] : (!transform.any_op) -> (!transform.any_op, !transform.any_op)
%mul_fused, %mul_containing = transform.structured.fuse_into_containing_op %mul into %tiled : (!transform.any_op, !transform.any_op) -> (!transform.any_op, !transform.any_op)
}
}Full diff: https://github.com/llvm/llvm-project/pull/166299.diff 1 Files Affected:
diff --git a/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp b/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
index 3a433825fd31a..59629c422a034 100644
--- a/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
+++ b/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
@@ -997,8 +997,11 @@ tileAndFuseFirstExtractUse(RewriterBase &rewriter, Diagnostic &diag,
// Iterate over the outputs of the producer and over the loop bbArgs and
// check if any bbArg points to the same value as the producer output. In
// such case, make the producer output point to the bbArg directly.
- for (OpOperand &initOperandPtr :
- cast<DestinationStyleOpInterface>(clone).getDpsInitsMutable()) {
+ auto dpsInterface = dyn_cast<DestinationStyleOpInterface>(clone);
+ if (!dpsInterface)
+ return;
+
+ for (OpOperand &initOperandPtr : dpsInterface.getDpsInitsMutable()) {
Value producerOperand =
clone->getOperand(initOperandPtr.getOperandNumber());
for (BlockArgument containerIterArg :
|
|
@llvm/pr-subscribers-mlir-linalg Author: Hsiang-Chieh Tsou (hsjts0u) ChangesUse module {
func.func @<!-- -->fuse(%arg0: tensor<4x4x4xf32>, %arg1: tensor<4x4x4xf32>) -> tensor<4x4x4xf32> {
%mul = "myop.a"(%arg0, %arg1) : (tensor<4x4x4xf32>, tensor<4x4x4xf32>) -> tensor<4x4x4xf32>
%add = "myop.b"(%mul, %mul) : (tensor<4x4x4xf32>, tensor<4x4x4xf32>) -> tensor<4x4x4xf32>
return %add : tensor<4x4x4xf32>
}
transform.sequence failures(propagate) {
^bb0(%func: !transform.any_op):
%mul = transform.structured.match ops{["myop.a"]} in %func : (!transform.any_op) -> !transform.any_op
%add = transform.structured.match ops{["myop.b"]} in %func : (!transform.any_op) -> !transform.any_op
%loop, %tiled = transform.structured.tile_using_forall %add tile_sizes [1, 2, 4] : (!transform.any_op) -> (!transform.any_op, !transform.any_op)
%mul_fused, %mul_containing = transform.structured.fuse_into_containing_op %mul into %tiled : (!transform.any_op, !transform.any_op) -> (!transform.any_op, !transform.any_op)
}
}Full diff: https://github.com/llvm/llvm-project/pull/166299.diff 1 Files Affected:
diff --git a/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp b/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
index 3a433825fd31a..59629c422a034 100644
--- a/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
+++ b/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
@@ -997,8 +997,11 @@ tileAndFuseFirstExtractUse(RewriterBase &rewriter, Diagnostic &diag,
// Iterate over the outputs of the producer and over the loop bbArgs and
// check if any bbArg points to the same value as the producer output. In
// such case, make the producer output point to the bbArg directly.
- for (OpOperand &initOperandPtr :
- cast<DestinationStyleOpInterface>(clone).getDpsInitsMutable()) {
+ auto dpsInterface = dyn_cast<DestinationStyleOpInterface>(clone);
+ if (!dpsInterface)
+ return;
+
+ for (OpOperand &initOperandPtr : dpsInterface.getDpsInitsMutable()) {
Value producerOperand =
clone->getOperand(initOperandPtr.getOperandNumber());
for (BlockArgument containerIterArg :
|
|
cc @pabloantoniom and @ftynse to review or tag relevant reviewers. Thanks! |
|
Bump on this |
MaheshRavishankar
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Could you check the producerOp instead of the clone. That way we wont even clone if we dont need to .
ftynse
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
For testing, you can add new ops into the test dialect and give them some trivial implementation of the interface, like do nothing or return failure. We only need to check there's no crash, not a particular behavior.
2c30f37 to
93b8734
Compare
|
✅ With the latest revision this PR passed the C/C++ code formatter. |
9091c8f to
238d9eb
Compare
238d9eb to
72dd81d
Compare
|
@ftynse Can you help merge this PR. Thank you! |
|
LLVM Buildbot has detected a new failure on builder Full details are available at: https://lab.llvm.org/buildbot/#/builders/169/builds/16808 Here is the relevant piece of the build log for the reference |
…ly return (llvm#166299) Use `dyn_cast` instead of `cast` and early return if op does not implement the `DestinationStyleOpInterface`. Before the change the following IR would cause a segfault when the transform interpreter is run, where `myop.a` and `myop.b` implement the `TilingInterface` and not the `DestinationStyleOpInterface`. Tried looking for ops in the upstream dialect that implement the `TilingInterface` and not the `DestinationStyleOpInterface` to add a test but could not find any. ```mlir module { func.func @fuse(%arg0: tensor<4x4x4xf32>, %arg1: tensor<4x4x4xf32>) -> tensor<4x4x4xf32> { %mul = "myop.a"(%arg0, %arg1) : (tensor<4x4x4xf32>, tensor<4x4x4xf32>) -> tensor<4x4x4xf32> %add = "myop.b"(%mul, %mul) : (tensor<4x4x4xf32>, tensor<4x4x4xf32>) -> tensor<4x4x4xf32> return %add : tensor<4x4x4xf32> } transform.sequence failures(propagate) { ^bb0(%func: !transform.any_op): %mul = transform.structured.match ops{["myop.a"]} in %func : (!transform.any_op) -> !transform.any_op %add = transform.structured.match ops{["myop.b"]} in %func : (!transform.any_op) -> !transform.any_op %loop, %tiled = transform.structured.tile_using_forall %add tile_sizes [1, 2, 4] : (!transform.any_op) -> (!transform.any_op, !transform.any_op) %mul_fused, %mul_containing = transform.structured.fuse_into_containing_op %mul into %tiled : (!transform.any_op, !transform.any_op) -> (!transform.any_op, !transform.any_op) } } ```
Use
dyn_castinstead ofcastand early return if op does not implement theDestinationStyleOpInterface. Before the change the following IR would cause a segfault when the transform interpreter is run, wheremyop.aandmyop.bimplement theTilingInterfaceand not theDestinationStyleOpInterface. Tried looking for ops in the upstream dialect that implement theTilingInterfaceand not theDestinationStyleOpInterfaceto add a test but could not find any.