diff --git a/llvm/lib/Transforms/InstCombine/InstCombineAddSub.cpp b/llvm/lib/Transforms/InstCombine/InstCombineAddSub.cpp index 9bee523c7b7e5..6a0c75f768737 100644 --- a/llvm/lib/Transforms/InstCombine/InstCombineAddSub.cpp +++ b/llvm/lib/Transforms/InstCombine/InstCombineAddSub.cpp @@ -2857,9 +2857,33 @@ Instruction *InstCombinerImpl::visitSub(BinaryOperator &I) { I, Builder.CreateIntrinsic(Intrinsic::usub_sat, {Ty}, {X, Op1})); // Op0 - umin(X, Op0) --> usub.sat(Op0, X) - if (match(Op1, m_OneUse(m_c_UMin(m_Value(X), m_Specific(Op0))))) - return replaceInstUsesWith( - I, Builder.CreateIntrinsic(Intrinsic::usub_sat, {Ty}, {Op0, X})); + if (match(Op1, m_OneUse(m_c_UMin(m_Value(X), m_Specific(Op0))))) { + // Op0 - umin(Op0, C) s< C2 --> Op0 s< C2 + C + // Constraints: + // 1. C >= 0: Range [0, C) must be valid. + // 2. C2 > 0: Range [C, C+C2) must imply extension. + // 3. No Signed Overflow: Merged upper bound (C + C2) must be safe. + const APInt *C; + if (I.hasOneUse() && I.hasNoSignedWrap() && match(X, m_APInt(C)) && + C->isNonNegative()) { + const APInt *C2; + CmpPredicate Pred; + Instruction *MustICmp = cast(I.user_back()); + if (match(MustICmp, m_ICmp(Pred, m_Specific(&I), m_APInt(C2))) && + Pred == CmpInst::ICMP_SLT) { + bool Overflow; + APInt Sum = C->sadd_ov(*C2, Overflow); + if (C2->isStrictlyPositive() && !Overflow) { + Value *NewCmp = Builder.CreateICmpSLT(Op0, ConstantInt::get(Ty, Sum)); + eraseInstFromFunction(*replaceInstUsesWith(*MustICmp, NewCmp)); + // return eraseInstFromFunction(I); + return nullptr; + } + } + } + Value *USub = Builder.CreateIntrinsic(Intrinsic::usub_sat, {Ty}, {Op0, X}); + return replaceInstUsesWith(I, USub); + } // Op0 - umax(X, Op0) --> 0 - usub.sat(X, Op0) if (match(Op1, m_OneUse(m_c_UMax(m_Value(X), m_Specific(Op0))))) { diff --git a/llvm/test/Transforms/InstCombine/icmp-sub.ll b/llvm/test/Transforms/InstCombine/icmp-sub.ll index 13ed7ba0c1703..8bf7e2041d76a 100644 --- a/llvm/test/Transforms/InstCombine/icmp-sub.ll +++ b/llvm/test/Transforms/InstCombine/icmp-sub.ll @@ -3,6 +3,123 @@ declare void @use(i32) declare void @use_vec(<2 x i8>) +declare i8 @llvm.umin.i8(i8, i8) + +; 1. Positive Tests +; Basic valid case +define i1 @test_basic_opt(i8 %x) { +; CHECK-LABEL: @test_basic_opt( +; CHECK-NEXT: [[CMP:%.*]] = icmp slt i8 [[X:%.*]], 30 +; CHECK-NEXT: ret i1 [[CMP]] +; + %min = call i8 @llvm.umin.i8(i8 %x, i8 10) + %sub = sub nsw i8 %x, %min + %cmp = icmp slt i8 %sub, 20 + ret i1 %cmp +} + +; Boundary case - C is 0 +define i1 @test_c_is_zero(i8 %x) { +; CHECK-LABEL: @test_c_is_zero( +; CHECK-NEXT: [[CMP:%.*]] = icmp slt i8 [[X:%.*]], 10 +; CHECK-NEXT: ret i1 [[CMP]] +; + %min = call i8 @llvm.umin.i8(i8 %x, i8 0) + %sub = sub nsw i8 %x, %min + %cmp = icmp slt i8 %sub, 10 + ret i1 %cmp +} + +; Boundary case - Sum is exactly SMAX (127) +define i1 @test_sum_is_smax(i8 %x) { +; CHECK-LABEL: @test_sum_is_smax( +; CHECK-NEXT: [[CMP:%.*]] = icmp ne i8 [[X:%.*]], 127 +; CHECK-NEXT: ret i1 [[CMP]] +; + %min = call i8 @llvm.umin.i8(i8 %x, i8 100) + %sub = sub nsw i8 %x, %min + %cmp = icmp slt i8 %sub, 27 + ret i1 %cmp +} + +; 2. Negative Tests +; Missing 'nsw' flag on sub +define i1 @fail_no_nsw(i8 %x) { +; CHECK-LABEL: @fail_no_nsw( +; CHECK-NEXT: [[TMP1:%.*]] = add i8 [[X:%.*]], 118 +; CHECK-NEXT: [[CMP:%.*]] = icmp ult i8 [[TMP1]], -108 +; CHECK-NEXT: ret i1 [[CMP]] +; + %min = call i8 @llvm.umin.i8(i8 %x, i8 10) + %sub = sub i8 %x, %min + %cmp = icmp slt i8 %sub, 20 + ret i1 %cmp +} + +; C is Negative (Constraint 1 Violation) +define i1 @fail_c_negative(i8 %x) { +; CHECK-LABEL: @fail_c_negative( +; CHECK-NEXT: ret i1 true +; + %min = call i8 @llvm.umin.i8(i8 %x, i8 -10) + %sub = sub nsw i8 %x, %min + %cmp = icmp slt i8 %sub, 20 + ret i1 %cmp +} + +; C2 is Zero (Constraint 2 Violation) +define i1 @fail_c2_zero(i8 %x) { +; CHECK-LABEL: @fail_c2_zero( +; CHECK-NEXT: [[CMP:%.*]] = icmp ugt i8 [[X:%.*]], -119 +; CHECK-NEXT: ret i1 [[CMP]] +; + %min = call i8 @llvm.umin.i8(i8 %x, i8 10) + %sub = sub nsw i8 %x, %min + %cmp = icmp slt i8 %sub, 0 + ret i1 %cmp +} + +; C2 is Negative (Constraint 2 Violation) +define i1 @fail_c2_negative(i8 %x) { +; CHECK-LABEL: @fail_c2_negative( +; CHECK-NEXT: [[CMP:%.*]] = icmp ugt i8 [[X:%.*]], -119 +; CHECK-NEXT: ret i1 [[CMP]] +; + %min = call i8 @llvm.umin.i8(i8 %x, i8 10) + %sub = sub nsw i8 %x, %min + %cmp = icmp slt i8 %sub, -5 + ret i1 %cmp +} + +; Signed Overflow in C + C2 (Constraint 3 Violation) +define i1 @fail_sum_overflow(i8 %x) { +; CHECK-LABEL: @fail_sum_overflow( +; CHECK-NEXT: [[TMP1:%.*]] = add i8 [[X:%.*]], 28 +; CHECK-NEXT: [[CMP:%.*]] = icmp ult i8 [[TMP1]], -98 +; CHECK-NEXT: ret i1 [[CMP]] +; + %min = call i8 @llvm.umin.i8(i8 %x, i8 100) + %sub = sub nsw i8 %x, %min + %cmp = icmp slt i8 %sub, 30 + ret i1 %cmp +} + +; Multi-use of sub instruction +define i8 @fail_multi_use(i8 %x) { +; CHECK-LABEL: @fail_multi_use( +; CHECK-NEXT: [[SUB:%.*]] = call i8 @llvm.usub.sat.i8(i8 [[X:%.*]], i8 10) +; CHECK-NEXT: [[CMP:%.*]] = icmp slt i8 [[SUB]], 20 +; CHECK-NEXT: [[RES:%.*]] = zext i1 [[CMP]] to i8 +; CHECK-NEXT: [[RET:%.*]] = add nuw i8 [[SUB]], [[RES]] +; CHECK-NEXT: ret i8 [[RET]] +; + %min = call i8 @llvm.umin.i8(i8 %x, i8 10) + %sub = sub nsw i8 %x, %min + %cmp = icmp slt i8 %sub, 20 + %res = zext i1 %cmp to i8 + %ret = add i8 %sub, %res + ret i8 %ret +} define i1 @test_nuw_and_unsigned_pred(i64 %x) { ; CHECK-LABEL: @test_nuw_and_unsigned_pred(