diff --git a/llvm/lib/Target/X86/X86ISelLowering.cpp b/llvm/lib/Target/X86/X86ISelLowering.cpp index eeb5eb8a262de..e09cba5b0059c 100644 --- a/llvm/lib/Target/X86/X86ISelLowering.cpp +++ b/llvm/lib/Target/X86/X86ISelLowering.cpp @@ -38999,6 +38999,27 @@ void X86TargetLowering::computeKnownBitsForTargetNode(const SDValue Op, } break; } + case X86ISD::VPMADD52L: + case X86ISD::VPMADD52H: { + EVT VT = Op.getValueType(); + assert(Op.getValueType().isVector() && + Op.getValueType().getScalarType() == MVT::i64 && + "Unexpected VPMADD52 type"); + KnownBits K0 = + DAG.computeKnownBits(Op.getOperand(0), DemandedElts, Depth + 1); + KnownBits K1 = + DAG.computeKnownBits(Op.getOperand(1), DemandedElts, Depth + 1); + KnownBits KAcc = + DAG.computeKnownBits(Op.getOperand(2), DemandedElts, Depth + 1); + K0 = K0.trunc(52); + K1 = K1.trunc(52); + KnownBits KnownMul = (Op.getOpcode() == X86ISD::VPMADD52L) + ? KnownBits::mul(K0, K1) + : KnownBits::mulhu(K0, K1); + KnownMul = KnownMul.zext(64); + Known = KnownBits::add(KAcc, KnownMul); + return; + } } // Handle target shuffles. diff --git a/llvm/test/CodeGen/X86/combine-vpmadd52HL.ll b/llvm/test/CodeGen/X86/combine-vpmadd52HL.ll new file mode 100644 index 0000000000000..0b5be5fc9900b --- /dev/null +++ b/llvm/test/CodeGen/X86/combine-vpmadd52HL.ll @@ -0,0 +1,138 @@ +; NOTE: Assertions have been autogenerated by utils/update_llc_test_checks.py UTC_ARGS: --version 5 +; RUN: llc < %s -mtriple=x86_64-- -mattr=+avx512ifma,+avx512vl | FileCheck %s --check-prefixes=AVX512VL + + + +; H path: take the high 52 bits of the product and add them to the accumulator +; 25-bit = (1<<25)-1 = 33554431 +; 26-bit = (1<<26)-1 = 67108863 + +declare <2 x i64> @llvm.x86.avx512.vpmadd52h.uq.128(<2 x i64>, <2 x i64>, <2 x i64>) +declare <4 x i64> @llvm.x86.avx512.vpmadd52h.uq.256(<4 x i64>, <4 x i64>, <4 x i64>) +declare <8 x i64> @llvm.x86.avx512.vpmadd52h.uq.512(<8 x i64>, <8 x i64>, <8 x i64>) + +define <2 x i64> @kb52h_128_mask25_and1(<2 x i64> %x, <2 x i64> %y) { +; AVX512VL-LABEL: kb52h_128_mask25_and1: +; AVX512VL: # %bb.0: +; AVX512VL-NEXT: vmovddup {{.*#+}} xmm0 = [1,1] +; AVX512VL-NEXT: # xmm0 = mem[0,0] +; AVX512VL-NEXT: retq + %mx = and <2 x i64> %x, + %my = and <2 x i64> %y, + %r = call <2 x i64> @llvm.x86.avx512.vpmadd52h.uq.128( + <2 x i64> , ; acc + <2 x i64> %mx, ; x (masked to 25-bit) + <2 x i64> %my) ; y (masked to 25-bit) + %ret = and <2 x i64> %r, + ret <2 x i64> %ret +} + +define <4 x i64> @kb52h_256_mask25x26_acc1(<4 x i64> %x, <4 x i64> %y) { +; AVX512VL-LABEL: kb52h_256_mask25x26_acc1: +; AVX512VL: # %bb.0: +; AVX512VL-NEXT: vbroadcastsd {{.*#+}} ymm0 = [1,1,1,1] +; AVX512VL-NEXT: retq + %mx = and <4 x i64> %x, + %my = and <4 x i64> %y, + %r = call <4 x i64> @llvm.x86.avx512.vpmadd52h.uq.256( + <4 x i64> , + <4 x i64> %mx, + <4 x i64> %my) + ret <4 x i64> %r +} + +define <8 x i64> @kb52h_512_mask25_and1(<8 x i64> %x, <8 x i64> %y) { +; AVX512VL-LABEL: kb52h_512_mask25_and1: +; AVX512VL: # %bb.0: +; AVX512VL-NEXT: vbroadcastsd {{.*#+}} zmm0 = [1,1,1,1,1,1,1,1] +; AVX512VL-NEXT: retq + %mx = and <8 x i64> %x, + %my = and <8 x i64> %y, + %r = call <8 x i64> @llvm.x86.avx512.vpmadd52h.uq.512( + <8 x i64> , + <8 x i64> %mx, + <8 x i64> %my) + %ret = and <8 x i64> %r, + ret <8 x i64> %ret +} + + +; 26-bit = 67108863 = (1<<26)-1 +; 50-bit = 1125899906842623 = (1<<50)-1 + +declare <2 x i64> @llvm.x86.avx512.vpmadd52l.uq.128(<2 x i64>, <2 x i64>, <2 x i64>) +declare <4 x i64> @llvm.x86.avx512.vpmadd52l.uq.256(<4 x i64>, <4 x i64>, <4 x i64>) +declare <8 x i64> @llvm.x86.avx512.vpmadd52l.uq.512(<8 x i64>, <8 x i64>, <8 x i64>) + + + +define <2 x i64> @kb52l_128_mask26x26_add_intrin(<2 x i64> %x, <2 x i64> %y, <2 x i64> %acc) { +; AVX512VL-LABEL: kb52l_128_mask26x26_add_intrin: +; AVX512VL: # %bb.0: +; AVX512VL-NEXT: vpbroadcastq {{.*#+}} xmm3 = [67108863,67108863] +; AVX512VL-NEXT: vpand %xmm3, %xmm0, %xmm0 +; AVX512VL-NEXT: vpand %xmm3, %xmm1, %xmm1 +; AVX512VL-NEXT: vpmadd52luq %xmm1, %xmm0, %xmm2 +; AVX512VL-NEXT: vmovdqa %xmm2, %xmm0 +; AVX512VL-NEXT: retq + %xm = and <2 x i64> %x, + %ym = and <2 x i64> %y, + %r = call <2 x i64> @llvm.x86.avx512.vpmadd52l.uq.128(<2 x i64> %acc, <2 x i64> %xm, <2 x i64> %ym) + ret <2 x i64> %r +} + + + +define <4 x i64> @kb52l_256_mask50x3_add_intrin(<4 x i64> %x, <4 x i64> %y, <4 x i64> %acc) { +; AVX512VL-LABEL: kb52l_256_mask50x3_add_intrin: +; AVX512VL: # %bb.0: +; AVX512VL-NEXT: vpandq {{\.?LCPI[0-9]+_[0-9]+}}(%rip){1to4}, %ymm0, %ymm0 +; AVX512VL-NEXT: vpandq {{\.?LCPI[0-9]+_[0-9]+}}(%rip){1to4}, %ymm1, %ymm1 +; AVX512VL-NEXT: vpmadd52luq %ymm1, %ymm0, %ymm2 +; AVX512VL-NEXT: vmovdqa %ymm2, %ymm0 +; AVX512VL-NEXT: retq + %xm = and <4 x i64> %x, + %ym = and <4 x i64> %y, + %r = call <4 x i64> @llvm.x86.avx512.vpmadd52l.uq.256(<4 x i64> %acc, <4 x i64> %xm, <4 x i64> %ym) + ret <4 x i64> %r +} + + + +define <8 x i64> @kb52l_512_mask26x26_add_intrin(<8 x i64> %x, <8 x i64> %y, <8 x i64> %acc) { +; AVX512-NOVL-LABEL: kb52l_512_mask26x26_add_intrin: +; AVX512-NOVL: vpmadd52luq +; AVX512-NOVL: retq +; AVX512VL-LABEL: kb52l_512_mask26x26_add_intrin: +; AVX512VL: # %bb.0: +; AVX512VL-NEXT: vpbroadcastq {{.*#+}} zmm3 = [67108863,67108863,67108863,67108863,67108863,67108863,67108863,67108863] +; AVX512VL-NEXT: vpandq %zmm3, %zmm0, %zmm0 +; AVX512VL-NEXT: vpandq %zmm3, %zmm1, %zmm1 +; AVX512VL-NEXT: vpmadd52luq %zmm1, %zmm0, %zmm2 +; AVX512VL-NEXT: vmovdqa64 %zmm2, %zmm0 +; AVX512VL-NEXT: retq + %xm = and <8 x i64> %x, + %ym = and <8 x i64> %y, + %r = call <8 x i64> @llvm.x86.avx512.vpmadd52l.uq.512(<8 x i64> %acc, <8 x i64> %xm, <8 x i64> %ym) + ret <8 x i64> %r +} + + + + +define <2 x i64> @kb52l_128_neg_27x27_plain(<2 x i64> %x, <2 x i64> %y, <2 x i64> %acc) { +; AVX512VL-LABEL: kb52l_128_neg_27x27_plain: +; AVX512VL: # %bb.0: +; AVX512VL-NEXT: vpbroadcastq {{.*#+}} xmm3 = [67108864,67108864] +; AVX512VL-NEXT: vpand %xmm3, %xmm0, %xmm0 +; AVX512VL-NEXT: vpand %xmm3, %xmm1, %xmm1 +; AVX512VL-NEXT: vpmuldq %xmm1, %xmm0, %xmm0 +; AVX512VL-NEXT: vpaddq %xmm2, %xmm0, %xmm0 +; AVX512VL-NEXT: retq + %xm = and <2 x i64> %x, ; 1<<26 + %ym = and <2 x i64> %y, + %mul = mul <2 x i64> %xm, %ym + %res = add <2 x i64> %mul, %acc + ret <2 x i64> %res +} +