Skip to content

Commit

Permalink
[SVE] Refactor lowering for fixed length MGATHER/MSCATTER.
Browse files Browse the repository at this point in the history
Lower fixed length MGATHER/MSCATTER operations to scalable vector
equivalents, which are then lowered to SVE specific nodes. This
two stage process is in preparation for making scalable vector
MGATHER/MSCATTER operations legal.

Differential Revision: https://reviews.llvm.org/D125192
  • Loading branch information
paulwalker-arm committed May 21, 2022
1 parent 86fd1c1 commit 216f546
Showing 1 changed file with 98 additions and 69 deletions.
167 changes: 98 additions & 69 deletions llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
Expand Up @@ -4696,33 +4696,65 @@ SDValue AArch64TargetLowering::LowerMGATHER(SDValue Op,
MGT->getMemOperand(), IndexType, ExtType);
}

// Lower fixed length gather to a scalable equivalent.
if (VT.isFixedLengthVector()) {
assert(Subtarget->useSVEForFixedLengthVectors() &&
"Cannot lower when not using SVE for fixed vectors!");

// NOTE: Handle floating-point as if integer then bitcast the result.
EVT DataVT = VT.changeVectorElementTypeToInteger();
MemVT = MemVT.changeVectorElementTypeToInteger();

// Find the smallest integer fixed length vector we can use for the gather.
EVT PromotedVT = VT.changeVectorElementType(MVT::i32);
if (DataVT.getVectorElementType() == MVT::i64 ||
Index.getValueType().getVectorElementType() == MVT::i64 ||
Mask.getValueType().getVectorElementType() == MVT::i64)
PromotedVT = VT.changeVectorElementType(MVT::i64);

// Promote vector operands except for passthrough, which we know is either
// undef or zero, and thus best constructed directly.
unsigned ExtOpcode = IsSigned ? ISD::SIGN_EXTEND : ISD::ZERO_EXTEND;
Index = DAG.getNode(ExtOpcode, DL, PromotedVT, Index);
Mask = DAG.getNode(ISD::SIGN_EXTEND, DL, PromotedVT, Mask);

// A promoted result type forces the need for an extending load.
if (PromotedVT != DataVT && ExtType == ISD::NON_EXTLOAD)
ExtType = ISD::EXTLOAD;

EVT ContainerVT = getContainerForFixedLengthVector(DAG, PromotedVT);

// Convert fixed length vector operands to scalable.
MemVT = ContainerVT.changeVectorElementType(MemVT.getVectorElementType());
Index = convertToScalableVector(DAG, ContainerVT, Index);
Mask = convertFixedMaskToScalableVector(Mask, DAG);
PassThru = PassThru->isUndef() ? DAG.getUNDEF(ContainerVT)
: DAG.getConstant(0, DL, ContainerVT);

// Emit equivalent scalable vector gather.
SDValue Ops[] = {Chain, PassThru, Mask, BasePtr, Index, Scale};
SDValue Load =
DAG.getMaskedGather(DAG.getVTList(ContainerVT, MVT::Other), MemVT, DL,
Ops, MGT->getMemOperand(), IndexType, ExtType);

// Extract fixed length data then convert to the required result type.
SDValue Result = convertFromScalableVector(DAG, PromotedVT, Load);
Result = DAG.getNode(ISD::TRUNCATE, DL, DataVT, Result);
if (VT.isFloatingPoint())
Result = DAG.getNode(ISD::BITCAST, DL, VT, Result);

return DAG.getMergeValues({Result, Load.getValue(1)}, DL);
}

bool IdxNeedsExtend =
getGatherScatterIndexIsExtended(Index) ||
Index.getSimpleValueType().getVectorElementType() == MVT::i32;

EVT IndexVT = Index.getSimpleValueType();
SDValue InputVT = DAG.getValueType(MemVT);

bool IsFixedLength = MGT->getMemoryVT().isFixedLengthVector();

if (IsFixedLength) {
assert(Subtarget->useSVEForFixedLengthVectors() &&
"Cannot lower when not using SVE for fixed vectors");
if (MemVT.getScalarSizeInBits() <= IndexVT.getScalarSizeInBits()) {
IndexVT = getContainerForFixedLengthVector(DAG, IndexVT);
MemVT = IndexVT.changeVectorElementType(MemVT.getVectorElementType());
} else {
MemVT = getContainerForFixedLengthVector(DAG, MemVT);
IndexVT = MemVT.changeTypeToInteger();
}
InputVT = DAG.getValueType(MemVT.changeTypeToInteger());
Mask = DAG.getNode(
ISD::SIGN_EXTEND, DL,
VT.changeVectorElementType(IndexVT.getVectorElementType()), Mask);
}

// Handle FP data by using an integer gather and casting the result.
if (VT.isFloatingPoint() && !IsFixedLength)
if (VT.isFloatingPoint())
InputVT = DAG.getValueType(MemVT.changeVectorElementTypeToInteger());

SDVTList VTs = DAG.getVTList(IndexVT, MVT::Other);
Expand All @@ -4737,25 +4769,11 @@ SDValue AArch64TargetLowering::LowerMGATHER(SDValue Op,
if (ExtType == ISD::SEXTLOAD)
Opcode = getSignExtendedGatherOpcode(Opcode);

if (IsFixedLength) {
if (Index.getSimpleValueType().isFixedLengthVector())
Index = convertToScalableVector(DAG, IndexVT, Index);
if (BasePtr.getSimpleValueType().isFixedLengthVector())
BasePtr = convertToScalableVector(DAG, IndexVT, BasePtr);
Mask = convertFixedMaskToScalableVector(Mask, DAG);
}

SDValue Ops[] = {Chain, Mask, BasePtr, Index, InputVT};
SDValue Result = DAG.getNode(Opcode, DL, VTs, Ops);
Chain = Result.getValue(1);

if (IsFixedLength) {
Result = convertFromScalableVector(
DAG, VT.changeVectorElementType(IndexVT.getVectorElementType()),
Result);
Result = DAG.getNode(ISD::TRUNCATE, DL, VT.changeTypeToInteger(), Result);
Result = DAG.getNode(ISD::BITCAST, DL, VT, Result);
} else if (VT.isFloatingPoint())
if (VT.isFloatingPoint())
Result = getSVESafeBitCast(VT, Result, DAG);

return DAG.getMergeValues({Result, Chain}, DL);
Expand All @@ -4775,6 +4793,7 @@ SDValue AArch64TargetLowering::LowerMSCATTER(SDValue Op,
EVT VT = StoreVal.getValueType();
EVT MemVT = MSC->getMemoryVT();
ISD::MemIndexType IndexType = MSC->getIndexType();
bool Truncating = MSC->isTruncatingStore();

bool IsScaled = MSC->isIndexScaled();
bool IsSigned = MSC->isIndexSigned();
Expand All @@ -4791,42 +4810,60 @@ SDValue AArch64TargetLowering::LowerMSCATTER(SDValue Op,

SDValue Ops[] = {Chain, StoreVal, Mask, BasePtr, Index, Scale};
return DAG.getMaskedScatter(MSC->getVTList(), MemVT, DL, Ops,
MSC->getMemOperand(), IndexType,
MSC->isTruncatingStore());
MSC->getMemOperand(), IndexType, Truncating);
}

// Lower fixed length scatter to a scalable equivalent.
if (VT.isFixedLengthVector()) {
assert(Subtarget->useSVEForFixedLengthVectors() &&
"Cannot lower when not using SVE for fixed vectors!");

// Once bitcast we treat floating-point scatters as if integer.
if (VT.isFloatingPoint()) {
VT = VT.changeVectorElementTypeToInteger();
MemVT = MemVT.changeVectorElementTypeToInteger();
StoreVal = DAG.getNode(ISD::BITCAST, DL, VT, StoreVal);
}

// Find the smallest integer fixed length vector we can use for the scatter.
EVT PromotedVT = VT.changeVectorElementType(MVT::i32);
if (VT.getVectorElementType() == MVT::i64 ||
Index.getValueType().getVectorElementType() == MVT::i64 ||
Mask.getValueType().getVectorElementType() == MVT::i64)
PromotedVT = VT.changeVectorElementType(MVT::i64);

// Promote vector operands.
unsigned ExtOpcode = IsSigned ? ISD::SIGN_EXTEND : ISD::ZERO_EXTEND;
Index = DAG.getNode(ExtOpcode, DL, PromotedVT, Index);
Mask = DAG.getNode(ISD::SIGN_EXTEND, DL, PromotedVT, Mask);
StoreVal = DAG.getNode(ISD::ANY_EXTEND, DL, PromotedVT, StoreVal);

// A promoted value type forces the need for a truncating store.
if (PromotedVT != VT)
Truncating = true;

EVT ContainerVT = getContainerForFixedLengthVector(DAG, PromotedVT);

// Convert fixed length vector operands to scalable.
MemVT = ContainerVT.changeVectorElementType(MemVT.getVectorElementType());
Index = convertToScalableVector(DAG, ContainerVT, Index);
Mask = convertFixedMaskToScalableVector(Mask, DAG);
StoreVal = convertToScalableVector(DAG, ContainerVT, StoreVal);

// Emit equivalent scalable vector scatter.
SDValue Ops[] = {Chain, StoreVal, Mask, BasePtr, Index, Scale};
return DAG.getMaskedScatter(MSC->getVTList(), MemVT, DL, Ops,
MSC->getMemOperand(), IndexType, Truncating);
}

bool NeedsExtend =
getGatherScatterIndexIsExtended(Index) ||
Index.getSimpleValueType().getVectorElementType() == MVT::i32;

EVT IndexVT = Index.getSimpleValueType();
SDVTList VTs = DAG.getVTList(MVT::Other);
SDValue InputVT = DAG.getValueType(MemVT);

bool IsFixedLength = MSC->getMemoryVT().isFixedLengthVector();

if (IsFixedLength) {
assert(Subtarget->useSVEForFixedLengthVectors() &&
"Cannot lower when not using SVE for fixed vectors");
if (MemVT.getScalarSizeInBits() <= IndexVT.getScalarSizeInBits()) {
IndexVT = getContainerForFixedLengthVector(DAG, IndexVT);
MemVT = IndexVT.changeVectorElementType(MemVT.getVectorElementType());
} else {
MemVT = getContainerForFixedLengthVector(DAG, MemVT);
IndexVT = MemVT.changeTypeToInteger();
}
InputVT = DAG.getValueType(MemVT.changeTypeToInteger());

StoreVal =
DAG.getNode(ISD::BITCAST, DL, VT.changeTypeToInteger(), StoreVal);
StoreVal = DAG.getNode(
ISD::ANY_EXTEND, DL,
VT.changeVectorElementType(IndexVT.getVectorElementType()), StoreVal);
StoreVal = convertToScalableVector(DAG, IndexVT, StoreVal);
Mask = DAG.getNode(
ISD::SIGN_EXTEND, DL,
VT.changeVectorElementType(IndexVT.getVectorElementType()), Mask);
} else if (VT.isFloatingPoint()) {
if (VT.isFloatingPoint()) {
// Handle FP data by casting the data so an integer scatter can be used.
EVT StoreValVT = getPackedSVEVectorVT(VT.getVectorElementCount());
StoreVal = getSVESafeBitCast(StoreValVT, StoreVal, DAG);
Expand All @@ -4840,14 +4877,6 @@ SDValue AArch64TargetLowering::LowerMSCATTER(SDValue Op,
selectGatherScatterAddrMode(BasePtr, Index, IsScaled, MemVT, Opcode,
/*isGather=*/false, DAG);

if (IsFixedLength) {
if (Index.getSimpleValueType().isFixedLengthVector())
Index = convertToScalableVector(DAG, IndexVT, Index);
if (BasePtr.getSimpleValueType().isFixedLengthVector())
BasePtr = convertToScalableVector(DAG, IndexVT, BasePtr);
Mask = convertFixedMaskToScalableVector(Mask, DAG);
}

SDValue Ops[] = {Chain, StoreVal, Mask, BasePtr, Index, InputVT};
return DAG.getNode(Opcode, DL, VTs, Ops);
}
Expand Down

0 comments on commit 216f546

Please sign in to comment.