Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

ISel: introduce vector ISD::LRINT, ISD::LLRINT; custom RISCV lowering #66924

Merged
merged 5 commits into from
Oct 19, 2023

Conversation

artagnon
Copy link
Contributor

The issue #55208 noticed that std::rint is vectorized by the SLPVectorizer, but a very similar function, std::lrint, is not. std::lrint corresponds to ISD::LRINT in the SelectionDAG, and std::llrint is a familiar cousin corresponding to ISD::LLRINT. Now, neither ISD::LRINT nor ISD::LLRINT have a corresponding vector variant, and the LangRef makes this clear in the documentation of llvm.lrint.* and llvm.llrint.*.

This patch extends the LangRef to include vector variants of llvm.lrint.* and llvm.llrint.*, and lays the necessary ground-work of scalarizing it for all targets. However, this patch would be devoid of motivation unless we show the utility of these new vector variants. Hence, the RISCV target has been chosen to implement a custom lowering to the vfcvt.x.f.v instruction. The patch also includes a CostModel for RISCV, and a trivial follow-up can potentially enable the SLPVectorizer to vectorize std::lrint and std::llrint, fixing #55208.

The patch includes tests, obviously for the RISCV target, but also for the X86, AArch64, and PowerPC targets to justify the addition of the vector variants to the LangRef.

@llvmbot
Copy link
Collaborator

llvmbot commented Sep 20, 2023

@llvm/pr-subscribers-llvm-ir
@llvm/pr-subscribers-backend-x86
@llvm/pr-subscribers-llvm-analysis
@llvm/pr-subscribers-llvm-selectiondag
@llvm/pr-subscribers-backend-aarch64

@llvm/pr-subscribers-backend-risc-v

Changes

The issue #55208 noticed that std::rint is vectorized by the SLPVectorizer, but a very similar function, std::lrint, is not. std::lrint corresponds to ISD::LRINT in the SelectionDAG, and std::llrint is a familiar cousin corresponding to ISD::LLRINT. Now, neither ISD::LRINT nor ISD::LLRINT have a corresponding vector variant, and the LangRef makes this clear in the documentation of llvm.lrint.* and llvm.llrint.*.

This patch extends the LangRef to include vector variants of llvm.lrint.* and llvm.llrint.*, and lays the necessary ground-work of scalarizing it for all targets. However, this patch would be devoid of motivation unless we show the utility of these new vector variants. Hence, the RISCV target has been chosen to implement a custom lowering to the vfcvt.x.f.v instruction. The patch also includes a CostModel for RISCV, and a trivial follow-up can potentially enable the SLPVectorizer to vectorize std::lrint and std::llrint, fixing #55208.

The patch includes tests, obviously for the RISCV target, but also for the X86, AArch64, and PowerPC targets to justify the addition of the vector variants to the LangRef.


Patch is 439.49 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/66924.diff

22 Files Affected:

  • (modified) llvm/docs/LangRef.rst (+4-2)
  • (modified) llvm/include/llvm/CodeGen/BasicTTIImpl.h (+6)
  • (modified) llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp (+19)
  • (modified) llvm/lib/CodeGen/SelectionDAG/LegalizeFloatTypes.cpp (+9-5)
  • (modified) llvm/lib/CodeGen/SelectionDAG/LegalizeIntegerTypes.cpp (+17)
  • (modified) llvm/lib/CodeGen/SelectionDAG/LegalizeTypes.h (+3-2)
  • (modified) llvm/lib/CodeGen/SelectionDAG/LegalizeVectorOps.cpp (+2)
  • (modified) llvm/lib/CodeGen/SelectionDAG/LegalizeVectorTypes.cpp (+15-1)
  • (modified) llvm/lib/CodeGen/SelectionDAG/SelectionDAG.cpp (+2)
  • (modified) llvm/lib/CodeGen/TargetLoweringBase.cpp (+6-6)
  • (modified) llvm/lib/IR/Verifier.cpp (+1-3)
  • (modified) llvm/lib/Target/RISCV/RISCVISelLowering.cpp (+57-40)
  • (modified) llvm/lib/Target/RISCV/RISCVTargetTransformInfo.cpp (+27)
  • (modified) llvm/test/Analysis/CostModel/RISCV/fround.ll (+130)
  • (added) llvm/test/CodeGen/AArch64/vector-llrint.ll (+621)
  • (added) llvm/test/CodeGen/AArch64/vector-lrint.ll (+622)
  • (added) llvm/test/CodeGen/PowerPC/vector-llrint.ll (+4847)
  • (added) llvm/test/CodeGen/PowerPC/vector-lrint.ll (+4858)
  • (added) llvm/test/CodeGen/RISCV/rvv/llrint-sdnode.ll (+108)
  • (added) llvm/test/CodeGen/RISCV/rvv/lrint-sdnode.ll (+155)
  • (added) llvm/test/CodeGen/X86/vector-llrint.ll (+290)
  • (added) llvm/test/CodeGen/X86/vector-lrint.ll (+429)
diff --git a/llvm/docs/LangRef.rst b/llvm/docs/LangRef.rst
index d3b0cb0cc50cec2..fdff5da1451cf3d 100644
--- a/llvm/docs/LangRef.rst
+++ b/llvm/docs/LangRef.rst
@@ -15802,7 +15802,8 @@ Syntax:
 """""""
 
 This is an overloaded intrinsic. You can use ``llvm.lrint`` on any
-floating-point type. Not all targets support all types however.
+floating-point type or vector of floating-point type. Not all targets
+support all types however.
 
 ::
 
@@ -15846,7 +15847,8 @@ Syntax:
 """""""
 
 This is an overloaded intrinsic. You can use ``llvm.llrint`` on any
-floating-point type. Not all targets support all types however.
+floating-point type or vector of floating-point type. Not all targets
+support all types however.
 
 ::
 
diff --git a/llvm/include/llvm/CodeGen/BasicTTIImpl.h b/llvm/include/llvm/CodeGen/BasicTTIImpl.h
index c11d558a73e9d09..eeca2f4060ed200 100644
--- a/llvm/include/llvm/CodeGen/BasicTTIImpl.h
+++ b/llvm/include/llvm/CodeGen/BasicTTIImpl.h
@@ -1843,6 +1843,12 @@ class BasicTTIImplBase : public TargetTransformInfoImplCRTPBase<T> {
     case Intrinsic::rint:
       ISD = ISD::FRINT;
       break;
+    case Intrinsic::lrint:
+      ISD = ISD::LRINT;
+      break;
+    case Intrinsic::llrint:
+      ISD = ISD::LLRINT;
+      break;
     case Intrinsic::round:
       ISD = ISD::FROUND;
       break;
diff --git a/llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp b/llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp
index 5088d59147492f1..4f3a9ac79965308 100644
--- a/llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp
+++ b/llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp
@@ -504,6 +504,7 @@ namespace {
     SDValue visitUINT_TO_FP(SDNode *N);
     SDValue visitFP_TO_SINT(SDNode *N);
     SDValue visitFP_TO_UINT(SDNode *N);
+    SDValue visitXRINT(SDNode *N);
     SDValue visitFP_ROUND(SDNode *N);
     SDValue visitFP_EXTEND(SDNode *N);
     SDValue visitFNEG(SDNode *N);
@@ -2005,6 +2006,9 @@ SDValue DAGCombiner::visit(SDNode *N) {
   case ISD::UINT_TO_FP:         return visitUINT_TO_FP(N);
   case ISD::FP_TO_SINT:         return visitFP_TO_SINT(N);
   case ISD::FP_TO_UINT:         return visitFP_TO_UINT(N);
+  case ISD::LRINT:
+  case ISD::LLRINT:
+    return visitXRINT(N);
   case ISD::FP_ROUND:           return visitFP_ROUND(N);
   case ISD::FP_EXTEND:          return visitFP_EXTEND(N);
   case ISD::FNEG:               return visitFNEG(N);
@@ -17206,6 +17210,21 @@ SDValue DAGCombiner::visitFP_TO_UINT(SDNode *N) {
   return FoldIntToFPToInt(N, DAG);
 }
 
+SDValue DAGCombiner::visitXRINT(SDNode *N) {
+  SDValue N0 = N->getOperand(0);
+  EVT VT = N->getValueType(0);
+
+  // fold (lrint|llrint undef) -> undef
+  if (N0.isUndef())
+    return DAG.getUNDEF(VT);
+
+  // fold (lrint|llrint c1fp) -> c1
+  if (DAG.isConstantFPBuildVectorOrConstantFP(N0))
+    return DAG.getNode(N->getOpcode(), SDLoc(N), VT, N0);
+
+  return SDValue();
+}
+
 SDValue DAGCombiner::visitFP_ROUND(SDNode *N) {
   SDValue N0 = N->getOperand(0);
   SDValue N1 = N->getOperand(1);
diff --git a/llvm/lib/CodeGen/SelectionDAG/LegalizeFloatTypes.cpp b/llvm/lib/CodeGen/SelectionDAG/LegalizeFloatTypes.cpp
index 95f181217803515..2c8d23fbb674907 100644
--- a/llvm/lib/CodeGen/SelectionDAG/LegalizeFloatTypes.cpp
+++ b/llvm/lib/CodeGen/SelectionDAG/LegalizeFloatTypes.cpp
@@ -2209,10 +2209,15 @@ bool DAGTypeLegalizer::PromoteFloatOperand(SDNode *N, unsigned OpNo) {
     case ISD::BITCAST:    R = PromoteFloatOp_BITCAST(N, OpNo); break;
     case ISD::FCOPYSIGN:  R = PromoteFloatOp_FCOPYSIGN(N, OpNo); break;
     case ISD::FP_TO_SINT:
-    case ISD::FP_TO_UINT: R = PromoteFloatOp_FP_TO_XINT(N, OpNo); break;
+    case ISD::FP_TO_UINT:
+    case ISD::LRINT:
+    case ISD::LLRINT:
+      R = PromoteFloatOp_UnaryOp(N, OpNo);
+      break;
     case ISD::FP_TO_SINT_SAT:
     case ISD::FP_TO_UINT_SAT:
-                          R = PromoteFloatOp_FP_TO_XINT_SAT(N, OpNo); break;
+      R = PromoteFloatOp_BinOp(N, OpNo);
+      break;
     case ISD::FP_EXTEND:  R = PromoteFloatOp_FP_EXTEND(N, OpNo); break;
     case ISD::SELECT_CC:  R = PromoteFloatOp_SELECT_CC(N, OpNo); break;
     case ISD::SETCC:      R = PromoteFloatOp_SETCC(N, OpNo); break;
@@ -2251,13 +2256,12 @@ SDValue DAGTypeLegalizer::PromoteFloatOp_FCOPYSIGN(SDNode *N, unsigned OpNo) {
 }
 
 // Convert the promoted float value to the desired integer type
-SDValue DAGTypeLegalizer::PromoteFloatOp_FP_TO_XINT(SDNode *N, unsigned OpNo) {
+SDValue DAGTypeLegalizer::PromoteFloatOp_UnaryOp(SDNode *N, unsigned OpNo) {
   SDValue Op = GetPromotedFloat(N->getOperand(0));
   return DAG.getNode(N->getOpcode(), SDLoc(N), N->getValueType(0), Op);
 }
 
-SDValue DAGTypeLegalizer::PromoteFloatOp_FP_TO_XINT_SAT(SDNode *N,
-                                                        unsigned OpNo) {
+SDValue DAGTypeLegalizer::PromoteFloatOp_BinOp(SDNode *N, unsigned OpNo) {
   SDValue Op = GetPromotedFloat(N->getOperand(0));
   return DAG.getNode(N->getOpcode(), SDLoc(N), N->getValueType(0), Op,
                      N->getOperand(1));
diff --git a/llvm/lib/CodeGen/SelectionDAG/LegalizeIntegerTypes.cpp b/llvm/lib/CodeGen/SelectionDAG/LegalizeIntegerTypes.cpp
index fc9e3ff3734989d..70e85523b9540b5 100644
--- a/llvm/lib/CodeGen/SelectionDAG/LegalizeIntegerTypes.cpp
+++ b/llvm/lib/CodeGen/SelectionDAG/LegalizeIntegerTypes.cpp
@@ -162,6 +162,11 @@ void DAGTypeLegalizer::PromoteIntegerResult(SDNode *N, unsigned ResNo) {
   case ISD::FP_TO_UINT_SAT:
                          Res = PromoteIntRes_FP_TO_XINT_SAT(N); break;
 
+  case ISD::LRINT:
+  case ISD::LLRINT:
+    Res = PromoteIntRes_XRINT(N);
+    break;
+
   case ISD::FP_TO_BF16:
   case ISD::FP_TO_FP16:
     Res = PromoteIntRes_FP_TO_FP16_BF16(N);
@@ -539,6 +544,18 @@ SDValue DAGTypeLegalizer::PromoteIntRes_BSWAP(SDNode *N) {
                      Mask, EVL);
 }
 
+SDValue DAGTypeLegalizer::PromoteIntRes_XRINT(SDNode *N) {
+  SDValue Op = GetPromotedInteger(N->getOperand(0));
+  EVT OVT = N->getValueType(0);
+  EVT NVT = Op.getValueType();
+  SDLoc dl(N);
+
+  unsigned DiffBits = NVT.getScalarSizeInBits() - OVT.getScalarSizeInBits();
+  SDValue ShAmt = DAG.getShiftAmountConstant(DiffBits, NVT, dl);
+  return DAG.getNode(ISD::SRL, dl, NVT,
+                     DAG.getNode(N->getOpcode(), dl, NVT, Op), ShAmt);
+}
+
 SDValue DAGTypeLegalizer::PromoteIntRes_BITREVERSE(SDNode *N) {
   SDValue Op = GetPromotedInteger(N->getOperand(0));
   EVT OVT = N->getValueType(0);
diff --git a/llvm/lib/CodeGen/SelectionDAG/LegalizeTypes.h b/llvm/lib/CodeGen/SelectionDAG/LegalizeTypes.h
index c802604a3470e13..4f7fb9f12289052 100644
--- a/llvm/lib/CodeGen/SelectionDAG/LegalizeTypes.h
+++ b/llvm/lib/CodeGen/SelectionDAG/LegalizeTypes.h
@@ -324,6 +324,7 @@ class LLVM_LIBRARY_VISIBILITY DAGTypeLegalizer {
   SDValue PromoteIntRes_CTTZ(SDNode *N);
   SDValue PromoteIntRes_EXTRACT_VECTOR_ELT(SDNode *N);
   SDValue PromoteIntRes_FP_TO_XINT(SDNode *N);
+  SDValue PromoteIntRes_XRINT(SDNode *N);
   SDValue PromoteIntRes_FP_TO_XINT_SAT(SDNode *N);
   SDValue PromoteIntRes_FP_TO_FP16_BF16(SDNode *N);
   SDValue PromoteIntRes_FREEZE(SDNode *N);
@@ -711,8 +712,8 @@ class LLVM_LIBRARY_VISIBILITY DAGTypeLegalizer {
   SDValue PromoteFloatOp_BITCAST(SDNode *N, unsigned OpNo);
   SDValue PromoteFloatOp_FCOPYSIGN(SDNode *N, unsigned OpNo);
   SDValue PromoteFloatOp_FP_EXTEND(SDNode *N, unsigned OpNo);
-  SDValue PromoteFloatOp_FP_TO_XINT(SDNode *N, unsigned OpNo);
-  SDValue PromoteFloatOp_FP_TO_XINT_SAT(SDNode *N, unsigned OpNo);
+  SDValue PromoteFloatOp_UnaryOp(SDNode *N, unsigned OpNo);
+  SDValue PromoteFloatOp_BinOp(SDNode *N, unsigned OpNo);
   SDValue PromoteFloatOp_STORE(SDNode *N, unsigned OpNo);
   SDValue PromoteFloatOp_SELECT_CC(SDNode *N, unsigned OpNo);
   SDValue PromoteFloatOp_SETCC(SDNode *N, unsigned OpNo);
diff --git a/llvm/lib/CodeGen/SelectionDAG/LegalizeVectorOps.cpp b/llvm/lib/CodeGen/SelectionDAG/LegalizeVectorOps.cpp
index dec81475f3a88fc..2bc9c5f0844541d 100644
--- a/llvm/lib/CodeGen/SelectionDAG/LegalizeVectorOps.cpp
+++ b/llvm/lib/CodeGen/SelectionDAG/LegalizeVectorOps.cpp
@@ -396,6 +396,8 @@ SDValue VectorLegalizer::LegalizeOp(SDValue Op) {
   case ISD::FCEIL:
   case ISD::FTRUNC:
   case ISD::FRINT:
+  case ISD::LRINT:
+  case ISD::LLRINT:
   case ISD::FNEARBYINT:
   case ISD::FROUND:
   case ISD::FROUNDEVEN:
diff --git a/llvm/lib/CodeGen/SelectionDAG/LegalizeVectorTypes.cpp b/llvm/lib/CodeGen/SelectionDAG/LegalizeVectorTypes.cpp
index 1bb6fbbf064b931..2c5343c3c4b160e 100644
--- a/llvm/lib/CodeGen/SelectionDAG/LegalizeVectorTypes.cpp
+++ b/llvm/lib/CodeGen/SelectionDAG/LegalizeVectorTypes.cpp
@@ -101,6 +101,8 @@ void DAGTypeLegalizer::ScalarizeVectorResult(SDNode *N, unsigned ResNo) {
   case ISD::FP_TO_SINT:
   case ISD::FP_TO_UINT:
   case ISD::FRINT:
+  case ISD::LRINT:
+  case ISD::LLRINT:
   case ISD::FROUND:
   case ISD::FROUNDEVEN:
   case ISD::FSIN:
@@ -681,6 +683,8 @@ bool DAGTypeLegalizer::ScalarizeVectorOperand(SDNode *N, unsigned OpNo) {
   case ISD::FP_TO_UINT:
   case ISD::SINT_TO_FP:
   case ISD::UINT_TO_FP:
+  case ISD::LRINT:
+  case ISD::LLRINT:
     Res = ScalarizeVecOp_UnaryOp(N);
     break;
   case ISD::STRICT_SINT_TO_FP:
@@ -1097,6 +1101,8 @@ void DAGTypeLegalizer::SplitVectorResult(SDNode *N, unsigned ResNo) {
   case ISD::VP_FP_TO_UINT:
   case ISD::FRINT:
   case ISD::VP_FRINT:
+  case ISD::LRINT:
+  case ISD::LLRINT:
   case ISD::FROUND:
   case ISD::VP_FROUND:
   case ISD::FROUNDEVEN:
@@ -2974,6 +2980,8 @@ bool DAGTypeLegalizer::SplitVectorOperand(SDNode *N, unsigned OpNo) {
   case ISD::ZERO_EXTEND:
   case ISD::ANY_EXTEND:
   case ISD::FTRUNC:
+  case ISD::LRINT:
+  case ISD::LLRINT:
     Res = SplitVecOp_UnaryOp(N);
     break;
   case ISD::FLDEXP:
@@ -4209,6 +4217,8 @@ void DAGTypeLegalizer::WidenVectorResult(SDNode *N, unsigned ResNo) {
   case ISD::FLOG2:
   case ISD::FNEARBYINT:
   case ISD::FRINT:
+  case ISD::LRINT:
+  case ISD::LLRINT:
   case ISD::FROUND:
   case ISD::FROUNDEVEN:
   case ISD::FSIN:
@@ -5958,7 +5968,11 @@ bool DAGTypeLegalizer::WidenVectorOperand(SDNode *N, unsigned OpNo) {
   case ISD::STRICT_FSETCCS:     Res = WidenVecOp_STRICT_FSETCC(N); break;
   case ISD::VSELECT:            Res = WidenVecOp_VSELECT(N); break;
   case ISD::FLDEXP:
-  case ISD::FCOPYSIGN:          Res = WidenVecOp_UnrollVectorOp(N); break;
+  case ISD::FCOPYSIGN:
+  case ISD::LRINT:
+  case ISD::LLRINT:
+    Res = WidenVecOp_UnrollVectorOp(N);
+    break;
   case ISD::IS_FPCLASS:         Res = WidenVecOp_IS_FPCLASS(N); break;
 
   case ISD::ANY_EXTEND:
diff --git a/llvm/lib/CodeGen/SelectionDAG/SelectionDAG.cpp b/llvm/lib/CodeGen/SelectionDAG/SelectionDAG.cpp
index 7fcd1f4f898911a..a695c1d1f874b76 100644
--- a/llvm/lib/CodeGen/SelectionDAG/SelectionDAG.cpp
+++ b/llvm/lib/CodeGen/SelectionDAG/SelectionDAG.cpp
@@ -5111,6 +5111,8 @@ bool SelectionDAG::isKnownNeverNaN(SDValue Op, bool SNaN, unsigned Depth) const
   case ISD::FROUND:
   case ISD::FROUNDEVEN:
   case ISD::FRINT:
+  case ISD::LRINT:
+  case ISD::LLRINT:
   case ISD::FNEARBYINT:
   case ISD::FLDEXP: {
     if (SNaN)
diff --git a/llvm/lib/CodeGen/TargetLoweringBase.cpp b/llvm/lib/CodeGen/TargetLoweringBase.cpp
index 3e4bff5ddce1264..99eadf4bb9d578b 100644
--- a/llvm/lib/CodeGen/TargetLoweringBase.cpp
+++ b/llvm/lib/CodeGen/TargetLoweringBase.cpp
@@ -873,13 +873,13 @@ void TargetLoweringBase::initActions() {
 
     // These operations default to expand for vector types.
     if (VT.isVector())
-      setOperationAction({ISD::FCOPYSIGN, ISD::SIGN_EXTEND_INREG,
-                          ISD::ANY_EXTEND_VECTOR_INREG,
-                          ISD::SIGN_EXTEND_VECTOR_INREG,
-                          ISD::ZERO_EXTEND_VECTOR_INREG, ISD::SPLAT_VECTOR},
-                         VT, Expand);
+      setOperationAction(
+          {ISD::FCOPYSIGN, ISD::SIGN_EXTEND_INREG, ISD::ANY_EXTEND_VECTOR_INREG,
+           ISD::SIGN_EXTEND_VECTOR_INREG, ISD::ZERO_EXTEND_VECTOR_INREG,
+           ISD::SPLAT_VECTOR, ISD::LRINT, ISD::LLRINT},
+          VT, Expand);
 
-    // Constrained floating-point operations default to expand.
+      // Constrained floating-point operations default to expand.
 #define DAG_INSTRUCTION(NAME, NARG, ROUND_MODE, INTRINSIC, DAGN)               \
     setOperationAction(ISD::STRICT_##DAGN, VT, Expand);
 #include "llvm/IR/ConstrainedOps.def"
diff --git a/llvm/lib/IR/Verifier.cpp b/llvm/lib/IR/Verifier.cpp
index 5dac691e17cd6ef..7459c7d84874588 100644
--- a/llvm/lib/IR/Verifier.cpp
+++ b/llvm/lib/IR/Verifier.cpp
@@ -5659,9 +5659,7 @@ void Verifier::visitIntrinsicCall(Intrinsic::ID ID, CallBase &Call) {
     break;
   }
   case Intrinsic::lround:
-  case Intrinsic::llround:
-  case Intrinsic::lrint:
-  case Intrinsic::llrint: {
+  case Intrinsic::llround: {
     Type *ValTy = Call.getArgOperand(0)->getType();
     Type *ResultTy = Call.getType();
     Check(!ValTy->isVectorTy() && !ResultTy->isVectorTy(),
diff --git a/llvm/lib/Target/RISCV/RISCVISelLowering.cpp b/llvm/lib/Target/RISCV/RISCVISelLowering.cpp
index f1cea6c6756f4fc..e10002b82e06edf 100644
--- a/llvm/lib/Target/RISCV/RISCVISelLowering.cpp
+++ b/llvm/lib/Target/RISCV/RISCVISelLowering.cpp
@@ -732,7 +732,7 @@ RISCVTargetLowering::RISCVTargetLowering(const TargetMachine &TM,
                          VT, Custom);
       setOperationAction({ISD::FP_TO_SINT_SAT, ISD::FP_TO_UINT_SAT}, VT,
                          Custom);
-
+      setOperationAction({ISD::LRINT, ISD::LLRINT}, VT, Custom);
       setOperationAction(
           {ISD::SADDSAT, ISD::UADDSAT, ISD::SSUBSAT, ISD::USUBSAT}, VT, Legal);
 
@@ -2687,13 +2687,14 @@ static RISCVFPRndMode::RoundingMode matchRoundingOp(unsigned Opc) {
   return RISCVFPRndMode::Invalid;
 }
 
-// Expand vector FTRUNC, FCEIL, FFLOOR, FROUND, VP_FCEIL, VP_FFLOOR, VP_FROUND
-// VP_FROUNDEVEN, VP_FROUNDTOZERO, VP_FRINT and VP_FNEARBYINT by converting to
-// the integer domain and back. Taking care to avoid converting values that are
+// Expand vector FTRUNC, FCEIL, FFLOOR, FROUND, FRINT, VP_FCEIL, VP_FFLOOR,
+// VP_FROUND, VP_FROUNDEVEN, VP_FROUNDTOZERO, VP_FRINT and VP_FNEARBYINT by
+// converting to the integer domain and back. Expand LRINT and LLRINT by
+// converting to integer domain. Take care to avoid converting values that are
 // nan or already correct.
 static SDValue
-lowerVectorFTRUNC_FCEIL_FFLOOR_FROUND(SDValue Op, SelectionDAG &DAG,
-                                      const RISCVSubtarget &Subtarget) {
+lowerVectorFTRUNC_FCEIL_FFLOOR_FROUND_XRINT(SDValue Op, SelectionDAG &DAG,
+                                            const RISCVSubtarget &Subtarget) {
   MVT VT = Op.getSimpleValueType();
   assert(VT.isVector() && "Unexpected type");
 
@@ -2721,28 +2722,34 @@ lowerVectorFTRUNC_FCEIL_FFLOOR_FROUND(SDValue Op, SelectionDAG &DAG,
   // Freeze the source since we are increasing the number of uses.
   Src = DAG.getFreeze(Src);
 
-  // We do the conversion on the absolute value and fix the sign at the end.
-  SDValue Abs = DAG.getNode(RISCVISD::FABS_VL, DL, ContainerVT, Src, Mask, VL);
-
-  // Determine the largest integer that can be represented exactly. This and
-  // values larger than it don't have any fractional bits so don't need to
-  // be converted.
-  const fltSemantics &FltSem = DAG.EVTToAPFloatSemantics(ContainerVT);
-  unsigned Precision = APFloat::semanticsPrecision(FltSem);
-  APFloat MaxVal = APFloat(FltSem);
-  MaxVal.convertFromAPInt(APInt::getOneBitSet(Precision, Precision - 1),
-                          /*IsSigned*/ false, APFloat::rmNearestTiesToEven);
-  SDValue MaxValNode =
-      DAG.getConstantFP(MaxVal, DL, ContainerVT.getVectorElementType());
-  SDValue MaxValSplat = DAG.getNode(RISCVISD::VFMV_V_F_VL, DL, ContainerVT,
-                                    DAG.getUNDEF(ContainerVT), MaxValNode, VL);
-
-  // If abs(Src) was larger than MaxVal or nan, keep it.
-  MVT SetccVT = MVT::getVectorVT(MVT::i1, ContainerVT.getVectorElementCount());
-  Mask =
-      DAG.getNode(RISCVISD::SETCC_VL, DL, SetccVT,
-                  {Abs, MaxValSplat, DAG.getCondCode(ISD::SETOLT),
-                   Mask, Mask, VL});
+  // Skip all this logic for LRINT and LLRINT, since ContainerVT isn't a FP
+  // type.
+  if (Op.getOpcode() != ISD::LRINT && Op.getOpcode() != ISD::LLRINT) {
+    // We do the conversion on the absolute value and fix the sign at the end.
+    SDValue Abs =
+        DAG.getNode(RISCVISD::FABS_VL, DL, ContainerVT, Src, Mask, VL);
+
+    // Determine the largest integer that can be represented exactly. This and
+    // values larger than it don't have any fractional bits so don't need to
+    // be converted.
+    const fltSemantics &FltSem = DAG.EVTToAPFloatSemantics(ContainerVT);
+    unsigned Precision = APFloat::semanticsPrecision(FltSem);
+    APFloat MaxVal = APFloat(FltSem);
+    MaxVal.convertFromAPInt(APInt::getOneBitSet(Precision, Precision - 1),
+                            /*IsSigned*/ false, APFloat::rmNearestTiesToEven);
+    SDValue MaxValNode =
+        DAG.getConstantFP(MaxVal, DL, ContainerVT.getVectorElementType());
+    SDValue MaxValSplat =
+        DAG.getNode(RISCVISD::VFMV_V_F_VL, DL, ContainerVT,
+                    DAG.getUNDEF(ContainerVT), MaxValNode, VL);
+
+    // If abs(Src) was larger than MaxVal or nan, keep it.
+    MVT SetccVT =
+        MVT::getVectorVT(MVT::i1, ContainerVT.getVectorElementCount());
+    Mask = DAG.getNode(
+        RISCVISD::SETCC_VL, DL, SetccVT,
+        {Abs, MaxValSplat, DAG.getCondCode(ISD::SETOLT), Mask, Mask, VL});
+  }
 
   // Truncate to integer and convert back to FP.
   MVT IntVT = ContainerVT.changeVectorElementTypeToInteger();
@@ -2773,6 +2780,8 @@ lowerVectorFTRUNC_FCEIL_FFLOOR_FROUND(SDValue Op, SelectionDAG &DAG,
     break;
   case ISD::FRINT:
   case ISD::VP_FRINT:
+  case ISD::LRINT:
+  case ISD::LLRINT:
     Truncated = DAG.getNode(RISCVISD::VFCVT_X_F_VL, DL, IntVT, Src, Mask, VL);
     break;
   case ISD::FNEARBYINT:
@@ -2782,14 +2791,17 @@ lowerVectorFTRUNC_FCEIL_FFLOOR_FROUND(SDValue Op, SelectionDAG &DAG,
     break;
   }
 
-  // VFROUND_NOEXCEPT_VL includes SINT_TO_FP_VL.
-  if (Truncated.getOpcode() != RISCVISD::VFROUND_NOEXCEPT_VL)
-    Truncated = DAG.getNode(RISCVISD::SINT_TO_FP_VL, DL, ContainerVT, Truncated,
-                            Mask, VL);
+  // For LRINT and LLRINT, we're done: don't convert back to float.
+  if (Op.getOpcode() != ISD::LRINT && Op.getOpcode() != ISD::LLRINT) {
+    // VFROUND_NOEXCEPT_VL includes SINT_TO_FP_VL.
+    if (Truncated.getOpcode() != RISCVISD::VFROUND_NOEXCEPT_VL)
+      Truncated = DAG.getNode(RISCVISD::SINT_TO_FP_VL, DL, ContainerVT,
+                              Truncated, Mask, VL);
 
-  // Restore the original sign so that -0.0 is preserved.
-  Truncated = DAG.getNode(RISCVISD::FCOPYSIGN_VL, DL, ContainerVT, Truncated,
-                          Src, Src, Mask, VL);
+    // Restore the original sign so that -0.0 is preserved.
+    Truncated = DAG.getNode(RISCVISD::FCOPYSIGN_VL, DL, ContainerVT, Truncated,
+                            Src, Src, Mask, VL);
+  }
 
   if (!VT.isFixedLengthVector())
     return Truncated;
@@ -2902,11 +2914,14 @@ lowerVectorStrictFTRUNC_FCEIL_FFLOOR_FROUND(SDValue Op, SelectionDAG &DAG,
 }
 
 static SDValue
-lowerFTRUNC_FCEIL_FFLOOR_FROUND(SDValue Op, SelectionDAG &DAG,
-                                const RISCVSubtarget &Subtarget) {
+lowerFTRUNC_FCEIL_FFLOOR_FROUND_XRINT(SDValue Op, SelectionDAG &DAG,
+                                      const RISCVSubtarget &Subtarget) {
   MVT VT = Op.getSimpleValueType();
+  if (Op.getOpcode() == ISD::LRINT || Op.getOpcode() == ISD::LLRINT)
+    assert(VT.isVector() &&
+           "Expected vector in custom lowering of LRINT and LLRINT");
   if (VT.isVector())
-    return lowerVectorFTRUNC_FCEIL_FFLOOR_FROUND(Op, DAG, Subtarget);
+    return lowerVectorFTRUNC_FCEIL_FFLOOR_FROUND_XRINT(Op, DAG, Subtarget);
 
   if (DAG.shouldOptForSize())
     return SDValue();
@@ -5921,9 +5936,11 @@ SDValue RISCVTargetLowering::LowerOperation(SDValue Op,
   case ISD::FFLOOR:
   case ISD::FNEARBYINT:
   case ISD::FRINT:
+  case ISD::LRINT:
+  case ISD::LLRINT:
   case ISD::FROUND:
   case ISD::FROUNDEVEN:
-    return lowerFTRUNC_FCEIL_FFLOOR_FROUND(Op, DAG, Subtarget);
+    return lowerFTRUNC_FCEIL_FFLOOR_FROUND_XRINT(Op, DAG, Subtarget);
   case ISD::VECREDUCE_ADD:
   case ISD::VECREDUCE_UMAX:
   case ISD::VECREDUCE_SMAX:
@@ -6287,7 +6304,7 @@ SDValue RISCVTargetLowering::LowerOperation(SDValue Op,
         (Subtarget.hasVInstructionsF16Minimal() &&
          !Subtarget.hasVIns...
[truncated]

Copy link
Collaborator

@RKSimon RKSimon left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please can you move the tests to a stacked pre-commit so we can see the codegen diff?

llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp Outdated Show resolved Hide resolved
llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp Outdated Show resolved Hide resolved
llvm/lib/CodeGen/SelectionDAG/LegalizeFloatTypes.cpp Outdated Show resolved Hide resolved
llvm/lib/CodeGen/SelectionDAG/LegalizeFloatTypes.cpp Outdated Show resolved Hide resolved
llvm/lib/CodeGen/SelectionDAG/LegalizeFloatTypes.cpp Outdated Show resolved Hide resolved
llvm/lib/CodeGen/TargetLoweringBase.cpp Show resolved Hide resolved
@artagnon
Copy link
Contributor Author

Please can you move the tests to a stacked pre-commit so we can see the codegen diff?

Um, wouldn't the tests assert, as I'm using a vector variant of [l]lrint that wouldn't be legal yet?

@RKSimon
Copy link
Collaborator

RKSimon commented Sep 20, 2023

Maybe the costs tests at least?

@artagnon
Copy link
Contributor Author

Maybe the costs tests at least?

Same issue. The module is broken, and the cost model tests can't work, since the vector versions of [l]lrint haven't yet been introduced.

@github-actions
Copy link

github-actions bot commented Sep 22, 2023

✅ With the latest revision this PR passed the C/C++ code formatter.

@artagnon
Copy link
Contributor Author

I've realized the limitations of this patch:

  • rv32 can only do i32, and rv64 can only do i64.
  • I don't know if the ISD::LLRINT lowering on rv32 is correct (it's identical to the rv64 output, but I haven't added a RUN line for it).

Should I add asserts to guard against these?

@artagnon
Copy link
Contributor Author

artagnon commented Oct 4, 2023

Gentle ping.

@artagnon
Copy link
Contributor Author

Gentle ping.

The issue llvm#55208 noticed that std::rint is vectorized by the
SLPVectorizer, but a very similar function, std::lrint, is not.
std::lrint corresponds to ISD::LRINT in the SelectionDAG, and
std::llrint is a familiar cousin corresponding to ISD::LLRINT. Now,
neither ISD::LRINT nor ISD::LLRINT have a corresponding vector variant,
and the LangRef makes this clear in the documentation of llvm.lrint.*
and llvm.llrint.*.

This patch extends the LangRef to include vector variants of
llvm.lrint.* and llvm.llrint.*, and lays the necessary ground-work of
scalarizing it for all targets. However, this patch would be devoid of
motivation unless we show the utility of these new vector variants.
Hence, the RISCV target has been chosen to implement a custom lowering
to the vfcvt.x.f.v instruction. The patch also includes a CostModel for
RISCV, and a trivial follow-up can potentially enable the SLPVectorizer
to vectorize std::lrint and std::llrint, fixing llvm#55208.

The patch includes tests, obviously for the RISCV target, but also for
the X86, AArch64, and PowerPC targets to justify the addition of the
vector variants to the LangRef.
Copy link
Collaborator

@topperc topperc left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

5 participants