Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
189 changes: 155 additions & 34 deletions llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -698,6 +698,11 @@ AArch64TargetLowering::AArch64TargetLowering(const TargetMachine &TM,
setOperationAction(ISD::ABS, MVT::i64, Custom);
}

setOperationAction(ISD::ABDS, MVT::i32, Custom);
setOperationAction(ISD::ABDS, MVT::i64, Custom);
setOperationAction(ISD::ABDU, MVT::i32, Custom);
setOperationAction(ISD::ABDU, MVT::i64, Custom);

setOperationAction(ISD::SDIVREM, MVT::i32, Expand);
setOperationAction(ISD::SDIVREM, MVT::i64, Expand);
for (MVT VT : MVT::fixedlen_vector_valuetypes()) {
Expand Down Expand Up @@ -3712,7 +3717,8 @@ static SDValue emitStrictFPComparison(SDValue LHS, SDValue RHS, const SDLoc &DL,
}

static SDValue emitComparison(SDValue LHS, SDValue RHS, ISD::CondCode CC,
const SDLoc &DL, SelectionDAG &DAG) {
const SDLoc &DL, SelectionDAG &DAG,
bool MIOrPLSupported = false) {
EVT VT = LHS.getValueType();
const bool FullFP16 = DAG.getSubtarget<AArch64Subtarget>().hasFullFP16();

Expand Down Expand Up @@ -3755,6 +3761,33 @@ static SDValue emitComparison(SDValue LHS, SDValue RHS, ISD::CondCode CC,
} else if (LHS.getOpcode() == AArch64ISD::ANDS) {
// Use result of ANDS
return LHS.getValue(1);
} else if (MIOrPLSupported) {
// For MIOrPLSupported, optimize SUB/ADD operations with zero comparison
if (LHS.getOpcode() == ISD::SUB && CC == ISD::SETLT) {
// SUB(x, y) < 0 -> SUBS(x, y)
return DAG
.getNode(AArch64ISD::SUBS, DL, DAG.getVTList(VT, FlagsVT),
LHS.getOperand(0), LHS.getOperand(1))
.getValue(1);
} else if (LHS.getOpcode() == ISD::ADD && CC == ISD::SETGE) {
// ADD(x, y) >= 0 -> ADDS(x, y)
return DAG
.getNode(AArch64ISD::ADDS, DL, DAG.getVTList(VT, FlagsVT),
LHS.getOperand(0), LHS.getOperand(1))
.getValue(1);
} else if (LHS.getOpcode() == ISD::ADD && CC == ISD::SETLT) {
// ADD(x, y) < 0 -> SUBS(x, y)
return DAG
.getNode(AArch64ISD::ADDS, DL, DAG.getVTList(VT, FlagsVT),
LHS.getOperand(0), LHS.getOperand(1))
.getValue(1);
} else if (LHS.getOpcode() == ISD::SUB && CC == ISD::SETGE) {
// SUB(x, y) >= 0 -> ADDS(x, y)
return DAG
.getNode(AArch64ISD::SUBS, DL, DAG.getVTList(VT, FlagsVT),
LHS.getOperand(0), LHS.getOperand(1))
.getValue(1);
}
}
}

Expand Down Expand Up @@ -3819,7 +3852,8 @@ static SDValue emitConditionalComparison(SDValue LHS, SDValue RHS,
ISD::CondCode CC, SDValue CCOp,
AArch64CC::CondCode Predicate,
AArch64CC::CondCode OutCC,
const SDLoc &DL, SelectionDAG &DAG) {
const SDLoc &DL, SelectionDAG &DAG,
bool MIOrPLSupported = false) {
unsigned Opcode = 0;
const bool FullFP16 = DAG.getSubtarget<AArch64Subtarget>().hasFullFP16();

Expand All @@ -3846,6 +3880,30 @@ static SDValue emitConditionalComparison(SDValue LHS, SDValue RHS,
// we combine a (CCMP (sub 0, op1), op2) into a CCMN instruction ?
Opcode = AArch64ISD::CCMN;
LHS = LHS.getOperand(1);
} else if (isNullConstant(RHS) && !isUnsignedIntSetCC(CC) &&
MIOrPLSupported) {
// For MIOrPLSupported, optimize SUB/ADD operations with zero comparison
if (LHS.getOpcode() == ISD::SUB && CC == ISD::SETLT) {
// SUB(x, y) < 0 -> CCMP(x, y) with appropriate condition
Opcode = AArch64ISD::CCMP;
RHS = LHS.getOperand(1);
LHS = LHS.getOperand(0);
} else if (LHS.getOpcode() == ISD::ADD && CC == ISD::SETGE) {
// ADD(x, y) >= 0 -> CCMP(x, y) with appropriate condition
Opcode = AArch64ISD::CCMN;
RHS = LHS.getOperand(1);
LHS = LHS.getOperand(0);
} else if (LHS.getOpcode() == ISD::ADD && CC == ISD::SETLT) {
// ADD(x, y) < 0 -> CCMP(x, -y) with appropriate condition
Opcode = AArch64ISD::CCMN;
RHS = LHS.getOperand(1);
LHS = LHS.getOperand(0);
} else if (LHS.getOpcode() == ISD::SUB && CC == ISD::SETGE) {
// SUB(x, y) >= 0 -> CCMP(-x, y) with appropriate condition
Opcode = AArch64ISD::CCMP;
RHS = LHS.getOperand(1);
LHS = LHS.getOperand(0);
}
}
if (Opcode == 0)
Opcode = AArch64ISD::CCMP;
Expand Down Expand Up @@ -3972,7 +4030,7 @@ static SDValue emitConjunctionRec(SelectionDAG &DAG, SDValue Val,
return emitComparison(LHS, RHS, CC, DL, DAG);
// Otherwise produce a ccmp.
return emitConditionalComparison(LHS, RHS, CC, CCOp, Predicate, OutCC, DL,
DAG);
DAG, true);
}
assert(Val->hasOneUse() && "Valid conjunction/disjunction tree");

Expand Down Expand Up @@ -4251,7 +4309,7 @@ static SDValue getAArch64Cmp(SDValue LHS, SDValue RHS, ISD::CondCode CC,
}

if (!Cmp) {
Cmp = emitComparison(LHS, RHS, CC, DL, DAG);
Cmp = emitComparison(LHS, RHS, CC, DL, DAG, true);
AArch64CC = changeIntCCToAArch64CC(CC, RHS);
}
AArch64cc = getCondCode(DAG, AArch64CC);
Expand Down Expand Up @@ -7371,13 +7429,100 @@ SDValue AArch64TargetLowering::LowerABS(SDValue Op, SelectionDAG &DAG) const {
return LowerToPredicatedOp(Op, DAG, AArch64ISD::ABS_MERGE_PASSTHRU);

SDLoc DL(Op);
SDValue Neg = DAG.getNegative(Op.getOperand(0), DL, VT);

// Generate SUBS & CSEL.
SDValue Cmp = DAG.getNode(AArch64ISD::SUBS, DL, DAG.getVTList(VT, FlagsVT),
Op.getOperand(0), DAG.getConstant(0, DL, VT));
// Generate CMP & CSEL.
SDValue Cmp = emitComparison(Op.getOperand(0), DAG.getConstant(0, DL, VT),
ISD::SETGE, DL, DAG, true);
SDValue Neg = DAG.getNegative(Op.getOperand(0), DL, VT);
return DAG.getNode(AArch64ISD::CSEL, DL, VT, Op.getOperand(0), Neg,
getCondCode(DAG, AArch64CC::PL), Cmp.getValue(1));
getCondCode(DAG, AArch64CC::PL), Cmp);
}

// Generate SUBS and CNEG for absolute difference.
SDValue AArch64TargetLowering::LowerABD(SDValue Op, SelectionDAG &DAG) const {
MVT VT = Op.getSimpleValueType();

bool IsSigned = Op.getOpcode() == ISD::ABDS;
if (VT.isVector()) {
if (IsSigned)
return LowerToPredicatedOp(Op, DAG, AArch64ISD::ABDS_PRED);
else
return LowerToPredicatedOp(Op, DAG, AArch64ISD::ABDU_PRED);
}

SDValue LHS = Op.getOperand(0);
SDValue RHS = Op.getOperand(1);
SDLoc DL(Op);

if (!isa<ConstantSDNode>(RHS) || !isLegalCmpImmed(RHS->getAsAPIntVal())) {
SDValue TheLHS = isCMN(LHS, IsSigned ? ISD::SETGE : ISD::SETUGE, DAG)
? LHS.getOperand(1)
: LHS;
SDValue TheRHS = isCMN(RHS, IsSigned ? ISD::SETGE : ISD::SETUGE, DAG)
? RHS.getOperand(1)
: RHS;
if (getCmpOperandFoldingProfit(TheLHS) >
getCmpOperandFoldingProfit(TheRHS)) {
std::swap(LHS, RHS);
}
}

// If the subtract doesn't overflow then just use abs(sub())
bool IsNonNegative = DAG.SignBitIsZero(LHS) && DAG.SignBitIsZero(RHS);

if (DAG.willNotOverflowSub(IsSigned || IsNonNegative, LHS, RHS))
return DAG.getNode(ISD::ABS, DL, VT,
DAG.getNode(ISD::SUB, DL, VT, LHS, RHS));

if (DAG.willNotOverflowSub(IsSigned || IsNonNegative, RHS, LHS))
return DAG.getNode(ISD::ABS, DL, VT,
DAG.getNode(ISD::SUB, DL, VT, RHS, LHS));

unsigned Opcode = AArch64ISD::SUBS;
// Check if RHS is a subtraction against 0: (0 - X)
if (RHS.getOpcode() == ISD::SUB) {
SDValue SubLHS = RHS.getOperand(0);
SDValue SubRHS = RHS.getOperand(1);

// Check if it's 0 - X
if (isNullConstant(SubLHS)) {
bool CanUseAdd = false;
if (IsSigned) {
// For UCMP: only if X is known to never be INT_MIN (to avoid overflow)
if (RHS->getFlags().hasNoSignedWrap() || !DAG.computeKnownBits(SubRHS)
.getSignedMinValue()
.isMinSignedValue()) {
CanUseAdd = true;
}
} else {
// For UCMP: only if X is known to never be zero
if (DAG.isKnownNeverZero(SubRHS)) {
CanUseAdd = true;
}
}

if (CanUseAdd) {
Opcode = AArch64ISD::ADDS;
RHS = SubRHS; // Replace RHS with X, so we do LHS + X instead of
// LHS - (0 - X)
}
}
}

// Generate SUBS and CSEL for absolute difference (like LowerABS)
// Compute a - b with flags
SDValue Cmp = DAG.getNode(Opcode, DL, DAG.getVTList(VT, FlagsVT), LHS, RHS);

// Compute b - a (negative of a - b)
SDValue Neg = DAG.getNegative(Cmp.getValue(0), DL, VT);

// For unsigned: use HS (a >= b) to select a-b, otherwise b-a
// For signed: use GE (a >= b) to select a-b, otherwise b-a
AArch64CC::CondCode CC = IsSigned ? AArch64CC::GE : AArch64CC::HS;

// CSEL: if a > b, select a-b, otherwise b-a
return DAG.getNode(AArch64ISD::CSEL, DL, VT, Cmp.getValue(0), Neg,
getCondCode(DAG, CC), Cmp.getValue(1));
}

static SDValue LowerBRCOND(SDValue Op, SelectionDAG &DAG) {
Expand Down Expand Up @@ -7832,9 +7977,8 @@ SDValue AArch64TargetLowering::LowerOperation(SDValue Op,
case ISD::ABS:
return LowerABS(Op, DAG);
case ISD::ABDS:
return LowerToPredicatedOp(Op, DAG, AArch64ISD::ABDS_PRED);
case ISD::ABDU:
return LowerToPredicatedOp(Op, DAG, AArch64ISD::ABDU_PRED);
return LowerABD(Op, DAG);
case ISD::AVGFLOORS:
return LowerAVG(Op, DAG, AArch64ISD::HADDS_PRED);
case ISD::AVGFLOORU:
Expand Down Expand Up @@ -25815,29 +25959,6 @@ static SDValue performCSELCombine(SDNode *N,
}
}

// CSEL a, b, cc, SUBS(SUB(x,y), 0) -> CSEL a, b, cc, SUBS(x,y) if cc doesn't
// use overflow flags, to avoid the comparison with zero. In case of success,
// this also replaces the original SUB(x,y) with the newly created SUBS(x,y).
// NOTE: Perhaps in the future use performFlagSettingCombine to replace SUB
// nodes with their SUBS equivalent as is already done for other flag-setting
// operators, in which case doing the replacement here becomes redundant.
if (Cond.getOpcode() == AArch64ISD::SUBS && Cond->hasNUsesOfValue(1, 1) &&
isNullConstant(Cond.getOperand(1))) {
SDValue Sub = Cond.getOperand(0);
AArch64CC::CondCode CC =
static_cast<AArch64CC::CondCode>(N->getConstantOperandVal(2));
if (Sub.getOpcode() == ISD::SUB &&
(CC == AArch64CC::EQ || CC == AArch64CC::NE || CC == AArch64CC::MI ||
CC == AArch64CC::PL)) {
SDLoc DL(N);
SDValue Subs = DAG.getNode(AArch64ISD::SUBS, DL, Cond->getVTList(),
Sub.getOperand(0), Sub.getOperand(1));
DCI.CombineTo(Sub.getNode(), Subs);
DCI.CombineTo(Cond.getNode(), Subs, Subs.getValue(1));
return SDValue(N, 0);
}
}

// CSEL (LASTB P, Z), X, NE(ANY P) -> CLASTB P, X, Z
if (SDValue CondLast = foldCSELofLASTB(N, DAG))
return CondLast;
Expand Down
1 change: 1 addition & 0 deletions llvm/lib/Target/AArch64/AArch64ISelLowering.h
Original file line number Diff line number Diff line change
Expand Up @@ -604,6 +604,7 @@ class AArch64TargetLowering : public TargetLowering {
SDValue LowerSTORE(SDValue Op, SelectionDAG &DAG) const;
SDValue LowerStore128(SDValue Op, SelectionDAG &DAG) const;
SDValue LowerABS(SDValue Op, SelectionDAG &DAG) const;
SDValue LowerABD(SDValue Op, SelectionDAG &DAG) const;

SDValue LowerMGATHER(SDValue Op, SelectionDAG &DAG) const;
SDValue LowerMSCATTER(SDValue Op, SelectionDAG &DAG) const;
Expand Down
16 changes: 16 additions & 0 deletions llvm/lib/Target/AArch64/AArch64InstrInfo.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1730,12 +1730,20 @@ static unsigned sForm(MachineInstr &Instr) {

case AArch64::ADDSWrr:
case AArch64::ADDSWri:
case AArch64::ADDSWrx:
case AArch64::ADDSXrr:
case AArch64::ADDSXri:
case AArch64::ADDSXrx:
case AArch64::SUBSWrr:
case AArch64::SUBSWri:
case AArch64::SUBSWrx:
case AArch64::SUBSXrr:
case AArch64::SUBSXri:
case AArch64::SUBSXrx:
case AArch64::ADCSWr:
case AArch64::ADCSXr:
case AArch64::SBCSWr:
case AArch64::SBCSXr:
return Instr.getOpcode();

case AArch64::ADDWrr:
Expand All @@ -1746,6 +1754,10 @@ static unsigned sForm(MachineInstr &Instr) {
return AArch64::ADDSXrr;
case AArch64::ADDXri:
return AArch64::ADDSXri;
case AArch64::ADDWrx:
return AArch64::ADDSWrx;
case AArch64::ADDXrx:
return AArch64::ADDSXrx;
case AArch64::ADCWr:
return AArch64::ADCSWr;
case AArch64::ADCXr:
Expand All @@ -1758,6 +1770,10 @@ static unsigned sForm(MachineInstr &Instr) {
return AArch64::SUBSXrr;
case AArch64::SUBXri:
return AArch64::SUBSXri;
case AArch64::SUBWrx:
return AArch64::SUBSWrx;
case AArch64::SUBXrx:
return AArch64::SUBSXrx;
case AArch64::SBCWr:
return AArch64::SBCSWr;
case AArch64::SBCXr:
Expand Down
Loading
Loading