diff --git a/llvm/lib/Target/RISCV/RISCVISelLowering.cpp b/llvm/lib/Target/RISCV/RISCVISelLowering.cpp index 1c930acd9c4a0..0af03cd2b5f1b 100644 --- a/llvm/lib/Target/RISCV/RISCVISelLowering.cpp +++ b/llvm/lib/Target/RISCV/RISCVISelLowering.cpp @@ -16485,6 +16485,35 @@ static SDValue expandMulToAddOrSubOfShl(SDNode *N, SelectionDAG &DAG, return DAG.getNode(Op, DL, VT, Shift1, Shift2); } +static SDValue getShlAddShlAdd(SDNode *N, SelectionDAG &DAG, unsigned ShX, + unsigned ShY) { + SDLoc DL(N); + EVT VT = N->getValueType(0); + SDValue X = N->getOperand(0); + SDValue Mul359 = DAG.getNode(RISCVISD::SHL_ADD, DL, VT, X, + DAG.getConstant(ShY, DL, VT), X); + return DAG.getNode(RISCVISD::SHL_ADD, DL, VT, Mul359, + DAG.getConstant(ShX, DL, VT), Mul359); +} + +static SDValue expandMulToShlAddShlAdd(SDNode *N, SelectionDAG &DAG, + uint64_t MulAmt) { + switch (MulAmt) { + case 5 * 3: + return getShlAddShlAdd(N, DAG, 2, 1); + case 9 * 3: + return getShlAddShlAdd(N, DAG, 3, 1); + case 5 * 5: + return getShlAddShlAdd(N, DAG, 2, 2); + case 9 * 5: + return getShlAddShlAdd(N, DAG, 3, 2); + case 9 * 9: + return getShlAddShlAdd(N, DAG, 3, 3); + default: + return SDValue(); + } +} + // Try to expand a scalar multiply to a faster sequence. static SDValue expandMul(SDNode *N, SelectionDAG &DAG, TargetLowering::DAGCombinerInfo &DCI, @@ -16514,18 +16543,17 @@ static SDValue expandMul(SDNode *N, SelectionDAG &DAG, if (Subtarget.hasVendorXqciac() && isInt<12>(CNode->getSExtValue())) return SDValue(); - // WARNING: The code below is knowingly incorrect with regards to undef semantics. - // We're adding additional uses of X here, and in principle, we should be freezing - // X before doing so. However, adding freeze here causes real regressions, and no - // other target properly freezes X in these cases either. - SDValue X = N->getOperand(0); - + // WARNING: The code below is knowingly incorrect with regards to undef + // semantics. We're adding additional uses of X here, and in principle, we + // should be freezing X before doing so. However, adding freeze here causes + // real regressions, and no other target properly freezes X in these cases + // either. if (Subtarget.hasShlAdd(3)) { + SDValue X = N->getOperand(0); int Shift; if (int ShXAmount = isShifted359(MulAmt, Shift)) { // 3/5/9 * 2^N -> shl (shXadd X, X), N SDLoc DL(N); - SDValue X = N->getOperand(0); // Put the shift first if we can fold a zext into the shift forming // a slli.uw. if (X.getOpcode() == ISD::AND && isa(X.getOperand(1)) && @@ -16544,38 +16572,8 @@ static SDValue expandMul(SDNode *N, SelectionDAG &DAG, } // 3/5/9 * 3/5/9 -> shXadd (shYadd X, X), (shYadd X, X) - int ShX; - int ShY; - switch (MulAmt) { - case 3 * 5: - ShY = 1; - ShX = 2; - break; - case 3 * 9: - ShY = 1; - ShX = 3; - break; - case 5 * 5: - ShX = ShY = 2; - break; - case 5 * 9: - ShY = 2; - ShX = 3; - break; - case 9 * 9: - ShX = ShY = 3; - break; - default: - ShX = ShY = 0; - break; - } - if (ShX) { - SDLoc DL(N); - SDValue Mul359 = DAG.getNode(RISCVISD::SHL_ADD, DL, VT, X, - DAG.getConstant(ShY, DL, VT), X); - return DAG.getNode(RISCVISD::SHL_ADD, DL, VT, Mul359, - DAG.getConstant(ShX, DL, VT), Mul359); - } + if (SDValue V = expandMulToShlAddShlAdd(N, DAG, MulAmt)) + return V; // If this is a power 2 + 2/4/8, we can use a shift followed by a single // shXadd. First check if this a sum of two power of 2s because that's @@ -16638,23 +16636,12 @@ static SDValue expandMul(SDNode *N, SelectionDAG &DAG, } } - for (uint64_t Divisor : {3, 5, 9}) { - if (MulAmt % Divisor != 0) - continue; - uint64_t MulAmt2 = MulAmt / Divisor; - // 3/5/9 * 3/5/9 * 2^N - In particular, this covers multiples - // of 25 which happen to be quite common. - if (int ShBAmount = isShifted359(MulAmt2, Shift)) { - SDLoc DL(N); - SDValue Mul359A = - DAG.getNode(RISCVISD::SHL_ADD, DL, VT, X, - DAG.getConstant(Log2_64(Divisor - 1), DL, VT), X); - SDValue Mul359B = - DAG.getNode(RISCVISD::SHL_ADD, DL, VT, Mul359A, - DAG.getConstant(ShBAmount, DL, VT), Mul359A); - return DAG.getNode(ISD::SHL, DL, VT, Mul359B, - DAG.getConstant(Shift, DL, VT)); - } + // 3/5/9 * 3/5/9 * 2^N - In particular, this covers multiples + // of 25 which happen to be quite common. + Shift = llvm::countr_zero(MulAmt); + if (SDValue V = expandMulToShlAddShlAdd(N, DAG, MulAmt >> Shift)) { + SDLoc DL(N); + return DAG.getNode(ISD::SHL, DL, VT, V, DAG.getConstant(Shift, DL, VT)); } }