diff --git a/llvm/lib/Transforms/InstCombine/InstCombineCompares.cpp b/llvm/lib/Transforms/InstCombine/InstCombineCompares.cpp index fba1ccf2c8c9b..734d27c70705e 100644 --- a/llvm/lib/Transforms/InstCombine/InstCombineCompares.cpp +++ b/llvm/lib/Transforms/InstCombine/InstCombineCompares.cpp @@ -6001,6 +6001,26 @@ Instruction *InstCombinerImpl::foldICmpEquality(ICmpInst &I) { Value *Op0 = I.getOperand(0), *Op1 = I.getOperand(1); const CmpInst::Predicate Pred = I.getPredicate(); + + // icmp (shl nsw X, Log2), (add nsw (shl nsw Y, Log2), K) -> icmp X, (add nsw + // Y, 1) + Value *X, *Y; + const APInt *CLog2M0, *CLog2M1, *CVal; + auto M0 = m_NSWShl(m_Value(X), m_APIntAllowPoison(CLog2M0)); + auto M1 = m_NSWAdd(m_NSWShl(m_Value(Y), m_APIntAllowPoison(CLog2M1)), + m_APIntAllowPoison(CVal)); + + if (match(&I, m_c_ICmp(M0, M1)) && *CLog2M0 == *CLog2M1) { + unsigned BitWidth = CLog2M0->getBitWidth(); + unsigned ShAmt = (unsigned)CLog2M0->getLimitedValue(BitWidth); + APInt ExpectedK = APInt::getOneBitSet(BitWidth, ShAmt); + if (*CVal == ExpectedK) { + Value *NewRHS = Builder.CreateAdd(Y, ConstantInt::get(Y->getType(), 1), + "", /*HasNUW=*/false, /*HasNSW=*/true); + return new ICmpInst(Pred, X, NewRHS); + } + } + Value *A, *B, *C, *D; if (match(Op0, m_Xor(m_Value(A), m_Value(B)))) { if (A == Op1 || B == Op1) { // (A^B) == A -> B == 0 diff --git a/llvm/test/Transforms/InstCombine/icmp-shl-add-to-add.ll b/llvm/test/Transforms/InstCombine/icmp-shl-add-to-add.ll new file mode 100644 index 0000000000000..1523b283b8b08 --- /dev/null +++ b/llvm/test/Transforms/InstCombine/icmp-shl-add-to-add.ll @@ -0,0 +1,172 @@ +; NOTE: Assertions have been autogenerated by utils/update_test_checks.py +; RUN: opt < %s -passes=instcombine -S | FileCheck %s + +; Test case: Fold (X << 5) == ((Y << 5) + 32) into X == (Y + 1). +; This corresponds to the provided alive2 proof. + +declare void @use_i64(i64) + +define i1 @shl_add_const_eq_base(i64 %v0, i64 %v3) { +; CHECK-LABEL: @shl_add_const_eq_base( +; CHECK-NEXT: [[V5:%.*]] = add nsw i64 [[V3:%.*]], 1 +; CHECK-NEXT: [[V6:%.*]] = icmp eq i64 [[V1:%.*]], [[V5]] +; CHECK-NEXT: ret i1 [[V6]] +; + %v1 = shl nsw i64 %v0, 5 + %v4 = shl nsw i64 %v3, 5 + %v5 = add nsw i64 %v4, 32 + %v6 = icmp eq i64 %v1, %v5 + ret i1 %v6 +} + +; Test: icmp ne +define i1 @shl_add_const_ne(i64 %v0, i64 %v3) { +; CHECK-LABEL: @shl_add_const_ne( +; CHECK-NEXT: [[V5:%.*]] = add nsw i64 [[V3:%.*]], 1 +; CHECK-NEXT: [[V6:%.*]] = icmp ne i64 [[V1:%.*]], [[V5]] +; CHECK-NEXT: ret i1 [[V6]] +; + %v1 = shl nsw i64 %v0, 5 + %v4 = shl nsw i64 %v3, 5 + %v5 = add nsw i64 %v4, 32 + %v6 = icmp ne i64 %v1, %v5 ; Note: icmp ne + ret i1 %v6 +} + +; Test: shl amounts do not match (5 vs 4). +define i1 @shl_add_const_eq_mismatch_shl_amt(i64 %v0, i64 %v3) { +; CHECK-LABEL: @shl_add_const_eq_mismatch_shl_amt( +; CHECK-NEXT: [[V1:%.*]] = shl nsw i64 [[V0:%.*]], 5 +; CHECK-NEXT: [[V4:%.*]] = shl nsw i64 [[V3:%.*]], 4 +; CHECK-NEXT: [[V5:%.*]] = add nsw i64 [[V4]], 16 +; CHECK-NEXT: [[V6:%.*]] = icmp eq i64 [[V1]], [[V5]] +; CHECK-NEXT: ret i1 [[V6]] +; + %v1 = shl nsw i64 %v0, 5 + %v4 = shl nsw i64 %v3, 4 ; Shift amount mismatch + %v5 = add nsw i64 %v4, 16 + %v6 = icmp eq i64 %v1, %v5 + ret i1 %v6 +} + +; Test: Constant is wrong (32 vs 64). +define i1 @shl_add_const_eq_wrong_constant(i64 %v0, i64 %v3) { +; CHECK-LABEL: @shl_add_const_eq_wrong_constant( +; CHECK-NEXT: [[V1:%.*]] = shl nsw i64 [[V0:%.*]], 5 +; CHECK-NEXT: [[V4:%.*]] = shl nsw i64 [[V3:%.*]], 5 +; CHECK-NEXT: [[V5:%.*]] = add nsw i64 [[V4]], 64 +; CHECK-NEXT: [[V6:%.*]] = icmp eq i64 [[V1]], [[V5]] +; CHECK-NEXT: ret i1 [[V6]] +; + %v1 = shl nsw i64 %v0, 5 + %v4 = shl nsw i64 %v3, 5 + %v5 = add nsw i64 %v4, 64 ; Constant mismatch + %v6 = icmp eq i64 %v1, %v5 + ret i1 %v6 +} + +; Test: Missing NSW flag on one of the shl instructions. +define i1 @shl_add_const_eq_no_nsw_on_v1(i64 %v0, i64 %v3) { +; CHECK-LABEL: @shl_add_const_eq_no_nsw_on_v1( +; CHECK-NEXT: [[V1:%.*]] = shl i64 [[V0:%.*]], 5 +; CHECK-NEXT: [[V4:%.*]] = shl nsw i64 [[V3:%.*]], 5 +; CHECK-NEXT: [[V5:%.*]] = add nsw i64 [[V4]], 32 +; CHECK-NEXT: [[V6:%.*]] = icmp eq i64 [[V1]], [[V5]] +; CHECK-NEXT: ret i1 [[V6]] +; + %v1 = shl i64 %v0, 5 ; Missing nsw + %v4 = shl nsw i64 %v3, 5 + %v5 = add nsw i64 %v4, 32 + %v6 = icmp eq i64 %v1, %v5 + ret i1 %v6 +} + +; Test: Lower bit width (i8) and different shift amount (3). Constant is 8. +define i1 @shl_add_const_eq_i8(i8 %v0, i8 %v3) { +; CHECK-LABEL: @shl_add_const_eq_i8( +; CHECK-NEXT: [[TMP1:%.*]] = add nsw i8 [[V3:%.*]], 1 +; CHECK-NEXT: [[V6:%.*]] = icmp eq i8 [[V0:%.*]], [[TMP1]] +; CHECK-NEXT: ret i1 [[V6]] +; + %v1 = shl nsw i8 %v0, 3 + %v4 = shl nsw i8 %v3, 3 + %v5 = add nsw i8 %v4, 8 ; 2^3 = 8 + %v6 = icmp eq i8 %v1, %v5 + ret i1 %v6 +} + +; Test: i32 bit width and larger shift amount (10). Constant is 1024. +define i1 @shl_add_const_eq_i32(i32 %v0, i32 %v3) { +; CHECK-LABEL: @shl_add_const_eq_i32( +; CHECK-NEXT: [[TMP1:%.*]] = add nsw i32 [[V3:%.*]], 1 +; CHECK-NEXT: [[V6:%.*]] = icmp eq i32 [[V0:%.*]], [[TMP1]] +; CHECK-NEXT: ret i1 [[V6]] +; + %v1 = shl nsw i32 %v0, 10 + %v4 = shl nsw i32 %v3, 10 + %v5 = add nsw i32 %v4, 1024 ; 2^10 = 1024 + %v6 = icmp eq i32 %v1, %v5 + ret i1 %v6 +} + +; Test: Multi-use case. The optimization should still occur if applicable, +; but the extraneous call must be preserved. +define i1 @shl_add_const_eq_multi_use(i64 %v0, i64 %v3) { +; CHECK-LABEL: @shl_add_const_eq_multi_use( +; CHECK-NEXT: [[V1:%.*]] = shl nsw i64 [[V0:%.*]], 5 +; CHECK-NEXT: call void @use_i64(i64 [[V1]]) +; CHECK-NEXT: [[TMP1:%.*]] = add nsw i64 [[V3:%.*]], 1 +; CHECK-NEXT: [[V6:%.*]] = icmp eq i64 [[V0]], [[TMP1]] +; CHECK-NEXT: ret i1 [[V6]] +; + %v1 = shl nsw i64 %v0, 5 + call void @use_i64(i64 %v1) ; Additional use of v1 + %v4 = shl nsw i64 %v3, 5 + %v5 = add nsw i64 %v4, 32 + %v6 = icmp eq i64 %v1, %v5 + ret i1 %v6 +} + +; Test: Vector splat. Should fold once optimization is applied. +define <2 x i1> @shl_add_const_eq_vec_splat(<2 x i64> %v0, <2 x i64> %v3) { +; CHECK-LABEL: @shl_add_const_eq_vec_splat( +; CHECK-NEXT: [[V5:%.*]] = add nsw <2 x i64> [[V3:%.*]], splat (i64 1) +; CHECK-NEXT: [[V6:%.*]] = icmp eq <2 x i64> [[V1:%.*]], [[V5]] +; CHECK-NEXT: ret <2 x i1> [[V6]] +; + %v1 = shl nsw <2 x i64> %v0, + %v4 = shl nsw <2 x i64> %v3, + %v5 = add nsw <2 x i64> %v4, + %v6 = icmp eq <2 x i64> %v1, %v5 + ret <2 x i1> %v6 +} + +; Test: Vector splat with poison. Should fold once optimization is applied. +define <2 x i1> @shl_add_const_eq_vec_splat_poison(<2 x i64> %v0, <2 x i64> %v3) { +; CHECK-LABEL: @shl_add_const_eq_vec_splat_poison( +; CHECK-NEXT: [[V5:%.*]] = add nsw <2 x i64> [[V3:%.*]], splat (i64 1) +; CHECK-NEXT: [[V6:%.*]] = icmp eq <2 x i64> [[V1:%.*]], [[V5]] +; CHECK-NEXT: ret <2 x i1> [[V6]] +; + %v1 = shl nsw <2 x i64> %v0, + %v4 = shl nsw <2 x i64> %v3, + %v5 = add nsw <2 x i64> %v4, + %v6 = icmp eq <2 x i64> %v1, %v5 + ret <2 x i1> %v6 +} + +; Test: Vector non-splat (should not fold). +define <2 x i1> @shl_add_const_eq_vec_non_splat(<2 x i64> %v0, <2 x i64> %v3) { +; CHECK-LABEL: @shl_add_const_eq_vec_non_splat( +; CHECK-NEXT: [[V1:%.*]] = shl nsw <2 x i64> [[V0:%.*]], +; CHECK-NEXT: [[V4:%.*]] = shl nsw <2 x i64> [[V3:%.*]], +; CHECK-NEXT: [[V5:%.*]] = add nsw <2 x i64> [[V4]], +; CHECK-NEXT: [[V6:%.*]] = icmp eq <2 x i64> [[V1]], [[V5]] +; CHECK-NEXT: ret <2 x i1> [[V6]] +; + %v1 = shl nsw <2 x i64> %v0, + %v4 = shl nsw <2 x i64> %v3, + %v5 = add nsw <2 x i64> %v4, + %v6 = icmp eq <2 x i64> %v1, %v5 + ret <2 x i1> %v6 +}