Skip to content

Commit

Permalink
[ARM] Add predicated mla reduction patterns
Browse files Browse the repository at this point in the history
Similar to 8fa824d but this time for MLA patterns, this selects
predicated vmlav/vmlava/vmlalv/vmlava instructions from
vecreduce.add(select(p, mul(x, y), 0)) nodes.

Differential Revision: https://reviews.llvm.org/D84102
  • Loading branch information
davemgreen committed Jul 23, 2020
1 parent cee60bb commit b37e922
Show file tree
Hide file tree
Showing 4 changed files with 467 additions and 2,163 deletions.
43 changes: 43 additions & 0 deletions llvm/lib/Target/ARM/ARMISelLowering.cpp
Expand Up @@ -1730,10 +1730,16 @@ const char *ARMTargetLowering::getTargetNodeName(unsigned Opcode) const {
case ARMISD::VADDLVApu: return "ARMISD::VADDLVApu";
case ARMISD::VMLAVs: return "ARMISD::VMLAVs";
case ARMISD::VMLAVu: return "ARMISD::VMLAVu";
case ARMISD::VMLAVps: return "ARMISD::VMLAVps";
case ARMISD::VMLAVpu: return "ARMISD::VMLAVpu";
case ARMISD::VMLALVs: return "ARMISD::VMLALVs";
case ARMISD::VMLALVu: return "ARMISD::VMLALVu";
case ARMISD::VMLALVps: return "ARMISD::VMLALVps";
case ARMISD::VMLALVpu: return "ARMISD::VMLALVpu";
case ARMISD::VMLALVAs: return "ARMISD::VMLALVAs";
case ARMISD::VMLALVAu: return "ARMISD::VMLALVAu";
case ARMISD::VMLALVAps: return "ARMISD::VMLALVAps";
case ARMISD::VMLALVApu: return "ARMISD::VMLALVApu";
case ARMISD::UMAAL: return "ARMISD::UMAAL";
case ARMISD::UMLAL: return "ARMISD::UMLAL";
case ARMISD::SMLAL: return "ARMISD::SMLAL";
Expand Down Expand Up @@ -12261,6 +12267,14 @@ static SDValue PerformADDVecReduce(SDNode *N,
return M;
if (SDValue M = MakeVecReduce(ARMISD::VMLALVu, ARMISD::VMLALVAu, N1, N0))
return M;
if (SDValue M = MakeVecReduce(ARMISD::VMLALVps, ARMISD::VMLALVAps, N0, N1))
return M;
if (SDValue M = MakeVecReduce(ARMISD::VMLALVpu, ARMISD::VMLALVApu, N0, N1))
return M;
if (SDValue M = MakeVecReduce(ARMISD::VMLALVps, ARMISD::VMLALVAps, N1, N0))
return M;
if (SDValue M = MakeVecReduce(ARMISD::VMLALVpu, ARMISD::VMLALVApu, N1, N0))
return M;
return SDValue();
}

Expand Down Expand Up @@ -14760,6 +14774,26 @@ static SDValue PerformVECREDUCE_ADDCombine(SDNode *N, SelectionDAG &DAG,
return true;
return false;
};
auto IsPredVMLAV = [&](MVT RetTy, unsigned ExtendCode, ArrayRef<MVT> ExtTypes,
SDValue &A, SDValue &B, SDValue &Mask) {
if (ResVT != RetTy || N0->getOpcode() != ISD::VSELECT ||
!ISD::isBuildVectorAllZeros(N0->getOperand(2).getNode()))
return false;
Mask = N0->getOperand(0);
SDValue Mul = N0->getOperand(1);
if (Mul->getOpcode() != ISD::MUL)
return false;
SDValue ExtA = Mul->getOperand(0);
SDValue ExtB = Mul->getOperand(1);
if (ExtA->getOpcode() != ExtendCode && ExtB->getOpcode() != ExtendCode)
return false;
A = ExtA->getOperand(0);
B = ExtB->getOperand(0);
if (A.getValueType() == B.getValueType() &&
llvm::any_of(ExtTypes, [&A](MVT Ty) { return A.getValueType() == Ty; }))
return true;
return false;
};
auto Create64bitNode = [&](unsigned Opcode, ArrayRef<SDValue> Ops) {
SDValue Node = DAG.getNode(Opcode, dl, {MVT::i32, MVT::i32}, Ops);
return DAG.getNode(ISD::BUILD_PAIR, dl, MVT::i64, Node,
Expand Down Expand Up @@ -14794,6 +14828,15 @@ static SDValue PerformVECREDUCE_ADDCombine(SDNode *N, SelectionDAG &DAG,
return Create64bitNode(ARMISD::VMLALVs, {A, B});
if (IsVMLAV(MVT::i64, ISD::ZERO_EXTEND, {MVT::v8i16, MVT::v4i32}, A, B))
return Create64bitNode(ARMISD::VMLALVu, {A, B});

if (IsPredVMLAV(MVT::i32, ISD::SIGN_EXTEND, {MVT::v8i16, MVT::v16i8}, A, B, Mask))
return DAG.getNode(ARMISD::VMLAVps, dl, ResVT, A, B, Mask);
if (IsPredVMLAV(MVT::i32, ISD::ZERO_EXTEND, {MVT::v8i16, MVT::v16i8}, A, B, Mask))
return DAG.getNode(ARMISD::VMLAVpu, dl, ResVT, A, B, Mask);
if (IsPredVMLAV(MVT::i64, ISD::SIGN_EXTEND, {MVT::v8i16, MVT::v4i32}, A, B, Mask))
return Create64bitNode(ARMISD::VMLALVps, {A, B, Mask});
if (IsPredVMLAV(MVT::i64, ISD::ZERO_EXTEND, {MVT::v8i16, MVT::v4i32}, A, B, Mask))
return Create64bitNode(ARMISD::VMLALVpu, {A, B, Mask});
return SDValue();
}

Expand Down
18 changes: 12 additions & 6 deletions llvm/lib/Target/ARM/ARMISelLowering.h
Expand Up @@ -229,12 +229,18 @@ class VectorType;
VADDLVpu,
VADDLVAps, // Same as VADDLVp[su] but with a v4i1 predicate mask
VADDLVApu,
VMLAVs,
VMLAVu,
VMLALVs,
VMLALVu,
VMLALVAs,
VMLALVAu,
VMLAVs, // sign- or zero-extend the elements of two vectors to i32, multiply them
VMLAVu, // and add the results together, returning an i32 of their sum
VMLAVps, // Same as VMLAV[su] with a v4i1 predicate mask
VMLAVpu,
VMLALVs, // Same as VMLAV but with i64, returning the low and
VMLALVu, // high 32-bit halves of the sum
VMLALVps, // Same as VMLALV[su] with a v4i1 predicate mask
VMLALVpu,
VMLALVAs, // Same as VMLALV but also add an input accumulator
VMLALVAu, // provided as low and high halves
VMLALVAps, // Same as VMLALVA[su] with a v4i1 predicate mask
VMLALVApu,

SMULWB, // Signed multiply word by half word, bottom
SMULWT, // Signed multiply word by half word, top
Expand Down
91 changes: 86 additions & 5 deletions llvm/lib/Target/ARM/ARMInstrMVE.td
Expand Up @@ -1105,12 +1105,28 @@ def SDTVecReduce2LA : SDTypeProfile<2, 4, [ // VMLALVA
SDTCisInt<0>, SDTCisInt<1>, SDTCisInt<2>, SDTCisInt<3>,
SDTCisVec<4>, SDTCisVec<5>
]>;
def SDTVecReduce2P : SDTypeProfile<1, 3, [ // VMLAV
SDTCisInt<0>, SDTCisVec<1>, SDTCisVec<2>, SDTCisVec<3>
]>;
def SDTVecReduce2LP : SDTypeProfile<2, 3, [ // VMLALV
SDTCisInt<0>, SDTCisInt<1>, SDTCisVec<2>, SDTCisVec<3>, SDTCisVec<4>
]>;
def SDTVecReduce2LAP : SDTypeProfile<2, 5, [ // VMLALVA
SDTCisInt<0>, SDTCisInt<1>, SDTCisInt<2>, SDTCisInt<3>,
SDTCisVec<4>, SDTCisVec<5>, SDTCisVec<6>
]>;
def ARMVMLAVs : SDNode<"ARMISD::VMLAVs", SDTVecReduce2>;
def ARMVMLAVu : SDNode<"ARMISD::VMLAVu", SDTVecReduce2>;
def ARMVMLALVs : SDNode<"ARMISD::VMLALVs", SDTVecReduce2L>;
def ARMVMLALVu : SDNode<"ARMISD::VMLALVu", SDTVecReduce2L>;
def ARMVMLALVAs : SDNode<"ARMISD::VMLALVAs", SDTVecReduce2LA>;
def ARMVMLALVAu : SDNode<"ARMISD::VMLALVAu", SDTVecReduce2LA>;
def ARMVMLALVAs : SDNode<"ARMISD::VMLALVAs", SDTVecReduce2LA>;
def ARMVMLALVAu : SDNode<"ARMISD::VMLALVAu", SDTVecReduce2LA>;
def ARMVMLAVps : SDNode<"ARMISD::VMLAVps", SDTVecReduce2P>;
def ARMVMLAVpu : SDNode<"ARMISD::VMLAVpu", SDTVecReduce2P>;
def ARMVMLALVps : SDNode<"ARMISD::VMLALVps", SDTVecReduce2LP>;
def ARMVMLALVpu : SDNode<"ARMISD::VMLALVpu", SDTVecReduce2LP>;
def ARMVMLALVAps : SDNode<"ARMISD::VMLALVAps", SDTVecReduce2LAP>;
def ARMVMLALVApu : SDNode<"ARMISD::VMLALVApu", SDTVecReduce2LAP>;

let Predicates = [HasMVEInt] in {
def : Pat<(i32 (vecreduce_add (mul (v4i32 MQPR:$src1), (v4i32 MQPR:$src2)))),
Expand All @@ -1129,22 +1145,68 @@ let Predicates = [HasMVEInt] in {
(i32 (MVE_VMLADAVu8 (v16i8 MQPR:$val1), (v16i8 MQPR:$val2)))>;

def : Pat<(i32 (add (i32 (vecreduce_add (mul (v4i32 MQPR:$src1), (v4i32 MQPR:$src2)))),
(i32 tGPREven:$src3))),
(i32 tGPREven:$src3))),
(i32 (MVE_VMLADAVau32 $src3, $src1, $src2))>;
def : Pat<(i32 (add (i32 (vecreduce_add (mul (v8i16 MQPR:$src1), (v8i16 MQPR:$src2)))),
(i32 tGPREven:$src3))),
(i32 tGPREven:$src3))),
(i32 (MVE_VMLADAVau16 $src3, $src1, $src2))>;
def : Pat<(i32 (add (ARMVMLAVs (v8i16 MQPR:$val1), (v8i16 MQPR:$val2)), tGPREven:$Rd)),
(i32 (MVE_VMLADAVas16 tGPREven:$Rd, (v8i16 MQPR:$val1), (v8i16 MQPR:$val2)))>;
def : Pat<(i32 (add (ARMVMLAVu (v8i16 MQPR:$val1), (v8i16 MQPR:$val2)), tGPREven:$Rd)),
(i32 (MVE_VMLADAVau16 tGPREven:$Rd, (v8i16 MQPR:$val1), (v8i16 MQPR:$val2)))>;
def : Pat<(i32 (add (i32 (vecreduce_add (mul (v16i8 MQPR:$src1), (v16i8 MQPR:$src2)))),
(i32 tGPREven:$src3))),
(i32 tGPREven:$src3))),
(i32 (MVE_VMLADAVau8 $src3, $src1, $src2))>;
def : Pat<(i32 (add (ARMVMLAVs (v16i8 MQPR:$val1), (v16i8 MQPR:$val2)), tGPREven:$Rd)),
(i32 (MVE_VMLADAVas8 tGPREven:$Rd, (v16i8 MQPR:$val1), (v16i8 MQPR:$val2)))>;
def : Pat<(i32 (add (ARMVMLAVu (v16i8 MQPR:$val1), (v16i8 MQPR:$val2)), tGPREven:$Rd)),
(i32 (MVE_VMLADAVau8 tGPREven:$Rd, (v16i8 MQPR:$val1), (v16i8 MQPR:$val2)))>;

// Predicated
def : Pat<(i32 (vecreduce_add (vselect (v4i1 VCCR:$pred),
(mul (v4i32 MQPR:$src1), (v4i32 MQPR:$src2)),
(v4i32 ARMimmAllZerosV)))),
(i32 (MVE_VMLADAVu32 $src1, $src2, ARMVCCThen, $pred))>;
def : Pat<(i32 (vecreduce_add (vselect (v8i1 VCCR:$pred),
(mul (v8i16 MQPR:$src1), (v8i16 MQPR:$src2)),
(v8i16 ARMimmAllZerosV)))),
(i32 (MVE_VMLADAVu16 $src1, $src2, ARMVCCThen, $pred))>;
def : Pat<(i32 (ARMVMLAVps (v8i16 MQPR:$val1), (v8i16 MQPR:$val2), (v8i1 VCCR:$pred))),
(i32 (MVE_VMLADAVs16 (v8i16 MQPR:$val1), (v8i16 MQPR:$val2), ARMVCCThen, $pred))>;
def : Pat<(i32 (ARMVMLAVpu (v8i16 MQPR:$val1), (v8i16 MQPR:$val2), (v8i1 VCCR:$pred))),
(i32 (MVE_VMLADAVu16 (v8i16 MQPR:$val1), (v8i16 MQPR:$val2), ARMVCCThen, $pred))>;
def : Pat<(i32 (vecreduce_add (vselect (v16i1 VCCR:$pred),
(mul (v16i8 MQPR:$src1), (v16i8 MQPR:$src2)),
(v16i8 ARMimmAllZerosV)))),
(i32 (MVE_VMLADAVu8 $src1, $src2, ARMVCCThen, $pred))>;
def : Pat<(i32 (ARMVMLAVps (v16i8 MQPR:$val1), (v16i8 MQPR:$val2), (v16i1 VCCR:$pred))),
(i32 (MVE_VMLADAVs8 (v16i8 MQPR:$val1), (v16i8 MQPR:$val2), ARMVCCThen, $pred))>;
def : Pat<(i32 (ARMVMLAVpu (v16i8 MQPR:$val1), (v16i8 MQPR:$val2), (v16i1 VCCR:$pred))),
(i32 (MVE_VMLADAVu8 (v16i8 MQPR:$val1), (v16i8 MQPR:$val2), ARMVCCThen, $pred))>;

def : Pat<(i32 (add (i32 (vecreduce_add (vselect (v4i1 VCCR:$pred),
(mul (v4i32 MQPR:$src1), (v4i32 MQPR:$src2)),
(v4i32 ARMimmAllZerosV)))),
(i32 tGPREven:$src3))),
(i32 (MVE_VMLADAVau32 $src3, $src1, $src2, ARMVCCThen, $pred))>;
def : Pat<(i32 (add (i32 (vecreduce_add (vselect (v8i1 VCCR:$pred),
(mul (v8i16 MQPR:$src1), (v8i16 MQPR:$src2)),
(v8i16 ARMimmAllZerosV)))),
(i32 tGPREven:$src3))),
(i32 (MVE_VMLADAVau16 $src3, $src1, $src2, ARMVCCThen, $pred))>;
def : Pat<(i32 (add (ARMVMLAVps (v8i16 MQPR:$val1), (v8i16 MQPR:$val2), (v8i1 VCCR:$pred)), tGPREven:$Rd)),
(i32 (MVE_VMLADAVas16 tGPREven:$Rd, (v8i16 MQPR:$val1), (v8i16 MQPR:$val2), ARMVCCThen, $pred))>;
def : Pat<(i32 (add (ARMVMLAVpu (v8i16 MQPR:$val1), (v8i16 MQPR:$val2), (v8i1 VCCR:$pred)), tGPREven:$Rd)),
(i32 (MVE_VMLADAVau16 tGPREven:$Rd, (v8i16 MQPR:$val1), (v8i16 MQPR:$val2), ARMVCCThen, $pred))>;
def : Pat<(i32 (add (i32 (vecreduce_add (vselect (v16i1 VCCR:$pred),
(mul (v16i8 MQPR:$src1), (v16i8 MQPR:$src2)),
(v16i8 ARMimmAllZerosV)))),
(i32 tGPREven:$src3))),
(i32 (MVE_VMLADAVau8 $src3, $src1, $src2, ARMVCCThen, $pred))>;
def : Pat<(i32 (add (ARMVMLAVps (v16i8 MQPR:$val1), (v16i8 MQPR:$val2), (v16i1 VCCR:$pred)), tGPREven:$Rd)),
(i32 (MVE_VMLADAVas8 tGPREven:$Rd, (v16i8 MQPR:$val1), (v16i8 MQPR:$val2), ARMVCCThen, $pred))>;
def : Pat<(i32 (add (ARMVMLAVpu (v16i8 MQPR:$val1), (v16i8 MQPR:$val2), (v16i1 VCCR:$pred)), tGPREven:$Rd)),
(i32 (MVE_VMLADAVau8 tGPREven:$Rd, (v16i8 MQPR:$val1), (v16i8 MQPR:$val2), ARMVCCThen, $pred))>;
}

// vmlav aliases vmladav
Expand Down Expand Up @@ -1264,6 +1326,25 @@ let Predicates = [HasMVEInt] in {
(MVE_VMLALDAVas16 tGPREven:$Rda, tGPROdd:$Rdb, (v8i16 MQPR:$val1), (v8i16 MQPR:$val2))>;
def : Pat<(ARMVMLALVAu tGPREven:$Rda, tGPROdd:$Rdb, (v8i16 MQPR:$val1), (v8i16 MQPR:$val2)),
(MVE_VMLALDAVau16 tGPREven:$Rda, tGPROdd:$Rdb, (v8i16 MQPR:$val1), (v8i16 MQPR:$val2))>;

// Predicated
def : Pat<(ARMVMLALVps (v4i32 MQPR:$val1), (v4i32 MQPR:$val2), (v4i1 VCCR:$pred)),
(MVE_VMLALDAVs32 (v4i32 MQPR:$val1), (v4i32 MQPR:$val2), ARMVCCThen, $pred)>;
def : Pat<(ARMVMLALVpu (v4i32 MQPR:$val1), (v4i32 MQPR:$val2), (v4i1 VCCR:$pred)),
(MVE_VMLALDAVu32 (v4i32 MQPR:$val1), (v4i32 MQPR:$val2), ARMVCCThen, $pred)>;
def : Pat<(ARMVMLALVps (v8i16 MQPR:$val1), (v8i16 MQPR:$val2), (v8i1 VCCR:$pred)),
(MVE_VMLALDAVs16 (v8i16 MQPR:$val1), (v8i16 MQPR:$val2), ARMVCCThen, $pred)>;
def : Pat<(ARMVMLALVpu (v8i16 MQPR:$val1), (v8i16 MQPR:$val2), (v8i1 VCCR:$pred)),
(MVE_VMLALDAVu16 (v8i16 MQPR:$val1), (v8i16 MQPR:$val2), ARMVCCThen, $pred)>;

def : Pat<(ARMVMLALVAps tGPREven:$Rda, tGPROdd:$Rdb, (v4i32 MQPR:$val1), (v4i32 MQPR:$val2), (v4i1 VCCR:$pred)),
(MVE_VMLALDAVas32 tGPREven:$Rda, tGPROdd:$Rdb, (v4i32 MQPR:$val1), (v4i32 MQPR:$val2), ARMVCCThen, $pred)>;
def : Pat<(ARMVMLALVApu tGPREven:$Rda, tGPROdd:$Rdb, (v4i32 MQPR:$val1), (v4i32 MQPR:$val2), (v4i1 VCCR:$pred)),
(MVE_VMLALDAVau32 tGPREven:$Rda, tGPROdd:$Rdb, (v4i32 MQPR:$val1), (v4i32 MQPR:$val2), ARMVCCThen, $pred)>;
def : Pat<(ARMVMLALVAps tGPREven:$Rda, tGPROdd:$Rdb, (v8i16 MQPR:$val1), (v8i16 MQPR:$val2), (v8i1 VCCR:$pred)),
(MVE_VMLALDAVas16 tGPREven:$Rda, tGPROdd:$Rdb, (v8i16 MQPR:$val1), (v8i16 MQPR:$val2), ARMVCCThen, $pred)>;
def : Pat<(ARMVMLALVApu tGPREven:$Rda, tGPROdd:$Rdb, (v8i16 MQPR:$val1), (v8i16 MQPR:$val2), (v8i1 VCCR:$pred)),
(MVE_VMLALDAVau16 tGPREven:$Rda, tGPROdd:$Rdb, (v8i16 MQPR:$val1), (v8i16 MQPR:$val2), ARMVCCThen, $pred)>;
}

// vmlalv aliases vmlaldav
Expand Down

0 comments on commit b37e922

Please sign in to comment.