diff --git a/llvm/lib/Transforms/InstCombine/InstCombineAddSub.cpp b/llvm/lib/Transforms/InstCombine/InstCombineAddSub.cpp index bff09f5676680..c91d20be10e9f 100644 --- a/llvm/lib/Transforms/InstCombine/InstCombineAddSub.cpp +++ b/llvm/lib/Transforms/InstCombine/InstCombineAddSub.cpp @@ -1487,6 +1487,9 @@ Instruction *InstCombinerImpl::visitAdd(BinaryOperator &I) { if (Instruction *Phi = foldBinopWithPhiOperands(I)) return Phi; + if (Value *R = foldOpOfXWithXEqC(&I, SQ.getWithInstruction(&I))) + return replaceInstUsesWith(I, R); + // (A*B)+(A*C) -> A*(B+C) etc if (Value *V = foldUsingDistributiveLaws(I)) return replaceInstUsesWith(I, V); @@ -2092,6 +2095,9 @@ Instruction *InstCombinerImpl::visitSub(BinaryOperator &I) { if (Instruction *Phi = foldBinopWithPhiOperands(I)) return Phi; + if (Value *R = foldOpOfXWithXEqC(&I, SQ.getWithInstruction(&I))) + return replaceInstUsesWith(I, R); + Value *Op0 = I.getOperand(0), *Op1 = I.getOperand(1); // If this is a 'B = x-(-A)', change to B = x+A. diff --git a/llvm/lib/Transforms/InstCombine/InstCombineAndOrXor.cpp b/llvm/lib/Transforms/InstCombine/InstCombineAndOrXor.cpp index 8695e9e69df20..16f6dfbd995e0 100644 --- a/llvm/lib/Transforms/InstCombine/InstCombineAndOrXor.cpp +++ b/llvm/lib/Transforms/InstCombine/InstCombineAndOrXor.cpp @@ -2275,6 +2275,9 @@ Instruction *InstCombinerImpl::visitAnd(BinaryOperator &I) { if (Instruction *Phi = foldBinopWithPhiOperands(I)) return Phi; + if (Value *R = foldOpOfXWithXEqC(&I, SQ.getWithInstruction(&I))) + return replaceInstUsesWith(I, R); + // See if we can simplify any instructions used by the instruction whose sole // purpose is to compute bits we don't care about. if (SimplifyDemandedInstructionBits(I)) @@ -3438,6 +3441,9 @@ Instruction *InstCombinerImpl::visitOr(BinaryOperator &I) { if (Instruction *Phi = foldBinopWithPhiOperands(I)) return Phi; + if (Value *R = foldOpOfXWithXEqC(&I, SQ.getWithInstruction(&I))) + return replaceInstUsesWith(I, R); + // See if we can simplify any instructions used by the instruction whose sole // purpose is to compute bits we don't care about. if (SimplifyDemandedInstructionBits(I)) @@ -4571,6 +4577,9 @@ Instruction *InstCombinerImpl::visitXor(BinaryOperator &I) { if (Instruction *NewXor = foldXorToXor(I, Builder)) return NewXor; + if (Value *R = foldOpOfXWithXEqC(&I, SQ.getWithInstruction(&I))) + return replaceInstUsesWith(I, R); + // (A&B)^(A&C) -> A&(B^C) etc if (Value *V = foldUsingDistributiveLaws(I)) return replaceInstUsesWith(I, V); diff --git a/llvm/lib/Transforms/InstCombine/InstCombineCalls.cpp b/llvm/lib/Transforms/InstCombine/InstCombineCalls.cpp index b6f339da31f7f..636c1284caea3 100644 --- a/llvm/lib/Transforms/InstCombine/InstCombineCalls.cpp +++ b/llvm/lib/Transforms/InstCombine/InstCombineCalls.cpp @@ -1491,6 +1491,9 @@ Instruction *InstCombinerImpl::visitCallInst(CallInst &CI) { IntrinsicInst *II = dyn_cast(&CI); if (!II) return visitCallBase(CI); + if (Value *R = foldOpOfXWithXEqC(II, SQ.getWithInstruction(&CI))) + return replaceInstUsesWith(CI, R); + // For atomic unordered mem intrinsics if len is not a positive or // not a multiple of element size then behavior is undefined. if (auto *AMI = dyn_cast(II)) diff --git a/llvm/lib/Transforms/InstCombine/InstCombineInternal.h b/llvm/lib/Transforms/InstCombine/InstCombineInternal.h index db7838bbe3c25..79a500f532329 100644 --- a/llvm/lib/Transforms/InstCombine/InstCombineInternal.h +++ b/llvm/lib/Transforms/InstCombine/InstCombineInternal.h @@ -755,6 +755,7 @@ class LLVM_LIBRARY_VISIBILITY InstCombinerImpl final Value *EvaluateInDifferentType(Value *V, Type *Ty, bool isSigned); + Value *foldOpOfXWithXEqC(Value *Op, const SimplifyQuery &SQ); bool tryToSinkInstruction(Instruction *I, BasicBlock *DestBlock); void tryToSinkInstructionDbgValues( Instruction *I, BasicBlock::iterator InsertPos, BasicBlock *SrcBlock, diff --git a/llvm/lib/Transforms/InstCombine/InstCombineMulDivRem.cpp b/llvm/lib/Transforms/InstCombine/InstCombineMulDivRem.cpp index ca1b1921404d8..86958b040002e 100644 --- a/llvm/lib/Transforms/InstCombine/InstCombineMulDivRem.cpp +++ b/llvm/lib/Transforms/InstCombine/InstCombineMulDivRem.cpp @@ -204,6 +204,9 @@ Instruction *InstCombinerImpl::visitMul(BinaryOperator &I) { if (Instruction *Phi = foldBinopWithPhiOperands(I)) return Phi; + if (Value *R = foldOpOfXWithXEqC(&I, SQ.getWithInstruction(&I))) + return replaceInstUsesWith(I, R); + if (Value *V = foldUsingDistributiveLaws(I)) return replaceInstUsesWith(I, V); diff --git a/llvm/lib/Transforms/InstCombine/InstCombineShifts.cpp b/llvm/lib/Transforms/InstCombine/InstCombineShifts.cpp index 0f1979fbe0c76..4e4a0a97ddee8 100644 --- a/llvm/lib/Transforms/InstCombine/InstCombineShifts.cpp +++ b/llvm/lib/Transforms/InstCombine/InstCombineShifts.cpp @@ -1020,6 +1020,9 @@ Instruction *InstCombinerImpl::visitShl(BinaryOperator &I) { if (Instruction *V = dropRedundantMaskingOfLeftShiftInput(&I, Q, Builder)) return V; + if (Value *R = foldOpOfXWithXEqC(&I, Q)) + return replaceInstUsesWith(I, R); + Value *Op0 = I.getOperand(0), *Op1 = I.getOperand(1); Type *Ty = I.getType(); unsigned BitWidth = Ty->getScalarSizeInBits(); @@ -1252,6 +1255,9 @@ Instruction *InstCombinerImpl::visitLShr(BinaryOperator &I) { if (Instruction *R = commonShiftTransforms(I)) return R; + if (Value *R = foldOpOfXWithXEqC(&I, SQ.getWithInstruction(&I))) + return replaceInstUsesWith(I, R); + Value *Op0 = I.getOperand(0), *Op1 = I.getOperand(1); Type *Ty = I.getType(); Value *X; @@ -1625,6 +1631,9 @@ Instruction *InstCombinerImpl::visitAShr(BinaryOperator &I) { if (Instruction *R = commonShiftTransforms(I)) return R; + if (Value *R = foldOpOfXWithXEqC(&I, SQ.getWithInstruction(&I))) + return replaceInstUsesWith(I, R); + Value *Op0 = I.getOperand(0), *Op1 = I.getOperand(1); Type *Ty = I.getType(); unsigned BitWidth = Ty->getScalarSizeInBits(); diff --git a/llvm/lib/Transforms/InstCombine/InstructionCombining.cpp b/llvm/lib/Transforms/InstCombine/InstructionCombining.cpp index eb48157af009c..f43d97bb1bcf9 100644 --- a/llvm/lib/Transforms/InstCombine/InstructionCombining.cpp +++ b/llvm/lib/Transforms/InstCombine/InstructionCombining.cpp @@ -4829,6 +4829,89 @@ void InstCombinerImpl::tryToSinkInstructionDbgValues( } } +// If we have: +// `(op X, (zext/sext (icmp eq X, C)))` +// We can transform it to: +// `(select (icmp eq X, C), (op C, (zext/sext 1)), (op X, 0))` +// We do so if the `zext/sext` is one use and `(op X, 0)` simplifies. +Value *InstCombinerImpl::foldOpOfXWithXEqC(Value *Op, const SimplifyQuery &SQ) { + Value *Cond; + Constant *C, *ExtC; + + // match `(op X, (zext/sext (icmp eq X, C)))` and see if `(op X, 0)` + // simplifies. + // If we match and simplify, store the `icmp` in `Cond`, `(zext/sext C)` in + // `ExtC`. + auto MatchXWithXEqC = [&](Value *Op0, Value *Op1) -> Value * { + if (match(Op0, m_OneUse(m_ZExtOrSExt(m_Value(Cond))))) { + ICmpInst::Predicate Pred; + if (!match(Cond, m_ICmp(Pred, m_Specific(Op1), m_ImmConstant(C))) || + Pred != ICmpInst::ICMP_EQ) + return nullptr; + + ExtC = isa(Op0) ? ConstantInt::getAllOnesValue(C->getType()) + : ConstantInt::get(C->getType(), 1); + return simplifyWithOpReplaced(Op, Op0, + Constant::getNullValue(Op1->getType()), SQ, + /*AllowRefinement=*/true); + } + return nullptr; + }; + + Value *SimpleOp = nullptr, *ConstOp = nullptr; + if (auto *BO = dyn_cast(Op)) { + switch (BO->getOpcode()) { + // Potential TODO: For all of these, if Op1 is the compare, the compare + // must be true and we could replace Op0 with C (otherwise immediate UB). + case Instruction::UDiv: + case Instruction::SDiv: + case Instruction::URem: + case Instruction::SRem: + return nullptr; + default: + break; + } + + // Try X is Op0 + if ((SimpleOp = MatchXWithXEqC(BO->getOperand(0), BO->getOperand(1)))) + ConstOp = Builder.CreateBinOp(BO->getOpcode(), ExtC, C); + // Try X is Op1 + else if ((SimpleOp = MatchXWithXEqC(BO->getOperand(1), BO->getOperand(0)))) + ConstOp = Builder.CreateBinOp(BO->getOpcode(), C, ExtC); + } else if (auto *II = dyn_cast(Op)) { + switch (II->getIntrinsicID()) { + default: + return nullptr; + case Intrinsic::sshl_sat: + case Intrinsic::ushl_sat: + case Intrinsic::umax: + case Intrinsic::umin: + case Intrinsic::smax: + case Intrinsic::smin: + case Intrinsic::uadd_sat: + case Intrinsic::usub_sat: + case Intrinsic::sadd_sat: + case Intrinsic::ssub_sat: + // Try X is Op0 + if ((SimpleOp = + MatchXWithXEqC(II->getArgOperand(0), II->getArgOperand(1)))) + ConstOp = Builder.CreateBinaryIntrinsic(II->getIntrinsicID(), ExtC, C); + // Try X is Op1 + else if ((SimpleOp = + MatchXWithXEqC(II->getArgOperand(1), II->getArgOperand(0)))) + ConstOp = Builder.CreateBinaryIntrinsic(II->getIntrinsicID(), C, ExtC); + break; + } + } + + assert((SimpleOp == nullptr) == (ConstOp == nullptr) && + "Simplfied Op and Constant Op are de-synced!"); + if (SimpleOp == nullptr) + return nullptr; + + return Builder.CreateSelect(Cond, ConstOp, SimpleOp); +} + void InstCombinerImpl::tryToSinkInstructionDbgVariableRecords( Instruction *I, BasicBlock::iterator InsertPos, BasicBlock *SrcBlock, BasicBlock *DestBlock, diff --git a/llvm/test/Transforms/InstCombine/apint-shift.ll b/llvm/test/Transforms/InstCombine/apint-shift.ll index 05c3db70ce1ca..f508939b73321 100644 --- a/llvm/test/Transforms/InstCombine/apint-shift.ll +++ b/llvm/test/Transforms/InstCombine/apint-shift.ll @@ -564,14 +564,7 @@ define i40 @test26(i40 %A) { ; https://bugs.chromium.org/p/oss-fuzz/issues/detail?id=9880 define i177 @ossfuzz_9880(i177 %X) { ; CHECK-LABEL: @ossfuzz_9880( -; CHECK-NEXT: [[A:%.*]] = alloca i177, align 8 -; CHECK-NEXT: [[L1:%.*]] = load i177, ptr [[A]], align 4 -; CHECK-NEXT: [[TMP1:%.*]] = icmp eq i177 [[L1]], -1 -; CHECK-NEXT: [[B5_NEG:%.*]] = sext i1 [[TMP1]] to i177 -; CHECK-NEXT: [[B14:%.*]] = add i177 [[L1]], [[B5_NEG]] -; CHECK-NEXT: [[TMP2:%.*]] = icmp eq i177 [[B14]], -1 -; CHECK-NEXT: [[B1:%.*]] = zext i1 [[TMP2]] to i177 -; CHECK-NEXT: ret i177 [[B1]] +; CHECK-NEXT: ret i177 0 ; %A = alloca i177 %L1 = load i177, ptr %A diff --git a/llvm/test/Transforms/InstCombine/fold-ext-eq-c-with-op.ll b/llvm/test/Transforms/InstCombine/fold-ext-eq-c-with-op.ll new file mode 100644 index 0000000000000..870f976a97c5f --- /dev/null +++ b/llvm/test/Transforms/InstCombine/fold-ext-eq-c-with-op.ll @@ -0,0 +1,189 @@ +; NOTE: Assertions have been autogenerated by utils/update_test_checks.py +; RUN: opt < %s -passes=instcombine -S | FileCheck %s + +declare void @use.i8(i8) +define i8 @fold_add_zext_eq_0(i8 %x) { +; CHECK-LABEL: @fold_add_zext_eq_0( +; CHECK-NEXT: [[R:%.*]] = call i8 @llvm.umax.i8(i8 [[X:%.*]], i8 1) +; CHECK-NEXT: ret i8 [[R]] +; + %x_eq = icmp eq i8 %x, 0 + %x_eq_ext = zext i1 %x_eq to i8 + %r = add i8 %x, %x_eq_ext + ret i8 %r +} + +define i8 @fold_add_sext_eq_0(i8 %x) { +; CHECK-LABEL: @fold_add_sext_eq_0( +; CHECK-NEXT: [[X_EQ:%.*]] = icmp eq i8 [[X:%.*]], 0 +; CHECK-NEXT: [[R:%.*]] = select i1 [[X_EQ]], i8 -1, i8 [[X]] +; CHECK-NEXT: ret i8 [[R]] +; + %x_eq = icmp eq i8 %x, 0 + %x_eq_ext = sext i1 %x_eq to i8 + %r = add i8 %x, %x_eq_ext + ret i8 %r +} + +define i8 @fold_add_zext_eq_0_fail_multiuse_exp(i8 %x) { +; CHECK-LABEL: @fold_add_zext_eq_0_fail_multiuse_exp( +; CHECK-NEXT: [[X_EQ:%.*]] = icmp eq i8 [[X:%.*]], 0 +; CHECK-NEXT: [[X_EQ_EXT:%.*]] = zext i1 [[X_EQ]] to i8 +; CHECK-NEXT: [[R:%.*]] = add i8 [[X_EQ_EXT]], [[X]] +; CHECK-NEXT: call void @use.i8(i8 [[X_EQ_EXT]]) +; CHECK-NEXT: ret i8 [[R]] +; + %x_eq = icmp eq i8 %x, 0 + %x_eq_ext = zext i1 %x_eq to i8 + %r = add i8 %x, %x_eq_ext + call void @use.i8(i8 %x_eq_ext) + ret i8 %r +} + +define i8 @fold_mul_sext_eq_12(i8 %x) { +; CHECK-LABEL: @fold_mul_sext_eq_12( +; CHECK-NEXT: [[X_EQ:%.*]] = icmp eq i8 [[X:%.*]], 12 +; CHECK-NEXT: [[R:%.*]] = select i1 [[X_EQ]], i8 -12, i8 0 +; CHECK-NEXT: ret i8 [[R]] +; + %x_eq = icmp eq i8 %x, 12 + %x_eq_ext = sext i1 %x_eq to i8 + %r = mul i8 %x, %x_eq_ext + ret i8 %r +} + +define i8 @fold_mul_sext_eq_12_fail_multiuse(i8 %x) { +; CHECK-LABEL: @fold_mul_sext_eq_12_fail_multiuse( +; CHECK-NEXT: [[X_EQ:%.*]] = icmp eq i8 [[X:%.*]], 12 +; CHECK-NEXT: [[X_EQ_EXT:%.*]] = sext i1 [[X_EQ]] to i8 +; CHECK-NEXT: [[R:%.*]] = mul i8 [[X_EQ_EXT]], [[X]] +; CHECK-NEXT: call void @use.i8(i8 [[X_EQ_EXT]]) +; CHECK-NEXT: ret i8 [[R]] +; + %x_eq = icmp eq i8 %x, 12 + %x_eq_ext = sext i1 %x_eq to i8 + %r = mul i8 %x, %x_eq_ext + call void @use.i8(i8 %x_eq_ext) + ret i8 %r +} + +define i8 @fold_shl_zext_eq_3_rhs(i8 %x) { +; CHECK-LABEL: @fold_shl_zext_eq_3_rhs( +; CHECK-NEXT: [[X_EQ:%.*]] = icmp eq i8 [[X:%.*]], 3 +; CHECK-NEXT: [[R:%.*]] = select i1 [[X_EQ]], i8 6, i8 [[X]] +; CHECK-NEXT: ret i8 [[R]] +; + %x_eq = icmp eq i8 %x, 3 + %x_eq_ext = zext i1 %x_eq to i8 + %r = shl i8 %x, %x_eq_ext + ret i8 %r +} + +define i8 @fold_shl_zext_eq_3_lhs(i8 %x) { +; CHECK-LABEL: @fold_shl_zext_eq_3_lhs( +; CHECK-NEXT: [[X_EQ:%.*]] = icmp eq i8 [[X:%.*]], 3 +; CHECK-NEXT: [[R:%.*]] = select i1 [[X_EQ]], i8 8, i8 0 +; CHECK-NEXT: ret i8 [[R]] +; + %x_eq = icmp eq i8 %x, 3 + %x_eq_ext = zext i1 %x_eq to i8 + %r = shl i8 %x_eq_ext, %x + ret i8 %r +} + +define <2 x i8> @fold_lshr_sext_eq_15_5_lhs(<2 x i8> %x) { +; CHECK-LABEL: @fold_lshr_sext_eq_15_5_lhs( +; CHECK-NEXT: [[X_EQ:%.*]] = icmp eq <2 x i8> [[X:%.*]], +; CHECK-NEXT: [[R:%.*]] = select <2 x i1> [[X_EQ]], <2 x i8> , <2 x i8> zeroinitializer +; CHECK-NEXT: ret <2 x i8> [[R]] +; + %x_eq = icmp eq <2 x i8> %x, + %x_eq_ext = sext <2 x i1> %x_eq to <2 x i8> + %r = lshr <2 x i8> %x_eq_ext, %x + ret <2 x i8> %r +} + +define <2 x i8> @fold_lshr_sext_eq_15_poison_rhs(<2 x i8> %x) { +; CHECK-LABEL: @fold_lshr_sext_eq_15_poison_rhs( +; CHECK-NEXT: ret <2 x i8> [[X:%.*]] +; + %x_eq = icmp eq <2 x i8> %x, + %x_eq_ext = sext <2 x i1> %x_eq to <2 x i8> + %r = lshr <2 x i8> %x, %x_eq_ext + ret <2 x i8> %r +} + +define i8 @fold_umax_zext_eq_9(i8 %x) { +; CHECK-LABEL: @fold_umax_zext_eq_9( +; CHECK-NEXT: [[X_EQ:%.*]] = icmp eq i8 [[X:%.*]], 9 +; CHECK-NEXT: [[R:%.*]] = select i1 [[X_EQ]], i8 -1, i8 [[X]] +; CHECK-NEXT: ret i8 [[R]] +; + %x_eq = icmp eq i8 %x, 9 + %x_eq_ext = sext i1 %x_eq to i8 + %r = call i8 @llvm.umax.i8(i8 %x, i8 %x_eq_ext) + ret i8 %r +} + +define i8 @fold_sshl_sat_sext_eq_3_rhs(i8 %x) { +; CHECK-LABEL: @fold_sshl_sat_sext_eq_3_rhs( +; CHECK-NEXT: [[X_EQ:%.*]] = icmp eq i8 [[X:%.*]], 3 +; CHECK-NEXT: [[X_EQ_EXT:%.*]] = sext i1 [[X_EQ]] to i8 +; CHECK-NEXT: [[R:%.*]] = call i8 @llvm.sshl.sat.i8(i8 [[X]], i8 [[X_EQ_EXT]]) +; CHECK-NEXT: ret i8 [[R]] +; + %x_eq = icmp eq i8 %x, 3 + %x_eq_ext = sext i1 %x_eq to i8 + %r = call i8 @llvm.sshl.sat.i8(i8 %x, i8 %x_eq_ext) + ret i8 %r +} + +define i8 @fold_ushl_sat_zext_eq_3_lhs(i8 %x) { +; CHECK-LABEL: @fold_ushl_sat_zext_eq_3_lhs( +; CHECK-NEXT: [[X_EQ:%.*]] = icmp eq i8 [[X:%.*]], 3 +; CHECK-NEXT: [[X_EQ_EXT:%.*]] = zext i1 [[X_EQ]] to i8 +; CHECK-NEXT: [[R:%.*]] = call i8 @llvm.ushl.sat.i8(i8 [[X_EQ_EXT]], i8 [[X]]) +; CHECK-NEXT: ret i8 [[R]] +; + %x_eq = icmp eq i8 %x, 3 + %x_eq_ext = zext i1 %x_eq to i8 + %r = call i8 @llvm.ushl.sat.i8(i8 %x_eq_ext, i8 %x) + ret i8 %r +} + +define i8 @fold_uadd_sat_zext_eq_3_rhs(i8 %x) { +; CHECK-LABEL: @fold_uadd_sat_zext_eq_3_rhs( +; CHECK-NEXT: [[X_EQ:%.*]] = icmp eq i8 [[X:%.*]], 3 +; CHECK-NEXT: [[R:%.*]] = select i1 [[X_EQ]], i8 4, i8 [[X]] +; CHECK-NEXT: ret i8 [[R]] +; + %x_eq = icmp eq i8 %x, 3 + %x_eq_ext = zext i1 %x_eq to i8 + %r = call i8 @llvm.uadd.sat.i8(i8 %x, i8 %x_eq_ext) + ret i8 %r +} + +define i8 @fold_ssub_sat_sext_eq_99_lhs_fail(i8 %x) { +; CHECK-LABEL: @fold_ssub_sat_sext_eq_99_lhs_fail( +; CHECK-NEXT: [[X_EQ:%.*]] = icmp eq i8 [[X:%.*]], 99 +; CHECK-NEXT: [[X_EQ_EXT:%.*]] = sext i1 [[X_EQ]] to i8 +; CHECK-NEXT: [[R:%.*]] = call i8 @llvm.ssub.sat.i8(i8 [[X_EQ_EXT]], i8 [[X]]) +; CHECK-NEXT: ret i8 [[R]] +; + %x_eq = icmp eq i8 %x, 99 + %x_eq_ext = sext i1 %x_eq to i8 + %r = call i8 @llvm.ssub.sat.i8(i8 %x_eq_ext, i8 %x) + ret i8 %r +} + +define i8 @fold_ssub_sat_zext_eq_99_rhs(i8 %x) { +; CHECK-LABEL: @fold_ssub_sat_zext_eq_99_rhs( +; CHECK-NEXT: [[X_EQ:%.*]] = icmp eq i8 [[X:%.*]], 99 +; CHECK-NEXT: [[R:%.*]] = select i1 [[X_EQ]], i8 98, i8 [[X]] +; CHECK-NEXT: ret i8 [[R]] +; + %x_eq = icmp eq i8 %x, 99 + %x_eq_ext = zext i1 %x_eq to i8 + %r = call i8 @llvm.ssub.sat.i8(i8 %x, i8 %x_eq_ext) + ret i8 %r +}