Skip to content

Commit

Permalink
[X86] Replace selectScalarSSELoad ComplexPattern with PatFrags to han…
Browse files Browse the repository at this point in the history
…dle the 3 types of loads we currently match.

This ensures we create mem operands for these instructions fixing PR45949.

Unfortunately, it increases the size of X86GenDAGISel.inc, but some dag
combine canonicalization could reduce the types of load we need to match.
  • Loading branch information
topperc committed May 16, 2020
1 parent 0ec5f50 commit 135b877
Show file tree
Hide file tree
Showing 5 changed files with 63 additions and 138 deletions.
75 changes: 0 additions & 75 deletions llvm/lib/Target/X86/X86ISelDAGToDAG.cpp
Expand Up @@ -229,11 +229,6 @@ namespace {
bool selectTLSADDRAddr(SDValue N, SDValue &Base,
SDValue &Scale, SDValue &Index, SDValue &Disp,
SDValue &Segment);
bool selectScalarSSELoad(SDNode *Root, SDNode *Parent, SDValue N,
SDValue &Base, SDValue &Scale,
SDValue &Index, SDValue &Disp,
SDValue &Segment,
SDValue &NodeWithChain);
bool selectRelocImm(SDValue N, SDValue &Op);

bool tryFoldLoad(SDNode *Root, SDNode *P, SDValue N,
Expand Down Expand Up @@ -2473,76 +2468,6 @@ bool X86DAGToDAGISel::selectAddr(SDNode *Parent, SDValue N, SDValue &Base,
return true;
}

// We can only fold a load if all nodes between it and the root node have a
// single use. If there are additional uses, we could end up duplicating the
// load.
static bool hasSingleUsesFromRoot(SDNode *Root, SDNode *User) {
while (User != Root) {
if (!User->hasOneUse())
return false;
User = *User->use_begin();
}

return true;
}

/// Match a scalar SSE load. In particular, we want to match a load whose top
/// elements are either undef or zeros. The load flavor is derived from the
/// type of N, which is either v4f32 or v2f64.
///
/// We also return:
/// PatternChainNode: this is the matched node that has a chain input and
/// output.
bool X86DAGToDAGISel::selectScalarSSELoad(SDNode *Root, SDNode *Parent,
SDValue N, SDValue &Base,
SDValue &Scale, SDValue &Index,
SDValue &Disp, SDValue &Segment,
SDValue &PatternNodeWithChain) {
if (!hasSingleUsesFromRoot(Root, Parent))
return false;

// We can allow a full vector load here since narrowing a load is ok unless
// it's volatile or atomic.
if (ISD::isNON_EXTLoad(N.getNode())) {
LoadSDNode *LD = cast<LoadSDNode>(N);
if (LD->isSimple() &&
IsProfitableToFold(N, LD, Root) &&
IsLegalToFold(N, Parent, Root, OptLevel)) {
PatternNodeWithChain = N;
return selectAddr(LD, LD->getBasePtr(), Base, Scale, Index, Disp,
Segment);
}
}

// We can also match the special zero extended load opcode.
if (N.getOpcode() == X86ISD::VZEXT_LOAD) {
PatternNodeWithChain = N;
if (IsProfitableToFold(PatternNodeWithChain, N.getNode(), Root) &&
IsLegalToFold(PatternNodeWithChain, Parent, Root, OptLevel)) {
auto *MI = cast<MemIntrinsicSDNode>(PatternNodeWithChain);
return selectAddr(MI, MI->getBasePtr(), Base, Scale, Index, Disp,
Segment);
}
}

// Need to make sure that the SCALAR_TO_VECTOR and load are both only used
// once. Otherwise the load might get duplicated and the chain output of the
// duplicate load will not be observed by all dependencies.
if (N.getOpcode() == ISD::SCALAR_TO_VECTOR && N.getNode()->hasOneUse()) {
PatternNodeWithChain = N.getOperand(0);
if (ISD::isNON_EXTLoad(PatternNodeWithChain.getNode()) &&
IsProfitableToFold(PatternNodeWithChain, N.getNode(), Root) &&
IsLegalToFold(PatternNodeWithChain, N.getNode(), Root, OptLevel)) {
LoadSDNode *LD = cast<LoadSDNode>(PatternNodeWithChain);
return selectAddr(LD, LD->getBasePtr(), Base, Scale, Index, Disp,
Segment);
}
}

return false;
}


bool X86DAGToDAGISel::selectMOV64Imm32(SDValue N, SDValue &Imm) {
if (const ConstantSDNode *CN = dyn_cast<ConstantSDNode>(N)) {
uint64_t ImmVal = CN->getZExtValue();
Expand Down
42 changes: 21 additions & 21 deletions llvm/lib/Target/X86/X86InstrAVX512.td
Expand Up @@ -76,11 +76,11 @@ class X86VectorVTInfo<int numelts, ValueType eltvt, RegisterClass rc,
PatFrag ScalarLdFrag = !cast<PatFrag>("load" # EltVT);
PatFrag BroadcastLdFrag = !cast<PatFrag>("X86VBroadcastld" # EltSizeName);

ComplexPattern ScalarIntMemCPat = !if (!eq (EltTypeName, "f32"),
!cast<ComplexPattern>("sse_load_f32"),
!if (!eq (EltTypeName, "f64"),
!cast<ComplexPattern>("sse_load_f64"),
?));
PatFrags ScalarIntMemFrags = !if (!eq (EltTypeName, "f32"),
!cast<PatFrags>("sse_load_f32"),
!if (!eq (EltTypeName, "f64"),
!cast<PatFrags>("sse_load_f64"),
?));

// The string to specify embedded broadcast in assembly.
string BroadcastStr = "{1to" # NumElts # "}";
Expand Down Expand Up @@ -2065,9 +2065,9 @@ multiclass avx512_cmp_scalar<X86VectorVTInfo _, SDNode OpNode, SDNode OpNodeSAE,
(ins _.RC:$src1, _.IntScalarMemOp:$src2, u8imm:$cc),
"vcmp"#_.Suffix,
"$cc, $src2, $src1", "$src1, $src2, $cc",
(OpNode (_.VT _.RC:$src1), _.ScalarIntMemCPat:$src2,
(OpNode (_.VT _.RC:$src1), (_.ScalarIntMemFrags addr:$src2),
timm:$cc),
(OpNode_su (_.VT _.RC:$src1), _.ScalarIntMemCPat:$src2,
(OpNode_su (_.VT _.RC:$src1), (_.ScalarIntMemFrags addr:$src2),
timm:$cc)>, EVEX_4V, VEX_LIG, EVEX_CD8<_.EltSize, CD8VT1>,
Sched<[sched.Folded, sched.ReadAfterFold]>, SIMD_EXC;

Expand Down Expand Up @@ -2643,15 +2643,15 @@ multiclass avx512_scalar_fpclass<bits<8> opc, string OpcodeStr,
OpcodeStr#_.Suffix#
"\t{$src2, $src1, $dst|$dst, $src1, $src2}",
[(set _.KRC:$dst,
(X86Vfpclasss _.ScalarIntMemCPat:$src1,
(i32 timm:$src2)))]>,
(X86Vfpclasss (_.ScalarIntMemFrags addr:$src1),
(i32 timm:$src2)))]>,
Sched<[sched.Folded, sched.ReadAfterFold]>;
def rmk : AVX512<opc, MRMSrcMem, (outs _.KRC:$dst),
(ins _.KRCWM:$mask, _.IntScalarMemOp:$src1, i32u8imm:$src2),
OpcodeStr#_.Suffix#
"\t{$src2, $src1, $dst {${mask}}|$dst {${mask}}, $src1, $src2}",
[(set _.KRC:$dst,(and _.KRCWM:$mask,
(X86Vfpclasss_su _.ScalarIntMemCPat:$src1,
(X86Vfpclasss_su (_.ScalarIntMemFrags addr:$src1),
(i32 timm:$src2))))]>,
EVEX_K, Sched<[sched.Folded, sched.ReadAfterFold]>;
}
Expand Down Expand Up @@ -5293,7 +5293,7 @@ multiclass avx512_fp_scalar<bits<8> opc, string OpcodeStr,X86VectorVTInfo _,
(ins _.RC:$src1, _.IntScalarMemOp:$src2), OpcodeStr,
"$src2, $src1", "$src1, $src2",
(_.VT (VecNode _.RC:$src1,
_.ScalarIntMemCPat:$src2))>,
(_.ScalarIntMemFrags addr:$src2)))>,
Sched<[sched.Folded, sched.ReadAfterFold]>;
let isCodeGenOnly = 1, Predicates = [HasAVX512] in {
def rr : I< opc, MRMSrcReg, (outs _.FRC:$dst),
Expand Down Expand Up @@ -5339,7 +5339,7 @@ multiclass avx512_fp_scalar_sae<bits<8> opc, string OpcodeStr,X86VectorVTInfo _,
(ins _.RC:$src1, _.IntScalarMemOp:$src2), OpcodeStr,
"$src2, $src1", "$src1, $src2",
(_.VT (VecNode _.RC:$src1,
_.ScalarIntMemCPat:$src2))>,
(_.ScalarIntMemFrags addr:$src2)))>,
Sched<[sched.Folded, sched.ReadAfterFold]>, SIMD_EXC;

let isCodeGenOnly = 1, Predicates = [HasAVX512],
Expand Down Expand Up @@ -5628,7 +5628,7 @@ multiclass avx512_fp_scalef_scalar<bits<8> opc, string OpcodeStr, SDNode OpNode,
defm rm: AVX512_maskable_scalar<opc, MRMSrcMem, _, (outs _.RC:$dst),
(ins _.RC:$src1, _.IntScalarMemOp:$src2), OpcodeStr#_.Suffix,
"$src2, $src1", "$src1, $src2",
(OpNode _.RC:$src1, _.ScalarIntMemCPat:$src2)>,
(OpNode _.RC:$src1, (_.ScalarIntMemFrags addr:$src2))>,
Sched<[sched.Folded, sched.ReadAfterFold]>;
}
}
Expand Down Expand Up @@ -7227,7 +7227,7 @@ multiclass avx512_cvt_s_int_round<bits<8> opc, X86VectorVTInfo SrcVT,
def rm_Int : SI<opc, MRMSrcMem, (outs DstVT.RC:$dst), (ins SrcVT.IntScalarMemOp:$src),
!strconcat(asm,"\t{$src, $dst|$dst, $src}"),
[(set DstVT.RC:$dst, (OpNode
(SrcVT.VT SrcVT.ScalarIntMemCPat:$src)))]>,
(SrcVT.ScalarIntMemFrags addr:$src)))]>,
EVEX, VEX_LIG, Sched<[sched.Folded, sched.ReadAfterFold]>, SIMD_EXC;
} // Predicates = [HasAVX512]

Expand Down Expand Up @@ -7419,7 +7419,7 @@ let Predicates = [HasAVX512], ExeDomain = _SrcRC.ExeDomain in {
(ins _SrcRC.IntScalarMemOp:$src),
!strconcat(asm,"\t{$src, $dst|$dst, $src}"),
[(set _DstRC.RC:$dst,
(OpNodeInt (_SrcRC.VT _SrcRC.ScalarIntMemCPat:$src)))]>,
(OpNodeInt (_SrcRC.ScalarIntMemFrags addr:$src)))]>,
EVEX, VEX_LIG, Sched<[sched.Folded, sched.ReadAfterFold]>, SIMD_EXC;
} //HasAVX512

Expand Down Expand Up @@ -7476,7 +7476,7 @@ multiclass avx512_cvt_fp_scalar<bits<8> opc, string OpcodeStr, X86VectorVTInfo _
(ins _.RC:$src1, _Src.IntScalarMemOp:$src2), OpcodeStr,
"$src2, $src1", "$src1, $src2",
(_.VT (OpNode (_.VT _.RC:$src1),
(_Src.VT _Src.ScalarIntMemCPat:$src2)))>,
(_Src.ScalarIntMemFrags addr:$src2)))>,
EVEX_4V, VEX_LIG,
Sched<[sched.Folded, sched.ReadAfterFold]>;

Expand Down Expand Up @@ -8710,7 +8710,7 @@ multiclass avx512_fp14_s<bits<8> opc, string OpcodeStr, SDNode OpNode,
(ins _.RC:$src1, _.IntScalarMemOp:$src2), OpcodeStr,
"$src2, $src1", "$src1, $src2",
(OpNode (_.VT _.RC:$src1),
_.ScalarIntMemCPat:$src2)>, EVEX_4V, VEX_LIG,
(_.ScalarIntMemFrags addr:$src2))>, EVEX_4V, VEX_LIG,
Sched<[sched.Folded, sched.ReadAfterFold]>;
}
}
Expand Down Expand Up @@ -8798,7 +8798,7 @@ multiclass avx512_fp28_s<bits<8> opc, string OpcodeStr,X86VectorVTInfo _,
defm m : AVX512_maskable_scalar<opc, MRMSrcMem, _, (outs _.RC:$dst),
(ins _.RC:$src1, _.IntScalarMemOp:$src2), OpcodeStr,
"$src2, $src1", "$src1, $src2",
(OpNode (_.VT _.RC:$src1), _.ScalarIntMemCPat:$src2)>,
(OpNode (_.VT _.RC:$src1), (_.ScalarIntMemFrags addr:$src2))>,
Sched<[sched.Folded, sched.ReadAfterFold]>, SIMD_EXC;
}
}
Expand Down Expand Up @@ -8977,7 +8977,7 @@ multiclass avx512_sqrt_scalar<bits<8> opc, string OpcodeStr, X86FoldableSchedWri
(ins _.RC:$src1, _.IntScalarMemOp:$src2), OpcodeStr,
"$src2, $src1", "$src1, $src2",
(X86fsqrts (_.VT _.RC:$src1),
_.ScalarIntMemCPat:$src2)>,
(_.ScalarIntMemFrags addr:$src2))>,
Sched<[sched.Folded, sched.ReadAfterFold]>, SIMD_EXC;
let Uses = [MXCSR] in
defm rb_Int : AVX512_maskable_scalar<opc, MRMSrcReg, _, (outs _.RC:$dst),
Expand Down Expand Up @@ -9050,7 +9050,7 @@ multiclass avx512_rndscale_scalar<bits<8> opc, string OpcodeStr,
OpcodeStr,
"$src3, $src2, $src1", "$src1, $src2, $src3",
(_.VT (X86RndScales _.RC:$src1,
_.ScalarIntMemCPat:$src2, (i32 timm:$src3)))>,
(_.ScalarIntMemFrags addr:$src2), (i32 timm:$src3)))>,
Sched<[sched.Folded, sched.ReadAfterFold]>, SIMD_EXC;

let isCodeGenOnly = 1, hasSideEffects = 0, Predicates = [HasAVX512] in {
Expand Down Expand Up @@ -10221,7 +10221,7 @@ multiclass avx512_fp_scalar_imm<bits<8> opc, string OpcodeStr, SDNode OpNode,
(ins _.RC:$src1, _.IntScalarMemOp:$src2, i32u8imm:$src3),
OpcodeStr, "$src3, $src2, $src1", "$src1, $src2, $src3",
(OpNode (_.VT _.RC:$src1),
(_.VT _.ScalarIntMemCPat:$src2),
(_.ScalarIntMemFrags addr:$src2),
(i32 timm:$src3))>,
Sched<[sched.Folded, sched.ReadAfterFold]>;
}
Expand Down
34 changes: 17 additions & 17 deletions llvm/lib/Target/X86/X86InstrFragmentsSIMD.td
Expand Up @@ -789,23 +789,6 @@ def SDTX86MaskedStore: SDTypeProfile<0, 3, [ // masked store
SDTCisVec<0>, SDTCisPtrTy<1>, SDTCisVec<2>, SDTCisSameNumEltsAs<0, 2>
]>;

//===----------------------------------------------------------------------===//
// SSE Complex Patterns
//===----------------------------------------------------------------------===//

// These are 'extloads' from a scalar to the low element of a vector, zeroing
// the top elements. These are used for the SSE 'ss' and 'sd' instruction
// forms.
def sse_load_f32 : ComplexPattern<v4f32, 5, "selectScalarSSELoad", [],
[SDNPHasChain, SDNPMayLoad, SDNPMemOperand,
SDNPWantRoot, SDNPWantParent]>;
def sse_load_f64 : ComplexPattern<v2f64, 5, "selectScalarSSELoad", [],
[SDNPHasChain, SDNPMayLoad, SDNPMemOperand,
SDNPWantRoot, SDNPWantParent]>;

def ssmem : X86MemOperand<"printdwordmem", X86Mem32AsmOperand>;
def sdmem : X86MemOperand<"printqwordmem", X86Mem64AsmOperand>;

//===----------------------------------------------------------------------===//
// SSE pattern fragments
//===----------------------------------------------------------------------===//
Expand Down Expand Up @@ -976,6 +959,23 @@ def X86VBroadcastld64 : PatFrag<(ops node:$src),
return cast<MemIntrinsicSDNode>(N)->getMemoryVT().getStoreSize() == 8;
}]>;

// Scalar SSE intrinsic fragments to match several different types of loads.
// Used by scalar SSE intrinsic instructions which have 128 bit types, but
// only load a single element.
// FIXME: We should add more canolicalizing in DAGCombine. Particulary removing
// the simple_load case.
def sse_load_f32 : PatFrags<(ops node:$ptr),
[(v4f32 (simple_load node:$ptr)),
(v4f32 (X86vzload32 node:$ptr)),
(v4f32 (scalar_to_vector (loadf32 node:$ptr)))]>;
def sse_load_f64 : PatFrags<(ops node:$ptr),
[(v2f64 (simple_load node:$ptr)),
(v2f64 (X86vzload64 node:$ptr)),
(v2f64 (scalar_to_vector (loadf64 node:$ptr)))]>;

def ssmem : X86MemOperand<"printdwordmem", X86Mem32AsmOperand>;
def sdmem : X86MemOperand<"printqwordmem", X86Mem64AsmOperand>;


def fp32imm0 : PatLeaf<(f32 fpimm), [{
return N->isExactlyValue(+0.0);
Expand Down

0 comments on commit 135b877

Please sign in to comment.