Skip to content

Commit

Permalink
[DAGCombiner] fold binops with constant into select-of-constants
Browse files Browse the repository at this point in the history
This is part of the ongoing attempt to improve select codegen for all targets and select 
canonicalization in IR (see D24480 for more background). The transform is a subset of what
is done in InstCombine's FoldOpIntoSelect().

I first noticed a regression in the x86 avx512-insert-extract.ll tests with a patch that 
hopes to convert more selects to basic math ops. This appears to be a general missing DAG
transform though, so I added tests for all standard binops in rL296621 
(PowerPC was chosen semi-randomly; it has scripted FileCheck support, but so do ARM and x86).

The poor output for "sel_constants_shl_constant" is tracked with:
https://bugs.llvm.org/show_bug.cgi?id=32105

Differential Revision: https://reviews.llvm.org/D30502

llvm-svn: 296699
  • Loading branch information
rotateright committed Mar 1, 2017
1 parent d80b69f commit 9293865
Show file tree
Hide file tree
Showing 7 changed files with 346 additions and 335 deletions.
112 changes: 112 additions & 0 deletions llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp
Expand Up @@ -337,6 +337,7 @@ namespace {
SDValue visitShiftByConstant(SDNode *N, ConstantSDNode *Amt);

SDValue foldSelectOfConstants(SDNode *N);
SDValue foldBinOpIntoSelect(SDNode *BO);
bool SimplifySelectOps(SDNode *SELECT, SDValue LHS, SDValue RHS);
SDValue SimplifyBinOpWithSameOpcodeHands(SDNode *N);
SDValue SimplifySelect(const SDLoc &DL, SDValue N0, SDValue N1, SDValue N2);
Expand Down Expand Up @@ -1747,6 +1748,59 @@ static ConstantSDNode *getAsNonOpaqueConstant(SDValue N) {
return Const != nullptr && !Const->isOpaque() ? Const : nullptr;
}

SDValue DAGCombiner::foldBinOpIntoSelect(SDNode *BO) {
auto BinOpcode = BO->getOpcode();
assert((BinOpcode == ISD::ADD || BinOpcode == ISD::SUB ||
BinOpcode == ISD::MUL || BinOpcode == ISD::SDIV ||
BinOpcode == ISD::UDIV || BinOpcode == ISD::SREM ||
BinOpcode == ISD::UREM || BinOpcode == ISD::AND ||
BinOpcode == ISD::OR || BinOpcode == ISD::XOR ||
BinOpcode == ISD::SHL || BinOpcode == ISD::SRL ||
BinOpcode == ISD::SRA || BinOpcode == ISD::FADD ||
BinOpcode == ISD::FSUB || BinOpcode == ISD::FMUL ||
BinOpcode == ISD::FDIV || BinOpcode == ISD::FREM) &&
"Unexpected binary operator");

SDValue C1 = BO->getOperand(1);
if (!isConstantOrConstantVector(C1) &&
!isConstantFPBuildVectorOrConstantFP(C1))
return SDValue();

// Don't do this unless the old select is going away. We want to eliminate the
// binary operator, not replace a binop with a select.
// TODO: Handle ISD::SELECT_CC.
SDValue Sel = BO->getOperand(0);
if (Sel.getOpcode() != ISD::SELECT || !Sel.hasOneUse())
return SDValue();

SDValue CT = Sel.getOperand(1);
if (!isConstantOrConstantVector(CT) &&
!isConstantFPBuildVectorOrConstantFP(CT))
return SDValue();

SDValue CF = Sel.getOperand(2);
if (!isConstantOrConstantVector(CF) &&
!isConstantFPBuildVectorOrConstantFP(CF))
return SDValue();

// We have a select-of-constants followed by a binary operator with a
// constant. Eliminate the binop by pulling the constant math into the select.
// Example: add (select Cond, CT, CF), C1 --> select Cond, CT + C1, CF + C1
EVT VT = Sel.getValueType();
SDLoc DL(Sel);
SDValue NewCT = DAG.getNode(BinOpcode, DL, VT, CT, C1);
assert((isConstantOrConstantVector(NewCT) ||
isConstantFPBuildVectorOrConstantFP(NewCT)) &&
"Failed to constant fold a binop with constant operands");

SDValue NewCF = DAG.getNode(BinOpcode, DL, VT, CF, C1);
assert((isConstantOrConstantVector(NewCF) ||
isConstantFPBuildVectorOrConstantFP(NewCF)) &&
"Failed to constant fold a binop with constant operands");

return DAG.getSelect(DL, VT, Sel.getOperand(0), NewCT, NewCF);
}

SDValue DAGCombiner::visitADD(SDNode *N) {
SDValue N0 = N->getOperand(0);
SDValue N1 = N->getOperand(1);
Expand Down Expand Up @@ -1795,6 +1849,9 @@ SDValue DAGCombiner::visitADD(SDNode *N) {
}
}

if (SDValue NewSel = foldBinOpIntoSelect(N))
return NewSel;

// reassociate add
if (SDValue RADD = ReassociateOps(ISD::ADD, DL, N0, N1))
return RADD;
Expand Down Expand Up @@ -1999,6 +2056,9 @@ SDValue DAGCombiner::visitSUB(SDNode *N) {
N1.getNode());
}

if (SDValue NewSel = foldBinOpIntoSelect(N))
return NewSel;

ConstantSDNode *N1C = getAsNonOpaqueConstant(N1);

// fold (sub x, c) -> (add x, -c)
Expand Down Expand Up @@ -2210,6 +2270,10 @@ SDValue DAGCombiner::visitMUL(SDNode *N) {
// fold (mul x, 1) -> x
if (N1IsConst && ConstValue1 == 1 && IsFullSplat)
return N0;

if (SDValue NewSel = foldBinOpIntoSelect(N))
return NewSel;

// fold (mul x, -1) -> 0-x
if (N1IsConst && ConstValue1.isAllOnesValue()) {
SDLoc DL(N);
Expand Down Expand Up @@ -2401,6 +2465,9 @@ SDValue DAGCombiner::visitSDIV(SDNode *N) {
return DAG.getNode(ISD::SUB, DL, VT,
DAG.getConstant(0, DL, VT), N0);

if (SDValue NewSel = foldBinOpIntoSelect(N))
return NewSel;

// If we know the sign bits of both operands are zero, strength reduce to a
// udiv instead. Handles (X&15) /s 4 -> X&15 >> 2
if (DAG.SignBitIsZero(N1) && DAG.SignBitIsZero(N0))
Expand Down Expand Up @@ -2493,6 +2560,9 @@ SDValue DAGCombiner::visitUDIV(SDNode *N) {
N0C, N1C))
return Folded;

if (SDValue NewSel = foldBinOpIntoSelect(N))
return NewSel;

// fold (udiv x, (1 << c)) -> x >>u c
if (isConstantOrConstantVector(N1, /*NoOpaques*/ true) &&
DAG.isKnownToBeAPowerOfTwo(N1)) {
Expand Down Expand Up @@ -2561,6 +2631,9 @@ SDValue DAGCombiner::visitREM(SDNode *N) {
if (SDValue Folded = DAG.FoldConstantArithmetic(Opcode, DL, VT, N0C, N1C))
return Folded;

if (SDValue NewSel = foldBinOpIntoSelect(N))
return NewSel;

if (isSigned) {
// If we know the sign bits of both operands are zero, strength reduce to a
// urem instead. Handles (X & 0x0FFFFFFF) %s 16 -> X&15
Expand Down Expand Up @@ -3267,6 +3340,10 @@ SDValue DAGCombiner::visitAND(SDNode *N) {
if (N1C && DAG.MaskedValueIsZero(SDValue(N, 0),
APInt::getAllOnesValue(BitWidth)))
return DAG.getConstant(0, SDLoc(N), VT);

if (SDValue NewSel = foldBinOpIntoSelect(N))
return NewSel;

// reassociate and
if (SDValue RAND = ReassociateOps(ISD::AND, SDLoc(N), N0, N1))
return RAND;
Expand Down Expand Up @@ -4008,6 +4085,10 @@ SDValue DAGCombiner::visitOR(SDNode *N) {
// fold (or x, -1) -> -1
if (isAllOnesConstant(N1))
return N1;

if (SDValue NewSel = foldBinOpIntoSelect(N))
return NewSel;

// fold (or x, c) -> c iff (x & ~c) == 0
if (N1C && DAG.MaskedValueIsZero(N0, ~N1C->getAPIntValue()))
return N1;
Expand Down Expand Up @@ -4753,6 +4834,10 @@ SDValue DAGCombiner::visitXOR(SDNode *N) {
// fold (xor x, 0) -> x
if (isNullConstant(N1))
return N0;

if (SDValue NewSel = foldBinOpIntoSelect(N))
return NewSel;

// reassociate xor
if (SDValue RXOR = ReassociateOps(ISD::XOR, SDLoc(N), N0, N1))
return RXOR;
Expand Down Expand Up @@ -5040,6 +5125,10 @@ SDValue DAGCombiner::visitSHL(SDNode *N) {
// fold (shl undef, x) -> 0
if (N0.isUndef())
return DAG.getConstant(0, SDLoc(N), VT);

if (SDValue NewSel = foldBinOpIntoSelect(N))
return NewSel;

// if (shl x, c) is known to be zero, return 0
if (DAG.MaskedValueIsZero(SDValue(N, 0),
APInt::getAllOnesValue(OpSizeInBits)))
Expand Down Expand Up @@ -5243,6 +5332,10 @@ SDValue DAGCombiner::visitSRA(SDNode *N) {
// fold (sra x, 0) -> x
if (N1C && N1C->isNullValue())
return N0;

if (SDValue NewSel = foldBinOpIntoSelect(N))
return NewSel;

// fold (sra (shl x, c1), c1) -> sext_inreg for some c1 and target supports
// sext_inreg.
if (N1C && N0.getOpcode() == ISD::SHL && N1 == N0.getOperand(1)) {
Expand Down Expand Up @@ -5390,6 +5483,10 @@ SDValue DAGCombiner::visitSRL(SDNode *N) {
// fold (srl x, 0) -> x
if (N1C && N1C->isNullValue())
return N0;

if (SDValue NewSel = foldBinOpIntoSelect(N))
return NewSel;

// if (srl x, c) is known to be zero, return 0
if (N1C && DAG.MaskedValueIsZero(SDValue(N, 0),
APInt::getAllOnesValue(OpSizeInBits)))
Expand Down Expand Up @@ -9064,6 +9161,9 @@ SDValue DAGCombiner::visitFADD(SDNode *N) {
if (N0CFP && !N1CFP)
return DAG.getNode(ISD::FADD, DL, VT, N1, N0, Flags);

if (SDValue NewSel = foldBinOpIntoSelect(N))
return NewSel;

// fold (fadd A, (fneg B)) -> (fsub A, B)
if ((!LegalOperations || TLI.isOperationLegalOrCustom(ISD::FSUB, VT)) &&
isNegatibleForFree(N1, LegalOperations, TLI, &Options) == 2)
Expand Down Expand Up @@ -9211,6 +9311,9 @@ SDValue DAGCombiner::visitFSUB(SDNode *N) {
if (N0CFP && N1CFP)
return DAG.getNode(ISD::FSUB, DL, VT, N0, N1, Flags);

if (SDValue NewSel = foldBinOpIntoSelect(N))
return NewSel;

// fold (fsub A, (fneg B)) -> (fadd A, B)
if (isNegatibleForFree(N1, LegalOperations, TLI, &Options))
return DAG.getNode(ISD::FADD, DL, VT, N0,
Expand Down Expand Up @@ -9290,6 +9393,9 @@ SDValue DAGCombiner::visitFMUL(SDNode *N) {
if (N1CFP && N1CFP->isExactlyValue(1.0))
return N0;

if (SDValue NewSel = foldBinOpIntoSelect(N))
return NewSel;

if (Options.UnsafeFPMath) {
// fold (fmul A, 0) -> 0
if (N1CFP && N1CFP->isZero())
Expand Down Expand Up @@ -9544,6 +9650,9 @@ SDValue DAGCombiner::visitFDIV(SDNode *N) {
if (N0CFP && N1CFP)
return DAG.getNode(ISD::FDIV, SDLoc(N), VT, N0, N1, Flags);

if (SDValue NewSel = foldBinOpIntoSelect(N))
return NewSel;

if (Options.UnsafeFPMath) {
// fold (fdiv X, c2) -> fmul X, 1/c2 if losing precision is acceptable.
if (N1CFP) {
Expand Down Expand Up @@ -9647,6 +9756,9 @@ SDValue DAGCombiner::visitFREM(SDNode *N) {
return DAG.getNode(ISD::FREM, SDLoc(N), VT, N0, N1,
&cast<BinaryWithFlagsSDNode>(N)->Flags);

if (SDValue NewSel = foldBinOpIntoSelect(N))
return NewSel;

return SDValue();
}

Expand Down
12 changes: 5 additions & 7 deletions llvm/test/CodeGen/ARM/select_xform.ll
Expand Up @@ -223,21 +223,19 @@ entry:
ret i32 %add
}

; Do not fold the xor into the select
; Fold the xor into the select.
define i32 @t15(i32 %p) {
entry:
; ARM-LABEL: t15:
; ARM: mov [[REG:r[0-9]+]], #2
; ARM: mov [[REG:r[0-9]+]], #3
; ARM: cmp r0, #8
; ARM: movwgt [[REG:r[0-9]+]], #1
; ARM: eor r0, [[REG:r[0-9]+]], #1
; ARM: movwgt [[REG:r[0-9]+]], #0

; T2-LABEL: t15:
; T2: movs [[REG:r[0-9]+]], #2
; T2: movs [[REG:r[0-9]+]], #3
; T2: cmp [[REG:r[0-9]+]], #8
; T2: it gt
; T2: movgt [[REG:r[0-9]+]], #1
; T2: eor r0, [[REG:r[0-9]+]], #1
; T2: movgt [[REG:r[0-9]+]], #0
%cmp = icmp sgt i32 %p, 8
%a = select i1 %cmp, i32 1, i32 2
%xor = xor i32 %a, 1
Expand Down

0 comments on commit 9293865

Please sign in to comment.