-
Notifications
You must be signed in to change notification settings - Fork 15.2k
[X86] Shrink width of masked loads/stores #105451
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
[X86] Shrink width of masked loads/stores #105451
Conversation
@llvm/pr-subscribers-backend-x86 @llvm/pr-subscribers-llvm-selectiondag Author: None (goldsteinn) Changes
Patch is 42.88 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/105451.diff 8 Files Affected:
diff --git a/llvm/include/llvm/CodeGen/SelectionDAGNodes.h b/llvm/include/llvm/CodeGen/SelectionDAGNodes.h
index 88549d9c9a2858..b2151cd32f1088 100644
--- a/llvm/include/llvm/CodeGen/SelectionDAGNodes.h
+++ b/llvm/include/llvm/CodeGen/SelectionDAGNodes.h
@@ -1848,6 +1848,11 @@ bool isOneOrOneSplat(SDValue V, bool AllowUndefs = false);
/// Does not permit build vector implicit truncation.
bool isAllOnesOrAllOnesSplat(SDValue V, bool AllowUndefs = false);
+/// Returns the demanded elements from the mask of a masked op (i.e
+/// MSTORE/MLOAD).
+APInt getDemandedEltsForMaskedOp(SDValue Mask, unsigned NumElts,
+ SmallVector<SDValue> *MaskEltsOut = nullptr);
+
/// Return true if \p V is either a integer or FP constant.
inline bool isIntOrFPConstant(SDValue V) {
return isa<ConstantSDNode>(V) || isa<ConstantFPSDNode>(V);
diff --git a/llvm/lib/CodeGen/SelectionDAG/SelectionDAG.cpp b/llvm/lib/CodeGen/SelectionDAG/SelectionDAG.cpp
index 18a3b7bce104a7..cb0e098a1e511e 100644
--- a/llvm/lib/CodeGen/SelectionDAG/SelectionDAG.cpp
+++ b/llvm/lib/CodeGen/SelectionDAG/SelectionDAG.cpp
@@ -12128,6 +12128,24 @@ bool llvm::isAllOnesOrAllOnesSplat(SDValue N, bool AllowUndefs) {
return C && C->isAllOnes() && C->getValueSizeInBits(0) == BitWidth;
}
+APInt llvm::getDemandedEltsForMaskedOp(SDValue Mask, unsigned NumElts,
+ SmallVector<SDValue> *MaskEltsOut) {
+ if (!ISD::isBuildVectorOfConstantSDNodes(Mask.getNode()))
+ return APInt::getAllOnes(NumElts);
+ APInt Demanded = APInt::getZero(NumElts);
+ BuildVectorSDNode *MaskBV = cast<BuildVectorSDNode>(Mask);
+ for (unsigned i = 0; i < MaskBV->getNumOperands(); ++i) {
+ APInt V;
+ if (!sd_match(MaskBV->getOperand(i), m_ConstInt(V)))
+ return APInt::getAllOnes(NumElts);
+ if (V.isNegative())
+ Demanded.setBit(i);
+ if (MaskEltsOut)
+ MaskEltsOut->emplace_back(MaskBV->getOperand(i));
+ }
+ return Demanded;
+}
+
HandleSDNode::~HandleSDNode() {
DropOperands();
}
diff --git a/llvm/lib/Target/X86/X86ISelLowering.cpp b/llvm/lib/Target/X86/X86ISelLowering.cpp
index 169c955f0ba89f..e8bbfaa41b2cb4 100644
--- a/llvm/lib/Target/X86/X86ISelLowering.cpp
+++ b/llvm/lib/Target/X86/X86ISelLowering.cpp
@@ -51536,20 +51536,111 @@ combineMaskedLoadConstantMask(MaskedLoadSDNode *ML, SelectionDAG &DAG,
return DCI.CombineTo(ML, Blend, NewML.getValue(1), true);
}
+static bool tryShrinkMaskedOperation(SelectionDAG &DAG, const SDLoc &DL,
+ SDValue Mask, EVT OrigVT,
+ SDValue *ValInOut, EVT *NewVTOut,
+ SDValue *NewMaskOut) {
+ // Ensure we have a reasonable input type.
+ // Also ensure ensure input bits is larger then xmm, otherwise its not
+ // profitable to try to shrink.
+ if (!OrigVT.isSimple() || !OrigVT.isVector() ||
+ OrigVT.getSizeInBits() <= 128 || !isPowerOf2_64(OrigVT.getSizeInBits()) ||
+ !isPowerOf2_64(OrigVT.getScalarSizeInBits()))
+ return false;
+
+ SmallVector<SDValue> OrigMask;
+ APInt DemandedElts = getDemandedEltsForMaskedOp(
+ Mask, OrigVT.getVectorNumElements(), &OrigMask);
+ if (DemandedElts.isAllOnes() || DemandedElts.isZero())
+ return false;
+
+ unsigned OrigNumElts = OrigVT.getVectorNumElements();
+ unsigned ReqElts =
+ DemandedElts.getBitWidth() - DemandedElts.countLeadingZeros();
+ // We can't shrink out vector category in a meaningful way.
+ if (ReqElts > OrigNumElts / 2U)
+ return false;
+
+ // At most shrink to xmm.
+ unsigned NewNumElts =
+ std::max(128U / OrigVT.getScalarSizeInBits(), PowerOf2Ceil(ReqElts));
+
+ EVT NewVT =
+ EVT::getVectorVT(*DAG.getContext(), OrigVT.getScalarType(), NewNumElts);
+ if (!NewVT.isSimple())
+ return false;
+
+ // Extract all the value arguments;
+ if (ValInOut && *ValInOut)
+ *ValInOut = DAG.getNode(ISD::EXTRACT_SUBVECTOR, DL, NewVT, *ValInOut,
+ DAG.getIntPtrConstant(0, DL));
+ if (NewVTOut)
+ *NewVTOut = NewVT;
+ *NewMaskOut = SDValue();
+ // The mask was just truncating, so don't need it anymore.
+ if (NewNumElts == ReqElts && DemandedElts.isMask())
+ return true;
+
+ // Get smaller mask.
+ EVT NewMaskVT = EVT::getVectorVT(
+ *DAG.getContext(), Mask.getValueType().getScalarType(), NewNumElts);
+ OrigMask.truncate(NewNumElts);
+ *NewMaskOut = DAG.getBuildVector(NewMaskVT, DL, OrigMask);
+ return true;
+}
+
+static bool tryShrinkMaskedOperation(SelectionDAG &DAG, const SDLoc &DL,
+ SDValue Mask, EVT OrigVT, EVT *NewVTOut,
+ SDValue *NewMaskOut) {
+ return tryShrinkMaskedOperation(DAG, DL, Mask, OrigVT, nullptr, NewVTOut,
+ NewMaskOut);
+}
+
+static bool tryShrinkMaskedOperation(SelectionDAG &DAG, const SDLoc &DL,
+ SDValue Mask, EVT OrigVT,
+ SDValue *ValInOut, SDValue *NewMaskOut) {
+ return tryShrinkMaskedOperation(DAG, DL, Mask, OrigVT, ValInOut, nullptr,
+ NewMaskOut);
+}
+
static SDValue combineMaskedLoad(SDNode *N, SelectionDAG &DAG,
TargetLowering::DAGCombinerInfo &DCI,
const X86Subtarget &Subtarget) {
+ using namespace llvm::SDPatternMatch;
auto *Mld = cast<MaskedLoadSDNode>(N);
+ SDLoc DL(N);
// TODO: Expanding load with constant mask may be optimized as well.
if (Mld->isExpandingLoad())
return SDValue();
+ SDValue Mask = Mld->getMask();
+ EVT VT = Mld->getValueType(0);
if (Mld->getExtensionType() == ISD::NON_EXTLOAD) {
if (SDValue ScalarLoad =
reduceMaskedLoadToScalarLoad(Mld, DAG, DCI, Subtarget))
return ScalarLoad;
+ SDValue NewMask;
+ EVT NewVT;
+ if (sd_match(Mld->getPassThru(), m_Zero()) &&
+ tryShrinkMaskedOperation(DAG, DL, Mask, VT, &NewVT, &NewMask)) {
+ SDValue NewLoad;
+ if (NewMask)
+ NewLoad = DAG.getMaskedLoad(
+ NewVT, DL, Mld->getChain(), Mld->getBasePtr(), Mld->getOffset(),
+ NewMask, getZeroVector(NewVT.getSimpleVT(), Subtarget, DAG, DL),
+ Mld->getMemoryVT(), Mld->getMemOperand(), Mld->getAddressingMode(),
+ Mld->getExtensionType());
+ else
+ NewLoad = DAG.getLoad(NewVT, DL, Mld->getChain(), Mld->getBasePtr(),
+ Mld->getMemOperand());
+
+ SDValue R = DAG.getNode(ISD::INSERT_SUBVECTOR, DL, VT, Mld->getPassThru(),
+ NewLoad, DAG.getIntPtrConstant(0, DL));
+ return DCI.CombineTo(Mld, R, NewLoad.getValue(1), true);
+ }
+
// TODO: Do some AVX512 subsets benefit from this transform?
if (!Subtarget.hasAVX512())
if (SDValue Blend = combineMaskedLoadConstantMask(Mld, DAG, DCI))
@@ -51558,9 +51649,7 @@ static SDValue combineMaskedLoad(SDNode *N, SelectionDAG &DAG,
// If the mask value has been legalized to a non-boolean vector, try to
// simplify ops leading up to it. We only demand the MSB of each lane.
- SDValue Mask = Mld->getMask();
if (Mask.getScalarValueSizeInBits() != 1) {
- EVT VT = Mld->getValueType(0);
const TargetLowering &TLI = DAG.getTargetLoweringInfo();
APInt DemandedBits(APInt::getSignMask(VT.getScalarSizeInBits()));
if (TLI.SimplifyDemandedBits(Mask, DemandedBits, DCI)) {
@@ -51622,6 +51711,8 @@ static SDValue combineMaskedStore(SDNode *N, SelectionDAG &DAG,
if (Mst->isCompressingStore())
return SDValue();
+
+
EVT VT = Mst->getValue().getValueType();
SDLoc dl(Mst);
const TargetLowering &TLI = DAG.getTargetLoweringInfo();
@@ -51651,6 +51742,17 @@ static SDValue combineMaskedStore(SDNode *N, SelectionDAG &DAG,
}
SDValue Value = Mst->getValue();
+ SDValue NewMask;
+ if (tryShrinkMaskedOperation(DAG, dl, Mask, VT, &Value, &NewMask)) {
+ if (NewMask)
+ return DAG.getMaskedStore(Mst->getChain(), dl, Value, Mst->getBasePtr(),
+ Mst->getOffset(), NewMask, Mst->getMemoryVT(),
+ Mst->getMemOperand(), Mst->getAddressingMode());
+ return DAG.getStore(Mst->getChain(), SDLoc(N), Value, Mst->getBasePtr(),
+ Mst->getPointerInfo(), Mst->getOriginalAlign(),
+ Mst->getMemOperand()->getFlags());
+ }
+
if (Value.getOpcode() == ISD::TRUNCATE && Value.getNode()->hasOneUse() &&
TLI.isTruncStoreLegal(Value.getOperand(0).getValueType(),
Mst->getMemoryVT())) {
diff --git a/llvm/test/CodeGen/X86/masked-load-store-shrink.ll b/llvm/test/CodeGen/X86/masked-load-store-shrink.ll
new file mode 100644
index 00000000000000..c1cdef54bc3737
--- /dev/null
+++ b/llvm/test/CodeGen/X86/masked-load-store-shrink.ll
@@ -0,0 +1,716 @@
+; NOTE: Assertions have been autogenerated by utils/update_llc_test_checks.py
+; RUN: llc < %s -disable-peephole -mtriple=x86_64-apple-darwin -mattr=sse2 | FileCheck %s --check-prefixes=SSE,SSE2
+; RUN: llc < %s -disable-peephole -mtriple=x86_64-apple-darwin -mattr=sse4.2 | FileCheck %s --check-prefixes=SSE,SSE42
+; RUN: llc < %s -disable-peephole -mtriple=x86_64-apple-darwin -mattr=avx | FileCheck %s --check-prefixes=AVX,AVX1OR2,AVX1
+; RUN: llc < %s -disable-peephole -mtriple=x86_64-apple-darwin -mattr=avx2 | FileCheck %s --check-prefixes=AVX,AVX1OR2,AVX2
+; RUN: llc < %s -disable-peephole -mtriple=x86_64-apple-darwin -mattr=avx512f | FileCheck %s --check-prefixes=AVX,AVX512,AVX512F
+; RUN: llc < %s -disable-peephole -mtriple=x86_64-apple-darwin -mattr=avx512f,avx512dq,avx512vl | FileCheck %s --check-prefixes=AVX,AVX512,AVX512VL,AVX512VLDQ
+; RUN: llc < %s -disable-peephole -mtriple=x86_64-apple-darwin -mattr=avx512f,avx512bw,avx512vl | FileCheck %s --check-prefixes=AVX,AVX512,AVX512VL,AVX512VLBW
+; RUN: llc < %s -mtriple=i686-apple-darwin -mattr=avx512f,avx512bw,avx512dq,avx512vl -verify-machineinstrs | FileCheck %s --check-prefixes=X86-AVX512
+
+define <4 x i64> @mload256_to_load128(ptr %p) nounwind {
+; SSE-LABEL: mload256_to_load128:
+; SSE: ## %bb.0:
+; SSE-NEXT: movups (%rdi), %xmm0
+; SSE-NEXT: xorps %xmm1, %xmm1
+; SSE-NEXT: retq
+;
+; AVX-LABEL: mload256_to_load128:
+; AVX: ## %bb.0:
+; AVX-NEXT: vmovaps (%rdi), %xmm0
+; AVX-NEXT: retq
+;
+; X86-AVX512-LABEL: mload256_to_load128:
+; X86-AVX512: ## %bb.0:
+; X86-AVX512-NEXT: movl {{[0-9]+}}(%esp), %eax
+; X86-AVX512-NEXT: vmovaps (%eax), %xmm0
+; X86-AVX512-NEXT: retl
+ %tmp = tail call <8 x float> @llvm.masked.load.v8f32.p0(ptr %p, i32 32, <8 x i1> <i1 true, i1 true, i1 true, i1 true, i1 false, i1 false, i1 false, i1 false>, <8 x float> <float poison, float poison, float poison, float poison, float 0.000000e+00, float 0.000000e+00, float 0.000000e+00, float 0.000000e+00>)
+ %r = bitcast <8 x float> %tmp to <4 x i64>
+ ret <4 x i64> %r
+}
+
+define <8 x i64> @mload512_to_load256(ptr %p) nounwind {
+; SSE-LABEL: mload512_to_load256:
+; SSE: ## %bb.0:
+; SSE-NEXT: movups (%rdi), %xmm0
+; SSE-NEXT: movups 16(%rdi), %xmm1
+; SSE-NEXT: xorps %xmm2, %xmm2
+; SSE-NEXT: xorps %xmm3, %xmm3
+; SSE-NEXT: retq
+;
+; AVX1OR2-LABEL: mload512_to_load256:
+; AVX1OR2: ## %bb.0:
+; AVX1OR2-NEXT: vmovups (%rdi), %ymm0
+; AVX1OR2-NEXT: vxorps %xmm1, %xmm1, %xmm1
+; AVX1OR2-NEXT: retq
+;
+; AVX512-LABEL: mload512_to_load256:
+; AVX512: ## %bb.0:
+; AVX512-NEXT: vmovups (%rdi), %ymm0
+; AVX512-NEXT: retq
+;
+; X86-AVX512-LABEL: mload512_to_load256:
+; X86-AVX512: ## %bb.0:
+; X86-AVX512-NEXT: movl {{[0-9]+}}(%esp), %eax
+; X86-AVX512-NEXT: vmovups (%eax), %ymm0
+; X86-AVX512-NEXT: retl
+ %tmp = tail call <32 x i16> @llvm.masked.load.v32i16.p0(ptr %p, i32 1, <32 x i1> <i1 true, i1 true, i1 true, i1 true, i1 true, i1 true, i1 true, i1 true, i1 true, i1 true, i1 true, i1 true, i1 true, i1 true, i1 true, i1 true, i1 false, i1 false, i1 false, i1 false, i1 false, i1 false, i1 false, i1 false, i1 false, i1 false, i1 false, i1 false, i1 false, i1 false, i1 false, i1 false>, <32 x i16> <i16 poison, i16 poison, i16 poison, i16 poison, i16 poison, i16 poison, i16 poison, i16 poison, i16 poison, i16 poison, i16 poison, i16 poison, i16 poison, i16 poison, i16 poison, i16 poison, i16 0, i16 0, i16 0, i16 0, i16 0, i16 0, i16 0, i16 0, i16 0, i16 0, i16 0, i16 0, i16 0, i16 0, i16 0, i16 0>)
+ %r = bitcast <32 x i16> %tmp to <8 x i64>
+ ret <8 x i64> %r
+}
+
+define <8 x i64> @mload512_to_mload128(ptr %p) nounwind {
+; SSE-LABEL: mload512_to_mload128:
+; SSE: ## %bb.0:
+; SSE-NEXT: movsd {{.*#+}} xmm0 = mem[0],zero
+; SSE-NEXT: xorps %xmm1, %xmm1
+; SSE-NEXT: xorps %xmm2, %xmm2
+; SSE-NEXT: xorps %xmm3, %xmm3
+; SSE-NEXT: retq
+;
+; AVX1OR2-LABEL: mload512_to_mload128:
+; AVX1OR2: ## %bb.0:
+; AVX1OR2-NEXT: vmovsd {{.*#+}} xmm0 = [4294967295,4294967295,0,0]
+; AVX1OR2-NEXT: vmaskmovps (%rdi), %xmm0, %xmm0
+; AVX1OR2-NEXT: vxorps %xmm1, %xmm1, %xmm1
+; AVX1OR2-NEXT: retq
+;
+; AVX512F-LABEL: mload512_to_mload128:
+; AVX512F: ## %bb.0:
+; AVX512F-NEXT: movw $3, %ax
+; AVX512F-NEXT: kmovw %eax, %k1
+; AVX512F-NEXT: vmovaps (%rdi), %zmm0 {%k1} {z}
+; AVX512F-NEXT: vmovaps %xmm0, %xmm0
+; AVX512F-NEXT: retq
+;
+; AVX512VLDQ-LABEL: mload512_to_mload128:
+; AVX512VLDQ: ## %bb.0:
+; AVX512VLDQ-NEXT: movb $3, %al
+; AVX512VLDQ-NEXT: kmovw %eax, %k1
+; AVX512VLDQ-NEXT: vmovaps (%rdi), %xmm0 {%k1} {z}
+; AVX512VLDQ-NEXT: retq
+;
+; AVX512VLBW-LABEL: mload512_to_mload128:
+; AVX512VLBW: ## %bb.0:
+; AVX512VLBW-NEXT: movb $3, %al
+; AVX512VLBW-NEXT: kmovd %eax, %k1
+; AVX512VLBW-NEXT: vmovaps (%rdi), %xmm0 {%k1} {z}
+; AVX512VLBW-NEXT: retq
+;
+; X86-AVX512-LABEL: mload512_to_mload128:
+; X86-AVX512: ## %bb.0:
+; X86-AVX512-NEXT: movl {{[0-9]+}}(%esp), %eax
+; X86-AVX512-NEXT: movb $3, %cl
+; X86-AVX512-NEXT: kmovd %ecx, %k1
+; X86-AVX512-NEXT: vmovaps (%eax), %xmm0 {%k1} {z}
+; X86-AVX512-NEXT: retl
+ %tmp = tail call <16 x float> @llvm.masked.load.v16f32.p0(ptr %p, i32 64, <16 x i1> <i1 true, i1 true, i1 false, i1 false, i1 false, i1 false, i1 false, i1 false, i1 false, i1 false, i1 false, i1 false, i1 false, i1 false, i1 false, i1 false>, <16 x float> <float poison, float poison, float 0.000000e+00, float 0.000000e+00, float 0.000000e+00, float 0.000000e+00, float 0.000000e+00, float 0.000000e+00, float 0.000000e+00, float 0.000000e+00, float 0.000000e+00, float 0.000000e+00, float 0.000000e+00, float 0.000000e+00, float 0.000000e+00, float 0.000000e+00>)
+ %r = bitcast <16 x float> %tmp to <8 x i64>
+ ret <8 x i64> %r
+}
+
+define <4 x i64> @mload256_to_mload128(ptr %p) nounwind {
+; SSE2-LABEL: mload256_to_mload128:
+; SSE2: ## %bb.0:
+; SSE2-NEXT: movss {{.*#+}} xmm0 = mem[0],zero,zero,zero
+; SSE2-NEXT: movss {{.*#+}} xmm1 = mem[0],zero,zero,zero
+; SSE2-NEXT: movlhps {{.*#+}} xmm0 = xmm0[0],xmm1[0]
+; SSE2-NEXT: xorps %xmm1, %xmm1
+; SSE2-NEXT: retq
+;
+; SSE42-LABEL: mload256_to_mload128:
+; SSE42: ## %bb.0:
+; SSE42-NEXT: movss {{.*#+}} xmm0 = mem[0],zero,zero,zero
+; SSE42-NEXT: insertps {{.*#+}} xmm0 = xmm0[0],zero,mem[0],zero
+; SSE42-NEXT: xorps %xmm1, %xmm1
+; SSE42-NEXT: retq
+;
+; AVX1OR2-LABEL: mload256_to_mload128:
+; AVX1OR2: ## %bb.0:
+; AVX1OR2-NEXT: vmovddup {{.*#+}} xmm0 = [4294967295,0,4294967295,0]
+; AVX1OR2-NEXT: ## xmm0 = mem[0,0]
+; AVX1OR2-NEXT: vmaskmovps (%rdi), %xmm0, %xmm0
+; AVX1OR2-NEXT: retq
+;
+; AVX512F-LABEL: mload256_to_mload128:
+; AVX512F: ## %bb.0:
+; AVX512F-NEXT: movw $5, %ax
+; AVX512F-NEXT: kmovw %eax, %k1
+; AVX512F-NEXT: vmovups (%rdi), %zmm0 {%k1} {z}
+; AVX512F-NEXT: vmovaps %xmm0, %xmm0
+; AVX512F-NEXT: retq
+;
+; AVX512VLDQ-LABEL: mload256_to_mload128:
+; AVX512VLDQ: ## %bb.0:
+; AVX512VLDQ-NEXT: movb $5, %al
+; AVX512VLDQ-NEXT: kmovw %eax, %k1
+; AVX512VLDQ-NEXT: vmovaps (%rdi), %xmm0 {%k1} {z}
+; AVX512VLDQ-NEXT: retq
+;
+; AVX512VLBW-LABEL: mload256_to_mload128:
+; AVX512VLBW: ## %bb.0:
+; AVX512VLBW-NEXT: movb $5, %al
+; AVX512VLBW-NEXT: kmovd %eax, %k1
+; AVX512VLBW-NEXT: vmovaps (%rdi), %xmm0 {%k1} {z}
+; AVX512VLBW-NEXT: retq
+;
+; X86-AVX512-LABEL: mload256_to_mload128:
+; X86-AVX512: ## %bb.0:
+; X86-AVX512-NEXT: movl {{[0-9]+}}(%esp), %eax
+; X86-AVX512-NEXT: movb $5, %cl
+; X86-AVX512-NEXT: kmovd %ecx, %k1
+; X86-AVX512-NEXT: vmovaps (%eax), %xmm0 {%k1} {z}
+; X86-AVX512-NEXT: retl
+ %tmp = tail call <8 x float> @llvm.masked.load.v8f32.p0(ptr %p, i32 32, <8 x i1> <i1 true, i1 false, i1 true, i1 false, i1 false, i1 false, i1 false, i1 false>, <8 x float> <float poison, float 0.000000e+00, float poison, float 0.000000e+00, float 0.000000e+00, float 0.000000e+00, float 0.000000e+00, float 0.000000e+00>)
+ %r = bitcast <8 x float> %tmp to <4 x i64>
+ ret <4 x i64> %r
+}
+
+define <8 x i64> @mload512_to_mload256(ptr %p) nounwind {
+; SSE-LABEL: mload512_to_mload256:
+; SSE: ## %bb.0:
+; SSE-NEXT: xorps %xmm0, %xmm0
+; SSE-NEXT: movhps {{.*#+}} xmm0 = xmm0[0,1],mem[0,1]
+; SSE-NEXT: movss {{.*#+}} xmm1 = mem[0],zero,zero,zero
+; SSE-NEXT: xorps %xmm2, %xmm2
+; SSE-NEXT: xorps %xmm3, %xmm3
+; SSE-NEXT: retq
+;
+; AVX1OR2-LABEL: mload512_to_mload256:
+; AVX1OR2: ## %bb.0:
+; AVX1OR2-NEXT: vmovaps {{.*#+}} ymm0 = [0,0,4294967295,4294967295,4294967295,0,0,0]
+; AVX1OR2-NEXT: vmaskmovps (%rdi), %ymm0, %ymm0
+; AVX1OR2-NEXT: vxorps %xmm1, %xmm1, %xmm1
+; AVX1OR2-NEXT: retq
+;
+; AVX512F-LABEL: mload512_to_mload256:
+; AVX512F: ## %bb.0:
+; AVX512F-NEXT: movw $28, %ax
+; AVX512F-NEXT: kmovw %eax, %k1
+; AVX512F-NEXT: vmovaps (%rdi), %zmm0 {%k1} {z}
+; AVX512F-NEXT: vmovaps %ymm0, %ymm0
+; AVX512F-NEXT: retq
+;
+; AVX512VLDQ-LABEL: mload512_to_mload256:
+; AVX512VLDQ: ## %bb.0:
+; AVX512VLDQ-NEXT: movb $28, %al
+; AVX512VLDQ-NEXT: kmovw %eax, %k1
+; AVX512VLDQ-NEXT: vmovaps (%rdi), %ymm0 {%k1} {z}
+; AVX512VLDQ-NEXT: retq
+;
+; AVX512VLBW-LABEL: mload512_to_mload256:
+; AVX512VLBW: ## %bb.0:
+; AVX512VLBW-NEXT: movb $28, %al
+; AVX512VLBW-NEXT: kmovd %eax, %k1
+; AVX512VLBW-NEXT: vmovaps (%rdi), %ymm0 {%k1} {z}
+; AVX512VLBW-NEXT: retq
+;
+; X86-AVX512-LABEL: mload512_to_mload256:
+; X86-AVX512: ## %bb.0:
+; X86-AVX512-NEXT: movl {{[0-9]+}}(%esp), %eax
+; X86-AVX512-NEXT: movb $28, %cl
+; X86-AVX512-NEXT: kmovd %ecx, %k1
+; X86-AVX512-NEXT: vmovaps (%eax), %ymm0 {%k1} {z}
+; X86-AVX512-NEXT: retl
+ %tmp = tail call <16 x float> @llvm.masked.load.v16f32.p0(ptr %p, i32 64, <16 x i1> <i1 false, i1 false, i1 true, i1 true, i1 true, i1 false, i1 false, i1 false, i1 false, i1 false, i1 false, i1 false, i1 false, i1 false, i1 false, i1 false>, <16 x float> <float 0.000000e+00, float 0.000000e+00, float poison, float poison, float poison, float 0.000000e+00, float 0.000000e+00, float 0.000000e+00, float 0.000000e+00, float 0.000000e+00, float 0.000000e+00, float 0.000000e+00, float 0.000000e+00, float 0.000000e+00, float 0.000000e+00, float 0.000000e+00>)
+ %r = bitcast <16 x float> %tmp to <8 x i64>
+ ret <8 x i64> %r
+}
+
+define <8 x i64> @mload512_fail_no_possible_shrink(ptr %p) nounwind {
+; SSE-LABEL: mload512_fail_no_possible_shrink:
+; SSE: ## %bb.0:
+; SSE-NEXT: movss {{.*#+}} xmm2 = mem[0],zero,zero,zero
+; SSE-NEXT: movups (%rdi), %xmm0
+; SSE-NEXT: movups 16(%rdi), %xmm1
+; SSE-NEXT: xorps %xmm3, %xmm3
+; SSE-NEXT: retq
+;
+; AVX1OR2-LABEL: mload512_fail_no_possible_shrink:
+; AVX1OR2: ## %bb.0:
+; AVX1OR2-NEXT: vmovss {{.*#+}} xmm0 = [4294967295,0,0,0]
+; AVX1OR2-NEXT: vmaskmovps 32(%rdi), %xmm0, %xmm1
+; AVX1OR2-NEXT: vmovaps (%rdi), %ymm0
+; AVX1OR2-NEXT: retq
+;
+; AVX512F-LABEL: mload512_fail_no_possible_shrink:
+; AVX512F: ## %bb.0:
+; AVX512F-NEXT: movw $511, %ax ## imm = 0x1FF
+; AVX512F-NEXT: kmovw %eax, %k1
+; AVX512F-NEXT: vmovaps (%rdi), %zmm0 {%k1} {z}
+; AVX512F-NEXT: retq
+;
+; AVX512VLDQ-LABEL: mload512_fail_no_possible_shrink:
+; AVX512VLDQ: ## %bb.0:
+; AVX512VLDQ-NEXT: movw $511, %ax ## imm = 0x1FF
+; AVX512VLDQ-NEXT: kmovw %eax, %k1
+; AVX512VLDQ-NEXT: vmovaps (%rdi), %zmm0 {%k1} {z}
+; AVX512...
[truncated]
|
You can test this locally with the following command:git-clang-format --diff a16f0dc9c2f0690e28622b0d80bd154fb0e6a30a 03cedf6315570e1d403380cddd688deb03198448 --extensions h,cpp -- llvm/include/llvm/CodeGen/SelectionDAGNodes.h llvm/lib/CodeGen/SelectionDAG/SelectionDAG.cpp llvm/lib/Target/X86/X86ISelLowering.cpp View the diff from clang-format here.diff --git a/llvm/lib/Target/X86/X86ISelLowering.cpp b/llvm/lib/Target/X86/X86ISelLowering.cpp
index b576127702..2b8dca91c0 100644
--- a/llvm/lib/Target/X86/X86ISelLowering.cpp
+++ b/llvm/lib/Target/X86/X86ISelLowering.cpp
@@ -51712,8 +51712,6 @@ static SDValue combineMaskedStore(SDNode *N, SelectionDAG &DAG,
if (Mst->isCompressingStore())
return SDValue();
-
-
EVT VT = Mst->getValue().getValueType();
SDLoc dl(Mst);
const TargetLowering &TLI = DAG.getTargetLoweringInfo();
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can we do the shrink in IR phase?
SDValue *ValInOut, EVT *NewVTOut, | ||
SDValue *NewMaskOut) { | ||
// Ensure we have a reasonable input type. | ||
// Also ensure ensure input bits is larger then xmm, otherwise its not |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
ensure ensure
-> ensure
if (!OrigVT.isSimple() || !OrigVT.isVector() || | ||
OrigVT.getSizeInBits() <= 128 || !isPowerOf2_64(OrigVT.getSizeInBits()) || | ||
!isPowerOf2_64(OrigVT.getScalarSizeInBits())) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can we just check if it's legal type and size is 256 or 512?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
+1, maybe we can use is256BitVector
is512BitVector
unsigned ReqElts = | ||
DemandedElts.getBitWidth() - DemandedElts.countLeadingZeros(); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think we are able to load arbitrary subvector instead of the first one?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We could but the codegen is only "free" for the lower one. Ill add it as a potential todo.
APInt V; | ||
if (!sd_match(MaskBV->getOperand(i), m_ConstInt(V))) | ||
return APInt::getAllOnes(NumElts); | ||
if (V.isNegative()) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
isNegative()
-> isOne()
?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Negative is "more" correct I think.
For example:
// If the mask value has been legalized to a non-boolean vector, try to
// simplify ops leading up to it. We only demand the MSB of each lane.
if (Mask.getScalarValueSizeInBits() != 1) {
const TargetLowering &TLI = DAG.getTargetLoweringInfo();
APInt DemandedBits(APInt::getSignMask(VT.getScalarSizeInBits()));
if (TLI.SimplifyDemandedBits(Mask, DemandedBits, DCI)) {
if (N->getOpcode() != ISD::DELETED_NODE)
DCI.AddToWorklist(N);
return SDValue(N, 0);
}
if (SDValue NewMask =
TLI.SimplifyMultipleUseDemandedBits(Mask, DemandedBits, DAG))
return DAG.getMaskedLoad(
VT, SDLoc(N), Mld->getChain(), Mld->getBasePtr(), Mld->getOffset(),
NewMask, Mld->getPassThru(), Mld->getMemoryVT(), Mld->getMemOperand(),
Mld->getAddressingMode(), Mld->getExtensionType());
}
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Doesn't Mask have to respect getBooleanContents ?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Hmm? Don't see how this violates that.
SmallVector<SDValue> OrigMask; | ||
APInt DemandedElts = getDemandedEltsForMaskedOp( | ||
Mask, OrigVT.getVectorNumElements(), &OrigMask); | ||
if (DemandedElts.isAllOnes() || DemandedElts.isZero()) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Maybe we should use assert
here. AllOnes masked.load/store for 256/512 should be before optimized away before SelectionDAG
.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Slightly opposed. It might be produced as a temporary no?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Slightly opposed. It might be produced as a temporary no?
Probably, although I wasn't sure 1) the profitability on other arch 2) if it might mess up vectorization to have different sized vecs in the middle-end and 3) I wasn't 100% sure on the semantics, particularly: "Only the masked-on lanes of the vector need to be inbounds of an allocation (but all these lanes need to be inbounds of the same allocation)" I wasn't sure if we could gurantee the second point but I'm not an expert in the memory model. |
In the best case we can convert a masked load/store to a narrower normal load/store. I.e `_mm512_maskz_load_ps(p, 0xff)` can be done with just an normal `ymm` store. As well, if the mask is entirely encapsulated in a lower sub-vector, we can shrink the load/store i.e `_mm512_maskz_load_ps(p, 0x1c)` is the same as `_mm256_maskz_load_ps(p, 0x1c)`.
b6fc529
to
03cedf6
Compare
unsigned NewNumElts = | ||
std::max(128U / OrigVT.getScalarSizeInBits(), PowerOf2Ceil(ReqElts)); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
IIUC, line 51559 assumes the mask is a boolean vector while line 51567 does not assume ?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
How so? We are re-using the orig scalar type.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Line 51559 use the number of non-zero bits as the ReqElts
. I think this assumes the mask is a vector of boolean?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thats elements, not bits,
if (!NewVT.isSimple()) | ||
return false; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Should this be unreachable? Since you check OrigVT.isSimple()
at the beginning.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Probably, although since we are re-creating with a different ele count its theoretically not unreachable.
In the best case we can convert a masked load/store to a narrower
normal load/store. I.e
_mm512_maskz_load_ps(p, 0xff)
can be donewith just an normal
ymm
store. As well, if the mask is entirelyencapsulated in a lower sub-vector, we can shrink the load/store i.e
_mm512_maskz_load_ps(p, 0x1c)
is the same as_mm256_maskz_load_ps(p, 0x1c)
.