diff --git a/llvm/lib/Target/RISCV/RISCVISelDAGToDAG.cpp b/llvm/lib/Target/RISCV/RISCVISelDAGToDAG.cpp index d69fdf0c490eb..235afeb34126c 100644 --- a/llvm/lib/Target/RISCV/RISCVISelDAGToDAG.cpp +++ b/llvm/lib/Target/RISCV/RISCVISelDAGToDAG.cpp @@ -2049,7 +2049,9 @@ bool RISCVDAGToDAGISel::selectZExti32(SDValue N, SDValue &Val) { /// SHXADD we are trying to match. bool RISCVDAGToDAGISel::selectSHXADDOp(SDValue N, unsigned ShAmt, SDValue &Val) { - if (N.getOpcode() == ISD::SHL && isa(N.getOperand(1))) { + bool LeftShift = N.getOpcode() == ISD::SHL; + if ((LeftShift || N.getOpcode() == ISD::SRL) && + isa(N.getOperand(1))) { unsigned C1 = N.getConstantOperandVal(1); SDValue N0 = N.getOperand(0); if (N0.getOpcode() == ISD::AND && N0.hasOneUse() && @@ -2061,12 +2063,26 @@ bool RISCVDAGToDAGISel::selectSHXADDOp(SDValue N, unsigned ShAmt, unsigned Trailing = countTrailingZeros(Mask); // Look for (shl (and X, Mask), C1) where Mask has 32 leading zeros and // C3 trailing zeros. If C1+C3==ShAmt we can use SRLIW+SHXADD. - if (Leading == 32 && Trailing > 0 && (C1 + Trailing) == ShAmt) { + if (LeftShift && Leading == 32 && Trailing > 0 && + (Trailing + C1) == ShAmt) { SDLoc DL(N); EVT VT = N.getValueType(); Val = SDValue(CurDAG->getMachineNode( - RISCV::SRLIW, DL, VT, N0->getOperand(0), - CurDAG->getTargetConstant(Trailing, DL, VT)), 0); + RISCV::SRLIW, DL, VT, N0.getOperand(0), + CurDAG->getTargetConstant(Trailing, DL, VT)), + 0); + return true; + } + // Look for (srl (and X, Mask), C1) where Mask has 32 leading zeros and + // C3 trailing zeros. If C3-C1==ShAmt we can use SRLIW+SHXADD. + if (!LeftShift && Leading == 32 && Trailing > C1 && + (Trailing - C1) == ShAmt) { + SDLoc DL(N); + EVT VT = N.getValueType(); + Val = SDValue(CurDAG->getMachineNode( + RISCV::SRLIW, DL, VT, N0.getOperand(0), + CurDAG->getTargetConstant(Trailing, DL, VT)), + 0); return true; } } diff --git a/llvm/test/CodeGen/RISCV/rv64zba.ll b/llvm/test/CodeGen/RISCV/rv64zba.ll index 06d256ee4e16c..4d0bc2e8437df 100644 --- a/llvm/test/CodeGen/RISCV/rv64zba.ll +++ b/llvm/test/CodeGen/RISCV/rv64zba.ll @@ -1375,3 +1375,70 @@ define i64 @srliw_2_sh3add(i64* %0, i32 signext %1) { %6 = load i64, i64* %5, align 8 ret i64 %6 } + +define signext i16 @srliw_2_sh1add(i16* %0, i32 signext %1) { +; RV64I-LABEL: srliw_2_sh1add: +; RV64I: # %bb.0: +; RV64I-NEXT: srliw a1, a1, 2 +; RV64I-NEXT: slli a1, a1, 1 +; RV64I-NEXT: add a0, a0, a1 +; RV64I-NEXT: lh a0, 0(a0) +; RV64I-NEXT: ret +; +; RV64ZBA-LABEL: srliw_2_sh1add: +; RV64ZBA: # %bb.0: +; RV64ZBA-NEXT: srliw a1, a1, 2 +; RV64ZBA-NEXT: sh1add a0, a1, a0 +; RV64ZBA-NEXT: lh a0, 0(a0) +; RV64ZBA-NEXT: ret + %3 = lshr i32 %1, 2 + %4 = zext i32 %3 to i64 + %5 = getelementptr inbounds i16, i16* %0, i64 %4 + %6 = load i16, i16* %5, align 2 + ret i16 %6 +} + + +define signext i32 @srliw_3_sh2add(i32* %0, i32 signext %1) { +; RV64I-LABEL: srliw_3_sh2add: +; RV64I: # %bb.0: +; RV64I-NEXT: srliw a1, a1, 3 +; RV64I-NEXT: slli a1, a1, 2 +; RV64I-NEXT: add a0, a0, a1 +; RV64I-NEXT: lw a0, 0(a0) +; RV64I-NEXT: ret +; +; RV64ZBA-LABEL: srliw_3_sh2add: +; RV64ZBA: # %bb.0: +; RV64ZBA-NEXT: srliw a1, a1, 3 +; RV64ZBA-NEXT: sh2add a0, a1, a0 +; RV64ZBA-NEXT: lw a0, 0(a0) +; RV64ZBA-NEXT: ret + %3 = lshr i32 %1, 3 + %4 = zext i32 %3 to i64 + %5 = getelementptr inbounds i32, i32* %0, i64 %4 + %6 = load i32, i32* %5, align 4 + ret i32 %6 +} + +define i64 @srliw_4_sh3add(i64* %0, i32 signext %1) { +; RV64I-LABEL: srliw_4_sh3add: +; RV64I: # %bb.0: +; RV64I-NEXT: srliw a1, a1, 4 +; RV64I-NEXT: slli a1, a1, 3 +; RV64I-NEXT: add a0, a0, a1 +; RV64I-NEXT: ld a0, 0(a0) +; RV64I-NEXT: ret +; +; RV64ZBA-LABEL: srliw_4_sh3add: +; RV64ZBA: # %bb.0: +; RV64ZBA-NEXT: srliw a1, a1, 4 +; RV64ZBA-NEXT: sh3add a0, a1, a0 +; RV64ZBA-NEXT: ld a0, 0(a0) +; RV64ZBA-NEXT: ret + %3 = lshr i32 %1, 4 + %4 = zext i32 %3 to i64 + %5 = getelementptr inbounds i64, i64* %0, i64 %4 + %6 = load i64, i64* %5, align 8 + ret i64 %6 +}