Skip to content

Conversation

badumbatish
Copy link
Contributor

Follow up to #159756

@llvmbot
Copy link
Member

llvmbot commented Sep 20, 2025

@llvm/pr-subscribers-mlir

@llvm/pr-subscribers-mlir-tosa

Author: Jasmine Tang (badumbatish)

Changes

Follow up to #159756


Full diff: https://github.com/llvm/llvm-project/pull/159910.diff

2 Files Affected:

  • (modified) mlir/lib/Dialect/Tosa/IR/TosaOps.cpp (+17-7)
  • (modified) mlir/test/Dialect/Tosa/verifier.mlir (+42)
diff --git a/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp b/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp
index 1c0a6a618fcd2..c5133dfa9609e 100644
--- a/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp
+++ b/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp
@@ -4065,16 +4065,26 @@ LogicalResult WhileOp::verify() {
           .failed())
     return failure();
 
-  auto bodyYield = cast<tosa::YieldOp>(getBodyGraph().front().getTerminator());
-  if (errorIfTypeOrShapeMismatch(*this, bodyYield.getInputs(),
-                                 "'body_graph' results", getInputList(),
-                                 "'input_list'")
-          .failed())
-    return failure();
+  if (getBodyGraph().front().mightHaveTerminator()) {
+    auto bodyYield =
+        dyn_cast<tosa::YieldOp>(getBodyGraph().front().getTerminator());
+    if (bodyYield && errorIfTypeOrShapeMismatch(*this, bodyYield.getInputs(),
+                                                "'body_graph' results",
+                                                getInputList(), "'input_list'")
+                         .failed())
+      return failure();
+  }
 
   // Condition block output must be a single element tensor with a single bool
   // value.
-  auto condYield = cast<tosa::YieldOp>(getCondGraph().front().getTerminator());
+  if (!getCondGraph().front().mightHaveTerminator())
+    return success();
+
+  auto condYield =
+      dyn_cast<tosa::YieldOp>(getCondGraph().front().getTerminator());
+  if (!condYield)
+    return success();
+
   if (condYield.getInputs().size() != 1)
     return emitOpError() << "require 'cond_graph' only have one result";
 
diff --git a/mlir/test/Dialect/Tosa/verifier.mlir b/mlir/test/Dialect/Tosa/verifier.mlir
index f58ddb180ce4f..2e18fe46e21c6 100644
--- a/mlir/test/Dialect/Tosa/verifier.mlir
+++ b/mlir/test/Dialect/Tosa/verifier.mlir
@@ -649,6 +649,48 @@ func.func @test_cond_if_incorrect_type_simple(%arg0: tensor<f32>, %arg1: tensor<
   return %0 : tensor<f32>
 }
 
+// -----
+func.func @test_while_loop_wrong_terminator(%arg0: tensor<i32>, %arg1: tensor<i32>) -> tensor<i32> {
+    %0 = tosa.while_loop (%arg2 = %arg0) : (tensor<i32>) -> tensor<i32> {
+      // expected-error@+2 {{'func.return' op expects parent op 'func.func'}}
+      %1 = tosa.greater_equal %arg1, %arg2 : (tensor<i32>, tensor<i32>) -> tensor<i1>
+      "func.return"(%arg2) : (tensor<i32>) -> ()
+    } do {
+    ^bb0(%arg2: tensor<i32>):
+      %1 = "tosa.const"() <{values = dense<1> : tensor<i32>}> : () -> tensor<i32>
+      %2 = tosa.add %arg2, %1 : (tensor<i32>, tensor<i32>) -> tensor<i32>
+      tosa.yield %2 : tensor<i32>
+    }
+    return %0 : tensor<i32>
+}
+
+// -----
+func.func @test_while_loop_missing_cond_terminator(%arg0: tensor<i32>, %arg1: tensor<i32>) -> tensor<i32> {
+    %0 = tosa.while_loop (%arg2 = %arg0) : (tensor<i32>) -> tensor<i32> {
+      // expected-error@+1 {{block with no terminator}}
+      %1 = tosa.greater_equal %arg1, %arg2 : (tensor<i32>, tensor<i32>) -> tensor<i1>
+    } do {
+    ^bb0(%arg2: tensor<i32>):
+      %1 = "tosa.const"() <{values = dense<1> : tensor<i32>}> : () -> tensor<i32>
+      %2 = tosa.add %arg2, %1 : (tensor<i32>, tensor<i32>) -> tensor<i32>
+      tosa.yield %2 : tensor<i32>
+    }
+    return %0 : tensor<i32>
+}
+
+// -----
+func.func @test_while_loop_missing_body_terminator(%arg0: tensor<i32>, %arg1: tensor<i32>) -> tensor<i32> {
+    %0 = tosa.while_loop (%arg2 = %arg0) : (tensor<i32>) -> tensor<i32> {
+      %1 = tosa.greater_equal %arg1, %arg2 : (tensor<i32>, tensor<i32>) -> tensor<i1>
+      tosa.yield %1 : tensor<i1>
+    } do {
+    ^bb0(%arg2: tensor<i32>):
+      // expected-error@+1 {{block with no terminator}}
+      %1 = "tosa.const"() <{values = dense<1> : tensor<i32>}> : () -> tensor<i32>
+    }
+    return %0 : tensor<i32>
+}
+
 // -----
 
 func.func @test_while_loop_input_list_mismatch_body_block_in(%arg0: tensor<10xi32>, %arg1: tensor<i32>) {

Copy link
Contributor

@lhutton1 lhutton1 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks @badumbatish!

@lhutton1 lhutton1 merged commit eede476 into llvm:main Sep 22, 2025
12 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants