Skip to content

Conversation

goldsteinn
Copy link
Contributor

@goldsteinn goldsteinn commented Aug 20, 2024

  • [X86] Add tests for shrinking width of masked loads/stores; NFC
  • [X86] Shrink width of masked loads/stores

In the best case we can convert a masked load/store to a narrower
normal load/store. I.e _mm512_maskz_load_ps(p, 0xff) can be done
with just an normal ymm store. As well, if the mask is entirely
encapsulated in a lower sub-vector, we can shrink the load/store i.e
_mm512_maskz_load_ps(p, 0x1c) is the same as
_mm256_maskz_load_ps(p, 0x1c).

@goldsteinn goldsteinn requested a review from RKSimon August 20, 2024 23:57
@llvmbot llvmbot added backend:X86 llvm:SelectionDAG SelectionDAGISel as well labels Aug 20, 2024
@llvmbot
Copy link
Member

llvmbot commented Aug 20, 2024

@llvm/pr-subscribers-backend-x86

@llvm/pr-subscribers-llvm-selectiondag

Author: None (goldsteinn)

Changes
  • [X86] Add tests for shrinking width of masked loads/stores; NFC
  • [X86] Shrinking width of masked loads/stores

Patch is 42.88 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/105451.diff

8 Files Affected:

  • (modified) llvm/include/llvm/CodeGen/SelectionDAGNodes.h (+5)
  • (modified) llvm/lib/CodeGen/SelectionDAG/SelectionDAG.cpp (+18)
  • (modified) llvm/lib/Target/X86/X86ISelLowering.cpp (+104-2)
  • (added) llvm/test/CodeGen/X86/masked-load-store-shrink.ll (+716)
  • (modified) llvm/test/CodeGen/X86/masked_load.ll (+6-6)
  • (modified) llvm/test/CodeGen/X86/masked_loadstore_split.ll (+3-3)
  • (modified) llvm/test/CodeGen/X86/masked_store.ll (+4-4)
  • (modified) llvm/test/CodeGen/X86/pr46532.ll (+1-1)
diff --git a/llvm/include/llvm/CodeGen/SelectionDAGNodes.h b/llvm/include/llvm/CodeGen/SelectionDAGNodes.h
index 88549d9c9a2858..b2151cd32f1088 100644
--- a/llvm/include/llvm/CodeGen/SelectionDAGNodes.h
+++ b/llvm/include/llvm/CodeGen/SelectionDAGNodes.h
@@ -1848,6 +1848,11 @@ bool isOneOrOneSplat(SDValue V, bool AllowUndefs = false);
 /// Does not permit build vector implicit truncation.
 bool isAllOnesOrAllOnesSplat(SDValue V, bool AllowUndefs = false);
 
+/// Returns the demanded elements from the mask of a masked op (i.e
+/// MSTORE/MLOAD).
+APInt getDemandedEltsForMaskedOp(SDValue Mask, unsigned NumElts,
+                                 SmallVector<SDValue> *MaskEltsOut = nullptr);
+
 /// Return true if \p V is either a integer or FP constant.
 inline bool isIntOrFPConstant(SDValue V) {
   return isa<ConstantSDNode>(V) || isa<ConstantFPSDNode>(V);
diff --git a/llvm/lib/CodeGen/SelectionDAG/SelectionDAG.cpp b/llvm/lib/CodeGen/SelectionDAG/SelectionDAG.cpp
index 18a3b7bce104a7..cb0e098a1e511e 100644
--- a/llvm/lib/CodeGen/SelectionDAG/SelectionDAG.cpp
+++ b/llvm/lib/CodeGen/SelectionDAG/SelectionDAG.cpp
@@ -12128,6 +12128,24 @@ bool llvm::isAllOnesOrAllOnesSplat(SDValue N, bool AllowUndefs) {
   return C && C->isAllOnes() && C->getValueSizeInBits(0) == BitWidth;
 }
 
+APInt llvm::getDemandedEltsForMaskedOp(SDValue Mask, unsigned NumElts,
+                                       SmallVector<SDValue> *MaskEltsOut) {
+  if (!ISD::isBuildVectorOfConstantSDNodes(Mask.getNode()))
+    return APInt::getAllOnes(NumElts);
+  APInt Demanded = APInt::getZero(NumElts);
+  BuildVectorSDNode *MaskBV = cast<BuildVectorSDNode>(Mask);
+  for (unsigned i = 0; i < MaskBV->getNumOperands(); ++i) {
+    APInt V;
+    if (!sd_match(MaskBV->getOperand(i), m_ConstInt(V)))
+      return APInt::getAllOnes(NumElts);
+    if (V.isNegative())
+      Demanded.setBit(i);
+    if (MaskEltsOut)
+      MaskEltsOut->emplace_back(MaskBV->getOperand(i));
+  }
+  return Demanded;
+}
+
 HandleSDNode::~HandleSDNode() {
   DropOperands();
 }
diff --git a/llvm/lib/Target/X86/X86ISelLowering.cpp b/llvm/lib/Target/X86/X86ISelLowering.cpp
index 169c955f0ba89f..e8bbfaa41b2cb4 100644
--- a/llvm/lib/Target/X86/X86ISelLowering.cpp
+++ b/llvm/lib/Target/X86/X86ISelLowering.cpp
@@ -51536,20 +51536,111 @@ combineMaskedLoadConstantMask(MaskedLoadSDNode *ML, SelectionDAG &DAG,
   return DCI.CombineTo(ML, Blend, NewML.getValue(1), true);
 }
 
+static bool tryShrinkMaskedOperation(SelectionDAG &DAG, const SDLoc &DL,
+                                     SDValue Mask, EVT OrigVT,
+                                     SDValue *ValInOut, EVT *NewVTOut,
+                                     SDValue *NewMaskOut) {
+  // Ensure we have a reasonable input type.
+  // Also ensure ensure input bits is larger then xmm, otherwise its not
+  // profitable to try to shrink.
+  if (!OrigVT.isSimple() || !OrigVT.isVector() ||
+      OrigVT.getSizeInBits() <= 128 || !isPowerOf2_64(OrigVT.getSizeInBits()) ||
+      !isPowerOf2_64(OrigVT.getScalarSizeInBits()))
+    return false;
+
+  SmallVector<SDValue> OrigMask;
+  APInt DemandedElts = getDemandedEltsForMaskedOp(
+      Mask, OrigVT.getVectorNumElements(), &OrigMask);
+  if (DemandedElts.isAllOnes() || DemandedElts.isZero())
+    return false;
+
+  unsigned OrigNumElts = OrigVT.getVectorNumElements();
+  unsigned ReqElts =
+      DemandedElts.getBitWidth() - DemandedElts.countLeadingZeros();
+  // We can't shrink out vector category in a meaningful way.
+  if (ReqElts > OrigNumElts / 2U)
+    return false;
+
+  // At most shrink to xmm.
+  unsigned NewNumElts =
+      std::max(128U / OrigVT.getScalarSizeInBits(), PowerOf2Ceil(ReqElts));
+
+  EVT NewVT =
+      EVT::getVectorVT(*DAG.getContext(), OrigVT.getScalarType(), NewNumElts);
+  if (!NewVT.isSimple())
+    return false;
+
+  // Extract all the value arguments;
+  if (ValInOut && *ValInOut)
+    *ValInOut = DAG.getNode(ISD::EXTRACT_SUBVECTOR, DL, NewVT, *ValInOut,
+                            DAG.getIntPtrConstant(0, DL));
+  if (NewVTOut)
+    *NewVTOut = NewVT;
+  *NewMaskOut = SDValue();
+  // The mask was just truncating, so don't need it anymore.
+  if (NewNumElts == ReqElts && DemandedElts.isMask())
+    return true;
+
+  // Get smaller mask.
+  EVT NewMaskVT = EVT::getVectorVT(
+      *DAG.getContext(), Mask.getValueType().getScalarType(), NewNumElts);
+  OrigMask.truncate(NewNumElts);
+  *NewMaskOut = DAG.getBuildVector(NewMaskVT, DL, OrigMask);
+  return true;
+}
+
+static bool tryShrinkMaskedOperation(SelectionDAG &DAG, const SDLoc &DL,
+                                     SDValue Mask, EVT OrigVT, EVT *NewVTOut,
+                                     SDValue *NewMaskOut) {
+  return tryShrinkMaskedOperation(DAG, DL, Mask, OrigVT, nullptr, NewVTOut,
+                                  NewMaskOut);
+}
+
+static bool tryShrinkMaskedOperation(SelectionDAG &DAG, const SDLoc &DL,
+                                     SDValue Mask, EVT OrigVT,
+                                     SDValue *ValInOut, SDValue *NewMaskOut) {
+  return tryShrinkMaskedOperation(DAG, DL, Mask, OrigVT, ValInOut, nullptr,
+                                  NewMaskOut);
+}
+
 static SDValue combineMaskedLoad(SDNode *N, SelectionDAG &DAG,
                                  TargetLowering::DAGCombinerInfo &DCI,
                                  const X86Subtarget &Subtarget) {
+  using namespace llvm::SDPatternMatch;
   auto *Mld = cast<MaskedLoadSDNode>(N);
+  SDLoc DL(N);
 
   // TODO: Expanding load with constant mask may be optimized as well.
   if (Mld->isExpandingLoad())
     return SDValue();
 
+  SDValue Mask = Mld->getMask();
+  EVT VT = Mld->getValueType(0);
   if (Mld->getExtensionType() == ISD::NON_EXTLOAD) {
     if (SDValue ScalarLoad =
             reduceMaskedLoadToScalarLoad(Mld, DAG, DCI, Subtarget))
       return ScalarLoad;
 
+    SDValue NewMask;
+    EVT NewVT;
+    if (sd_match(Mld->getPassThru(), m_Zero()) &&
+        tryShrinkMaskedOperation(DAG, DL, Mask, VT, &NewVT, &NewMask)) {
+      SDValue NewLoad;
+      if (NewMask)
+        NewLoad = DAG.getMaskedLoad(
+            NewVT, DL, Mld->getChain(), Mld->getBasePtr(), Mld->getOffset(),
+            NewMask, getZeroVector(NewVT.getSimpleVT(), Subtarget, DAG, DL),
+            Mld->getMemoryVT(), Mld->getMemOperand(), Mld->getAddressingMode(),
+            Mld->getExtensionType());
+      else
+        NewLoad = DAG.getLoad(NewVT, DL, Mld->getChain(), Mld->getBasePtr(),
+                              Mld->getMemOperand());
+
+      SDValue R = DAG.getNode(ISD::INSERT_SUBVECTOR, DL, VT, Mld->getPassThru(),
+                              NewLoad, DAG.getIntPtrConstant(0, DL));
+      return DCI.CombineTo(Mld, R, NewLoad.getValue(1), true);
+    }
+
     // TODO: Do some AVX512 subsets benefit from this transform?
     if (!Subtarget.hasAVX512())
       if (SDValue Blend = combineMaskedLoadConstantMask(Mld, DAG, DCI))
@@ -51558,9 +51649,7 @@ static SDValue combineMaskedLoad(SDNode *N, SelectionDAG &DAG,
 
   // If the mask value has been legalized to a non-boolean vector, try to
   // simplify ops leading up to it. We only demand the MSB of each lane.
-  SDValue Mask = Mld->getMask();
   if (Mask.getScalarValueSizeInBits() != 1) {
-    EVT VT = Mld->getValueType(0);
     const TargetLowering &TLI = DAG.getTargetLoweringInfo();
     APInt DemandedBits(APInt::getSignMask(VT.getScalarSizeInBits()));
     if (TLI.SimplifyDemandedBits(Mask, DemandedBits, DCI)) {
@@ -51622,6 +51711,8 @@ static SDValue combineMaskedStore(SDNode *N, SelectionDAG &DAG,
   if (Mst->isCompressingStore())
     return SDValue();
 
+
+
   EVT VT = Mst->getValue().getValueType();
   SDLoc dl(Mst);
   const TargetLowering &TLI = DAG.getTargetLoweringInfo();
@@ -51651,6 +51742,17 @@ static SDValue combineMaskedStore(SDNode *N, SelectionDAG &DAG,
   }
 
   SDValue Value = Mst->getValue();
+  SDValue NewMask;
+  if (tryShrinkMaskedOperation(DAG, dl, Mask, VT, &Value, &NewMask)) {
+    if (NewMask)
+      return DAG.getMaskedStore(Mst->getChain(), dl, Value, Mst->getBasePtr(),
+                                Mst->getOffset(), NewMask, Mst->getMemoryVT(),
+                                Mst->getMemOperand(), Mst->getAddressingMode());
+    return DAG.getStore(Mst->getChain(), SDLoc(N), Value, Mst->getBasePtr(),
+                        Mst->getPointerInfo(), Mst->getOriginalAlign(),
+                        Mst->getMemOperand()->getFlags());
+  }
+
   if (Value.getOpcode() == ISD::TRUNCATE && Value.getNode()->hasOneUse() &&
       TLI.isTruncStoreLegal(Value.getOperand(0).getValueType(),
                             Mst->getMemoryVT())) {
diff --git a/llvm/test/CodeGen/X86/masked-load-store-shrink.ll b/llvm/test/CodeGen/X86/masked-load-store-shrink.ll
new file mode 100644
index 00000000000000..c1cdef54bc3737
--- /dev/null
+++ b/llvm/test/CodeGen/X86/masked-load-store-shrink.ll
@@ -0,0 +1,716 @@
+; NOTE: Assertions have been autogenerated by utils/update_llc_test_checks.py
+; RUN: llc < %s -disable-peephole -mtriple=x86_64-apple-darwin -mattr=sse2    | FileCheck %s --check-prefixes=SSE,SSE2
+; RUN: llc < %s -disable-peephole -mtriple=x86_64-apple-darwin -mattr=sse4.2  | FileCheck %s --check-prefixes=SSE,SSE42
+; RUN: llc < %s -disable-peephole -mtriple=x86_64-apple-darwin -mattr=avx     | FileCheck %s --check-prefixes=AVX,AVX1OR2,AVX1
+; RUN: llc < %s -disable-peephole -mtriple=x86_64-apple-darwin -mattr=avx2    | FileCheck %s --check-prefixes=AVX,AVX1OR2,AVX2
+; RUN: llc < %s -disable-peephole -mtriple=x86_64-apple-darwin -mattr=avx512f | FileCheck %s --check-prefixes=AVX,AVX512,AVX512F
+; RUN: llc < %s -disable-peephole -mtriple=x86_64-apple-darwin -mattr=avx512f,avx512dq,avx512vl | FileCheck %s --check-prefixes=AVX,AVX512,AVX512VL,AVX512VLDQ
+; RUN: llc < %s -disable-peephole -mtriple=x86_64-apple-darwin -mattr=avx512f,avx512bw,avx512vl | FileCheck %s --check-prefixes=AVX,AVX512,AVX512VL,AVX512VLBW
+; RUN: llc < %s -mtriple=i686-apple-darwin -mattr=avx512f,avx512bw,avx512dq,avx512vl -verify-machineinstrs | FileCheck %s --check-prefixes=X86-AVX512
+
+define <4 x i64> @mload256_to_load128(ptr %p) nounwind {
+; SSE-LABEL: mload256_to_load128:
+; SSE:       ## %bb.0:
+; SSE-NEXT:    movups (%rdi), %xmm0
+; SSE-NEXT:    xorps %xmm1, %xmm1
+; SSE-NEXT:    retq
+;
+; AVX-LABEL: mload256_to_load128:
+; AVX:       ## %bb.0:
+; AVX-NEXT:    vmovaps (%rdi), %xmm0
+; AVX-NEXT:    retq
+;
+; X86-AVX512-LABEL: mload256_to_load128:
+; X86-AVX512:       ## %bb.0:
+; X86-AVX512-NEXT:    movl {{[0-9]+}}(%esp), %eax
+; X86-AVX512-NEXT:    vmovaps (%eax), %xmm0
+; X86-AVX512-NEXT:    retl
+  %tmp = tail call <8 x float> @llvm.masked.load.v8f32.p0(ptr %p, i32 32, <8 x i1> <i1 true, i1 true, i1 true, i1 true, i1 false, i1 false, i1 false, i1 false>, <8 x float> <float poison, float poison, float poison, float poison, float 0.000000e+00, float 0.000000e+00, float 0.000000e+00, float 0.000000e+00>)
+  %r = bitcast <8 x float> %tmp to <4 x i64>
+  ret <4 x i64> %r
+}
+
+define <8 x i64> @mload512_to_load256(ptr %p) nounwind {
+; SSE-LABEL: mload512_to_load256:
+; SSE:       ## %bb.0:
+; SSE-NEXT:    movups (%rdi), %xmm0
+; SSE-NEXT:    movups 16(%rdi), %xmm1
+; SSE-NEXT:    xorps %xmm2, %xmm2
+; SSE-NEXT:    xorps %xmm3, %xmm3
+; SSE-NEXT:    retq
+;
+; AVX1OR2-LABEL: mload512_to_load256:
+; AVX1OR2:       ## %bb.0:
+; AVX1OR2-NEXT:    vmovups (%rdi), %ymm0
+; AVX1OR2-NEXT:    vxorps %xmm1, %xmm1, %xmm1
+; AVX1OR2-NEXT:    retq
+;
+; AVX512-LABEL: mload512_to_load256:
+; AVX512:       ## %bb.0:
+; AVX512-NEXT:    vmovups (%rdi), %ymm0
+; AVX512-NEXT:    retq
+;
+; X86-AVX512-LABEL: mload512_to_load256:
+; X86-AVX512:       ## %bb.0:
+; X86-AVX512-NEXT:    movl {{[0-9]+}}(%esp), %eax
+; X86-AVX512-NEXT:    vmovups (%eax), %ymm0
+; X86-AVX512-NEXT:    retl
+  %tmp = tail call <32 x i16> @llvm.masked.load.v32i16.p0(ptr %p, i32 1, <32 x i1> <i1 true, i1 true, i1 true, i1 true, i1 true, i1 true, i1 true, i1 true, i1 true, i1 true, i1 true, i1 true, i1 true, i1 true, i1 true, i1 true, i1 false, i1 false, i1 false, i1 false, i1 false, i1 false, i1 false, i1 false, i1 false, i1 false, i1 false, i1 false, i1 false, i1 false, i1 false, i1 false>, <32 x i16> <i16 poison, i16 poison, i16 poison, i16 poison, i16 poison, i16 poison, i16 poison, i16 poison, i16 poison, i16 poison, i16 poison, i16 poison, i16 poison, i16 poison, i16 poison, i16 poison, i16 0, i16 0, i16 0, i16 0, i16 0, i16 0, i16 0, i16 0, i16 0, i16 0, i16 0, i16 0, i16 0, i16 0, i16 0, i16 0>)
+  %r = bitcast <32 x i16> %tmp to <8 x i64>
+  ret <8 x i64> %r
+}
+
+define <8 x i64> @mload512_to_mload128(ptr %p) nounwind {
+; SSE-LABEL: mload512_to_mload128:
+; SSE:       ## %bb.0:
+; SSE-NEXT:    movsd {{.*#+}} xmm0 = mem[0],zero
+; SSE-NEXT:    xorps %xmm1, %xmm1
+; SSE-NEXT:    xorps %xmm2, %xmm2
+; SSE-NEXT:    xorps %xmm3, %xmm3
+; SSE-NEXT:    retq
+;
+; AVX1OR2-LABEL: mload512_to_mload128:
+; AVX1OR2:       ## %bb.0:
+; AVX1OR2-NEXT:    vmovsd {{.*#+}} xmm0 = [4294967295,4294967295,0,0]
+; AVX1OR2-NEXT:    vmaskmovps (%rdi), %xmm0, %xmm0
+; AVX1OR2-NEXT:    vxorps %xmm1, %xmm1, %xmm1
+; AVX1OR2-NEXT:    retq
+;
+; AVX512F-LABEL: mload512_to_mload128:
+; AVX512F:       ## %bb.0:
+; AVX512F-NEXT:    movw $3, %ax
+; AVX512F-NEXT:    kmovw %eax, %k1
+; AVX512F-NEXT:    vmovaps (%rdi), %zmm0 {%k1} {z}
+; AVX512F-NEXT:    vmovaps %xmm0, %xmm0
+; AVX512F-NEXT:    retq
+;
+; AVX512VLDQ-LABEL: mload512_to_mload128:
+; AVX512VLDQ:       ## %bb.0:
+; AVX512VLDQ-NEXT:    movb $3, %al
+; AVX512VLDQ-NEXT:    kmovw %eax, %k1
+; AVX512VLDQ-NEXT:    vmovaps (%rdi), %xmm0 {%k1} {z}
+; AVX512VLDQ-NEXT:    retq
+;
+; AVX512VLBW-LABEL: mload512_to_mload128:
+; AVX512VLBW:       ## %bb.0:
+; AVX512VLBW-NEXT:    movb $3, %al
+; AVX512VLBW-NEXT:    kmovd %eax, %k1
+; AVX512VLBW-NEXT:    vmovaps (%rdi), %xmm0 {%k1} {z}
+; AVX512VLBW-NEXT:    retq
+;
+; X86-AVX512-LABEL: mload512_to_mload128:
+; X86-AVX512:       ## %bb.0:
+; X86-AVX512-NEXT:    movl {{[0-9]+}}(%esp), %eax
+; X86-AVX512-NEXT:    movb $3, %cl
+; X86-AVX512-NEXT:    kmovd %ecx, %k1
+; X86-AVX512-NEXT:    vmovaps (%eax), %xmm0 {%k1} {z}
+; X86-AVX512-NEXT:    retl
+  %tmp = tail call <16 x float> @llvm.masked.load.v16f32.p0(ptr %p, i32 64, <16 x i1> <i1 true, i1 true, i1 false, i1 false, i1 false, i1 false, i1 false, i1 false, i1 false, i1 false, i1 false, i1 false, i1 false, i1 false, i1 false, i1 false>, <16 x float> <float poison, float poison, float 0.000000e+00, float 0.000000e+00, float 0.000000e+00, float 0.000000e+00, float 0.000000e+00, float 0.000000e+00, float 0.000000e+00, float 0.000000e+00, float 0.000000e+00, float 0.000000e+00, float 0.000000e+00, float 0.000000e+00, float 0.000000e+00, float 0.000000e+00>)
+  %r = bitcast <16 x float> %tmp to <8 x i64>
+  ret <8 x i64> %r
+}
+
+define <4 x i64> @mload256_to_mload128(ptr %p) nounwind {
+; SSE2-LABEL: mload256_to_mload128:
+; SSE2:       ## %bb.0:
+; SSE2-NEXT:    movss {{.*#+}} xmm0 = mem[0],zero,zero,zero
+; SSE2-NEXT:    movss {{.*#+}} xmm1 = mem[0],zero,zero,zero
+; SSE2-NEXT:    movlhps {{.*#+}} xmm0 = xmm0[0],xmm1[0]
+; SSE2-NEXT:    xorps %xmm1, %xmm1
+; SSE2-NEXT:    retq
+;
+; SSE42-LABEL: mload256_to_mload128:
+; SSE42:       ## %bb.0:
+; SSE42-NEXT:    movss {{.*#+}} xmm0 = mem[0],zero,zero,zero
+; SSE42-NEXT:    insertps {{.*#+}} xmm0 = xmm0[0],zero,mem[0],zero
+; SSE42-NEXT:    xorps %xmm1, %xmm1
+; SSE42-NEXT:    retq
+;
+; AVX1OR2-LABEL: mload256_to_mload128:
+; AVX1OR2:       ## %bb.0:
+; AVX1OR2-NEXT:    vmovddup {{.*#+}} xmm0 = [4294967295,0,4294967295,0]
+; AVX1OR2-NEXT:    ## xmm0 = mem[0,0]
+; AVX1OR2-NEXT:    vmaskmovps (%rdi), %xmm0, %xmm0
+; AVX1OR2-NEXT:    retq
+;
+; AVX512F-LABEL: mload256_to_mload128:
+; AVX512F:       ## %bb.0:
+; AVX512F-NEXT:    movw $5, %ax
+; AVX512F-NEXT:    kmovw %eax, %k1
+; AVX512F-NEXT:    vmovups (%rdi), %zmm0 {%k1} {z}
+; AVX512F-NEXT:    vmovaps %xmm0, %xmm0
+; AVX512F-NEXT:    retq
+;
+; AVX512VLDQ-LABEL: mload256_to_mload128:
+; AVX512VLDQ:       ## %bb.0:
+; AVX512VLDQ-NEXT:    movb $5, %al
+; AVX512VLDQ-NEXT:    kmovw %eax, %k1
+; AVX512VLDQ-NEXT:    vmovaps (%rdi), %xmm0 {%k1} {z}
+; AVX512VLDQ-NEXT:    retq
+;
+; AVX512VLBW-LABEL: mload256_to_mload128:
+; AVX512VLBW:       ## %bb.0:
+; AVX512VLBW-NEXT:    movb $5, %al
+; AVX512VLBW-NEXT:    kmovd %eax, %k1
+; AVX512VLBW-NEXT:    vmovaps (%rdi), %xmm0 {%k1} {z}
+; AVX512VLBW-NEXT:    retq
+;
+; X86-AVX512-LABEL: mload256_to_mload128:
+; X86-AVX512:       ## %bb.0:
+; X86-AVX512-NEXT:    movl {{[0-9]+}}(%esp), %eax
+; X86-AVX512-NEXT:    movb $5, %cl
+; X86-AVX512-NEXT:    kmovd %ecx, %k1
+; X86-AVX512-NEXT:    vmovaps (%eax), %xmm0 {%k1} {z}
+; X86-AVX512-NEXT:    retl
+  %tmp = tail call <8 x float> @llvm.masked.load.v8f32.p0(ptr %p, i32 32, <8 x i1> <i1 true, i1 false, i1 true, i1 false, i1 false, i1 false, i1 false, i1 false>, <8 x float> <float poison, float 0.000000e+00, float poison, float 0.000000e+00, float 0.000000e+00, float 0.000000e+00, float 0.000000e+00, float 0.000000e+00>)
+  %r = bitcast <8 x float> %tmp to <4 x i64>
+  ret <4 x i64> %r
+}
+
+define <8 x i64> @mload512_to_mload256(ptr %p) nounwind {
+; SSE-LABEL: mload512_to_mload256:
+; SSE:       ## %bb.0:
+; SSE-NEXT:    xorps %xmm0, %xmm0
+; SSE-NEXT:    movhps {{.*#+}} xmm0 = xmm0[0,1],mem[0,1]
+; SSE-NEXT:    movss {{.*#+}} xmm1 = mem[0],zero,zero,zero
+; SSE-NEXT:    xorps %xmm2, %xmm2
+; SSE-NEXT:    xorps %xmm3, %xmm3
+; SSE-NEXT:    retq
+;
+; AVX1OR2-LABEL: mload512_to_mload256:
+; AVX1OR2:       ## %bb.0:
+; AVX1OR2-NEXT:    vmovaps {{.*#+}} ymm0 = [0,0,4294967295,4294967295,4294967295,0,0,0]
+; AVX1OR2-NEXT:    vmaskmovps (%rdi), %ymm0, %ymm0
+; AVX1OR2-NEXT:    vxorps %xmm1, %xmm1, %xmm1
+; AVX1OR2-NEXT:    retq
+;
+; AVX512F-LABEL: mload512_to_mload256:
+; AVX512F:       ## %bb.0:
+; AVX512F-NEXT:    movw $28, %ax
+; AVX512F-NEXT:    kmovw %eax, %k1
+; AVX512F-NEXT:    vmovaps (%rdi), %zmm0 {%k1} {z}
+; AVX512F-NEXT:    vmovaps %ymm0, %ymm0
+; AVX512F-NEXT:    retq
+;
+; AVX512VLDQ-LABEL: mload512_to_mload256:
+; AVX512VLDQ:       ## %bb.0:
+; AVX512VLDQ-NEXT:    movb $28, %al
+; AVX512VLDQ-NEXT:    kmovw %eax, %k1
+; AVX512VLDQ-NEXT:    vmovaps (%rdi), %ymm0 {%k1} {z}
+; AVX512VLDQ-NEXT:    retq
+;
+; AVX512VLBW-LABEL: mload512_to_mload256:
+; AVX512VLBW:       ## %bb.0:
+; AVX512VLBW-NEXT:    movb $28, %al
+; AVX512VLBW-NEXT:    kmovd %eax, %k1
+; AVX512VLBW-NEXT:    vmovaps (%rdi), %ymm0 {%k1} {z}
+; AVX512VLBW-NEXT:    retq
+;
+; X86-AVX512-LABEL: mload512_to_mload256:
+; X86-AVX512:       ## %bb.0:
+; X86-AVX512-NEXT:    movl {{[0-9]+}}(%esp), %eax
+; X86-AVX512-NEXT:    movb $28, %cl
+; X86-AVX512-NEXT:    kmovd %ecx, %k1
+; X86-AVX512-NEXT:    vmovaps (%eax), %ymm0 {%k1} {z}
+; X86-AVX512-NEXT:    retl
+  %tmp = tail call <16 x float> @llvm.masked.load.v16f32.p0(ptr %p, i32 64, <16 x i1> <i1 false, i1 false, i1 true, i1 true, i1 true, i1 false, i1 false, i1 false, i1 false, i1 false, i1 false, i1 false, i1 false, i1 false, i1 false, i1 false>, <16 x float> <float 0.000000e+00, float 0.000000e+00, float poison, float poison, float poison, float 0.000000e+00, float 0.000000e+00, float 0.000000e+00, float 0.000000e+00, float 0.000000e+00, float 0.000000e+00, float 0.000000e+00, float 0.000000e+00, float 0.000000e+00, float 0.000000e+00, float 0.000000e+00>)
+  %r = bitcast <16 x float> %tmp to <8 x i64>
+  ret <8 x i64> %r
+}
+
+define <8 x i64> @mload512_fail_no_possible_shrink(ptr %p) nounwind {
+; SSE-LABEL: mload512_fail_no_possible_shrink:
+; SSE:       ## %bb.0:
+; SSE-NEXT:    movss {{.*#+}} xmm2 = mem[0],zero,zero,zero
+; SSE-NEXT:    movups (%rdi), %xmm0
+; SSE-NEXT:    movups 16(%rdi), %xmm1
+; SSE-NEXT:    xorps %xmm3, %xmm3
+; SSE-NEXT:    retq
+;
+; AVX1OR2-LABEL: mload512_fail_no_possible_shrink:
+; AVX1OR2:       ## %bb.0:
+; AVX1OR2-NEXT:    vmovss {{.*#+}} xmm0 = [4294967295,0,0,0]
+; AVX1OR2-NEXT:    vmaskmovps 32(%rdi), %xmm0, %xmm1
+; AVX1OR2-NEXT:    vmovaps (%rdi), %ymm0
+; AVX1OR2-NEXT:    retq
+;
+; AVX512F-LABEL: mload512_fail_no_possible_shrink:
+; AVX512F:       ## %bb.0:
+; AVX512F-NEXT:    movw $511, %ax ## imm = 0x1FF
+; AVX512F-NEXT:    kmovw %eax, %k1
+; AVX512F-NEXT:    vmovaps (%rdi), %zmm0 {%k1} {z}
+; AVX512F-NEXT:    retq
+;
+; AVX512VLDQ-LABEL: mload512_fail_no_possible_shrink:
+; AVX512VLDQ:       ## %bb.0:
+; AVX512VLDQ-NEXT:    movw $511, %ax ## imm = 0x1FF
+; AVX512VLDQ-NEXT:    kmovw %eax, %k1
+; AVX512VLDQ-NEXT:    vmovaps (%rdi), %zmm0 {%k1} {z}
+; AVX512...
[truncated]

@goldsteinn goldsteinn requested a review from KanRobert August 20, 2024 23:57
@goldsteinn goldsteinn changed the title goldsteinn/shrink masked memory ops [X86] Shrink width of masked loads/stores Aug 20, 2024
Copy link

github-actions bot commented Aug 21, 2024

⚠️ C/C++ code formatter, clang-format found issues in your code. ⚠️

You can test this locally with the following command:
git-clang-format --diff a16f0dc9c2f0690e28622b0d80bd154fb0e6a30a 03cedf6315570e1d403380cddd688deb03198448 --extensions h,cpp -- llvm/include/llvm/CodeGen/SelectionDAGNodes.h llvm/lib/CodeGen/SelectionDAG/SelectionDAG.cpp llvm/lib/Target/X86/X86ISelLowering.cpp
View the diff from clang-format here.
diff --git a/llvm/lib/Target/X86/X86ISelLowering.cpp b/llvm/lib/Target/X86/X86ISelLowering.cpp
index b576127702..2b8dca91c0 100644
--- a/llvm/lib/Target/X86/X86ISelLowering.cpp
+++ b/llvm/lib/Target/X86/X86ISelLowering.cpp
@@ -51712,8 +51712,6 @@ static SDValue combineMaskedStore(SDNode *N, SelectionDAG &DAG,
   if (Mst->isCompressingStore())
     return SDValue();
 
-
-
   EVT VT = Mst->getValue().getValueType();
   SDLoc dl(Mst);
   const TargetLowering &TLI = DAG.getTargetLoweringInfo();

Copy link
Contributor

@phoebewang phoebewang left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we do the shrink in IR phase?

SDValue *ValInOut, EVT *NewVTOut,
SDValue *NewMaskOut) {
// Ensure we have a reasonable input type.
// Also ensure ensure input bits is larger then xmm, otherwise its not
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ensure ensure -> ensure

Comment on lines 51546 to 51548
if (!OrigVT.isSimple() || !OrigVT.isVector() ||
OrigVT.getSizeInBits() <= 128 || !isPowerOf2_64(OrigVT.getSizeInBits()) ||
!isPowerOf2_64(OrigVT.getScalarSizeInBits()))
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we just check if it's legal type and size is 256 or 512?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

+1, maybe we can use is256BitVector is512BitVector

Comment on lines 51558 to 51559
unsigned ReqElts =
DemandedElts.getBitWidth() - DemandedElts.countLeadingZeros();
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think we are able to load arbitrary subvector instead of the first one?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We could but the codegen is only "free" for the lower one. Ill add it as a potential todo.

APInt V;
if (!sd_match(MaskBV->getOperand(i), m_ConstInt(V)))
return APInt::getAllOnes(NumElts);
if (V.isNegative())
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

isNegative() -> isOne()?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Negative is "more" correct I think.
For example:

  // If the mask value has been legalized to a non-boolean vector, try to
  // simplify ops leading up to it. We only demand the MSB of each lane.
  if (Mask.getScalarValueSizeInBits() != 1) {
    const TargetLowering &TLI = DAG.getTargetLoweringInfo();
    APInt DemandedBits(APInt::getSignMask(VT.getScalarSizeInBits()));
    if (TLI.SimplifyDemandedBits(Mask, DemandedBits, DCI)) {
      if (N->getOpcode() != ISD::DELETED_NODE)
        DCI.AddToWorklist(N);
      return SDValue(N, 0);
    }
    if (SDValue NewMask =
            TLI.SimplifyMultipleUseDemandedBits(Mask, DemandedBits, DAG))
      return DAG.getMaskedLoad(
          VT, SDLoc(N), Mld->getChain(), Mld->getBasePtr(), Mld->getOffset(),
          NewMask, Mld->getPassThru(), Mld->getMemoryVT(), Mld->getMemOperand(),
          Mld->getAddressingMode(), Mld->getExtensionType());
  }

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Doesn't Mask have to respect getBooleanContents ?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hmm? Don't see how this violates that.

SmallVector<SDValue> OrigMask;
APInt DemandedElts = getDemandedEltsForMaskedOp(
Mask, OrigVT.getVectorNumElements(), &OrigMask);
if (DemandedElts.isAllOnes() || DemandedElts.isZero())
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe we should use assert here. AllOnes masked.load/store for 256/512 should be before optimized away before SelectionDAG.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Slightly opposed. It might be produced as a temporary no?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Slightly opposed. It might be produced as a temporary no?

@goldsteinn
Copy link
Contributor Author

Can we do the shrink in IR phase?

Probably, although I wasn't sure 1) the profitability on other arch 2) if it might mess up vectorization to have different sized vecs in the middle-end and 3) I wasn't 100% sure on the semantics, particularly: "Only the masked-on lanes of the vector need to be inbounds of an allocation (but all these lanes need to be inbounds of the same allocation)" I wasn't sure if we could gurantee the second point but I'm not an expert in the memory model.

In the best case we can convert a masked load/store to a narrower
normal load/store. I.e `_mm512_maskz_load_ps(p, 0xff)` can be done
with just an normal `ymm` store. As well, if the mask is entirely
encapsulated in a lower sub-vector, we can shrink the load/store i.e
`_mm512_maskz_load_ps(p, 0x1c)` is the same as
`_mm256_maskz_load_ps(p, 0x1c)`.
@goldsteinn goldsteinn force-pushed the goldsteinn/shrink-masked-memory-ops branch from b6fc529 to 03cedf6 Compare August 21, 2024 18:40
Comment on lines +51566 to +51567
unsigned NewNumElts =
std::max(128U / OrigVT.getScalarSizeInBits(), PowerOf2Ceil(ReqElts));
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

IIUC, line 51559 assumes the mask is a boolean vector while line 51567 does not assume ?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

How so? We are re-using the orig scalar type.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Line 51559 use the number of non-zero bits as the ReqElts. I think this assumes the mask is a vector of boolean?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thats elements, not bits,

Comment on lines +51571 to +51572
if (!NewVT.isSimple())
return false;
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should this be unreachable? Since you check OrigVT.isSimple() at the beginning.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Probably, although since we are re-creating with a different ele count its theoretically not unreachable.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
backend:X86 llvm:SelectionDAG SelectionDAGISel as well
Projects
None yet
Development

Successfully merging this pull request may close these issues.

5 participants