diff --git a/mlir/lib/Dialect/Arith/IR/ArithOps.cpp b/mlir/lib/Dialect/Arith/IR/ArithOps.cpp index 5e1c4302a16c4..007b105d2328c 100644 --- a/mlir/lib/Dialect/Arith/IR/ArithOps.cpp +++ b/mlir/lib/Dialect/Arith/IR/ArithOps.cpp @@ -13,6 +13,7 @@ #include "mlir/Dialect/Arith/IR/Arith.h" #include "mlir/Dialect/CommonFolders.h" +#include "mlir/Dialect/UB/IR/UBOps.h" #include "mlir/IR/Builders.h" #include "mlir/IR/BuiltinAttributeInterfaces.h" #include "mlir/IR/BuiltinAttributes.h" @@ -2194,6 +2195,13 @@ OpFoldResult arith::SelectOp::fold(FoldAdaptor adaptor) { if (matchPattern(condition, m_Zero())) return falseVal; + // If either operand is fully poisoned, return the other. + if (isa_and_nonnull(adaptor.getTrueValue())) + return falseVal; + + if (isa_and_nonnull(adaptor.getFalseValue())) + return trueVal; + // select %x, true, false => %x if (getType().isInteger(1) && matchPattern(getTrueValue(), m_One()) && matchPattern(getFalseValue(), m_Zero())) diff --git a/mlir/lib/Dialect/Arith/IR/CMakeLists.txt b/mlir/lib/Dialect/Arith/IR/CMakeLists.txt index fdbeb39e60c06..4beb99ccfdfba 100644 --- a/mlir/lib/Dialect/Arith/IR/CMakeLists.txt +++ b/mlir/lib/Dialect/Arith/IR/CMakeLists.txt @@ -28,6 +28,7 @@ add_mlir_dialect_library(MLIRArithDialect MLIRInferIntRangeInterface MLIRInferTypeOpInterface MLIRIR + MLIRUBDialect ) add_mlir_dialect_library(MLIRArithValueBoundsOpInterfaceImpl diff --git a/mlir/test/Dialect/Arith/canonicalize.mlir b/mlir/test/Dialect/Arith/canonicalize.mlir index 5b392fe9cf58a..0c8e0974b017d 100644 --- a/mlir/test/Dialect/Arith/canonicalize.mlir +++ b/mlir/test/Dialect/Arith/canonicalize.mlir @@ -2567,3 +2567,20 @@ func.func @foldOrXor6(%arg0: index) -> index { %2 = arith.ori %arg0, %1 : index return %2 : index } + +// CHECK-LABEL: @selectOfPoison +// CHECK-SAME: %[[ARG:[[:alnum:]]+]]: i32 +// CHECK: %[[UB:.*]] = ub.poison : i32 +// CHECK: return %[[ARG]], %[[ARG]], %[[UB]], %[[ARG]] +func.func @selectOfPoison(%cond : i1, %arg: i32) -> (i32, i32, i32, i32) { + %poison = ub.poison : i32 + %select1 = arith.select %cond, %poison, %arg : i32 + %select2 = arith.select %cond, %arg, %poison : i32 + + // Check that constant folding is applied prior to poison handling. + %true = arith.constant true + %false = arith.constant false + %select3 = arith.select %true, %poison, %arg : i32 + %select4 = arith.select %false, %poison, %arg : i32 + return %select1, %select2, %select3, %select4 : i32, i32, i32, i32 +}