diff --git a/llvm/lib/Target/RISCV/RISCVISelDAGToDAG.cpp b/llvm/lib/Target/RISCV/RISCVISelDAGToDAG.cpp index 3492e60662380..76ecd4fccfd85 100644 --- a/llvm/lib/Target/RISCV/RISCVISelDAGToDAG.cpp +++ b/llvm/lib/Target/RISCV/RISCVISelDAGToDAG.cpp @@ -1871,6 +1871,35 @@ void RISCVDAGToDAGISel::Select(SDNode *Node) { CurDAG->RemoveDeadNode(Node); return; } + case RISCVISD::WSLL: + case RISCVISD::WSLA: { + // Custom select WSLL/WSLA for RV32P. + assert(Subtarget->hasStdExtP() && !Subtarget->is64Bit() && VT == MVT::i32 && + "Unexpected opcode"); + + bool IsSigned = Node->getOpcode() == RISCVISD::WSLA; + + SDValue ShAmt = Node->getOperand(1); + + unsigned Opc; + + auto *ShAmtC = dyn_cast(ShAmt); + if (ShAmtC && ShAmtC->getZExtValue() < 64) { + Opc = IsSigned ? RISCV::WSLAI : RISCV::WSLLI; + ShAmt = CurDAG->getTargetConstant(ShAmtC->getZExtValue(), DL, XLenVT); + } else { + Opc = IsSigned ? RISCV::WSLA : RISCV::WSLL; + } + + SDNode *WShift = CurDAG->getMachineNode(Opc, DL, MVT::Untyped, + Node->getOperand(0), ShAmt); + + auto [Lo, Hi] = extractGPRPair(CurDAG, DL, SDValue(WShift, 0)); + ReplaceUses(SDValue(Node, 0), Lo); + ReplaceUses(SDValue(Node, 1), Hi); + CurDAG->RemoveDeadNode(Node); + return; + } case ISD::LOAD: { if (tryIndexedLoad(Node)) return; diff --git a/llvm/lib/Target/RISCV/RISCVISelLowering.cpp b/llvm/lib/Target/RISCV/RISCVISelLowering.cpp index 891bc22a7463d..29ff12aa96efd 100644 --- a/llvm/lib/Target/RISCV/RISCVISelLowering.cpp +++ b/llvm/lib/Target/RISCV/RISCVISelLowering.cpp @@ -15418,6 +15418,26 @@ void RISCVTargetLowering::ReplaceNodeResults(SDNode *N, assert(!Subtarget.is64Bit() && Subtarget.hasStdExtP() && "Unexpected custom legalisation"); + SDValue LHS = N->getOperand(0); + SDValue ShAmt = N->getOperand(1); + + unsigned WideOpc = 0; + APInt HighMask = APInt::getHighBitsSet(64, 32); + if (DAG.MaskedValueIsZero(LHS, HighMask)) + WideOpc = RISCVISD::WSLL; + else if (DAG.ComputeMaxSignificantBits(LHS) <= 32) + WideOpc = RISCVISD::WSLA; + + if (WideOpc) { + SDValue Res = + DAG.getNode(WideOpc, DL, DAG.getVTList(MVT::i32, MVT::i32), + DAG.getNode(ISD::TRUNCATE, DL, MVT::i32, LHS), + DAG.getNode(ISD::TRUNCATE, DL, MVT::i32, ShAmt)); + Results.push_back(DAG.getNode(ISD::BUILD_PAIR, DL, N->getValueType(0), + Res, Res.getValue(1))); + return; + } + // Only handle constant shifts < 32. Non-constant shifts are handled by // lowerShiftLeftParts/lowerShiftRightParts, and shifts >= 32 use default // legalization. @@ -15425,22 +15445,22 @@ void RISCVTargetLowering::ReplaceNodeResults(SDNode *N, if (!ShAmtC || ShAmtC->getZExtValue() >= 32) break; - auto [Lo, Hi] = DAG.SplitScalar(N->getOperand(0), DL, MVT::i32, MVT::i32); + auto [Lo, Hi] = DAG.SplitScalar(LHS, DL, MVT::i32, MVT::i32); SDValue LoRes, HiRes; if (N->getOpcode() == ISD::SHL) { // Lo = slli Lo, shamt // Hi = nsrli {Hi, Lo}, (32 - shamt) uint64_t ShAmtVal = ShAmtC->getZExtValue(); - LoRes = DAG.getNode(ISD::SHL, DL, MVT::i32, Lo, N->getOperand(1)); + LoRes = DAG.getNode(ISD::SHL, DL, MVT::i32, Lo, ShAmt); HiRes = DAG.getNode(RISCVISD::NSRL, DL, MVT::i32, Lo, Hi, DAG.getConstant(32 - ShAmtVal, DL, MVT::i32)); } else { bool IsSRA = N->getOpcode() == ISD::SRA; LoRes = DAG.getNode(IsSRA ? RISCVISD::NSRA : RISCVISD::NSRL, DL, - MVT::i32, Lo, Hi, N->getOperand(1)); - HiRes = DAG.getNode(IsSRA ? ISD::SRA : ISD::SRL, DL, MVT::i32, Hi, - N->getOperand(1)); + MVT::i32, Lo, Hi, ShAmt); + HiRes = + DAG.getNode(IsSRA ? ISD::SRA : ISD::SRL, DL, MVT::i32, Hi, ShAmt); } SDValue Res = DAG.getNode(ISD::BUILD_PAIR, DL, MVT::i64, LoRes, HiRes); Results.push_back(Res); diff --git a/llvm/lib/Target/RISCV/RISCVInstrInfoP.td b/llvm/lib/Target/RISCV/RISCVInstrInfoP.td index 7bb9ad5feb219..ea16cf28bfd7c 100644 --- a/llvm/lib/Target/RISCV/RISCVInstrInfoP.td +++ b/llvm/lib/Target/RISCV/RISCVInstrInfoP.td @@ -1633,6 +1633,13 @@ def riscv_wsubau : RVSDNode<"WSUBAU", SDT_RISCVWideningAddSubAccumulate>; def riscv_wmulsu : RVSDNode<"WMULSU", SDTIntBinHiLoOp>; +def SDT_RISCVWideningShiftLeft : SDTypeProfile<2, 2, [SDTCisVT<0, i32>, + SDTCisSameAs<0, 1>, + SDTCisSameAs<0, 2>, + SDTCisSameAs<0, 3>]>; +def riscv_wsll : RVSDNode<"WSLL", SDT_RISCVWideningShiftLeft>; +def riscv_wsla : RVSDNode<"WSLA", SDT_RISCVWideningShiftLeft>; + // Narrowing shift: res = nsrl(lo, hi, shamt) is equivalent to // res = truncate (srl (build_pair lo, hi), shamt), XLenVT def SDT_RISCVNarrowingShift : SDTypeProfile<1, 3, [SDTCisVT<0, i32>, diff --git a/llvm/test/CodeGen/RISCV/rv32p.ll b/llvm/test/CodeGen/RISCV/rv32p.ll index cc00f427126ba..fdc7d98e5d833 100644 --- a/llvm/test/CodeGen/RISCV/rv32p.ll +++ b/llvm/test/CodeGen/RISCV/rv32p.ll @@ -781,6 +781,46 @@ define i64 @wmulsu_i32(i32 %x, i32 %y) { ret i64 %c } +define i64 @wsla_i32(i32 %x, i64 %y) { +; CHECK-LABEL: wsla_i32: +; CHECK: # %bb.0: +; CHECK-NEXT: wsla a0, a0, a1 +; CHECK-NEXT: ret + %a = sext i32 %x to i64 + %b = shl i64 %a, %y + ret i64 %b +} + +define i64 @wsll_i32(i32 %x, i64 %y) { +; CHECK-LABEL: wsll_i32: +; CHECK: # %bb.0: +; CHECK-NEXT: wsll a0, a0, a1 +; CHECK-NEXT: ret + %a = zext i32 %x to i64 + %b = shl i64 %a, %y + ret i64 %b +} + +define i64 @wslai_i32(i32 %x) { +; CHECK-LABEL: wslai_i32: +; CHECK: # %bb.0: +; CHECK-NEXT: wslai a0, a0, 23 +; CHECK-NEXT: ret + %a = sext i32 %x to i64 + %b = shl i64 %a, 23 + ret i64 %b +} + +define i64 @wslli_i32(i32 %x, i64 %y) { +; CHECK-LABEL: wslli_i32: +; CHECK: # %bb.0: +; CHECK-NEXT: wslli a0, a0, 10 +; CHECK-NEXT: ret + %a = zext i32 %x to i64 + %b = shl i64 %a, 10 + ret i64 %b +} + ; Test that mulh continues to be used with P. define i32 @mulh_i32(i32 %x, i32 %y) { ; CHECK-LABEL: mulh_i32: