-
Notifications
You must be signed in to change notification settings - Fork 15.4k
[SelectionDAG] Add SelectionDAG::getTypeSize. NFC #169764
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
Similar to how getElementCount avoids the need to reason about fixed and scalable ElementCounts separately, this patch adds getTypeSize to do the same for TypeSize. It also goes through and replaces some of the manual uses of getVScale with getTypeSize/getElementCount where possible.
|
@llvm/pr-subscribers-llvm-selectiondag @llvm/pr-subscribers-backend-aarch64 Author: Luke Lau (lukel97) ChangesSimilar to how getElementCount avoids the need to reason about fixed and scalable ElementCounts separately, this patch adds getTypeSize to do the same for TypeSize. It also goes through and replaces some of the manual uses of getVScale with getTypeSize/getElementCount where possible. Full diff: https://github.com/llvm/llvm-project/pull/169764.diff 7 Files Affected:
diff --git a/llvm/include/llvm/CodeGen/SelectionDAG.h b/llvm/include/llvm/CodeGen/SelectionDAG.h
index b024e8a68bd6e..338f124edd601 100644
--- a/llvm/include/llvm/CodeGen/SelectionDAG.h
+++ b/llvm/include/llvm/CodeGen/SelectionDAG.h
@@ -1190,6 +1190,9 @@ class SelectionDAG {
LLVM_ABI SDValue getElementCount(const SDLoc &DL, EVT VT, ElementCount EC,
bool ConstantFold = true);
+ LLVM_ABI SDValue getTypeSize(const SDLoc &DL, EVT VT, TypeSize TS,
+ bool ConstantFold = true);
+
/// Return a GLOBAL_OFFSET_TABLE node. This does not have a useful SDLoc.
SDValue getGLOBAL_OFFSET_TABLE(EVT VT) {
return getNode(ISD::GLOBAL_OFFSET_TABLE, SDLoc(), VT);
diff --git a/llvm/lib/CodeGen/SelectionDAG/LegalizeVectorTypes.cpp b/llvm/lib/CodeGen/SelectionDAG/LegalizeVectorTypes.cpp
index 4274e951446b8..53b7aede7b4a5 100644
--- a/llvm/lib/CodeGen/SelectionDAG/LegalizeVectorTypes.cpp
+++ b/llvm/lib/CodeGen/SelectionDAG/LegalizeVectorTypes.cpp
@@ -1702,10 +1702,8 @@ void DAGTypeLegalizer::SplitVecRes_LOOP_DEPENDENCE_MASK(SDNode *N, SDValue &Lo,
Lo = DAG.getNode(N->getOpcode(), DL, LoVT, PtrA, PtrB, N->getOperand(2));
unsigned EltSize = N->getConstantOperandVal(2);
- unsigned Offset = EltSize * HiVT.getVectorMinNumElements();
- SDValue Addend = HiVT.isScalableVT()
- ? DAG.getVScale(DL, MVT::i64, APInt(64, Offset))
- : DAG.getConstant(Offset, DL, MVT::i64);
+ ElementCount Offset = HiVT.getVectorElementCount() * EltSize;
+ SDValue Addend = DAG.getElementCount(DL, MVT::i64, Offset);
PtrA = DAG.getNode(ISD::ADD, DL, MVT::i64, PtrA, Addend);
Hi = DAG.getNode(N->getOpcode(), DL, HiVT, PtrA, PtrB, N->getOperand(2));
diff --git a/llvm/lib/CodeGen/SelectionDAG/SelectionDAG.cpp b/llvm/lib/CodeGen/SelectionDAG/SelectionDAG.cpp
index 1b15a207a2d37..aa18df180fbb0 100644
--- a/llvm/lib/CodeGen/SelectionDAG/SelectionDAG.cpp
+++ b/llvm/lib/CodeGen/SelectionDAG/SelectionDAG.cpp
@@ -2102,13 +2102,24 @@ SDValue SelectionDAG::getVScale(const SDLoc &DL, EVT VT, APInt MulImm,
return getNode(ISD::VSCALE, DL, VT, getConstant(MulImm, DL, VT));
}
+template <typename Ty>
+static SDValue getFixedOrScalableQuantity(SelectionDAG &DAG, const SDLoc &DL,
+ EVT VT, Ty X, bool ConstantFold) {
+ if (X.isScalable())
+ return DAG.getVScale(DL, VT,
+ APInt(VT.getSizeInBits(), X.getKnownMinValue()));
+
+ return DAG.getConstant(X.getKnownMinValue(), DL, VT);
+}
+
SDValue SelectionDAG::getElementCount(const SDLoc &DL, EVT VT, ElementCount EC,
bool ConstantFold) {
- if (EC.isScalable())
- return getVScale(DL, VT,
- APInt(VT.getSizeInBits(), EC.getKnownMinValue()));
+ return getFixedOrScalableQuantity(*this, DL, VT, EC, ConstantFold);
+}
- return getConstant(EC.getKnownMinValue(), DL, VT);
+SDValue SelectionDAG::getTypeSize(const SDLoc &DL, EVT VT, TypeSize TS,
+ bool ConstantFold) {
+ return getFixedOrScalableQuantity(*this, DL, VT, TS, ConstantFold);
}
SDValue SelectionDAG::getStepVector(const SDLoc &DL, EVT ResVT) {
@@ -8485,16 +8496,7 @@ static SDValue getMemsetStringVal(EVT VT, const SDLoc &dl, SelectionDAG &DAG,
SDValue SelectionDAG::getMemBasePlusOffset(SDValue Base, TypeSize Offset,
const SDLoc &DL,
const SDNodeFlags Flags) {
- EVT VT = Base.getValueType();
- SDValue Index;
-
- if (Offset.isScalable())
- Index = getVScale(DL, Base.getValueType(),
- APInt(Base.getValueSizeInBits().getFixedValue(),
- Offset.getKnownMinValue()));
- else
- Index = getConstant(Offset.getFixedValue(), DL, VT);
-
+ SDValue Index = getTypeSize(DL, Base.getValueType(), Offset);
return getMemBasePlusOffset(Base, Index, DL, Flags);
}
@@ -13570,11 +13572,8 @@ std::pair<SDValue, SDValue> SelectionDAG::SplitEVL(SDValue N, EVT VecVT,
EVT VT = N.getValueType();
assert(VecVT.getVectorElementCount().isKnownEven() &&
"Expecting the mask to be an evenly-sized vector");
- unsigned HalfMinNumElts = VecVT.getVectorMinNumElements() / 2;
- SDValue HalfNumElts =
- VecVT.isFixedLengthVector()
- ? getConstant(HalfMinNumElts, DL, VT)
- : getVScale(DL, VT, APInt(VT.getScalarSizeInBits(), HalfMinNumElts));
+ SDValue HalfNumElts = getElementCount(
+ DL, VT, VecVT.getVectorElementCount().divideCoefficientBy(2));
SDValue Lo = getNode(ISD::UMIN, DL, VT, N, HalfNumElts);
SDValue Hi = getNode(ISD::USUBSAT, DL, VT, N, HalfNumElts);
return std::make_pair(Lo, Hi);
diff --git a/llvm/lib/CodeGen/SelectionDAG/SelectionDAGBuilder.cpp b/llvm/lib/CodeGen/SelectionDAG/SelectionDAGBuilder.cpp
index 88b35582a9f7d..571f8b11dccf9 100644
--- a/llvm/lib/CodeGen/SelectionDAG/SelectionDAGBuilder.cpp
+++ b/llvm/lib/CodeGen/SelectionDAG/SelectionDAGBuilder.cpp
@@ -4583,17 +4583,8 @@ void SelectionDAGBuilder::visitAlloca(const AllocaInst &I) {
if (AllocSize.getValueType() != IntPtr)
AllocSize = DAG.getZExtOrTrunc(AllocSize, dl, IntPtr);
- if (TySize.isScalable())
- AllocSize = DAG.getNode(ISD::MUL, dl, IntPtr, AllocSize,
- DAG.getVScale(dl, IntPtr,
- APInt(IntPtr.getScalarSizeInBits(),
- TySize.getKnownMinValue())));
- else {
- SDValue TySizeValue =
- DAG.getConstant(TySize.getFixedValue(), dl, MVT::getIntegerVT(64));
- AllocSize = DAG.getNode(ISD::MUL, dl, IntPtr, AllocSize,
- DAG.getZExtOrTrunc(TySizeValue, dl, IntPtr));
- }
+ AllocSize = DAG.getNode(ISD::MUL, dl, IntPtr, AllocSize,
+ DAG.getTypeSize(dl, IntPtr, TySize));
// Handle alignment. If the requested alignment is less than or equal to
// the stack alignment, ignore it. If the size is greater than or equal to
diff --git a/llvm/lib/CodeGen/SelectionDAG/TargetLowering.cpp b/llvm/lib/CodeGen/SelectionDAG/TargetLowering.cpp
index 5684e0e4c26c4..c7bab045f4bf1 100644
--- a/llvm/lib/CodeGen/SelectionDAG/TargetLowering.cpp
+++ b/llvm/lib/CodeGen/SelectionDAG/TargetLowering.cpp
@@ -10625,12 +10625,8 @@ TargetLowering::IncrementMemoryAddress(SDValue Addr, SDValue Mask,
SDValue Scale = DAG.getConstant(DataVT.getScalarSizeInBits() / 8, DL,
AddrVT);
Increment = DAG.getNode(ISD::MUL, DL, AddrVT, Increment, Scale);
- } else if (DataVT.isScalableVector()) {
- Increment = DAG.getVScale(DL, AddrVT,
- APInt(AddrVT.getFixedSizeInBits(),
- DataVT.getStoreSize().getKnownMinValue()));
} else
- Increment = DAG.getConstant(DataVT.getStoreSize(), DL, AddrVT);
+ Increment = DAG.getTypeSize(DL, AddrVT, DataVT.getStoreSize());
return DAG.getNode(ISD::ADD, DL, AddrVT, Addr, Increment);
}
@@ -11923,10 +11919,8 @@ SDValue TargetLowering::expandVectorSplice(SDNode *Node,
// Store the lo part of CONCAT_VECTORS(V1, V2)
SDValue StoreV1 = DAG.getStore(DAG.getEntryNode(), DL, V1, StackPtr, PtrInfo);
// Store the hi part of CONCAT_VECTORS(V1, V2)
- SDValue OffsetToV2 = DAG.getVScale(
- DL, PtrVT,
- APInt(PtrVT.getFixedSizeInBits(), VT.getStoreSize().getKnownMinValue()));
- SDValue StackPtr2 = DAG.getNode(ISD::ADD, DL, PtrVT, StackPtr, OffsetToV2);
+ SDValue VTBytes = DAG.getTypeSize(DL, PtrVT, VT.getStoreSize());
+ SDValue StackPtr2 = DAG.getNode(ISD::ADD, DL, PtrVT, StackPtr, VTBytes);
SDValue StoreV2 = DAG.getStore(StoreV1, DL, V2, StackPtr2, PtrInfo);
if (Imm >= 0) {
@@ -11945,13 +11939,8 @@ SDValue TargetLowering::expandVectorSplice(SDNode *Node,
SDValue TrailingBytes =
DAG.getConstant(TrailingElts * EltByteSize, DL, PtrVT);
- if (TrailingElts > VT.getVectorMinNumElements()) {
- SDValue VLBytes =
- DAG.getVScale(DL, PtrVT,
- APInt(PtrVT.getFixedSizeInBits(),
- VT.getStoreSize().getKnownMinValue()));
- TrailingBytes = DAG.getNode(ISD::UMIN, DL, PtrVT, TrailingBytes, VLBytes);
- }
+ if (TrailingElts > VT.getVectorMinNumElements())
+ TrailingBytes = DAG.getNode(ISD::UMIN, DL, PtrVT, TrailingBytes, VTBytes);
// Calculate the start address of the spliced result.
StackPtr2 = DAG.getNode(ISD::SUB, DL, PtrVT, StackPtr2, TrailingBytes);
diff --git a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
index 83ce39fa314d1..1c91f5ca97eb3 100644
--- a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
+++ b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
@@ -8639,7 +8639,7 @@ SDValue AArch64TargetLowering::LowerFormalArguments(
Subtarget->isWindowsArm64EC()) &&
"Indirect arguments should be scalable on most subtargets");
- uint64_t PartSize = VA.getValVT().getStoreSize().getKnownMinValue();
+ TypeSize PartSize = VA.getValVT().getStoreSize();
unsigned NumParts = 1;
if (Ins[i].Flags.isInConsecutiveRegs()) {
while (!Ins[i + NumParts - 1].Flags.isInConsecutiveRegsLast())
@@ -8656,16 +8656,8 @@ SDValue AArch64TargetLowering::LowerFormalArguments(
InVals.push_back(ArgValue);
NumParts--;
if (NumParts > 0) {
- SDValue BytesIncrement;
- if (PartLoad.isScalableVector()) {
- BytesIncrement = DAG.getVScale(
- DL, Ptr.getValueType(),
- APInt(Ptr.getValueSizeInBits().getFixedValue(), PartSize));
- } else {
- BytesIncrement = DAG.getConstant(
- APInt(Ptr.getValueSizeInBits().getFixedValue(), PartSize), DL,
- Ptr.getValueType());
- }
+ SDValue BytesIncrement =
+ DAG.getTypeSize(DL, Ptr.getValueType(), PartSize);
Ptr = DAG.getNode(ISD::ADD, DL, Ptr.getValueType(), Ptr,
BytesIncrement, SDNodeFlags::NoUnsignedWrap);
ExtraArgLocs++;
@@ -9868,8 +9860,8 @@ AArch64TargetLowering::LowerCall(CallLoweringInfo &CLI,
assert((isScalable || Subtarget->isWindowsArm64EC()) &&
"Indirect arguments should be scalable on most subtargets");
- uint64_t StoreSize = VA.getValVT().getStoreSize().getKnownMinValue();
- uint64_t PartSize = StoreSize;
+ TypeSize StoreSize = VA.getValVT().getStoreSize();
+ TypeSize PartSize = StoreSize;
unsigned NumParts = 1;
if (Outs[i].Flags.isInConsecutiveRegs()) {
while (!Outs[i + NumParts - 1].Flags.isInConsecutiveRegsLast())
@@ -9880,7 +9872,8 @@ AArch64TargetLowering::LowerCall(CallLoweringInfo &CLI,
Type *Ty = EVT(VA.getValVT()).getTypeForEVT(*DAG.getContext());
Align Alignment = DAG.getDataLayout().getPrefTypeAlign(Ty);
MachineFrameInfo &MFI = MF.getFrameInfo();
- int FI = MFI.CreateStackObject(StoreSize, Alignment, false);
+ int FI =
+ MFI.CreateStackObject(StoreSize.getKnownMinValue(), Alignment, false);
if (isScalable) {
bool IsPred = VA.getValVT() == MVT::aarch64svcount ||
VA.getValVT().getVectorElementType() == MVT::i1;
@@ -9901,16 +9894,8 @@ AArch64TargetLowering::LowerCall(CallLoweringInfo &CLI,
NumParts--;
if (NumParts > 0) {
- SDValue BytesIncrement;
- if (isScalable) {
- BytesIncrement = DAG.getVScale(
- DL, Ptr.getValueType(),
- APInt(Ptr.getValueSizeInBits().getFixedValue(), PartSize));
- } else {
- BytesIncrement = DAG.getConstant(
- APInt(Ptr.getValueSizeInBits().getFixedValue(), PartSize), DL,
- Ptr.getValueType());
- }
+ SDValue BytesIncrement =
+ DAG.getTypeSize(DL, Ptr.getValueType(), PartSize);
MPI = MachinePointerInfo(MPI.getAddrSpace());
Ptr = DAG.getNode(ISD::ADD, DL, Ptr.getValueType(), Ptr,
BytesIncrement, SDNodeFlags::NoUnsignedWrap);
diff --git a/llvm/lib/Target/RISCV/RISCVISelLowering.cpp b/llvm/lib/Target/RISCV/RISCVISelLowering.cpp
index be53f51afe79f..f8180e2d86664 100644
--- a/llvm/lib/Target/RISCV/RISCVISelLowering.cpp
+++ b/llvm/lib/Target/RISCV/RISCVISelLowering.cpp
@@ -12739,10 +12739,7 @@ SDValue RISCVTargetLowering::lowerVECTOR_INTERLEAVE(SDValue Op,
SmallVector<SDValue, 8> Loads(Factor);
- SDValue Increment =
- DAG.getVScale(DL, PtrVT,
- APInt(PtrVT.getFixedSizeInBits(),
- VecVT.getStoreSize().getKnownMinValue()));
+ SDValue Increment = DAG.getTypeSize(DL, PtrVT, VecVT.getStoreSize());
for (unsigned i = 0; i != Factor; ++i) {
if (i != 0)
StackPtr = DAG.getNode(ISD::ADD, DL, PtrVT, StackPtr, Increment);
@@ -14140,9 +14137,8 @@ RISCVTargetLowering::lowerVPReverseExperimental(SDValue Op,
// Slide off any elements from past EVL that were reversed into the low
// elements.
- unsigned MinElts = GatherVT.getVectorMinNumElements();
SDValue VLMax =
- DAG.getVScale(DL, XLenVT, APInt(XLenVT.getSizeInBits(), MinElts));
+ DAG.getElementCount(DL, XLenVT, GatherVT.getVectorElementCount());
SDValue Diff = DAG.getNode(ISD::SUB, DL, XLenVT, VLMax, EVL);
Result = getVSlidedown(DAG, Subtarget, DL, GatherVT,
|
🐧 Linux x64 Test Results
|
sdesmalen-arm
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Nice idea to make a dedicated function for this!
|
|
||
| template <typename Ty> | ||
| static SDValue getFixedOrScalableQuantity(SelectionDAG &DAG, const SDLoc &DL, | ||
| EVT VT, Ty X, bool ConstantFold) { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
ConstantFold is unused?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Good point. I guess it was originally meant to be plumbed through to getVScale but was forgotten? In any case it doesn't look like anything ever sets ConstantFold to false. It defaults to true. I've gone ahead and just removed it in 5c13c4a
| return getNode(ISD::VSCALE, DL, VT, getConstant(MulImm, DL, VT)); | ||
| } | ||
|
|
||
| template <typename Ty> |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can you add documentation for this function to describe what VT and X mean and how they are used to create the end result?
Also, why does it need to be a templated function? Is that for distinguishing EVT and MVT?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Added a comment in 5c13c4a.
It's templated since it needs to work over both ElementType and TypeSize, both of which inherit from details::FixedOrScalableQuantity<LeafTy, ScalarTy>, but with different LeafTy and ScalarTys.
sdesmalen-arm
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM with nit addressed
| if (EC.isScalable()) | ||
| return getVScale(DL, VT, | ||
| APInt(VT.getSizeInBits(), EC.getKnownMinValue())); | ||
| /// \returns a value of type \p VT that represents the runtime value of \p X, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit:
| /// \returns a value of type \p VT that represents the runtime value of \p X, | |
| /// \returns a value of type \p VT that represents the runtime value of \p Quantity, |
maybe also say that this can only be an ElementCount or a TypeSize?
Similar to how getElementCount avoids the need to reason about fixed and scalable ElementCounts separately, this patch adds getTypeSize to do the same for TypeSize. It also goes through and replaces some of the manual uses of getVScale with getTypeSize/getElementCount where possible.
Similar to how getElementCount avoids the need to reason about fixed and scalable ElementCounts separately, this patch adds getTypeSize to do the same for TypeSize. It also goes through and replaces some of the manual uses of getVScale with getTypeSize/getElementCount where possible.
Similar to how getElementCount avoids the need to reason about fixed and scalable ElementCounts separately, this patch adds getTypeSize to do the same for TypeSize.
It also goes through and replaces some of the manual uses of getVScale with getTypeSize/getElementCount where possible.