diff --git a/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp b/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp index 1c0a6a618fcd2..c5133dfa9609e 100644 --- a/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp +++ b/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp @@ -4065,16 +4065,26 @@ LogicalResult WhileOp::verify() { .failed()) return failure(); - auto bodyYield = cast(getBodyGraph().front().getTerminator()); - if (errorIfTypeOrShapeMismatch(*this, bodyYield.getInputs(), - "'body_graph' results", getInputList(), - "'input_list'") - .failed()) - return failure(); + if (getBodyGraph().front().mightHaveTerminator()) { + auto bodyYield = + dyn_cast(getBodyGraph().front().getTerminator()); + if (bodyYield && errorIfTypeOrShapeMismatch(*this, bodyYield.getInputs(), + "'body_graph' results", + getInputList(), "'input_list'") + .failed()) + return failure(); + } // Condition block output must be a single element tensor with a single bool // value. - auto condYield = cast(getCondGraph().front().getTerminator()); + if (!getCondGraph().front().mightHaveTerminator()) + return success(); + + auto condYield = + dyn_cast(getCondGraph().front().getTerminator()); + if (!condYield) + return success(); + if (condYield.getInputs().size() != 1) return emitOpError() << "require 'cond_graph' only have one result"; diff --git a/mlir/test/Dialect/Tosa/verifier.mlir b/mlir/test/Dialect/Tosa/verifier.mlir index f58ddb180ce4f..2e18fe46e21c6 100644 --- a/mlir/test/Dialect/Tosa/verifier.mlir +++ b/mlir/test/Dialect/Tosa/verifier.mlir @@ -649,6 +649,48 @@ func.func @test_cond_if_incorrect_type_simple(%arg0: tensor, %arg1: tensor< return %0 : tensor } +// ----- +func.func @test_while_loop_wrong_terminator(%arg0: tensor, %arg1: tensor) -> tensor { + %0 = tosa.while_loop (%arg2 = %arg0) : (tensor) -> tensor { + // expected-error@+2 {{'func.return' op expects parent op 'func.func'}} + %1 = tosa.greater_equal %arg1, %arg2 : (tensor, tensor) -> tensor + "func.return"(%arg2) : (tensor) -> () + } do { + ^bb0(%arg2: tensor): + %1 = "tosa.const"() <{values = dense<1> : tensor}> : () -> tensor + %2 = tosa.add %arg2, %1 : (tensor, tensor) -> tensor + tosa.yield %2 : tensor + } + return %0 : tensor +} + +// ----- +func.func @test_while_loop_missing_cond_terminator(%arg0: tensor, %arg1: tensor) -> tensor { + %0 = tosa.while_loop (%arg2 = %arg0) : (tensor) -> tensor { + // expected-error@+1 {{block with no terminator}} + %1 = tosa.greater_equal %arg1, %arg2 : (tensor, tensor) -> tensor + } do { + ^bb0(%arg2: tensor): + %1 = "tosa.const"() <{values = dense<1> : tensor}> : () -> tensor + %2 = tosa.add %arg2, %1 : (tensor, tensor) -> tensor + tosa.yield %2 : tensor + } + return %0 : tensor +} + +// ----- +func.func @test_while_loop_missing_body_terminator(%arg0: tensor, %arg1: tensor) -> tensor { + %0 = tosa.while_loop (%arg2 = %arg0) : (tensor) -> tensor { + %1 = tosa.greater_equal %arg1, %arg2 : (tensor, tensor) -> tensor + tosa.yield %1 : tensor + } do { + ^bb0(%arg2: tensor): + // expected-error@+1 {{block with no terminator}} + %1 = "tosa.const"() <{values = dense<1> : tensor}> : () -> tensor + } + return %0 : tensor +} + // ----- func.func @test_while_loop_input_list_mismatch_body_block_in(%arg0: tensor<10xi32>, %arg1: tensor) {