Skip to content

Commit

Permalink
[ARM] Recognize "double extend" reduction patterns
Browse files Browse the repository at this point in the history
We can sometimes get code that does:
  xe = zext i16 x to i32
  ye = zext i16 y to i32
  m = mul i32 xe, ye
  me = zext i32 m to i64
  r = vecreduce.add(me)
This "double extend" can trip up the reduction identification, but
should give identical results.

This extends the pattern matching to handle them.

Differential Revision: https://reviews.llvm.org/D87276
  • Loading branch information
davemgreen committed Sep 12, 2020
1 parent 36e2e2e commit c437446
Show file tree
Hide file tree
Showing 3 changed files with 151 additions and 2,338 deletions.
31 changes: 28 additions & 3 deletions llvm/lib/Target/ARM/ARMISelLowering.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -14765,10 +14765,25 @@ static SDValue PerformVECREDUCE_ADDCombine(SDNode *N, SelectionDAG &DAG,
};
auto IsVMLAV = [&](MVT RetTy, unsigned ExtendCode, ArrayRef<MVT> ExtTypes,
SDValue &A, SDValue &B) {
if (ResVT != RetTy || N0->getOpcode() != ISD::MUL)
// For a vmla we are trying to match a larger pattern:
// ExtA = sext/zext A
// ExtB = sext/zext B
// Mul = mul ExtA, ExtB
// vecreduce.add Mul
// There might also be en extra extend between the mul and the addreduce, so
// long as the bitwidth is high enough to make them equivalent (for example
// original v8i16 might be mul at v8i32 and the reduce happens at v8i64).
if (ResVT != RetTy)
return false;
SDValue ExtA = N0->getOperand(0);
SDValue ExtB = N0->getOperand(1);
SDValue Mul = N0;
if (Mul->getOpcode() == ExtendCode &&
Mul->getOperand(0).getScalarValueSizeInBits() * 2 >=
ResVT.getScalarSizeInBits())
Mul = Mul->getOperand(0);
if (Mul->getOpcode() != ISD::MUL)
return false;
SDValue ExtA = Mul->getOperand(0);
SDValue ExtB = Mul->getOperand(1);
if (ExtA->getOpcode() != ExtendCode && ExtB->getOpcode() != ExtendCode)
return false;
A = ExtA->getOperand(0);
Expand All @@ -14780,11 +14795,21 @@ static SDValue PerformVECREDUCE_ADDCombine(SDNode *N, SelectionDAG &DAG,
};
auto IsPredVMLAV = [&](MVT RetTy, unsigned ExtendCode, ArrayRef<MVT> ExtTypes,
SDValue &A, SDValue &B, SDValue &Mask) {
// Same as the pattern above with a select for the zero predicated lanes
// ExtA = sext/zext A
// ExtB = sext/zext B
// Mul = mul ExtA, ExtB
// N0 = select Mask, Mul, 0
// vecreduce.add N0
if (ResVT != RetTy || N0->getOpcode() != ISD::VSELECT ||
!ISD::isBuildVectorAllZeros(N0->getOperand(2).getNode()))
return false;
Mask = N0->getOperand(0);
SDValue Mul = N0->getOperand(1);
if (Mul->getOpcode() == ExtendCode &&
Mul->getOperand(0).getScalarValueSizeInBits() * 2 >=
ResVT.getScalarSizeInBits())
Mul = Mul->getOperand(0);
if (Mul->getOpcode() != ISD::MUL)
return false;
SDValue ExtA = Mul->getOperand(0);
Expand Down
Loading

0 comments on commit c437446

Please sign in to comment.