Skip to content

Commit

Permalink
[RISCV] Share reduction lowering code for vp.reduce
Browse files Browse the repository at this point in the history
We can consolidate code and clarify edge case behavior at the same time.

There are two functional differences here.

First, I remove the ResVT handling, and always use the reduction element type. This appears to be dead code. There's no test coverage, and this code doesn't need to account for scalar type legalization anyways.

Second, if the VL happens to be known non-zero, we can avoid passing through start. This is mostly needed to allow reuse of the existing code; I don't consider it interesting as an optimization on it's own.

Differential Revision: https://reviews.llvm.org/D139733
  • Loading branch information
preames committed Dec 9, 2022
1 parent 6b2829d commit 1ebe8f4
Showing 1 changed file with 12 additions and 16 deletions.
28 changes: 12 additions & 16 deletions llvm/lib/Target/RISCV/RISCVISelLowering.cpp
Expand Up @@ -5796,6 +5796,13 @@ SDValue RISCVTargetLowering::lowerVectorMaskVecReduction(SDValue Op,
return DAG.getNode(BaseOpc, DL, XLenVT, SetCC, Op.getOperand(0));
}

static bool hasNonZeroAVL(SDValue AVL) {
auto *RegisterAVL = dyn_cast<RegisterSDNode>(AVL);
auto *ImmAVL = dyn_cast<ConstantSDNode>(AVL);
return (RegisterAVL && RegisterAVL->getReg() == RISCV::X0) ||
(ImmAVL && ImmAVL->getZExtValue() >= 1);
}

/// Helper to lower a reduction sequence of the form:
/// scalar = reduce_op vec, scalar_start
static SDValue lowerReductionSeq(unsigned RVVOpcode, SDValue StartValue, SDValue Vec, SDValue Mask, SDValue VL,
Expand All @@ -5808,7 +5815,8 @@ static SDValue lowerReductionSeq(unsigned RVVOpcode, SDValue StartValue, SDValue
SDValue InitialSplat =
lowerScalarSplat(SDValue(), StartValue, DAG.getConstant(1, DL, XLenVT),
M1VT, DL, DAG, Subtarget);
SDValue Reduction = DAG.getNode(RVVOpcode, DL, M1VT, DAG.getUNDEF(M1VT), Vec,
SDValue PassThru = hasNonZeroAVL(VL) ? DAG.getUNDEF(M1VT) : InitialSplat;
SDValue Reduction = DAG.getNode(RVVOpcode, DL, M1VT, PassThru, Vec,
InitialSplat, Mask, VL);
return DAG.getNode(ISD::EXTRACT_VECTOR_ELT, DL, VecEltVT, Reduction,
DAG.getConstant(0, DL, XLenVT));
Expand Down Expand Up @@ -5951,29 +5959,17 @@ SDValue RISCVTargetLowering::lowerVPREDUCE(SDValue Op,
return SDValue();

MVT VecVT = VecEVT.getSimpleVT();
MVT VecEltVT = VecVT.getVectorElementType();
unsigned RVVOpcode = getRVVVPReductionOp(Op.getOpcode());

MVT ContainerVT = VecVT;
if (VecVT.isFixedLengthVector()) {
ContainerVT = getContainerForFixedLengthVector(VecVT);
auto ContainerVT = getContainerForFixedLengthVector(VecVT);
Vec = convertToScalableVector(ContainerVT, Vec, DAG, Subtarget);
}

SDValue VL = Op.getOperand(3);
SDValue Mask = Op.getOperand(2);

MVT M1VT = getLMUL1VT(ContainerVT);
MVT XLenVT = Subtarget.getXLenVT();
MVT ResVT = !VecVT.isInteger() || VecEltVT.bitsGE(XLenVT) ? VecEltVT : XLenVT;

SDValue StartSplat = lowerScalarSplat(SDValue(), Op.getOperand(0),
DAG.getConstant(1, DL, XLenVT), M1VT,
DL, DAG, Subtarget);
SDValue Reduction =
DAG.getNode(RVVOpcode, DL, M1VT, StartSplat, Vec, StartSplat, Mask, VL);
SDValue Elt0 = DAG.getNode(ISD::EXTRACT_VECTOR_ELT, DL, ResVT, Reduction,
DAG.getConstant(0, DL, XLenVT));
SDValue Elt0 = lowerReductionSeq(RVVOpcode, Op.getOperand(0), Vec, Mask, VL,
DL, DAG, Subtarget);
if (!VecVT.isInteger())
return Elt0;
return DAG.getSExtOrTrunc(Elt0, DL, Op.getValueType());
Expand Down

0 comments on commit 1ebe8f4

Please sign in to comment.