diff --git a/llvm/include/llvm/CodeGen/ISDOpcodes.h b/llvm/include/llvm/CodeGen/ISDOpcodes.h index 49d51a27e3c0f..a45e5b26a5819 100644 --- a/llvm/include/llvm/CodeGen/ISDOpcodes.h +++ b/llvm/include/llvm/CodeGen/ISDOpcodes.h @@ -710,6 +710,13 @@ enum NodeType { FSHL, FSHR, + // Represents (ADD (SHL a, b), c) with the arguments appearing in the order + // a, b, c. 'b' must be a constant, and follows the rules for shift amount + // types described just above. This is used soley post-legalization when + // lowering MUL to target specific instructions - e.g. LEA on x86 or + // sh1add/sh2add/sh3add on RISCV. + SHL_ADD, + /// Byte Swap and Counting operators. BSWAP, CTTZ, diff --git a/llvm/include/llvm/Target/TargetSelectionDAG.td b/llvm/include/llvm/Target/TargetSelectionDAG.td index ea3520835fa07..aeef292d34b85 100644 --- a/llvm/include/llvm/Target/TargetSelectionDAG.td +++ b/llvm/include/llvm/Target/TargetSelectionDAG.td @@ -121,6 +121,10 @@ def SDTIntShiftOp : SDTypeProfile<1, 2, [ // shl, sra, srl def SDTIntShiftDOp: SDTypeProfile<1, 3, [ // fshl, fshr SDTCisSameAs<0, 1>, SDTCisSameAs<0, 2>, SDTCisInt<0>, SDTCisInt<3> ]>; +def SDTIntShiftAddOp : SDTypeProfile<1, 3, [ // shl_add + SDTCisSameAs<0, 1>, SDTCisSameAs<0, 3>, SDTCisInt<0>, SDTCisInt<2>, + SDTCisInt<3> +]>; def SDTIntSatNoShOp : SDTypeProfile<1, 2, [ // ssat with no shift SDTCisSameAs<0, 1>, SDTCisInt<2> ]>; @@ -411,6 +415,7 @@ def rotl : SDNode<"ISD::ROTL" , SDTIntShiftOp>; def rotr : SDNode<"ISD::ROTR" , SDTIntShiftOp>; def fshl : SDNode<"ISD::FSHL" , SDTIntShiftDOp>; def fshr : SDNode<"ISD::FSHR" , SDTIntShiftDOp>; +def shl_add : SDNode<"ISD::SHL_ADD" , SDTIntShiftAddOp>; def and : SDNode<"ISD::AND" , SDTIntBinOp, [SDNPCommutative, SDNPAssociative]>; def or : SDNode<"ISD::OR" , SDTIntBinOp, diff --git a/llvm/lib/CodeGen/SelectionDAG/SelectionDAG.cpp b/llvm/lib/CodeGen/SelectionDAG/SelectionDAG.cpp index ca0a95750ba8d..0e8297fdb020b 100644 --- a/llvm/lib/CodeGen/SelectionDAG/SelectionDAG.cpp +++ b/llvm/lib/CodeGen/SelectionDAG/SelectionDAG.cpp @@ -3521,6 +3521,13 @@ KnownBits SelectionDAG::computeKnownBits(SDValue Op, const APInt &DemandedElts, Known = KnownBits::ashr(Known, Known2, /*ShAmtNonZero=*/false, Op->getFlags().hasExact()); break; + case ISD::SHL_ADD: + Known = computeKnownBits(Op.getOperand(0), DemandedElts, Depth + 1); + Known2 = computeKnownBits(Op.getOperand(1), DemandedElts, Depth + 1); + Known = KnownBits::computeForAddSub( + true, false, false, KnownBits::shl(Known, Known2), + computeKnownBits(Op.getOperand(2), DemandedElts, Depth + 1)); + break; case ISD::FSHL: case ISD::FSHR: if (ConstantSDNode *C = isConstOrConstSplat(Op.getOperand(2), DemandedElts)) { @@ -7346,6 +7353,11 @@ SDValue SelectionDAG::getNode(unsigned Opcode, const SDLoc &DL, EVT VT, if (N1.getValueType() == VT) return N1; break; + case ISD::SHL_ADD: + assert(VT == N1.getValueType() && VT == N3.getValueType()); + assert(TLI->isTypeLegal(VT) && "Created only post legalize"); + assert(isa(N2) && "Constant shift expected"); + break; } // Memoize node if it doesn't produce a glue result. diff --git a/llvm/lib/CodeGen/SelectionDAG/SelectionDAGDumper.cpp b/llvm/lib/CodeGen/SelectionDAG/SelectionDAGDumper.cpp index 6691aa41face3..cc9dafcfa0c72 100644 --- a/llvm/lib/CodeGen/SelectionDAG/SelectionDAGDumper.cpp +++ b/llvm/lib/CodeGen/SelectionDAG/SelectionDAGDumper.cpp @@ -264,6 +264,7 @@ std::string SDNode::getOperationName(const SelectionDAG *G) const { case ISD::SRL: return "srl"; case ISD::ROTL: return "rotl"; case ISD::ROTR: return "rotr"; + case ISD::SHL_ADD: return "shl_add"; case ISD::FSHL: return "fshl"; case ISD::FSHR: return "fshr"; case ISD::FADD: return "fadd"; diff --git a/llvm/lib/Target/RISCV/RISCVISelLowering.cpp b/llvm/lib/Target/RISCV/RISCVISelLowering.cpp index 27387595164a4..cc64ccbedee92 100644 --- a/llvm/lib/Target/RISCV/RISCVISelLowering.cpp +++ b/llvm/lib/Target/RISCV/RISCVISelLowering.cpp @@ -12789,10 +12789,9 @@ static SDValue transformAddShlImm(SDNode *N, SelectionDAG &DAG, SDLoc DL(N); SDValue NS = (C0 < C1) ? N0->getOperand(0) : N1->getOperand(0); SDValue NL = (C0 > C1) ? N0->getOperand(0) : N1->getOperand(0); - SDValue NA0 = - DAG.getNode(ISD::SHL, DL, VT, NL, DAG.getConstant(Diff, DL, VT)); - SDValue NA1 = DAG.getNode(ISD::ADD, DL, VT, NA0, NS); - return DAG.getNode(ISD::SHL, DL, VT, NA1, DAG.getConstant(Bits, DL, VT)); + SDValue SHADD = + DAG.getNode(ISD::SHL_ADD, DL, VT, NL, DAG.getConstant(Diff, DL, VT), NS); + return DAG.getNode(ISD::SHL, DL, VT, SHADD, DAG.getConstant(Bits, DL, VT)); } // Combine a constant select operand into its use: @@ -13028,14 +13027,17 @@ static SDValue combineAddOfBooleanXor(SDNode *N, SelectionDAG &DAG) { N0.getOperand(0)); } -static SDValue performADDCombine(SDNode *N, SelectionDAG &DAG, +static SDValue performADDCombine(SDNode *N, + TargetLowering::DAGCombinerInfo &DCI, const RISCVSubtarget &Subtarget) { + SelectionDAG &DAG = DCI.DAG; if (SDValue V = combineAddOfBooleanXor(N, DAG)) return V; if (SDValue V = transformAddImmMulImm(N, DAG, Subtarget)) return V; - if (SDValue V = transformAddShlImm(N, DAG, Subtarget)) - return V; + if (!DCI.isBeforeLegalize()) + if (SDValue V = transformAddShlImm(N, DAG, Subtarget)) + return V; if (SDValue V = combineBinOpToReduce(N, DAG, Subtarget)) return V; if (SDValue V = combineBinOpOfExtractToReduceTree(N, DAG, Subtarget)) @@ -15894,7 +15896,7 @@ SDValue RISCVTargetLowering::PerformDAGCombine(SDNode *N, return V; if (SDValue V = combineToVWMACC(N, DAG, Subtarget)) return V; - return performADDCombine(N, DAG, Subtarget); + return performADDCombine(N, DCI, Subtarget); } case ISD::SUB: { if (SDValue V = combineBinOp_VLToVWBinOp_VL(N, DCI, Subtarget)) diff --git a/llvm/lib/Target/RISCV/RISCVInstrInfoZb.td b/llvm/lib/Target/RISCV/RISCVInstrInfoZb.td index 434b071e628a0..8837c66d60377 100644 --- a/llvm/lib/Target/RISCV/RISCVInstrInfoZb.td +++ b/llvm/lib/Target/RISCV/RISCVInstrInfoZb.td @@ -678,6 +678,8 @@ foreach i = {1,2,3} in { defvar shxadd = !cast("SH"#i#"ADD"); def : Pat<(XLenVT (add_like_non_imm12 (shl GPR:$rs1, (XLenVT i)), GPR:$rs2)), (shxadd GPR:$rs1, GPR:$rs2)>; + def : Pat<(XLenVT (shl_add GPR:$rs1, (XLenVT i), GPR:$rs2)), + (shxadd GPR:$rs1, GPR:$rs2)>; defvar pat = !cast("sh"#i#"add_op"); // More complex cases use a ComplexPattern. @@ -881,6 +883,9 @@ foreach i = {1,2,3} in { defvar shxadd = !cast("SH"#i#"ADD"); def : Pat<(i32 (add_like_non_imm12 (shl GPR:$rs1, (i64 i)), GPR:$rs2)), (shxadd GPR:$rs1, GPR:$rs2)>; + def : Pat<(i32 (shl_add GPR:$rs1, (i32 i), GPR:$rs2)), + (shxadd GPR:$rs1, GPR:$rs2)>; + } } diff --git a/llvm/lib/Target/X86/X86ISelDAGToDAG.cpp b/llvm/lib/Target/X86/X86ISelDAGToDAG.cpp index 4e4241efd63d6..49baebee62652 100644 --- a/llvm/lib/Target/X86/X86ISelDAGToDAG.cpp +++ b/llvm/lib/Target/X86/X86ISelDAGToDAG.cpp @@ -2519,7 +2519,6 @@ bool X86DAGToDAGISel::matchAddressRecursively(SDValue N, X86ISelAddressMode &AM, if (N.getResNo() != 0) break; [[fallthrough]]; case ISD::MUL: - case X86ISD::MUL_IMM: // X*[3,5,9] -> X+X*[2,4,8] if (AM.BaseType == X86ISelAddressMode::RegBase && AM.Base_Reg.getNode() == nullptr && @@ -2551,7 +2550,44 @@ bool X86DAGToDAGISel::matchAddressRecursively(SDValue N, X86ISelAddressMode &AM, } } break; - + case ISD::SHL_ADD: { + // X << [1,2,3] + Y (we should never create anything else) + auto *CN = cast(N.getOperand(1)); + assert(CN->getZExtValue() == 1 || CN->getZExtValue() == 2 || + CN->getZExtValue() == 3); + if (AM.BaseType == X86ISelAddressMode::RegBase && + AM.Base_Reg.getNode() == nullptr && AM.IndexReg.getNode() == nullptr) { + AM.Scale = unsigned(2 << (CN->getZExtValue() - 1)); + + if (N.getOperand(0) == N.getOperand(2)) { + SDValue MulVal = N.getOperand(0); + SDValue Reg; + + // Okay, we know that we have a scale by now. However, if the scaled + // value is an add of something and a constant, we can fold the + // constant into the disp field here. + if (MulVal.getNode()->getOpcode() == ISD::ADD && + N->isOnlyUserOf(MulVal.getNode()) && + isa(MulVal.getOperand(1))) { + Reg = MulVal.getOperand(0); + auto *AddVal = cast(MulVal.getOperand(1)); + uint64_t Disp = AddVal->getSExtValue() * (AM.Scale + 1); + if (foldOffsetIntoAddress(Disp, AM)) + Reg = N.getOperand(0); + } else { + Reg = N.getOperand(0); + } + AM.IndexReg = AM.Base_Reg = Reg; + return false; + } + // TODO: If N.getOperand(2) is a constant, we could try folding + // the displacement analogously to the above. + AM.IndexReg = N.getOperand(0); + AM.Base_Reg = N.getOperand(2); + return false; + } + break; + } case ISD::SUB: { // Given A-B, if A can be completely folded into the address and // the index field with the index field unused, use -B as the index. diff --git a/llvm/lib/Target/X86/X86ISelLowering.cpp b/llvm/lib/Target/X86/X86ISelLowering.cpp index f16a751a166d6..1748b372e29b5 100644 --- a/llvm/lib/Target/X86/X86ISelLowering.cpp +++ b/llvm/lib/Target/X86/X86ISelLowering.cpp @@ -33553,7 +33553,6 @@ const char *X86TargetLowering::getTargetNodeName(unsigned Opcode) const { NODE_NAME_CASE(BZHI) NODE_NAME_CASE(PDEP) NODE_NAME_CASE(PEXT) - NODE_NAME_CASE(MUL_IMM) NODE_NAME_CASE(MOVMSK) NODE_NAME_CASE(PTEST) NODE_NAME_CASE(TESTP) @@ -36845,13 +36844,6 @@ void X86TargetLowering::computeKnownBitsForTargetNode(const SDValue Op, Known.resetAll(); switch (Opc) { default: break; - case X86ISD::MUL_IMM: { - KnownBits Known2; - Known = DAG.computeKnownBits(Op.getOperand(1), DemandedElts, Depth + 1); - Known2 = DAG.computeKnownBits(Op.getOperand(0), DemandedElts, Depth + 1); - Known = KnownBits::mul(Known, Known2); - break; - } case X86ISD::SETCC: Known.Zero.setBitsFrom(1); break; @@ -46905,12 +46897,18 @@ static SDValue reduceVMULWidth(SDNode *N, SelectionDAG &DAG, return DAG.getNode(ISD::CONCAT_VECTORS, DL, VT, ResLo, ResHi); } +static SDValue createMulImm(uint64_t MulAmt, SDValue N, SelectionDAG &DAG, + EVT VT, const SDLoc &DL) { + assert(MulAmt == 3 || MulAmt == 5 || MulAmt == 9); + SDValue ShAmt = DAG.getConstant(Log2_64(MulAmt - 1), DL, MVT::i8); + return DAG.getNode(ISD::SHL_ADD, DL, VT, N, ShAmt, N); +} + static SDValue combineMulSpecial(uint64_t MulAmt, SDNode *N, SelectionDAG &DAG, EVT VT, const SDLoc &DL) { auto combineMulShlAddOrSub = [&](int Mult, int Shift, bool isAdd) { - SDValue Result = DAG.getNode(X86ISD::MUL_IMM, DL, VT, N->getOperand(0), - DAG.getConstant(Mult, DL, VT)); + SDValue Result = createMulImm(Mult, N->getOperand(0), DAG, VT, DL); Result = DAG.getNode(ISD::SHL, DL, VT, Result, DAG.getConstant(Shift, DL, MVT::i8)); Result = DAG.getNode(isAdd ? ISD::ADD : ISD::SUB, DL, VT, Result, @@ -46919,10 +46917,8 @@ static SDValue combineMulSpecial(uint64_t MulAmt, SDNode *N, SelectionDAG &DAG, }; auto combineMulMulAddOrSub = [&](int Mul1, int Mul2, bool isAdd) { - SDValue Result = DAG.getNode(X86ISD::MUL_IMM, DL, VT, N->getOperand(0), - DAG.getConstant(Mul1, DL, VT)); - Result = DAG.getNode(X86ISD::MUL_IMM, DL, VT, Result, - DAG.getConstant(Mul2, DL, VT)); + SDValue Result = createMulImm(Mul1, N->getOperand(0), DAG, VT, DL); + Result = createMulImm(Mul2, Result, DAG, VT, DL); Result = DAG.getNode(isAdd ? ISD::ADD : ISD::SUB, DL, VT, Result, N->getOperand(0)); return Result; @@ -46982,9 +46978,8 @@ static SDValue combineMulSpecial(uint64_t MulAmt, SDNode *N, SelectionDAG &DAG, unsigned ShiftAmt = Log2_64((MulAmt & (MulAmt - 1))); SDValue Shift1 = DAG.getNode(ISD::SHL, DL, VT, N->getOperand(0), DAG.getConstant(ShiftAmt, DL, MVT::i8)); - SDValue Shift2 = DAG.getNode(ISD::SHL, DL, VT, N->getOperand(0), - DAG.getConstant(ScaleShift, DL, MVT::i8)); - return DAG.getNode(ISD::ADD, DL, VT, Shift1, Shift2); + return DAG.getNode(ISD::SHL_ADD, DL, VT, N->getOperand(0), + DAG.getConstant(ScaleShift, DL, MVT::i8), Shift1); } } @@ -47204,8 +47199,7 @@ static SDValue combineMul(SDNode *N, SelectionDAG &DAG, SDValue NewMul = SDValue(); if (VT == MVT::i64 || VT == MVT::i32) { if (AbsMulAmt == 3 || AbsMulAmt == 5 || AbsMulAmt == 9) { - NewMul = DAG.getNode(X86ISD::MUL_IMM, DL, VT, N->getOperand(0), - DAG.getConstant(AbsMulAmt, DL, VT)); + NewMul = createMulImm(AbsMulAmt, N->getOperand(0), DAG, VT, DL); if (SignMulAmt < 0) NewMul = DAG.getNode(ISD::SUB, DL, VT, DAG.getConstant(0, DL, VT), NewMul); @@ -47243,15 +47237,13 @@ static SDValue combineMul(SDNode *N, SelectionDAG &DAG, NewMul = DAG.getNode(ISD::SHL, DL, VT, N->getOperand(0), DAG.getConstant(Log2_64(MulAmt1), DL, MVT::i8)); else - NewMul = DAG.getNode(X86ISD::MUL_IMM, DL, VT, N->getOperand(0), - DAG.getConstant(MulAmt1, DL, VT)); + NewMul = createMulImm(MulAmt1, N->getOperand(0), DAG, VT, DL); if (isPowerOf2_64(MulAmt2)) NewMul = DAG.getNode(ISD::SHL, DL, VT, NewMul, DAG.getConstant(Log2_64(MulAmt2), DL, MVT::i8)); else - NewMul = DAG.getNode(X86ISD::MUL_IMM, DL, VT, NewMul, - DAG.getConstant(MulAmt2, DL, VT)); + NewMul = NewMul = createMulImm(MulAmt2, NewMul, DAG, VT, DL); // Negate the result. if (SignMulAmt < 0) diff --git a/llvm/lib/Target/X86/X86ISelLowering.h b/llvm/lib/Target/X86/X86ISelLowering.h index 0a1e8ca442731..7c5bfac3308c8 100644 --- a/llvm/lib/Target/X86/X86ISelLowering.h +++ b/llvm/lib/Target/X86/X86ISelLowering.h @@ -417,9 +417,6 @@ namespace llvm { PDEP, PEXT, - // X86-specific multiply by immediate. - MUL_IMM, - // Vector sign bit extraction. MOVMSK, diff --git a/llvm/lib/Target/X86/X86InstrFragments.td b/llvm/lib/Target/X86/X86InstrFragments.td index f14c7200af968..faeeccab7dac7 100644 --- a/llvm/lib/Target/X86/X86InstrFragments.td +++ b/llvm/lib/Target/X86/X86InstrFragments.td @@ -284,8 +284,6 @@ def X86bzhi : SDNode<"X86ISD::BZHI", SDTIntBinOp>; def X86pdep : SDNode<"X86ISD::PDEP", SDTIntBinOp>; def X86pext : SDNode<"X86ISD::PEXT", SDTIntBinOp>; -def X86mul_imm : SDNode<"X86ISD::MUL_IMM", SDTIntBinOp>; - def X86DynAlloca : SDNode<"X86ISD::DYN_ALLOCA", SDT_X86DYN_ALLOCA, [SDNPHasChain, SDNPOutGlue]>; @@ -341,11 +339,11 @@ def X86cmpccxadd : SDNode<"X86ISD::CMPCCXADD", SDTX86Cmpccxadd, // Define X86-specific addressing mode. def addr : ComplexPattern; def lea32addr : ComplexPattern; // In 64-bit mode 32-bit LEAs can use RIP-relative addressing. def lea64_32addr : ComplexPattern; @@ -356,7 +354,7 @@ def tls32baseaddr : ComplexPattern; def lea64addr : ComplexPattern; def tls64addr : ComplexPattern