Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[RISCV] Provide a more efficient lowering for experimental.cttz.elts. #88552

Merged
merged 6 commits into from
Apr 16, 2024

Conversation

topperc
Copy link
Collaborator

@topperc topperc commented Apr 12, 2024

For experimental.cttz.elts, we can use a vfirst instruction, but we need to correct the result if input vector can be 0. cttz.elts returns the vector length while vfirst returns -1.

For experimental.cttz.elts, we can use a vfirst instruction, but we
need to correct the result if input vector can be 0. cttz.elts returns
the vector length while vfirst returns -1.
@llvmbot
Copy link
Collaborator

llvmbot commented Apr 12, 2024

@llvm/pr-subscribers-backend-risc-v

Author: Craig Topper (topperc)

Changes

For experimental.cttz.elts, we can use a vfirst instruction, but we need to correct the result if input vector can be 0. cttz.elts returns the vector length while vfirst returns -1.


Full diff: https://github.com/llvm/llvm-project/pull/88552.diff

3 Files Affected:

  • (modified) llvm/lib/Target/RISCV/RISCVISelLowering.cpp (+39)
  • (modified) llvm/lib/Target/RISCV/RISCVISelLowering.h (+2)
  • (modified) llvm/test/CodeGen/RISCV/intrinsic-cttz-elts-vscale.ll (+67-22)
diff --git a/llvm/lib/Target/RISCV/RISCVISelLowering.cpp b/llvm/lib/Target/RISCV/RISCVISelLowering.cpp
index 5a572002091ff3..fa4e9bf002e967 100644
--- a/llvm/lib/Target/RISCV/RISCVISelLowering.cpp
+++ b/llvm/lib/Target/RISCV/RISCVISelLowering.cpp
@@ -1484,6 +1484,11 @@ bool RISCVTargetLowering::shouldExpandGetVectorLength(EVT TripCountVT,
   return VF > MaxVF || !isPowerOf2_32(VF);
 }
 
+bool RISCVTargetLowering::shouldExpandCttzElements(EVT VT) const {
+  return !Subtarget.hasVInstructions() ||
+         VT.getVectorElementType() != MVT::i1 || !isTypeLegal(VT);
+}
+
 bool RISCVTargetLowering::getTgtMemIntrinsic(IntrinsicInfo &Info,
                                              const CallInst &I,
                                              MachineFunction &MF,
@@ -8718,6 +8723,33 @@ static SDValue lowerGetVectorLength(SDNode *N, SelectionDAG &DAG,
   return DAG.getNode(ISD::TRUNCATE, DL, N->getValueType(0), Res);
 }
 
+static SDValue lowerCttzElts(SDNode *N, SelectionDAG &DAG,
+                             const RISCVSubtarget &Subtarget) {
+  SDValue Op0 = N->getOperand(1);
+  MVT OpVT = Op0.getSimpleValueType();
+  MVT ContainerVT = OpVT;
+  if (OpVT.isFixedLengthVector()) {
+    ContainerVT = getContainerForFixedLengthVector(DAG, OpVT, Subtarget);
+    Op0 = convertToScalableVector(ContainerVT, Op0, DAG, Subtarget);
+  }
+  MVT XLenVT = Subtarget.getXLenVT();
+  SDLoc DL(N);
+  auto [Mask, VL] = getDefaultVLOps(OpVT, ContainerVT, DL, DAG, Subtarget);
+  SDValue Res = DAG.getNode(RISCVISD::VFIRST_VL, DL, XLenVT, Op0, Mask, VL);
+  if (isOneConstant(N->getOperand(2)))
+    return Res;
+
+  // Convert -1 to VL.
+  SDValue Setcc =
+      DAG.getSetCC(DL, XLenVT, Res, DAG.getConstant(0, DL, XLenVT), ISD::SETLT);
+  // We need to use vscale rather than X0 for scalable vectors.
+  if (!OpVT.isFixedLengthVector())
+    VL = DAG.getVScale(
+        DL, XLenVT,
+        APInt(XLenVT.getSizeInBits(), OpVT.getVectorMinNumElements()));
+  return DAG.getSelect(DL, XLenVT, Setcc, VL, Res);
+}
+
 static inline void promoteVCIXScalar(const SDValue &Op,
                                      SmallVectorImpl<SDValue> &Operands,
                                      SelectionDAG &DAG) {
@@ -8913,6 +8945,8 @@ SDValue RISCVTargetLowering::LowerINTRINSIC_WO_CHAIN(SDValue Op,
   }
   case Intrinsic::experimental_get_vector_length:
     return lowerGetVectorLength(Op.getNode(), DAG, Subtarget);
+  case Intrinsic::experimental_cttz_elts:
+    return lowerCttzElts(Op.getNode(), DAG, Subtarget);
   case Intrinsic::riscv_vmv_x_s: {
     SDValue Res = DAG.getNode(RISCVISD::VMV_X_S, DL, XLenVT, Op.getOperand(1));
     return DAG.getNode(ISD::TRUNCATE, DL, Op.getValueType(), Res);
@@ -12336,6 +12370,11 @@ void RISCVTargetLowering::ReplaceNodeResults(SDNode *N,
       Results.push_back(DAG.getNode(ISD::TRUNCATE, DL, MVT::i32, Res));
       return;
     }
+    case Intrinsic::experimental_cttz_elts: {
+      SDValue Res = lowerCttzElts(N, DAG, Subtarget);
+      Results.push_back(DAG.getNode(ISD::TRUNCATE, DL, MVT::i32, Res));
+      return;
+    }
     case Intrinsic::riscv_orc_b:
     case Intrinsic::riscv_brev8:
     case Intrinsic::riscv_sha256sig0:
diff --git a/llvm/lib/Target/RISCV/RISCVISelLowering.h b/llvm/lib/Target/RISCV/RISCVISelLowering.h
index ace5b3fd2b95b4..e2633733c31b19 100644
--- a/llvm/lib/Target/RISCV/RISCVISelLowering.h
+++ b/llvm/lib/Target/RISCV/RISCVISelLowering.h
@@ -986,6 +986,8 @@ class RISCVTargetLowering : public TargetLowering {
   bool shouldExpandGetVectorLength(EVT TripCountVT, unsigned VF,
                                    bool IsScalable) const override;
 
+  bool shouldExpandCttzElements(EVT VT) const override;
+
   /// RVV code generation for fixed length vectors does not lower all
   /// BUILD_VECTORs. This makes BUILD_VECTOR legalisation a source of stores to
   /// merge. However, merging them creates a BUILD_VECTOR that is just as
diff --git a/llvm/test/CodeGen/RISCV/intrinsic-cttz-elts-vscale.ll b/llvm/test/CodeGen/RISCV/intrinsic-cttz-elts-vscale.ll
index 65d0768c60885d..8fe6ef0ab52a06 100644
--- a/llvm/test/CodeGen/RISCV/intrinsic-cttz-elts-vscale.ll
+++ b/llvm/test/CodeGen/RISCV/intrinsic-cttz-elts-vscale.ll
@@ -128,43 +128,88 @@ define i64 @ctz_nxv8i1_no_range(<vscale x 8 x i16> %a) {
 define i32 @ctz_nxv16i1(<vscale x 16 x i1> %pg, <vscale x 16 x i1> %a) {
 ; RV32-LABEL: ctz_nxv16i1:
 ; RV32:       # %bb.0:
-; RV32-NEXT:    vmv1r.v v0, v8
+; RV32-NEXT:    vsetvli a0, zero, e8, m2, ta, ma
+; RV32-NEXT:    vfirst.m a0, v8
+; RV32-NEXT:    bgez a0, .LBB2_2
+; RV32-NEXT:  # %bb.1:
 ; RV32-NEXT:    csrr a0, vlenb
 ; RV32-NEXT:    slli a0, a0, 1
-; RV32-NEXT:    vsetvli a1, zero, e32, m8, ta, ma
-; RV32-NEXT:    vmv.v.x v8, a0
-; RV32-NEXT:    vid.v v16
-; RV32-NEXT:    li a1, -1
-; RV32-NEXT:    vmadd.vx v16, a1, v8
-; RV32-NEXT:    vmv.v.i v8, 0
-; RV32-NEXT:    vmerge.vvm v8, v8, v16, v0
-; RV32-NEXT:    vredmaxu.vs v8, v8, v8
-; RV32-NEXT:    vmv.x.s a1, v8
-; RV32-NEXT:    sub a0, a0, a1
+; RV32-NEXT:  .LBB2_2:
 ; RV32-NEXT:    ret
 ;
 ; RV64-LABEL: ctz_nxv16i1:
 ; RV64:       # %bb.0:
-; RV64-NEXT:    vmv1r.v v0, v8
+; RV64-NEXT:    vsetvli a0, zero, e8, m2, ta, ma
+; RV64-NEXT:    vfirst.m a0, v8
+; RV64-NEXT:    bgez a0, .LBB2_2
+; RV64-NEXT:  # %bb.1:
 ; RV64-NEXT:    csrr a0, vlenb
 ; RV64-NEXT:    slli a0, a0, 1
-; RV64-NEXT:    vsetvli a1, zero, e32, m8, ta, ma
-; RV64-NEXT:    vmv.v.x v8, a0
-; RV64-NEXT:    vid.v v16
-; RV64-NEXT:    li a1, -1
-; RV64-NEXT:    vmadd.vx v16, a1, v8
-; RV64-NEXT:    vmv.v.i v8, 0
-; RV64-NEXT:    vmerge.vvm v8, v8, v16, v0
-; RV64-NEXT:    vredmaxu.vs v8, v8, v8
-; RV64-NEXT:    vmv.x.s a1, v8
-; RV64-NEXT:    subw a0, a0, a1
+; RV64-NEXT:  .LBB2_2:
 ; RV64-NEXT:    ret
   %res = call i32 @llvm.experimental.cttz.elts.i32.nxv16i1(<vscale x 16 x i1> %a, i1 0)
   ret i32 %res
 }
 
+define i32 @ctz_nxv16i1_poison(<vscale x 16 x i1> %pg, <vscale x 16 x i1> %a) {
+; RV32-LABEL: ctz_nxv16i1_poison:
+; RV32:       # %bb.0:
+; RV32-NEXT:    vsetvli a0, zero, e8, m2, ta, ma
+; RV32-NEXT:    vfirst.m a0, v8
+; RV32-NEXT:    ret
+;
+; RV64-LABEL: ctz_nxv16i1_poison:
+; RV64:       # %bb.0:
+; RV64-NEXT:    vsetvli a0, zero, e8, m2, ta, ma
+; RV64-NEXT:    vfirst.m a0, v8
+; RV64-NEXT:    ret
+  %res = call i32 @llvm.experimental.cttz.elts.i32.nxv16i1(<vscale x 16 x i1> %a, i1 1)
+  ret i32 %res
+}
+
+define i32 @ctz_v16i1(<16 x i1> %pg, <16 x i1> %a) {
+; RV32-LABEL: ctz_v16i1:
+; RV32:       # %bb.0:
+; RV32-NEXT:    vsetivli zero, 16, e8, m1, ta, ma
+; RV32-NEXT:    vfirst.m a0, v8
+; RV32-NEXT:    bgez a0, .LBB4_2
+; RV32-NEXT:  # %bb.1:
+; RV32-NEXT:    li a0, 16
+; RV32-NEXT:  .LBB4_2:
+; RV32-NEXT:    ret
+;
+; RV64-LABEL: ctz_v16i1:
+; RV64:       # %bb.0:
+; RV64-NEXT:    vsetivli zero, 16, e8, m1, ta, ma
+; RV64-NEXT:    vfirst.m a0, v8
+; RV64-NEXT:    bgez a0, .LBB4_2
+; RV64-NEXT:  # %bb.1:
+; RV64-NEXT:    li a0, 16
+; RV64-NEXT:  .LBB4_2:
+; RV64-NEXT:    ret
+  %res = call i32 @llvm.experimental.cttz.elts.i32.v16i1(<16 x i1> %a, i1 0)
+  ret i32 %res
+}
+
+define i32 @ctz_v16i1_poison(<16 x i1> %pg, <16 x i1> %a) {
+; RV32-LABEL: ctz_v16i1_poison:
+; RV32:       # %bb.0:
+; RV32-NEXT:    vsetivli zero, 16, e8, m1, ta, ma
+; RV32-NEXT:    vfirst.m a0, v8
+; RV32-NEXT:    ret
+;
+; RV64-LABEL: ctz_v16i1_poison:
+; RV64:       # %bb.0:
+; RV64-NEXT:    vsetivli zero, 16, e8, m1, ta, ma
+; RV64-NEXT:    vfirst.m a0, v8
+; RV64-NEXT:    ret
+  %res = call i32 @llvm.experimental.cttz.elts.i32.v16i1(<16 x i1> %a, i1 1)
+  ret i32 %res
+}
+
 declare i64 @llvm.experimental.cttz.elts.i64.nxv8i16(<vscale x 8 x i16>, i1)
 declare i32 @llvm.experimental.cttz.elts.i32.nxv16i1(<vscale x 16 x i1>, i1)
 declare i32 @llvm.experimental.cttz.elts.i32.nxv4i32(<vscale x 4 x i32>, i1)
+declare i32 @llvm.experimental.cttz.elts.i32.v16i1(<16 x i1>, i1)
 
 attributes #0 = { vscale_range(2,1024) }

Copy link
Contributor

@lukel97 lukel97 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

Comment on lines 8746 to 8749
if (!OpVT.isFixedLengthVector())
VL = DAG.getVScale(
DL, XLenVT,
APInt(XLenVT.getSizeInBits(), OpVT.getVectorMinNumElements()));
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think you can also do

Suggested change
if (!OpVT.isFixedLengthVector())
VL = DAG.getVScale(
DL, XLenVT,
APInt(XLenVT.getSizeInBits(), OpVT.getVectorMinNumElements()));
SDValue VL = DAG.getElementCount(DL, XLenVT, OpVT.getVectorElementCount())

Copy link
Collaborator

@preames preames left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM w/minor comment and a possible follow up.

@@ -1484,6 +1484,11 @@ bool RISCVTargetLowering::shouldExpandGetVectorLength(EVT TripCountVT,
return VF > MaxVF || !isPowerOf2_32(VF);
}

bool RISCVTargetLowering::shouldExpandCttzElements(EVT VT) const {
return !Subtarget.hasVInstructions() ||
VT.getVectorElementType() != MVT::i1 || !isTypeLegal(VT);
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We could handle the legal non-i1 case via s simple vector comparison against zero - producing a mask result that then flows through this code. From the tests, this looks like a pretty huge improvement in codegen quality.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That's true. I was just trying to cover the new usage from #88385

@@ -12336,6 +12370,11 @@ void RISCVTargetLowering::ReplaceNodeResults(SDNode *N,
Results.push_back(DAG.getNode(ISD::TRUNCATE, DL, MVT::i32, Res));
return;
}
case Intrinsic::experimental_cttz_elts: {
SDValue Res = lowerCttzElts(N, DAG, Subtarget);
Results.push_back(DAG.getNode(ISD::TRUNCATE, DL, MVT::i32, Res));
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The i32 seems slightly weird here - I'd expect XLenVT unless there's something I didn't spot on a first glance about when this code is reached?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is called by type legalization. It needs to return the original type so that it can plug in to the original users that haven't been legalized yet. Looking at it again, it should probably be N->getValuetype(0). I think we just lack test coverage for return types other than i32.

Copy link
Collaborator

@preames preames left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

@@ -12336,6 +12366,12 @@ void RISCVTargetLowering::ReplaceNodeResults(SDNode *N,
Results.push_back(DAG.getNode(ISD::TRUNCATE, DL, MVT::i32, Res));
return;
}
case Intrinsic::experimental_cttz_elts: {
EVT VT = N->getValueType(0);
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Single use variable.

@topperc topperc merged commit 5b9af38 into llvm:main Apr 16, 2024
3 of 4 checks passed
@topperc topperc deleted the pr/cttz_elts branch April 16, 2024 01:38
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

4 participants