Skip to content

Commit

Permalink
[DAG] Introduce getSplat utility for common dispatch pattern [nfc]
Browse files Browse the repository at this point in the history
We have a very common pattern of dispatching between BUILD_VECTOR and SPLAT_VECTOR creation repeated in many cases in code.  Common the pattern into a utility function.
  • Loading branch information
preames committed Oct 3, 2022
1 parent 939a3d2 commit a200b0f
Show file tree
Hide file tree
Showing 6 changed files with 27 additions and 50 deletions.
10 changes: 10 additions & 0 deletions llvm/include/llvm/CodeGen/SelectionDAG.h
Original file line number Diff line number Diff line change
Expand Up @@ -862,6 +862,16 @@ class SelectionDAG {
return getNode(ISD::SPLAT_VECTOR, DL, VT, Op);
}

/// Returns a node representing a splat of one value into all lanes
/// of the provided vector type. This is a utility which returns
/// either a BUILD_VECTOR or SPLAT_VECTOR depending on the
/// scalability of the desired vector type.
SDValue getSplat(EVT VT, const SDLoc &DL, SDValue Op) {
assert(VT.isVector() && "Can't splat to non-vector type");
return VT.isScalableVector() ?
getSplatVector(VT, DL, Op) : getSplatBuildVector(VT, DL, Op);
}

/// Returns a vector of type ResVT whose elements contain the linear sequence
/// <0, Step, Step * 2, Step * 3, ...>
SDValue getStepVector(const SDLoc &DL, EVT ResVT, APInt StepVal);
Expand Down
18 changes: 5 additions & 13 deletions llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -3469,11 +3469,8 @@ SDValue DAGCombiner::visitSUB(SDNode *N) {
if (VT.isVector()) {
SDValue N1S = DAG.getSplatValue(N1, true);
if (N1S && N1S.getOpcode() == ISD::SUB &&
isNullConstant(N1S.getOperand(0))) {
if (VT.isScalableVector())
return DAG.getSplatVector(VT, DL, N1S.getOperand(1));
return DAG.getSplatBuildVector(VT, DL, N1S.getOperand(1));
}
isNullConstant(N1S.getOperand(0)))
return DAG.getSplat(VT, DL, N1S.getOperand(1));
}
}

Expand Down Expand Up @@ -19778,11 +19775,8 @@ SDValue DAGCombiner::visitINSERT_VECTOR_ELT(SDNode *N) {
if (!IndexC) {
// If this is variable insert to undef vector, it might be better to splat:
// inselt undef, InVal, EltNo --> build_vector < InVal, InVal, ... >
if (InVec.isUndef() && TLI.shouldSplatInsEltVarIndex(VT)) {
if (VT.isScalableVector())
return DAG.getSplatVector(VT, DL, InVal);
return DAG.getSplatBuildVector(VT, DL, InVal);
}
if (InVec.isUndef() && TLI.shouldSplatInsEltVarIndex(VT))
return DAG.getSplat(VT, DL, InVal);
return SDValue();
}

Expand Down Expand Up @@ -23817,9 +23811,7 @@ static SDValue scalarizeBinOpOfSplats(SDNode *N, SelectionDAG &DAG,
}

// bo (splat X, Index), (splat Y, Index) --> splat (bo X, Y), Index
if (VT.isScalableVector())
return DAG.getSplatVector(VT, DL, ScalarBO);
return DAG.getSplatBuildVector(VT, DL, ScalarBO);
return DAG.getSplat(VT, DL, ScalarBO);
}

/// Visit a binary vector operation, like ADD.
Expand Down
8 changes: 2 additions & 6 deletions llvm/lib/CodeGen/SelectionDAG/LegalizeVectorOps.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -963,10 +963,7 @@ SDValue VectorLegalizer::ExpandSELECT(SDNode *Node) {
DAG.getConstant(0, DL, BitTy));

// Broadcast the mask so that the entire vector is all one or all zero.
if (VT.isFixedLengthVector())
Mask = DAG.getSplatBuildVector(MaskTy, DL, Mask);
else
Mask = DAG.getSplatVector(MaskTy, DL, Mask);
Mask = DAG.getSplat(MaskTy, DL, Mask);

// Bitcast the operands to be the same type as the mask.
// This is needed when we select between FP types because
Expand Down Expand Up @@ -1309,8 +1306,7 @@ SDValue VectorLegalizer::ExpandVP_MERGE(SDNode *Node) {
return DAG.UnrollVectorOp(Node);

SDValue StepVec = DAG.getStepVector(DL, EVLVecVT);
SDValue SplatEVL = IsFixedLen ? DAG.getSplatBuildVector(EVLVecVT, DL, EVL)
: DAG.getSplatVector(EVLVecVT, DL, EVL);
SDValue SplatEVL = DAG.getSplat(EVLVecVT, DL, EVL);
SDValue EVLMask =
DAG.getSetCC(DL, MaskVT, StepVec, SplatEVL, ISD::CondCode::SETULT);

Expand Down
13 changes: 4 additions & 9 deletions llvm/lib/CodeGen/SelectionDAG/SelectionDAG.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1607,11 +1607,8 @@ SDValue SelectionDAG::getConstant(const ConstantInt &Val, const SDLoc &DL,
}

SDValue Result(N, 0);
if (VT.isScalableVector())
Result = getSplatVector(VT, DL, Result);
else if (VT.isVector())
Result = getSplatBuildVector(VT, DL, Result);

if (VT.isVector())
Result = getSplat(VT, DL, Result);
return Result;
}

Expand Down Expand Up @@ -1663,10 +1660,8 @@ SDValue SelectionDAG::getConstantFP(const ConstantFP &V, const SDLoc &DL,
}

SDValue Result(N, 0);
if (VT.isScalableVector())
Result = getSplatVector(VT, DL, Result);
else if (VT.isVector())
Result = getSplatBuildVector(VT, DL, Result);
if (VT.isVector())
Result = getSplat(VT, DL, Result);
NewSDValueDbgMsg(Result, "Creating fp constant: ", this);
return Result;
}
Expand Down
24 changes: 5 additions & 19 deletions llvm/lib/CodeGen/SelectionDAG/SelectionDAGBuilder.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1695,9 +1695,7 @@ SDValue SelectionDAGBuilder::getValueImpl(const Value *V) {
else
Op = DAG.getConstant(0, getCurSDLoc(), EltVT);

if (isa<ScalableVectorType>(VecTy))
return NodeMap[V] = DAG.getSplatVector(VT, getCurSDLoc(), Op);
return NodeMap[V] = DAG.getSplatBuildVector(VT, getCurSDLoc(), Op);
return NodeMap[V] = DAG.getSplat(VT, getCurSDLoc(), Op);
}

llvm_unreachable("Unknown vector constant");
Expand Down Expand Up @@ -3904,10 +3902,7 @@ void SelectionDAGBuilder::visitGetElementPtr(const User &I) {
if (IsVectorGEP && !N.getValueType().isVector()) {
LLVMContext &Context = *DAG.getContext();
EVT VT = EVT::getVectorVT(Context, N.getValueType(), VectorElementCount);
if (VectorElementCount.isScalable())
N = DAG.getSplatVector(VT, dl, N);
else
N = DAG.getSplatBuildVector(VT, dl, N);
N = DAG.getSplat(VT, dl, N);
}

for (gep_type_iterator GTI = gep_type_begin(&I), E = gep_type_end(&I);
Expand Down Expand Up @@ -3979,10 +3974,7 @@ void SelectionDAGBuilder::visitGetElementPtr(const User &I) {
if (!IdxN.getValueType().isVector() && IsVectorGEP) {
EVT VT = EVT::getVectorVT(*Context, IdxN.getValueType(),
VectorElementCount);
if (VectorElementCount.isScalable())
IdxN = DAG.getSplatVector(VT, dl, IdxN);
else
IdxN = DAG.getSplatBuildVector(VT, dl, IdxN);
IdxN = DAG.getSplat(VT, dl, IdxN);
}

// If the index is smaller or larger than intptr_t, truncate or extend
Expand Down Expand Up @@ -7247,14 +7239,8 @@ void SelectionDAGBuilder::visitIntrinsicCall(const CallInst &I,
SDValue TripCount = getValue(I.getOperand(1));
auto VecTy = CCVT.changeVectorElementType(ElementVT);

SDValue VectorIndex, VectorTripCount;
if (VecTy.isScalableVector()) {
VectorIndex = DAG.getSplatVector(VecTy, sdl, Index);
VectorTripCount = DAG.getSplatVector(VecTy, sdl, TripCount);
} else {
VectorIndex = DAG.getSplatBuildVector(VecTy, sdl, Index);
VectorTripCount = DAG.getSplatBuildVector(VecTy, sdl, TripCount);
}
SDValue VectorIndex = DAG.getSplat(VecTy, sdl, Index);
SDValue VectorTripCount = DAG.getSplat(VecTy, sdl, TripCount);
SDValue VectorStep = DAG.getStepVector(sdl, VecTy);
SDValue VectorInduction = DAG.getNode(
ISD::UADDSAT, sdl, VecTy, VectorIndex, VectorStep);
Expand Down
4 changes: 1 addition & 3 deletions llvm/lib/Target/RISCV/RISCVISelLowering.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -4192,9 +4192,7 @@ SDValue RISCVTargetLowering::lowerSELECT(SDValue Op, SelectionDAG &DAG) const {
// Lower vector SELECTs to VSELECTs by splatting the condition.
if (VT.isVector()) {
MVT SplatCondVT = VT.changeVectorElementType(MVT::i1);
SDValue CondSplat = VT.isScalableVector()
? DAG.getSplatVector(SplatCondVT, DL, CondV)
: DAG.getSplatBuildVector(SplatCondVT, DL, CondV);
SDValue CondSplat = DAG.getSplat(SplatCondVT, DL, CondV);
return DAG.getNode(ISD::VSELECT, DL, VT, CondSplat, TrueV, FalseV);
}

Expand Down

0 comments on commit a200b0f

Please sign in to comment.