Skip to content

Commit

Permalink
[VP] implementation of sdag support for VP memory intrinsics
Browse files Browse the repository at this point in the history
Followup to D99355: SDAG support for vector-predicated load/store/gather/scatter.

Reviewed By: frasercrmck

Differential Revision: https://reviews.llvm.org/D105871
  • Loading branch information
hmk46 authored and simoll committed Aug 31, 2021
1 parent e79474d commit 524ded7
Show file tree
Hide file tree
Showing 6 changed files with 814 additions and 4 deletions.
67 changes: 67 additions & 0 deletions llvm/include/llvm/CodeGen/SelectionDAG.h
Expand Up @@ -1307,6 +1307,73 @@ class SelectionDAG {
SDValue getIndexedStore(SDValue OrigStore, const SDLoc &dl, SDValue Base,
SDValue Offset, ISD::MemIndexedMode AM);

SDValue getLoadVP(ISD::MemIndexedMode AM, ISD::LoadExtType ExtType, EVT VT,
const SDLoc &dl, SDValue Chain, SDValue Ptr, SDValue Offset,
SDValue Mask, SDValue EVL, MachinePointerInfo PtrInfo,
EVT MemVT, Align Alignment,
MachineMemOperand::Flags MMOFlags, const AAMDNodes &AAInfo,
const MDNode *Ranges = nullptr, bool IsExpanding = false);
inline SDValue
getLoadVP(ISD::MemIndexedMode AM, ISD::LoadExtType ExtType, EVT VT,
const SDLoc &dl, SDValue Chain, SDValue Ptr, SDValue Offset,
SDValue Mask, SDValue EVL, MachinePointerInfo PtrInfo, EVT MemVT,
MaybeAlign Alignment = MaybeAlign(),
MachineMemOperand::Flags MMOFlags = MachineMemOperand::MONone,
const AAMDNodes &AAInfo = AAMDNodes(),
const MDNode *Ranges = nullptr, bool IsExpanding = false) {
// Ensures that codegen never sees a None Alignment.
return getLoadVP(AM, ExtType, VT, dl, Chain, Ptr, Offset, Mask, EVL,
PtrInfo, MemVT, Alignment.getValueOr(getEVTAlign(MemVT)),
MMOFlags, AAInfo, Ranges, IsExpanding);
}
SDValue getLoadVP(ISD::MemIndexedMode AM, ISD::LoadExtType ExtType, EVT VT,
const SDLoc &dl, SDValue Chain, SDValue Ptr, SDValue Offset,
SDValue Mask, SDValue EVL, EVT MemVT,
MachineMemOperand *MMO, bool IsExpanding = false);
SDValue getLoadVP(EVT VT, const SDLoc &dl, SDValue Chain, SDValue Ptr,
SDValue Mask, SDValue EVL, MachinePointerInfo PtrInfo,
MaybeAlign Alignment, MachineMemOperand::Flags MMOFlags,
const AAMDNodes &AAInfo, const MDNode *Ranges = nullptr,
bool IsExpanding = false);
SDValue getLoadVP(EVT VT, const SDLoc &dl, SDValue Chain, SDValue Ptr,
SDValue Mask, SDValue EVL, MachineMemOperand *MMO,
bool IsExpanding = false);
SDValue getExtLoadVP(ISD::LoadExtType ExtType, const SDLoc &dl, EVT VT,
SDValue Chain, SDValue Ptr, SDValue Mask, SDValue EVL,
MachinePointerInfo PtrInfo, EVT MemVT,
MaybeAlign Alignment, MachineMemOperand::Flags MMOFlags,
const AAMDNodes &AAInfo, bool IsExpanding = false);
SDValue getExtLoadVP(ISD::LoadExtType ExtType, const SDLoc &dl, EVT VT,
SDValue Chain, SDValue Ptr, SDValue Mask, SDValue EVL,
EVT MemVT, MachineMemOperand *MMO,
bool IsExpanding = false);
SDValue getIndexedLoadVP(SDValue OrigLoad, const SDLoc &dl, SDValue Base,
SDValue Offset, ISD::MemIndexedMode AM);
SDValue getStoreVP(SDValue Chain, const SDLoc &dl, SDValue Val, SDValue Ptr,
SDValue Mask, SDValue EVL, MachinePointerInfo PtrInfo,
Align Alignment, MachineMemOperand::Flags MMOFlags,
const AAMDNodes &AAInfo, bool IsCompressing = false);
SDValue getStoreVP(SDValue Chain, const SDLoc &dl, SDValue Val, SDValue Ptr,
SDValue Mask, SDValue EVL, MachineMemOperand *MMO,
bool IsCompressing = false);
SDValue getTruncStoreVP(SDValue Chain, const SDLoc &dl, SDValue Val,
SDValue Ptr, SDValue Mask, SDValue EVL,
MachinePointerInfo PtrInfo, EVT SVT, Align Alignment,
MachineMemOperand::Flags MMOFlags,
const AAMDNodes &AAInfo, bool IsCompressing = false);
SDValue getTruncStoreVP(SDValue Chain, const SDLoc &dl, SDValue Val,
SDValue Ptr, SDValue Mask, SDValue EVL, EVT SVT,
MachineMemOperand *MMO, bool IsCompressing = false);
SDValue getIndexedStoreVP(SDValue OrigStore, const SDLoc &dl, SDValue Base,
SDValue Offset, ISD::MemIndexedMode AM);

SDValue getGatherVP(SDVTList VTs, EVT VT, const SDLoc &dl,
ArrayRef<SDValue> Ops, MachineMemOperand *MMO,
ISD::MemIndexType IndexType);
SDValue getScatterVP(SDVTList VTs, EVT VT, const SDLoc &dl,
ArrayRef<SDValue> Ops, MachineMemOperand *MMO,
ISD::MemIndexType IndexType);

SDValue getMaskedLoad(EVT VT, const SDLoc &dl, SDValue Chain, SDValue Base,
SDValue Offset, SDValue Mask, SDValue Src0, EVT MemVT,
MachineMemOperand *MMO, ISD::MemIndexedMode AM,
Expand Down
213 changes: 213 additions & 0 deletions llvm/include/llvm/CodeGen/SelectionDAGNodes.h
Expand Up @@ -509,24 +509,30 @@ BEGIN_TWO_BYTE_PACK()

class LSBaseSDNodeBitfields {
friend class LSBaseSDNode;
friend class VPLoadStoreSDNode;
friend class MaskedLoadStoreSDNode;
friend class MaskedGatherScatterSDNode;
friend class VPGatherScatterSDNode;

uint16_t : NumMemSDNodeBits;

// This storage is shared between disparate class hierarchies to hold an
// enumeration specific to the class hierarchy in use.
// LSBaseSDNode => enum ISD::MemIndexedMode
// VPLoadStoreBaseSDNode => enum ISD::MemIndexedMode
// MaskedLoadStoreBaseSDNode => enum ISD::MemIndexedMode
// VPGatherScatterSDNode => enum ISD::MemIndexType
// MaskedGatherScatterSDNode => enum ISD::MemIndexType
uint16_t AddressingMode : 3;
};
enum { NumLSBaseSDNodeBits = NumMemSDNodeBits + 3 };

class LoadSDNodeBitfields {
friend class LoadSDNode;
friend class VPLoadSDNode;
friend class MaskedLoadSDNode;
friend class MaskedGatherSDNode;
friend class VPGatherSDNode;

uint16_t : NumLSBaseSDNodeBits;

Expand All @@ -536,8 +542,10 @@ BEGIN_TWO_BYTE_PACK()

class StoreSDNodeBitfields {
friend class StoreSDNode;
friend class VPStoreSDNode;
friend class MaskedStoreSDNode;
friend class MaskedScatterSDNode;
friend class VPScatterSDNode;

uint16_t : NumLSBaseSDNodeBits;

Expand Down Expand Up @@ -1353,10 +1361,13 @@ class MemSDNode : public SDNode {
const SDValue &getBasePtr() const {
switch (getOpcode()) {
case ISD::STORE:
case ISD::VP_STORE:
case ISD::MSTORE:
return getOperand(2);
case ISD::MGATHER:
case ISD::MSCATTER:
case ISD::VP_GATHER:
case ISD::VP_SCATTER:
return getOperand(3);
default:
return getOperand(1);
Expand Down Expand Up @@ -1393,6 +1404,10 @@ class MemSDNode : public SDNode {
case ISD::MSTORE:
case ISD::MGATHER:
case ISD::MSCATTER:
case ISD::VP_LOAD:
case ISD::VP_STORE:
case ISD::VP_GATHER:
case ISD::VP_SCATTER:
return true;
default:
return N->isMemIntrinsic() || N->isTargetMemoryOpcode();
Expand Down Expand Up @@ -2318,6 +2333,116 @@ class StoreSDNode : public LSBaseSDNode {
}
};

/// This base class is used to represent VP_LOAD and VP_STORE nodes
class VPLoadStoreSDNode : public MemSDNode {
public:
friend class SelectionDAG;

VPLoadStoreSDNode(ISD::NodeType NodeTy, unsigned Order, const DebugLoc &dl,
SDVTList VTs, ISD::MemIndexedMode AM, EVT MemVT,
MachineMemOperand *MMO)
: MemSDNode(NodeTy, Order, dl, VTs, MemVT, MMO) {
LSBaseSDNodeBits.AddressingMode = AM;
assert(getAddressingMode() == AM && "Value truncated");
}

// VPLoadSDNode (Chain, Ptr, Offset, Mask, EVL)
// VPStoreSDNode (Chain, Data, Ptr, Offset, Mask, EVL)
// Mask is a vector of i1 elements;
// the type of EVL is TLI.getVPExplicitVectorLengthTy().
const SDValue &getOffset() const {
return getOperand(getOpcode() == ISD::MLOAD ? 2 : 3);
}
const SDValue &getBasePtr() const {
return getOperand(getOpcode() == ISD::VP_LOAD ? 1 : 2);
}
const SDValue &getMask() const {
return getOperand(getOpcode() == ISD::VP_LOAD ? 3 : 4);
}
const SDValue &getVectorLength() const {
return getOperand(getOpcode() == ISD::VP_LOAD ? 4 : 5);
}

/// Return the addressing mode for this load or store:
/// unindexed, pre-inc, pre-dec, post-inc, or post-dec.
ISD::MemIndexedMode getAddressingMode() const {
return static_cast<ISD::MemIndexedMode>(LSBaseSDNodeBits.AddressingMode);
}

/// Return true if this is a pre/post inc/dec load/store.
bool isIndexed() const { return getAddressingMode() != ISD::UNINDEXED; }

/// Return true if this is NOT a pre/post inc/dec load/store.
bool isUnindexed() const { return getAddressingMode() == ISD::UNINDEXED; }

static bool classof(const SDNode *N) {
return N->getOpcode() == ISD::VP_LOAD || N->getOpcode() == ISD::VP_STORE;
}
};

/// This class is used to represent a VP_LOAD node
class VPLoadSDNode : public VPLoadStoreSDNode {
public:
friend class SelectionDAG;

VPLoadSDNode(unsigned Order, const DebugLoc &dl, SDVTList VTs,
ISD::MemIndexedMode AM, ISD::LoadExtType ETy, bool isExpanding,
EVT MemVT, MachineMemOperand *MMO)
: VPLoadStoreSDNode(ISD::VP_LOAD, Order, dl, VTs, AM, MemVT, MMO) {
LoadSDNodeBits.ExtTy = ETy;
LoadSDNodeBits.IsExpanding = isExpanding;
}

ISD::LoadExtType getExtensionType() const {
return static_cast<ISD::LoadExtType>(LoadSDNodeBits.ExtTy);
}

const SDValue &getBasePtr() const { return getOperand(1); }
const SDValue &getOffset() const { return getOperand(2); }
const SDValue &getMask() const { return getOperand(3); }
const SDValue &getVectorLength() const { return getOperand(4); }

static bool classof(const SDNode *N) {
return N->getOpcode() == ISD::VP_LOAD;
}
bool isExpandingLoad() const { return LoadSDNodeBits.IsExpanding; }
};

/// This class is used to represent a VP_STORE node
class VPStoreSDNode : public VPLoadStoreSDNode {
public:
friend class SelectionDAG;

VPStoreSDNode(unsigned Order, const DebugLoc &dl, SDVTList VTs,
ISD::MemIndexedMode AM, bool isTrunc, bool isCompressing,
EVT MemVT, MachineMemOperand *MMO)
: VPLoadStoreSDNode(ISD::VP_STORE, Order, dl, VTs, AM, MemVT, MMO) {
StoreSDNodeBits.IsTruncating = isTrunc;
StoreSDNodeBits.IsCompressing = isCompressing;
}

/// Return true if this is a truncating store.
/// For integers this is the same as doing a TRUNCATE and storing the result.
/// For floats, it is the same as doing an FP_ROUND and storing the result.
bool isTruncatingStore() const { return StoreSDNodeBits.IsTruncating; }

/// Returns true if the op does a compression to the vector before storing.
/// The node contiguously stores the active elements (integers or floats)
/// in src (those with their respective bit set in writemask k) to unaligned
/// memory at base_addr.
bool isCompressingStore() const { return StoreSDNodeBits.IsCompressing; }

const SDValue &getValue() const { return getOperand(1); }
const SDValue &getBasePtr() const { return getOperand(2); }
const SDValue &getOffset() const { return getOperand(3); }
const SDValue &getMask() const { return getOperand(4); }
const SDValue &getVectorLength() const { return getOperand(5); }

static bool classof(const SDNode *N) {
return N->getOpcode() == ISD::VP_STORE;
}
};

/// This base class is used to represent MLOAD and MSTORE nodes
class MaskedLoadStoreSDNode : public MemSDNode {
public:
Expand Down Expand Up @@ -2423,6 +2548,94 @@ class MaskedStoreSDNode : public MaskedLoadStoreSDNode {
}
};

/// This is a base class used to represent
/// VP_GATHER and VP_SCATTER nodes
///
class VPGatherScatterSDNode : public MemSDNode {
public:
friend class SelectionDAG;

VPGatherScatterSDNode(ISD::NodeType NodeTy, unsigned Order,
const DebugLoc &dl, SDVTList VTs, EVT MemVT,
MachineMemOperand *MMO, ISD::MemIndexType IndexType)
: MemSDNode(NodeTy, Order, dl, VTs, MemVT, MMO) {
LSBaseSDNodeBits.AddressingMode = IndexType;
assert(getIndexType() == IndexType && "Value truncated");
}

/// How is Index applied to BasePtr when computing addresses.
ISD::MemIndexType getIndexType() const {
return static_cast<ISD::MemIndexType>(LSBaseSDNodeBits.AddressingMode);
}
bool isIndexScaled() const {
return (getIndexType() == ISD::SIGNED_SCALED) ||
(getIndexType() == ISD::UNSIGNED_SCALED);
}
bool isIndexSigned() const {
return (getIndexType() == ISD::SIGNED_SCALED) ||
(getIndexType() == ISD::SIGNED_UNSCALED);
}

// In the both nodes address is Op1, mask is Op2:
// VPGatherSDNode (Chain, base, index, scale, mask, vlen)
// VPScatterSDNode (Chain, value, base, index, scale, mask, vlen)
// Mask is a vector of i1 elements
const SDValue &getBasePtr() const {
return getOperand((getOpcode() == ISD::VP_GATHER) ? 1 : 2);
}
const SDValue &getIndex() const {
return getOperand((getOpcode() == ISD::VP_GATHER) ? 2 : 3);
}
const SDValue &getScale() const {
return getOperand((getOpcode() == ISD::VP_GATHER) ? 3 : 4);
}
const SDValue &getMask() const {
return getOperand((getOpcode() == ISD::VP_GATHER) ? 4 : 5);
}
const SDValue &getVectorLength() const {
return getOperand((getOpcode() == ISD::VP_GATHER) ? 5 : 6);
}

static bool classof(const SDNode *N) {
return N->getOpcode() == ISD::VP_GATHER ||
N->getOpcode() == ISD::VP_SCATTER;
}
};

/// This class is used to represent an VP_GATHER node
///
class VPGatherSDNode : public VPGatherScatterSDNode {
public:
friend class SelectionDAG;

VPGatherSDNode(unsigned Order, const DebugLoc &dl, SDVTList VTs, EVT MemVT,
MachineMemOperand *MMO, ISD::MemIndexType IndexType)
: VPGatherScatterSDNode(ISD::VP_GATHER, Order, dl, VTs, MemVT, MMO,
IndexType) {}

static bool classof(const SDNode *N) {
return N->getOpcode() == ISD::VP_GATHER;
}
};

/// This class is used to represent an VP_SCATTER node
///
class VPScatterSDNode : public VPGatherScatterSDNode {
public:
friend class SelectionDAG;

VPScatterSDNode(unsigned Order, const DebugLoc &dl, SDVTList VTs, EVT MemVT,
MachineMemOperand *MMO, ISD::MemIndexType IndexType)
: VPGatherScatterSDNode(ISD::VP_SCATTER, Order, dl, VTs, MemVT, MMO,
IndexType) {}

const SDValue &getValue() const { return getOperand(1); }

static bool classof(const SDNode *N) {
return N->getOpcode() == ISD::VP_SCATTER;
}
};

/// This is a base class used to represent
/// MGATHER and MSCATTER nodes
///
Expand Down
6 changes: 4 additions & 2 deletions llvm/include/llvm/IR/VPIntrinsics.def
Expand Up @@ -209,7 +209,8 @@ HELPER_REGISTER_BINARY_FP_VP(frem, VP_FREM, FRem)

///// Memory Operations {
// llvm.vp.store(ptr,val,mask,vlen)
BEGIN_REGISTER_VP(vp_store, 2, 3, VP_STORE, 0)
BEGIN_REGISTER_VP_INTRINSIC(vp_store, 2, 3)
BEGIN_REGISTER_VP_SDNODE(VP_STORE, 0, vp_store, 3, 4)
HANDLE_VP_TO_OPC(Store)
HANDLE_VP_TO_INTRIN(masked_store)
HANDLE_VP_IS_MEMOP(vp_store, 1, 0)
Expand All @@ -222,7 +223,8 @@ HANDLE_VP_IS_MEMOP(vp_scatter, 1, 0)
END_REGISTER_VP(vp_scatter, VP_SCATTER)

// llvm.vp.load(ptr,mask,vlen)
BEGIN_REGISTER_VP(vp_load, 1, 2, VP_LOAD, -1)
BEGIN_REGISTER_VP_INTRINSIC(vp_load, 1, 2)
BEGIN_REGISTER_VP_SDNODE(VP_LOAD, -1, vp_load, 2, 3)
HANDLE_VP_TO_OPC(Load)
HANDLE_VP_TO_INTRIN(masked_load)
HANDLE_VP_IS_MEMOP(vp_load, 0, None)
Expand Down

0 comments on commit 524ded7

Please sign in to comment.