diff --git a/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp b/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp index b4c87a34a0e5a..309ef0bfd4ca3 100644 --- a/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp +++ b/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp @@ -4025,19 +4025,27 @@ LogicalResult IfOp::verify() { .failed()) return failure(); - auto thenYield = cast(getThenGraph().front().getTerminator()); - if (errorIfTypeOrShapeMismatch(*this, thenYield.getInputs(), - "'then_graph' results", getOutputList(), - "'output_list'") - .failed()) - return failure(); + // MLIR will verify the absence of the terminator for us if otherwise. + if (getThenGraph().front().mightHaveTerminator()) { + auto thenYield = + dyn_cast(getThenGraph().front().getTerminator()); + if (thenYield && errorIfTypeOrShapeMismatch( + *this, thenYield.getInputs(), "'then_graph' results", + getOutputList(), "'output_list'") + .failed()) + return failure(); + } - auto elseYield = cast(getElseGraph().front().getTerminator()); - if (errorIfTypeOrShapeMismatch(*this, elseYield.getInputs(), - "'else_graph' results", getOutputList(), - "'output_list'") - .failed()) - return failure(); + // MLIR will verify the absence of the terminator for us if otherwise. + if (getElseGraph().front().mightHaveTerminator()) { + auto elseYield = + dyn_cast(getElseGraph().front().getTerminator()); + if (elseYield && errorIfTypeOrShapeMismatch( + *this, elseYield.getInputs(), "'else_graph' results", + getOutputList(), "'output_list'") + .failed()) + return failure(); + } auto condType = getCondition().getType(); if (errorIfShapeNotSizeOne(*this, condType).failed()) diff --git a/mlir/test/Dialect/Tosa/verifier.mlir b/mlir/test/Dialect/Tosa/verifier.mlir index f58ddb180ce4f..e5571b6b4412c 100644 --- a/mlir/test/Dialect/Tosa/verifier.mlir +++ b/mlir/test/Dialect/Tosa/verifier.mlir @@ -438,6 +438,43 @@ func.func @test_pad_invalid_padding_value(%arg0: tensor<10xi8>, %arg1: tensor<1x return %1 : tensor<10xi8> } +// ----- +func.func @test_cond_if_wrong_terminator_op(%arg0: tensor) -> tensor { + %0 = "tosa.cond_if"(%arg0) ({ + %1 = "tosa.const"() <{values = dense<1> : tensor}> : () -> tensor + "tosa.yield"(%1) : (tensor) -> () + }, { + // expected-error@+2 {{'func.return' op expects parent op 'func.func'}} + %2 = "tosa.const"() <{values = dense<2> : tensor}> : () -> tensor + "func.return"(%2) : (tensor) -> () + }) : (tensor) -> tensor + return %0 : tensor +} + +// ----- +func.func @test_cond_if_missing_then_terminator(%arg0: tensor) -> tensor { + %0 = "tosa.cond_if"(%arg0) ({ + // expected-error@+1 {{block with no terminator}} + %1 = "tosa.const"() <{values = dense<1> : tensor}> : () -> tensor + }, { + %2 = "tosa.const"() <{values = dense<2> : tensor}> : () -> tensor + "tosa.yield"(%2) : (tensor) -> () + }) : (tensor) -> tensor + return %0 : tensor +} + +// ----- +func.func @test_cond_if_missing_else_terminator(%arg0: tensor) -> tensor { + %0 = "tosa.cond_if"(%arg0) ({ + %1 = "tosa.const"() <{values = dense<1> : tensor}> : () -> tensor + "tosa.yield"(%1) : (tensor) -> () + }, { + // expected-error@+1 {{block with no terminator}} + %2 = "tosa.const"() <{values = dense<2> : tensor}> : () -> tensor + }) : (tensor) -> tensor + return %0 : tensor +} + // ----- func.func @test_cond_if_input_list_mismatch_then_block(%arg0: tensor, %arg1: tensor, %arg2: tensor) -> tensor {