Skip to content

Commit

Permalink
[CVP] Handle use-site conditions in domain-based folds
Browse files Browse the repository at this point in the history
As a side-effect, this switchem them to use getConstantRange() rather
than getPredicateAt(). getPredicateAt() is not supposed to be more
powerful than getConstantRange() for non-equality comparisons (as
long as block values are used).
  • Loading branch information
nikic committed Jan 17, 2023
1 parent 12cb1cb commit a444fe0
Show file tree
Hide file tree
Showing 2 changed files with 23 additions and 35 deletions.
42 changes: 13 additions & 29 deletions llvm/lib/Transforms/Scalar/CorrelatedValuePropagation.cpp
Expand Up @@ -692,26 +692,13 @@ static bool processCallSite(CallBase &CB, LazyValueInfo *LVI) {
return true;
}

static bool isNonNegative(Value *V, LazyValueInfo *LVI, Instruction *CxtI) {
Constant *Zero = ConstantInt::get(V->getType(), 0);
auto Result = LVI->getPredicateAt(ICmpInst::ICMP_SGE, V, Zero, CxtI,
/*UseBlockValue=*/true);
return Result == LazyValueInfo::True;
}

static bool isNonPositive(Value *V, LazyValueInfo *LVI, Instruction *CxtI) {
Constant *Zero = ConstantInt::get(V->getType(), 0);
auto Result = LVI->getPredicateAt(ICmpInst::ICMP_SLE, V, Zero, CxtI,
/*UseBlockValue=*/true);
return Result == LazyValueInfo::True;
}

enum class Domain { NonNegative, NonPositive, Unknown };

Domain getDomain(Value *V, LazyValueInfo *LVI, Instruction *CxtI) {
if (isNonNegative(V, LVI, CxtI))
static Domain getDomain(const Use &U, LazyValueInfo *LVI) {
ConstantRange CR = LVI->getConstantRangeAtUse(U);
if (CR.isAllNonNegative())
return Domain::NonNegative;
if (isNonPositive(V, LVI, CxtI))
if (CR.icmp(ICmpInst::ICMP_SLE, APInt::getNullValue(CR.getBitWidth())))
return Domain::NonPositive;
return Domain::Unknown;
}
Expand Down Expand Up @@ -906,10 +893,9 @@ static bool processSRem(BinaryOperator *SDI, LazyValueInfo *LVI) {
};
std::array<Operand, 2> Ops;

for (const auto I : zip(Ops, SDI->operands())) {
Operand &Op = std::get<0>(I);
Op.V = std::get<1>(I);
Op.D = getDomain(Op.V, LVI, SDI);
for (const auto &[Op, U] : zip(Ops, SDI->operands())) {
Op.V = U;
Op.D = getDomain(U, LVI);
if (Op.D == Domain::Unknown)
return false;
}
Expand Down Expand Up @@ -964,10 +950,9 @@ static bool processSDiv(BinaryOperator *SDI, LazyValueInfo *LVI) {
};
std::array<Operand, 2> Ops;

for (const auto I : zip(Ops, SDI->operands())) {
Operand &Op = std::get<0>(I);
Op.V = std::get<1>(I);
Op.D = getDomain(Op.V, LVI, SDI);
for (const auto &[Op, U] : zip(Ops, SDI->operands())) {
Op.V = U;
Op.D = getDomain(U, LVI);
if (Op.D == Domain::Unknown)
return false;
}
Expand Down Expand Up @@ -1038,7 +1023,7 @@ static bool processAShr(BinaryOperator *SDI, LazyValueInfo *LVI) {
return true;
}

if (!isNonNegative(SDI->getOperand(0), LVI, SDI))
if (!LRange.isAllNonNegative())
return false;

++NumAShrsConverted;
Expand All @@ -1057,9 +1042,8 @@ static bool processSExt(SExtInst *SDI, LazyValueInfo *LVI) {
if (SDI->getType()->isVectorTy())
return false;

Value *Base = SDI->getOperand(0);

if (!isNonNegative(Base, LVI, SDI))
const Use &Base = SDI->getOperandUse(0);
if (!LVI->getConstantRangeAtUse(Base).isAllNonNegative())
return false;

++NumSExt;
Expand Down
16 changes: 10 additions & 6 deletions llvm/test/Transforms/CorrelatedValuePropagation/cond-at-use.ll
Expand Up @@ -425,9 +425,11 @@ define i16 @srem_narrow(i16 %x) {

define i16 @srem_convert(i16 %x) {
; CHECK-LABEL: @srem_convert(
; CHECK-NEXT: [[SREM:%.*]] = srem i16 [[X:%.*]], 42
; CHECK-NEXT: [[X_NONNEG:%.*]] = sub i16 0, [[X:%.*]]
; CHECK-NEXT: [[SREM1:%.*]] = urem i16 [[X_NONNEG]], 42
; CHECK-NEXT: [[SREM1_NEG:%.*]] = sub i16 0, [[SREM1]]
; CHECK-NEXT: [[CMP:%.*]] = icmp slt i16 [[X]], 0
; CHECK-NEXT: [[SEL:%.*]] = select i1 [[CMP]], i16 [[SREM]], i16 24
; CHECK-NEXT: [[SEL:%.*]] = select i1 [[CMP]], i16 [[SREM1_NEG]], i16 24
; CHECK-NEXT: ret i16 [[SEL]]
;
%srem = srem i16 %x, 42
Expand All @@ -438,9 +440,11 @@ define i16 @srem_convert(i16 %x) {

define i16 @sdiv_convert(i16 %x) {
; CHECK-LABEL: @sdiv_convert(
; CHECK-NEXT: [[SREM:%.*]] = sdiv i16 [[X:%.*]], 42
; CHECK-NEXT: [[X_NONNEG:%.*]] = sub i16 0, [[X:%.*]]
; CHECK-NEXT: [[SREM1:%.*]] = udiv i16 [[X_NONNEG]], 42
; CHECK-NEXT: [[SREM1_NEG:%.*]] = sub i16 0, [[SREM1]]
; CHECK-NEXT: [[CMP:%.*]] = icmp slt i16 [[X]], 0
; CHECK-NEXT: [[SEL:%.*]] = select i1 [[CMP]], i16 [[SREM]], i16 24
; CHECK-NEXT: [[SEL:%.*]] = select i1 [[CMP]], i16 [[SREM1_NEG]], i16 24
; CHECK-NEXT: ret i16 [[SEL]]
;
%srem = sdiv i16 %x, 42
Expand Down Expand Up @@ -503,7 +507,7 @@ define i16 @umin_elide(i16 %x) {

define i16 @ashr_convert(i16 %x, i16 %y) {
; CHECK-LABEL: @ashr_convert(
; CHECK-NEXT: [[ASHR:%.*]] = ashr i16 [[X:%.*]], [[Y:%.*]]
; CHECK-NEXT: [[ASHR:%.*]] = lshr i16 [[X:%.*]], [[Y:%.*]]
; CHECK-NEXT: [[CMP:%.*]] = icmp sge i16 [[X]], 0
; CHECK-NEXT: [[SEL:%.*]] = select i1 [[CMP]], i16 [[ASHR]], i16 24
; CHECK-NEXT: ret i16 [[SEL]]
Expand All @@ -516,7 +520,7 @@ define i16 @ashr_convert(i16 %x, i16 %y) {

define i32 @sext_convert(i16 %x) {
; CHECK-LABEL: @sext_convert(
; CHECK-NEXT: [[EXT:%.*]] = sext i16 [[X:%.*]] to i32
; CHECK-NEXT: [[EXT:%.*]] = zext i16 [[X:%.*]] to i32
; CHECK-NEXT: [[CMP:%.*]] = icmp sge i16 [[X]], 0
; CHECK-NEXT: [[SEL:%.*]] = select i1 [[CMP]], i32 [[EXT]], i32 24
; CHECK-NEXT: ret i32 [[SEL]]
Expand Down

0 comments on commit a444fe0

Please sign in to comment.