diff --git a/llvm/lib/Transforms/InstCombine/InstCombineSelect.cpp b/llvm/lib/Transforms/InstCombine/InstCombineSelect.cpp index eb4332fbc0959..c9362892d1bfd 100644 --- a/llvm/lib/Transforms/InstCombine/InstCombineSelect.cpp +++ b/llvm/lib/Transforms/InstCombine/InstCombineSelect.cpp @@ -1156,6 +1156,63 @@ static Value *foldAbsDiff(ICmpInst *Cmp, Value *TVal, Value *FVal, return nullptr; } +/// Fold select patterns involving abs intrinsic: +/// X == Positive ? X : ABS(X) -> ABS(X) +/// X == Positive ? Positive : ABS(X) -> ABS(X) +/// X > Positive ? X : ABS(X) -> ABS(X) +/// X >= Positive ? X : ABS(X) -> ABS(X) +static Value *foldSelectPositiveAbs(ICmpInst *Cmp, Value *TVal, Value *FVal, + InstCombiner::BuilderTy &Builder, + InstCombinerImpl &IC) { + // Check if false value is abs(X) + Value *X; + Constant *IntMinIsPoison; + if (!match(FVal, m_Intrinsic(m_Value(X), + m_Constant(IntMinIsPoison)))) + return nullptr; + + ICmpInst::Predicate Pred = Cmp->getPredicate(); + Value *CmpLHS = Cmp->getOperand(0); + Value *CmpRHS = Cmp->getOperand(1); + + // Normalize so that X is on the LHS of comparison + if (CmpLHS != X) { + if (CmpRHS == X) { + std::swap(CmpLHS, CmpRHS); + Pred = ICmpInst::getSwappedPredicate(Pred); + } else { + return nullptr; + } + } + + // Check if RHS is non-negative + if (!isKnownNonNegative(CmpRHS, IC.getSimplifyQuery())) + return nullptr; + + // Handle different patterns + switch (Pred) { + case ICmpInst::ICMP_EQ: + // X == Positive ? X : ABS(X) -> ABS(X) + if (TVal == X) + return FVal; + // X == Positive ? Positive : ABS(X) -> ABS(X) + if (TVal == CmpRHS) + return FVal; + break; + case ICmpInst::ICMP_SGT: + case ICmpInst::ICMP_SGE: + // X > Positive ? X : ABS(X) -> ABS(X) + // X >= Positive ? X : ABS(X) -> ABS(X) + if (TVal == X) + return FVal; + break; + default: + break; + } + + return nullptr; +} + /// Fold the following code sequence: /// \code /// int a = ctlz(x & -x); @@ -2068,6 +2125,9 @@ Instruction *InstCombinerImpl::foldSelectInstWithICmp(SelectInst &SI, if (Value *V = foldAbsDiff(ICI, TrueVal, FalseVal, Builder)) return replaceInstUsesWith(SI, V); + if (Value *V = foldSelectPositiveAbs(ICI, TrueVal, FalseVal, Builder, *this)) + return replaceInstUsesWith(SI, V); + if (Value *V = foldSelectWithConstOpToBinOp(ICI, TrueVal, FalseVal)) return replaceInstUsesWith(SI, V); diff --git a/llvm/test/Transforms/InstCombine/select-abs-positive.ll b/llvm/test/Transforms/InstCombine/select-abs-positive.ll new file mode 100644 index 0000000000000..a123e6e3a3272 --- /dev/null +++ b/llvm/test/Transforms/InstCombine/select-abs-positive.ll @@ -0,0 +1,158 @@ +; NOTE: Assertions have been autogenerated by utils/update_test_checks.py +; RUN: opt < %s -passes=instcombine -S | FileCheck %s + +declare i32 @llvm.abs.i32(i32, i1) +declare i64 @llvm.abs.i64(i64, i1) + +; X == Positive ? X : ABS(X) -> ABS(X) +define i32 @feq_1(i32 noundef %a, i8 noundef zeroext %b) { +; CHECK-LABEL: @feq_1( +; CHECK-NEXT: entry: +; CHECK-NEXT: [[COND:%.*]] = tail call i32 @llvm.abs.i32(i32 [[A:%.*]], i1 true) +; CHECK-NEXT: ret i32 [[COND]] +; +entry: + %conv = zext i8 %b to i32 + %cmp = icmp eq i32 %a, %conv + %cond = tail call i32 @llvm.abs.i32(i32 %a, i1 true) + %retval.0 = select i1 %cmp, i32 %a, i32 %cond + ret i32 %retval.0 +} + +; X == Positive ? Positive : ABS(X) -> ABS(X) +define i32 @feq_2(i32 noundef %a, i8 noundef zeroext %b) { +; CHECK-LABEL: @feq_2( +; CHECK-NEXT: entry: +; CHECK-NEXT: [[COND:%.*]] = tail call i32 @llvm.abs.i32(i32 [[A:%.*]], i1 true) +; CHECK-NEXT: ret i32 [[COND]] +; +entry: + %conv = zext i8 %b to i32 + %cmp = icmp eq i32 %a, %conv + %cond = tail call i32 @llvm.abs.i32(i32 %a, i1 true) + %retval.0 = select i1 %cmp, i32 %conv, i32 %cond + ret i32 %retval.0 +} + +; X > Positive ? X : ABS(X) -> ABS(X) +define i32 @fgt_1(i32 noundef %a, i8 noundef zeroext %b) { +; CHECK-LABEL: @fgt_1( +; CHECK-NEXT: entry: +; CHECK-NEXT: [[COND:%.*]] = tail call i32 @llvm.abs.i32(i32 [[A:%.*]], i1 true) +; CHECK-NEXT: ret i32 [[COND]] +; +entry: + %conv = zext i8 %b to i32 + %cmp = icmp sgt i32 %a, %conv + %cond = tail call i32 @llvm.abs.i32(i32 %a, i1 true) + %retval.0 = select i1 %cmp, i32 %a, i32 %cond + ret i32 %retval.0 +} + +; X >= Positive ? X : ABS(X) -> ABS(X) +define i32 @fge_1(i32 noundef %a, i8 noundef zeroext %b) { +; CHECK-LABEL: @fge_1( +; CHECK-NEXT: entry: +; CHECK-NEXT: [[CONV:%.*]] = zext i8 [[B:%.*]] to i32 +; CHECK-NEXT: [[CMP_NOT:%.*]] = icmp slt i32 [[A:%.*]], [[CONV]] +; CHECK-NEXT: [[COND:%.*]] = tail call i32 @llvm.abs.i32(i32 [[A]], i1 true) +; CHECK-NEXT: [[RETVAL_0:%.*]] = select i1 [[CMP_NOT]], i32 [[COND]], i32 [[A]] +; CHECK-NEXT: ret i32 [[RETVAL_0]] +; +entry: + %conv = zext i8 %b to i32 + %cmp = icmp sge i32 %a, %conv + %cond = tail call i32 @llvm.abs.i32(i32 %a, i1 true) + %retval.0 = select i1 %cmp, i32 %a, i32 %cond + ret i32 %retval.0 +} + +; Test with constant positive value +define i32 @constant_positive(i32 noundef %a) { +; CHECK-LABEL: @constant_positive( +; CHECK-NEXT: entry: +; CHECK-NEXT: [[COND:%.*]] = tail call i32 @llvm.abs.i32(i32 [[A:%.*]], i1 true) +; CHECK-NEXT: ret i32 [[COND]] +; +entry: + %cmp = icmp eq i32 %a, 42 + %cond = tail call i32 @llvm.abs.i32(i32 %a, i1 true) + %retval.0 = select i1 %cmp, i32 %a, i32 %cond + ret i32 %retval.0 +} + +; Negative test: Should not optimize when comparing with negative value +define i32 @negative_value(i32 noundef %a) { +; CHECK-LABEL: @negative_value( +; CHECK-NEXT: entry: +; CHECK-NEXT: [[CMP:%.*]] = icmp eq i32 [[A:%.*]], -42 +; CHECK-NEXT: [[COND:%.*]] = tail call i32 @llvm.abs.i32(i32 [[A]], i1 true) +; CHECK-NEXT: [[RETVAL_0:%.*]] = select i1 [[CMP]], i32 -42, i32 [[COND]] +; CHECK-NEXT: ret i32 [[RETVAL_0]] +; +entry: + %cmp = icmp eq i32 %a, -42 + %cond = tail call i32 @llvm.abs.i32(i32 %a, i1 true) + %retval.0 = select i1 %cmp, i32 %a, i32 %cond + ret i32 %retval.0 +} + +; Negative test: Should not optimize when true value is not X or positive +define i32 @wrong_true_value(i32 noundef %a, i32 noundef %c) { +; CHECK-LABEL: @wrong_true_value( +; CHECK-NEXT: entry: +; CHECK-NEXT: [[CMP:%.*]] = icmp eq i32 [[A:%.*]], 42 +; CHECK-NEXT: [[COND:%.*]] = tail call i32 @llvm.abs.i32(i32 [[A]], i1 true) +; CHECK-NEXT: [[RETVAL_0:%.*]] = select i1 [[CMP]], i32 [[C:%.*]], i32 [[COND]] +; CHECK-NEXT: ret i32 [[RETVAL_0]] +; +entry: + %cmp = icmp eq i32 %a, 42 + %cond = tail call i32 @llvm.abs.i32(i32 %a, i1 true) + %retval.0 = select i1 %cmp, i32 %c, i32 %cond + ret i32 %retval.0 +} + +; Negative test: Should not optimize when false value is not abs(X) +define i32 @wrong_false_value(i32 noundef %a, i32 noundef %c) { +; CHECK-LABEL: @wrong_false_value( +; CHECK-NEXT: entry: +; CHECK-NEXT: [[CMP:%.*]] = icmp eq i32 [[A:%.*]], 42 +; CHECK-NEXT: [[RETVAL_0:%.*]] = select i1 [[CMP]], i32 42, i32 [[C:%.*]] +; CHECK-NEXT: ret i32 [[RETVAL_0]] +; +entry: + %cmp = icmp eq i32 %a, 42 + %retval.0 = select i1 %cmp, i32 %a, i32 %c + ret i32 %retval.0 +} + +; Test with different types (i64) +define i64 @i64_test(i64 noundef %a, i32 noundef zeroext %b) { +; CHECK-LABEL: @i64_test( +; CHECK-NEXT: entry: +; CHECK-NEXT: [[COND:%.*]] = tail call i64 @llvm.abs.i64(i64 [[A:%.*]], i1 true) +; CHECK-NEXT: ret i64 [[COND]] +; +entry: + %conv = zext i32 %b to i64 + %cmp = icmp eq i64 %a, %conv + %cond = tail call i64 @llvm.abs.i64(i64 %a, i1 true) + %retval.0 = select i1 %cmp, i64 %a, i64 %cond + ret i64 %retval.0 +} + +; Test with swapped comparison operands +define i32 @swapped_comparison(i32 noundef %a, i8 noundef zeroext %b) { +; CHECK-LABEL: @swapped_comparison( +; CHECK-NEXT: entry: +; CHECK-NEXT: [[COND:%.*]] = tail call i32 @llvm.abs.i32(i32 [[A:%.*]], i1 true) +; CHECK-NEXT: ret i32 [[COND]] +; +entry: + %conv = zext i8 %b to i32 + %cmp = icmp eq i32 %conv, %a + %cond = tail call i32 @llvm.abs.i32(i32 %a, i1 true) + %retval.0 = select i1 %cmp, i32 %a, i32 %cond + ret i32 %retval.0 +}