diff --git a/llvm/lib/CodeGen/SelectionDAG/TargetLowering.cpp b/llvm/lib/CodeGen/SelectionDAG/TargetLowering.cpp index 690a86bd4606c1..92e18a4b630e91 100644 --- a/llvm/lib/CodeGen/SelectionDAG/TargetLowering.cpp +++ b/llvm/lib/CodeGen/SelectionDAG/TargetLowering.cpp @@ -2586,6 +2586,17 @@ bool TargetLowering::SimplifyDemandedBits( break; if (Src.getNode()->hasOneUse()) { + if (isTruncateFree(Src, VT) && + !isTruncateFree(Src.getValueType(), VT)) { + // If truncate is only free at trunc(srl), do not turn it into + // srl(trunc). The check is done by first check the truncate is free + // at Src's opcode(srl), then check the truncate is not done by + // referencing sub-register. In test, if both trunc(srl) and + // srl(trunc)'s trunc are free, srl(trunc) performs better. If only + // trunc(srl)'s trunc is free, trunc(srl) is better. + break; + } + std::optional ShAmtC = TLO.DAG.getValidShiftAmount(Src, DemandedElts, Depth + 2); if (!ShAmtC || *ShAmtC >= BitWidth) @@ -2596,7 +2607,6 @@ bool TargetLowering::SimplifyDemandedBits( APInt::getHighBitsSet(OperandBitWidth, OperandBitWidth - BitWidth); HighBits.lshrInPlace(ShVal); HighBits = HighBits.trunc(BitWidth); - if (!(HighBits & DemandedBits)) { // None of the shifted in bits are needed. Add a truncate of the // shift input, then shift it. diff --git a/llvm/lib/Target/RISCV/RISCVISelLowering.cpp b/llvm/lib/Target/RISCV/RISCVISelLowering.cpp index b8ba25df9910bb..caa4ebacc41da6 100644 --- a/llvm/lib/Target/RISCV/RISCVISelLowering.cpp +++ b/llvm/lib/Target/RISCV/RISCVISelLowering.cpp @@ -1894,6 +1894,21 @@ bool RISCVTargetLowering::isTruncateFree(EVT SrcVT, EVT DstVT) const { return (SrcBits == 64 && DestBits == 32); } +bool RISCVTargetLowering::isTruncateFree(SDValue Val, EVT VT2) const { + EVT SrcVT = Val.getValueType(); + // free truncate from vnsrl and vnsra + if (Subtarget.hasStdExtV() && + (Val.getOpcode() == ISD::SRL || Val.getOpcode() == ISD::SRA) && + SrcVT.isVector() && VT2.isVector()) { + unsigned SrcBits = SrcVT.getVectorElementType().getSizeInBits(); + unsigned DestBits = VT2.getVectorElementType().getSizeInBits(); + if (SrcBits == DestBits * 2) { + return true; + } + } + return TargetLowering::isTruncateFree(Val, VT2); +} + bool RISCVTargetLowering::isZExtFree(SDValue Val, EVT VT2) const { // Zexts are free if they can be combined with a load. // Don't advertise i32->i64 zextload as being free for RV64. It interacts diff --git a/llvm/lib/Target/RISCV/RISCVISelLowering.h b/llvm/lib/Target/RISCV/RISCVISelLowering.h index 7d8bceb5cb3417..2642a188820e14 100644 --- a/llvm/lib/Target/RISCV/RISCVISelLowering.h +++ b/llvm/lib/Target/RISCV/RISCVISelLowering.h @@ -497,6 +497,7 @@ class RISCVTargetLowering : public TargetLowering { bool isLegalAddImmediate(int64_t Imm) const override; bool isTruncateFree(Type *SrcTy, Type *DstTy) const override; bool isTruncateFree(EVT SrcVT, EVT DstVT) const override; + bool isTruncateFree(SDValue Val, EVT VT2) const override; bool isZExtFree(SDValue Val, EVT VT2) const override; bool isSExtCheaperThanZExt(EVT SrcVT, EVT DstVT) const override; bool signExtendConstant(const ConstantInt *CI) const override; diff --git a/llvm/test/CodeGen/RISCV/pr94265.ll b/llvm/test/CodeGen/RISCV/pr94265.ll new file mode 100644 index 00000000000000..cb41e22381d19d --- /dev/null +++ b/llvm/test/CodeGen/RISCV/pr94265.ll @@ -0,0 +1,31 @@ +; NOTE: Assertions have been autogenerated by utils/update_llc_test_checks.py +; RUN: llc < %s -mtriple=riscv32-- -mattr=+v | FileCheck -check-prefix=RV32I %s +; RUN: llc < %s -mtriple=riscv64-- -mattr=+v | FileCheck -check-prefix=RV64I %s + +define <8 x i16> @PR94265(<8 x i32> %a0) #0 { +; RV32I-LABEL: PR94265: +; RV32I: # %bb.0: +; RV32I-NEXT: vsetivli zero, 8, e32, m2, ta, ma +; RV32I-NEXT: vsra.vi v10, v8, 31 +; RV32I-NEXT: vsrl.vi v10, v10, 26 +; RV32I-NEXT: vadd.vv v8, v8, v10 +; RV32I-NEXT: vsetvli zero, zero, e16, m1, ta, ma +; RV32I-NEXT: vnsrl.wi v10, v8, 6 +; RV32I-NEXT: vsll.vi v8, v10, 10 +; RV32I-NEXT: ret +; +; RV64I-LABEL: PR94265: +; RV64I: # %bb.0: +; RV64I-NEXT: vsetivli zero, 8, e32, m2, ta, ma +; RV64I-NEXT: vsra.vi v10, v8, 31 +; RV64I-NEXT: vsrl.vi v10, v10, 26 +; RV64I-NEXT: vadd.vv v8, v8, v10 +; RV64I-NEXT: vsetvli zero, zero, e16, m1, ta, ma +; RV64I-NEXT: vnsrl.wi v10, v8, 6 +; RV64I-NEXT: vsll.vi v8, v10, 10 +; RV64I-NEXT: ret + %t1 = sdiv <8 x i32> %a0, + %t2 = trunc <8 x i32> %t1 to <8 x i16> + %t3 = shl <8 x i16> %t2, + ret <8 x i16> %t3 +}