From ecabba04a35432ad94447d199cf6127d57415456 Mon Sep 17 00:00:00 2001 From: Philip Reames Date: Tue, 13 Dec 2022 09:02:01 -0800 Subject: [PATCH] [RISCV] Use lowerScalarInsert when folding op into reduction [nfc] This doesn't cause any functional change since this is being applied to a insert generated by the same routine. This is mostly about consolidating the logic for vmv.s.x into one place to simplify future changes. --- llvm/lib/Target/RISCV/RISCVISelLowering.cpp | 26 +++++++-------------- 1 file changed, 9 insertions(+), 17 deletions(-) diff --git a/llvm/lib/Target/RISCV/RISCVISelLowering.cpp b/llvm/lib/Target/RISCV/RISCVISelLowering.cpp index 02345d89788a4..df8ec9a2927a6 100644 --- a/llvm/lib/Target/RISCV/RISCVISelLowering.cpp +++ b/llvm/lib/Target/RISCV/RISCVISelLowering.cpp @@ -8007,7 +8007,8 @@ void RISCVTargetLowering::ReplaceNodeResults(SDNode *N, } // Try to fold ( x, (reduction. vec, start)) -static SDValue combineBinOpToReduce(SDNode *N, SelectionDAG &DAG) { +static SDValue combineBinOpToReduce(SDNode *N, SelectionDAG &DAG, + const RISCVSubtarget &Subtarget) { auto BinOpToRVVReduce = [](unsigned Opc) { switch (Opc) { default: @@ -8084,20 +8085,11 @@ static SDValue combineBinOpToReduce(SDNode *N, SelectionDAG &DAG) { if (!ScalarV.hasOneUse()) return SDValue(); - EVT SplatVT = ScalarV.getValueType(); SDValue NewStart = N->getOperand(1 - ReduceIdx); - unsigned SplatOpc = RISCVISD::VFMV_S_F_VL; - if (SplatVT.isInteger()) { - auto *C = dyn_cast(NewStart.getNode()); - if (!C || C->isZero() || !isInt<5>(C->getSExtValue())) - SplatOpc = RISCVISD::VMV_S_X_VL; - else - SplatOpc = RISCVISD::VMV_V_X_VL; - } SDValue NewScalarV = - DAG.getNode(SplatOpc, SDLoc(N), SplatVT, ScalarV.getOperand(0), NewStart, - ScalarV.getOperand(2)); + lowerScalarInsert(NewStart, ScalarV.getOperand(2), ScalarV.getSimpleValueType(), + SDLoc(N), DAG, Subtarget); SDValue NewReduce = DAG.getNode(Reduce.getOpcode(), SDLoc(Reduce), Reduce.getValueType(), Reduce.getOperand(0), Reduce.getOperand(1), NewScalarV, @@ -8299,7 +8291,7 @@ static SDValue performADDCombine(SDNode *N, SelectionDAG &DAG, return V; if (SDValue V = transformAddShlImm(N, DAG, Subtarget)) return V; - if (SDValue V = combineBinOpToReduce(N, DAG)) + if (SDValue V = combineBinOpToReduce(N, DAG, Subtarget)) return V; // fold (add (select lhs, rhs, cc, 0, y), x) -> // (select lhs, rhs, cc, x, (add x, y)) @@ -8453,7 +8445,7 @@ static SDValue performANDCombine(SDNode *N, return DAG.getNode(ISD::TRUNCATE, DL, MVT::i32, And); } - if (SDValue V = combineBinOpToReduce(N, DAG)) + if (SDValue V = combineBinOpToReduce(N, DAG, Subtarget)) return V; if (DCI.isAfterLegalizeDAG()) @@ -8469,7 +8461,7 @@ static SDValue performORCombine(SDNode *N, TargetLowering::DAGCombinerInfo &DCI, const RISCVSubtarget &Subtarget) { SelectionDAG &DAG = DCI.DAG; - if (SDValue V = combineBinOpToReduce(N, DAG)) + if (SDValue V = combineBinOpToReduce(N, DAG, Subtarget)) return V; if (DCI.isAfterLegalizeDAG()) @@ -8497,7 +8489,7 @@ static SDValue performXORCombine(SDNode *N, SelectionDAG &DAG, DAG.getConstant(~1, DL, MVT::i64), N0.getOperand(1)); } - if (SDValue V = combineBinOpToReduce(N, DAG)) + if (SDValue V = combineBinOpToReduce(N, DAG, Subtarget)) return V; // fold (xor (select cond, 0, y), x) -> // (select cond, x, (xor x, y)) @@ -9717,7 +9709,7 @@ SDValue RISCVTargetLowering::PerformDAGCombine(SDNode *N, case ISD::SMIN: case ISD::FMAXNUM: case ISD::FMINNUM: - return combineBinOpToReduce(N, DAG); + return combineBinOpToReduce(N, DAG, Subtarget); case ISD::SETCC: return performSETCCCombine(N, DAG, Subtarget); case ISD::SIGN_EXTEND_INREG: