Skip to content

Commit

Permalink
[llvm][CodeGen] Addressing modes for SVE ldN.
Browse files Browse the repository at this point in the history
Reviewers: c-rhodes, efriedma, sdesmalen

Subscribers: huihuiz, tschuett, hiraditya, llvm-commits

Tags: #llvm

Differential Revision: https://reviews.llvm.org/D77251
  • Loading branch information
Francesco Petrogalli committed Jul 27, 2020
1 parent e574641 commit adb28e0
Show file tree
Hide file tree
Showing 3 changed files with 798 additions and 20 deletions.
64 changes: 44 additions & 20 deletions llvm/lib/Target/AArch64/AArch64ISelDAGToDAG.cpp
Expand Up @@ -245,7 +245,8 @@ class AArch64DAGToDAGISel : public SelectionDAGISel {
unsigned SubRegIdx);
void SelectLoadLane(SDNode *N, unsigned NumVecs, unsigned Opc);
void SelectPostLoadLane(SDNode *N, unsigned NumVecs, unsigned Opc);
void SelectPredicatedLoad(SDNode *N, unsigned NumVecs, const unsigned Opc);
void SelectPredicatedLoad(SDNode *N, unsigned NumVecs, unsigned Scale,
unsigned Opc_rr, unsigned Opc_ri);

bool SelectAddrModeFrameIndexSVE(SDValue N, SDValue &Base, SDValue &OffImm);
/// SVE Reg+Imm addressing mode.
Expand Down Expand Up @@ -1434,14 +1435,23 @@ AArch64DAGToDAGISel::findAddrModeSVELoadStore(SDNode *N, unsigned Opc_rr,
}

void AArch64DAGToDAGISel::SelectPredicatedLoad(SDNode *N, unsigned NumVecs,
const unsigned Opc) {
unsigned Scale, unsigned Opc_ri,
unsigned Opc_rr) {
assert(Scale < 4 && "Invalid scaling value.");
SDLoc DL(N);
EVT VT = N->getValueType(0);
SDValue Chain = N->getOperand(0);

// Optimize addressing mode.
SDValue Base, Offset;
unsigned Opc;
std::tie(Opc, Base, Offset) = findAddrModeSVELoadStore(
N, Opc_rr, Opc_ri, N->getOperand(2),
CurDAG->getTargetConstant(0, DL, MVT::i64), Scale);

SDValue Ops[] = {N->getOperand(1), // Predicate
N->getOperand(2), // Memory operand
CurDAG->getTargetConstant(0, DL, MVT::i64), Chain};
Base, // Memory operand
Offset, Chain};

const EVT ResTys[] = {MVT::Untyped, MVT::Other};

Expand Down Expand Up @@ -4726,51 +4736,51 @@ void AArch64DAGToDAGISel::Select(SDNode *Node) {
}
case AArch64ISD::SVE_LD2_MERGE_ZERO: {
if (VT == MVT::nxv16i8) {
SelectPredicatedLoad(Node, 2, AArch64::LD2B_IMM);
SelectPredicatedLoad(Node, 2, 0, AArch64::LD2B_IMM, AArch64::LD2B);
return;
} else if (VT == MVT::nxv8i16 || VT == MVT::nxv8f16 ||
(VT == MVT::nxv8bf16 && Subtarget->hasBF16())) {
SelectPredicatedLoad(Node, 2, AArch64::LD2H_IMM);
SelectPredicatedLoad(Node, 2, 1, AArch64::LD2H_IMM, AArch64::LD2H);
return;
} else if (VT == MVT::nxv4i32 || VT == MVT::nxv4f32) {
SelectPredicatedLoad(Node, 2, AArch64::LD2W_IMM);
SelectPredicatedLoad(Node, 2, 2, AArch64::LD2W_IMM, AArch64::LD2W);
return;
} else if (VT == MVT::nxv2i64 || VT == MVT::nxv2f64) {
SelectPredicatedLoad(Node, 2, AArch64::LD2D_IMM);
SelectPredicatedLoad(Node, 2, 3, AArch64::LD2D_IMM, AArch64::LD2D);
return;
}
break;
}
case AArch64ISD::SVE_LD3_MERGE_ZERO: {
if (VT == MVT::nxv16i8) {
SelectPredicatedLoad(Node, 3, AArch64::LD3B_IMM);
SelectPredicatedLoad(Node, 3, 0, AArch64::LD3B_IMM, AArch64::LD3B);
return;
} else if (VT == MVT::nxv8i16 || VT == MVT::nxv8f16 ||
(VT == MVT::nxv8bf16 && Subtarget->hasBF16())) {
SelectPredicatedLoad(Node, 3, AArch64::LD3H_IMM);
SelectPredicatedLoad(Node, 3, 1, AArch64::LD3H_IMM, AArch64::LD3H);
return;
} else if (VT == MVT::nxv4i32 || VT == MVT::nxv4f32) {
SelectPredicatedLoad(Node, 3, AArch64::LD3W_IMM);
SelectPredicatedLoad(Node, 3, 2, AArch64::LD3W_IMM, AArch64::LD3W);
return;
} else if (VT == MVT::nxv2i64 || VT == MVT::nxv2f64) {
SelectPredicatedLoad(Node, 3, AArch64::LD3D_IMM);
SelectPredicatedLoad(Node, 3, 3, AArch64::LD3D_IMM, AArch64::LD3D);
return;
}
break;
}
case AArch64ISD::SVE_LD4_MERGE_ZERO: {
if (VT == MVT::nxv16i8) {
SelectPredicatedLoad(Node, 4, AArch64::LD4B_IMM);
SelectPredicatedLoad(Node, 4, 0, AArch64::LD4B_IMM, AArch64::LD4B);
return;
} else if (VT == MVT::nxv8i16 || VT == MVT::nxv8f16 ||
(VT == MVT::nxv8bf16 && Subtarget->hasBF16())) {
SelectPredicatedLoad(Node, 4, AArch64::LD4H_IMM);
SelectPredicatedLoad(Node, 4, 1, AArch64::LD4H_IMM, AArch64::LD4H);
return;
} else if (VT == MVT::nxv4i32 || VT == MVT::nxv4f32) {
SelectPredicatedLoad(Node, 4, AArch64::LD4W_IMM);
SelectPredicatedLoad(Node, 4, 2, AArch64::LD4W_IMM, AArch64::LD4W);
return;
} else if (VT == MVT::nxv2i64 || VT == MVT::nxv2f64) {
SelectPredicatedLoad(Node, 4, AArch64::LD4D_IMM);
SelectPredicatedLoad(Node, 4, 3, AArch64::LD4D_IMM, AArch64::LD4D);
return;
}
break;
Expand All @@ -4790,10 +4800,14 @@ FunctionPass *llvm::createAArch64ISelDag(AArch64TargetMachine &TM,

/// When \p PredVT is a scalable vector predicate in the form
/// MVT::nx<M>xi1, it builds the correspondent scalable vector of
/// integers MVT::nx<M>xi<bits> s.t. M x bits = 128. If the input
/// integers MVT::nx<M>xi<bits> s.t. M x bits = 128. When targeting
/// structured vectors (NumVec >1), the output data type is
/// MVT::nx<M*NumVec>xi<bits> s.t. M x bits = 128. If the input
/// PredVT is not in the form MVT::nx<M>xi1, it returns an invalid
/// EVT.
static EVT getPackedVectorTypeFromPredicateType(LLVMContext &Ctx, EVT PredVT) {
static EVT getPackedVectorTypeFromPredicateType(LLVMContext &Ctx, EVT PredVT,
unsigned NumVec) {
assert(NumVec > 0 && NumVec < 5 && "Invalid number of vectors.");
if (!PredVT.isScalableVector() || PredVT.getVectorElementType() != MVT::i1)
return EVT();

Expand All @@ -4803,7 +4817,8 @@ static EVT getPackedVectorTypeFromPredicateType(LLVMContext &Ctx, EVT PredVT) {

ElementCount EC = PredVT.getVectorElementCount();
EVT ScalarVT = EVT::getIntegerVT(Ctx, AArch64::SVEBitsPerBlock / EC.Min);
EVT MemVT = EVT::getVectorVT(Ctx, ScalarVT, EC);
EVT MemVT = EVT::getVectorVT(Ctx, ScalarVT, EC * NumVec);

return MemVT;
}

Expand All @@ -4827,6 +4842,15 @@ static EVT getMemVTFromNode(LLVMContext &Ctx, SDNode *Root) {
return cast<VTSDNode>(Root->getOperand(3))->getVT();
case AArch64ISD::ST1_PRED:
return cast<VTSDNode>(Root->getOperand(4))->getVT();
case AArch64ISD::SVE_LD2_MERGE_ZERO:
return getPackedVectorTypeFromPredicateType(
Ctx, Root->getOperand(1)->getValueType(0), /*NumVec=*/2);
case AArch64ISD::SVE_LD3_MERGE_ZERO:
return getPackedVectorTypeFromPredicateType(
Ctx, Root->getOperand(1)->getValueType(0), /*NumVec=*/3);
case AArch64ISD::SVE_LD4_MERGE_ZERO:
return getPackedVectorTypeFromPredicateType(
Ctx, Root->getOperand(1)->getValueType(0), /*NumVec=*/4);
default:
break;
}
Expand All @@ -4842,7 +4866,7 @@ static EVT getMemVTFromNode(LLVMContext &Ctx, SDNode *Root) {
// We are using an SVE prefetch intrinsic. Type must be inferred
// from the width of the predicate.
return getPackedVectorTypeFromPredicateType(
Ctx, Root->getOperand(2)->getValueType(0));
Ctx, Root->getOperand(2)->getValueType(0), /*NumVec=*/1);
}

/// SelectAddrModeIndexedSVE - Attempt selection of the addressing mode:
Expand Down

0 comments on commit adb28e0

Please sign in to comment.