diff --git a/llvm/lib/Target/X86/X86ISelLowering.cpp b/llvm/lib/Target/X86/X86ISelLowering.cpp index 3631016b0f5c7..635169b2feea8 100644 --- a/llvm/lib/Target/X86/X86ISelLowering.cpp +++ b/llvm/lib/Target/X86/X86ISelLowering.cpp @@ -60186,8 +60186,30 @@ static SDValue combineVPMADD(SDNode *N, SelectionDAG &DAG, static SDValue combineVPMADD52LH(SDNode *N, SelectionDAG &DAG, TargetLowering::DAGCombinerInfo &DCI) { MVT VT = N->getSimpleValueType(0); - unsigned NumEltBits = VT.getScalarSizeInBits(); + + bool AddLow = N->getOpcode() == X86ISD::VPMADD52L; + SDValue Op0 = N->getOperand(0); + SDValue Op1 = N->getOperand(1); + SDValue Op2 = N->getOperand(2); + SDLoc DL(N); + + APInt C0, C1; + bool HasC0 = X86::isConstantSplat(Op0, C0), + HasC1 = X86::isConstantSplat(Op1, C1); + + // lo/hi(C * X) + Z --> lo/hi(X * C) + Z + if (HasC0 && !HasC1) + return DAG.getNode(N->getOpcode(), DL, VT, Op1, Op0, Op2); + + // lo(X * 1) + Z --> lo(X) + Z iff X == lo(X) + if (AddLow && HasC1 && C1.trunc(52).isOne()) { + KnownBits KnownOp0 = DAG.computeKnownBits(Op0); + if (KnownOp0.countMinLeadingZeros() >= 12) + return DAG.getNode(ISD::ADD, DL, VT, Op0, Op2); + } + const TargetLowering &TLI = DAG.getTargetLoweringInfo(); + unsigned NumEltBits = VT.getScalarSizeInBits(); if (TLI.SimplifyDemandedBits(SDValue(N, 0), APInt::getAllOnes(NumEltBits), DCI)) return SDValue(N, 0); diff --git a/llvm/test/CodeGen/X86/combine-vpmadd52.ll b/llvm/test/CodeGen/X86/combine-vpmadd52.ll index 2cb060ea92b14..8b741e9ef9482 100644 --- a/llvm/test/CodeGen/X86/combine-vpmadd52.ll +++ b/llvm/test/CodeGen/X86/combine-vpmadd52.ll @@ -398,3 +398,60 @@ define <2 x i64> @test3_knownbits_vpmadd52h_negative(<2 x i64> %x0, <2 x i64> %x %ret = and <2 x i64> %madd, splat (i64 1) ret <2 x i64> %ret } + +define <2 x i64> @test_vpmadd52l_mul_one(<2 x i64> %x0, <2 x i32> %x1) { +; CHECK-LABEL: test_vpmadd52l_mul_one: +; CHECK: # %bb.0: +; CHECK-NEXT: vpmovzxdq {{.*#+}} xmm1 = xmm1[0],zero,xmm1[1],zero +; CHECK-NEXT: vpaddq %xmm0, %xmm1, %xmm0 +; CHECK-NEXT: retq + %ext = zext <2 x i32> %x1 to <2 x i64> + %ifma = call <2 x i64> @llvm.x86.avx512.vpmadd52l.uq.128(<2 x i64> %x0, <2 x i64> splat(i64 1), <2 x i64> %ext) + ret <2 x i64> %ifma +} + +define <2 x i64> @test_vpmadd52l_mul_one_commuted(<2 x i64> %x0, <2 x i32> %x1) { +; CHECK-LABEL: test_vpmadd52l_mul_one_commuted: +; CHECK: # %bb.0: +; CHECK-NEXT: vpmovzxdq {{.*#+}} xmm1 = xmm1[0],zero,xmm1[1],zero +; CHECK-NEXT: vpaddq %xmm0, %xmm1, %xmm0 +; CHECK-NEXT: retq + %ext = zext <2 x i32> %x1 to <2 x i64> + %ifma = call <2 x i64> @llvm.x86.avx512.vpmadd52l.uq.128(<2 x i64> %x0, <2 x i64> %ext, <2 x i64> splat(i64 1)) + ret <2 x i64> %ifma +} + +define <2 x i64> @test_vpmadd52l_mul_one_no_mask(<2 x i64> %x0, <2 x i64> %x1) { +; AVX512-LABEL: test_vpmadd52l_mul_one_no_mask: +; AVX512: # %bb.0: +; AVX512-NEXT: vpmadd52luq {{\.?LCPI[0-9]+_[0-9]+}}(%rip){1to2}, %xmm1, %xmm0 +; AVX512-NEXT: retq +; +; AVX-LABEL: test_vpmadd52l_mul_one_no_mask: +; AVX: # %bb.0: +; AVX-NEXT: {vex} vpmadd52luq {{\.?LCPI[0-9]+_[0-9]+}}(%rip), %xmm1, %xmm0 +; AVX-NEXT: retq + %ifma = call <2 x i64> @llvm.x86.avx512.vpmadd52l.uq.128(<2 x i64> %x0, <2 x i64> splat(i64 1), <2 x i64> %x1) + ret <2 x i64> %ifma +} + +; Mul by (1 << 52) + 1 +define <2 x i64> @test_vpmadd52l_mul_one_in_52bits(<2 x i64> %x0, <2 x i32> %x1) { +; CHECK-LABEL: test_vpmadd52l_mul_one_in_52bits: +; CHECK: # %bb.0: +; CHECK-NEXT: vpmovzxdq {{.*#+}} xmm1 = xmm1[0],zero,xmm1[1],zero +; CHECK-NEXT: vpaddq %xmm0, %xmm1, %xmm0 +; CHECK-NEXT: retq + %ext = zext <2 x i32> %x1 to <2 x i64> + %ifma = call <2 x i64> @llvm.x86.avx512.vpmadd52l.uq.128(<2 x i64> %x0, <2 x i64> splat(i64 4503599627370497), <2 x i64> %ext) + ret <2 x i64> %ifma +} + +; lo(x1) * 1 = lo(x1), the high 52 bits are zeroes still. +define <2 x i64> @test_vpmadd52h_mul_one(<2 x i64> %x0, <2 x i64> %x1) { +; CHECK-LABEL: test_vpmadd52h_mul_one: +; CHECK: # %bb.0: +; CHECK-NEXT: retq + %ifma = call <2 x i64> @llvm.x86.avx512.vpmadd52h.uq.128(<2 x i64> %x0, <2 x i64> splat(i64 1), <2 x i64> %x1) + ret <2 x i64> %ifma +}