Skip to content

Commit

Permalink
[RISCV] Use lowerScalarInsert when folding op into reduction [nfc]
Browse files Browse the repository at this point in the history
This doesn't cause any functional change since this is being applied to a insert generated by the same routine.  This is mostly about consolidating the logic for vmv.s.x into one place to simplify future changes.
  • Loading branch information
preames committed Dec 13, 2022
1 parent 44e0427 commit ecabba0
Showing 1 changed file with 9 additions and 17 deletions.
26 changes: 9 additions & 17 deletions llvm/lib/Target/RISCV/RISCVISelLowering.cpp
Expand Up @@ -8007,7 +8007,8 @@ void RISCVTargetLowering::ReplaceNodeResults(SDNode *N,
}

// Try to fold (<bop> x, (reduction.<bop> vec, start))
static SDValue combineBinOpToReduce(SDNode *N, SelectionDAG &DAG) {
static SDValue combineBinOpToReduce(SDNode *N, SelectionDAG &DAG,
const RISCVSubtarget &Subtarget) {
auto BinOpToRVVReduce = [](unsigned Opc) {
switch (Opc) {
default:
Expand Down Expand Up @@ -8084,20 +8085,11 @@ static SDValue combineBinOpToReduce(SDNode *N, SelectionDAG &DAG) {
if (!ScalarV.hasOneUse())
return SDValue();

EVT SplatVT = ScalarV.getValueType();
SDValue NewStart = N->getOperand(1 - ReduceIdx);
unsigned SplatOpc = RISCVISD::VFMV_S_F_VL;
if (SplatVT.isInteger()) {
auto *C = dyn_cast<ConstantSDNode>(NewStart.getNode());
if (!C || C->isZero() || !isInt<5>(C->getSExtValue()))
SplatOpc = RISCVISD::VMV_S_X_VL;
else
SplatOpc = RISCVISD::VMV_V_X_VL;
}

SDValue NewScalarV =
DAG.getNode(SplatOpc, SDLoc(N), SplatVT, ScalarV.getOperand(0), NewStart,
ScalarV.getOperand(2));
lowerScalarInsert(NewStart, ScalarV.getOperand(2), ScalarV.getSimpleValueType(),
SDLoc(N), DAG, Subtarget);
SDValue NewReduce =
DAG.getNode(Reduce.getOpcode(), SDLoc(Reduce), Reduce.getValueType(),
Reduce.getOperand(0), Reduce.getOperand(1), NewScalarV,
Expand Down Expand Up @@ -8299,7 +8291,7 @@ static SDValue performADDCombine(SDNode *N, SelectionDAG &DAG,
return V;
if (SDValue V = transformAddShlImm(N, DAG, Subtarget))
return V;
if (SDValue V = combineBinOpToReduce(N, DAG))
if (SDValue V = combineBinOpToReduce(N, DAG, Subtarget))
return V;
// fold (add (select lhs, rhs, cc, 0, y), x) ->
// (select lhs, rhs, cc, x, (add x, y))
Expand Down Expand Up @@ -8453,7 +8445,7 @@ static SDValue performANDCombine(SDNode *N,
return DAG.getNode(ISD::TRUNCATE, DL, MVT::i32, And);
}

if (SDValue V = combineBinOpToReduce(N, DAG))
if (SDValue V = combineBinOpToReduce(N, DAG, Subtarget))
return V;

if (DCI.isAfterLegalizeDAG())
Expand All @@ -8469,7 +8461,7 @@ static SDValue performORCombine(SDNode *N, TargetLowering::DAGCombinerInfo &DCI,
const RISCVSubtarget &Subtarget) {
SelectionDAG &DAG = DCI.DAG;

if (SDValue V = combineBinOpToReduce(N, DAG))
if (SDValue V = combineBinOpToReduce(N, DAG, Subtarget))
return V;

if (DCI.isAfterLegalizeDAG())
Expand Down Expand Up @@ -8497,7 +8489,7 @@ static SDValue performXORCombine(SDNode *N, SelectionDAG &DAG,
DAG.getConstant(~1, DL, MVT::i64), N0.getOperand(1));
}

if (SDValue V = combineBinOpToReduce(N, DAG))
if (SDValue V = combineBinOpToReduce(N, DAG, Subtarget))
return V;
// fold (xor (select cond, 0, y), x) ->
// (select cond, x, (xor x, y))
Expand Down Expand Up @@ -9717,7 +9709,7 @@ SDValue RISCVTargetLowering::PerformDAGCombine(SDNode *N,
case ISD::SMIN:
case ISD::FMAXNUM:
case ISD::FMINNUM:
return combineBinOpToReduce(N, DAG);
return combineBinOpToReduce(N, DAG, Subtarget);
case ISD::SETCC:
return performSETCCCombine(N, DAG, Subtarget);
case ISD::SIGN_EXTEND_INREG:
Expand Down

0 comments on commit ecabba0

Please sign in to comment.