Skip to content
Permalink
Browse files

[DAGCombine] matchBinOpReduction - add partial reduction matching

This patch adds support for recognizing cases where a larger vector type is being used to reduce just the elements in the lower subvector:

e.g. <8 x i32> reduction pattern in a <16 x i32> vector:

<4,5,6,7,u,u,u,u,u,u,u,u,u,u,u,u>
<2,3,u,u,u,u,u,u,u,u,u,u,u,u,u,u>
<1,u,u,u,u,u,u,u,u,u,u,u,u,u,u,u>

matchBinOpReduction returns the lower extracted subvector in such cases, assuming isExtractSubvectorCheap accepts the extraction.

I've only enabled it for X86 reduction sums so far. I intend to enable it for the bitop/minmax cases in future patches, and eventually I think its worth turning it on all the time. This is mainly just a case of ensuring calls to matchBinOpReduction don't make assumptions on the vector width based on the original vector extraction.

Fixes the x86 partial reduction sum cases in PR33758 and PR42023.

Differential Revision: https://reviews.llvm.org/D65047

llvm-svn: 366933
  • Loading branch information...
RKSimon committed Jul 24, 2019
1 parent e8bffd3 commit 7d318b2bb19771745021145730387d43c589a9a7
@@ -1588,9 +1588,12 @@ class SelectionDAG {
/// Extract. The reduction must use one of the opcodes listed in /p
/// CandidateBinOps and on success /p BinOp will contain the matching opcode.
/// Returns the vector that is being reduced on, or SDValue() if a reduction
/// was not matched.
/// was not matched. If \p AllowPartials is set then in the case of a
/// reduction pattern that only matches the first few stages, the extracted
/// subvector of the start of the reduction is returned.
SDValue matchBinOpReduction(SDNode *Extract, ISD::NodeType &BinOp,
ArrayRef<ISD::NodeType> CandidateBinOps);
ArrayRef<ISD::NodeType> CandidateBinOps,
bool AllowPartials = false);

/// Utility function used by legalize and lowering to
/// "unroll" a vector operation by splitting out the scalars and operating
@@ -9005,7 +9005,8 @@ void SDNode::intersectFlagsWith(const SDNodeFlags Flags) {

SDValue
SelectionDAG::matchBinOpReduction(SDNode *Extract, ISD::NodeType &BinOp,
ArrayRef<ISD::NodeType> CandidateBinOps) {
ArrayRef<ISD::NodeType> CandidateBinOps,
bool AllowPartials) {
// The pattern must end in an extract from index 0.
if (Extract->getOpcode() != ISD::EXTRACT_VECTOR_ELT ||
!isNullConstant(Extract->getOperand(1)))
@@ -9019,6 +9020,23 @@ SelectionDAG::matchBinOpReduction(SDNode *Extract, ISD::NodeType &BinOp,
return Op.getOpcode() == unsigned(BinOp);
}))
return SDValue();
unsigned CandidateBinOp = Op.getOpcode();

// Matching failed - attempt to see if we did enough stages that a partial
// reduction from a subvector is possible.
auto PartialReduction = [&](SDValue Op, unsigned NumSubElts) {
if (!AllowPartials || !Op)
return SDValue();
EVT OpVT = Op.getValueType();
EVT OpSVT = OpVT.getScalarType();
EVT SubVT = EVT::getVectorVT(*getContext(), OpSVT, NumSubElts);
if (!TLI->isExtractSubvectorCheap(SubVT, OpVT, 0))
return SDValue();
BinOp = (ISD::NodeType)CandidateBinOp;
return getNode(
ISD::EXTRACT_SUBVECTOR, SDLoc(Op), SubVT, Op,
getConstant(0, SDLoc(Op), TLI->getVectorIdxTy(getDataLayout())));
};

// At each stage, we're looking for something that looks like:
// %s = shufflevector <8 x i32> %op, <8 x i32> undef,
@@ -9030,10 +9048,15 @@ SelectionDAG::matchBinOpReduction(SDNode *Extract, ISD::NodeType &BinOp,
// <4,5,6,7,u,u,u,u>
// <2,3,u,u,u,u,u,u>
// <1,u,u,u,u,u,u,u>
unsigned CandidateBinOp = Op.getOpcode();
// While a partial reduction match would be:
// <2,3,u,u,u,u,u,u>
// <1,u,u,u,u,u,u,u>
SDValue PrevOp;
for (unsigned i = 0; i < Stages; ++i) {
unsigned MaskEnd = (1 << i);

if (Op.getOpcode() != CandidateBinOp)
return SDValue();
return PartialReduction(PrevOp, MaskEnd);

SDValue Op0 = Op.getOperand(0);
SDValue Op1 = Op.getOperand(1);
@@ -9049,12 +9072,14 @@ SelectionDAG::matchBinOpReduction(SDNode *Extract, ISD::NodeType &BinOp,
// The first operand of the shuffle should be the same as the other operand
// of the binop.
if (!Shuffle || Shuffle->getOperand(0) != Op)
return SDValue();
return PartialReduction(PrevOp, MaskEnd);

// Verify the shuffle has the expected (at this stage of the pyramid) mask.
for (int Index = 0, MaskEnd = 1 << i; Index < MaskEnd; ++Index)
if (Shuffle->getMaskElt(Index) != MaskEnd + Index)
return SDValue();
for (int Index = 0; Index < (int)MaskEnd; ++Index)
if (Shuffle->getMaskElt(Index) != (MaskEnd + Index))
return PartialReduction(PrevOp, MaskEnd);

PrevOp = Op;
}

BinOp = (ISD::NodeType)CandidateBinOp;

// TODO: Allow FADD with reduction and/or reassociation and no-signed-zeros.
ISD::NodeType Opc;
SDValue Rdx = DAG.matchBinOpReduction(ExtElt, Opc, {ISD::ADD});
SDValue Rdx = DAG.matchBinOpReduction(ExtElt, Opc, {ISD::ADD}, true);
if (!Rdx)
return SDValue();

"Reduction doesn't end in an extract from index 0");

EVT VT = ExtElt->getValueType(0);
EVT VecVT = ExtElt->getOperand(0).getValueType();
EVT VecVT = Rdx.getValueType();
if (VecVT.getScalarType() != VT)
return SDValue();

// vXi8 reduction - sum lo/hi halves then use PSADBW.
if (VT == MVT::i8) {
while (Rdx.getValueSizeInBits() > 128) {
EVT RdxVT = Rdx.getValueType();
unsigned HalfSize = RdxVT.getSizeInBits() / 2;
unsigned HalfElts = RdxVT.getVectorNumElements() / 2;
unsigned HalfSize = VecVT.getSizeInBits() / 2;
unsigned HalfElts = VecVT.getVectorNumElements() / 2;
SDValue Lo = extractSubVector(Rdx, 0, DAG, DL, HalfSize);
SDValue Hi = extractSubVector(Rdx, HalfElts, DAG, DL, HalfSize);
Rdx = DAG.getNode(ISD::ADD, DL, Lo.getValueType(), Lo, Hi);
VecVT = Rdx.getValueType();
}
assert(Rdx.getValueType() == MVT::v16i8 && "v16i8 reduction expected");
assert(VecVT == MVT::v16i8 && "v16i8 reduction expected");

SDValue Hi = DAG.getVectorShuffle(
MVT::v16i8, DL, Rdx, Rdx,
unsigned NumElts = VecVT.getVectorNumElements();
SDValue Hi = extract128BitVector(Rdx, NumElts / 2, DAG, DL);
SDValue Lo = extract128BitVector(Rdx, 0, DAG, DL);
VecVT = EVT::getVectorVT(*DAG.getContext(), VT, NumElts / 2);
Rdx = DAG.getNode(HorizOpcode, DL, VecVT, Hi, Lo);
Rdx = DAG.getNode(HorizOpcode, DL, Lo.getValueType(), Hi, Lo);
VecVT = Rdx.getValueType();
}
if (!((VecVT == MVT::v8i16 || VecVT == MVT::v4i32) && Subtarget.hasSSSE3()) &&
!((VecVT == MVT::v4f32 || VecVT == MVT::v2f64) && Subtarget.hasSSE3()))
return SDValue();

// extract (add (shuf X), X), 0 --> extract (hadd X, X), 0
assert(Rdx.getValueType() == VecVT && "Unexpected reduction match");
unsigned ReductionSteps = Log2_32(VecVT.getVectorNumElements());
for (unsigned i = 0; i != ReductionSteps; ++i)
Rdx = DAG.getNode(HorizOpcode, DL, VecVT, Rdx, Rdx);
@@ -1699,8 +1699,7 @@ define i32 @partial_reduction_add_v8i32(<8 x i32> %x) {
;
; AVX-FAST-LABEL: partial_reduction_add_v8i32:
; AVX-FAST: # %bb.0:
; AVX-FAST-NEXT: vpshufd {{.*#+}} xmm1 = xmm0[2,3,0,1]
; AVX-FAST-NEXT: vpaddd %xmm1, %xmm0, %xmm0
; AVX-FAST-NEXT: vphaddd %xmm0, %xmm0, %xmm0
; AVX-FAST-NEXT: vphaddd %xmm0, %xmm0, %xmm0
; AVX-FAST-NEXT: vmovd %xmm0, %eax
; AVX-FAST-NEXT: vzeroupper
@@ -1741,34 +1740,13 @@ define i32 @partial_reduction_add_v16i32(<16 x i32> %x) {
; AVX-SLOW-NEXT: vzeroupper
; AVX-SLOW-NEXT: retq
;
; AVX1-FAST-LABEL: partial_reduction_add_v16i32:
; AVX1-FAST: # %bb.0:
; AVX1-FAST-NEXT: vpshufd {{.*#+}} xmm1 = xmm0[2,3,0,1]
; AVX1-FAST-NEXT: vpaddd %xmm1, %xmm0, %xmm0
; AVX1-FAST-NEXT: vphaddd %xmm0, %xmm0, %xmm0
; AVX1-FAST-NEXT: vmovd %xmm0, %eax
; AVX1-FAST-NEXT: vzeroupper
; AVX1-FAST-NEXT: retq
;
; AVX2-FAST-LABEL: partial_reduction_add_v16i32:
; AVX2-FAST: # %bb.0:
; AVX2-FAST-NEXT: vpshufd {{.*#+}} xmm1 = xmm0[2,3,0,1]
; AVX2-FAST-NEXT: vpaddd %xmm1, %xmm0, %xmm0
; AVX2-FAST-NEXT: vpshufd {{.*#+}} xmm1 = xmm0[1,1,2,3]
; AVX2-FAST-NEXT: vpaddd %xmm1, %xmm0, %xmm0
; AVX2-FAST-NEXT: vmovd %xmm0, %eax
; AVX2-FAST-NEXT: vzeroupper
; AVX2-FAST-NEXT: retq
;
; AVX512-FAST-LABEL: partial_reduction_add_v16i32:
; AVX512-FAST: # %bb.0:
; AVX512-FAST-NEXT: vpshufd {{.*#+}} xmm1 = xmm0[2,3,0,1]
; AVX512-FAST-NEXT: vpaddd %xmm1, %xmm0, %xmm0
; AVX512-FAST-NEXT: vpshufd {{.*#+}} xmm1 = xmm0[1,1,2,3]
; AVX512-FAST-NEXT: vpaddd %xmm1, %xmm0, %xmm0
; AVX512-FAST-NEXT: vmovd %xmm0, %eax
; AVX512-FAST-NEXT: vzeroupper
; AVX512-FAST-NEXT: retq
; AVX-FAST-LABEL: partial_reduction_add_v16i32:
; AVX-FAST: # %bb.0:
; AVX-FAST-NEXT: vphaddd %xmm0, %xmm0, %xmm0
; AVX-FAST-NEXT: vphaddd %xmm0, %xmm0, %xmm0
; AVX-FAST-NEXT: vmovd %xmm0, %eax
; AVX-FAST-NEXT: vzeroupper
; AVX-FAST-NEXT: retq
%x23 = shufflevector <16 x i32> %x, <16 x i32> undef, <16 x i32> <i32 2, i32 3, i32 undef, i32 undef, i32 undef, i32 undef, i32 undef, i32 undef, i32 undef, i32 undef, i32 undef, i32 undef, i32 undef, i32 undef, i32 undef, i32 undef>
%x0213 = add <16 x i32> %x, %x23
%x13 = shufflevector <16 x i32> %x0213, <16 x i32> undef, <16 x i32> <i32 1, i32 undef, i32 undef, i32 undef, i32 undef, i32 undef, i32 undef, i32 undef, i32 undef, i32 undef, i32 undef, i32 undef, i32 undef, i32 undef, i32 undef, i32 undef>
@@ -2010,8 +1988,7 @@ define i32 @hadd32_8(<8 x i32> %x225) {
;
; AVX-FAST-LABEL: hadd32_8:
; AVX-FAST: # %bb.0:
; AVX-FAST-NEXT: vpshufd {{.*#+}} xmm1 = xmm0[2,3,0,1]
; AVX-FAST-NEXT: vpaddd %xmm1, %xmm0, %xmm0
; AVX-FAST-NEXT: vphaddd %xmm0, %xmm0, %xmm0
; AVX-FAST-NEXT: vphaddd %xmm0, %xmm0, %xmm0
; AVX-FAST-NEXT: vmovd %xmm0, %eax
; AVX-FAST-NEXT: vzeroupper
@@ -2052,34 +2029,13 @@ define i32 @hadd32_16(<16 x i32> %x225) {
; AVX-SLOW-NEXT: vzeroupper
; AVX-SLOW-NEXT: retq
;
; AVX1-FAST-LABEL: hadd32_16:
; AVX1-FAST: # %bb.0:
; AVX1-FAST-NEXT: vpshufd {{.*#+}} xmm1 = xmm0[2,3,0,1]
; AVX1-FAST-NEXT: vpaddd %xmm1, %xmm0, %xmm0
; AVX1-FAST-NEXT: vphaddd %xmm0, %xmm0, %xmm0
; AVX1-FAST-NEXT: vmovd %xmm0, %eax
; AVX1-FAST-NEXT: vzeroupper
; AVX1-FAST-NEXT: retq
;
; AVX2-FAST-LABEL: hadd32_16:
; AVX2-FAST: # %bb.0:
; AVX2-FAST-NEXT: vpshufd {{.*#+}} xmm1 = xmm0[2,3,0,1]
; AVX2-FAST-NEXT: vpaddd %xmm1, %xmm0, %xmm0
; AVX2-FAST-NEXT: vpshufd {{.*#+}} xmm1 = xmm0[1,1,2,3]
; AVX2-FAST-NEXT: vpaddd %xmm1, %xmm0, %xmm0
; AVX2-FAST-NEXT: vmovd %xmm0, %eax
; AVX2-FAST-NEXT: vzeroupper
; AVX2-FAST-NEXT: retq
;
; AVX512-FAST-LABEL: hadd32_16:
; AVX512-FAST: # %bb.0:
; AVX512-FAST-NEXT: vpshufd {{.*#+}} xmm1 = xmm0[2,3,0,1]
; AVX512-FAST-NEXT: vpaddd %xmm1, %xmm0, %xmm0
; AVX512-FAST-NEXT: vpshufd {{.*#+}} xmm1 = xmm0[1,1,2,3]
; AVX512-FAST-NEXT: vpaddd %xmm1, %xmm0, %xmm0
; AVX512-FAST-NEXT: vmovd %xmm0, %eax
; AVX512-FAST-NEXT: vzeroupper
; AVX512-FAST-NEXT: retq
; AVX-FAST-LABEL: hadd32_16:
; AVX-FAST: # %bb.0:
; AVX-FAST-NEXT: vphaddd %xmm0, %xmm0, %xmm0
; AVX-FAST-NEXT: vphaddd %xmm0, %xmm0, %xmm0
; AVX-FAST-NEXT: vmovd %xmm0, %eax
; AVX-FAST-NEXT: vzeroupper
; AVX-FAST-NEXT: retq
%x226 = shufflevector <16 x i32> %x225, <16 x i32> undef, <16 x i32> <i32 2, i32 3, i32 undef, i32 undef, i32 undef, i32 undef, i32 undef, i32 undef, i32 undef, i32 undef, i32 undef, i32 undef, i32 undef, i32 undef, i32 undef, i32 undef>
%x227 = add <16 x i32> %x225, %x226
%x228 = shufflevector <16 x i32> %x227, <16 x i32> undef, <16 x i32> <i32 1, i32 undef, i32 undef, i32 undef, i32 undef, i32 undef, i32 undef, i32 undef, i32 undef, i32 undef, i32 undef, i32 undef, i32 undef, i32 undef, i32 undef, i32 undef>
@@ -2149,8 +2105,7 @@ define i32 @hadd32_8_optsize(<8 x i32> %x225) optsize {
;
; AVX-LABEL: hadd32_8_optsize:
; AVX: # %bb.0:
; AVX-NEXT: vpshufd {{.*#+}} xmm1 = xmm0[2,3,0,1]
; AVX-NEXT: vpaddd %xmm1, %xmm0, %xmm0
; AVX-NEXT: vphaddd %xmm0, %xmm0, %xmm0
; AVX-NEXT: vphaddd %xmm0, %xmm0, %xmm0
; AVX-NEXT: vmovd %xmm0, %eax
; AVX-NEXT: vzeroupper
@@ -2172,63 +2127,13 @@ define i32 @hadd32_16_optsize(<16 x i32> %x225) optsize {
; SSE3-NEXT: movd %xmm1, %eax
; SSE3-NEXT: retq
;
; AVX1-SLOW-LABEL: hadd32_16_optsize:
; AVX1-SLOW: # %bb.0:
; AVX1-SLOW-NEXT: vpshufd {{.*#+}} xmm1 = xmm0[2,3,0,1]
; AVX1-SLOW-NEXT: vpaddd %xmm1, %xmm0, %xmm0
; AVX1-SLOW-NEXT: vphaddd %xmm0, %xmm0, %xmm0
; AVX1-SLOW-NEXT: vmovd %xmm0, %eax
; AVX1-SLOW-NEXT: vzeroupper
; AVX1-SLOW-NEXT: retq
;
; AVX1-FAST-LABEL: hadd32_16_optsize:
; AVX1-FAST: # %bb.0:
; AVX1-FAST-NEXT: vpshufd {{.*#+}} xmm1 = xmm0[2,3,0,1]
; AVX1-FAST-NEXT: vpaddd %xmm1, %xmm0, %xmm0
; AVX1-FAST-NEXT: vphaddd %xmm0, %xmm0, %xmm0
; AVX1-FAST-NEXT: vmovd %xmm0, %eax
; AVX1-FAST-NEXT: vzeroupper
; AVX1-FAST-NEXT: retq
;
; AVX2-SLOW-LABEL: hadd32_16_optsize:
; AVX2-SLOW: # %bb.0:
; AVX2-SLOW-NEXT: vpshufd {{.*#+}} xmm1 = xmm0[2,3,0,1]
; AVX2-SLOW-NEXT: vpaddd %xmm1, %xmm0, %xmm0
; AVX2-SLOW-NEXT: vpshufd {{.*#+}} xmm1 = xmm0[1,1,2,3]
; AVX2-SLOW-NEXT: vpaddd %xmm1, %xmm0, %xmm0
; AVX2-SLOW-NEXT: vmovd %xmm0, %eax
; AVX2-SLOW-NEXT: vzeroupper
; AVX2-SLOW-NEXT: retq
;
; AVX2-FAST-LABEL: hadd32_16_optsize:
; AVX2-FAST: # %bb.0:
; AVX2-FAST-NEXT: vpshufd {{.*#+}} xmm1 = xmm0[2,3,0,1]
; AVX2-FAST-NEXT: vpaddd %xmm1, %xmm0, %xmm0
; AVX2-FAST-NEXT: vpshufd {{.*#+}} xmm1 = xmm0[1,1,2,3]
; AVX2-FAST-NEXT: vpaddd %xmm1, %xmm0, %xmm0
; AVX2-FAST-NEXT: vmovd %xmm0, %eax
; AVX2-FAST-NEXT: vzeroupper
; AVX2-FAST-NEXT: retq
;
; AVX512-SLOW-LABEL: hadd32_16_optsize:
; AVX512-SLOW: # %bb.0:
; AVX512-SLOW-NEXT: vpshufd {{.*#+}} xmm1 = xmm0[2,3,0,1]
; AVX512-SLOW-NEXT: vpaddd %xmm1, %xmm0, %xmm0
; AVX512-SLOW-NEXT: vpshufd {{.*#+}} xmm1 = xmm0[1,1,2,3]
; AVX512-SLOW-NEXT: vpaddd %xmm1, %xmm0, %xmm0
; AVX512-SLOW-NEXT: vmovd %xmm0, %eax
; AVX512-SLOW-NEXT: vzeroupper
; AVX512-SLOW-NEXT: retq
;
; AVX512-FAST-LABEL: hadd32_16_optsize:
; AVX512-FAST: # %bb.0:
; AVX512-FAST-NEXT: vpshufd {{.*#+}} xmm1 = xmm0[2,3,0,1]
; AVX512-FAST-NEXT: vpaddd %xmm1, %xmm0, %xmm0
; AVX512-FAST-NEXT: vpshufd {{.*#+}} xmm1 = xmm0[1,1,2,3]
; AVX512-FAST-NEXT: vpaddd %xmm1, %xmm0, %xmm0
; AVX512-FAST-NEXT: vmovd %xmm0, %eax
; AVX512-FAST-NEXT: vzeroupper
; AVX512-FAST-NEXT: retq
; AVX-LABEL: hadd32_16_optsize:
; AVX: # %bb.0:
; AVX-NEXT: vphaddd %xmm0, %xmm0, %xmm0
; AVX-NEXT: vphaddd %xmm0, %xmm0, %xmm0
; AVX-NEXT: vmovd %xmm0, %eax
; AVX-NEXT: vzeroupper
; AVX-NEXT: retq
%x226 = shufflevector <16 x i32> %x225, <16 x i32> undef, <16 x i32> <i32 2, i32 3, i32 undef, i32 undef, i32 undef, i32 undef, i32 undef, i32 undef, i32 undef, i32 undef, i32 undef, i32 undef, i32 undef, i32 undef, i32 undef, i32 undef>
%x227 = add <16 x i32> %x225, %x226
%x228 = shufflevector <16 x i32> %x227, <16 x i32> undef, <16 x i32> <i32 1, i32 undef, i32 undef, i32 undef, i32 undef, i32 undef, i32 undef, i32 undef, i32 undef, i32 undef, i32 undef, i32 undef, i32 undef, i32 undef, i32 undef, i32 undef>

0 comments on commit 7d318b2

Please sign in to comment.
You can’t perform that action at this time.