Skip to content

Commit

Permalink
[mlir][arith] Fold select with poison
Browse files Browse the repository at this point in the history
If either of the operands of `select` is fully poisoned we can simply return the other.
This PR implements this optimization inside the `fold` method.

Note that this patch is the first to add a dependency on the UB dialect within Arith. Given this was inevitable (and part of the motivation) it should be fine I believe.

Differential Revision: https://reviews.llvm.org/D158986
  • Loading branch information
zero9178 committed Aug 29, 2023
1 parent b667e9c commit bbf0733
Show file tree
Hide file tree
Showing 3 changed files with 26 additions and 0 deletions.
8 changes: 8 additions & 0 deletions mlir/lib/Dialect/Arith/IR/ArithOps.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@

#include "mlir/Dialect/Arith/IR/Arith.h"
#include "mlir/Dialect/CommonFolders.h"
#include "mlir/Dialect/UB/IR/UBOps.h"
#include "mlir/IR/Builders.h"
#include "mlir/IR/BuiltinAttributeInterfaces.h"
#include "mlir/IR/BuiltinAttributes.h"
Expand Down Expand Up @@ -2194,6 +2195,13 @@ OpFoldResult arith::SelectOp::fold(FoldAdaptor adaptor) {
if (matchPattern(condition, m_Zero()))
return falseVal;

// If either operand is fully poisoned, return the other.
if (isa_and_nonnull<ub::PoisonAttr>(adaptor.getTrueValue()))
return falseVal;

if (isa_and_nonnull<ub::PoisonAttr>(adaptor.getFalseValue()))
return trueVal;

// select %x, true, false => %x
if (getType().isInteger(1) && matchPattern(getTrueValue(), m_One()) &&
matchPattern(getFalseValue(), m_Zero()))
Expand Down
1 change: 1 addition & 0 deletions mlir/lib/Dialect/Arith/IR/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@ add_mlir_dialect_library(MLIRArithDialect
MLIRInferIntRangeInterface
MLIRInferTypeOpInterface
MLIRIR
MLIRUBDialect
)

add_mlir_dialect_library(MLIRArithValueBoundsOpInterfaceImpl
Expand Down
17 changes: 17 additions & 0 deletions mlir/test/Dialect/Arith/canonicalize.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -2567,3 +2567,20 @@ func.func @foldOrXor6(%arg0: index) -> index {
%2 = arith.ori %arg0, %1 : index
return %2 : index
}

// CHECK-LABEL: @selectOfPoison
// CHECK-SAME: %[[ARG:[[:alnum:]]+]]: i32
// CHECK: %[[UB:.*]] = ub.poison : i32
// CHECK: return %[[ARG]], %[[ARG]], %[[UB]], %[[ARG]]
func.func @selectOfPoison(%cond : i1, %arg: i32) -> (i32, i32, i32, i32) {
%poison = ub.poison : i32
%select1 = arith.select %cond, %poison, %arg : i32
%select2 = arith.select %cond, %arg, %poison : i32

// Check that constant folding is applied prior to poison handling.
%true = arith.constant true
%false = arith.constant false
%select3 = arith.select %true, %poison, %arg : i32
%select4 = arith.select %false, %poison, %arg : i32
return %select1, %select2, %select3, %select4 : i32, i32, i32, i32
}

0 comments on commit bbf0733

Please sign in to comment.