Skip to content

Commit

Permalink
[PPC] Move the combine "a << (b % (sizeof(a) * 8)) -> (PPCshl a, b)" …
Browse files Browse the repository at this point in the history
…to the backend. NFC.

Summary:
Eli pointed out that it's unsafe to combine the shifts to ISD::SHL etc.,
because those are not defined for b > sizeof(a) * 8, even after some of
the combiners run.

However, PPCISD::SHL defines that behavior (as the instructions themselves).
Move the combination to the backend.

The tests in shift_mask.ll still pass.

Reviewers: echristo, hfinkel, efriedma, iteratee

Subscribers: nemanjai, llvm-commits

Differential Revision: https://reviews.llvm.org/D33076

llvm-svn: 302937
  • Loading branch information
timshen91 committed May 12, 2017
1 parent dd3a739 commit 10c64e6
Show file tree
Hide file tree
Showing 5 changed files with 108 additions and 58 deletions.
8 changes: 0 additions & 8 deletions llvm/include/llvm/Target/TargetLowering.h
Expand Up @@ -2063,14 +2063,6 @@ class TargetLoweringBase {
return false;
}

// Return true if the instruction that performs a << b actually performs
// a << (b % (sizeof(a) * 8)).
virtual bool supportsModuloShift(ISD::NodeType Inst, EVT ReturnType) const {
assert((Inst == ISD::SHL || Inst == ISD::SRA || Inst == ISD::SRL) &&
"Expect a shift instruction");
return false;
}

//===--------------------------------------------------------------------===//
// Runtime Library hooks
//
Expand Down
33 changes: 0 additions & 33 deletions llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp
Expand Up @@ -5313,17 +5313,6 @@ SDValue DAGCombiner::visitSHL(SDNode *N) {
}
}

// If the target supports masking y in (shl, y),
// fold (shl x, (and y, ((1 << numbits(x)) - 1))) -> (shl x, y)
if (TLI.isOperationLegal(ISD::SHL, VT) &&
TLI.supportsModuloShift(ISD::SHL, VT) && N1->getOpcode() == ISD::AND) {
if (ConstantSDNode *Mask = isConstOrConstSplat(N1->getOperand(1))) {
if (Mask->getZExtValue() == OpSizeInBits - 1) {
return DAG.getNode(ISD::SHL, SDLoc(N), VT, N0, N1->getOperand(0));
}
}
}

ConstantSDNode *N1C = isConstOrConstSplat(N1);

// fold (shl c1, c2) -> c1<<c2
Expand Down Expand Up @@ -5522,17 +5511,6 @@ SDValue DAGCombiner::visitSRA(SDNode *N) {
EVT VT = N0.getValueType();
unsigned OpSizeInBits = VT.getScalarSizeInBits();

// If the target supports masking y in (sra, y),
// fold (sra x, (and y, ((1 << numbits(x)) - 1))) -> (sra x, y)
if (TLI.isOperationLegal(ISD::SRA, VT) &&
TLI.supportsModuloShift(ISD::SRA, VT) && N1->getOpcode() == ISD::AND) {
if (ConstantSDNode *Mask = isConstOrConstSplat(N1->getOperand(1))) {
if (Mask->getZExtValue() == OpSizeInBits - 1) {
return DAG.getNode(ISD::SRA, SDLoc(N), VT, N0, N1->getOperand(0));
}
}
}

// Arithmetic shifting an all-sign-bit value is a no-op.
// fold (sra 0, x) -> 0
// fold (sra -1, x) -> -1
Expand Down Expand Up @@ -5687,17 +5665,6 @@ SDValue DAGCombiner::visitSRL(SDNode *N) {
EVT VT = N0.getValueType();
unsigned OpSizeInBits = VT.getScalarSizeInBits();

// If the target supports masking y in (srl, y),
// fold (srl x, (and y, ((1 << numbits(x)) - 1))) -> (srl x, y)
if (TLI.isOperationLegal(ISD::SRL, VT) &&
TLI.supportsModuloShift(ISD::SRL, VT) && N1->getOpcode() == ISD::AND) {
if (ConstantSDNode *Mask = isConstOrConstSplat(N1->getOperand(1))) {
if (Mask->getZExtValue() == OpSizeInBits - 1) {
return DAG.getNode(ISD::SRL, SDLoc(N), VT, N0, N1->getOperand(0));
}
}
}

// fold vector ops
if (VT.isVector())
if (SDValue FoldedVOp = SimplifyVBinOp(N))
Expand Down
64 changes: 64 additions & 0 deletions llvm/lib/Target/PowerPC/PPCISelLowering.cpp
Expand Up @@ -923,6 +923,9 @@ PPCTargetLowering::PPCTargetLowering(const PPCTargetMachine &TM,
setStackPointerRegisterToSaveRestore(isPPC64 ? PPC::X1 : PPC::R1);

// We have target-specific dag combine patterns for the following nodes:
setTargetDAGCombine(ISD::SHL);
setTargetDAGCombine(ISD::SRA);
setTargetDAGCombine(ISD::SRL);
setTargetDAGCombine(ISD::SINT_TO_FP);
setTargetDAGCombine(ISD::BUILD_VECTOR);
if (Subtarget.hasFPCVT())
Expand Down Expand Up @@ -11312,6 +11315,12 @@ SDValue PPCTargetLowering::PerformDAGCombine(SDNode *N,
SDLoc dl(N);
switch (N->getOpcode()) {
default: break;
case ISD::SHL:
return combineSHL(N, DCI);
case ISD::SRA:
return combineSRA(N, DCI);
case ISD::SRL:
return combineSRL(N, DCI);
case PPCISD::SHL:
if (isNullConstant(N->getOperand(0))) // 0 << V -> 0.
return N->getOperand(0);
Expand Down Expand Up @@ -12944,3 +12953,58 @@ bool PPCTargetLowering::isFPImmLegal(const APFloat &Imm, EVT VT) const {
return Imm.isPosZero();
}
}

// For vector shift operation op, fold
// (op x, (and y, ((1 << numbits(x)) - 1))) -> (target op x, y)
static SDValue stripModuloOnShift(const TargetLowering &TLI, SDNode *N,
SelectionDAG &DAG) {
SDValue N0 = N->getOperand(0);
SDValue N1 = N->getOperand(1);
EVT VT = N0.getValueType();
unsigned OpSizeInBits = VT.getScalarSizeInBits();
unsigned Opcode = N->getOpcode();
unsigned TargetOpcode;

switch (Opcode) {
default:
llvm_unreachable("Unexpected shift operation");
case ISD::SHL:
TargetOpcode = PPCISD::SHL;
break;
case ISD::SRL:
TargetOpcode = PPCISD::SRL;
break;
case ISD::SRA:
TargetOpcode = PPCISD::SRA;
break;
}

if (VT.isVector() && TLI.isOperationLegal(Opcode, VT) &&
N1->getOpcode() == ISD::AND)
if (ConstantSDNode *Mask = isConstOrConstSplat(N1->getOperand(1)))
if (Mask->getZExtValue() == OpSizeInBits - 1)
return DAG.getNode(TargetOpcode, SDLoc(N), VT, N0, N1->getOperand(0));

return SDValue();
}

SDValue PPCTargetLowering::combineSHL(SDNode *N, DAGCombinerInfo &DCI) const {
if (auto Value = stripModuloOnShift(*this, N, DCI.DAG))
return Value;

return SDValue();
}

SDValue PPCTargetLowering::combineSRA(SDNode *N, DAGCombinerInfo &DCI) const {
if (auto Value = stripModuloOnShift(*this, N, DCI.DAG))
return Value;

return SDValue();
}

SDValue PPCTargetLowering::combineSRL(SDNode *N, DAGCombinerInfo &DCI) const {
if (auto Value = stripModuloOnShift(*this, N, DCI.DAG))
return Value;

return SDValue();
}
21 changes: 10 additions & 11 deletions llvm/lib/Target/PowerPC/PPCISelLowering.h
Expand Up @@ -117,9 +117,13 @@ namespace llvm {
/// at function entry, used for PIC code.
GlobalBaseReg,

/// These nodes represent the 32-bit PPC shifts that operate on 6-bit
/// shift amounts. These nodes are generated by the multi-precision shift
/// code.
/// These nodes represent PPC shifts.
///
/// For scalar types, only the last `n + 1` bits of the shift amounts
/// are used, where n is log2(sizeof(element) * 8). See sld/slw, etc.
/// for exact behaviors.
///
/// For vector types, only the last n bits are used. See vsld.
SRL, SRA, SHL,

/// The combination of sra[wd]i and addze used to implemented signed
Expand Down Expand Up @@ -999,6 +1003,9 @@ namespace llvm {
SDValue DAGCombineBuildVector(SDNode *N, DAGCombinerInfo &DCI) const;
SDValue DAGCombineTruncBoolExt(SDNode *N, DAGCombinerInfo &DCI) const;
SDValue combineFPToIntToFP(SDNode *N, DAGCombinerInfo &DCI) const;
SDValue combineSHL(SDNode *N, DAGCombinerInfo &DCI) const;
SDValue combineSRA(SDNode *N, DAGCombinerInfo &DCI) const;
SDValue combineSRL(SDNode *N, DAGCombinerInfo &DCI) const;

/// ConvertSETCCToSubtract - looks at SETCC that compares ints. It replaces
/// SETCC with integer subtraction when (1) there is a legal way of doing it
Expand All @@ -1017,14 +1024,6 @@ namespace llvm {
SDValue
combineElementTruncationToVectorTruncation(SDNode *N,
DAGCombinerInfo &DCI) const;

bool supportsModuloShift(ISD::NodeType Inst,
EVT ReturnType) const override {
assert((Inst == ISD::SHL || Inst == ISD::SRA || Inst == ISD::SRL) &&
"Expect a shift instruction");
assert(isOperationLegal(Inst, ReturnType));
return ReturnType.isVector();
}
};

namespace PPC {
Expand Down
40 changes: 34 additions & 6 deletions llvm/lib/Target/PowerPC/PPCInstrAltivec.td
Expand Up @@ -987,20 +987,38 @@ def : Pat<(v8i16 (shl v8i16:$vA, v8i16:$vB)),
(v8i16 (VSLH $vA, $vB))>;
def : Pat<(v4i32 (shl v4i32:$vA, v4i32:$vB)),
(v4i32 (VSLW $vA, $vB))>;
def : Pat<(v16i8 (PPCshl v16i8:$vA, v16i8:$vB)),
(v16i8 (VSLB $vA, $vB))>;
def : Pat<(v8i16 (PPCshl v8i16:$vA, v8i16:$vB)),
(v8i16 (VSLH $vA, $vB))>;
def : Pat<(v4i32 (PPCshl v4i32:$vA, v4i32:$vB)),
(v4i32 (VSLW $vA, $vB))>;

def : Pat<(v16i8 (srl v16i8:$vA, v16i8:$vB)),
(v16i8 (VSRB $vA, $vB))>;
def : Pat<(v8i16 (srl v8i16:$vA, v8i16:$vB)),
(v8i16 (VSRH $vA, $vB))>;
def : Pat<(v4i32 (srl v4i32:$vA, v4i32:$vB)),
(v4i32 (VSRW $vA, $vB))>;
def : Pat<(v16i8 (PPCsrl v16i8:$vA, v16i8:$vB)),
(v16i8 (VSRB $vA, $vB))>;
def : Pat<(v8i16 (PPCsrl v8i16:$vA, v8i16:$vB)),
(v8i16 (VSRH $vA, $vB))>;
def : Pat<(v4i32 (PPCsrl v4i32:$vA, v4i32:$vB)),
(v4i32 (VSRW $vA, $vB))>;

def : Pat<(v16i8 (sra v16i8:$vA, v16i8:$vB)),
(v16i8 (VSRAB $vA, $vB))>;
def : Pat<(v8i16 (sra v8i16:$vA, v8i16:$vB)),
(v8i16 (VSRAH $vA, $vB))>;
def : Pat<(v4i32 (sra v4i32:$vA, v4i32:$vB)),
(v4i32 (VSRAW $vA, $vB))>;
def : Pat<(v16i8 (PPCsra v16i8:$vA, v16i8:$vB)),
(v16i8 (VSRAB $vA, $vB))>;
def : Pat<(v8i16 (PPCsra v8i16:$vA, v8i16:$vB)),
(v8i16 (VSRAH $vA, $vB))>;
def : Pat<(v4i32 (PPCsra v4i32:$vA, v4i32:$vB)),
(v4i32 (VSRAW $vA, $vB))>;

// Float to integer and integer to float conversions
def : Pat<(v4i32 (fp_to_sint v4f32:$vA)),
Expand Down Expand Up @@ -1072,14 +1090,24 @@ def:Pat<(vmrgow_swapped_shuffle v16i8:$vA, v16i8:$vB),
// Vector shifts
def VRLD : VX1_Int_Ty<196, "vrld", int_ppc_altivec_vrld, v2i64>;
def VSLD : VXForm_1<1476, (outs vrrc:$vD), (ins vrrc:$vA, vrrc:$vB),
"vsld $vD, $vA, $vB", IIC_VecGeneral,
[(set v2i64:$vD, (shl v2i64:$vA, v2i64:$vB))]>;
"vsld $vD, $vA, $vB", IIC_VecGeneral, []>;
def VSRD : VXForm_1<1732, (outs vrrc:$vD), (ins vrrc:$vA, vrrc:$vB),
"vsrd $vD, $vA, $vB", IIC_VecGeneral,
[(set v2i64:$vD, (srl v2i64:$vA, v2i64:$vB))]>;
"vsrd $vD, $vA, $vB", IIC_VecGeneral, []>;
def VSRAD : VXForm_1<964, (outs vrrc:$vD), (ins vrrc:$vA, vrrc:$vB),
"vsrad $vD, $vA, $vB", IIC_VecGeneral,
[(set v2i64:$vD, (sra v2i64:$vA, v2i64:$vB))]>;
"vsrad $vD, $vA, $vB", IIC_VecGeneral, []>;

def : Pat<(v2i64 (shl v2i64:$vA, v2i64:$vB)),
(v2i64 (VSLD $vA, $vB))>;
def : Pat<(v2i64 (PPCshl v2i64:$vA, v2i64:$vB)),
(v2i64 (VSLD $vA, $vB))>;
def : Pat<(v2i64 (srl v2i64:$vA, v2i64:$vB)),
(v2i64 (VSRD $vA, $vB))>;
def : Pat<(v2i64 (PPCsrl v2i64:$vA, v2i64:$vB)),
(v2i64 (VSRD $vA, $vB))>;
def : Pat<(v2i64 (sra v2i64:$vA, v2i64:$vB)),
(v2i64 (VSRAD $vA, $vB))>;
def : Pat<(v2i64 (PPCsra v2i64:$vA, v2i64:$vB)),
(v2i64 (VSRAD $vA, $vB))>;

// Vector Integer Arithmetic Instructions
let isCommutable = 1 in {
Expand Down

0 comments on commit 10c64e6

Please sign in to comment.