Skip to content

Commit

Permalink
[VE] (masked) load|store v256.32|64 isel
Browse files Browse the repository at this point in the history
Add `vvp_load|store` nodes. Lower to `vld`, `vst` where possible. Use
`vgt` for masked loads for now.

Reviewed By: kaz7

Differential Revision: https://reviews.llvm.org/D120413
  • Loading branch information
simoll committed Mar 2, 2022
1 parent 3ca1098 commit 9ebaec4
Show file tree
Hide file tree
Showing 11 changed files with 554 additions and 7 deletions.
120 changes: 119 additions & 1 deletion llvm/lib/Target/VE/VECustomDAG.cpp
Expand Up @@ -61,6 +61,10 @@ bool isMaskArithmetic(SDValue Op) {
/// \returns the VVP_* SDNode opcode corresponsing to \p OC.
Optional<unsigned> getVVPOpcode(unsigned Opcode) {
switch (Opcode) {
case ISD::MLOAD:
return VEISD::VVP_LOAD;
case ISD::MSTORE:
return VEISD::VVP_STORE;
#define HANDLE_VP_TO_VVP(VPOPC, VVPNAME) \
case ISD::VPOPC: \
return VEISD::VVPNAME;
Expand Down Expand Up @@ -166,8 +170,12 @@ Optional<int> getMaskPos(unsigned Opc) {
if (isVVPBinaryOp(Opc))
return 2;

// VM Opcodes.
// Other opcodes.
switch (Opc) {
case ISD::MSTORE:
return 4;
case ISD::MLOAD:
return 3;
case VEISD::VVP_SELECT:
return 2;
}
Expand All @@ -177,6 +185,116 @@ Optional<int> getMaskPos(unsigned Opc) {

bool isLegalAVL(SDValue AVL) { return AVL->getOpcode() == VEISD::LEGALAVL; }

/// Node Properties {

SDValue getNodeChain(SDValue Op) {
if (MemSDNode *MemN = dyn_cast<MemSDNode>(Op.getNode()))
return MemN->getChain();

switch (Op->getOpcode()) {
case VEISD::VVP_LOAD:
case VEISD::VVP_STORE:
return Op->getOperand(0);
}
return SDValue();
}

SDValue getMemoryPtr(SDValue Op) {
if (auto *MemN = dyn_cast<MemSDNode>(Op.getNode()))
return MemN->getBasePtr();

switch (Op->getOpcode()) {
case VEISD::VVP_LOAD:
return Op->getOperand(1);
case VEISD::VVP_STORE:
return Op->getOperand(2);
}
return SDValue();
}

Optional<EVT> getIdiomaticVectorType(SDNode *Op) {
unsigned OC = Op->getOpcode();

// For memory ops -> the transfered data type
if (auto MemN = dyn_cast<MemSDNode>(Op))
return MemN->getMemoryVT();

switch (OC) {
// Standard ISD.
case ISD::SELECT: // not aliased with VVP_SELECT
case ISD::CONCAT_VECTORS:
case ISD::EXTRACT_SUBVECTOR:
case ISD::VECTOR_SHUFFLE:
case ISD::BUILD_VECTOR:
case ISD::SCALAR_TO_VECTOR:
return Op->getValueType(0);
}

// Translate to VVP where possible.
if (auto VVPOpc = getVVPOpcode(OC))
OC = *VVPOpc;

switch (OC) {
default:
case VEISD::VVP_SETCC:
return Op->getOperand(0).getValueType();

case VEISD::VVP_SELECT:
#define ADD_BINARY_VVP_OP(VVP_NAME, ...) case VEISD::VVP_NAME:
#include "VVPNodes.def"
return Op->getValueType(0);

case VEISD::VVP_LOAD:
return Op->getValueType(0);

case VEISD::VVP_STORE:
return Op->getOperand(1)->getValueType(0);

// VEC
case VEISD::VEC_BROADCAST:
return Op->getValueType(0);
}
}

SDValue getLoadStoreStride(SDValue Op, VECustomDAG &CDAG) {
if (Op->getOpcode() == VEISD::VVP_STORE)
return Op->getOperand(3);
if (Op->getOpcode() == VEISD::VVP_LOAD)
return Op->getOperand(2);

if (isa<MemSDNode>(Op.getNode())) {
// Regular MLOAD/MSTORE/LOAD/STORE
// No stride argument -> use the contiguous element size as stride.
uint64_t ElemStride = getIdiomaticVectorType(Op.getNode())
->getVectorElementType()
.getStoreSize();
return CDAG.getConstant(ElemStride, MVT::i64);
}
return SDValue();
}

SDValue getStoredValue(SDValue Op) {
switch (Op->getOpcode()) {
case VEISD::VVP_STORE:
return Op->getOperand(1);
}
if (auto *StoreN = dyn_cast<StoreSDNode>(Op.getNode()))
return StoreN->getValue();
if (auto *StoreN = dyn_cast<MaskedStoreSDNode>(Op.getNode()))
return StoreN->getValue();
if (auto *StoreN = dyn_cast<VPStoreSDNode>(Op.getNode()))
return StoreN->getValue();
return SDValue();
}

SDValue getNodePassthru(SDValue Op) {
if (auto *N = dyn_cast<MaskedLoadSDNode>(Op.getNode()))
return N->getPassThru();
return SDValue();
}

/// } Node Properties

SDValue getNodeAVL(SDValue Op) {
auto PosOpt = getAVLPos(Op->getOpcode());
return PosOpt ? Op->getOperand(*PosOpt) : SDValue();
Expand Down
20 changes: 20 additions & 0 deletions llvm/lib/Target/VE/VECustomDAG.h
Expand Up @@ -88,6 +88,22 @@ std::pair<SDValue, bool> getAnnotatedNodeAVL(SDValue);

/// } AVL Functions

/// Node Properties {

Optional<EVT> getIdiomaticVectorType(SDNode *Op);

SDValue getLoadStoreStride(SDValue Op, VECustomDAG &CDAG);

SDValue getMemoryPtr(SDValue Op);

SDValue getNodeChain(SDValue Op);

SDValue getStoredValue(SDValue Op);

SDValue getNodePassthru(SDValue Op);

/// } Node Properties

enum class Packing {
Normal = 0, // 256 element standard mode.
Dense = 1 // 512 element packed mode.
Expand Down Expand Up @@ -157,6 +173,10 @@ class VECustomDAG {
SDValue getPack(EVT DestVT, SDValue LoVec, SDValue HiVec, SDValue AVL) const;
/// } Packing

SDValue getMergeValues(ArrayRef<SDValue> Values) const {
return DAG.getMergeValues(Values, DL);
}

SDValue getConstant(uint64_t Val, EVT VT, bool IsTarget = false,
bool IsOpaque = false) const;

Expand Down
41 changes: 35 additions & 6 deletions llvm/lib/Target/VE/VEISelLowering.cpp
Expand Up @@ -322,6 +322,17 @@ void VETargetLowering::initVPUActions() {
setOperationAction(ISD::INSERT_VECTOR_ELT, LegalPackedVT, Custom);
setOperationAction(ISD::EXTRACT_VECTOR_ELT, LegalPackedVT, Custom);
}

// vNt32, vNt64 ops (legal element types)
for (MVT VT : MVT::vector_valuetypes()) {
MVT ElemVT = VT.getVectorElementType();
unsigned ElemBits = ElemVT.getScalarSizeInBits();
if (ElemBits != 32 && ElemBits != 64)
continue;

for (unsigned MemOpc : {ISD::MLOAD, ISD::MSTORE, ISD::LOAD, ISD::STORE})
setOperationAction(MemOpc, VT, Custom);
}
}

SDValue
Expand Down Expand Up @@ -1321,14 +1332,19 @@ static SDValue lowerLoadF128(SDValue Op, SelectionDAG &DAG) {
SDValue VETargetLowering::lowerLOAD(SDValue Op, SelectionDAG &DAG) const {
LoadSDNode *LdNode = cast<LoadSDNode>(Op.getNode());

EVT MemVT = LdNode->getMemoryVT();

// Dispatch to vector isel.
if (MemVT.isVector() && !isMaskType(MemVT))
return lowerToVVP(Op, DAG);

SDValue BasePtr = LdNode->getBasePtr();
if (isa<FrameIndexSDNode>(BasePtr.getNode())) {
// Do not expand store instruction with frame index here because of
// dependency problems. We expand it later in eliminateFrameIndex().
return Op;
}

EVT MemVT = LdNode->getMemoryVT();
if (MemVT == MVT::f128)
return lowerLoadF128(Op, DAG);

Expand Down Expand Up @@ -1375,14 +1391,18 @@ SDValue VETargetLowering::lowerSTORE(SDValue Op, SelectionDAG &DAG) const {
StoreSDNode *StNode = cast<StoreSDNode>(Op.getNode());
assert(StNode && StNode->getOffset().isUndef() && "Unexpected node type");

// always expand non-mask vector loads to VVP
EVT MemVT = StNode->getMemoryVT();
if (MemVT.isVector() && !isMaskType(MemVT))
return lowerToVVP(Op, DAG);

SDValue BasePtr = StNode->getBasePtr();
if (isa<FrameIndexSDNode>(BasePtr.getNode())) {
// Do not expand store instruction with frame index here because of
// dependency problems. We expand it later in eliminateFrameIndex().
return Op;
}

EVT MemVT = StNode->getMemoryVT();
if (MemVT == MVT::f128)
return lowerStoreF128(Op, DAG);

Expand Down Expand Up @@ -1699,12 +1719,9 @@ VETargetLowering::getCustomOperationAction(SDNode &Op) const {
SDValue VETargetLowering::LowerOperation(SDValue Op, SelectionDAG &DAG) const {
LLVM_DEBUG(dbgs() << "::LowerOperation"; Op->print(dbgs()););
unsigned Opcode = Op.getOpcode();
if (ISD::isVPOpcode(Opcode))
return lowerToVVP(Op, DAG);

/// Scalar isel.
switch (Opcode) {
default:
llvm_unreachable("Should not custom lower this!");
case ISD::ATOMIC_FENCE:
return lowerATOMIC_FENCE(Op, DAG);
case ISD::ATOMIC_SWAP:
Expand Down Expand Up @@ -1748,6 +1765,16 @@ SDValue VETargetLowering::LowerOperation(SDValue Op, SelectionDAG &DAG) const {
return lowerINSERT_VECTOR_ELT(Op, DAG);
case ISD::EXTRACT_VECTOR_ELT:
return lowerEXTRACT_VECTOR_ELT(Op, DAG);
}

/// Vector isel.
LLVM_DEBUG(dbgs() << "::LowerOperation_VVP"; Op->print(dbgs()););
if (ISD::isVPOpcode(Opcode))
return lowerToVVP(Op, DAG);

switch (Opcode) {
default:
llvm_unreachable("Should not custom lower this!");

// Legalize the AVL of this internal node.
case VEISD::VEC_BROADCAST:
Expand All @@ -1759,6 +1786,8 @@ SDValue VETargetLowering::LowerOperation(SDValue Op, SelectionDAG &DAG) const {
return legalizeInternalVectorOp(Op, DAG);

// Translate into a VEC_*/VVP_* layer operation.
case ISD::MLOAD:
case ISD::MSTORE:
#define ADD_VVP_OP(VVP_NAME, ISD_NAME) case ISD::ISD_NAME:
#include "VVPNodes.def"
if (isMaskArithmetic(Op) && isPackedVectorType(Op.getValueType()))
Expand Down
2 changes: 2 additions & 0 deletions llvm/lib/Target/VE/VEISelLowering.h
Expand Up @@ -186,6 +186,8 @@ class VETargetLowering : public TargetLowering {

/// VVP Lowering {
SDValue lowerToVVP(SDValue Op, SelectionDAG &DAG) const;
SDValue lowerVVP_LOAD_STORE(SDValue Op, VECustomDAG&) const;

SDValue legalizeInternalVectorOp(SDValue Op, SelectionDAG &DAG) const;
SDValue splitVectorOp(SDValue Op, VECustomDAG &CDAG) const;
SDValue legalizePackedAVL(SDValue Op, VECustomDAG &CDAG) const;
Expand Down
41 changes: 41 additions & 0 deletions llvm/lib/Target/VE/VETargetTransformInfo.h
Expand Up @@ -21,6 +21,32 @@
#include "llvm/Analysis/TargetTransformInfo.h"
#include "llvm/CodeGen/BasicTTIImpl.h"

static llvm::Type *getVectorElementType(llvm::Type *Ty) {
return llvm::cast<llvm::FixedVectorType>(Ty)->getElementType();
}

static llvm::Type *getLaneType(llvm::Type *Ty) {
using namespace llvm;
if (!isa<VectorType>(Ty))
return Ty;
return getVectorElementType(Ty);
}

static bool isVectorLaneType(llvm::Type &ElemTy) {
// check element sizes for vregs
if (ElemTy.isIntegerTy()) {
unsigned ScaBits = ElemTy.getScalarSizeInBits();
return ScaBits == 1 || ScaBits == 32 || ScaBits == 64;
}
if (ElemTy.isPointerTy()) {
return true;
}
if (ElemTy.isFloatTy() || ElemTy.isDoubleTy()) {
return true;
}
return false;
}

namespace llvm {

class VETTIImpl : public BasicTTIImplBase<VETTIImpl> {
Expand Down Expand Up @@ -86,6 +112,21 @@ class VETTIImpl : public BasicTTIImplBase<VETTIImpl> {
// output
return false;
}

// Load & Store {
bool isLegalMaskedLoad(Type *DataType, MaybeAlign Alignment) {
return isVectorLaneType(*getLaneType(DataType));
}
bool isLegalMaskedStore(Type *DataType, MaybeAlign Alignment) {
return isVectorLaneType(*getLaneType(DataType));
}
bool isLegalMaskedGather(Type *DataType, MaybeAlign Alignment) {
return isVectorLaneType(*getLaneType(DataType));
};
bool isLegalMaskedScatter(Type *DataType, MaybeAlign Alignment) {
return isVectorLaneType(*getLaneType(DataType));
}
// } Load & Store
};

} // namespace llvm
Expand Down

0 comments on commit 9ebaec4

Please sign in to comment.