Skip to content

Commit

Permalink
DAG: Setting Masked-Expand-Load as a variant of Masked-Load node
Browse files Browse the repository at this point in the history
Masked-expand-load node represents load operation that loads a variable amount of elements from memory according to amount of "true" bits in the mask and expands the loaded elements according to their position in the mask vector.
Right now, the node is used in intrinsics for VEXPAND* instructions. 
The work is done towards implementation of masked.expandload and masked.compressstore intrinsics.

Differential Revision: https://reviews.llvm.org/D25322

llvm-svn: 283694
  • Loading branch information
Elena Demikhovsky committed Oct 9, 2016
1 parent b270288 commit 5b10aa1
Show file tree
Hide file tree
Showing 9 changed files with 97 additions and 46 deletions.
7 changes: 4 additions & 3 deletions llvm/include/llvm/CodeGen/SelectionDAG.h
Expand Up @@ -965,11 +965,12 @@ class SelectionDAG {

SDValue getMaskedLoad(EVT VT, const SDLoc &dl, SDValue Chain, SDValue Ptr,
SDValue Mask, SDValue Src0, EVT MemVT,
MachineMemOperand *MMO, ISD::LoadExtType);
MachineMemOperand *MMO, ISD::LoadExtType,
bool IsExpanding = false);
SDValue getMaskedStore(SDValue Chain, const SDLoc &dl, SDValue Val,
SDValue Ptr, SDValue Mask, EVT MemVT,
MachineMemOperand *MMO, bool IsTrunc,
bool isCompressing = false);
MachineMemOperand *MMO, bool IsTruncating = false,
bool IsCompressing = false);
SDValue getMaskedGather(SDVTList VTs, EVT VT, const SDLoc &dl,
ArrayRef<SDValue> Ops, MachineMemOperand *MMO);
SDValue getMaskedScatter(SDVTList VTs, EVT VT, const SDLoc &dl,
Expand Down
9 changes: 7 additions & 2 deletions llvm/include/llvm/CodeGen/SelectionDAGNodes.h
Expand Up @@ -444,6 +444,7 @@ class SDNode : public FoldingSetNode, public ilist_node<SDNode> {
uint16_t : NumLSBaseSDNodeBits;

uint16_t ExtTy : 2; // enum ISD::LoadExtType
uint16_t IsExpanding : 1;
};

class StoreSDNodeBitfields {
Expand Down Expand Up @@ -473,7 +474,7 @@ class SDNode : public FoldingSetNode, public ilist_node<SDNode> {
static_assert(sizeof(ConstantSDNodeBitfields) <= 2, "field too wide");
static_assert(sizeof(MemSDNodeBitfields) <= 2, "field too wide");
static_assert(sizeof(LSBaseSDNodeBitfields) <= 2, "field too wide");
static_assert(sizeof(LoadSDNodeBitfields) <= 2, "field too wide");
static_assert(sizeof(LoadSDNodeBitfields) <= 4, "field too wide");
static_assert(sizeof(StoreSDNodeBitfields) <= 2, "field too wide");

private:
Expand Down Expand Up @@ -1939,9 +1940,11 @@ class MaskedLoadSDNode : public MaskedLoadStoreSDNode {
public:
friend class SelectionDAG;
MaskedLoadSDNode(unsigned Order, const DebugLoc &dl, SDVTList VTs,
ISD::LoadExtType ETy, EVT MemVT, MachineMemOperand *MMO)
ISD::LoadExtType ETy, bool IsExpanding, EVT MemVT,
MachineMemOperand *MMO)
: MaskedLoadStoreSDNode(ISD::MLOAD, Order, dl, VTs, MemVT, MMO) {
LoadSDNodeBits.ExtTy = ETy;
LoadSDNodeBits.IsExpanding = IsExpanding;
}

ISD::LoadExtType getExtensionType() const {
Expand All @@ -1952,6 +1955,8 @@ class MaskedLoadSDNode : public MaskedLoadStoreSDNode {
static bool classof(const SDNode *N) {
return N->getOpcode() == ISD::MLOAD;
}

bool isExpandingLoad() const { return LoadSDNodeBits.IsExpanding; }
};

/// This class is used to represent an MSTORE node
Expand Down
12 changes: 6 additions & 6 deletions llvm/lib/CodeGen/SelectionDAG/SelectionDAG.cpp
Expand Up @@ -5347,23 +5347,23 @@ SDValue SelectionDAG::getIndexedStore(SDValue OrigStore, const SDLoc &dl,
SDValue SelectionDAG::getMaskedLoad(EVT VT, const SDLoc &dl, SDValue Chain,
SDValue Ptr, SDValue Mask, SDValue Src0,
EVT MemVT, MachineMemOperand *MMO,
ISD::LoadExtType ExtTy) {
ISD::LoadExtType ExtTy, bool isExpanding) {

SDVTList VTs = getVTList(VT, MVT::Other);
SDValue Ops[] = { Chain, Ptr, Mask, Src0 };
FoldingSetNodeID ID;
AddNodeIDNode(ID, ISD::MLOAD, VTs, Ops);
ID.AddInteger(VT.getRawBits());
ID.AddInteger(getSyntheticNodeSubclassData<MaskedLoadSDNode>(
dl.getIROrder(), VTs, ExtTy, MemVT, MMO));
dl.getIROrder(), VTs, ExtTy, isExpanding, MemVT, MMO));
ID.AddInteger(MMO->getPointerInfo().getAddrSpace());
void *IP = nullptr;
if (SDNode *E = FindNodeOrInsertPos(ID, dl, IP)) {
cast<MaskedLoadSDNode>(E)->refineAlignment(MMO);
return SDValue(E, 0);
}
auto *N = newSDNode<MaskedLoadSDNode>(dl.getIROrder(), dl.getDebugLoc(), VTs,
ExtTy, MemVT, MMO);
ExtTy, isExpanding, MemVT, MMO);
createOperands(N, Ops);

CSEMap.InsertNode(N, IP);
Expand All @@ -5374,7 +5374,7 @@ SDValue SelectionDAG::getMaskedLoad(EVT VT, const SDLoc &dl, SDValue Chain,
SDValue SelectionDAG::getMaskedStore(SDValue Chain, const SDLoc &dl,
SDValue Val, SDValue Ptr, SDValue Mask,
EVT MemVT, MachineMemOperand *MMO,
bool isTrunc, bool isCompress) {
bool IsTruncating, bool IsCompressing) {
assert(Chain.getValueType() == MVT::Other &&
"Invalid chain type");
EVT VT = Val.getValueType();
Expand All @@ -5384,15 +5384,15 @@ SDValue SelectionDAG::getMaskedStore(SDValue Chain, const SDLoc &dl,
AddNodeIDNode(ID, ISD::MSTORE, VTs, Ops);
ID.AddInteger(VT.getRawBits());
ID.AddInteger(getSyntheticNodeSubclassData<MaskedStoreSDNode>(
dl.getIROrder(), VTs, isTrunc, isCompress, MemVT, MMO));
dl.getIROrder(), VTs, IsTruncating, IsCompressing, MemVT, MMO));
ID.AddInteger(MMO->getPointerInfo().getAddrSpace());
void *IP = nullptr;
if (SDNode *E = FindNodeOrInsertPos(ID, dl, IP)) {
cast<MaskedStoreSDNode>(E)->refineAlignment(MMO);
return SDValue(E, 0);
}
auto *N = newSDNode<MaskedStoreSDNode>(dl.getIROrder(), dl.getDebugLoc(), VTs,
isTrunc, isCompress, MemVT, MMO);
IsTruncating, IsCompressing, MemVT, MMO);
createOperands(N, Ops);

CSEMap.InsertNode(N, IP);
Expand Down
2 changes: 1 addition & 1 deletion llvm/lib/CodeGen/SelectionDAG/SelectionDAGBuilder.cpp
Expand Up @@ -3821,7 +3821,7 @@ void SelectionDAGBuilder::visitMaskedLoad(const CallInst &I) {
Alignment, AAInfo, Ranges);

SDValue Load = DAG.getMaskedLoad(VT, sdl, InChain, Ptr, Mask, Src0, VT, MMO,
ISD::NON_EXTLOAD);
ISD::NON_EXTLOAD, false);
if (AddToChain) {
SDValue OutChain = Load.getValue(1);
DAG.setRoot(OutChain);
Expand Down
23 changes: 12 additions & 11 deletions llvm/lib/Target/X86/X86ISelLowering.cpp
Expand Up @@ -18854,7 +18854,8 @@ static SDValue LowerINTRINSIC_W_CHAIN(SDValue Op, const X86Subtarget &Subtarget,
SDValue VMask = getMaskNode(Mask, MaskVT, Subtarget, DAG, dl);

return DAG.getMaskedStore(Chain, dl, DataToCompress, Addr, VMask, VT,
MemIntr->getMemOperand(), false, true);
MemIntr->getMemOperand(),
false /* truncating */, true /* compressing */);
}
case TRUNCATE_TO_MEM_VI8:
case TRUNCATE_TO_MEM_VI16:
Expand All @@ -18877,7 +18878,7 @@ static SDValue LowerINTRINSIC_W_CHAIN(SDValue Op, const X86Subtarget &Subtarget,
SDValue VMask = getMaskNode(Mask, MaskVT, Subtarget, DAG, dl);

return DAG.getMaskedStore(Chain, dl, DataToTruncate, Addr, VMask, VT,
MemIntr->getMemOperand(), true);
MemIntr->getMemOperand(), true /* truncating */);
}
case EXPAND_FROM_MEM: {
SDValue Mask = Op.getOperand(4);
Expand All @@ -18889,16 +18890,16 @@ static SDValue LowerINTRINSIC_W_CHAIN(SDValue Op, const X86Subtarget &Subtarget,
MemIntrinsicSDNode *MemIntr = dyn_cast<MemIntrinsicSDNode>(Op);
assert(MemIntr && "Expected MemIntrinsicSDNode!");

SDValue DataToExpand = DAG.getLoad(VT, dl, Chain, Addr,
MemIntr->getMemOperand());
if (isAllOnesConstant(Mask)) // Return a regular (unmasked) vector load.
return DAG.getLoad(VT, dl, Chain, Addr, MemIntr->getMemOperand());
if (X86::isZeroNode(Mask))
return DAG.getUNDEF(VT);

if (isAllOnesConstant(Mask)) // return just a load
return DataToExpand;

SDValue Results[] = {
getVectorMaskingNode(DAG.getNode(IntrData->Opc0, dl, VT, DataToExpand),
Mask, PassThru, Subtarget, DAG), Chain};
return DAG.getMergeValues(Results, dl);
MVT MaskVT = MVT::getVectorVT(MVT::i1, VT.getVectorNumElements());
SDValue VMask = getMaskNode(Mask, MaskVT, Subtarget, DAG, dl);
return DAG.getMaskedLoad(VT, dl, Chain, Addr, VMask, PassThru, VT,
MemIntr->getMemOperand(), ISD::NON_EXTLOAD,
true /* expanding */);
}
}
}
Expand Down
21 changes: 18 additions & 3 deletions llvm/lib/Target/X86/X86InstrAVX512.td
Expand Up @@ -7552,13 +7552,28 @@ multiclass expand_by_vec_width<bits<8> opc, X86VectorVTInfo _,
AVX5128IBase, EVEX_CD8<_.EltSize, CD8VT1>;
}

multiclass expand_by_vec_width_lowering<X86VectorVTInfo _ > {

def : Pat<(_.VT (X86mExpandingLoad addr:$src, _.KRCWM:$mask, undef)),
(!cast<Instruction>(NAME#_.ZSuffix##rmkz)
_.KRCWM:$mask, addr:$src)>;

def : Pat<(_.VT (X86mExpandingLoad addr:$src, _.KRCWM:$mask,
(_.VT _.RC:$src0))),
(!cast<Instruction>(NAME#_.ZSuffix##rmk)
_.RC:$src0, _.KRCWM:$mask, addr:$src)>;
}

multiclass expand_by_elt_width<bits<8> opc, string OpcodeStr,
AVX512VLVectorVTInfo VTInfo> {
defm Z : expand_by_vec_width<opc, VTInfo.info512, OpcodeStr>, EVEX_V512;
defm Z : expand_by_vec_width<opc, VTInfo.info512, OpcodeStr>,
expand_by_vec_width_lowering<VTInfo.info512>, EVEX_V512;

let Predicates = [HasVLX] in {
defm Z256 : expand_by_vec_width<opc, VTInfo.info256, OpcodeStr>, EVEX_V256;
defm Z128 : expand_by_vec_width<opc, VTInfo.info128, OpcodeStr>, EVEX_V128;
defm Z256 : expand_by_vec_width<opc, VTInfo.info256, OpcodeStr>,
expand_by_vec_width_lowering<VTInfo.info256>, EVEX_V256;
defm Z128 : expand_by_vec_width<opc, VTInfo.info128, OpcodeStr>,
expand_by_vec_width_lowering<VTInfo.info128>, EVEX_V128;
}
}

Expand Down
32 changes: 19 additions & 13 deletions llvm/lib/Target/X86/X86InstrFragmentsSIMD.td
Expand Up @@ -919,30 +919,36 @@ def vinsert256_insert : PatFrag<(ops node:$bigvec, node:$smallvec,
return X86::isVINSERT256Index(N);
}], INSERT_get_vinsert256_imm>;

def masked_load_aligned128 : PatFrag<(ops node:$src1, node:$src2, node:$src3),
def X86mload : PatFrag<(ops node:$src1, node:$src2, node:$src3),
(masked_load node:$src1, node:$src2, node:$src3), [{
if (auto *Load = dyn_cast<MaskedLoadSDNode>(N))
return Load->getAlignment() >= 16;
return false;
return !cast<MaskedLoadSDNode>(N)->isExpandingLoad() &&
cast<MaskedLoadSDNode>(N)->getExtensionType() == ISD::NON_EXTLOAD;
}]>;

def masked_load_aligned128 : PatFrag<(ops node:$src1, node:$src2, node:$src3),
(X86mload node:$src1, node:$src2, node:$src3), [{
return cast<MaskedLoadSDNode>(N)->getAlignment() >= 16;
}]>;

def masked_load_aligned256 : PatFrag<(ops node:$src1, node:$src2, node:$src3),
(masked_load node:$src1, node:$src2, node:$src3), [{
if (auto *Load = dyn_cast<MaskedLoadSDNode>(N))
return Load->getAlignment() >= 32;
return false;
(X86mload node:$src1, node:$src2, node:$src3), [{
return cast<MaskedLoadSDNode>(N)->getAlignment() >= 32;
}]>;

def masked_load_aligned512 : PatFrag<(ops node:$src1, node:$src2, node:$src3),
(masked_load node:$src1, node:$src2, node:$src3), [{
if (auto *Load = dyn_cast<MaskedLoadSDNode>(N))
return Load->getAlignment() >= 64;
return false;
(X86mload node:$src1, node:$src2, node:$src3), [{
return cast<MaskedLoadSDNode>(N)->getAlignment() >= 64;
}]>;

def masked_load_unaligned : PatFrag<(ops node:$src1, node:$src2, node:$src3),
(masked_load node:$src1, node:$src2, node:$src3), [{
return isa<MaskedLoadSDNode>(N);
return !cast<MaskedLoadSDNode>(N)->isExpandingLoad() &&
cast<MaskedLoadSDNode>(N)->getExtensionType() == ISD::NON_EXTLOAD;
}]>;

def X86mExpandingLoad : PatFrag<(ops node:$src1, node:$src2, node:$src3),
(masked_load node:$src1, node:$src2, node:$src3), [{
return cast<MaskedLoadSDNode>(N)->isExpandingLoad();
}]>;

// Masked store fragments.
Expand Down
6 changes: 3 additions & 3 deletions llvm/lib/Target/X86/X86InstrSSE.td
Expand Up @@ -8622,12 +8622,12 @@ multiclass maskmov_lowering<string InstrStr, RegisterClass RC, ValueType VT,
def: Pat<(X86mstore addr:$ptr, (MaskVT RC:$mask), (VT RC:$src)),
(!cast<Instruction>(InstrStr#"mr") addr:$ptr, RC:$mask, RC:$src)>;
// masked load
def: Pat<(VT (masked_load addr:$ptr, (MaskVT RC:$mask), undef)),
def: Pat<(VT (X86mload addr:$ptr, (MaskVT RC:$mask), undef)),
(!cast<Instruction>(InstrStr#"rm") RC:$mask, addr:$ptr)>;
def: Pat<(VT (masked_load addr:$ptr, (MaskVT RC:$mask),
def: Pat<(VT (X86mload addr:$ptr, (MaskVT RC:$mask),
(VT (bitconvert (ZeroVT immAllZerosV))))),
(!cast<Instruction>(InstrStr#"rm") RC:$mask, addr:$ptr)>;
def: Pat<(VT (masked_load addr:$ptr, (MaskVT RC:$mask), (VT RC:$src0))),
def: Pat<(VT (X86mload addr:$ptr, (MaskVT RC:$mask), (VT RC:$src0))),
(!cast<Instruction>(BlendStr#"rr")
RC:$src0,
(!cast<Instruction>(InstrStr#"rm") RC:$mask, addr:$ptr),
Expand Down
31 changes: 27 additions & 4 deletions llvm/test/CodeGen/X86/avx512vl-intrinsics.ll
Expand Up @@ -1042,6 +1042,29 @@ define <4 x i32> @expand10(<4 x i32> %data, i8 %mask) {

declare <4 x i32> @llvm.x86.avx512.mask.expand.d.128(<4 x i32> %data, <4 x i32> %src0, i8 %mask)

define <8 x i64> @expand11(i8* %addr) {
; CHECK-LABEL: expand11:
; CHECK: ## BB#0:
; CHECK-NEXT: vmovups (%rdi), %zmm0 ## encoding: [0x62,0xf1,0x7c,0x48,0x10,0x07]
; CHECK-NEXT: retq ## encoding: [0xc3]
%res = call <8 x i64> @llvm.x86.avx512.mask.expand.load.q.512(i8* %addr, <8 x i64> undef, i8 -1)
ret <8 x i64> %res
}

define <8 x i64> @expand12(i8* %addr, i8 %mask) {
; CHECK-LABEL: expand12:
; CHECK: ## BB#0:
; CHECK-NEXT: kmovw %esi, %k1 ## encoding: [0xc5,0xf8,0x92,0xce]
; CHECK-NEXT: vpexpandq (%rdi), %zmm0 {%k1} {z} ## encoding: [0x62,0xf2,0xfd,0xc9,0x89,0x07]
; CHECK-NEXT: retq ## encoding: [0xc3]
%laddr = bitcast i8* %addr to <8 x i64>*
%data = load <8 x i64>, <8 x i64>* %laddr, align 1
%res = call <8 x i64> @llvm.x86.avx512.mask.expand.q.512(<8 x i64> %data, <8 x i64>zeroinitializer, i8 %mask)
ret <8 x i64> %res
}

declare <8 x i64> @llvm.x86.avx512.mask.expand.q.512(<8 x i64> , <8 x i64>, i8)

define < 2 x i64> @test_mask_mul_epi32_rr_128(< 4 x i32> %a, < 4 x i32> %b) {
; CHECK-LABEL: test_mask_mul_epi32_rr_128:
; CHECK: ## BB#0:
Expand Down Expand Up @@ -5250,9 +5273,9 @@ define <8 x i32>@test_int_x86_avx512_mask_psrav8_si_const() {
; CHECK: ## BB#0:
; CHECK-NEXT: vmovdqa32 {{.*#+}} ymm0 = [2,9,4294967284,23,4294967270,37,4294967256,51]
; CHECK-NEXT: ## encoding: [0x62,0xf1,0x7d,0x28,0x6f,0x05,A,A,A,A]
; CHECK-NEXT: ## fixup A - offset: 6, value: LCPI309_0-4, kind: reloc_riprel_4byte
; CHECK-NEXT: ## fixup A - offset: 6, value: LCPI311_0-4, kind: reloc_riprel_4byte
; CHECK-NEXT: vpsravd {{.*}}(%rip), %ymm0, %ymm0 ## encoding: [0x62,0xf2,0x7d,0x28,0x46,0x05,A,A,A,A]
; CHECK-NEXT: ## fixup A - offset: 6, value: LCPI309_1-4, kind: reloc_riprel_4byte
; CHECK-NEXT: ## fixup A - offset: 6, value: LCPI311_1-4, kind: reloc_riprel_4byte
; CHECK-NEXT: retq ## encoding: [0xc3]
%res = call <8 x i32> @llvm.x86.avx512.mask.psrav8.si(<8 x i32> <i32 2, i32 9, i32 -12, i32 23, i32 -26, i32 37, i32 -40, i32 51>, <8 x i32> <i32 1, i32 18, i32 35, i32 52, i32 69, i32 15, i32 32, i32 49>, <8 x i32> zeroinitializer, i8 -1)
ret <8 x i32> %res
Expand Down Expand Up @@ -5283,9 +5306,9 @@ define <2 x i64>@test_int_x86_avx512_mask_psrav_q_128_const(i8 %x3) {
; CHECK: ## BB#0:
; CHECK-NEXT: vmovdqa64 {{.*#+}} xmm0 = [2,18446744073709551607]
; CHECK-NEXT: ## encoding: [0x62,0xf1,0xfd,0x08,0x6f,0x05,A,A,A,A]
; CHECK-NEXT: ## fixup A - offset: 6, value: LCPI311_0-4, kind: reloc_riprel_4byte
; CHECK-NEXT: ## fixup A - offset: 6, value: LCPI313_0-4, kind: reloc_riprel_4byte
; CHECK-NEXT: vpsravq {{.*}}(%rip), %xmm0, %xmm0 ## encoding: [0x62,0xf2,0xfd,0x08,0x46,0x05,A,A,A,A]
; CHECK-NEXT: ## fixup A - offset: 6, value: LCPI311_1-4, kind: reloc_riprel_4byte
; CHECK-NEXT: ## fixup A - offset: 6, value: LCPI313_1-4, kind: reloc_riprel_4byte
; CHECK-NEXT: retq ## encoding: [0xc3]
%res = call <2 x i64> @llvm.x86.avx512.mask.psrav.q.128(<2 x i64> <i64 2, i64 -9>, <2 x i64> <i64 1, i64 90>, <2 x i64> zeroinitializer, i8 -1)
ret <2 x i64> %res
Expand Down

0 comments on commit 5b10aa1

Please sign in to comment.