diff --git a/llvm/include/llvm/Analysis/InstructionSimplify.h b/llvm/include/llvm/Analysis/InstructionSimplify.h index 0a2f199794f8b7..d75e0415679494 100644 --- a/llvm/include/llvm/Analysis/InstructionSimplify.h +++ b/llvm/include/llvm/Analysis/InstructionSimplify.h @@ -143,17 +143,36 @@ struct SimplifyQuery { // deprecated. // Please use the SimplifyQuery versions in new code. -/// Given operand for an FNeg, fold the result or return null. -Value *simplifyFNegInst(Value *Op, FastMathFlags FMF, const SimplifyQuery &Q); - /// Given operands for an Add, fold the result or return null. -Value *simplifyAddInst(Value *LHS, Value *RHS, bool isNSW, bool isNUW, +Value *simplifyAddInst(Value *LHS, Value *RHS, bool IsNSW, bool IsNUW, const SimplifyQuery &Q); /// Given operands for a Sub, fold the result or return null. -Value *simplifySubInst(Value *LHS, Value *RHS, bool isNSW, bool isNUW, +Value *simplifySubInst(Value *LHS, Value *RHS, bool IsNSW, bool IsNUW, + const SimplifyQuery &Q); + +/// Given operands for a Mul, fold the result or return null. +Value *simplifyMulInst(Value *LHS, Value *RHS, bool IsNSW, bool IsNUW, const SimplifyQuery &Q); +/// Given operands for an SDiv, fold the result or return null. +Value *simplifySDivInst(Value *LHS, Value *RHS, bool IsExact, + const SimplifyQuery &Q); + +/// Given operands for a UDiv, fold the result or return null. +Value *simplifyUDivInst(Value *LHS, Value *RHS, bool IsExact, + const SimplifyQuery &Q); + +/// Given operands for an SRem, fold the result or return null. +Value *simplifySRemInst(Value *LHS, Value *RHS, const SimplifyQuery &Q); + +/// Given operands for a URem, fold the result or return null. +Value *simplifyURemInst(Value *LHS, Value *RHS, const SimplifyQuery &Q); + +/// Given operand for an FNeg, fold the result or return null. +Value *simplifyFNegInst(Value *Op, FastMathFlags FMF, const SimplifyQuery &Q); + + /// Given operands for an FAdd, fold the result or return null. Value * simplifyFAddInst(Value *LHS, Value *RHS, FastMathFlags FMF, @@ -184,17 +203,6 @@ Value *simplifyFMAFMul(Value *LHS, Value *RHS, FastMathFlags FMF, fp::ExceptionBehavior ExBehavior = fp::ebIgnore, RoundingMode Rounding = RoundingMode::NearestTiesToEven); -/// Given operands for a Mul, fold the result or return null. -Value *simplifyMulInst(Value *LHS, Value *RHS, const SimplifyQuery &Q); - -/// Given operands for an SDiv, fold the result or return null. -Value *simplifySDivInst(Value *LHS, Value *RHS, bool IsExact, - const SimplifyQuery &Q); - -/// Given operands for a UDiv, fold the result or return null. -Value *simplifyUDivInst(Value *LHS, Value *RHS, bool IsExact, - const SimplifyQuery &Q); - /// Given operands for an FDiv, fold the result or return null. Value * simplifyFDivInst(Value *LHS, Value *RHS, FastMathFlags FMF, @@ -202,12 +210,6 @@ simplifyFDivInst(Value *LHS, Value *RHS, FastMathFlags FMF, fp::ExceptionBehavior ExBehavior = fp::ebIgnore, RoundingMode Rounding = RoundingMode::NearestTiesToEven); -/// Given operands for an SRem, fold the result or return null. -Value *simplifySRemInst(Value *LHS, Value *RHS, const SimplifyQuery &Q); - -/// Given operands for a URem, fold the result or return null. -Value *simplifyURemInst(Value *LHS, Value *RHS, const SimplifyQuery &Q); - /// Given operands for an FRem, fold the result or return null. Value * simplifyFRemInst(Value *LHS, Value *RHS, FastMathFlags FMF, @@ -216,15 +218,15 @@ simplifyFRemInst(Value *LHS, Value *RHS, FastMathFlags FMF, RoundingMode Rounding = RoundingMode::NearestTiesToEven); /// Given operands for a Shl, fold the result or return null. -Value *simplifyShlInst(Value *Op0, Value *Op1, bool isNSW, bool isNUW, +Value *simplifyShlInst(Value *Op0, Value *Op1, bool IsNSW, bool IsNUW, const SimplifyQuery &Q); /// Given operands for a LShr, fold the result or return null. -Value *simplifyLShrInst(Value *Op0, Value *Op1, bool isExact, +Value *simplifyLShrInst(Value *Op0, Value *Op1, bool IsExact, const SimplifyQuery &Q); /// Given operands for a AShr, fold the result or return nulll. -Value *simplifyAShrInst(Value *Op0, Value *Op1, bool isExact, +Value *simplifyAShrInst(Value *Op0, Value *Op1, bool IsExact, const SimplifyQuery &Q); /// Given operands for an And, fold the result or return null. diff --git a/llvm/lib/Analysis/InstructionSimplify.cpp b/llvm/lib/Analysis/InstructionSimplify.cpp index 5f6548b9cd59c5..53434c4cb2ac3c 100644 --- a/llvm/lib/Analysis/InstructionSimplify.cpp +++ b/llvm/lib/Analysis/InstructionSimplify.cpp @@ -922,8 +922,8 @@ Value *llvm::simplifySubInst(Value *Op0, Value *Op1, bool IsNSW, bool IsNUW, /// Given operands for a Mul, see if we can fold the result. /// If not, this returns null. -static Value *simplifyMulInst(Value *Op0, Value *Op1, const SimplifyQuery &Q, - unsigned MaxRecurse) { +static Value *simplifyMulInst(Value *Op0, Value *Op1, bool IsNSW, bool IsNUW, + const SimplifyQuery &Q, unsigned MaxRecurse) { if (Constant *C = foldOrCommuteConstant(Instruction::Mul, Op0, Op1, Q)) return C; @@ -980,8 +980,9 @@ static Value *simplifyMulInst(Value *Op0, Value *Op1, const SimplifyQuery &Q, return nullptr; } -Value *llvm::simplifyMulInst(Value *Op0, Value *Op1, const SimplifyQuery &Q) { - return ::simplifyMulInst(Op0, Op1, Q, RecursionLimit); +Value *llvm::simplifyMulInst(Value *Op0, Value *Op1, bool IsNSW, bool IsNUW, + const SimplifyQuery &Q) { + return ::simplifyMulInst(Op0, Op1, IsNSW, IsNUW, Q, RecursionLimit); } /// Check for common or similar folds of integer division or integer remainder. @@ -5707,7 +5708,8 @@ static Value *simplifyBinOp(unsigned Opcode, Value *LHS, Value *RHS, return simplifySubInst(LHS, RHS, /* IsNSW */ false, /* IsNUW */ false, Q, MaxRecurse); case Instruction::Mul: - return simplifyMulInst(LHS, RHS, Q, MaxRecurse); + return simplifyMulInst(LHS, RHS, /* IsNSW */ false, /* IsNUW */ false, Q, + MaxRecurse); case Instruction::SDiv: return simplifySDivInst(LHS, RHS, /* IsExact */ false, Q, MaxRecurse); case Instruction::UDiv: @@ -6582,7 +6584,9 @@ static Value *simplifyInstructionWithOperands(Instruction *I, case Instruction::FMul: return simplifyFMulInst(NewOps[0], NewOps[1], I->getFastMathFlags(), Q); case Instruction::Mul: - return simplifyMulInst(NewOps[0], NewOps[1], Q); + return simplifyMulInst(NewOps[0], NewOps[1], + Q.IIQ.hasNoSignedWrap(cast(I)), + Q.IIQ.hasNoUnsignedWrap(cast(I)), Q); case Instruction::SDiv: return simplifySDivInst(NewOps[0], NewOps[1], Q.IIQ.isExact(cast(I)), Q); diff --git a/llvm/lib/Transforms/InstCombine/InstCombineMulDivRem.cpp b/llvm/lib/Transforms/InstCombine/InstCombineMulDivRem.cpp index 460731d29c8b74..97f129e200de72 100644 --- a/llvm/lib/Transforms/InstCombine/InstCombineMulDivRem.cpp +++ b/llvm/lib/Transforms/InstCombine/InstCombineMulDivRem.cpp @@ -187,7 +187,9 @@ static Value *foldMulShl1(BinaryOperator &Mul, bool CommuteOperands, Instruction *InstCombinerImpl::visitMul(BinaryOperator &I) { Value *Op0 = I.getOperand(0), *Op1 = I.getOperand(1); - if (Value *V = simplifyMulInst(Op0, Op1, SQ.getWithInstruction(&I))) + if (Value *V = + simplifyMulInst(Op0, Op1, I.hasNoSignedWrap(), I.hasNoUnsignedWrap(), + SQ.getWithInstruction(&I))) return replaceInstUsesWith(I, V); if (SimplifyAssociativeOrCommutative(I)) diff --git a/llvm/test/Transforms/InstSimplify/mul.ll b/llvm/test/Transforms/InstSimplify/mul.ll index 902bde54841f6d..443a2250b0a200 100644 --- a/llvm/test/Transforms/InstSimplify/mul.ll +++ b/llvm/test/Transforms/InstSimplify/mul.ll @@ -50,12 +50,39 @@ define i32 @poison(i32 %x) { ret i32 %v } +define i1 @mul_i1(i1 %x, i1 %y) { +; CHECK-LABEL: @mul_i1( +; CHECK-NEXT: [[R:%.*]] = mul i1 [[X:%.*]], [[Y:%.*]] +; CHECK-NEXT: ret i1 [[R]] +; + %r = mul i1 %x, %y + ret i1 %r +} + +define i1 @mul_i1_nsw(i1 %x, i1 %y) { +; CHECK-LABEL: @mul_i1_nsw( +; CHECK-NEXT: [[R:%.*]] = mul nsw i1 [[X:%.*]], [[Y:%.*]] +; CHECK-NEXT: ret i1 [[R]] +; + %r = mul nsw i1 %x, %y + ret i1 %r +} + +define i1 @mul_i1_nuw(i1 %x, i1 %y) { +; CHECK-LABEL: @mul_i1_nuw( +; CHECK-NEXT: [[R:%.*]] = mul nuw i1 [[X:%.*]], [[Y:%.*]] +; CHECK-NEXT: ret i1 [[R]] +; + %r = mul nuw i1 %x, %y + ret i1 %r +} + define i1 @square_i1(i1 %x) { ; CHECK-LABEL: @square_i1( ; CHECK-NEXT: ret i1 [[X:%.*]] ; %r = mul i1 %x, %x - ret i1 %x + ret i1 %r } define i1 @square_i1_nsw(i1 %x) { @@ -63,7 +90,7 @@ define i1 @square_i1_nsw(i1 %x) { ; CHECK-NEXT: ret i1 [[X:%.*]] ; %r = mul nsw i1 %x, %x - ret i1 %x + ret i1 %r } define i1 @square_i1_nuw(i1 %x) { @@ -71,5 +98,5 @@ define i1 @square_i1_nuw(i1 %x) { ; CHECK-NEXT: ret i1 [[X:%.*]] ; %r = mul nuw i1 %x, %x - ret i1 %x + ret i1 %r }