diff --git a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp index e3fa268732b308..f8b71ee6428777 100644 --- a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp +++ b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp @@ -4696,6 +4696,56 @@ SDValue AArch64TargetLowering::LowerMGATHER(SDValue Op, MGT->getMemOperand(), IndexType, ExtType); } + // Lower fixed length gather to a scalable equivalent. + if (VT.isFixedLengthVector()) { + assert(Subtarget->useSVEForFixedLengthVectors() && + "Cannot lower when not using SVE for fixed vectors!"); + + // NOTE: Handle floating-point as if integer then bitcast the result. + EVT DataVT = VT.changeVectorElementTypeToInteger(); + MemVT = MemVT.changeVectorElementTypeToInteger(); + + // Find the smallest integer fixed length vector we can use for the gather. + EVT PromotedVT = VT.changeVectorElementType(MVT::i32); + if (DataVT.getVectorElementType() == MVT::i64 || + Index.getValueType().getVectorElementType() == MVT::i64 || + Mask.getValueType().getVectorElementType() == MVT::i64) + PromotedVT = VT.changeVectorElementType(MVT::i64); + + // Promote vector operands except for passthrough, which we know is either + // undef or zero, and thus best constructed directly. + unsigned ExtOpcode = IsSigned ? ISD::SIGN_EXTEND : ISD::ZERO_EXTEND; + Index = DAG.getNode(ExtOpcode, DL, PromotedVT, Index); + Mask = DAG.getNode(ISD::SIGN_EXTEND, DL, PromotedVT, Mask); + + // A promoted result type forces the need for an extending load. + if (PromotedVT != DataVT && ExtType == ISD::NON_EXTLOAD) + ExtType = ISD::EXTLOAD; + + EVT ContainerVT = getContainerForFixedLengthVector(DAG, PromotedVT); + + // Convert fixed length vector operands to scalable. + MemVT = ContainerVT.changeVectorElementType(MemVT.getVectorElementType()); + Index = convertToScalableVector(DAG, ContainerVT, Index); + Mask = convertFixedMaskToScalableVector(Mask, DAG); + PassThru = PassThru->isUndef() ? DAG.getUNDEF(ContainerVT) + : DAG.getConstant(0, DL, ContainerVT); + + // Emit equivalent scalable vector gather. + SDValue Ops[] = {Chain, PassThru, Mask, BasePtr, Index, Scale}; + SDValue Load = + DAG.getMaskedGather(DAG.getVTList(ContainerVT, MVT::Other), MemVT, DL, + Ops, MGT->getMemOperand(), IndexType, ExtType); + + // Extract fixed length data then convert to the required result type. + SDValue Result = convertFromScalableVector(DAG, PromotedVT, Load); + Result = DAG.getNode(ISD::TRUNCATE, DL, DataVT, Result); + if (VT.isFloatingPoint()) + Result = DAG.getNode(ISD::BITCAST, DL, VT, Result); + + return DAG.getMergeValues({Result, Load.getValue(1)}, DL); + } + bool IdxNeedsExtend = getGatherScatterIndexIsExtended(Index) || Index.getSimpleValueType().getVectorElementType() == MVT::i32; @@ -4703,26 +4753,8 @@ SDValue AArch64TargetLowering::LowerMGATHER(SDValue Op, EVT IndexVT = Index.getSimpleValueType(); SDValue InputVT = DAG.getValueType(MemVT); - bool IsFixedLength = MGT->getMemoryVT().isFixedLengthVector(); - - if (IsFixedLength) { - assert(Subtarget->useSVEForFixedLengthVectors() && - "Cannot lower when not using SVE for fixed vectors"); - if (MemVT.getScalarSizeInBits() <= IndexVT.getScalarSizeInBits()) { - IndexVT = getContainerForFixedLengthVector(DAG, IndexVT); - MemVT = IndexVT.changeVectorElementType(MemVT.getVectorElementType()); - } else { - MemVT = getContainerForFixedLengthVector(DAG, MemVT); - IndexVT = MemVT.changeTypeToInteger(); - } - InputVT = DAG.getValueType(MemVT.changeTypeToInteger()); - Mask = DAG.getNode( - ISD::SIGN_EXTEND, DL, - VT.changeVectorElementType(IndexVT.getVectorElementType()), Mask); - } - // Handle FP data by using an integer gather and casting the result. - if (VT.isFloatingPoint() && !IsFixedLength) + if (VT.isFloatingPoint()) InputVT = DAG.getValueType(MemVT.changeVectorElementTypeToInteger()); SDVTList VTs = DAG.getVTList(IndexVT, MVT::Other); @@ -4737,25 +4769,11 @@ SDValue AArch64TargetLowering::LowerMGATHER(SDValue Op, if (ExtType == ISD::SEXTLOAD) Opcode = getSignExtendedGatherOpcode(Opcode); - if (IsFixedLength) { - if (Index.getSimpleValueType().isFixedLengthVector()) - Index = convertToScalableVector(DAG, IndexVT, Index); - if (BasePtr.getSimpleValueType().isFixedLengthVector()) - BasePtr = convertToScalableVector(DAG, IndexVT, BasePtr); - Mask = convertFixedMaskToScalableVector(Mask, DAG); - } - SDValue Ops[] = {Chain, Mask, BasePtr, Index, InputVT}; SDValue Result = DAG.getNode(Opcode, DL, VTs, Ops); Chain = Result.getValue(1); - if (IsFixedLength) { - Result = convertFromScalableVector( - DAG, VT.changeVectorElementType(IndexVT.getVectorElementType()), - Result); - Result = DAG.getNode(ISD::TRUNCATE, DL, VT.changeTypeToInteger(), Result); - Result = DAG.getNode(ISD::BITCAST, DL, VT, Result); - } else if (VT.isFloatingPoint()) + if (VT.isFloatingPoint()) Result = getSVESafeBitCast(VT, Result, DAG); return DAG.getMergeValues({Result, Chain}, DL); @@ -4775,6 +4793,7 @@ SDValue AArch64TargetLowering::LowerMSCATTER(SDValue Op, EVT VT = StoreVal.getValueType(); EVT MemVT = MSC->getMemoryVT(); ISD::MemIndexType IndexType = MSC->getIndexType(); + bool Truncating = MSC->isTruncatingStore(); bool IsScaled = MSC->isIndexScaled(); bool IsSigned = MSC->isIndexSigned(); @@ -4791,42 +4810,60 @@ SDValue AArch64TargetLowering::LowerMSCATTER(SDValue Op, SDValue Ops[] = {Chain, StoreVal, Mask, BasePtr, Index, Scale}; return DAG.getMaskedScatter(MSC->getVTList(), MemVT, DL, Ops, - MSC->getMemOperand(), IndexType, - MSC->isTruncatingStore()); + MSC->getMemOperand(), IndexType, Truncating); + } + + // Lower fixed length scatter to a scalable equivalent. + if (VT.isFixedLengthVector()) { + assert(Subtarget->useSVEForFixedLengthVectors() && + "Cannot lower when not using SVE for fixed vectors!"); + + // Once bitcast we treat floating-point scatters as if integer. + if (VT.isFloatingPoint()) { + VT = VT.changeVectorElementTypeToInteger(); + MemVT = MemVT.changeVectorElementTypeToInteger(); + StoreVal = DAG.getNode(ISD::BITCAST, DL, VT, StoreVal); + } + + // Find the smallest integer fixed length vector we can use for the scatter. + EVT PromotedVT = VT.changeVectorElementType(MVT::i32); + if (VT.getVectorElementType() == MVT::i64 || + Index.getValueType().getVectorElementType() == MVT::i64 || + Mask.getValueType().getVectorElementType() == MVT::i64) + PromotedVT = VT.changeVectorElementType(MVT::i64); + + // Promote vector operands. + unsigned ExtOpcode = IsSigned ? ISD::SIGN_EXTEND : ISD::ZERO_EXTEND; + Index = DAG.getNode(ExtOpcode, DL, PromotedVT, Index); + Mask = DAG.getNode(ISD::SIGN_EXTEND, DL, PromotedVT, Mask); + StoreVal = DAG.getNode(ISD::ANY_EXTEND, DL, PromotedVT, StoreVal); + + // A promoted value type forces the need for a truncating store. + if (PromotedVT != VT) + Truncating = true; + + EVT ContainerVT = getContainerForFixedLengthVector(DAG, PromotedVT); + + // Convert fixed length vector operands to scalable. + MemVT = ContainerVT.changeVectorElementType(MemVT.getVectorElementType()); + Index = convertToScalableVector(DAG, ContainerVT, Index); + Mask = convertFixedMaskToScalableVector(Mask, DAG); + StoreVal = convertToScalableVector(DAG, ContainerVT, StoreVal); + + // Emit equivalent scalable vector scatter. + SDValue Ops[] = {Chain, StoreVal, Mask, BasePtr, Index, Scale}; + return DAG.getMaskedScatter(MSC->getVTList(), MemVT, DL, Ops, + MSC->getMemOperand(), IndexType, Truncating); } bool NeedsExtend = getGatherScatterIndexIsExtended(Index) || Index.getSimpleValueType().getVectorElementType() == MVT::i32; - EVT IndexVT = Index.getSimpleValueType(); SDVTList VTs = DAG.getVTList(MVT::Other); SDValue InputVT = DAG.getValueType(MemVT); - bool IsFixedLength = MSC->getMemoryVT().isFixedLengthVector(); - - if (IsFixedLength) { - assert(Subtarget->useSVEForFixedLengthVectors() && - "Cannot lower when not using SVE for fixed vectors"); - if (MemVT.getScalarSizeInBits() <= IndexVT.getScalarSizeInBits()) { - IndexVT = getContainerForFixedLengthVector(DAG, IndexVT); - MemVT = IndexVT.changeVectorElementType(MemVT.getVectorElementType()); - } else { - MemVT = getContainerForFixedLengthVector(DAG, MemVT); - IndexVT = MemVT.changeTypeToInteger(); - } - InputVT = DAG.getValueType(MemVT.changeTypeToInteger()); - - StoreVal = - DAG.getNode(ISD::BITCAST, DL, VT.changeTypeToInteger(), StoreVal); - StoreVal = DAG.getNode( - ISD::ANY_EXTEND, DL, - VT.changeVectorElementType(IndexVT.getVectorElementType()), StoreVal); - StoreVal = convertToScalableVector(DAG, IndexVT, StoreVal); - Mask = DAG.getNode( - ISD::SIGN_EXTEND, DL, - VT.changeVectorElementType(IndexVT.getVectorElementType()), Mask); - } else if (VT.isFloatingPoint()) { + if (VT.isFloatingPoint()) { // Handle FP data by casting the data so an integer scatter can be used. EVT StoreValVT = getPackedSVEVectorVT(VT.getVectorElementCount()); StoreVal = getSVESafeBitCast(StoreValVT, StoreVal, DAG); @@ -4840,14 +4877,6 @@ SDValue AArch64TargetLowering::LowerMSCATTER(SDValue Op, selectGatherScatterAddrMode(BasePtr, Index, IsScaled, MemVT, Opcode, /*isGather=*/false, DAG); - if (IsFixedLength) { - if (Index.getSimpleValueType().isFixedLengthVector()) - Index = convertToScalableVector(DAG, IndexVT, Index); - if (BasePtr.getSimpleValueType().isFixedLengthVector()) - BasePtr = convertToScalableVector(DAG, IndexVT, BasePtr); - Mask = convertFixedMaskToScalableVector(Mask, DAG); - } - SDValue Ops[] = {Chain, StoreVal, Mask, BasePtr, Index, InputVT}; return DAG.getNode(Opcode, DL, VTs, Ops); }