-
Notifications
You must be signed in to change notification settings - Fork 15.2k
[RISCV] Optimize pattern (setcc (selectLT (vfirst_vl ...) , 0, EVL, ...), EVL)
#90538
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
@llvm/pr-subscribers-llvm-ir @llvm/pr-subscribers-backend-risc-v Author: Min-Yih Hsu (mshockwave) ChangesGiven Right now this PR is stacked on top of #90502 -- but it doesn't have to -- to show changes on the test. The main patch is 018b08a. I'll add more tests tomorrow. Patch is 33.18 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/90538.diff 14 Files Affected:
diff --git a/llvm/docs/LangRef.rst b/llvm/docs/LangRef.rst
index 37662f79145d67..f79c1fd9278de3 100644
--- a/llvm/docs/LangRef.rst
+++ b/llvm/docs/LangRef.rst
@@ -24001,6 +24001,54 @@ Examples:
%also.r = select <4 x i1> %mask, <4 x i32> %t, <4 x i32> poison
+.. _int_vp_cttz_elts:
+
+'``llvm.vp.cttz.elts.*``' Intrinsics
+^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
+
+Syntax:
+"""""""
+This is an overloaded intrinsic. You can use ```llvm.vp.cttz.elts``` on any
+vector of integer elements, both fixed width and scalable.
+
+::
+
+ declare i32 @llvm.vp.cttz.elts.i32.v16i32 (<16 x i32> <op>, i1 <is_zero_poison>, <16 x i1> <mask>, i32 <vector_length>)
+ declare i64 @llvm.vp.cttz.elts.i64.nxv4i32 (<vscale x 4 x i32> <op>, i1 <is_zero_poison>, <vscale x 4 x i1> <mask>, i32 <vector_length>)
+ declare i64 @llvm.vp.cttz.elts.i64.v256i1 (<256 x i1> <op>, i1 <is_zero_poison>, <256 x i1> <mask>, i32 <vector_length>)
+
+Overview:
+"""""""""
+
+This '```llvm.vp.cttz.elts```' intrinsic counts the number of trailing zero
+elements of a vector. This is basically the vector-predicated version of
+'```llvm.experimental.cttz.elts```'.
+
+Arguments:
+""""""""""
+
+The first argument is the vector to be counted. This argument must be a vector
+with integer element type. The return type must also be an integer type which is
+wide enough to hold the maximum number of elements of the source vector. The
+behavior of this intrinsic is undefined if the return type is not wide enough
+for the number of elements in the input vector.
+
+The second argument is a constant flag that indicates whether the intrinsic
+returns a valid result if the first argument is all zero.
+
+The third operand is the vector mask and has the same number of elements as the
+input vector type. The fourth operand is the explicit vector length of the
+operation.
+
+Semantics:
+""""""""""
+
+The '``llvm.vp.cttz.elts``' intrinsic counts the trailing (least
+significant / lowest-numbered) zero elements in the first operand on each
+enabled lane. If the first argument is all zero and the second argument is true,
+the result is poison. Otherwise, it returns the explicit vector length (i.e. the
+fourth operand).
+
.. _int_vp_sadd_sat:
'``llvm.vp.sadd.sat.*``' Intrinsics
diff --git a/llvm/include/llvm/CodeGen/TargetLowering.h b/llvm/include/llvm/CodeGen/TargetLowering.h
index 661b2841c6ac72..7ed08cfa8a2022 100644
--- a/llvm/include/llvm/CodeGen/TargetLowering.h
+++ b/llvm/include/llvm/CodeGen/TargetLowering.h
@@ -5307,6 +5307,11 @@ class TargetLowering : public TargetLoweringBase {
/// \returns The expansion result or SDValue() if it fails.
SDValue expandVPCTTZ(SDNode *N, SelectionDAG &DAG) const;
+ /// Expand VP_CTTZ_ELTS/VP_CTTZ_ELTS_ZERO_UNDEF nodes.
+ /// \param N Node to expand
+ /// \returns The expansion result or SDValue() if it fails.
+ SDValue expandVPCTTZElements(SDNode *N, SelectionDAG &DAG) const;
+
/// Expand ABS nodes. Expands vector/scalar ABS nodes,
/// vector nodes can only succeed if all operations are legal/custom.
/// (ABS x) -> (XOR (ADD x, (SRA x, type_size)), (SRA x, type_size))
diff --git a/llvm/include/llvm/IR/Intrinsics.td b/llvm/include/llvm/IR/Intrinsics.td
index a2678d69ce4062..28116e5316c96b 100644
--- a/llvm/include/llvm/IR/Intrinsics.td
+++ b/llvm/include/llvm/IR/Intrinsics.td
@@ -2255,6 +2255,12 @@ let IntrProperties = [IntrNoMem, IntrNoSync, IntrWillReturn, ImmArg<ArgIndex<1>>
llvm_i1_ty,
LLVMScalarOrSameVectorWidth<0, llvm_i1_ty>,
llvm_i32_ty]>;
+
+ def int_vp_cttz_elts : DefaultAttrsIntrinsic<[ llvm_anyint_ty ],
+ [ llvm_anyvector_ty,
+ llvm_i1_ty,
+ LLVMScalarOrSameVectorWidth<1, llvm_i1_ty>,
+ llvm_i32_ty]>;
}
def int_get_active_lane_mask:
diff --git a/llvm/include/llvm/IR/VPIntrinsics.def b/llvm/include/llvm/IR/VPIntrinsics.def
index 1c2708a9e85437..f1cc8bcae467be 100644
--- a/llvm/include/llvm/IR/VPIntrinsics.def
+++ b/llvm/include/llvm/IR/VPIntrinsics.def
@@ -282,6 +282,15 @@ BEGIN_REGISTER_VP_SDNODE(VP_CTTZ_ZERO_UNDEF, -1, vp_cttz_zero_undef, 1, 2)
END_REGISTER_VP_SDNODE(VP_CTTZ_ZERO_UNDEF)
END_REGISTER_VP_INTRINSIC(vp_cttz)
+// llvm.vp.cttz.elts(x,is_zero_poison,mask,vl)
+BEGIN_REGISTER_VP_INTRINSIC(vp_cttz_elts, 2, 3)
+VP_PROPERTY_NO_FUNCTIONAL
+BEGIN_REGISTER_VP_SDNODE(VP_CTTZ_ELTS, 0, vp_cttz_elts, 1, 2)
+END_REGISTER_VP_SDNODE(VP_CTTZ_ELTS)
+BEGIN_REGISTER_VP_SDNODE(VP_CTTZ_ELTS_ZERO_UNDEF, 0, vp_cttz_elts_zero_undef, 1, 2)
+END_REGISTER_VP_SDNODE(VP_CTTZ_ELTS_ZERO_UNDEF)
+END_REGISTER_VP_INTRINSIC(vp_cttz_elts)
+
// llvm.vp.fshl(x,y,z,mask,vlen)
BEGIN_REGISTER_VP(vp_fshl, 3, 4, VP_FSHL, -1)
VP_PROPERTY_FUNCTIONAL_INTRINSIC(fshl)
diff --git a/llvm/lib/CodeGen/SelectionDAG/LegalizeDAG.cpp b/llvm/lib/CodeGen/SelectionDAG/LegalizeDAG.cpp
index 46e54b5366d66a..5322ea3b6a2d97 100644
--- a/llvm/lib/CodeGen/SelectionDAG/LegalizeDAG.cpp
+++ b/llvm/lib/CodeGen/SelectionDAG/LegalizeDAG.cpp
@@ -1220,6 +1220,11 @@ void SelectionDAGLegalize::LegalizeOp(SDNode *Node) {
Action = TLI.getOperationAction(
Node->getOpcode(), Node->getOperand(1).getValueType());
break;
+ case ISD::VP_CTTZ_ELTS:
+ case ISD::VP_CTTZ_ELTS_ZERO_UNDEF:
+ Action = TLI.getOperationAction(Node->getOpcode(),
+ Node->getOperand(0).getValueType());
+ break;
default:
if (Node->getOpcode() >= ISD::BUILTIN_OP_END) {
Action = TLI.getCustomOperationAction(*Node);
@@ -4234,6 +4239,10 @@ bool SelectionDAGLegalize::ExpandNode(SDNode *Node) {
case ISD::VECREDUCE_FMINIMUM:
Results.push_back(TLI.expandVecReduce(Node, DAG));
break;
+ case ISD::VP_CTTZ_ELTS:
+ case ISD::VP_CTTZ_ELTS_ZERO_UNDEF:
+ Results.push_back(TLI.expandVPCTTZElements(Node, DAG));
+ break;
case ISD::GLOBAL_OFFSET_TABLE:
case ISD::GlobalAddress:
case ISD::GlobalTLSAddress:
diff --git a/llvm/lib/CodeGen/SelectionDAG/LegalizeIntegerTypes.cpp b/llvm/lib/CodeGen/SelectionDAG/LegalizeIntegerTypes.cpp
index 55f9737bc94dd5..0aa36deda79dcc 100644
--- a/llvm/lib/CodeGen/SelectionDAG/LegalizeIntegerTypes.cpp
+++ b/llvm/lib/CodeGen/SelectionDAG/LegalizeIntegerTypes.cpp
@@ -76,6 +76,10 @@ void DAGTypeLegalizer::PromoteIntegerResult(SDNode *N, unsigned ResNo) {
case ISD::VP_CTTZ:
case ISD::CTTZ_ZERO_UNDEF:
case ISD::CTTZ: Res = PromoteIntRes_CTTZ(N); break;
+ case ISD::VP_CTTZ_ELTS_ZERO_UNDEF:
+ case ISD::VP_CTTZ_ELTS:
+ Res = PromoteIntRes_VP_CttzElements(N);
+ break;
case ISD::EXTRACT_VECTOR_ELT:
Res = PromoteIntRes_EXTRACT_VECTOR_ELT(N); break;
case ISD::LOAD: Res = PromoteIntRes_LOAD(cast<LoadSDNode>(N)); break;
@@ -724,6 +728,12 @@ SDValue DAGTypeLegalizer::PromoteIntRes_CTTZ(SDNode *N) {
N->getOperand(2));
}
+SDValue DAGTypeLegalizer::PromoteIntRes_VP_CttzElements(SDNode *N) {
+ SDLoc DL(N);
+ EVT NewVT = TLI.getTypeToTransformTo(*DAG.getContext(), N->getValueType(0));
+ return DAG.getNode(N->getOpcode(), DL, NewVT, N->ops());
+}
+
SDValue DAGTypeLegalizer::PromoteIntRes_EXTRACT_VECTOR_ELT(SDNode *N) {
SDLoc dl(N);
EVT NVT = TLI.getTypeToTransformTo(*DAG.getContext(), N->getValueType(0));
diff --git a/llvm/lib/CodeGen/SelectionDAG/LegalizeTypes.h b/llvm/lib/CodeGen/SelectionDAG/LegalizeTypes.h
index 4a2c7b355eb528..49be824deb5134 100644
--- a/llvm/lib/CodeGen/SelectionDAG/LegalizeTypes.h
+++ b/llvm/lib/CodeGen/SelectionDAG/LegalizeTypes.h
@@ -309,6 +309,7 @@ class LLVM_LIBRARY_VISIBILITY DAGTypeLegalizer {
SDValue PromoteIntRes_CTLZ(SDNode *N);
SDValue PromoteIntRes_CTPOP_PARITY(SDNode *N);
SDValue PromoteIntRes_CTTZ(SDNode *N);
+ SDValue PromoteIntRes_VP_CttzElements(SDNode *N);
SDValue PromoteIntRes_EXTRACT_VECTOR_ELT(SDNode *N);
SDValue PromoteIntRes_FP_TO_XINT(SDNode *N);
SDValue PromoteIntRes_FP_TO_XINT_SAT(SDNode *N);
@@ -912,6 +913,7 @@ class LLVM_LIBRARY_VISIBILITY DAGTypeLegalizer {
SDValue SplitVecOp_FP_ROUND(SDNode *N);
SDValue SplitVecOp_FPOpDifferentTypes(SDNode *N);
SDValue SplitVecOp_FP_TO_XINT_SAT(SDNode *N);
+ SDValue SplitVecOp_VP_CttzElements(SDNode *N);
//===--------------------------------------------------------------------===//
// Vector Widening Support: LegalizeVectorTypes.cpp
@@ -1019,6 +1021,7 @@ class LLVM_LIBRARY_VISIBILITY DAGTypeLegalizer {
SDValue WidenVecOp_VECREDUCE_SEQ(SDNode *N);
SDValue WidenVecOp_VP_REDUCE(SDNode *N);
SDValue WidenVecOp_ExpOp(SDNode *N);
+ SDValue WidenVecOp_VP_CttzElements(SDNode *N);
/// Helper function to generate a set of operations to perform
/// a vector operation for a wider type.
diff --git a/llvm/lib/CodeGen/SelectionDAG/LegalizeVectorOps.cpp b/llvm/lib/CodeGen/SelectionDAG/LegalizeVectorOps.cpp
index 8f87ee8e09393a..26cd5482168f9f 100644
--- a/llvm/lib/CodeGen/SelectionDAG/LegalizeVectorOps.cpp
+++ b/llvm/lib/CodeGen/SelectionDAG/LegalizeVectorOps.cpp
@@ -510,6 +510,13 @@ SDValue VectorLegalizer::LegalizeOp(SDValue Op) {
if (Action != TargetLowering::Legal) \
break; \
} \
+ /* Defer non-vector results to LegalizeDAG. */ \
+ /* Remove this after #90522 is landed */ \
+ if (ISD::VPID == ISD::VP_CTTZ_ELTS || \
+ ISD::VPID == ISD::VP_CTTZ_ELTS_ZERO_UNDEF) { \
+ Action = TargetLowering::Legal; \
+ break; \
+ } \
Action = TLI.getOperationAction(Node->getOpcode(), LegalizeVT); \
} break;
#include "llvm/IR/VPIntrinsics.def"
diff --git a/llvm/lib/CodeGen/SelectionDAG/LegalizeVectorTypes.cpp b/llvm/lib/CodeGen/SelectionDAG/LegalizeVectorTypes.cpp
index 985c9f16ab97cd..cab4dc5f3c1565 100644
--- a/llvm/lib/CodeGen/SelectionDAG/LegalizeVectorTypes.cpp
+++ b/llvm/lib/CodeGen/SelectionDAG/LegalizeVectorTypes.cpp
@@ -3098,6 +3098,10 @@ bool DAGTypeLegalizer::SplitVectorOperand(SDNode *N, unsigned OpNo) {
case ISD::VP_REDUCE_FMIN:
Res = SplitVecOp_VP_REDUCE(N, OpNo);
break;
+ case ISD::VP_CTTZ_ELTS:
+ case ISD::VP_CTTZ_ELTS_ZERO_UNDEF:
+ Res = SplitVecOp_VP_CttzElements(N);
+ break;
}
// If the result is null, the sub-method took care of registering results etc.
@@ -4056,6 +4060,29 @@ SDValue DAGTypeLegalizer::SplitVecOp_FP_TO_XINT_SAT(SDNode *N) {
return DAG.getNode(ISD::CONCAT_VECTORS, dl, ResVT, Lo, Hi);
}
+SDValue DAGTypeLegalizer::SplitVecOp_VP_CttzElements(SDNode *N) {
+ SDLoc DL(N);
+ EVT ResVT = N->getValueType(0);
+
+ SDValue Lo, Hi;
+ SDValue VecOp = N->getOperand(0);
+ GetSplitVector(VecOp, Lo, Hi);
+
+ auto [MaskLo, MaskHi] = SplitMask(N->getOperand(1));
+ auto [EVLLo, EVLHi] =
+ DAG.SplitEVL(N->getOperand(2), VecOp.getValueType(), DL);
+ SDValue VLo = DAG.getZExtOrTrunc(EVLLo, DL, ResVT);
+
+ // if VP_CTTZ_ELTS(Lo) != EVLLo => VP_CTTZ_ELTS(Lo).
+ // else => EVLLo + (VP_CTTZ_ELTS(Hi) or VP_CTTZ_ELTS_ZERO_UNDEF(Hi)).
+ SDValue ResLo = DAG.getNode(ISD::VP_CTTZ_ELTS, DL, ResVT, Lo, MaskLo, EVLLo);
+ SDValue ResLoNotEVL =
+ DAG.getSetCC(DL, getSetCCResultType(ResVT), ResLo, VLo, ISD::SETNE);
+ SDValue ResHi = DAG.getNode(N->getOpcode(), DL, ResVT, Hi, MaskHi, EVLHi);
+ return DAG.getSelect(DL, ResVT, ResLoNotEVL, ResLo,
+ DAG.getNode(ISD::ADD, DL, ResVT, VLo, ResHi));
+}
+
//===----------------------------------------------------------------------===//
// Result Vector Widening
//===----------------------------------------------------------------------===//
@@ -6161,6 +6188,10 @@ bool DAGTypeLegalizer::WidenVectorOperand(SDNode *N, unsigned OpNo) {
case ISD::VP_REDUCE_FMIN:
Res = WidenVecOp_VP_REDUCE(N);
break;
+ case ISD::VP_CTTZ_ELTS:
+ case ISD::VP_CTTZ_ELTS_ZERO_UNDEF:
+ Res = WidenVecOp_VP_CttzElements(N);
+ break;
}
// If Res is null, the sub-method took care of registering the result.
@@ -6924,6 +6955,17 @@ SDValue DAGTypeLegalizer::WidenVecOp_VSELECT(SDNode *N) {
DAG.getVectorIdxConstant(0, DL));
}
+SDValue DAGTypeLegalizer::WidenVecOp_VP_CttzElements(SDNode *N) {
+ SDLoc DL(N);
+ SDValue Source = GetWidenedVector(N->getOperand(0));
+ EVT SrcVT = Source.getValueType();
+ SDValue Mask =
+ GetWidenedMask(N->getOperand(1), SrcVT.getVectorElementCount());
+
+ return DAG.getNode(N->getOpcode(), DL, N->getValueType(0),
+ {Source, Mask, N->getOperand(2)}, N->getFlags());
+}
+
//===----------------------------------------------------------------------===//
// Vector Widening Utilities
//===----------------------------------------------------------------------===//
diff --git a/llvm/lib/CodeGen/SelectionDAG/SelectionDAGBuilder.cpp b/llvm/lib/CodeGen/SelectionDAG/SelectionDAGBuilder.cpp
index 5caf868c83a296..cfd82a342433fa 100644
--- a/llvm/lib/CodeGen/SelectionDAG/SelectionDAGBuilder.cpp
+++ b/llvm/lib/CodeGen/SelectionDAG/SelectionDAGBuilder.cpp
@@ -8076,6 +8076,11 @@ static unsigned getISDForVPIntrinsic(const VPIntrinsic &VPIntrin) {
ResOPC = IsZeroUndef ? ISD::VP_CTTZ_ZERO_UNDEF : ISD::VP_CTTZ;
break;
}
+ case Intrinsic::vp_cttz_elts: {
+ bool IsZeroPoison = cast<ConstantInt>(VPIntrin.getArgOperand(1))->isOne();
+ ResOPC = IsZeroPoison ? ISD::VP_CTTZ_ELTS_ZERO_UNDEF : ISD::VP_CTTZ_ELTS;
+ break;
+ }
#define HELPER_MAP_VPID_TO_VPSD(VPID, VPSD) \
case Intrinsic::VPID: \
ResOPC = ISD::VPSD; \
@@ -8428,7 +8433,9 @@ void SelectionDAGBuilder::visitVectorPredicationIntrinsic(
case ISD::VP_CTLZ:
case ISD::VP_CTLZ_ZERO_UNDEF:
case ISD::VP_CTTZ:
- case ISD::VP_CTTZ_ZERO_UNDEF: {
+ case ISD::VP_CTTZ_ZERO_UNDEF:
+ case ISD::VP_CTTZ_ELTS_ZERO_UNDEF:
+ case ISD::VP_CTTZ_ELTS: {
SDValue Result =
DAG.getNode(Opcode, DL, VTs, {OpValues[0], OpValues[2], OpValues[3]});
setValue(&VPIntrin, Result);
diff --git a/llvm/lib/CodeGen/SelectionDAG/TargetLowering.cpp b/llvm/lib/CodeGen/SelectionDAG/TargetLowering.cpp
index cdc1227fd572dc..336d89fbcf638e 100644
--- a/llvm/lib/CodeGen/SelectionDAG/TargetLowering.cpp
+++ b/llvm/lib/CodeGen/SelectionDAG/TargetLowering.cpp
@@ -9074,6 +9074,39 @@ SDValue TargetLowering::expandVPCTTZ(SDNode *Node, SelectionDAG &DAG) const {
return DAG.getNode(ISD::VP_CTPOP, dl, VT, Tmp, Mask, VL);
}
+SDValue TargetLowering::expandVPCTTZElements(SDNode *N,
+ SelectionDAG &DAG) const {
+ // %cond = to_bool_vec %source
+ // %splat = splat /*val=*/VL
+ // %tz = step_vector
+ // %v = vp.select %cond, /*true=*/tz, /*false=*/%splat
+ // %r = vp.reduce.umin %v
+ SDLoc DL(N);
+ SDValue Source = N->getOperand(0);
+ SDValue Mask = N->getOperand(1);
+ SDValue EVL = N->getOperand(2);
+ EVT SrcVT = Source.getValueType();
+ EVT ResVT = N->getValueType(0);
+ EVT ResVecVT =
+ EVT::getVectorVT(*DAG.getContext(), ResVT, SrcVT.getVectorElementCount());
+
+ // Convert to boolean vector.
+ if (SrcVT.getScalarType() != MVT::i1) {
+ SDValue AllZero = DAG.getConstant(0, DL, SrcVT);
+ SrcVT = EVT::getVectorVT(*DAG.getContext(), MVT::i1,
+ SrcVT.getVectorElementCount());
+ Source = DAG.getNode(ISD::VP_SETCC, DL, SrcVT, Source, AllZero,
+ DAG.getCondCode(ISD::SETNE), Mask, EVL);
+ }
+
+ SDValue ExtEVL = DAG.getZExtOrTrunc(EVL, DL, ResVT);
+ SDValue Splat = DAG.getSplat(ResVecVT, DL, ExtEVL);
+ SDValue StepVec = DAG.getStepVector(DL, ResVecVT);
+ SDValue Select =
+ DAG.getNode(ISD::VP_SELECT, DL, ResVecVT, Source, StepVec, Splat, EVL);
+ return DAG.getNode(ISD::VP_REDUCE_UMIN, DL, ResVT, ExtEVL, Select, Mask, EVL);
+}
+
SDValue TargetLowering::expandABS(SDNode *N, SelectionDAG &DAG,
bool IsNegative) const {
SDLoc dl(N);
diff --git a/llvm/lib/Target/RISCV/RISCVISelLowering.cpp b/llvm/lib/Target/RISCV/RISCVISelLowering.cpp
index 68f4ec5ef49f31..c61e477d79e110 100644
--- a/llvm/lib/Target/RISCV/RISCVISelLowering.cpp
+++ b/llvm/lib/Target/RISCV/RISCVISelLowering.cpp
@@ -28,6 +28,7 @@
#include "llvm/CodeGen/MachineInstrBuilder.h"
#include "llvm/CodeGen/MachineJumpTableInfo.h"
#include "llvm/CodeGen/MachineRegisterInfo.h"
+#include "llvm/CodeGen/SDPatternMatch.h"
#include "llvm/CodeGen/SelectionDAGAddressAnalysis.h"
#include "llvm/CodeGen/TargetLoweringObjectFileImpl.h"
#include "llvm/CodeGen/ValueTypes.h"
@@ -698,7 +699,7 @@ RISCVTargetLowering::RISCVTargetLowering(const TargetMachine &TM,
ISD::VP_SMAX, ISD::VP_UMIN, ISD::VP_UMAX,
ISD::VP_ABS, ISD::EXPERIMENTAL_VP_REVERSE, ISD::EXPERIMENTAL_VP_SPLICE,
ISD::VP_SADDSAT, ISD::VP_UADDSAT, ISD::VP_SSUBSAT,
- ISD::VP_USUBSAT};
+ ISD::VP_USUBSAT, ISD::VP_CTTZ_ELTS, ISD::VP_CTTZ_ELTS_ZERO_UNDEF};
static const unsigned FloatingPointVPOps[] = {
ISD::VP_FADD, ISD::VP_FSUB, ISD::VP_FMUL,
@@ -759,6 +760,9 @@ RISCVTargetLowering::RISCVTargetLowering(const TargetMachine &TM,
{ISD::SELECT_CC, ISD::VSELECT, ISD::VP_MERGE, ISD::VP_SELECT}, VT,
Expand);
+ setOperationAction({ISD::VP_CTTZ_ELTS, ISD::VP_CTTZ_ELTS_ZERO_UNDEF}, VT,
+ Custom);
+
setOperationAction({ISD::VP_AND, ISD::VP_OR, ISD::VP_XOR}, VT, Custom);
setOperationAction(
@@ -5341,6 +5345,44 @@ RISCVTargetLowering::lowerCTLZ_CTTZ_ZERO_UNDEF(SDValue Op,
return Res;
}
+SDValue RISCVTargetLowering::lowerVPCttzElements(SDValue Op,
+ SelectionDAG &DAG) const {
+ SDLoc DL(Op);
+ MVT XLenVT = Subtarget.getXLenVT();
+ SDValue Source = Op->getOperand(0);
+ MVT SrcVT = Source.getSimpleValueType();
+ SDValue Mask = Op->getOperand(1);
+ SDValue EVL = Op->getOperand(2);
+
+ if (SrcVT.isFixedLengthVector()) {
+ MVT ContainerVT = getContainerForFixedLengthVector(SrcVT);
+ Source = convertToScalableVector(ContainerVT, Source, DAG, Subtarget);
+ Mask = convertToScalableVector(getMaskTypeFor(ContainerVT), Mask, DAG,
+ Subtarget);
+ SrcVT = ContainerVT;
+ }
+
+ // Convert to boolean vector.
+ if (SrcVT.getScalarType() != MVT::i1) {
+ SDValue AllZero = DAG.getConstant(0, DL, SrcVT);
+ SrcVT = MVT::getVectorVT(MVT::i1, SrcVT.getVectorElementCount());
+ Source = DAG.getNode(RISCVISD::SETCC_VL, DL, SrcVT,
+ {Source, AllZero, DAG.getCondCode(ISD::SETNE),
+ DAG.getUNDEF(SrcVT), Mask, EVL});
+ }
+
+ SDValue Res = DAG.getNode(RISCVISD::VFIRST_VL, DL, XLenVT, Source, Mask, EVL);
+ if (Op->getOpcode() == ISD::VP_CTTZ_ELTS_ZERO_UNDEF)
+ // In this case, we can interpret poison as -1, so nothing to do further.
+ return Res;
+
+ // Convert -1 to VL.
+ SDValue SetCC =
+ DAG.getSetCC(DL, XLenVT, Res, DAG.getConstant(0, DL, XLenVT), ISD::SETLT);
+ Res = DAG.getSelect(DL, XLenVT, SetCC, EVL, Res);
+ return DAG.getNode(ISD::TRUNCATE, DL, Op.getValueType(), Res);
+}
+
// While RVV has alignment restrictions, we should always be able to load as a
// legal equivalently-sized byte-typed vector instead. This method is
// responsible for re-expressing a ISD::LOAD via a correctly-aligned type. If
@@ -6595,6 +6637,9 @@ SDValue RISCVTargetLowering::LowerOperation(SDValue Op,
if (Op.getOperand(1).getValueType().getVectorElementType() == MVT::i1)
return lowerVectorMaskVecReduction(Op, DAG, /*IsVP*/ true);
return lowerVPREDUCE(Op, DAG);
+ case ISD::VP_CTTZ_ELTS:
+ case ISD::VP_CTTZ_ELTS_ZERO_UNDEF:
+ return lowerVPCttzElements(Op, DAG);
case ISD::UNDEF: {
MVT ContainerVT = getContainerForFixedLengthVector(Op.getSimpleValueType());
return convertFromScalableVector(Op.getSimpleValueType(),
@@ -13634,9 +13679,69 @@ st...
[truncated]
|
✅ With the latest revision this PR passed the C/C++ code formatter. |
ISD::CondCode NewCC = cast<CondCodeSDNode>(Select->getOperand(2))->get(); | ||
if (Inverse) | ||
NewCC = ISD::getSetCCInverse(NewCC, OpVT); | ||
return DAG.getNode( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Use getSetCC.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Fixed.
…cc A, B)` when possible Given `(seteq (riscv_selectcc LHS, RHS, CC, X, Y), X)`, we can turn it into `(setCC LHS, RHS)`. I think we can generalize this into ISD::SELECT_CC as well.
018b08a
to
9416591
Compare
; NOTE: Assertions have been autogenerated by utils/update_llc_test_checks.py UTC_ARGS: --version 4 | ||
; RUN: llc -mtriple=riscv32 < %s | FileCheck %s | ||
|
||
define i1 @eq(i32 %a, i32 %b, i32 %c, i32 %d) { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It is incorrect. See https://alive2.llvm.org/ce/z/xjQp8n
(setcc (riscv_selectcc A, B, ...), Y)
to just (setcc A, B)
when possible(setcc (selectLT (vfirst_vl ...) , 0, EVL, ...), EVL)
I have limited the scope of this optimization to handle only the pattern associated with |
// a llvm.vp.cttz.elts that doesn't return XLen type. Due to non-trivial | ||
// number of sext(_inreg) and zext interleaving between the nodes. | ||
// That said, this is not really problem as long as we always generate the | ||
// said intrinsics with XLen return type. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The vectorizer will very likely use i32 not XLen.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Fixed. Now it recognizes vp.cttz.elts with non-XLen return type.
SDValue &Select) -> bool { | ||
// Remove any sext or zext | ||
auto ExtPattern = | ||
m_AnyOf(m_Opc(ISD::SIGN_EXTEND_INREG), m_And(m_Value(), m_AllOnes())); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
m_And(m_Value(), m_AllOnes()) isn't a zero extend and should be removed by DAG combine
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'm not sure m_AllOnes is implemented correctly. It uses APInt::isSameValue
which zero extends when bit widths don't match. m_AllOnes
should be checking that all bits are one. If m_AllOnes
is used to compare against a value that is less than 64 bits, that value will be zero extended to 64 bits and will fail to match.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It's fixed now. In this particular case I think VL will almost certain be zext from i32 so I'm pattern matching (and X, <32 trailing ones>)
here instead.
// against) is zext from i32. | ||
auto ZExtVL = m_And(m_Value(), m_SpecificInt(APInt::getLowBitsSet(64, 32))); | ||
|
||
// Remove any sext or zext |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'm nervous about peeking through zext and sexts without checking of types. I think we need to know that the compare is only using the same bits from EVL that the vfirst was using.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I agree. It's fixed now.
As the result of RISCV::SELECT_CC.
SDValue VLCandVTNode; | ||
EVT VLCandVT = VLCandidate.getValueType(); | ||
// Remove any sext. | ||
if (sd_match(Op, m_Opc(ISD::SIGN_EXTEND_INREG))) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Do we also need to make sure this doesn't drop any bits? If this sign_extends from a type that doesn't fit the EVL then we may be losing some information.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I thought about this before, but the intrinsic states that it's an undefined behavior if the user assigns a (return) type that can't fit EVL.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The sign extend wasn't created by the intrinsic. It was likely created by type legalization, but that's not the only way it can be created. So we can't assume the type in sign_extend_inreg represents the original type of the intrinsic.
I'm not even sure how to implement this combine to handle all the cases we need to handle. You'll probably get different results if you add the zeroext
attribute to the EVL function argument in your test. That would remove the AND between the EVL and the vfirst, but it would not remove the sign_extend_inreg between the evl and the compare.
Given
VFirst = (vfirst_vl ..., EVL)
and(seteq (riscv_selectLT VFirst, 0, EVL, VFirst), EVL)
, we can replace it with(setLT VFirst, 0)
. Similar replacements are done for variants w/setne
andriscv_selectGE
.