Skip to content

Commit

Permalink
[RISCV] Rework gather/scatter DAG combine structure [NFC]
Browse files Browse the repository at this point in the history
Instead of switching on type before and after common code, use a helper function.  This matches the style of DAGCombine.cpp more closely, and makes porting candidate changes from one place to the other much easier.
  • Loading branch information
preames committed Sep 12, 2023
1 parent 2e106d5 commit 17b071d
Showing 1 changed file with 83 additions and 56 deletions.
139 changes: 83 additions & 56 deletions llvm/lib/Target/RISCV/RISCVISelLowering.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -13482,6 +13482,34 @@ static SDValue combineToVWMACC(SDNode *N, SelectionDAG &DAG,
return DAG.getNode(Opc, DL, VT, Ops);
}

static bool legalizeScatterGatherIndexType(SDLoc DL, SDValue &Index,
ISD::MemIndexType &IndexType,
RISCVTargetLowering::DAGCombinerInfo &DCI) {
if (!DCI.isBeforeLegalize())
return false;

SelectionDAG &DAG = DCI.DAG;
const MVT XLenVT =
DAG.getMachineFunction().getSubtarget<RISCVSubtarget>().getXLenVT();

const EVT IndexVT = Index.getValueType();
const bool IsIndexSigned = isIndexTypeSigned(IndexType);

// RISC-V indexed loads only support the "unsigned unscaled" addressing
// mode, so anything else must be manually legalized.
if (!IsIndexSigned || !IndexVT.getVectorElementType().bitsLT(XLenVT))
return false;

// Any index legalization should first promote to XLenVT, so we don't lose
// bits when scaling. This may create an illegal index type so we let
// LLVM's legalization take care of the splitting.
// FIXME: LLVM can't split VP_GATHER or VP_SCATTER yet.
Index = DAG.getNode(ISD::SIGN_EXTEND, DL,
IndexVT.changeVectorElementType(XLenVT), Index);
IndexType = ISD::UNSIGNED_SCALED;
return true;
}

SDValue RISCVTargetLowering::PerformDAGCombine(SDNode *N,
DAGCombinerInfo &DCI) const {
SelectionDAG &DAG = DCI.DAG;
Expand Down Expand Up @@ -13827,74 +13855,73 @@ SDValue RISCVTargetLowering::PerformDAGCombine(SDNode *N,
return DAG.getNode(ISD::FCOPYSIGN, DL, VT, N->getOperand(0),
DAG.getNode(ISD::FNEG, DL, VT, NewFPExtRound));
}
case ISD::MGATHER:
case ISD::MSCATTER:
case ISD::VP_GATHER:
case ISD::VP_SCATTER: {
if (!DCI.isBeforeLegalize())
break;
SDValue Index, ScaleOp;
bool IsIndexSigned = false;
if (const auto *VPGSN = dyn_cast<VPGatherScatterSDNode>(N)) {
Index = VPGSN->getIndex();
ScaleOp = VPGSN->getScale();
IsIndexSigned = VPGSN->isIndexSigned();
assert(!VPGSN->isIndexScaled() &&
"Scaled gather/scatter should not be formed");
} else {
const auto *MGSN = cast<MaskedGatherScatterSDNode>(N);
Index = MGSN->getIndex();
ScaleOp = MGSN->getScale();
IsIndexSigned = MGSN->isIndexSigned();
assert(!MGSN->isIndexScaled() &&
"Scaled gather/scatter should not be formed");

}
EVT IndexVT = Index.getValueType();
// RISC-V indexed loads only support the "unsigned unscaled" addressing
// mode, so anything else must be manually legalized.
bool NeedsIdxLegalization =
(IsIndexSigned && IndexVT.getVectorElementType().bitsLT(XLenVT));
if (!NeedsIdxLegalization)
break;
case ISD::MGATHER: {
const auto *MGN = dyn_cast<MaskedGatherSDNode>(N);
SDValue Index = MGN->getIndex();
SDValue ScaleOp = MGN->getScale();
ISD::MemIndexType IndexType = MGN->getIndexType();
assert(!MGN->isIndexScaled() &&
"Scaled gather/scatter should not be formed");

SDLoc DL(N);
if (legalizeScatterGatherIndexType(DL, Index, IndexType, DCI))
return DAG.getMaskedGather(
N->getVTList(), MGN->getMemoryVT(), DL,
{MGN->getChain(), MGN->getPassThru(), MGN->getMask(),
MGN->getBasePtr(), Index, ScaleOp},
MGN->getMemOperand(), IndexType, MGN->getExtensionType());
break;
}
case ISD::MSCATTER:{
const auto *MSN = dyn_cast<MaskedScatterSDNode>(N);
SDValue Index = MSN->getIndex();
SDValue ScaleOp = MSN->getScale();
ISD::MemIndexType IndexType = MSN->getIndexType();
assert(!MSN->isIndexScaled() &&
"Scaled gather/scatter should not be formed");

// Any index legalization should first promote to XLenVT, so we don't lose
// bits when scaling. This may create an illegal index type so we let
// LLVM's legalization take care of the splitting.
// FIXME: LLVM can't split VP_GATHER or VP_SCATTER yet.
if (IndexVT.getVectorElementType().bitsLT(XLenVT)) {
IndexVT = IndexVT.changeVectorElementType(XLenVT);
Index = DAG.getNode(IsIndexSigned ? ISD::SIGN_EXTEND : ISD::ZERO_EXTEND,
DL, IndexVT, Index);
}
SDLoc DL(N);
if (legalizeScatterGatherIndexType(DL, Index, IndexType, DCI))
return DAG.getMaskedScatter(
N->getVTList(), MSN->getMemoryVT(), DL,
{MSN->getChain(), MSN->getValue(), MSN->getMask(), MSN->getBasePtr(),
Index, ScaleOp},
MSN->getMemOperand(), IndexType, MSN->isTruncatingStore());
break;
}
case ISD::VP_GATHER: {
const auto *VPGN = dyn_cast<VPGatherSDNode>(N);
SDValue Index = VPGN->getIndex();
SDValue ScaleOp = VPGN->getScale();
ISD::MemIndexType IndexType = VPGN->getIndexType();
assert(!VPGN->isIndexScaled() &&
"Scaled gather/scatter should not be formed");

ISD::MemIndexType NewIndexTy = ISD::UNSIGNED_SCALED;
if (const auto *VPGN = dyn_cast<VPGatherSDNode>(N))
SDLoc DL(N);
if (legalizeScatterGatherIndexType(DL, Index, IndexType, DCI))
return DAG.getGatherVP(N->getVTList(), VPGN->getMemoryVT(), DL,
{VPGN->getChain(), VPGN->getBasePtr(), Index,
ScaleOp, VPGN->getMask(),
VPGN->getVectorLength()},
VPGN->getMemOperand(), NewIndexTy);
if (const auto *VPSN = dyn_cast<VPScatterSDNode>(N))
VPGN->getMemOperand(), IndexType);
break;
}
case ISD::VP_SCATTER: {
const auto *VPSN = dyn_cast<VPScatterSDNode>(N);
SDValue Index = VPSN->getIndex();
SDValue ScaleOp = VPSN->getScale();
ISD::MemIndexType IndexType = VPSN->getIndexType();
assert(!VPSN->isIndexScaled() &&
"Scaled gather/scatter should not be formed");

SDLoc DL(N);
if (legalizeScatterGatherIndexType(DL, Index, IndexType, DCI))
return DAG.getScatterVP(N->getVTList(), VPSN->getMemoryVT(), DL,
{VPSN->getChain(), VPSN->getValue(),
VPSN->getBasePtr(), Index, ScaleOp,
VPSN->getMask(), VPSN->getVectorLength()},
VPSN->getMemOperand(), NewIndexTy);
if (const auto *MGN = dyn_cast<MaskedGatherSDNode>(N))
return DAG.getMaskedGather(
N->getVTList(), MGN->getMemoryVT(), DL,
{MGN->getChain(), MGN->getPassThru(), MGN->getMask(),
MGN->getBasePtr(), Index, ScaleOp},
MGN->getMemOperand(), NewIndexTy, MGN->getExtensionType());
const auto *MSN = cast<MaskedScatterSDNode>(N);
return DAG.getMaskedScatter(
N->getVTList(), MSN->getMemoryVT(), DL,
{MSN->getChain(), MSN->getValue(), MSN->getMask(), MSN->getBasePtr(),
Index, ScaleOp},
MSN->getMemOperand(), NewIndexTy, MSN->isTruncatingStore());
VPSN->getMemOperand(), IndexType);
break;
}
case RISCVISD::SRA_VL:
case RISCVISD::SRL_VL:
Expand Down

0 comments on commit 17b071d

Please sign in to comment.