Skip to content

Commit

Permalink
[DAG] Fold Op(vecreduce(a), vecreduce(b)) into vecreduce(Op(a,b))
Browse files Browse the repository at this point in the history
So long as the operation is reassociative, we can reassociate the double
vecreduce from for example fadd(vecreduce(a), vecreduce(b)) to
vecreduce(fadd(a,b)). This will in general save a few instructions, but some
architectures (MVE) require the opposite fold, so a shouldExpandReduction is
added to account for it. Only targets that use shouldExpandReduction will be
affected.

Differential Revision: https://reviews.llvm.org/D141870
  • Loading branch information
davemgreen committed Feb 8, 2023
1 parent 665ee0c commit 1af3f59
Show file tree
Hide file tree
Showing 11 changed files with 321 additions and 419 deletions.
6 changes: 6 additions & 0 deletions llvm/include/llvm/CodeGen/TargetLowering.h
Expand Up @@ -444,6 +444,12 @@ class TargetLoweringBase {
return true;
}

// Return true if op(vecreduce(x), vecreduce(y)) should be reassociated to
// vecreduce(op(x, y)) for the reduction opcode RedOpc.
virtual bool shouldReassociateReduction(unsigned RedOpc, EVT VT) const {
return true;
}

/// Return true if it is profitable to convert a select of FP constants into
/// a constant pool load whose address depends on the select condition. The
/// parameter may be used to differentiate a select with FP compare from
Expand Down
84 changes: 84 additions & 0 deletions llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp
Expand Up @@ -550,6 +550,9 @@ namespace {
SDValue N1);
SDValue reassociateOps(unsigned Opc, const SDLoc &DL, SDValue N0,
SDValue N1, SDNodeFlags Flags);
SDValue reassociateReduction(unsigned ResOpc, unsigned Opc, const SDLoc &DL,
EVT VT, SDValue N0, SDValue N1,
SDNodeFlags Flags = SDNodeFlags());

SDValue visitShiftByConstant(SDNode *N);

Expand Down Expand Up @@ -1310,6 +1313,25 @@ SDValue DAGCombiner::reassociateOps(unsigned Opc, const SDLoc &DL, SDValue N0,
return SDValue();
}

// Try to fold Opc(vecreduce(x), vecreduce(y)) -> vecreduce(Opc(x, y))
// Note that we only expect Flags to be passed from FP operations. For integer
// operations they need to be dropped.
SDValue DAGCombiner::reassociateReduction(unsigned RedOpc, unsigned Opc,
const SDLoc &DL, EVT VT, SDValue N0,
SDValue N1, SDNodeFlags Flags) {
if (N0.getOpcode() == RedOpc && N1.getOpcode() == RedOpc &&
N0.getOperand(0).getValueType() == N1.getOperand(0).getValueType() &&
N0->hasOneUse() && N1->hasOneUse() &&
TLI.isOperationLegalOrCustom(Opc, N0.getOperand(0).getValueType()) &&
TLI.shouldReassociateReduction(RedOpc, N0.getOperand(0).getValueType())) {
SelectionDAG::FlagInserter FlagsInserter(DAG, Flags);
return DAG.getNode(RedOpc, DL, VT,
DAG.getNode(Opc, DL, N0.getOperand(0).getValueType(),
N0.getOperand(0), N1.getOperand(0)));
}
return SDValue();
}

SDValue DAGCombiner::CombineTo(SDNode *N, const SDValue *To, unsigned NumTo,
bool AddTo) {
assert(N->getNumValues() == NumTo && "Broken CombineTo call!");
Expand Down Expand Up @@ -2650,6 +2672,11 @@ SDValue DAGCombiner::visitADDLike(SDNode *N) {
return Add;
if (SDValue Add = ReassociateAddOr(N1, N0))
return Add;

// Fold add(vecreduce(x), vecreduce(y)) -> vecreduce(add(x, y))
if (SDValue SD =
reassociateReduction(ISD::VECREDUCE_ADD, ISD::ADD, DL, VT, N0, N1))
return SD;
}
// fold ((0-A) + B) -> B-A
if (N0.getOpcode() == ISD::SUB && isNullOrNullSplat(N0.getOperand(0)))
Expand Down Expand Up @@ -4351,6 +4378,11 @@ SDValue DAGCombiner::visitMUL(SDNode *N) {
if (SDValue RMUL = reassociateOps(ISD::MUL, DL, N0, N1, N->getFlags()))
return RMUL;

// Fold mul(vecreduce(x), vecreduce(y)) -> vecreduce(mul(x, y))
if (SDValue SD =
reassociateReduction(ISD::VECREDUCE_MUL, ISD::MUL, DL, VT, N0, N1))
return SD;

// Simplify the operands using demanded-bits information.
if (SimplifyDemandedBits(SDValue(N, 0)))
return SDValue(N, 0);
Expand Down Expand Up @@ -5486,6 +5518,25 @@ SDValue DAGCombiner::visitIMINMAX(SDNode *N) {
if (SDValue S = PerformUMinFpToSatCombine(N0, N1, N0, N1, ISD::SETULT, DAG))
return S;

// Fold and(vecreduce(x), vecreduce(y)) -> vecreduce(and(x, y))
auto ReductionOpcode = [](unsigned Opcode) {
switch (Opcode) {
case ISD::SMIN:
return ISD::VECREDUCE_SMIN;
case ISD::SMAX:
return ISD::VECREDUCE_SMAX;
case ISD::UMIN:
return ISD::VECREDUCE_UMIN;
case ISD::UMAX:
return ISD::VECREDUCE_UMAX;
default:
llvm_unreachable("Unexpected opcode");
}
};
if (SDValue SD = reassociateReduction(ReductionOpcode(Opcode), Opcode,
SDLoc(N), VT, N0, N1))
return SD;

// Simplify the operands using demanded-bits information.
if (SimplifyDemandedBits(SDValue(N, 0)))
return SDValue(N, 0);
Expand Down Expand Up @@ -6525,6 +6576,11 @@ SDValue DAGCombiner::visitAND(SDNode *N) {
if (SDValue RAND = reassociateOps(ISD::AND, SDLoc(N), N0, N1, N->getFlags()))
return RAND;

// Fold and(vecreduce(x), vecreduce(y)) -> vecreduce(and(x, y))
if (SDValue SD = reassociateReduction(ISD::VECREDUCE_AND, ISD::AND, SDLoc(N),
VT, N0, N1))
return SD;

// fold (and (or x, C), D) -> D if (C & D) == D
auto MatchSubset = [](ConstantSDNode *LHS, ConstantSDNode *RHS) {
return RHS->getAPIntValue().isSubsetOf(LHS->getAPIntValue());
Expand Down Expand Up @@ -7419,6 +7475,11 @@ SDValue DAGCombiner::visitOR(SDNode *N) {
if (SDValue ROR = reassociateOps(ISD::OR, SDLoc(N), N0, N1, N->getFlags()))
return ROR;

// Fold or(vecreduce(x), vecreduce(y)) -> vecreduce(or(x, y))
if (SDValue SD = reassociateReduction(ISD::VECREDUCE_OR, ISD::OR, SDLoc(N),
VT, N0, N1))
return SD;

// Canonicalize (or (and X, c1), c2) -> (and (or X, c2), c1|c2)
// iff (c1 & c2) != 0 or c1/c2 are undef.
auto MatchIntersect = [](ConstantSDNode *C1, ConstantSDNode *C2) {
Expand Down Expand Up @@ -8903,6 +8964,11 @@ SDValue DAGCombiner::visitXOR(SDNode *N) {
if (SDValue RXOR = reassociateOps(ISD::XOR, DL, N0, N1, N->getFlags()))
return RXOR;

// Fold xor(vecreduce(x), vecreduce(y)) -> vecreduce(xor(x, y))
if (SDValue SD =
reassociateReduction(ISD::VECREDUCE_XOR, ISD::XOR, DL, VT, N0, N1))
return SD;

// fold (a^b) -> (a|b) iff a and b share no bits.
if ((!LegalOperations || TLI.isOperationLegal(ISD::OR, VT)) &&
DAG.haveNoCommonBitsSet(N0, N1))
Expand Down Expand Up @@ -15621,6 +15687,11 @@ SDValue DAGCombiner::visitFADD(SDNode *N) {
DAG.getConstantFP(4.0, DL, VT));
}
}

// Fold fadd(vecreduce(x), vecreduce(y)) -> vecreduce(fadd(x, y))
if (SDValue SD = reassociateReduction(ISD::VECREDUCE_FADD, ISD::FADD, DL,
VT, N0, N1, Flags))
return SD;
} // enable-unsafe-fp-math

// FADD -> FMA combines:
Expand Down Expand Up @@ -15795,6 +15866,11 @@ SDValue DAGCombiner::visitFMUL(SDNode *N) {
SDValue MulConsts = DAG.getNode(ISD::FMUL, DL, VT, Two, N1);
return DAG.getNode(ISD::FMUL, DL, VT, N0.getOperand(0), MulConsts);
}

// Fold fmul(vecreduce(x), vecreduce(y)) -> vecreduce(fmul(x, y))
if (SDValue SD = reassociateReduction(ISD::VECREDUCE_FMUL, ISD::FMUL, DL,
VT, N0, N1, Flags))
return SD;
}

// fold (fmul X, 2.0) -> (fadd X, X)
Expand Down Expand Up @@ -16845,6 +16921,14 @@ SDValue DAGCombiner::visitFMinMax(SDNode *N) {
}
}

const TargetOptions &Options = DAG.getTarget().Options;
if ((Options.UnsafeFPMath && Options.NoSignedZerosFPMath) ||
(Flags.hasAllowReassociation() && Flags.hasNoSignedZeros()))
if (SDValue SD = reassociateReduction(IsMin ? ISD::VECREDUCE_FMIN
: ISD::VECREDUCE_FMAX,
Opc, SDLoc(N), VT, N0, N1, Flags))
return SD;

return SDValue();
}

Expand Down
4 changes: 4 additions & 0 deletions llvm/lib/Target/ARM/ARMISelLowering.h
Expand Up @@ -617,6 +617,10 @@ class VectorType;
return TargetLowering::shouldFormOverflowOp(Opcode, VT, true);
}

bool shouldReassociateReduction(unsigned Opc, EVT VT) const override {
return Opc != ISD::VECREDUCE_ADD;
}

/// Returns true if an argument of type Ty needs to be passed in a
/// contiguous block of registers in calling convention CallConv.
bool functionArgumentNeedsConsecutiveRegisters(
Expand Down
12 changes: 4 additions & 8 deletions llvm/test/CodeGen/AArch64/aarch64-addv.ll
Expand Up @@ -102,11 +102,9 @@ define i32 @oversized_ADDV_512(ptr %arr) {
define i8 @addv_combine_i8(<8 x i8> %a1, <8 x i8> %a2) {
; CHECK-LABEL: addv_combine_i8:
; CHECK: // %bb.0: // %entry
; CHECK-NEXT: add v0.8b, v0.8b, v1.8b
; CHECK-NEXT: addv b0, v0.8b
; CHECK-NEXT: addv b1, v1.8b
; CHECK-NEXT: fmov w8, s0
; CHECK-NEXT: fmov w9, s1
; CHECK-NEXT: add w0, w8, w9
; CHECK-NEXT: fmov w0, s0
; CHECK-NEXT: ret
entry:
%rdx.1 = call i8 @llvm.vector.reduce.add.v8i8(<8 x i8> %a1)
Expand All @@ -118,11 +116,9 @@ entry:
define i16 @addv_combine_i16(<4 x i16> %a1, <4 x i16> %a2) {
; CHECK-LABEL: addv_combine_i16:
; CHECK: // %bb.0: // %entry
; CHECK-NEXT: add v0.4h, v0.4h, v1.4h
; CHECK-NEXT: addv h0, v0.4h
; CHECK-NEXT: addv h1, v1.4h
; CHECK-NEXT: fmov w8, s0
; CHECK-NEXT: fmov w9, s1
; CHECK-NEXT: add w0, w8, w9
; CHECK-NEXT: fmov w0, s0
; CHECK-NEXT: ret
entry:
%rdx.1 = call i16 @llvm.vector.reduce.add.v4i16(<4 x i16> %a1)
Expand Down

0 comments on commit 1af3f59

Please sign in to comment.