diff --git a/mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp b/mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp index 606626dfe4d2c..34e7e4200cd44 100644 --- a/mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp +++ b/mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp @@ -1302,9 +1302,11 @@ OpFoldResult CastOp::fold(FoldAdaptor adaptor) { auto intVal = operand.getSplatValue(); auto bitwidth = outETy.getIntOrFloatBitWidth(); - if (trunc) { + // i1 types are boolean in TOSA + if (outETy.isInteger(1)) { + intVal = APInt(bitwidth, intVal.isZero() ? 0 : 1); + } else if (trunc) { intVal = intVal.trunc(bitwidth); - // i1 types are boolean in TOSA } else if (unsignIn || inIntType.isInteger(1)) { intVal = intVal.zext(bitwidth); } else { diff --git a/mlir/test/Dialect/Tosa/canonicalize.mlir b/mlir/test/Dialect/Tosa/canonicalize.mlir index 11c8d54fda055..6b55442a82a0a 100644 --- a/mlir/test/Dialect/Tosa/canonicalize.mlir +++ b/mlir/test/Dialect/Tosa/canonicalize.mlir @@ -1349,3 +1349,14 @@ func.func @test_fold_i1_to_i32_cast() -> tensor { %1 = "tosa.cast"(%0) : (tensor) -> tensor return %1 : tensor } + +// ----- + +// CHECK-LABEL: @test_fold_i32_to_i1_cast +// CHECK: %[[OUT:.*]] = "tosa.const"() <{values = dense : tensor}> : () -> tensor +// CHECK: return %[[OUT]] : tensor +func.func @test_fold_i32_to_i1_cast() -> tensor { + %0 = "tosa.const"() <{values = dense<10> : tensor}> : () -> tensor + %1 = "tosa.cast"(%0) : (tensor) -> tensor + return %1 : tensor +}