Skip to content

Commit

Permalink
[SVE][CodeGen] Lower scalable integer vector reductions
Browse files Browse the repository at this point in the history
This patch uses the existing LowerFixedLengthReductionToSVE function to also lower
scalable vector reductions. A separate function has been added to lower VECREDUCE_AND
& VECREDUCE_OR operations with predicate types using ptest.

Lowering scalable floating-point reductions will be addressed in a follow up patch,
for now these will hit the assertion added to expandVecReduce() in TargetLowering.

Reviewed By: paulwalker-arm

Differential Revision: https://reviews.llvm.org/D89382
  • Loading branch information
kmclaughlin-arm committed Nov 4, 2020
1 parent f202d32 commit f2412d3
Show file tree
Hide file tree
Showing 10 changed files with 1,284 additions and 28 deletions.
2 changes: 1 addition & 1 deletion llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp
Expand Up @@ -20857,7 +20857,7 @@ SDValue DAGCombiner::visitVECREDUCE(SDNode *N) {
unsigned Opcode = N->getOpcode();

// VECREDUCE over 1-element vector is just an extract.
if (VT.getVectorNumElements() == 1) {
if (VT.getVectorElementCount().isScalar()) {
SDLoc dl(N);
SDValue Res =
DAG.getNode(ISD::EXTRACT_VECTOR_ELT, dl, VT.getVectorElementType(), N0,
Expand Down
33 changes: 29 additions & 4 deletions llvm/lib/CodeGen/SelectionDAG/SelectionDAG.cpp
Expand Up @@ -3323,6 +3323,9 @@ KnownBits SelectionDAG::computeKnownBits(SDValue Op, const APInt &DemandedElts,
SDValue InVec = Op.getOperand(0);
SDValue EltNo = Op.getOperand(1);
EVT VecVT = InVec.getValueType();
// computeKnownBits not yet implemented for scalable vectors.
if (VecVT.isScalableVector())
break;
const unsigned EltBitWidth = VecVT.getScalarSizeInBits();
const unsigned NumSrcElts = VecVT.getVectorNumElements();

Expand Down Expand Up @@ -4809,6 +4812,16 @@ SDValue SelectionDAG::getNode(unsigned Opcode, const SDLoc &DL, EVT VT,
case ISD::VSCALE:
assert(VT == Operand.getValueType() && "Unexpected VT!");
break;
case ISD::VECREDUCE_SMIN:
case ISD::VECREDUCE_UMAX:
if (Operand.getValueType().getScalarType() == MVT::i1)
return getNode(ISD::VECREDUCE_OR, DL, VT, Operand);
break;
case ISD::VECREDUCE_SMAX:
case ISD::VECREDUCE_UMIN:
if (Operand.getValueType().getScalarType() == MVT::i1)
return getNode(ISD::VECREDUCE_AND, DL, VT, Operand);
break;
}

SDNode *N;
Expand Down Expand Up @@ -5318,10 +5331,6 @@ SDValue SelectionDAG::getNode(unsigned Opcode, const SDLoc &DL, EVT VT,
case ISD::MULHS:
case ISD::SDIV:
case ISD::SREM:
case ISD::SMIN:
case ISD::SMAX:
case ISD::UMIN:
case ISD::UMAX:
case ISD::SADDSAT:
case ISD::SSUBSAT:
case ISD::UADDSAT:
Expand All @@ -5330,6 +5339,22 @@ SDValue SelectionDAG::getNode(unsigned Opcode, const SDLoc &DL, EVT VT,
assert(N1.getValueType() == N2.getValueType() &&
N1.getValueType() == VT && "Binary operator types must match!");
break;
case ISD::SMIN:
case ISD::UMAX:
assert(VT.isInteger() && "This operator does not apply to FP types!");
assert(N1.getValueType() == N2.getValueType() &&
N1.getValueType() == VT && "Binary operator types must match!");
if (VT.isVector() && VT.getVectorElementType() == MVT::i1)
return getNode(ISD::OR, DL, VT, N1, N2);
break;
case ISD::SMAX:
case ISD::UMIN:
assert(VT.isInteger() && "This operator does not apply to FP types!");
assert(N1.getValueType() == N2.getValueType() &&
N1.getValueType() == VT && "Binary operator types must match!");
if (VT.isVector() && VT.getVectorElementType() == MVT::i1)
return getNode(ISD::AND, DL, VT, N1, N2);
break;
case ISD::FADD:
case ISD::FSUB:
case ISD::FMUL:
Expand Down
4 changes: 4 additions & 0 deletions llvm/lib/CodeGen/SelectionDAG/TargetLowering.cpp
Expand Up @@ -8000,6 +8000,10 @@ SDValue TargetLowering::expandVecReduce(SDNode *Node, SelectionDAG &DAG) const {
SDValue Op = Node->getOperand(0);
EVT VT = Op.getValueType();

if (VT.isScalableVector())
report_fatal_error(
"Expanding reductions for scalable vectors is undefined.");

// Try to use a shuffle reduction for power of two vectors.
if (VT.isPow2VectorType()) {
while (VT.getVectorNumElements() > 1) {
Expand Down
86 changes: 69 additions & 17 deletions llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
Expand Up @@ -1013,6 +1013,14 @@ AArch64TargetLowering::AArch64TargetLowering(const TargetMachine &TM,
setOperationAction(ISD::SHL, VT, Custom);
setOperationAction(ISD::SRL, VT, Custom);
setOperationAction(ISD::SRA, VT, Custom);
setOperationAction(ISD::VECREDUCE_ADD, VT, Custom);
setOperationAction(ISD::VECREDUCE_AND, VT, Custom);
setOperationAction(ISD::VECREDUCE_OR, VT, Custom);
setOperationAction(ISD::VECREDUCE_XOR, VT, Custom);
setOperationAction(ISD::VECREDUCE_UMIN, VT, Custom);
setOperationAction(ISD::VECREDUCE_UMAX, VT, Custom);
setOperationAction(ISD::VECREDUCE_SMIN, VT, Custom);
setOperationAction(ISD::VECREDUCE_SMAX, VT, Custom);
}

// Illegal unpacked integer vector types.
Expand All @@ -1027,6 +1035,9 @@ AArch64TargetLowering::AArch64TargetLowering(const TargetMachine &TM,
setOperationAction(ISD::SETCC, VT, Custom);
setOperationAction(ISD::SPLAT_VECTOR, VT, Custom);
setOperationAction(ISD::TRUNCATE, VT, Custom);
setOperationAction(ISD::VECREDUCE_AND, VT, Custom);
setOperationAction(ISD::VECREDUCE_OR, VT, Custom);
setOperationAction(ISD::VECREDUCE_XOR, VT, Custom);

// There are no legal MVT::nxv16f## based types.
if (VT != MVT::nxv16i1) {
Expand Down Expand Up @@ -9815,30 +9826,35 @@ SDValue AArch64TargetLowering::LowerVECREDUCE(SDValue Op,
Op.getOpcode() == ISD::VECREDUCE_FADD ||
(Op.getOpcode() != ISD::VECREDUCE_ADD &&
SrcVT.getVectorElementType() == MVT::i64);
if (useSVEForFixedLengthVectorVT(SrcVT, OverrideNEON)) {
if (SrcVT.isScalableVector() ||
useSVEForFixedLengthVectorVT(SrcVT, OverrideNEON)) {

if (SrcVT.getVectorElementType() == MVT::i1)
return LowerPredReductionToSVE(Op, DAG);

switch (Op.getOpcode()) {
case ISD::VECREDUCE_ADD:
return LowerFixedLengthReductionToSVE(AArch64ISD::UADDV_PRED, Op, DAG);
return LowerReductionToSVE(AArch64ISD::UADDV_PRED, Op, DAG);
case ISD::VECREDUCE_AND:
return LowerFixedLengthReductionToSVE(AArch64ISD::ANDV_PRED, Op, DAG);
return LowerReductionToSVE(AArch64ISD::ANDV_PRED, Op, DAG);
case ISD::VECREDUCE_OR:
return LowerFixedLengthReductionToSVE(AArch64ISD::ORV_PRED, Op, DAG);
return LowerReductionToSVE(AArch64ISD::ORV_PRED, Op, DAG);
case ISD::VECREDUCE_SMAX:
return LowerFixedLengthReductionToSVE(AArch64ISD::SMAXV_PRED, Op, DAG);
return LowerReductionToSVE(AArch64ISD::SMAXV_PRED, Op, DAG);
case ISD::VECREDUCE_SMIN:
return LowerFixedLengthReductionToSVE(AArch64ISD::SMINV_PRED, Op, DAG);
return LowerReductionToSVE(AArch64ISD::SMINV_PRED, Op, DAG);
case ISD::VECREDUCE_UMAX:
return LowerFixedLengthReductionToSVE(AArch64ISD::UMAXV_PRED, Op, DAG);
return LowerReductionToSVE(AArch64ISD::UMAXV_PRED, Op, DAG);
case ISD::VECREDUCE_UMIN:
return LowerFixedLengthReductionToSVE(AArch64ISD::UMINV_PRED, Op, DAG);
return LowerReductionToSVE(AArch64ISD::UMINV_PRED, Op, DAG);
case ISD::VECREDUCE_XOR:
return LowerFixedLengthReductionToSVE(AArch64ISD::EORV_PRED, Op, DAG);
return LowerReductionToSVE(AArch64ISD::EORV_PRED, Op, DAG);
case ISD::VECREDUCE_FADD:
return LowerFixedLengthReductionToSVE(AArch64ISD::FADDV_PRED, Op, DAG);
return LowerReductionToSVE(AArch64ISD::FADDV_PRED, Op, DAG);
case ISD::VECREDUCE_FMAX:
return LowerFixedLengthReductionToSVE(AArch64ISD::FMAXNMV_PRED, Op, DAG);
return LowerReductionToSVE(AArch64ISD::FMAXNMV_PRED, Op, DAG);
case ISD::VECREDUCE_FMIN:
return LowerFixedLengthReductionToSVE(AArch64ISD::FMINNMV_PRED, Op, DAG);
return LowerReductionToSVE(AArch64ISD::FMINNMV_PRED, Op, DAG);
default:
llvm_unreachable("Unhandled fixed length reduction");
}
Expand Down Expand Up @@ -16333,20 +16349,56 @@ SDValue AArch64TargetLowering::LowerVECREDUCE_SEQ_FADD(SDValue ScalarOp,
return DAG.getNode(ISD::EXTRACT_VECTOR_ELT, DL, ResVT, Rdx, Zero);
}

SDValue AArch64TargetLowering::LowerFixedLengthReductionToSVE(unsigned Opcode,
SDValue ScalarOp, SelectionDAG &DAG) const {
SDValue AArch64TargetLowering::LowerPredReductionToSVE(SDValue ReduceOp,
SelectionDAG &DAG) const {
SDLoc DL(ReduceOp);
SDValue Op = ReduceOp.getOperand(0);
EVT OpVT = Op.getValueType();
EVT VT = ReduceOp.getValueType();

if (!OpVT.isScalableVector() || OpVT.getVectorElementType() != MVT::i1)
return SDValue();

SDValue Pg = getPredicateForVector(DAG, DL, OpVT);

switch (ReduceOp.getOpcode()) {
default:
return SDValue();
case ISD::VECREDUCE_OR:
return getPTest(DAG, VT, Pg, Op, AArch64CC::ANY_ACTIVE);
case ISD::VECREDUCE_AND: {
Op = DAG.getNode(ISD::XOR, DL, OpVT, Op, Pg);
return getPTest(DAG, VT, Pg, Op, AArch64CC::NONE_ACTIVE);
}
case ISD::VECREDUCE_XOR: {
SDValue ID =
DAG.getTargetConstant(Intrinsic::aarch64_sve_cntp, DL, MVT::i64);
SDValue Cntp =
DAG.getNode(ISD::INTRINSIC_WO_CHAIN, DL, MVT::i64, ID, Pg, Op);
return DAG.getAnyExtOrTrunc(Cntp, DL, VT);
}
}

return SDValue();
}

SDValue AArch64TargetLowering::LowerReductionToSVE(unsigned Opcode,
SDValue ScalarOp,
SelectionDAG &DAG) const {
SDLoc DL(ScalarOp);
SDValue VecOp = ScalarOp.getOperand(0);
EVT SrcVT = VecOp.getValueType();

SDValue Pg = getPredicateForVector(DAG, DL, SrcVT);
EVT ContainerVT = getContainerForFixedLengthVector(DAG, SrcVT);
VecOp = convertToScalableVector(DAG, ContainerVT, VecOp);
if (useSVEForFixedLengthVectorVT(SrcVT, true)) {
EVT ContainerVT = getContainerForFixedLengthVector(DAG, SrcVT);
VecOp = convertToScalableVector(DAG, ContainerVT, VecOp);
}

// UADDV always returns an i64 result.
EVT ResVT = (Opcode == AArch64ISD::UADDV_PRED) ? MVT::i64 :
SrcVT.getVectorElementType();

SDValue Pg = getPredicateForVector(DAG, DL, SrcVT);
SDValue Rdx = DAG.getNode(Opcode, DL, getPackedSVEVectorVT(ResVT), Pg, VecOp);
SDValue Res = DAG.getNode(ISD::EXTRACT_VECTOR_ELT, DL, ResVT,
Rdx, DAG.getConstant(0, DL, MVT::i64));
Expand Down
5 changes: 3 additions & 2 deletions llvm/lib/Target/AArch64/AArch64ISelLowering.h
Expand Up @@ -933,8 +933,9 @@ class AArch64TargetLowering : public TargetLowering {
SelectionDAG &DAG) const;
SDValue LowerFixedLengthVectorLoadToSVE(SDValue Op, SelectionDAG &DAG) const;
SDValue LowerVECREDUCE_SEQ_FADD(SDValue ScalarOp, SelectionDAG &DAG) const;
SDValue LowerFixedLengthReductionToSVE(unsigned Opcode, SDValue ScalarOp,
SelectionDAG &DAG) const;
SDValue LowerPredReductionToSVE(SDValue ScalarOp, SelectionDAG &DAG) const;
SDValue LowerReductionToSVE(unsigned Opcode, SDValue ScalarOp,
SelectionDAG &DAG) const;
SDValue LowerFixedLengthVectorSelectToSVE(SDValue Op, SelectionDAG &DAG) const;
SDValue LowerFixedLengthVectorSetccToSVE(SDValue Op, SelectionDAG &DAG) const;
SDValue LowerFixedLengthVectorStoreToSVE(SDValue Op, SelectionDAG &DAG) const;
Expand Down

0 comments on commit f2412d3

Please sign in to comment.