Skip to content

Conversation

mshockwave
Copy link
Member

@mshockwave mshockwave commented Apr 30, 2024

Given VFirst = (vfirst_vl ..., EVL) and (seteq (riscv_selectLT VFirst, 0, EVL, VFirst), EVL), we can replace it with (setLT VFirst, 0). Similar replacements are done for variants w/ setne and riscv_selectGE.

@llvmbot
Copy link
Member

llvmbot commented Apr 30, 2024

@llvm/pr-subscribers-llvm-ir
@llvm/pr-subscribers-llvm-selectiondag

@llvm/pr-subscribers-backend-risc-v

Author: Min-Yih Hsu (mshockwave)

Changes

Given (seteq (riscv_selectcc LHS, RHS, CC, X, Y), X), we can turn it into (setCC LHS, RHS).
I think we can generalize this into ISD::SELECT_CC as well.


Right now this PR is stacked on top of #90502 -- but it doesn't have to -- to show changes on the test. The main patch is 018b08a. I'll add more tests tomorrow.


Patch is 33.18 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/90538.diff

14 Files Affected:

  • (modified) llvm/docs/LangRef.rst (+48)
  • (modified) llvm/include/llvm/CodeGen/TargetLowering.h (+5)
  • (modified) llvm/include/llvm/IR/Intrinsics.td (+6)
  • (modified) llvm/include/llvm/IR/VPIntrinsics.def (+9)
  • (modified) llvm/lib/CodeGen/SelectionDAG/LegalizeDAG.cpp (+9)
  • (modified) llvm/lib/CodeGen/SelectionDAG/LegalizeIntegerTypes.cpp (+10)
  • (modified) llvm/lib/CodeGen/SelectionDAG/LegalizeTypes.h (+3)
  • (modified) llvm/lib/CodeGen/SelectionDAG/LegalizeVectorOps.cpp (+7)
  • (modified) llvm/lib/CodeGen/SelectionDAG/LegalizeVectorTypes.cpp (+42)
  • (modified) llvm/lib/CodeGen/SelectionDAG/SelectionDAGBuilder.cpp (+8-1)
  • (modified) llvm/lib/CodeGen/SelectionDAG/TargetLowering.cpp (+33)
  • (modified) llvm/lib/Target/RISCV/RISCVISelLowering.cpp (+109-10)
  • (modified) llvm/lib/Target/RISCV/RISCVISelLowering.h (+1)
  • (added) llvm/test/CodeGen/RISCV/rvv/vp-cttz-elts.ll (+234)
diff --git a/llvm/docs/LangRef.rst b/llvm/docs/LangRef.rst
index 37662f79145d67..f79c1fd9278de3 100644
--- a/llvm/docs/LangRef.rst
+++ b/llvm/docs/LangRef.rst
@@ -24001,6 +24001,54 @@ Examples:
       %also.r = select <4 x i1> %mask, <4 x i32> %t, <4 x i32> poison
 
 
+.. _int_vp_cttz_elts:
+
+'``llvm.vp.cttz.elts.*``' Intrinsics
+^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
+
+Syntax:
+"""""""
+This is an overloaded intrinsic. You can use ```llvm.vp.cttz.elts``` on any
+vector of integer elements, both fixed width and scalable.
+
+::
+
+      declare i32  @llvm.vp.cttz.elts.i32.v16i32 (<16 x i32> <op>, i1 <is_zero_poison>, <16 x i1> <mask>, i32 <vector_length>)
+      declare i64  @llvm.vp.cttz.elts.i64.nxv4i32 (<vscale x 4 x i32> <op>, i1 <is_zero_poison>, <vscale x 4 x i1> <mask>, i32 <vector_length>)
+      declare i64  @llvm.vp.cttz.elts.i64.v256i1 (<256 x i1> <op>, i1 <is_zero_poison>, <256 x i1> <mask>, i32 <vector_length>)
+
+Overview:
+"""""""""
+
+This '```llvm.vp.cttz.elts```' intrinsic counts the number of trailing zero
+elements of a vector. This is basically the vector-predicated version of
+'```llvm.experimental.cttz.elts```'.
+
+Arguments:
+""""""""""
+
+The first argument is the vector to be counted. This argument must be a vector
+with integer element type. The return type must also be an integer type which is
+wide enough to hold the maximum number of elements of the source vector. The
+behavior of this intrinsic is undefined if the return type is not wide enough
+for the number of elements in the input vector.
+
+The second argument is a constant flag that indicates whether the intrinsic
+returns a valid result if the first argument is all zero.
+
+The third operand is the vector mask and has the same number of elements as the
+input vector type. The fourth operand is the explicit vector length of the
+operation.
+
+Semantics:
+""""""""""
+
+The '``llvm.vp.cttz.elts``' intrinsic counts the trailing (least
+significant / lowest-numbered) zero elements in the first operand on each
+enabled lane. If the first argument is all zero and the second argument is true,
+the result is poison. Otherwise, it returns the explicit vector length (i.e. the
+fourth operand).
+
 .. _int_vp_sadd_sat:
 
 '``llvm.vp.sadd.sat.*``' Intrinsics
diff --git a/llvm/include/llvm/CodeGen/TargetLowering.h b/llvm/include/llvm/CodeGen/TargetLowering.h
index 661b2841c6ac72..7ed08cfa8a2022 100644
--- a/llvm/include/llvm/CodeGen/TargetLowering.h
+++ b/llvm/include/llvm/CodeGen/TargetLowering.h
@@ -5307,6 +5307,11 @@ class TargetLowering : public TargetLoweringBase {
   /// \returns The expansion result or SDValue() if it fails.
   SDValue expandVPCTTZ(SDNode *N, SelectionDAG &DAG) const;
 
+  /// Expand VP_CTTZ_ELTS/VP_CTTZ_ELTS_ZERO_UNDEF nodes.
+  /// \param N Node to expand
+  /// \returns The expansion result or SDValue() if it fails.
+  SDValue expandVPCTTZElements(SDNode *N, SelectionDAG &DAG) const;
+
   /// Expand ABS nodes. Expands vector/scalar ABS nodes,
   /// vector nodes can only succeed if all operations are legal/custom.
   /// (ABS x) -> (XOR (ADD x, (SRA x, type_size)), (SRA x, type_size))
diff --git a/llvm/include/llvm/IR/Intrinsics.td b/llvm/include/llvm/IR/Intrinsics.td
index a2678d69ce4062..28116e5316c96b 100644
--- a/llvm/include/llvm/IR/Intrinsics.td
+++ b/llvm/include/llvm/IR/Intrinsics.td
@@ -2255,6 +2255,12 @@ let IntrProperties = [IntrNoMem, IntrNoSync, IntrWillReturn, ImmArg<ArgIndex<1>>
                                llvm_i1_ty,
                                LLVMScalarOrSameVectorWidth<0, llvm_i1_ty>,
                                llvm_i32_ty]>;
+
+  def int_vp_cttz_elts : DefaultAttrsIntrinsic<[ llvm_anyint_ty ],
+                                  [ llvm_anyvector_ty,
+                                    llvm_i1_ty,
+                                    LLVMScalarOrSameVectorWidth<1, llvm_i1_ty>,
+                                    llvm_i32_ty]>;
 }
 
 def int_get_active_lane_mask:
diff --git a/llvm/include/llvm/IR/VPIntrinsics.def b/llvm/include/llvm/IR/VPIntrinsics.def
index 1c2708a9e85437..f1cc8bcae467be 100644
--- a/llvm/include/llvm/IR/VPIntrinsics.def
+++ b/llvm/include/llvm/IR/VPIntrinsics.def
@@ -282,6 +282,15 @@ BEGIN_REGISTER_VP_SDNODE(VP_CTTZ_ZERO_UNDEF, -1, vp_cttz_zero_undef, 1, 2)
 END_REGISTER_VP_SDNODE(VP_CTTZ_ZERO_UNDEF)
 END_REGISTER_VP_INTRINSIC(vp_cttz)
 
+// llvm.vp.cttz.elts(x,is_zero_poison,mask,vl)
+BEGIN_REGISTER_VP_INTRINSIC(vp_cttz_elts, 2, 3)
+VP_PROPERTY_NO_FUNCTIONAL
+BEGIN_REGISTER_VP_SDNODE(VP_CTTZ_ELTS, 0, vp_cttz_elts, 1, 2)
+END_REGISTER_VP_SDNODE(VP_CTTZ_ELTS)
+BEGIN_REGISTER_VP_SDNODE(VP_CTTZ_ELTS_ZERO_UNDEF, 0, vp_cttz_elts_zero_undef, 1, 2)
+END_REGISTER_VP_SDNODE(VP_CTTZ_ELTS_ZERO_UNDEF)
+END_REGISTER_VP_INTRINSIC(vp_cttz_elts)
+
 // llvm.vp.fshl(x,y,z,mask,vlen)
 BEGIN_REGISTER_VP(vp_fshl, 3, 4, VP_FSHL, -1)
 VP_PROPERTY_FUNCTIONAL_INTRINSIC(fshl)
diff --git a/llvm/lib/CodeGen/SelectionDAG/LegalizeDAG.cpp b/llvm/lib/CodeGen/SelectionDAG/LegalizeDAG.cpp
index 46e54b5366d66a..5322ea3b6a2d97 100644
--- a/llvm/lib/CodeGen/SelectionDAG/LegalizeDAG.cpp
+++ b/llvm/lib/CodeGen/SelectionDAG/LegalizeDAG.cpp
@@ -1220,6 +1220,11 @@ void SelectionDAGLegalize::LegalizeOp(SDNode *Node) {
     Action = TLI.getOperationAction(
         Node->getOpcode(), Node->getOperand(1).getValueType());
     break;
+  case ISD::VP_CTTZ_ELTS:
+  case ISD::VP_CTTZ_ELTS_ZERO_UNDEF:
+    Action = TLI.getOperationAction(Node->getOpcode(),
+                                    Node->getOperand(0).getValueType());
+    break;
   default:
     if (Node->getOpcode() >= ISD::BUILTIN_OP_END) {
       Action = TLI.getCustomOperationAction(*Node);
@@ -4234,6 +4239,10 @@ bool SelectionDAGLegalize::ExpandNode(SDNode *Node) {
   case ISD::VECREDUCE_FMINIMUM:
     Results.push_back(TLI.expandVecReduce(Node, DAG));
     break;
+  case ISD::VP_CTTZ_ELTS:
+  case ISD::VP_CTTZ_ELTS_ZERO_UNDEF:
+    Results.push_back(TLI.expandVPCTTZElements(Node, DAG));
+    break;
   case ISD::GLOBAL_OFFSET_TABLE:
   case ISD::GlobalAddress:
   case ISD::GlobalTLSAddress:
diff --git a/llvm/lib/CodeGen/SelectionDAG/LegalizeIntegerTypes.cpp b/llvm/lib/CodeGen/SelectionDAG/LegalizeIntegerTypes.cpp
index 55f9737bc94dd5..0aa36deda79dcc 100644
--- a/llvm/lib/CodeGen/SelectionDAG/LegalizeIntegerTypes.cpp
+++ b/llvm/lib/CodeGen/SelectionDAG/LegalizeIntegerTypes.cpp
@@ -76,6 +76,10 @@ void DAGTypeLegalizer::PromoteIntegerResult(SDNode *N, unsigned ResNo) {
   case ISD::VP_CTTZ:
   case ISD::CTTZ_ZERO_UNDEF:
   case ISD::CTTZ:        Res = PromoteIntRes_CTTZ(N); break;
+  case ISD::VP_CTTZ_ELTS_ZERO_UNDEF:
+  case ISD::VP_CTTZ_ELTS:
+    Res = PromoteIntRes_VP_CttzElements(N);
+    break;
   case ISD::EXTRACT_VECTOR_ELT:
                          Res = PromoteIntRes_EXTRACT_VECTOR_ELT(N); break;
   case ISD::LOAD:        Res = PromoteIntRes_LOAD(cast<LoadSDNode>(N)); break;
@@ -724,6 +728,12 @@ SDValue DAGTypeLegalizer::PromoteIntRes_CTTZ(SDNode *N) {
                      N->getOperand(2));
 }
 
+SDValue DAGTypeLegalizer::PromoteIntRes_VP_CttzElements(SDNode *N) {
+  SDLoc DL(N);
+  EVT NewVT = TLI.getTypeToTransformTo(*DAG.getContext(), N->getValueType(0));
+  return DAG.getNode(N->getOpcode(), DL, NewVT, N->ops());
+}
+
 SDValue DAGTypeLegalizer::PromoteIntRes_EXTRACT_VECTOR_ELT(SDNode *N) {
   SDLoc dl(N);
   EVT NVT = TLI.getTypeToTransformTo(*DAG.getContext(), N->getValueType(0));
diff --git a/llvm/lib/CodeGen/SelectionDAG/LegalizeTypes.h b/llvm/lib/CodeGen/SelectionDAG/LegalizeTypes.h
index 4a2c7b355eb528..49be824deb5134 100644
--- a/llvm/lib/CodeGen/SelectionDAG/LegalizeTypes.h
+++ b/llvm/lib/CodeGen/SelectionDAG/LegalizeTypes.h
@@ -309,6 +309,7 @@ class LLVM_LIBRARY_VISIBILITY DAGTypeLegalizer {
   SDValue PromoteIntRes_CTLZ(SDNode *N);
   SDValue PromoteIntRes_CTPOP_PARITY(SDNode *N);
   SDValue PromoteIntRes_CTTZ(SDNode *N);
+  SDValue PromoteIntRes_VP_CttzElements(SDNode *N);
   SDValue PromoteIntRes_EXTRACT_VECTOR_ELT(SDNode *N);
   SDValue PromoteIntRes_FP_TO_XINT(SDNode *N);
   SDValue PromoteIntRes_FP_TO_XINT_SAT(SDNode *N);
@@ -912,6 +913,7 @@ class LLVM_LIBRARY_VISIBILITY DAGTypeLegalizer {
   SDValue SplitVecOp_FP_ROUND(SDNode *N);
   SDValue SplitVecOp_FPOpDifferentTypes(SDNode *N);
   SDValue SplitVecOp_FP_TO_XINT_SAT(SDNode *N);
+  SDValue SplitVecOp_VP_CttzElements(SDNode *N);
 
   //===--------------------------------------------------------------------===//
   // Vector Widening Support: LegalizeVectorTypes.cpp
@@ -1019,6 +1021,7 @@ class LLVM_LIBRARY_VISIBILITY DAGTypeLegalizer {
   SDValue WidenVecOp_VECREDUCE_SEQ(SDNode *N);
   SDValue WidenVecOp_VP_REDUCE(SDNode *N);
   SDValue WidenVecOp_ExpOp(SDNode *N);
+  SDValue WidenVecOp_VP_CttzElements(SDNode *N);
 
   /// Helper function to generate a set of operations to perform
   /// a vector operation for a wider type.
diff --git a/llvm/lib/CodeGen/SelectionDAG/LegalizeVectorOps.cpp b/llvm/lib/CodeGen/SelectionDAG/LegalizeVectorOps.cpp
index 8f87ee8e09393a..26cd5482168f9f 100644
--- a/llvm/lib/CodeGen/SelectionDAG/LegalizeVectorOps.cpp
+++ b/llvm/lib/CodeGen/SelectionDAG/LegalizeVectorOps.cpp
@@ -510,6 +510,13 @@ SDValue VectorLegalizer::LegalizeOp(SDValue Op) {
       if (Action != TargetLowering::Legal)                                     \
         break;                                                                 \
     }                                                                          \
+    /* Defer non-vector results to LegalizeDAG. */                             \
+    /* Remove this after #90522 is landed */                                   \
+    if (ISD::VPID == ISD::VP_CTTZ_ELTS ||                                      \
+        ISD::VPID == ISD::VP_CTTZ_ELTS_ZERO_UNDEF) {                           \
+      Action = TargetLowering::Legal;                                          \
+      break;                                                                   \
+    }                                                                          \
     Action = TLI.getOperationAction(Node->getOpcode(), LegalizeVT);            \
   } break;
 #include "llvm/IR/VPIntrinsics.def"
diff --git a/llvm/lib/CodeGen/SelectionDAG/LegalizeVectorTypes.cpp b/llvm/lib/CodeGen/SelectionDAG/LegalizeVectorTypes.cpp
index 985c9f16ab97cd..cab4dc5f3c1565 100644
--- a/llvm/lib/CodeGen/SelectionDAG/LegalizeVectorTypes.cpp
+++ b/llvm/lib/CodeGen/SelectionDAG/LegalizeVectorTypes.cpp
@@ -3098,6 +3098,10 @@ bool DAGTypeLegalizer::SplitVectorOperand(SDNode *N, unsigned OpNo) {
   case ISD::VP_REDUCE_FMIN:
     Res = SplitVecOp_VP_REDUCE(N, OpNo);
     break;
+  case ISD::VP_CTTZ_ELTS:
+  case ISD::VP_CTTZ_ELTS_ZERO_UNDEF:
+    Res = SplitVecOp_VP_CttzElements(N);
+    break;
   }
 
   // If the result is null, the sub-method took care of registering results etc.
@@ -4056,6 +4060,29 @@ SDValue DAGTypeLegalizer::SplitVecOp_FP_TO_XINT_SAT(SDNode *N) {
   return DAG.getNode(ISD::CONCAT_VECTORS, dl, ResVT, Lo, Hi);
 }
 
+SDValue DAGTypeLegalizer::SplitVecOp_VP_CttzElements(SDNode *N) {
+  SDLoc DL(N);
+  EVT ResVT = N->getValueType(0);
+
+  SDValue Lo, Hi;
+  SDValue VecOp = N->getOperand(0);
+  GetSplitVector(VecOp, Lo, Hi);
+
+  auto [MaskLo, MaskHi] = SplitMask(N->getOperand(1));
+  auto [EVLLo, EVLHi] =
+      DAG.SplitEVL(N->getOperand(2), VecOp.getValueType(), DL);
+  SDValue VLo = DAG.getZExtOrTrunc(EVLLo, DL, ResVT);
+
+  // if VP_CTTZ_ELTS(Lo) != EVLLo => VP_CTTZ_ELTS(Lo).
+  // else => EVLLo + (VP_CTTZ_ELTS(Hi) or VP_CTTZ_ELTS_ZERO_UNDEF(Hi)).
+  SDValue ResLo = DAG.getNode(ISD::VP_CTTZ_ELTS, DL, ResVT, Lo, MaskLo, EVLLo);
+  SDValue ResLoNotEVL =
+      DAG.getSetCC(DL, getSetCCResultType(ResVT), ResLo, VLo, ISD::SETNE);
+  SDValue ResHi = DAG.getNode(N->getOpcode(), DL, ResVT, Hi, MaskHi, EVLHi);
+  return DAG.getSelect(DL, ResVT, ResLoNotEVL, ResLo,
+                       DAG.getNode(ISD::ADD, DL, ResVT, VLo, ResHi));
+}
+
 //===----------------------------------------------------------------------===//
 //  Result Vector Widening
 //===----------------------------------------------------------------------===//
@@ -6161,6 +6188,10 @@ bool DAGTypeLegalizer::WidenVectorOperand(SDNode *N, unsigned OpNo) {
   case ISD::VP_REDUCE_FMIN:
     Res = WidenVecOp_VP_REDUCE(N);
     break;
+  case ISD::VP_CTTZ_ELTS:
+  case ISD::VP_CTTZ_ELTS_ZERO_UNDEF:
+    Res = WidenVecOp_VP_CttzElements(N);
+    break;
   }
 
   // If Res is null, the sub-method took care of registering the result.
@@ -6924,6 +6955,17 @@ SDValue DAGTypeLegalizer::WidenVecOp_VSELECT(SDNode *N) {
                      DAG.getVectorIdxConstant(0, DL));
 }
 
+SDValue DAGTypeLegalizer::WidenVecOp_VP_CttzElements(SDNode *N) {
+  SDLoc DL(N);
+  SDValue Source = GetWidenedVector(N->getOperand(0));
+  EVT SrcVT = Source.getValueType();
+  SDValue Mask =
+      GetWidenedMask(N->getOperand(1), SrcVT.getVectorElementCount());
+
+  return DAG.getNode(N->getOpcode(), DL, N->getValueType(0),
+                     {Source, Mask, N->getOperand(2)}, N->getFlags());
+}
+
 //===----------------------------------------------------------------------===//
 // Vector Widening Utilities
 //===----------------------------------------------------------------------===//
diff --git a/llvm/lib/CodeGen/SelectionDAG/SelectionDAGBuilder.cpp b/llvm/lib/CodeGen/SelectionDAG/SelectionDAGBuilder.cpp
index 5caf868c83a296..cfd82a342433fa 100644
--- a/llvm/lib/CodeGen/SelectionDAG/SelectionDAGBuilder.cpp
+++ b/llvm/lib/CodeGen/SelectionDAG/SelectionDAGBuilder.cpp
@@ -8076,6 +8076,11 @@ static unsigned getISDForVPIntrinsic(const VPIntrinsic &VPIntrin) {
     ResOPC = IsZeroUndef ? ISD::VP_CTTZ_ZERO_UNDEF : ISD::VP_CTTZ;
     break;
   }
+  case Intrinsic::vp_cttz_elts: {
+    bool IsZeroPoison = cast<ConstantInt>(VPIntrin.getArgOperand(1))->isOne();
+    ResOPC = IsZeroPoison ? ISD::VP_CTTZ_ELTS_ZERO_UNDEF : ISD::VP_CTTZ_ELTS;
+    break;
+  }
 #define HELPER_MAP_VPID_TO_VPSD(VPID, VPSD)                                    \
   case Intrinsic::VPID:                                                        \
     ResOPC = ISD::VPSD;                                                        \
@@ -8428,7 +8433,9 @@ void SelectionDAGBuilder::visitVectorPredicationIntrinsic(
   case ISD::VP_CTLZ:
   case ISD::VP_CTLZ_ZERO_UNDEF:
   case ISD::VP_CTTZ:
-  case ISD::VP_CTTZ_ZERO_UNDEF: {
+  case ISD::VP_CTTZ_ZERO_UNDEF:
+  case ISD::VP_CTTZ_ELTS_ZERO_UNDEF:
+  case ISD::VP_CTTZ_ELTS: {
     SDValue Result =
         DAG.getNode(Opcode, DL, VTs, {OpValues[0], OpValues[2], OpValues[3]});
     setValue(&VPIntrin, Result);
diff --git a/llvm/lib/CodeGen/SelectionDAG/TargetLowering.cpp b/llvm/lib/CodeGen/SelectionDAG/TargetLowering.cpp
index cdc1227fd572dc..336d89fbcf638e 100644
--- a/llvm/lib/CodeGen/SelectionDAG/TargetLowering.cpp
+++ b/llvm/lib/CodeGen/SelectionDAG/TargetLowering.cpp
@@ -9074,6 +9074,39 @@ SDValue TargetLowering::expandVPCTTZ(SDNode *Node, SelectionDAG &DAG) const {
   return DAG.getNode(ISD::VP_CTPOP, dl, VT, Tmp, Mask, VL);
 }
 
+SDValue TargetLowering::expandVPCTTZElements(SDNode *N,
+                                             SelectionDAG &DAG) const {
+  // %cond = to_bool_vec %source
+  // %splat = splat /*val=*/VL
+  // %tz = step_vector
+  // %v = vp.select %cond, /*true=*/tz, /*false=*/%splat
+  // %r = vp.reduce.umin %v
+  SDLoc DL(N);
+  SDValue Source = N->getOperand(0);
+  SDValue Mask = N->getOperand(1);
+  SDValue EVL = N->getOperand(2);
+  EVT SrcVT = Source.getValueType();
+  EVT ResVT = N->getValueType(0);
+  EVT ResVecVT =
+      EVT::getVectorVT(*DAG.getContext(), ResVT, SrcVT.getVectorElementCount());
+
+  // Convert to boolean vector.
+  if (SrcVT.getScalarType() != MVT::i1) {
+    SDValue AllZero = DAG.getConstant(0, DL, SrcVT);
+    SrcVT = EVT::getVectorVT(*DAG.getContext(), MVT::i1,
+                             SrcVT.getVectorElementCount());
+    Source = DAG.getNode(ISD::VP_SETCC, DL, SrcVT, Source, AllZero,
+                         DAG.getCondCode(ISD::SETNE), Mask, EVL);
+  }
+
+  SDValue ExtEVL = DAG.getZExtOrTrunc(EVL, DL, ResVT);
+  SDValue Splat = DAG.getSplat(ResVecVT, DL, ExtEVL);
+  SDValue StepVec = DAG.getStepVector(DL, ResVecVT);
+  SDValue Select =
+      DAG.getNode(ISD::VP_SELECT, DL, ResVecVT, Source, StepVec, Splat, EVL);
+  return DAG.getNode(ISD::VP_REDUCE_UMIN, DL, ResVT, ExtEVL, Select, Mask, EVL);
+}
+
 SDValue TargetLowering::expandABS(SDNode *N, SelectionDAG &DAG,
                                   bool IsNegative) const {
   SDLoc dl(N);
diff --git a/llvm/lib/Target/RISCV/RISCVISelLowering.cpp b/llvm/lib/Target/RISCV/RISCVISelLowering.cpp
index 68f4ec5ef49f31..c61e477d79e110 100644
--- a/llvm/lib/Target/RISCV/RISCVISelLowering.cpp
+++ b/llvm/lib/Target/RISCV/RISCVISelLowering.cpp
@@ -28,6 +28,7 @@
 #include "llvm/CodeGen/MachineInstrBuilder.h"
 #include "llvm/CodeGen/MachineJumpTableInfo.h"
 #include "llvm/CodeGen/MachineRegisterInfo.h"
+#include "llvm/CodeGen/SDPatternMatch.h"
 #include "llvm/CodeGen/SelectionDAGAddressAnalysis.h"
 #include "llvm/CodeGen/TargetLoweringObjectFileImpl.h"
 #include "llvm/CodeGen/ValueTypes.h"
@@ -698,7 +699,7 @@ RISCVTargetLowering::RISCVTargetLowering(const TargetMachine &TM,
         ISD::VP_SMAX,        ISD::VP_UMIN,        ISD::VP_UMAX,
         ISD::VP_ABS, ISD::EXPERIMENTAL_VP_REVERSE, ISD::EXPERIMENTAL_VP_SPLICE,
         ISD::VP_SADDSAT,     ISD::VP_UADDSAT,     ISD::VP_SSUBSAT,
-        ISD::VP_USUBSAT};
+        ISD::VP_USUBSAT,     ISD::VP_CTTZ_ELTS,   ISD::VP_CTTZ_ELTS_ZERO_UNDEF};
 
     static const unsigned FloatingPointVPOps[] = {
         ISD::VP_FADD,        ISD::VP_FSUB,        ISD::VP_FMUL,
@@ -759,6 +760,9 @@ RISCVTargetLowering::RISCVTargetLowering(const TargetMachine &TM,
           {ISD::SELECT_CC, ISD::VSELECT, ISD::VP_MERGE, ISD::VP_SELECT}, VT,
           Expand);
 
+      setOperationAction({ISD::VP_CTTZ_ELTS, ISD::VP_CTTZ_ELTS_ZERO_UNDEF}, VT,
+                         Custom);
+
       setOperationAction({ISD::VP_AND, ISD::VP_OR, ISD::VP_XOR}, VT, Custom);
 
       setOperationAction(
@@ -5341,6 +5345,44 @@ RISCVTargetLowering::lowerCTLZ_CTTZ_ZERO_UNDEF(SDValue Op,
   return Res;
 }
 
+SDValue RISCVTargetLowering::lowerVPCttzElements(SDValue Op,
+                                                 SelectionDAG &DAG) const {
+  SDLoc DL(Op);
+  MVT XLenVT = Subtarget.getXLenVT();
+  SDValue Source = Op->getOperand(0);
+  MVT SrcVT = Source.getSimpleValueType();
+  SDValue Mask = Op->getOperand(1);
+  SDValue EVL = Op->getOperand(2);
+
+  if (SrcVT.isFixedLengthVector()) {
+    MVT ContainerVT = getContainerForFixedLengthVector(SrcVT);
+    Source = convertToScalableVector(ContainerVT, Source, DAG, Subtarget);
+    Mask = convertToScalableVector(getMaskTypeFor(ContainerVT), Mask, DAG,
+                                   Subtarget);
+    SrcVT = ContainerVT;
+  }
+
+  // Convert to boolean vector.
+  if (SrcVT.getScalarType() != MVT::i1) {
+    SDValue AllZero = DAG.getConstant(0, DL, SrcVT);
+    SrcVT = MVT::getVectorVT(MVT::i1, SrcVT.getVectorElementCount());
+    Source = DAG.getNode(RISCVISD::SETCC_VL, DL, SrcVT,
+                         {Source, AllZero, DAG.getCondCode(ISD::SETNE),
+                          DAG.getUNDEF(SrcVT), Mask, EVL});
+  }
+
+  SDValue Res = DAG.getNode(RISCVISD::VFIRST_VL, DL, XLenVT, Source, Mask, EVL);
+  if (Op->getOpcode() == ISD::VP_CTTZ_ELTS_ZERO_UNDEF)
+    // In this case, we can interpret poison as -1, so nothing to do further.
+    return Res;
+
+  // Convert -1 to VL.
+  SDValue SetCC =
+      DAG.getSetCC(DL, XLenVT, Res, DAG.getConstant(0, DL, XLenVT), ISD::SETLT);
+  Res = DAG.getSelect(DL, XLenVT, SetCC, EVL, Res);
+  return DAG.getNode(ISD::TRUNCATE, DL, Op.getValueType(), Res);
+}
+
 // While RVV has alignment restrictions, we should always be able to load as a
 // legal equivalently-sized byte-typed vector instead. This method is
 // responsible for re-expressing a ISD::LOAD via a correctly-aligned type. If
@@ -6595,6 +6637,9 @@ SDValue RISCVTargetLowering::LowerOperation(SDValue Op,
     if (Op.getOperand(1).getValueType().getVectorElementType() == MVT::i1)
       return lowerVectorMaskVecReduction(Op, DAG, /*IsVP*/ true);
     return lowerVPREDUCE(Op, DAG);
+  case ISD::VP_CTTZ_ELTS:
+  case ISD::VP_CTTZ_ELTS_ZERO_UNDEF:
+    return lowerVPCttzElements(Op, DAG);
   case ISD::UNDEF: {
     MVT ContainerVT = getContainerForFixedLengthVector(Op.getSimpleValueType());
     return convertFromScalableVector(Op.getSimpleValueType(),
@@ -13634,9 +13679,69 @@ st...
[truncated]

Copy link

github-actions bot commented Apr 30, 2024

✅ With the latest revision this PR passed the C/C++ code formatter.

ISD::CondCode NewCC = cast<CondCodeSDNode>(Select->getOperand(2))->get();
if (Inverse)
NewCC = ISD::getSetCCInverse(NewCC, OpVT);
return DAG.getNode(
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Use getSetCC.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Fixed.

…cc A, B)` when possible

Given `(seteq (riscv_selectcc LHS, RHS, CC, X, Y), X)`, we can turn it
into `(setCC LHS, RHS)`.

I think we can generalize this into ISD::SELECT_CC as well.
@mshockwave mshockwave force-pushed the patch/riscv-setcc-selectcc branch from 018b08a to 9416591 Compare April 30, 2024 17:55
; NOTE: Assertions have been autogenerated by utils/update_llc_test_checks.py UTC_ARGS: --version 4
; RUN: llc -mtriple=riscv32 < %s | FileCheck %s

define i1 @eq(i32 %a, i32 %b, i32 %c, i32 %d) {
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@mshockwave mshockwave changed the title [RISCV] Combine (setcc (riscv_selectcc A, B, ...), Y) to just (setcc A, B) when possible [RISCV] Optimize pattern (setcc (selectLT (vfirst_vl ...) , 0, EVL, ...), EVL) Apr 30, 2024
@mshockwave
Copy link
Member Author

I have limited the scope of this optimization to handle only the pattern associated with RISCVISD::VFIRST_VL.

// a llvm.vp.cttz.elts that doesn't return XLen type. Due to non-trivial
// number of sext(_inreg) and zext interleaving between the nodes.
// That said, this is not really problem as long as we always generate the
// said intrinsics with XLen return type.
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The vectorizer will very likely use i32 not XLen.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Fixed. Now it recognizes vp.cttz.elts with non-XLen return type.

@mshockwave mshockwave requested a review from dtcxzyw May 1, 2024 15:49
SDValue &Select) -> bool {
// Remove any sext or zext
auto ExtPattern =
m_AnyOf(m_Opc(ISD::SIGN_EXTEND_INREG), m_And(m_Value(), m_AllOnes()));
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

m_And(m_Value(), m_AllOnes()) isn't a zero extend and should be removed by DAG combine

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm not sure m_AllOnes is implemented correctly. It uses APInt::isSameValue which zero extends when bit widths don't match. m_AllOnes should be checking that all bits are one. If m_AllOnes is used to compare against a value that is less than 64 bits, that value will be zero extended to 64 bits and will fail to match.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It's fixed now. In this particular case I think VL will almost certain be zext from i32 so I'm pattern matching (and X, <32 trailing ones>) here instead.

// against) is zext from i32.
auto ZExtVL = m_And(m_Value(), m_SpecificInt(APInt::getLowBitsSet(64, 32)));

// Remove any sext or zext
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm nervous about peeking through zext and sexts without checking of types. I think we need to know that the compare is only using the same bits from EVL that the vfirst was using.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I agree. It's fixed now.

SDValue VLCandVTNode;
EVT VLCandVT = VLCandidate.getValueType();
// Remove any sext.
if (sd_match(Op, m_Opc(ISD::SIGN_EXTEND_INREG)))
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do we also need to make sure this doesn't drop any bits? If this sign_extends from a type that doesn't fit the EVL then we may be losing some information.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I thought about this before, but the intrinsic states that it's an undefined behavior if the user assigns a (return) type that can't fit EVL.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The sign extend wasn't created by the intrinsic. It was likely created by type legalization, but that's not the only way it can be created. So we can't assume the type in sign_extend_inreg represents the original type of the intrinsic.

I'm not even sure how to implement this combine to handle all the cases we need to handle. You'll probably get different results if you add the zeroext attribute to the EVL function argument in your test. That would remove the AND between the EVL and the vfirst, but it would not remove the sign_extend_inreg between the evl and the compare.

@dtcxzyw dtcxzyw requested review from lukel97 and wangpc-pp May 11, 2024 00:42
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
Projects
None yet
Development

Successfully merging this pull request may close these issues.

5 participants