Skip to content

Commit

Permalink
[ConstraintElim] Add facts implied by MinMaxIntrinsic
Browse files Browse the repository at this point in the history
Fixes #63896 and rust-lang/rust#113757.
This patch adds facts implied by llvm.smin/smax/umin/umax intrinsics.

Reviewed By: fhahn

Differential Revision: https://reviews.llvm.org/D155412
  • Loading branch information
dtcxzyw committed Jul 24, 2023
1 parent 047273f commit 92a11eb
Show file tree
Hide file tree
Showing 2 changed files with 37 additions and 21 deletions.
36 changes: 26 additions & 10 deletions llvm/lib/Transforms/Scalar/ConstraintElimination.cpp
Expand Up @@ -784,6 +784,11 @@ void State::addInfoFor(BasicBlock &BB) {
continue;
}

if (isa<MinMaxIntrinsic>(&I)) {
WorkList.push_back(FactOrCheck::getFact(DT.getNode(&BB), &I));
continue;
}

Value *Cond;
// For now, just handle assumes with a single compare as condition.
if (match(&I, m_Intrinsic<Intrinsic::assume>(m_Value(Cond))) &&
Expand Down Expand Up @@ -1363,22 +1368,14 @@ static bool eliminateConstraints(Function &F, DominatorTree &DT,
}

LLVM_DEBUG(dbgs() << "fact to add to the system: " << *CB.Inst << "\n");
ICmpInst::Predicate Pred;
Value *A, *B;
Value *Cmp = CB.Inst;
match(Cmp, m_Intrinsic<Intrinsic::assume>(m_Value(Cmp)));
if (match(Cmp, m_ICmp(Pred, m_Value(A), m_Value(B)))) {
auto AddFact = [&](CmpInst::Predicate Pred, Value *A, Value *B) {
if (Info.getCS(CmpInst::isSigned(Pred)).size() > MaxRows) {
LLVM_DEBUG(
dbgs()
<< "Skip adding constraint because system has too many rows.\n");
continue;
return;
}

// Use the inverse predicate if required.
if (CB.Not)
Pred = CmpInst::getInversePredicate(Pred);

Info.addFact(Pred, A, B, CB.NumIn, CB.NumOut, DFSInStack);
if (ReproducerModule && DFSInStack.size() > ReproducerCondStack.size())
ReproducerCondStack.emplace_back(Pred, A, B);
Expand All @@ -1394,6 +1391,25 @@ static bool eliminateConstraints(Function &F, DominatorTree &DT,
nullptr, nullptr);
}
}
};

ICmpInst::Predicate Pred;
if (auto *MinMax = dyn_cast<MinMaxIntrinsic>(CB.Inst)) {
Pred = ICmpInst::getNonStrictPredicate(MinMax->getPredicate());
AddFact(Pred, MinMax, MinMax->getLHS());
AddFact(Pred, MinMax, MinMax->getRHS());
continue;
}

Value *A, *B;
Value *Cmp = CB.Inst;
match(Cmp, m_Intrinsic<Intrinsic::assume>(m_Value(Cmp)));
if (match(Cmp, m_ICmp(Pred, m_Value(A), m_Value(B)))) {
// Use the inverse predicate if required.
if (CB.Not)
Pred = CmpInst::getInversePredicate(Pred);

AddFact(Pred, A, B);
}
}

Expand Down
22 changes: 11 additions & 11 deletions llvm/test/Transforms/ConstraintElimination/minmax.ll
Expand Up @@ -11,7 +11,7 @@ define i1 @umax_ugt(i32 %x, i32 %y) {
; CHECK: if:
; CHECK-NEXT: [[CMP2:%.*]] = icmp ugt i32 [[Y]], [[X]]
; CHECK-NEXT: [[CMP3:%.*]] = icmp uge i32 [[Y]], [[X]]
; CHECK-NEXT: [[RET:%.*]] = xor i1 [[CMP2]], [[CMP3]]
; CHECK-NEXT: [[RET:%.*]] = xor i1 true, true
; CHECK-NEXT: ret i1 [[RET]]
; CHECK: end:
; CHECK-NEXT: ret i1 false
Expand Down Expand Up @@ -39,7 +39,7 @@ define i1 @umax_uge(i32 %x, i32 %y) {
; CHECK: if:
; CHECK-NEXT: [[CMP2:%.*]] = icmp ugt i32 [[Y]], [[X]]
; CHECK-NEXT: [[CMP3:%.*]] = icmp uge i32 [[Y]], [[X]]
; CHECK-NEXT: [[RET:%.*]] = xor i1 [[CMP2]], [[CMP3]]
; CHECK-NEXT: [[RET:%.*]] = xor i1 [[CMP2]], true
; CHECK-NEXT: ret i1 [[RET]]
; CHECK: end:
; CHECK-NEXT: ret i1 false
Expand Down Expand Up @@ -67,7 +67,7 @@ define i1 @umin_ult(i32 %x, i32 %y) {
; CHECK: if:
; CHECK-NEXT: [[CMP2:%.*]] = icmp ult i32 [[Y]], [[X]]
; CHECK-NEXT: [[CMP3:%.*]] = icmp ule i32 [[Y]], [[X]]
; CHECK-NEXT: [[RET:%.*]] = xor i1 [[CMP2]], [[CMP3]]
; CHECK-NEXT: [[RET:%.*]] = xor i1 true, true
; CHECK-NEXT: ret i1 [[RET]]
; CHECK: end:
; CHECK-NEXT: ret i1 false
Expand Down Expand Up @@ -95,7 +95,7 @@ define i1 @umin_ule(i32 %x, i32 %y) {
; CHECK: if:
; CHECK-NEXT: [[CMP2:%.*]] = icmp ult i32 [[Y]], [[X]]
; CHECK-NEXT: [[CMP3:%.*]] = icmp ule i32 [[Y]], [[X]]
; CHECK-NEXT: [[RET:%.*]] = xor i1 [[CMP2]], [[CMP3]]
; CHECK-NEXT: [[RET:%.*]] = xor i1 [[CMP2]], true
; CHECK-NEXT: ret i1 [[RET]]
; CHECK: end:
; CHECK-NEXT: ret i1 false
Expand Down Expand Up @@ -123,7 +123,7 @@ define i1 @smax_sgt(i32 %x, i32 %y) {
; CHECK: if:
; CHECK-NEXT: [[CMP2:%.*]] = icmp sgt i32 [[Y]], [[X]]
; CHECK-NEXT: [[CMP3:%.*]] = icmp sge i32 [[Y]], [[X]]
; CHECK-NEXT: [[RET:%.*]] = xor i1 [[CMP2]], [[CMP3]]
; CHECK-NEXT: [[RET:%.*]] = xor i1 true, true
; CHECK-NEXT: ret i1 [[RET]]
; CHECK: end:
; CHECK-NEXT: ret i1 false
Expand Down Expand Up @@ -151,7 +151,7 @@ define i1 @smax_sge(i32 %x, i32 %y) {
; CHECK: if:
; CHECK-NEXT: [[CMP2:%.*]] = icmp sgt i32 [[Y]], [[X]]
; CHECK-NEXT: [[CMP3:%.*]] = icmp sge i32 [[Y]], [[X]]
; CHECK-NEXT: [[RET:%.*]] = xor i1 [[CMP2]], [[CMP3]]
; CHECK-NEXT: [[RET:%.*]] = xor i1 [[CMP2]], true
; CHECK-NEXT: ret i1 [[RET]]
; CHECK: end:
; CHECK-NEXT: ret i1 false
Expand Down Expand Up @@ -179,7 +179,7 @@ define i1 @smin_slt(i32 %x, i32 %y) {
; CHECK: if:
; CHECK-NEXT: [[CMP2:%.*]] = icmp slt i32 [[Y]], [[X]]
; CHECK-NEXT: [[CMP3:%.*]] = icmp sle i32 [[Y]], [[X]]
; CHECK-NEXT: [[RET:%.*]] = xor i1 [[CMP2]], [[CMP3]]
; CHECK-NEXT: [[RET:%.*]] = xor i1 true, true
; CHECK-NEXT: ret i1 [[RET]]
; CHECK: end:
; CHECK-NEXT: ret i1 false
Expand Down Expand Up @@ -207,7 +207,7 @@ define i1 @smin_sle(i32 %x, i32 %y) {
; CHECK: if:
; CHECK-NEXT: [[CMP2:%.*]] = icmp slt i32 [[Y]], [[X]]
; CHECK-NEXT: [[CMP3:%.*]] = icmp sle i32 [[Y]], [[X]]
; CHECK-NEXT: [[RET:%.*]] = xor i1 [[CMP2]], [[CMP3]]
; CHECK-NEXT: [[RET:%.*]] = xor i1 [[CMP2]], true
; CHECK-NEXT: ret i1 [[RET]]
; CHECK: end:
; CHECK-NEXT: ret i1 false
Expand Down Expand Up @@ -235,7 +235,7 @@ define i1 @umax_uge_ugt_with_add_nuw(i32 %x, i32 %y) {
; CHECK-NEXT: br i1 [[CMP]], label [[IF:%.*]], label [[END:%.*]]
; CHECK: if:
; CHECK-NEXT: [[CMP2:%.*]] = icmp ugt i32 [[Y]], [[X]]
; CHECK-NEXT: ret i1 [[CMP2]]
; CHECK-NEXT: ret i1 true
; CHECK: end:
; CHECK-NEXT: ret i1 false
;
Expand Down Expand Up @@ -297,7 +297,7 @@ define i1 @umax_ugt_ugt_both(i32 %x, i32 %y, i32 %z) {
; CHECK: if:
; CHECK-NEXT: [[CMP2:%.*]] = icmp ugt i32 [[Z]], [[X]]
; CHECK-NEXT: [[CMP3:%.*]] = icmp ugt i32 [[Z]], [[Y]]
; CHECK-NEXT: [[AND:%.*]] = xor i1 [[CMP2]], [[CMP3]]
; CHECK-NEXT: [[AND:%.*]] = xor i1 true, true
; CHECK-NEXT: ret i1 [[AND]]
; CHECK: end:
; CHECK-NEXT: ret i1 false
Expand All @@ -323,7 +323,7 @@ define i1 @smin_branchless(i32 %x, i32 %y) {
; CHECK-NEXT: [[MIN:%.*]] = call i32 @llvm.smin.i32(i32 [[X]], i32 [[Y]])
; CHECK-NEXT: [[CMP1:%.*]] = icmp sle i32 [[MIN]], [[X]]
; CHECK-NEXT: [[CMP2:%.*]] = icmp sgt i32 [[MIN]], [[X]]
; CHECK-NEXT: [[RET:%.*]] = xor i1 [[CMP1]], [[CMP2]]
; CHECK-NEXT: [[RET:%.*]] = xor i1 true, false
; CHECK-NEXT: ret i1 [[RET]]
;
entry:
Expand Down

0 comments on commit 92a11eb

Please sign in to comment.