-
Notifications
You must be signed in to change notification settings - Fork 15.2k
[mlir][tosa] Robustify Tosa_IfOp against null dereference and wrong assertion #159756
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
@llvm/pr-subscribers-mlir-tosa @llvm/pr-subscribers-mlir Author: Jasmine Tang (badumbatish) ChangesFixes #159650. The current implementation ICE out if we access an IfOp's terminator when it doesn't have it. Instead the PR defers the job of verifying that a block would have at least a terminator. The current implementation also crashes with cast<YieldOp> if the terminator is not a YieldOp, the PR also defers the job of verification to the op itself. Full diff: https://github.com/llvm/llvm-project/pull/159756.diff 2 Files Affected:
diff --git a/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp b/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp
index b4c87a34a0e5a..309ef0bfd4ca3 100644
--- a/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp
+++ b/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp
@@ -4025,19 +4025,27 @@ LogicalResult IfOp::verify() {
.failed())
return failure();
- auto thenYield = cast<tosa::YieldOp>(getThenGraph().front().getTerminator());
- if (errorIfTypeOrShapeMismatch(*this, thenYield.getInputs(),
- "'then_graph' results", getOutputList(),
- "'output_list'")
- .failed())
- return failure();
+ // MLIR will verify the absence of the terminator for us if otherwise.
+ if (getThenGraph().front().mightHaveTerminator()) {
+ auto thenYield =
+ dyn_cast<tosa::YieldOp>(getThenGraph().front().getTerminator());
+ if (thenYield && errorIfTypeOrShapeMismatch(
+ *this, thenYield.getInputs(), "'then_graph' results",
+ getOutputList(), "'output_list'")
+ .failed())
+ return failure();
+ }
- auto elseYield = cast<tosa::YieldOp>(getElseGraph().front().getTerminator());
- if (errorIfTypeOrShapeMismatch(*this, elseYield.getInputs(),
- "'else_graph' results", getOutputList(),
- "'output_list'")
- .failed())
- return failure();
+ // MLIR will verify the absence of the terminator for us if otherwise.
+ if (getElseGraph().front().mightHaveTerminator()) {
+ auto elseYield =
+ dyn_cast<tosa::YieldOp>(getElseGraph().front().getTerminator());
+ if (elseYield && errorIfTypeOrShapeMismatch(
+ *this, elseYield.getInputs(), "'else_graph' results",
+ getOutputList(), "'output_list'")
+ .failed())
+ return failure();
+ }
auto condType = getCondition().getType();
if (errorIfShapeNotSizeOne(*this, condType).failed())
diff --git a/mlir/test/Dialect/Tosa/verifier.mlir b/mlir/test/Dialect/Tosa/verifier.mlir
index f58ddb180ce4f..cee2af2ad0c07 100644
--- a/mlir/test/Dialect/Tosa/verifier.mlir
+++ b/mlir/test/Dialect/Tosa/verifier.mlir
@@ -438,6 +438,31 @@ func.func @test_pad_invalid_padding_value(%arg0: tensor<10xi8>, %arg1: tensor<1x
return %1 : tensor<10xi8>
}
+// -----
+func.func @test_cond_if_wrong_terminator_op(%arg0: tensor<i1>) -> tensor<i32> {
+ %0 = "tosa.cond_if"(%arg0) ({
+ %1 = "tosa.const"() <{values = dense<1> : tensor<i32>}> : () -> tensor<i32>
+ "tosa.yield"(%1) : (tensor<i32>) -> ()
+ }, {
+ // expected-error@+2 {{'func.return' op expects parent op 'func.func'}}
+ %2 = "tosa.const"() <{values = dense<2> : tensor<i32>}> : () -> tensor<i32>
+ "func.return"(%2) : (tensor<i32>) -> ()
+ }) : (tensor<i1>) -> tensor<i32>
+ return %0 : tensor<i32>
+}
+
+// -----
+func.func @test_cond_if_missing_terminator(%arg0: tensor<i1>) -> tensor<i32> {
+ %0 = "tosa.cond_if"(%arg0) ({
+ %1 = "tosa.const"() <{values = dense<1> : tensor<i32>}> : () -> tensor<i32>
+ "tosa.yield"(%1) : (tensor<i32>) -> ()
+ }, {
+ // expected-error@+1 {{block with no terminator}}
+ %2 = "tosa.const"() <{values = dense<2> : tensor<i32>}> : () -> tensor<i32>
+ }) : (tensor<i1>) -> tensor<i32>
+ return %0 : tensor<i32>
+}
+
// -----
func.func @test_cond_if_input_list_mismatch_then_block(%arg0: tensor<f32>, %arg1: tensor<f32>, %arg2: tensor<i1>) -> tensor<f32> {
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LG overall
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM overall, thanks @badumbatish. I suspect while
might need a similar check?
yep i'll give that operation a touch up too |
Fixes #159650.
The current implementation ICE out if we access an IfOp's terminator when it doesn't have it. Instead the PR defers the job of verifying that a block would have at least a terminator.
The current implementation also crashes with cast if the terminator is not a YieldOp, the PR also defers the job of verification to the op itself.