diff --git a/mlir/lib/Dialect/Tosa/Transforms/TosaValidation.cpp b/mlir/lib/Dialect/Tosa/Transforms/TosaValidation.cpp index 790bbf77877bc..e6091df367754 100644 --- a/mlir/lib/Dialect/Tosa/Transforms/TosaValidation.cpp +++ b/mlir/lib/Dialect/Tosa/Transforms/TosaValidation.cpp @@ -1257,8 +1257,8 @@ bool checkErrorIfCondIf(Operation *op) { // tosa.yield %arg4 // } - return failed(checkIsolatedRegion(op, ifOp.getThenGraph(), "then")) || - failed(checkIsolatedRegion(op, ifOp.getElseGraph(), "else")); + return succeeded(checkIsolatedRegion(op, ifOp.getThenGraph(), "then")) && + succeeded(checkIsolatedRegion(op, ifOp.getElseGraph(), "else")); } bool checkErrorIfWhileLoop(Operation *op) { @@ -1266,8 +1266,8 @@ bool checkErrorIfWhileLoop(Operation *op) { if (!whileOp) return true; - return failed(checkIsolatedRegion(op, whileOp.getCondGraph(), "cond")) || - failed(checkIsolatedRegion(op, whileOp.getBodyGraph(), "body")); + return succeeded(checkIsolatedRegion(op, whileOp.getCondGraph(), "cond")) && + succeeded(checkIsolatedRegion(op, whileOp.getBodyGraph(), "body")); } bool checkErrorIfScatter(Operation *op) { diff --git a/mlir/test/Dialect/Tosa/error_if_check.mlir b/mlir/test/Dialect/Tosa/error_if_check.mlir index 290773b23193f..2f9421c43d2fb 100644 --- a/mlir/test/Dialect/Tosa/error_if_check.mlir +++ b/mlir/test/Dialect/Tosa/error_if_check.mlir @@ -269,20 +269,6 @@ func.func @test_cond_if_simplified_form_not_isolated_from_above(%arg0: tensor, %arg1: tensor, %arg2: tensor) -> tensor { - %0 = "tosa.cond_if"(%arg2, %arg0, %arg1) ({ - ^bb0(%arg3: tensor, %arg4: tensor): - tosa.yield %arg3 : tensor - }, { - ^bb0(%arg3: tensor, %arg4: tensor): - tosa.yield %arg4 : tensor - }) : (tensor, tensor, tensor) -> tensor - return %0 : tensor -} - -// ----- - func.func @test_while_loop_cond_not_isolated_from_above(%arg0: tensor, %arg1: tensor, %arg2: tensor) { %0 = "tosa.const"() {values = dense<0> : tensor} : () -> tensor // expected-error@+1 {{'tosa.while_loop' op is not conformant to the TOSA specification. It requires the 'cond' region is isolated from above.}} @@ -318,22 +304,3 @@ func.func @test_while_loop_body_not_isolated_from_above(%arg0: tensor, %arg }) : (tensor) -> (tensor) return } - -// ----- - -// Check isolated while_loops are valid -func.func @test_while_loop_isolated_from_above(%arg0: tensor, %arg1: tensor) { - %0 = "tosa.const"() {values = dense<0> : tensor} : () -> tensor - %1:3 = "tosa.while_loop"(%0, %arg0, %arg1) ({ - ^bb0(%arg3: tensor, %arg4: tensor, %arg5: tensor): - %2 = "tosa.greater_equal"(%arg3, %arg5) : (tensor, tensor) -> tensor - %3 = "tosa.logical_not"(%2) : (tensor) -> tensor - "tosa.yield"(%3) : (tensor) -> () - }, { - ^bb0(%arg3: tensor, %arg4: tensor, %arg5: tensor): - %2 = "tosa.const"() {values = dense<1> : tensor} : () -> tensor - %3 = "tosa.add"(%arg3, %2) : (tensor, tensor) -> tensor - "tosa.yield"(%3, %arg4, %arg5) : (tensor, tensor, tensor) -> () - }) : (tensor, tensor, tensor) -> (tensor, tensor, tensor) - return -} diff --git a/mlir/test/Dialect/Tosa/tosa-validation-valid-strict.mlir b/mlir/test/Dialect/Tosa/tosa-validation-valid-strict.mlir new file mode 100644 index 0000000000000..f05ae7f58261d --- /dev/null +++ b/mlir/test/Dialect/Tosa/tosa-validation-valid-strict.mlir @@ -0,0 +1,34 @@ +// RUN: mlir-opt %s -split-input-file -verify-diagnostics --tosa-validate="profile=pro_int,pro_fp extension=int16,int4,bf16,fp8e4m3,fp8e5m2,fft,variable,controlflow,doubleround,inexactround strict-op-spec-alignment" | FileCheck %s + +// ----- + +// CHECK-LABEL: test_cond_if_isolated_from_above +func.func @test_cond_if_isolated_from_above(%arg0: tensor, %arg1: tensor, %arg2: tensor) -> tensor { + %0 = "tosa.cond_if"(%arg2, %arg0, %arg1) ({ + ^bb0(%arg3: tensor, %arg4: tensor): + tosa.yield %arg3 : tensor + }, { + ^bb0(%arg3: tensor, %arg4: tensor): + tosa.yield %arg4 : tensor + }) : (tensor, tensor, tensor) -> tensor + return %0 : tensor +} + +// ----- + +// CHECK-LABEL: test_while_loop_isolated_from_above +func.func @test_while_loop_isolated_from_above(%arg0: tensor, %arg1: tensor) { + %0 = "tosa.const"() {values = dense<0> : tensor} : () -> tensor + %1:3 = "tosa.while_loop"(%0, %arg0, %arg1) ({ + ^bb0(%arg3: tensor, %arg4: tensor, %arg5: tensor): + %2 = "tosa.greater_equal"(%arg3, %arg5) : (tensor, tensor) -> tensor + %3 = "tosa.logical_not"(%2) : (tensor) -> tensor + "tosa.yield"(%3) : (tensor) -> () + }, { + ^bb0(%arg3: tensor, %arg4: tensor, %arg5: tensor): + %2 = "tosa.const"() {values = dense<1> : tensor} : () -> tensor + %3 = "tosa.add"(%arg3, %2) : (tensor, tensor) -> tensor + "tosa.yield"(%3, %arg4, %arg5) : (tensor, tensor, tensor) -> () + }) : (tensor, tensor, tensor) -> (tensor, tensor, tensor) + return +}