Skip to content

Commit

Permalink
[AArch64][SVE] Remove LD1/ST1 dependency on llvm.masked.load/store
Browse files Browse the repository at this point in the history
Summary:
The SVE masked load and store intrinsics introduced in D76688 rely on
common llvm.masked.load/store nodes. This patch creates new ISD nodes
for LD1(S) & ST1 to remove this dependency.

Additionally, this adds support for sign & zero extending
loads and truncating stores.

Reviewers: sdesmalen, efriedma, cameron.mcinally, c-rhodes, rengolin

Reviewed By: efriedma

Subscribers: tschuett, kristof.beyls, hiraditya, rkruppe, psnobl, danielkiss, andwar, cfe-commits, llvm-commits

Tags: #llvm

Differential Revision: https://reviews.llvm.org/D78204
  • Loading branch information
kmclaughlin-arm committed Apr 20, 2020
1 parent 1f67508 commit 33ffce5
Show file tree
Hide file tree
Showing 7 changed files with 1,012 additions and 242 deletions.
4 changes: 4 additions & 0 deletions llvm/lib/Target/AArch64/AArch64ISelDAGToDAG.cpp
Expand Up @@ -4646,9 +4646,13 @@ static EVT getMemVTFromNode(LLVMContext &Ctx, SDNode *Root) {
// For custom ISD nodes, we have to look at them individually to extract the
// type of the data moved to/from memory.
switch (Opcode) {
case AArch64ISD::LD1:
case AArch64ISD::LD1S:
case AArch64ISD::LDNF1:
case AArch64ISD::LDNF1S:
return cast<VTSDNode>(Root->getOperand(3))->getVT();
case AArch64ISD::ST1:
return cast<VTSDNode>(Root->getOperand(4))->getVT();
default:
break;
}
Expand Down
103 changes: 69 additions & 34 deletions llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
Expand Up @@ -1415,6 +1415,8 @@ const char *AArch64TargetLowering::getTargetNodeName(unsigned Opcode) const {
case AArch64ISD::INSR: return "AArch64ISD::INSR";
case AArch64ISD::PTEST: return "AArch64ISD::PTEST";
case AArch64ISD::PTRUE: return "AArch64ISD::PTRUE";
case AArch64ISD::LD1: return "AArch64ISD::LD1";
case AArch64ISD::LD1S: return "AArch64ISD::LD1S";
case AArch64ISD::LDNF1: return "AArch64ISD::LDNF1";
case AArch64ISD::LDNF1S: return "AArch64ISD::LDNF1S";
case AArch64ISD::LDFF1: return "AArch64ISD::LDFF1";
Expand Down Expand Up @@ -1454,6 +1456,8 @@ const char *AArch64TargetLowering::getTargetNodeName(unsigned Opcode) const {
case AArch64ISD::GLDNT1_INDEX: return "AArch64ISD::GLDNT1_INDEX";
case AArch64ISD::GLDNT1S: return "AArch64ISD::GLDNT1S";

case AArch64ISD::ST1: return "AArch64ISD::ST1";

case AArch64ISD::SST1: return "AArch64ISD::SST1";
case AArch64ISD::SST1_SCALED: return "AArch64ISD::SST1_SCALED";
case AArch64ISD::SST1_SXTW: return "AArch64ISD::SST1_SXTW";
Expand Down Expand Up @@ -9041,7 +9045,6 @@ bool AArch64TargetLowering::getTgtMemIntrinsic(IntrinsicInfo &Info,
Info.align = Align(16);
Info.flags = MachineMemOperand::MOStore | MachineMemOperand::MOVolatile;
return true;
case Intrinsic::aarch64_sve_ld1:
case Intrinsic::aarch64_sve_ldnt1: {
PointerType *PtrTy = cast<PointerType>(I.getArgOperand(1)->getType());
Info.opc = ISD::INTRINSIC_W_CHAIN;
Expand All @@ -9054,7 +9057,6 @@ bool AArch64TargetLowering::getTgtMemIntrinsic(IntrinsicInfo &Info,
Info.flags |= MachineMemOperand::MONonTemporal;
return true;
}
case Intrinsic::aarch64_sve_st1:
case Intrinsic::aarch64_sve_stnt1: {
PointerType *PtrTy = cast<PointerType>(I.getArgOperand(2)->getType());
Info.opc = ISD::INTRINSIC_W_CHAIN;
Expand Down Expand Up @@ -10515,6 +10517,7 @@ static SDValue performSVEAndCombine(SDNode *N,
// SVE load instructions perform an implicit zero-extend, which makes them
// perfect candidates for combining.
switch (Src->getOpcode()) {
case AArch64ISD::LD1:
case AArch64ISD::LDNF1:
case AArch64ISD::LDFF1:
MemVT = cast<VTSDNode>(Src->getOperand(3))->getVT();
Expand Down Expand Up @@ -11581,7 +11584,33 @@ static MVT getSVEContainerType(EVT ContentTy) {
}
}

static SDValue performLD1Combine(SDNode *N, SelectionDAG &DAG) {
static SDValue performLD1Combine(SDNode *N, SelectionDAG &DAG, unsigned Opc) {
SDLoc DL(N);
EVT VT = N->getValueType(0);

if (VT.getSizeInBits().getKnownMinSize() > AArch64::SVEBitsPerBlock)
return SDValue();

EVT ContainerVT = VT;
if (ContainerVT.isInteger())
ContainerVT = getSVEContainerType(ContainerVT);

SDVTList VTs = DAG.getVTList(ContainerVT, MVT::Other);
SDValue Ops[] = { N->getOperand(0), // Chain
N->getOperand(2), // Pg
N->getOperand(3), // Base
DAG.getValueType(VT) };

SDValue Load = DAG.getNode(Opc, DL, VTs, Ops);
SDValue LoadChain = SDValue(Load.getNode(), 1);

if (ContainerVT.isInteger() && (VT != ContainerVT))
Load = DAG.getNode(ISD::TRUNCATE, DL, VT, Load.getValue(0));

return DAG.getMergeValues({ Load, LoadChain }, DL);
}

static SDValue performLDNT1Combine(SDNode *N, SelectionDAG &DAG) {
SDLoc DL(N);
EVT VT = N->getValueType(0);
EVT PtrTy = N->getOperand(3).getValueType();
Expand All @@ -11608,6 +11637,32 @@ static SDValue performLD1Combine(SDNode *N, SelectionDAG &DAG) {

static SDValue performST1Combine(SDNode *N, SelectionDAG &DAG) {
SDLoc DL(N);
SDValue Data = N->getOperand(2);
EVT DataVT = Data.getValueType();
EVT HwSrcVt = getSVEContainerType(DataVT);
SDValue InputVT = DAG.getValueType(DataVT);

if (DataVT.isFloatingPoint())
InputVT = DAG.getValueType(HwSrcVt);

SDValue SrcNew;
if (Data.getValueType().isFloatingPoint())
SrcNew = DAG.getNode(ISD::BITCAST, DL, HwSrcVt, Data);
else
SrcNew = DAG.getNode(ISD::ANY_EXTEND, DL, HwSrcVt, Data);

SDValue Ops[] = { N->getOperand(0), // Chain
SrcNew,
N->getOperand(4), // Base
N->getOperand(3), // Pg
InputVT
};

return DAG.getNode(AArch64ISD::ST1, DL, N->getValueType(0), Ops);
}

static SDValue performSTNT1Combine(SDNode *N, SelectionDAG &DAG) {
SDLoc DL(N);

SDValue Data = N->getOperand(2);
EVT DataVT = Data.getValueType();
Expand All @@ -11623,32 +11678,6 @@ static SDValue performST1Combine(SDNode *N, SelectionDAG &DAG) {
ISD::UNINDEXED, false, false);
}

static SDValue performLDNF1Combine(SDNode *N, SelectionDAG &DAG, unsigned Opc) {
SDLoc DL(N);
EVT VT = N->getValueType(0);

if (VT.getSizeInBits().getKnownMinSize() > AArch64::SVEBitsPerBlock)
return SDValue();

EVT ContainerVT = VT;
if (ContainerVT.isInteger())
ContainerVT = getSVEContainerType(ContainerVT);

SDVTList VTs = DAG.getVTList(ContainerVT, MVT::Other);
SDValue Ops[] = { N->getOperand(0), // Chain
N->getOperand(2), // Pg
N->getOperand(3), // Base
DAG.getValueType(VT) };

SDValue Load = DAG.getNode(Opc, DL, VTs, Ops);
SDValue LoadChain = SDValue(Load.getNode(), 1);

if (ContainerVT.isInteger() && (VT != ContainerVT))
Load = DAG.getNode(ISD::TRUNCATE, DL, VT, Load.getValue(0));

return DAG.getMergeValues({ Load, LoadChain }, DL);
}

/// Replace a splat of zeros to a vector store by scalar stores of WZR/XZR. The
/// load store optimizer pass will merge them to store pair stores. This should
/// be better than a movi to create the vector zero followed by a vector store
Expand Down Expand Up @@ -12963,6 +12992,10 @@ performSignExtendInRegCombine(SDNode *N, TargetLowering::DAGCombinerInfo &DCI,
unsigned NewOpc;
unsigned MemVTOpNum = 4;
switch (Opc) {
case AArch64ISD::LD1:
NewOpc = AArch64ISD::LD1S;
MemVTOpNum = 3;
break;
case AArch64ISD::LDNF1:
NewOpc = AArch64ISD::LDNF1S;
MemVTOpNum = 3;
Expand Down Expand Up @@ -13189,9 +13222,8 @@ SDValue AArch64TargetLowering::PerformDAGCombine(SDNode *N,
case Intrinsic::aarch64_neon_st3lane:
case Intrinsic::aarch64_neon_st4lane:
return performNEONPostLDSTCombine(N, DCI, DAG);
case Intrinsic::aarch64_sve_ld1:
case Intrinsic::aarch64_sve_ldnt1:
return performLD1Combine(N, DAG);
return performLDNT1Combine(N, DAG);
case Intrinsic::aarch64_sve_ldnt1_gather_scalar_offset:
return performGatherLoadCombine(N, DAG, AArch64ISD::GLDNT1);
case Intrinsic::aarch64_sve_ldnt1_gather:
Expand All @@ -13200,13 +13232,16 @@ SDValue AArch64TargetLowering::PerformDAGCombine(SDNode *N,
return performGatherLoadCombine(N, DAG, AArch64ISD::GLDNT1_INDEX);
case Intrinsic::aarch64_sve_ldnt1_gather_uxtw:
return performGatherLoadCombine(N, DAG, AArch64ISD::GLDNT1);
case Intrinsic::aarch64_sve_ld1:
return performLD1Combine(N, DAG, AArch64ISD::LD1);
case Intrinsic::aarch64_sve_ldnf1:
return performLDNF1Combine(N, DAG, AArch64ISD::LDNF1);
return performLD1Combine(N, DAG, AArch64ISD::LDNF1);
case Intrinsic::aarch64_sve_ldff1:
return performLDNF1Combine(N, DAG, AArch64ISD::LDFF1);
return performLD1Combine(N, DAG, AArch64ISD::LDFF1);
case Intrinsic::aarch64_sve_st1:
case Intrinsic::aarch64_sve_stnt1:
return performST1Combine(N, DAG);
case Intrinsic::aarch64_sve_stnt1:
return performSTNT1Combine(N, DAG);
case Intrinsic::aarch64_sve_stnt1_scatter_scalar_offset:
return performScatterStoreCombine(N, DAG, AArch64ISD::SSTNT1);
case Intrinsic::aarch64_sve_stnt1_scatter_uxtw:
Expand Down
4 changes: 4 additions & 0 deletions llvm/lib/Target/AArch64/AArch64ISelLowering.h
Expand Up @@ -226,6 +226,8 @@ enum NodeType : unsigned {

REINTERPRET_CAST,

LD1,
LD1S,
LDNF1,
LDNF1S,
LDFF1,
Expand Down Expand Up @@ -272,6 +274,8 @@ enum NodeType : unsigned {
GLDNT1_INDEX,
GLDNT1S,

ST1,

// Scatter store
SST1,
SST1_SCALED,
Expand Down
125 changes: 99 additions & 26 deletions llvm/lib/Target/AArch64/AArch64SVEInstrInfo.td
Expand Up @@ -13,18 +13,23 @@
def SVE8BitLslImm : ComplexPattern<i32, 2, "SelectSVE8BitLslImm", [imm]>;
def SVELShiftImm64 : ComplexPattern<i32, 1, "SelectSVEShiftImm64<0, 64>", []>;

// Non-faulting & first-faulting loads - node definitions
// Contiguous loads - node definitions
//
def SDT_AArch64_LDNF1 : SDTypeProfile<1, 3, [
def SDT_AArch64_LD1 : SDTypeProfile<1, 3, [
SDTCisVec<0>, SDTCisVec<1>, SDTCisPtrTy<2>,
SDTCVecEltisVT<1,i1>, SDTCisSameNumEltsAs<0,1>
]>;

def AArch64ldnf1 : SDNode<"AArch64ISD::LDNF1", SDT_AArch64_LDNF1, [SDNPHasChain, SDNPMayLoad, SDNPOptInGlue, SDNPOutGlue]>;
def AArch64ldff1 : SDNode<"AArch64ISD::LDFF1", SDT_AArch64_LDNF1, [SDNPHasChain, SDNPMayLoad, SDNPOptInGlue, SDNPOutGlue]>;
def AArch64ld1 : SDNode<"AArch64ISD::LD1", SDT_AArch64_LD1, [SDNPHasChain, SDNPMayLoad, SDNPOptInGlue]>;
def AArch64ld1s : SDNode<"AArch64ISD::LD1S", SDT_AArch64_LD1, [SDNPHasChain, SDNPMayLoad, SDNPOptInGlue]>;

// Non-faulting & first-faulting loads - node definitions
//
def AArch64ldnf1 : SDNode<"AArch64ISD::LDNF1", SDT_AArch64_LD1, [SDNPHasChain, SDNPMayLoad, SDNPOptInGlue, SDNPOutGlue]>;
def AArch64ldff1 : SDNode<"AArch64ISD::LDFF1", SDT_AArch64_LD1, [SDNPHasChain, SDNPMayLoad, SDNPOptInGlue, SDNPOutGlue]>;

def AArch64ldnf1s : SDNode<"AArch64ISD::LDNF1S", SDT_AArch64_LDNF1, [SDNPHasChain, SDNPMayLoad, SDNPOptInGlue, SDNPOutGlue]>;
def AArch64ldff1s : SDNode<"AArch64ISD::LDFF1S", SDT_AArch64_LDNF1, [SDNPHasChain, SDNPMayLoad, SDNPOptInGlue, SDNPOutGlue]>;
def AArch64ldnf1s : SDNode<"AArch64ISD::LDNF1S", SDT_AArch64_LD1, [SDNPHasChain, SDNPMayLoad, SDNPOptInGlue, SDNPOutGlue]>;
def AArch64ldff1s : SDNode<"AArch64ISD::LDFF1S", SDT_AArch64_LD1, [SDNPHasChain, SDNPMayLoad, SDNPOptInGlue, SDNPOutGlue]>;

// Gather loads - node definitions
//
Expand Down Expand Up @@ -73,6 +78,15 @@ def AArch64ldff1s_gather_imm : SDNode<"AArch64ISD::GLDFF1S_IMM",
def AArch64ldnt1_gather : SDNode<"AArch64ISD::GLDNT1", SDT_AArch64_GATHER_VS, [SDNPHasChain, SDNPMayLoad]>;
def AArch64ldnt1s_gather : SDNode<"AArch64ISD::GLDNT1S", SDT_AArch64_GATHER_VS, [SDNPHasChain, SDNPMayLoad]>;

// Contiguous stores - node definitions
//
def SDT_AArch64_ST1 : SDTypeProfile<0, 4, [
SDTCisVec<0>, SDTCisPtrTy<1>, SDTCisVec<2>,
SDTCVecEltisVT<2,i1>, SDTCisSameNumEltsAs<0,2>
]>;

def AArch64st1 : SDNode<"AArch64ISD::ST1", SDT_AArch64_ST1, [SDNPHasChain, SDNPMayStore]>;

// Scatter stores - node definitions
//
def SDT_AArch64_SCATTER_SV : SDTypeProfile<0, 5, [
Expand Down Expand Up @@ -1554,7 +1568,7 @@ multiclass sve_prefetch<SDPatternOperator prefetch, ValueType PredTy, Instructio
defm Pat_Load_P4 : unpred_load_predicate<nxv4i1, LDR_PXI>;
defm Pat_Load_P2 : unpred_load_predicate<nxv2i1, LDR_PXI>;

multiclass ldnf1<Instruction I, ValueType Ty, SDPatternOperator Load, ValueType PredTy, ValueType MemVT> {
multiclass ld1<Instruction I, ValueType Ty, SDPatternOperator Load, ValueType PredTy, ValueType MemVT> {
// scalar + immediate (mul vl)
let AddedComplexity = 1 in {
def : Pat<(Ty (Load (PredTy PPR:$gp), (am_sve_indexed_s4 GPR64sp:$base, simm4s1:$offset), MemVT)),
Expand All @@ -1566,32 +1580,60 @@ multiclass sve_prefetch<SDPatternOperator prefetch, ValueType PredTy, Instructio
(I PPR:$gp, GPR64sp:$base, (i64 0))>;
}

// 2-element contiguous loads
defm : ld1<LD1B_D_IMM, nxv2i64, AArch64ld1, nxv2i1, nxv2i8>;
defm : ld1<LD1SB_D_IMM, nxv2i64, AArch64ld1s, nxv2i1, nxv2i8>;
defm : ld1<LD1H_D_IMM, nxv2i64, AArch64ld1, nxv2i1, nxv2i16>;
defm : ld1<LD1SH_D_IMM, nxv2i64, AArch64ld1s, nxv2i1, nxv2i16>;
defm : ld1<LD1W_D_IMM, nxv2i64, AArch64ld1, nxv2i1, nxv2i32>;
defm : ld1<LD1SW_D_IMM, nxv2i64, AArch64ld1s, nxv2i1, nxv2i32>;
defm : ld1<LD1D_IMM, nxv2i64, AArch64ld1, nxv2i1, nxv2i64>;
defm : ld1<LD1D_IMM, nxv2f64, AArch64ld1, nxv2i1, nxv2f64>;

// 4-element contiguous loads
defm : ld1<LD1B_S_IMM, nxv4i32, AArch64ld1, nxv4i1, nxv4i8>;
defm : ld1<LD1SB_S_IMM, nxv4i32, AArch64ld1s, nxv4i1, nxv4i8>;
defm : ld1<LD1H_S_IMM, nxv4i32, AArch64ld1, nxv4i1, nxv4i16>;
defm : ld1<LD1SH_S_IMM, nxv4i32, AArch64ld1s, nxv4i1, nxv4i16>;
defm : ld1<LD1W_IMM, nxv4i32, AArch64ld1, nxv4i1, nxv4i32>;
defm : ld1<LD1W_IMM, nxv4f32, AArch64ld1, nxv4i1, nxv4f32>;

// 8-element contiguous loads
defm : ld1<LD1B_H_IMM, nxv8i16, AArch64ld1, nxv8i1, nxv8i8>;
defm : ld1<LD1SB_H_IMM, nxv8i16, AArch64ld1s, nxv8i1, nxv8i8>;
defm : ld1<LD1H_IMM, nxv8i16, AArch64ld1, nxv8i1, nxv8i16>;
defm : ld1<LD1H_IMM, nxv8f16, AArch64ld1, nxv8i1, nxv8f16>;

// 16-element contiguous loads
defm : ld1<LD1B_IMM, nxv16i8, AArch64ld1, nxv16i1, nxv16i8>;


// 2-element contiguous non-faulting loads
defm : ldnf1<LDNF1B_D_IMM, nxv2i64, AArch64ldnf1, nxv2i1, nxv2i8>;
defm : ldnf1<LDNF1SB_D_IMM, nxv2i64, AArch64ldnf1s, nxv2i1, nxv2i8>;
defm : ldnf1<LDNF1H_D_IMM, nxv2i64, AArch64ldnf1, nxv2i1, nxv2i16>;
defm : ldnf1<LDNF1SH_D_IMM, nxv2i64, AArch64ldnf1s, nxv2i1, nxv2i16>;
defm : ldnf1<LDNF1W_D_IMM, nxv2i64, AArch64ldnf1, nxv2i1, nxv2i32>;
defm : ldnf1<LDNF1SW_D_IMM, nxv2i64, AArch64ldnf1s, nxv2i1, nxv2i32>;
defm : ldnf1<LDNF1D_IMM, nxv2i64, AArch64ldnf1, nxv2i1, nxv2i64>;
defm : ldnf1<LDNF1D_IMM, nxv2f64, AArch64ldnf1, nxv2i1, nxv2f64>;
defm : ld1<LDNF1B_D_IMM, nxv2i64, AArch64ldnf1, nxv2i1, nxv2i8>;
defm : ld1<LDNF1SB_D_IMM, nxv2i64, AArch64ldnf1s, nxv2i1, nxv2i8>;
defm : ld1<LDNF1H_D_IMM, nxv2i64, AArch64ldnf1, nxv2i1, nxv2i16>;
defm : ld1<LDNF1SH_D_IMM, nxv2i64, AArch64ldnf1s, nxv2i1, nxv2i16>;
defm : ld1<LDNF1W_D_IMM, nxv2i64, AArch64ldnf1, nxv2i1, nxv2i32>;
defm : ld1<LDNF1SW_D_IMM, nxv2i64, AArch64ldnf1s, nxv2i1, nxv2i32>;
defm : ld1<LDNF1D_IMM, nxv2i64, AArch64ldnf1, nxv2i1, nxv2i64>;
defm : ld1<LDNF1D_IMM, nxv2f64, AArch64ldnf1, nxv2i1, nxv2f64>;

// 4-element contiguous non-faulting loads
defm : ldnf1<LDNF1B_S_IMM, nxv4i32, AArch64ldnf1, nxv4i1, nxv4i8>;
defm : ldnf1<LDNF1SB_S_IMM, nxv4i32, AArch64ldnf1s, nxv4i1, nxv4i8>;
defm : ldnf1<LDNF1H_S_IMM, nxv4i32, AArch64ldnf1, nxv4i1, nxv4i16>;
defm : ldnf1<LDNF1SH_S_IMM, nxv4i32, AArch64ldnf1s, nxv4i1, nxv4i16>;
defm : ldnf1<LDNF1W_IMM, nxv4i32, AArch64ldnf1, nxv4i1, nxv4i32>;
defm : ldnf1<LDNF1W_IMM, nxv4f32, AArch64ldnf1, nxv4i1, nxv4f32>;
defm : ld1<LDNF1B_S_IMM, nxv4i32, AArch64ldnf1, nxv4i1, nxv4i8>;
defm : ld1<LDNF1SB_S_IMM, nxv4i32, AArch64ldnf1s, nxv4i1, nxv4i8>;
defm : ld1<LDNF1H_S_IMM, nxv4i32, AArch64ldnf1, nxv4i1, nxv4i16>;
defm : ld1<LDNF1SH_S_IMM, nxv4i32, AArch64ldnf1s, nxv4i1, nxv4i16>;
defm : ld1<LDNF1W_IMM, nxv4i32, AArch64ldnf1, nxv4i1, nxv4i32>;
defm : ld1<LDNF1W_IMM, nxv4f32, AArch64ldnf1, nxv4i1, nxv4f32>;

// 8-element contiguous non-faulting loads
defm : ldnf1<LDNF1B_H_IMM, nxv8i16, AArch64ldnf1, nxv8i1, nxv8i8>;
defm : ldnf1<LDNF1SB_H_IMM, nxv8i16, AArch64ldnf1s, nxv8i1, nxv8i8>;
defm : ldnf1<LDNF1H_IMM, nxv8i16, AArch64ldnf1, nxv8i1, nxv8i16>;
defm : ldnf1<LDNF1H_IMM, nxv8f16, AArch64ldnf1, nxv8i1, nxv8f16>;
defm : ld1<LDNF1B_H_IMM, nxv8i16, AArch64ldnf1, nxv8i1, nxv8i8>;
defm : ld1<LDNF1SB_H_IMM, nxv8i16, AArch64ldnf1s, nxv8i1, nxv8i8>;
defm : ld1<LDNF1H_IMM, nxv8i16, AArch64ldnf1, nxv8i1, nxv8i16>;
defm : ld1<LDNF1H_IMM, nxv8f16, AArch64ldnf1, nxv8i1, nxv8f16>;

// 16-element contiguous non-faulting loads
defm : ldnf1<LDNF1B_IMM, nxv16i8, AArch64ldnf1, nxv16i1, nxv16i8>;
defm : ld1<LDNF1B_IMM, nxv16i8, AArch64ldnf1, nxv16i1, nxv16i8>;

multiclass ldff1<Instruction I, ValueType Ty, SDPatternOperator Load, ValueType PredTy, ValueType MemVT, ComplexPattern AddrCP> {
// reg + reg
Expand Down Expand Up @@ -1632,6 +1674,37 @@ multiclass sve_prefetch<SDPatternOperator prefetch, ValueType PredTy, Instructio

// 16-element contiguous first faulting loads
defm : ldff1<LDFF1B, nxv16i8, AArch64ldff1, nxv16i1, nxv16i8, am_sve_regreg_lsl0>;

multiclass st1<Instruction I, ValueType Ty, SDPatternOperator Store, ValueType PredTy, ValueType MemVT> {
// scalar + immediate (mul vl)
let AddedComplexity = 1 in {
def : Pat<(Store (Ty ZPR:$vec), (am_sve_indexed_s4 GPR64sp:$base, simm4s1:$offset), (PredTy PPR:$gp), MemVT),
(I ZPR:$vec, PPR:$gp, GPR64sp:$base, simm4s1:$offset)>;
}

// base
def : Pat<(Store (Ty ZPR:$vec), GPR64:$base, (PredTy PPR:$gp), MemVT),
(I ZPR:$vec, PPR:$gp, GPR64:$base, (i64 0))>;
}

// 2-element contiguous store
defm : st1<ST1B_D_IMM, nxv2i64, AArch64st1, nxv2i1, nxv2i8>;
defm : st1<ST1H_D_IMM, nxv2i64, AArch64st1, nxv2i1, nxv2i16>;
defm : st1<ST1W_D_IMM, nxv2i64, AArch64st1, nxv2i1, nxv2i32>;
defm : st1<ST1D_IMM, nxv2i64, AArch64st1, nxv2i1, nxv2i64>;

// 4-element contiguous store
defm : st1<ST1B_S_IMM, nxv4i32, AArch64st1, nxv4i1, nxv4i8>;
defm : st1<ST1H_S_IMM, nxv4i32, AArch64st1, nxv4i1, nxv4i16>;
defm : st1<ST1W_IMM, nxv4i32, AArch64st1, nxv4i1, nxv4i32>;

// 8-element contiguous store
defm : st1<ST1B_H_IMM, nxv8i16, AArch64st1, nxv8i1, nxv8i8>;
defm : st1<ST1H_IMM, nxv8i16, AArch64st1, nxv8i1, nxv8i16>;

// 16-element contiguous store
defm : st1<ST1B_IMM, nxv16i8, AArch64st1, nxv16i1, nxv16i8>;

}

let Predicates = [HasSVE2] in {
Expand Down

0 comments on commit 33ffce5

Please sign in to comment.