Skip to content

Conversation

badumbatish
Copy link
Contributor

Fixes #159650.

The current implementation ICE out if we access an IfOp's terminator when it doesn't have it. Instead the PR defers the job of verifying that a block would have at least a terminator.

The current implementation also crashes with cast if the terminator is not a YieldOp, the PR also defers the job of verification to the op itself.

@llvmbot
Copy link
Member

llvmbot commented Sep 19, 2025

@llvm/pr-subscribers-mlir-tosa

@llvm/pr-subscribers-mlir

Author: Jasmine Tang (badumbatish)

Changes

Fixes #159650.

The current implementation ICE out if we access an IfOp's terminator when it doesn't have it. Instead the PR defers the job of verifying that a block would have at least a terminator.

The current implementation also crashes with cast<YieldOp> if the terminator is not a YieldOp, the PR also defers the job of verification to the op itself.


Full diff: https://github.com/llvm/llvm-project/pull/159756.diff

2 Files Affected:

  • (modified) mlir/lib/Dialect/Tosa/IR/TosaOps.cpp (+20-12)
  • (modified) mlir/test/Dialect/Tosa/verifier.mlir (+25)
diff --git a/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp b/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp
index b4c87a34a0e5a..309ef0bfd4ca3 100644
--- a/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp
+++ b/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp
@@ -4025,19 +4025,27 @@ LogicalResult IfOp::verify() {
           .failed())
     return failure();
 
-  auto thenYield = cast<tosa::YieldOp>(getThenGraph().front().getTerminator());
-  if (errorIfTypeOrShapeMismatch(*this, thenYield.getInputs(),
-                                 "'then_graph' results", getOutputList(),
-                                 "'output_list'")
-          .failed())
-    return failure();
+  // MLIR will verify the absence of the terminator for us if otherwise.
+  if (getThenGraph().front().mightHaveTerminator()) {
+    auto thenYield =
+        dyn_cast<tosa::YieldOp>(getThenGraph().front().getTerminator());
+    if (thenYield && errorIfTypeOrShapeMismatch(
+                         *this, thenYield.getInputs(), "'then_graph' results",
+                         getOutputList(), "'output_list'")
+                         .failed())
+      return failure();
+  }
 
-  auto elseYield = cast<tosa::YieldOp>(getElseGraph().front().getTerminator());
-  if (errorIfTypeOrShapeMismatch(*this, elseYield.getInputs(),
-                                 "'else_graph' results", getOutputList(),
-                                 "'output_list'")
-          .failed())
-    return failure();
+  // MLIR will verify the absence of the terminator for us if otherwise.
+  if (getElseGraph().front().mightHaveTerminator()) {
+    auto elseYield =
+        dyn_cast<tosa::YieldOp>(getElseGraph().front().getTerminator());
+    if (elseYield && errorIfTypeOrShapeMismatch(
+                         *this, elseYield.getInputs(), "'else_graph' results",
+                         getOutputList(), "'output_list'")
+                         .failed())
+      return failure();
+  }
 
   auto condType = getCondition().getType();
   if (errorIfShapeNotSizeOne(*this, condType).failed())
diff --git a/mlir/test/Dialect/Tosa/verifier.mlir b/mlir/test/Dialect/Tosa/verifier.mlir
index f58ddb180ce4f..cee2af2ad0c07 100644
--- a/mlir/test/Dialect/Tosa/verifier.mlir
+++ b/mlir/test/Dialect/Tosa/verifier.mlir
@@ -438,6 +438,31 @@ func.func @test_pad_invalid_padding_value(%arg0: tensor<10xi8>, %arg1: tensor<1x
   return %1 : tensor<10xi8>
 }
 
+// -----
+func.func @test_cond_if_wrong_terminator_op(%arg0: tensor<i1>) -> tensor<i32> {
+  %0 = "tosa.cond_if"(%arg0) ({
+    %1 = "tosa.const"() <{values = dense<1> : tensor<i32>}> : () -> tensor<i32>
+    "tosa.yield"(%1) : (tensor<i32>) -> ()
+  }, {
+    // expected-error@+2 {{'func.return' op expects parent op 'func.func'}}
+    %2 = "tosa.const"() <{values = dense<2> : tensor<i32>}> : () -> tensor<i32>
+    "func.return"(%2) : (tensor<i32>) -> ()
+  }) : (tensor<i1>) -> tensor<i32>
+  return %0 : tensor<i32>
+}
+
+// -----
+func.func @test_cond_if_missing_terminator(%arg0: tensor<i1>) -> tensor<i32> {
+  %0 = "tosa.cond_if"(%arg0) ({
+    %1 = "tosa.const"() <{values = dense<1> : tensor<i32>}> : () -> tensor<i32>
+    "tosa.yield"(%1) : (tensor<i32>) -> ()
+  }, {
+    // expected-error@+1 {{block with no terminator}}
+    %2 = "tosa.const"() <{values = dense<2> : tensor<i32>}> : () -> tensor<i32>
+  }) : (tensor<i1>) -> tensor<i32>
+  return %0 : tensor<i32>
+}
+
 // -----
 
 func.func @test_cond_if_input_list_mismatch_then_block(%arg0: tensor<f32>, %arg1: tensor<f32>, %arg2: tensor<i1>) -> tensor<f32> {

Copy link
Collaborator

@joker-eph joker-eph left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LG overall

@lhutton1 lhutton1 changed the title [MLIR] Robustify Tosa_IfOp against null dereference and wrong assertion [mlir][tosa] Robustify Tosa_IfOp against null dereference and wrong assertion Sep 19, 2025
Copy link
Contributor

@lhutton1 lhutton1 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM overall, thanks @badumbatish. I suspect while might need a similar check?

@badumbatish
Copy link
Contributor Author

LGTM overall, thanks @badumbatish. I suspect while might need a similar check?

yep i'll give that operation a touch up too

@lhutton1 lhutton1 merged commit 370ea51 into llvm:main Sep 22, 2025
9 checks passed
lhutton1 pushed a commit that referenced this pull request Sep 22, 2025
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

[MLIR] TOSA IfOp crashes when region lacks required terminator
4 participants