diff --git a/llvm/lib/Target/RISCV/RISCVISelLowering.cpp b/llvm/lib/Target/RISCV/RISCVISelLowering.cpp index c12426458c3a5..ffd15483b246e 100644 --- a/llvm/lib/Target/RISCV/RISCVISelLowering.cpp +++ b/llvm/lib/Target/RISCV/RISCVISelLowering.cpp @@ -10328,31 +10328,57 @@ SDValue RISCVTargetLowering::lowerShiftRightParts(SDValue Op, SelectionDAG &DAG, SDValue Shamt = Op.getOperand(2); EVT VT = Lo.getValueType(); - // With P extension on RV32, use NSRL/NSRA for the low part. - if (Subtarget.hasStdExtP() && !Subtarget.is64Bit()) { - SDValue LoRes = DAG.getNode(IsSRA ? RISCVISD::NSRA : RISCVISD::NSRL, DL, VT, - Lo, Hi, Shamt); - // Mask shift amount to avoid UB when Shamt >= 32. + // With P extension, use NSRL/NSRA for RV32 or FSHR (SRX) for RV64. + if (Subtarget.hasStdExtP()) { + unsigned XLen = Subtarget.getXLen(); + + SDValue LoRes; + if (Subtarget.is64Bit()) { + // On RV64, use FSHR (SRX instruction) for the low part. We will need + // to fix this later if ShAmt >= 64. + LoRes = DAG.getNode(ISD::FSHR, DL, VT, Hi, Lo, Shamt); + } else { + // On RV32, use NSRL/NSRA for the low part. + // NSRL/NSRA read 6 bits of shift amount, so they handle Shamt >= 32 + // correctly. + LoRes = DAG.getNode(IsSRA ? RISCVISD::NSRA : RISCVISD::NSRL, DL, VT, Lo, + Hi, Shamt); + } + + // Mask shift amount to avoid UB when Shamt >= XLen. SDValue ShamtMasked = - DAG.getNode(ISD::AND, DL, VT, Shamt, DAG.getConstant(31, DL, VT)); + DAG.getNode(ISD::AND, DL, VT, Shamt, DAG.getConstant(XLen - 1, DL, VT)); SDValue HiRes = DAG.getNode(IsSRA ? ISD::SRA : ISD::SRL, DL, VT, Hi, ShamtMasked); - // Create a mask that is -1 when Shamt >= 32, 0 otherwise. + // Create a mask that is -1 when Shamt >= XLen, 0 otherwise. // FIXME: We should use a select and let LowerSelect make the // optimizations. SDValue ShAmtExt = - DAG.getNode(ISD::SHL, DL, VT, Shamt, DAG.getConstant(26, DL, VT)); - SDValue Mask = - DAG.getNode(ISD::SRA, DL, VT, ShAmtExt, DAG.getConstant(31, DL, VT)); + DAG.getNode(ISD::SHL, DL, VT, Shamt, + DAG.getConstant(XLen - Log2_32(XLen) - 1, DL, VT)); + SDValue Mask = DAG.getNode(ISD::SRA, DL, VT, ShAmtExt, + DAG.getConstant(XLen - 1, DL, VT)); + + if (Subtarget.is64Bit()) { + // On RV64, FSHR masks shift amount to 63. We need to replace LoRes + // with HiRes when Shamt >= 64. + // LoRes = (LoRes & ~Mask) | (HiRes & Mask) + SDValue LoMasked = + DAG.getNode(ISD::AND, DL, VT, LoRes, DAG.getNOT(DL, Mask, VT)); + SDValue HiMasked = DAG.getNode(ISD::AND, DL, VT, HiRes, Mask); + LoRes = DAG.getNode(ISD::OR, DL, VT, LoMasked, HiMasked, + SDNodeFlags::Disjoint); + } + // If ShAmt >= XLen, we need to replace HiRes with 0 or sign bits. if (IsSRA) { - // sra hi, hi, (mask & 31) - shifts by 31 when shamt >= 32 - SDValue MaskAmt = - DAG.getNode(ISD::AND, DL, VT, Mask, DAG.getConstant(31, DL, VT)); + // sra hi, hi, (mask & (XLen-1)) - shifts by XLen-1 when shamt >= XLen + SDValue MaskAmt = DAG.getNode(ISD::AND, DL, VT, Mask, + DAG.getConstant(XLen - 1, DL, VT)); HiRes = DAG.getNode(ISD::SRA, DL, VT, HiRes, MaskAmt); } else { - // andn hi, hi, mask - clears hi when shamt >= 32 + // andn hi, hi, mask - clears hi when shamt >= XLen HiRes = DAG.getNode(ISD::AND, DL, VT, HiRes, DAG.getNOT(DL, Mask, VT)); } diff --git a/llvm/test/CodeGen/RISCV/rv64p.ll b/llvm/test/CodeGen/RISCV/rv64p.ll index 670022a537e00..747a676b134fa 100644 --- a/llvm/test/CodeGen/RISCV/rv64p.ll +++ b/llvm/test/CodeGen/RISCV/rv64p.ll @@ -391,21 +391,12 @@ define i128 @slli_i128_large(i128 %x) { define i128 @srl_i128(i128 %x, i128 %y) { ; CHECK-LABEL: srl_i128: ; CHECK: # %bb.0: -; CHECK-NEXT: addi a4, a2, -64 ; CHECK-NEXT: srl a3, a1, a2 -; CHECK-NEXT: bltz a4, .LBB32_2 -; CHECK-NEXT: # %bb.1: -; CHECK-NEXT: mv a0, a3 -; CHECK-NEXT: j .LBB32_3 -; CHECK-NEXT: .LBB32_2: -; CHECK-NEXT: srl a0, a0, a2 -; CHECK-NEXT: not a2, a2 -; CHECK-NEXT: slli a1, a1, 1 -; CHECK-NEXT: sll a1, a1, a2 -; CHECK-NEXT: or a0, a0, a1 -; CHECK-NEXT: .LBB32_3: -; CHECK-NEXT: srai a1, a4, 63 -; CHECK-NEXT: and a1, a1, a3 +; CHECK-NEXT: srx a0, a1, a2 +; CHECK-NEXT: slli a2, a2, 57 +; CHECK-NEXT: srai a2, a2, 63 +; CHECK-NEXT: mvm a0, a3, a2 +; CHECK-NEXT: andn a1, a3, a2 ; CHECK-NEXT: ret %b = lshr i128 %x, %y ret i128 %b @@ -461,21 +452,12 @@ define i128 @srli_i128_large(i128 %x) { define i128 @sra_i128(i128 %x, i128 %y) { ; CHECK-LABEL: sra_i128: ; CHECK: # %bb.0: -; CHECK-NEXT: mv a3, a1 -; CHECK-NEXT: addi a4, a2, -64 -; CHECK-NEXT: sra a1, a1, a2 -; CHECK-NEXT: bltz a4, .LBB37_2 -; CHECK-NEXT: # %bb.1: -; CHECK-NEXT: srai a3, a3, 63 -; CHECK-NEXT: mv a0, a1 -; CHECK-NEXT: mv a1, a3 -; CHECK-NEXT: ret -; CHECK-NEXT: .LBB37_2: -; CHECK-NEXT: srl a0, a0, a2 -; CHECK-NEXT: not a2, a2 -; CHECK-NEXT: slli a3, a3, 1 -; CHECK-NEXT: sll a2, a3, a2 -; CHECK-NEXT: or a0, a0, a2 +; CHECK-NEXT: sra a3, a1, a2 +; CHECK-NEXT: srx a0, a1, a2 +; CHECK-NEXT: slli a2, a2, 57 +; CHECK-NEXT: srai a2, a2, 63 +; CHECK-NEXT: mvm a0, a3, a2 +; CHECK-NEXT: sra a1, a3, a2 ; CHECK-NEXT: ret %b = ashr i128 %x, %y ret i128 %b