diff --git a/llvm/lib/Transforms/Scalar/EarlyCSE.cpp b/llvm/lib/Transforms/Scalar/EarlyCSE.cpp index 4990fa9f8b5ea..f736d429cb638 100644 --- a/llvm/lib/Transforms/Scalar/EarlyCSE.cpp +++ b/llvm/lib/Transforms/Scalar/EarlyCSE.cpp @@ -314,14 +314,14 @@ static unsigned getHashValueImpl(SimpleValue Val) { "Invalid/unknown instruction"); // Handle intrinsics with commutative operands. - // TODO: Extend this to handle intrinsics with >2 operands where the 1st - // 2 operands are commutative. auto *II = dyn_cast(Inst); - if (II && II->isCommutative() && II->arg_size() == 2) { + if (II && II->isCommutative() && II->arg_size() >= 2) { Value *LHS = II->getArgOperand(0), *RHS = II->getArgOperand(1); if (LHS > RHS) std::swap(LHS, RHS); - return hash_combine(II->getOpcode(), LHS, RHS); + return hash_combine( + II->getOpcode(), LHS, RHS, + hash_combine_range(II->value_op_begin() + 2, II->value_op_end())); } // gc.relocate is 'special' call: its second and third operands are @@ -396,13 +396,14 @@ static bool isEqualImpl(SimpleValue LHS, SimpleValue RHS) { LHSCmp->getSwappedPredicate() == RHSCmp->getPredicate(); } - // TODO: Extend this for >2 args by matching the trailing N-2 args. auto *LII = dyn_cast(LHSI); auto *RII = dyn_cast(RHSI); if (LII && RII && LII->getIntrinsicID() == RII->getIntrinsicID() && - LII->isCommutative() && LII->arg_size() == 2) { + LII->isCommutative() && LII->arg_size() >= 2) { return LII->getArgOperand(0) == RII->getArgOperand(1) && - LII->getArgOperand(1) == RII->getArgOperand(0); + LII->getArgOperand(1) == RII->getArgOperand(0) && + std::equal(LII->arg_begin() + 2, LII->arg_end(), + RII->arg_begin() + 2, RII->arg_end()); } // See comment above in `getHashValue()`. diff --git a/llvm/test/Transforms/EarlyCSE/commute.ll b/llvm/test/Transforms/EarlyCSE/commute.ll index 6aaaf992e4414..1cf7ddda7f0dd 100644 --- a/llvm/test/Transforms/EarlyCSE/commute.ll +++ b/llvm/test/Transforms/EarlyCSE/commute.ll @@ -999,14 +999,10 @@ define i4 @smin_umin(i4 %a, i4 %b) { ret i4 %o } -; TODO: handle >2 args - define i16 @smul_fix(i16 %a, i16 %b) { ; CHECK-LABEL: @smul_fix( ; CHECK-NEXT: [[X:%.*]] = call i16 @llvm.smul.fix.i16(i16 [[A:%.*]], i16 [[B:%.*]], i32 3) -; CHECK-NEXT: [[Y:%.*]] = call i16 @llvm.smul.fix.i16(i16 [[B]], i16 [[A]], i32 3) -; CHECK-NEXT: [[O:%.*]] = or i16 [[X]], [[Y]] -; CHECK-NEXT: ret i16 [[O]] +; CHECK-NEXT: ret i16 [[X]] ; %x = call i16 @llvm.smul.fix.i16(i16 %a, i16 %b, i32 3) %y = call i16 @llvm.smul.fix.i16(i16 %b, i16 %a, i32 3) @@ -1014,14 +1010,10 @@ define i16 @smul_fix(i16 %a, i16 %b) { ret i16 %o } -; TODO: handle >2 args - define i16 @umul_fix(i16 %a, i16 %b, i32 %s) { ; CHECK-LABEL: @umul_fix( ; CHECK-NEXT: [[X:%.*]] = call i16 @llvm.umul.fix.i16(i16 [[A:%.*]], i16 [[B:%.*]], i32 1) -; CHECK-NEXT: [[Y:%.*]] = call i16 @llvm.umul.fix.i16(i16 [[B]], i16 [[A]], i32 1) -; CHECK-NEXT: [[O:%.*]] = or i16 [[X]], [[Y]] -; CHECK-NEXT: ret i16 [[O]] +; CHECK-NEXT: ret i16 [[X]] ; %x = call i16 @llvm.umul.fix.i16(i16 %a, i16 %b, i32 1) %y = call i16 @llvm.umul.fix.i16(i16 %b, i16 %a, i32 1) @@ -1029,14 +1021,10 @@ define i16 @umul_fix(i16 %a, i16 %b, i32 %s) { ret i16 %o } -; TODO: handle >2 args - define <3 x i16> @smul_fix_sat(<3 x i16> %a, <3 x i16> %b) { ; CHECK-LABEL: @smul_fix_sat( ; CHECK-NEXT: [[X:%.*]] = call <3 x i16> @llvm.smul.fix.sat.v3i16(<3 x i16> [[A:%.*]], <3 x i16> [[B:%.*]], i32 2) -; CHECK-NEXT: [[Y:%.*]] = call <3 x i16> @llvm.smul.fix.sat.v3i16(<3 x i16> [[B]], <3 x i16> [[A]], i32 2) -; CHECK-NEXT: [[O:%.*]] = or <3 x i16> [[X]], [[Y]] -; CHECK-NEXT: ret <3 x i16> [[O]] +; CHECK-NEXT: ret <3 x i16> [[X]] ; %x = call <3 x i16> @llvm.smul.fix.sat.v3i16(<3 x i16> %a, <3 x i16> %b, i32 2) %y = call <3 x i16> @llvm.smul.fix.sat.v3i16(<3 x i16> %b, <3 x i16> %a, i32 2) @@ -1044,14 +1032,10 @@ define <3 x i16> @smul_fix_sat(<3 x i16> %a, <3 x i16> %b) { ret <3 x i16> %o } -; TODO: handle >2 args - define <3 x i16> @umul_fix_sat(<3 x i16> %a, <3 x i16> %b) { ; CHECK-LABEL: @umul_fix_sat( ; CHECK-NEXT: [[X:%.*]] = call <3 x i16> @llvm.umul.fix.sat.v3i16(<3 x i16> [[A:%.*]], <3 x i16> [[B:%.*]], i32 3) -; CHECK-NEXT: [[Y:%.*]] = call <3 x i16> @llvm.umul.fix.sat.v3i16(<3 x i16> [[B]], <3 x i16> [[A]], i32 3) -; CHECK-NEXT: [[O:%.*]] = or <3 x i16> [[X]], [[Y]] -; CHECK-NEXT: ret <3 x i16> [[O]] +; CHECK-NEXT: ret <3 x i16> [[X]] ; %x = call <3 x i16> @llvm.umul.fix.sat.v3i16(<3 x i16> %a, <3 x i16> %b, i32 3) %y = call <3 x i16> @llvm.umul.fix.sat.v3i16(<3 x i16> %b, <3 x i16> %a, i32 3) @@ -1085,17 +1069,26 @@ define i16 @umul_fix_scale(i16 %a, i16 %b, i32 %s) { ret i16 %o } -; TODO: handle >2 args - define float @fma(float %a, float %b, float %c) { ; CHECK-LABEL: @fma( ; CHECK-NEXT: [[X:%.*]] = call float @llvm.fma.f32(float [[A:%.*]], float [[B:%.*]], float [[C:%.*]]) -; CHECK-NEXT: [[Y:%.*]] = call float @llvm.fma.f32(float [[B]], float [[A]], float [[C]]) +; CHECK-NEXT: ret float 1.000000e+00 +; + %x = call float @llvm.fma.f32(float %a, float %b, float %c) + %y = call float @llvm.fma.f32(float %b, float %a, float %c) + %r = fdiv nnan float %x, %y + ret float %r +} + +define float @fma_fail(float %a, float %b, float %c) { +; CHECK-LABEL: @fma_fail( +; CHECK-NEXT: [[X:%.*]] = call float @llvm.fma.f32(float [[A:%.*]], float [[B:%.*]], float [[C:%.*]]) +; CHECK-NEXT: [[Y:%.*]] = call float @llvm.fma.f32(float [[A]], float [[C]], float [[B]]) ; CHECK-NEXT: [[R:%.*]] = fdiv nnan float [[X]], [[Y]] ; CHECK-NEXT: ret float [[R]] ; %x = call float @llvm.fma.f32(float %a, float %b, float %c) - %y = call float @llvm.fma.f32(float %b, float %a, float %c) + %y = call float @llvm.fma.f32(float %a, float %c, float %b) %r = fdiv nnan float %x, %y ret float %r } @@ -1113,17 +1106,39 @@ define float @fma_different_add_ops(float %a, float %b, float %c, float %d) { ret float %r } -; TODO: handle >2 args - define <2 x double> @fmuladd(<2 x double> %a, <2 x double> %b, <2 x double> %c) { ; CHECK-LABEL: @fmuladd( ; CHECK-NEXT: [[X:%.*]] = call <2 x double> @llvm.fmuladd.v2f64(<2 x double> [[A:%.*]], <2 x double> [[B:%.*]], <2 x double> [[C:%.*]]) -; CHECK-NEXT: [[Y:%.*]] = call <2 x double> @llvm.fmuladd.v2f64(<2 x double> [[B]], <2 x double> [[A]], <2 x double> [[C]]) +; CHECK-NEXT: ret <2 x double> +; + %x = call <2 x double> @llvm.fmuladd.v2f64(<2 x double> %a, <2 x double> %b, <2 x double> %c) + %y = call <2 x double> @llvm.fmuladd.v2f64(<2 x double> %b, <2 x double> %a, <2 x double> %c) + %r = fdiv nnan <2 x double> %x, %y + ret <2 x double> %r +} + +define <2 x double> @fmuladd_fail1(<2 x double> %a, <2 x double> %b, <2 x double> %c) { +; CHECK-LABEL: @fmuladd_fail1( +; CHECK-NEXT: [[X:%.*]] = call <2 x double> @llvm.fmuladd.v2f64(<2 x double> [[A:%.*]], <2 x double> [[B:%.*]], <2 x double> [[C:%.*]]) +; CHECK-NEXT: [[Y:%.*]] = call <2 x double> @llvm.fmuladd.v2f64(<2 x double> [[C]], <2 x double> [[B]], <2 x double> [[A]]) ; CHECK-NEXT: [[R:%.*]] = fdiv nnan <2 x double> [[X]], [[Y]] ; CHECK-NEXT: ret <2 x double> [[R]] ; %x = call <2 x double> @llvm.fmuladd.v2f64(<2 x double> %a, <2 x double> %b, <2 x double> %c) - %y = call <2 x double> @llvm.fmuladd.v2f64(<2 x double> %b, <2 x double> %a, <2 x double> %c) + %y = call <2 x double> @llvm.fmuladd.v2f64(<2 x double> %c, <2 x double> %b, <2 x double> %a) + %r = fdiv nnan <2 x double> %x, %y + ret <2 x double> %r +} + +define <2 x double> @fmuladd_fail2(<2 x double> %a, <2 x double> %b, <2 x double> %c) { +; CHECK-LABEL: @fmuladd_fail2( +; CHECK-NEXT: [[X:%.*]] = call <2 x double> @llvm.fmuladd.v2f64(<2 x double> [[A:%.*]], <2 x double> [[B:%.*]], <2 x double> [[C:%.*]]) +; CHECK-NEXT: [[Y:%.*]] = call <2 x double> @llvm.fmuladd.v2f64(<2 x double> [[A]], <2 x double> [[C]], <2 x double> [[B]]) +; CHECK-NEXT: [[R:%.*]] = fdiv nnan <2 x double> [[X]], [[Y]] +; CHECK-NEXT: ret <2 x double> [[R]] +; + %x = call <2 x double> @llvm.fmuladd.v2f64(<2 x double> %a, <2 x double> %b, <2 x double> %c) + %y = call <2 x double> @llvm.fmuladd.v2f64(<2 x double> %a, <2 x double> %c, <2 x double> %b) %r = fdiv nnan <2 x double> %x, %y ret <2 x double> %r }