Skip to content

Commit

Permalink
[ARM] Lower (select_cc k k (select_cc ~k ~k x)) into (SSAT l_k x)
Browse files Browse the repository at this point in the history
Summary:
SSAT saturates an integer, making sure that its value lies within
an interval [-k, k]. Since the constant is given to SSAT as the
number of bytes set to one, k + 1 must be a power of 2, otherwise
the optimization is not possible. Also, the select_cc must use <
and > respectively so that they define an interval.

Reviewers: mcrosier, jmolloy, rengolin

Subscribers: aemerson, rengolin, llvm-commits

Differential Revision: http://reviews.llvm.org/D21372

llvm-svn: 273581
  • Loading branch information
pbarrio committed Jun 23, 2016
1 parent 80771b9 commit 7a64346
Show file tree
Hide file tree
Showing 6 changed files with 359 additions and 1 deletion.
3 changes: 3 additions & 0 deletions llvm/include/llvm/Target/TargetSelectionDAG.td
Expand Up @@ -116,6 +116,9 @@ def SDTIntBinOp : SDTypeProfile<1, 2, [ // add, and, or, xor, udiv, etc.
def SDTIntShiftOp : SDTypeProfile<1, 2, [ // shl, sra, srl
SDTCisSameAs<0, 1>, SDTCisInt<0>, SDTCisInt<2>
]>;
def SDTIntSatNoShOp : SDTypeProfile<1, 2, [ // ssat with no shift
SDTCisSameAs<0, 1>, SDTCisInt<2>
]>;
def SDTIntBinHiLoOp : SDTypeProfile<2, 2, [ // mulhi, mullo, sdivrem, udivrem
SDTCisSameAs<0, 1>, SDTCisSameAs<0, 2>, SDTCisSameAs<0, 3>,SDTCisInt<0>
]>;
Expand Down
134 changes: 133 additions & 1 deletion llvm/lib/Target/ARM/ARMISelLowering.cpp
Expand Up @@ -1136,6 +1136,8 @@ const char *ARMTargetLowering::getTargetNodeName(unsigned Opcode) const {

case ARMISD::CMOV: return "ARMISD::CMOV";

case ARMISD::SSAT: return "ARMISD::SSAT";

case ARMISD::SRL_FLAG: return "ARMISD::SRL_FLAG";
case ARMISD::SRA_FLAG: return "ARMISD::SRA_FLAG";
case ARMISD::RRX: return "ARMISD::RRX";
Expand Down Expand Up @@ -3728,14 +3730,144 @@ SDValue ARMTargetLowering::getCMOV(const SDLoc &dl, EVT VT, SDValue FalseVal,
}
}

bool isGTorGE(ISD::CondCode CC) { return CC == ISD::SETGT || CC == ISD::SETGE; }

bool isLTorLE(ISD::CondCode CC) { return CC == ISD::SETLT || CC == ISD::SETLE; }

// See if a conditional (LHS CC RHS ? TrueVal : FalseVal) is lower-saturating.
// All of these conditions (and their <= and >= counterparts) will do:
// x < k ? k : x
// x > k ? x : k
// k < x ? x : k
// k > x ? k : x
bool isLowerSaturate(const SDValue LHS, const SDValue RHS,
const SDValue TrueVal, const SDValue FalseVal,
const ISD::CondCode CC, const SDValue K) {
return (isGTorGE(CC) &&
((K == LHS && K == TrueVal) || (K == RHS && K == FalseVal))) ||
(isLTorLE(CC) &&
((K == RHS && K == TrueVal) || (K == LHS && K == FalseVal)));
}

// Similar to isLowerSaturate(), but checks for upper-saturating conditions.
bool isUpperSaturate(const SDValue LHS, const SDValue RHS,
const SDValue TrueVal, const SDValue FalseVal,
const ISD::CondCode CC, const SDValue K) {
return (isGTorGE(CC) &&
((K == RHS && K == TrueVal) || (K == LHS && K == FalseVal))) ||
(isLTorLE(CC) &&
((K == LHS && K == TrueVal) || (K == RHS && K == FalseVal)));
}

// Check if two chained conditionals could be converted into SSAT.
//
// SSAT can replace a set of two conditional selectors that bound a number to an
// interval of type [k, ~k] when k + 1 is a power of 2. Here are some examples:
//
// x < -k ? -k : (x > k ? k : x)
// x < -k ? -k : (x < k ? x : k)
// x > -k ? (x > k ? k : x) : -k
// x < k ? (x < -k ? -k : x) : k
// etc.
//
// It returns true if the conversion can be done, false otherwise.
// Additionally, the variable is returned in parameter V and the constant in K.
bool isSaturatingConditional(const SDValue &Op, SDValue &V, uint64_t &K) {

SDValue LHS1 = Op.getOperand(0);
SDValue RHS1 = Op.getOperand(1);
SDValue TrueVal1 = Op.getOperand(2);
SDValue FalseVal1 = Op.getOperand(3);
ISD::CondCode CC1 = cast<CondCodeSDNode>(Op.getOperand(4))->get();

const SDValue Op2 = isa<ConstantSDNode>(TrueVal1) ? FalseVal1 : TrueVal1;
if (Op2.getOpcode() != ISD::SELECT_CC)
return false;

SDValue LHS2 = Op2.getOperand(0);
SDValue RHS2 = Op2.getOperand(1);
SDValue TrueVal2 = Op2.getOperand(2);
SDValue FalseVal2 = Op2.getOperand(3);
ISD::CondCode CC2 = cast<CondCodeSDNode>(Op2.getOperand(4))->get();

// Find out which are the constants and which are the variables
// in each conditional
SDValue *K1 = isa<ConstantSDNode>(LHS1) ? &LHS1 : isa<ConstantSDNode>(RHS1)
? &RHS1
: NULL;
SDValue *K2 = isa<ConstantSDNode>(LHS2) ? &LHS2 : isa<ConstantSDNode>(RHS2)
? &RHS2
: NULL;
SDValue K2Tmp = isa<ConstantSDNode>(TrueVal2) ? TrueVal2 : FalseVal2;
SDValue V1Tmp = (K1 && *K1 == LHS1) ? RHS1 : LHS1;
SDValue V2Tmp = (K2 && *K2 == LHS2) ? RHS2 : LHS2;
SDValue V2 = (K2Tmp == TrueVal2) ? FalseVal2 : TrueVal2;

// We must detect cases where the original operations worked with 16- or
// 8-bit values. In such case, V2Tmp != V2 because the comparison operations
// must work with sign-extended values but the select operations return
// the original non-extended value.
SDValue V2TmpReg = V2Tmp;
if (V2Tmp->getOpcode() == ISD::SIGN_EXTEND_INREG)
V2TmpReg = V2Tmp->getOperand(0);

// Check that the registers and the constants have the correct values
// in both conditionals
if (!K1 || !K2 || *K1 == Op2 || *K2 != K2Tmp || V1Tmp != V2Tmp ||
V2TmpReg != V2)
return false;

// Figure out which conditional is saturating the lower/upper bound.
const SDValue *LowerCheckOp =
isLowerSaturate(LHS1, RHS1, TrueVal1, FalseVal1, CC1, *K1)
? &Op
: isLowerSaturate(LHS2, RHS2, TrueVal2, FalseVal2, CC2, *K2) ? &Op2
: NULL;
const SDValue *UpperCheckOp =
isUpperSaturate(LHS1, RHS1, TrueVal1, FalseVal1, CC1, *K1)
? &Op
: isUpperSaturate(LHS2, RHS2, TrueVal2, FalseVal2, CC2, *K2) ? &Op2
: NULL;

if (!UpperCheckOp || !LowerCheckOp || LowerCheckOp == UpperCheckOp)
return false;

// Check that the constant in the lower-bound check is
// the opposite of the constant in the upper-bound check
// in 1's complement.
int64_t Val1 = cast<ConstantSDNode>(*K1)->getSExtValue();
int64_t Val2 = cast<ConstantSDNode>(*K2)->getSExtValue();
int64_t PosVal = std::max(Val1, Val2);

if (((Val1 > Val2 && UpperCheckOp == &Op) ||
(Val1 < Val2 && UpperCheckOp == &Op2)) &&
Val1 == ~Val2 && isPowerOf2_64(PosVal + 1)) {

V = V2;
K = (uint64_t)PosVal; // At this point, PosVal is guaranteed to be positive
return true;
}

return false;
}

SDValue ARMTargetLowering::LowerSELECT_CC(SDValue Op, SelectionDAG &DAG) const {

EVT VT = Op.getValueType();
SDLoc dl(Op);

// Try to convert two saturating conditional selects into a single SSAT
SDValue SatValue;
uint64_t SatConstant;
if (isSaturatingConditional(Op, SatValue, SatConstant))
return DAG.getNode(ARMISD::SSAT, dl, VT, SatValue,
DAG.getConstant(countTrailingOnes(SatConstant), dl, VT));

SDValue LHS = Op.getOperand(0);
SDValue RHS = Op.getOperand(1);
ISD::CondCode CC = cast<CondCodeSDNode>(Op.getOperand(4))->get();
SDValue TrueVal = Op.getOperand(2);
SDValue FalseVal = Op.getOperand(3);
SDLoc dl(Op);

if (Subtarget->isFPOnlySP() && LHS.getValueType() == MVT::f64) {
DAG.getTargetLoweringInfo().softenSetCCOperands(DAG, MVT::f64, LHS, RHS, CC,
Expand Down
2 changes: 2 additions & 0 deletions llvm/lib/Target/ARM/ARMISelLowering.h
Expand Up @@ -60,6 +60,8 @@ namespace llvm {

CMOV, // ARM conditional move instructions.

SSAT, // Signed saturation

BCC_i64,

SRL_FLAG, // V,Flag = srl_flag X -> srl X, 1 + save carry out.
Expand Down
4 changes: 4 additions & 0 deletions llvm/lib/Target/ARM/ARMInstrInfo.td
Expand Up @@ -129,6 +129,8 @@ def ARMintretflag : SDNode<"ARMISD::INTRET_FLAG", SDT_ARMcall,
def ARMcmov : SDNode<"ARMISD::CMOV", SDT_ARMCMov,
[SDNPInGlue]>;

def ARMssatnoshift : SDNode<"ARMISD::SSAT", SDTIntSatNoShOp, []>;

def ARMbrcond : SDNode<"ARMISD::BRCOND", SDT_ARMBrcond,
[SDNPHasChain, SDNPInGlue, SDNPOutGlue]>;

Expand Down Expand Up @@ -3713,6 +3715,8 @@ def : ARMV6Pat<(int_arm_ssat GPRnopc:$a, imm1_32:$pos),
(SSAT imm1_32:$pos, GPRnopc:$a, 0)>;
def : ARMV6Pat<(int_arm_usat GPRnopc:$a, imm0_31:$pos),
(USAT imm0_31:$pos, GPRnopc:$a, 0)>;
def : ARMPat<(ARMssatnoshift GPRnopc:$Rn, imm0_31:$imm),
(SSAT imm0_31:$imm, GPRnopc:$Rn, 0)>;

//===----------------------------------------------------------------------===//
// Bitwise Instructions.
Expand Down
2 changes: 2 additions & 0 deletions llvm/lib/Target/ARM/ARMInstrThumb2.td
Expand Up @@ -2287,6 +2287,8 @@ def t2USAT16: T2SatI<(outs rGPR:$Rd), (ins imm0_15:$sat_imm, rGPR:$Rn),

def : T2Pat<(int_arm_ssat GPR:$a, imm1_32:$pos), (t2SSAT imm1_32:$pos, GPR:$a, 0)>;
def : T2Pat<(int_arm_usat GPR:$a, imm0_31:$pos), (t2USAT imm0_31:$pos, GPR:$a, 0)>;
def : T2Pat<(ARMssatnoshift GPRnopc:$Rn, imm0_31:$imm),
(t2SSAT imm0_31:$imm, GPRnopc:$Rn, 0)>;

//===----------------------------------------------------------------------===//
// Shift and rotate Instructions.
Expand Down

0 comments on commit 7a64346

Please sign in to comment.