diff --git a/llvm/lib/Target/RISCV/RISCVISelLowering.cpp b/llvm/lib/Target/RISCV/RISCVISelLowering.cpp index 472eceed4577a..ebb32bd599668 100644 --- a/llvm/lib/Target/RISCV/RISCVISelLowering.cpp +++ b/llvm/lib/Target/RISCV/RISCVISelLowering.cpp @@ -1040,7 +1040,7 @@ RISCVTargetLowering::RISCVTargetLowering(const TargetMachine &TM, setJumpIsExpensive(); setTargetDAGCombine({ISD::INTRINSIC_WO_CHAIN, ISD::ADD, ISD::SUB, ISD::AND, - ISD::OR, ISD::XOR, ISD::SETCC, ISD::SELECT, ISD::MUL}); + ISD::OR, ISD::XOR, ISD::SETCC, ISD::SELECT}); if (Subtarget.is64Bit()) setTargetDAGCombine(ISD::SRA); @@ -8644,134 +8644,6 @@ static SDValue combineDeMorganOfBoolean(SDNode *N, SelectionDAG &DAG) { return DAG.getNode(ISD::XOR, DL, VT, Logic, DAG.getConstant(1, DL, VT)); } -static SDValue performMULCombine(SDNode *N, SelectionDAG &DAG, - const RISCVSubtarget &Subtarget) { - SDLoc DL(N); - const MVT XLenVT = Subtarget.getXLenVT(); - const EVT VT = N->getValueType(0); - - // An MUL is usually smaller than any alternative sequence for legal type. - const TargetLowering &TLI = DAG.getTargetLoweringInfo(); - if (DAG.getMachineFunction().getFunction().hasMinSize() && - TLI.isOperationLegal(ISD::MUL, VT)) - return SDValue(); - - SDValue N0 = N->getOperand(0); - SDValue N1 = N->getOperand(1); - ConstantSDNode *ConstOp = dyn_cast(N1); - // Any optimization requires a constant RHS. - if (!ConstOp) - return SDValue(); - - const APInt &C = ConstOp->getAPIntValue(); - // A multiply-by-pow2 will be reduced to a shift by the - // architecture-independent code. - if (C.isPowerOf2()) - return SDValue(); - - // The below optimizations only work for non-negative constants - if (!C.isNonNegative()) - return SDValue(); - - auto Shl = [&](SDValue Value, unsigned ShiftAmount) { - if (!ShiftAmount) - return Value; - - SDValue ShiftAmountConst = DAG.getConstant(ShiftAmount, DL, XLenVT); - return DAG.getNode(ISD::SHL, DL, Value.getValueType(), Value, - ShiftAmountConst); - }; - auto Add = [&](SDValue Addend1, SDValue Addend2) { - return DAG.getNode(ISD::ADD, DL, Addend1.getValueType(), Addend1, Addend2); - }; - - if (Subtarget.hasVendorXTHeadBa()) { - // We try to simplify using shift-and-add instructions into up to - // 3 instructions (e.g. 2x shift-and-add and 1x shift). - - auto isDivisibleByShiftedAddConst = [&](APInt C, APInt &N, - APInt &Quotient) { - unsigned BitWidth = C.getBitWidth(); - for (unsigned i = 3; i >= 1; --i) { - APInt X(BitWidth, (1 << i) + 1); - APInt Remainder; - APInt::sdivrem(C, X, Quotient, Remainder); - if (Remainder == 0) { - N = X; - return true; - } - } - return false; - }; - auto isShiftedAddConst = [&](APInt C, APInt &N) { - APInt Quotient; - return isDivisibleByShiftedAddConst(C, N, Quotient) && Quotient == 1; - }; - auto isSmallShiftAmount = [](APInt C) { - return (C == 2) || (C == 4) || (C == 8); - }; - - auto ShiftAndAdd = [&](SDValue Value, unsigned ShiftAmount, - SDValue Addend) { - return Add(Shl(Value, ShiftAmount), Addend); - }; - auto AnyExt = [&](SDValue Value) { - return DAG.getNode(ISD::ANY_EXTEND, DL, XLenVT, Value); - }; - auto Trunc = [&](SDValue Value) { - return DAG.getNode(ISD::TRUNCATE, DL, VT, Value); - }; - - unsigned TrailingZeroes = C.countTrailingZeros(); - const APInt ShiftedC = C.ashr(TrailingZeroes); - const APInt ShiftedCMinusOne = ShiftedC - 1; - - // the below comments use the following notation: - // n, m .. a shift-amount for a shift-and-add instruction - // (i.e. in { 2, 4, 8 }) - // k .. a power-of-2 that is equivalent to shifting by - // TrailingZeroes bits - // i, j .. a power-of-2 - - APInt ShiftAmt1; - APInt ShiftAmt2; - APInt Quotient; - - // C = (m + 1) * k - if (isShiftedAddConst(ShiftedC, ShiftAmt1)) { - SDValue Op0 = AnyExt(N0); - SDValue Result = ShiftAndAdd(Op0, ShiftAmt1.logBase2(), Op0); - return Trunc(Shl(Result, TrailingZeroes)); - } - // C = (m + 1) * (n + 1) * k - if (isDivisibleByShiftedAddConst(ShiftedC, ShiftAmt1, Quotient) && - isShiftedAddConst(Quotient, ShiftAmt2)) { - SDValue Op0 = AnyExt(N0); - SDValue Result = ShiftAndAdd(Op0, ShiftAmt1.logBase2(), Op0); - Result = ShiftAndAdd(Result, ShiftAmt2.logBase2(), Result); - return Trunc(Shl(Result, TrailingZeroes)); - } - // C = ((m + 1) * n + 1) * k - if (isDivisibleByShiftedAddConst(ShiftedCMinusOne, ShiftAmt1, ShiftAmt2) && - isSmallShiftAmount(ShiftAmt2)) { - SDValue Op0 = AnyExt(N0); - SDValue Result = ShiftAndAdd(Op0, ShiftAmt1.logBase2(), Op0); - Result = ShiftAndAdd(Result, Quotient.logBase2(), Op0); - return Trunc(Shl(Result, TrailingZeroes)); - } - - // C has 2 bits set: synthesize using 2 shifts and 1 add (which may - // see one of the shifts merged into a shift-and-add, if feasible) - if (C.countPopulation() == 2) { - APInt HighBit(C.getBitWidth(), (1 << C.logBase2())); - APInt LowBit = C - HighBit; - return Add(Shl(N0, HighBit.logBase2()), Shl(N0, LowBit.logBase2())); - } - } - - return SDValue(); -} - static SDValue performTRUNCATECombine(SDNode *N, SelectionDAG &DAG, const RISCVSubtarget &Subtarget) { SDValue N0 = N->getOperand(0); @@ -10421,8 +10293,6 @@ SDValue RISCVTargetLowering::PerformDAGCombine(SDNode *N, return performADDCombine(N, DAG, Subtarget); case ISD::SUB: return performSUBCombine(N, DAG, Subtarget); - case ISD::MUL: - return performMULCombine(N, DAG, Subtarget); case ISD::AND: return performANDCombine(N, DCI, Subtarget); case ISD::OR: diff --git a/llvm/lib/Target/RISCV/RISCVInstrInfoXTHead.td b/llvm/lib/Target/RISCV/RISCVInstrInfoXTHead.td index 60d41093451ab..e3dbc670d514b 100644 --- a/llvm/lib/Target/RISCV/RISCVInstrInfoXTHead.td +++ b/llvm/lib/Target/RISCV/RISCVInstrInfoXTHead.td @@ -240,6 +240,67 @@ def : Pat<(add sh2add_op:$rs1, non_imm12:$rs2), (TH_ADDSL GPR:$rs2, sh2add_op:$rs1, 2)>; def : Pat<(add sh3add_op:$rs1, non_imm12:$rs2), (TH_ADDSL GPR:$rs2, sh3add_op:$rs1, 3)>; + +def : Pat<(add (mul_oneuse GPR:$rs1, (XLenVT 6)), GPR:$rs2), + (TH_ADDSL GPR:$rs2, (TH_ADDSL GPR:$rs1, GPR:$rs1, 1), 1)>; +def : Pat<(add (mul_oneuse GPR:$rs1, (XLenVT 10)), GPR:$rs2), + (TH_ADDSL GPR:$rs2, (TH_ADDSL GPR:$rs1, GPR:$rs1, 2), 1)>; +def : Pat<(add (mul_oneuse GPR:$rs1, (XLenVT 18)), GPR:$rs2), + (TH_ADDSL GPR:$rs2, (TH_ADDSL GPR:$rs1, GPR:$rs1, 3), 1)>; +def : Pat<(add (mul_oneuse GPR:$rs1, (XLenVT 12)), GPR:$rs2), + (TH_ADDSL GPR:$rs2, (TH_ADDSL GPR:$rs1, GPR:$rs1, 1), 2)>; +def : Pat<(add (mul_oneuse GPR:$rs1, (XLenVT 20)), GPR:$rs2), + (TH_ADDSL GPR:$rs2, (TH_ADDSL GPR:$rs1, GPR:$rs1, 2), 2)>; +def : Pat<(add (mul_oneuse GPR:$rs1, (XLenVT 36)), GPR:$rs2), + (TH_ADDSL GPR:$rs2, (TH_ADDSL GPR:$rs1, GPR:$rs1, 3), 2)>; +def : Pat<(add (mul_oneuse GPR:$rs1, (XLenVT 24)), GPR:$rs2), + (TH_ADDSL GPR:$rs2, (TH_ADDSL GPR:$rs1, GPR:$rs1, 1), 3)>; +def : Pat<(add (mul_oneuse GPR:$rs1, (XLenVT 40)), GPR:$rs2), + (TH_ADDSL GPR:$rs2, (TH_ADDSL GPR:$rs1, GPR:$rs1, 2), 3)>; +def : Pat<(add (mul_oneuse GPR:$rs1, (XLenVT 72)), GPR:$rs2), + (TH_ADDSL GPR:$rs2, (TH_ADDSL GPR:$rs1, GPR:$rs1, 3), 3)>; + +def : Pat<(add GPR:$r, CSImm12MulBy4:$i), + (TH_ADDSL GPR:$r, (ADDI X0, (SimmShiftRightBy2XForm CSImm12MulBy4:$i)), 2)>; +def : Pat<(add GPR:$r, CSImm12MulBy8:$i), + (TH_ADDSL GPR:$r, (ADDI X0, (SimmShiftRightBy3XForm CSImm12MulBy8:$i)), 3)>; + +def : Pat<(mul GPR:$r, C3LeftShift:$i), + (SLLI (TH_ADDSL GPR:$r, GPR:$r, 1), + (TrailingZeros C3LeftShift:$i))>; +def : Pat<(mul GPR:$r, C5LeftShift:$i), + (SLLI (TH_ADDSL GPR:$r, GPR:$r, 2), + (TrailingZeros C5LeftShift:$i))>; +def : Pat<(mul GPR:$r, C9LeftShift:$i), + (SLLI (TH_ADDSL GPR:$r, GPR:$r, 3), + (TrailingZeros C9LeftShift:$i))>; + +def : Pat<(mul_const_oneuse GPR:$r, (XLenVT 11)), + (TH_ADDSL GPR:$r, (TH_ADDSL GPR:$r, GPR:$r, 2), 1)>; +def : Pat<(mul_const_oneuse GPR:$r, (XLenVT 19)), + (TH_ADDSL GPR:$r, (TH_ADDSL GPR:$r, GPR:$r, 3), 1)>; +def : Pat<(mul_const_oneuse GPR:$r, (XLenVT 13)), + (TH_ADDSL GPR:$r, (TH_ADDSL GPR:$r, GPR:$r, 1), 2)>; +def : Pat<(mul_const_oneuse GPR:$r, (XLenVT 21)), + (TH_ADDSL GPR:$r, (TH_ADDSL GPR:$r, GPR:$r, 2), 2)>; +def : Pat<(mul_const_oneuse GPR:$r, (XLenVT 37)), + (TH_ADDSL GPR:$r, (TH_ADDSL GPR:$r, GPR:$r, 3), 2)>; +def : Pat<(mul_const_oneuse GPR:$r, (XLenVT 25)), + (TH_ADDSL (TH_ADDSL GPR:$r, GPR:$r, 2), (TH_ADDSL GPR:$r, GPR:$r, 2), 2)>; +def : Pat<(mul_const_oneuse GPR:$r, (XLenVT 41)), + (TH_ADDSL GPR:$r, (TH_ADDSL GPR:$r, GPR:$r, 2), 3)>; +def : Pat<(mul_const_oneuse GPR:$r, (XLenVT 73)), + (TH_ADDSL GPR:$r, (TH_ADDSL GPR:$r, GPR:$r, 3), 3)>; +def : Pat<(mul_const_oneuse GPR:$r, (XLenVT 27)), + (TH_ADDSL (TH_ADDSL GPR:$r, GPR:$r, 3), (TH_ADDSL GPR:$r, GPR:$r, 3), 1)>; +def : Pat<(mul_const_oneuse GPR:$r, (XLenVT 45)), + (TH_ADDSL (TH_ADDSL GPR:$r, GPR:$r, 3), (TH_ADDSL GPR:$r, GPR:$r, 3), 2)>; +def : Pat<(mul_const_oneuse GPR:$r, (XLenVT 81)), + (TH_ADDSL (TH_ADDSL GPR:$r, GPR:$r, 3), (TH_ADDSL GPR:$r, GPR:$r, 3), 3)>; + +def : Pat<(mul_const_oneuse GPR:$r, (XLenVT 200)), + (SLLI (TH_ADDSL (TH_ADDSL GPR:$r, GPR:$r, 2), + (TH_ADDSL GPR:$r, GPR:$r, 2), 2), 3)>; } // Predicates = [HasVendorXTHeadBa] let Predicates = [HasVendorXTHeadBb] in {