Skip to content

Commit

Permalink
[ARM] Add codegen for SMMULR, SMMLAR and SMMLSR
Browse files Browse the repository at this point in the history
This patch teaches the Arm back-end to generate the SMMULR, SMMLAR and SMMLSR
instructions from equivalent IR patterns.

Differential Revision: https://reviews.llvm.org/D41775

llvm-svn: 322361
  • Loading branch information
avieira-arm committed Jan 12, 2018
1 parent 26b9de9 commit 5627c21
Show file tree
Hide file tree
Showing 5 changed files with 292 additions and 61 deletions.
153 changes: 98 additions & 55 deletions llvm/lib/Target/ARM/ARMISelLowering.cpp
Expand Up @@ -1337,6 +1337,8 @@ const char *ARMTargetLowering::getTargetNodeName(unsigned Opcode) const {
case ARMISD::SMLALDX: return "ARMISD::SMLALDX";
case ARMISD::SMLSLD: return "ARMISD::SMLSLD";
case ARMISD::SMLSLDX: return "ARMISD::SMLSLDX";
case ARMISD::SMMLAR: return "ARMISD::SMMLAR";
case ARMISD::SMMLSR: return "ARMISD::SMMLSR";
case ARMISD::BUILD_VECTOR: return "ARMISD::BUILD_VECTOR";
case ARMISD::BFI: return "ARMISD::BFI";
case ARMISD::VORRIMM: return "ARMISD::VORRIMM";
Expand Down Expand Up @@ -9860,7 +9862,7 @@ static SDValue AddCombineTo64BitSMLAL16(SDNode *AddcNode, SDNode *AddeNode,
return resNode;
}

static SDValue AddCombineTo64bitMLAL(SDNode *AddeNode,
static SDValue AddCombineTo64bitMLAL(SDNode *AddeSubeNode,
TargetLowering::DAGCombinerInfo &DCI,
const ARMSubtarget *Subtarget) {
// Look for multiply add opportunities.
Expand All @@ -9877,49 +9879,61 @@ static SDValue AddCombineTo64bitMLAL(SDNode *AddeNode,
// V V
// ADDE <- hiAdd
//
assert(AddeNode->getOpcode() == ARMISD::ADDE && "Expect an ADDE");

assert(AddeNode->getNumOperands() == 3 &&
AddeNode->getOperand(2).getValueType() == MVT::i32 &&
// In the special case where only the higher part of a signed result is used
// and the add to the low part of the result of ISD::UMUL_LOHI adds or subtracts
// a constant with the exact value of 0x80000000, we recognize we are dealing
// with a "rounded multiply and add" (or subtract) and transform it into
// either a ARMISD::SMMLAR or ARMISD::SMMLSR respectively.

assert((AddeSubeNode->getOpcode() == ARMISD::ADDE ||
AddeSubeNode->getOpcode() == ARMISD::SUBE) &&
"Expect an ADDE or SUBE");

assert(AddeSubeNode->getNumOperands() == 3 &&
AddeSubeNode->getOperand(2).getValueType() == MVT::i32 &&
"ADDE node has the wrong inputs");

// Check that we are chained to the right ADDC node.
SDNode* AddcNode = AddeNode->getOperand(2).getNode();
if (AddcNode->getOpcode() != ARMISD::ADDC)
// Check that we are chained to the right ADDC or SUBC node.
SDNode *AddcSubcNode = AddeSubeNode->getOperand(2).getNode();
if ((AddeSubeNode->getOpcode() == ARMISD::ADDE &&
AddcSubcNode->getOpcode() != ARMISD::ADDC) ||
(AddeSubeNode->getOpcode() == ARMISD::SUBE &&
AddcSubcNode->getOpcode() != ARMISD::SUBC))
return SDValue();

SDValue AddcOp0 = AddcNode->getOperand(0);
SDValue AddcOp1 = AddcNode->getOperand(1);
SDValue AddcSubcOp0 = AddcSubcNode->getOperand(0);
SDValue AddcSubcOp1 = AddcSubcNode->getOperand(1);

// Check if the two operands are from the same mul_lohi node.
if (AddcOp0.getNode() == AddcOp1.getNode())
if (AddcSubcOp0.getNode() == AddcSubcOp1.getNode())
return SDValue();

assert(AddcNode->getNumValues() == 2 &&
AddcNode->getValueType(0) == MVT::i32 &&
assert(AddcSubcNode->getNumValues() == 2 &&
AddcSubcNode->getValueType(0) == MVT::i32 &&
"Expect ADDC with two result values. First: i32");

// Check that the ADDC adds the low result of the S/UMUL_LOHI. If not, it
// maybe a SMLAL which multiplies two 16-bit values.
if (AddcOp0->getOpcode() != ISD::UMUL_LOHI &&
AddcOp0->getOpcode() != ISD::SMUL_LOHI &&
AddcOp1->getOpcode() != ISD::UMUL_LOHI &&
AddcOp1->getOpcode() != ISD::SMUL_LOHI)
return AddCombineTo64BitSMLAL16(AddcNode, AddeNode, DCI, Subtarget);
if (AddeSubeNode->getOpcode() == ARMISD::ADDE &&
AddcSubcOp0->getOpcode() != ISD::UMUL_LOHI &&
AddcSubcOp0->getOpcode() != ISD::SMUL_LOHI &&
AddcSubcOp1->getOpcode() != ISD::UMUL_LOHI &&
AddcSubcOp1->getOpcode() != ISD::SMUL_LOHI)
return AddCombineTo64BitSMLAL16(AddcSubcNode, AddeSubeNode, DCI, Subtarget);

// Check for the triangle shape.
SDValue AddeOp0 = AddeNode->getOperand(0);
SDValue AddeOp1 = AddeNode->getOperand(1);
SDValue AddeSubeOp0 = AddeSubeNode->getOperand(0);
SDValue AddeSubeOp1 = AddeSubeNode->getOperand(1);

// Make sure that the ADDE operands are not coming from the same node.
if (AddeOp0.getNode() == AddeOp1.getNode())
// Make sure that the ADDE/SUBE operands are not coming from the same node.
if (AddeSubeOp0.getNode() == AddeSubeOp1.getNode())
return SDValue();

// Find the MUL_LOHI node walking up ADDE's operands.
// Find the MUL_LOHI node walking up ADDE/SUBE's operands.
bool IsLeftOperandMUL = false;
SDValue MULOp = findMUL_LOHI(AddeOp0);
SDValue MULOp = findMUL_LOHI(AddeSubeOp0);
if (MULOp == SDValue())
MULOp = findMUL_LOHI(AddeOp1);
MULOp = findMUL_LOHI(AddeSubeOp1);
else
IsLeftOperandMUL = true;
if (MULOp == SDValue())
Expand All @@ -9930,63 +9944,88 @@ static SDValue AddCombineTo64bitMLAL(SDNode *AddeNode,
unsigned FinalOpc = (Opc == ISD::SMUL_LOHI) ? ARMISD::SMLAL : ARMISD::UMLAL;

// Figure out the high and low input values to the MLAL node.
SDValue* HiAdd = nullptr;
SDValue* LoMul = nullptr;
SDValue* LowAdd = nullptr;
SDValue *HiAddSub = nullptr;
SDValue *LoMul = nullptr;
SDValue *LowAddSub = nullptr;

// Ensure that ADDE is from high result of ISD::xMUL_LOHI.
if ((AddeOp0 != MULOp.getValue(1)) && (AddeOp1 != MULOp.getValue(1)))
// Ensure that ADDE/SUBE is from high result of ISD::xMUL_LOHI.
if ((AddeSubeOp0 != MULOp.getValue(1)) && (AddeSubeOp1 != MULOp.getValue(1)))
return SDValue();

if (IsLeftOperandMUL)
HiAdd = &AddeOp1;
HiAddSub = &AddeSubeOp1;
else
HiAdd = &AddeOp0;

HiAddSub = &AddeSubeOp0;

// Ensure that LoMul and LowAdd are taken from correct ISD::SMUL_LOHI node
// whose low result is fed to the ADDC we are checking.
// Ensure that LoMul and LowAddSub are taken from correct ISD::SMUL_LOHI node
// whose low result is fed to the ADDC/SUBC we are checking.

if (AddcOp0 == MULOp.getValue(0)) {
LoMul = &AddcOp0;
LowAdd = &AddcOp1;
if (AddcSubcOp0 == MULOp.getValue(0)) {
LoMul = &AddcSubcOp0;
LowAddSub = &AddcSubcOp1;
}
if (AddcOp1 == MULOp.getValue(0)) {
LoMul = &AddcOp1;
LowAdd = &AddcOp0;
if (AddcSubcOp1 == MULOp.getValue(0)) {
LoMul = &AddcSubcOp1;
LowAddSub = &AddcSubcOp0;
}

if (!LoMul)
return SDValue();

// If HiAdd is the same node as ADDC or is a predecessor of ADDC the
// replacement below will create a cycle.
if (AddcNode == HiAdd->getNode() ||
AddcNode->isPredecessorOf(HiAdd->getNode()))
// If HiAddSub is the same node as ADDC/SUBC or is a predecessor of ADDC/SUBC
// the replacement below will create a cycle.
if (AddcSubcNode == HiAddSub->getNode() ||
AddcSubcNode->isPredecessorOf(HiAddSub->getNode()))
return SDValue();

// Create the merged node.
SelectionDAG &DAG = DCI.DAG;

// Build operand list.
// Start building operand list.
SmallVector<SDValue, 8> Ops;
Ops.push_back(LoMul->getOperand(0));
Ops.push_back(LoMul->getOperand(1));
Ops.push_back(*LowAdd);
Ops.push_back(*HiAdd);

SDValue MLALNode = DAG.getNode(FinalOpc, SDLoc(AddcNode),
// Check whether we can use SMMLAR, SMMLSR or SMMULR instead. For this to be
// the case, we must be doing signed multiplication and only use the higher
// part of the result of the MLAL, furthermore the LowAddSub must be a constant
// addition or subtraction with the value of 0x800000.
if (Subtarget->hasV6Ops() && Subtarget->hasDSP() && Subtarget->useMulOps() &&
FinalOpc == ARMISD::SMLAL && !AddeSubeNode->hasAnyUseOfValue(1) &&
LowAddSub->getNode()->getOpcode() == ISD::Constant &&
static_cast<ConstantSDNode *>(LowAddSub->getNode())->getZExtValue() ==
0x80000000) {
Ops.push_back(*HiAddSub);
if (AddcSubcNode->getOpcode() == ARMISD::SUBC) {
FinalOpc = ARMISD::SMMLSR;
} else {
FinalOpc = ARMISD::SMMLAR;
}
SDValue NewNode = DAG.getNode(FinalOpc, SDLoc(AddcSubcNode), MVT::i32, Ops);
DAG.ReplaceAllUsesOfValueWith(SDValue(AddeSubeNode, 0), NewNode);

return SDValue(AddeSubeNode, 0);
} else if (AddcSubcNode->getOpcode() == ARMISD::SUBC)
// SMMLS is generated during instruction selection and the rest of this
// function can not handle the case where AddcSubcNode is a SUBC.
return SDValue();

// Finish building the operand list for {U/S}MLAL
Ops.push_back(*LowAddSub);
Ops.push_back(*HiAddSub);

SDValue MLALNode = DAG.getNode(FinalOpc, SDLoc(AddcSubcNode),
DAG.getVTList(MVT::i32, MVT::i32), Ops);

// Replace the ADDs' nodes uses by the MLA node's values.
SDValue HiMLALResult(MLALNode.getNode(), 1);
DAG.ReplaceAllUsesOfValueWith(SDValue(AddeNode, 0), HiMLALResult);
DAG.ReplaceAllUsesOfValueWith(SDValue(AddeSubeNode, 0), HiMLALResult);

SDValue LoMLALResult(MLALNode.getNode(), 0);
DAG.ReplaceAllUsesOfValueWith(SDValue(AddcNode, 0), LoMLALResult);
DAG.ReplaceAllUsesOfValueWith(SDValue(AddcSubcNode, 0), LoMLALResult);

// Return original node to notify the driver to stop replacing.
return SDValue(AddeNode, 0);
return SDValue(AddeSubeNode, 0);
}

static SDValue AddCombineTo64bitUMAAL(SDNode *AddeNode,
Expand Down Expand Up @@ -10098,9 +10137,11 @@ static SDValue PerformAddcSubcCombine(SDNode *N,
return SDValue();
}

static SDValue PerformAddeSubeCombine(SDNode *N, SelectionDAG &DAG,
static SDValue PerformAddeSubeCombine(SDNode *N,
TargetLowering::DAGCombinerInfo &DCI,
const ARMSubtarget *Subtarget) {
if (Subtarget->isThumb1Only()) {
SelectionDAG &DAG = DCI.DAG;
SDValue RHS = N->getOperand(1);
if (ConstantSDNode *C = dyn_cast<ConstantSDNode>(RHS)) {
int64_t imm = C->getSExtValue();
Expand All @@ -10118,6 +10159,8 @@ static SDValue PerformAddeSubeCombine(SDNode *N, SelectionDAG &DAG,
N->getOperand(0), RHS, N->getOperand(2));
}
}
} else if (N->getOperand(1)->getOpcode() == ISD::SMUL_LOHI) {
return AddCombineTo64bitMLAL(N, DCI, Subtarget);
}
return SDValue();
}
Expand All @@ -10130,7 +10173,7 @@ static SDValue PerformADDECombine(SDNode *N,
const ARMSubtarget *Subtarget) {
// Only ARM and Thumb2 support UMLAL/SMLAL.
if (Subtarget->isThumb1Only())
return PerformAddeSubeCombine(N, DCI.DAG, Subtarget);
return PerformAddeSubeCombine(N, DCI, Subtarget);

// Only perform the checks after legalize when the pattern is available.
if (DCI.isBeforeLegalize()) return SDValue();
Expand Down Expand Up @@ -12338,7 +12381,7 @@ SDValue ARMTargetLowering::PerformDAGCombine(SDNode *N,
case ISD::AND: return PerformANDCombine(N, DCI, Subtarget);
case ARMISD::ADDC:
case ARMISD::SUBC: return PerformAddcSubcCombine(N, DCI, Subtarget);
case ARMISD::SUBE: return PerformAddeSubeCombine(N, DCI.DAG, Subtarget);
case ARMISD::SUBE: return PerformAddeSubeCombine(N, DCI, Subtarget);
case ARMISD::BFI: return PerformBFICombine(N, DCI);
case ARMISD::VMOVRRD: return PerformVMOVRRDCombine(N, DCI, Subtarget);
case ARMISD::VMOVDRR: return PerformVMOVDRRCombine(N, DCI.DAG);
Expand Down
2 changes: 2 additions & 0 deletions llvm/lib/Target/ARM/ARMISelLowering.h
Expand Up @@ -203,6 +203,8 @@ class VectorType;
SMLALDX, // Signed multiply accumulate long dual exchange
SMLSLD, // Signed multiply subtract long dual
SMLSLDX, // Signed multiply subtract long dual exchange
SMMLAR, // Signed multiply long, round and add
SMMLSR, // Signed multiply long, subtract and round

// Operands of the standard BUILD_VECTOR node are not legalized, which
// is fine if BUILD_VECTORs are always lowered to shuffles or other
Expand Down
17 changes: 14 additions & 3 deletions llvm/lib/Target/ARM/ARMInstrInfo.td
Expand Up @@ -105,6 +105,14 @@ def ARMSmlaldx : SDNode<"ARMISD::SMLALDX", SDT_LongMac>;
def ARMSmlsld : SDNode<"ARMISD::SMLSLD", SDT_LongMac>;
def ARMSmlsldx : SDNode<"ARMISD::SMLSLDX", SDT_LongMac>;

def SDT_MulHSR : SDTypeProfile<1, 3, [SDTCisVT<0,i32>,
SDTCisSameAs<0, 1>,
SDTCisSameAs<0, 2>,
SDTCisSameAs<0, 3>]>;

def ARMsmmlar : SDNode<"ARMISD::SMMLAR", SDT_MulHSR>;
def ARMsmmlsr : SDNode<"ARMISD::SMMLSR", SDT_MulHSR>;

// Node definitions.
def ARMWrapper : SDNode<"ARMISD::Wrapper", SDTIntUnaryOp>;
def ARMWrapperPIC : SDNode<"ARMISD::WrapperPIC", SDTIntUnaryOp>;
Expand Down Expand Up @@ -4143,7 +4151,8 @@ def SMMUL : AMul2I <0b0111010, 0b0001, (outs GPR:$Rd), (ins GPR:$Rn, GPR:$Rm),
}

def SMMULR : AMul2I <0b0111010, 0b0011, (outs GPR:$Rd), (ins GPR:$Rn, GPR:$Rm),
IIC_iMUL32, "smmulr", "\t$Rd, $Rn, $Rm", []>,
IIC_iMUL32, "smmulr", "\t$Rd, $Rn, $Rm",
[(set GPR:$Rd, (ARMsmmlar GPR:$Rn, GPR:$Rm, (i32 0)))]>,
Requires<[IsARM, HasV6]>,
Sched<[WriteMUL32, ReadMUL, ReadMUL]> {
let Inst{15-12} = 0b1111;
Expand All @@ -4158,7 +4167,8 @@ def SMMLA : AMul2Ia <0b0111010, 0b0001, (outs GPR:$Rd),

def SMMLAR : AMul2Ia <0b0111010, 0b0011, (outs GPR:$Rd),
(ins GPR:$Rn, GPR:$Rm, GPR:$Ra),
IIC_iMAC32, "smmlar", "\t$Rd, $Rn, $Rm, $Ra", []>,
IIC_iMAC32, "smmlar", "\t$Rd, $Rn, $Rm, $Ra",
[(set GPR:$Rd, (ARMsmmlar GPR:$Rn, GPR:$Rm, GPR:$Ra))]>,
Requires<[IsARM, HasV6]>,
Sched<[WriteMAC32, ReadMUL, ReadMUL, ReadMAC]>;

Expand All @@ -4170,7 +4180,8 @@ def SMMLS : AMul2Ia <0b0111010, 0b1101, (outs GPR:$Rd),

def SMMLSR : AMul2Ia <0b0111010, 0b1111, (outs GPR:$Rd),
(ins GPR:$Rn, GPR:$Rm, GPR:$Ra),
IIC_iMAC32, "smmlsr", "\t$Rd, $Rn, $Rm, $Ra", []>,
IIC_iMAC32, "smmlsr", "\t$Rd, $Rn, $Rm, $Ra",
[(set GPR:$Rd, (ARMsmmlsr GPR:$Rn, GPR:$Rm, GPR:$Ra))]>,
Requires<[IsARM, HasV6]>,
Sched<[WriteMAC32, ReadMUL, ReadMUL, ReadMAC]>;

Expand Down
10 changes: 7 additions & 3 deletions llvm/lib/Target/ARM/ARMInstrThumb2.td
Expand Up @@ -2661,7 +2661,9 @@ class T2SMMUL<bits<4> op7_4, string opc, list<dag> pattern>
}
def t2SMMUL : T2SMMUL<0b0000, "smmul", [(set rGPR:$Rd, (mulhs rGPR:$Rn,
rGPR:$Rm))]>;
def t2SMMULR : T2SMMUL<0b0001, "smmulr", []>;
def t2SMMULR :
T2SMMUL<0b0001, "smmulr",
[(set rGPR:$Rd, (ARMsmmlar rGPR:$Rn, rGPR:$Rm, (i32 0)))]>;

class T2FourRegSMMLA<bits<3> op22_20, bits<4> op7_4, string opc,
list<dag> pattern>
Expand All @@ -2677,9 +2679,11 @@ class T2FourRegSMMLA<bits<3> op22_20, bits<4> op7_4, string opc,

def t2SMMLA : T2FourRegSMMLA<0b101, 0b0000, "smmla",
[(set rGPR:$Rd, (add (mulhs rGPR:$Rm, rGPR:$Rn), rGPR:$Ra))]>;
def t2SMMLAR: T2FourRegSMMLA<0b101, 0b0001, "smmlar", []>;
def t2SMMLAR: T2FourRegSMMLA<0b101, 0b0001, "smmlar",
[(set rGPR:$Rd, (ARMsmmlar rGPR:$Rn, rGPR:$Rm, rGPR:$Ra))]>;
def t2SMMLS: T2FourRegSMMLA<0b110, 0b0000, "smmls", []>;
def t2SMMLSR: T2FourRegSMMLA<0b110, 0b0001, "smmlsr", []>;
def t2SMMLSR: T2FourRegSMMLA<0b110, 0b0001, "smmlsr",
[(set rGPR:$Rd, (ARMsmmlsr rGPR:$Rn, rGPR:$Rm, rGPR:$Ra))]>;

class T2ThreeRegSMUL<bits<3> op22_20, bits<2> op5_4, string opc,
list<dag> pattern>
Expand Down

0 comments on commit 5627c21

Please sign in to comment.