Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[RISCV] Construct constants via instructions if materialization is costly #86926

Open
wants to merge 1 commit into
base: main
Choose a base branch
from

Conversation

wangpc-pp
Copy link
Contributor

@wangpc-pp wangpc-pp commented Mar 28, 2024

For RISCV, it is costly to materialize constants used in lowering
ISD::CTPOP/ISD::VP_CTPOP.

We can query the materialization cost via RISCVMatInt::getIntMatCost
and if the cost is larger than 2, we should construct the constant
via two instructions.

This PR contains only vector part. We don't do the optimization if
LMUL is equal to LMUL_8 as it increases register preasure.

@llvmbot
Copy link
Collaborator

llvmbot commented Mar 28, 2024

@llvm/pr-subscribers-backend-risc-v

Author: Wang Pengcheng (wangpc-pp)

Changes

For RISCV, it is costly to materialize constants used in lowering
ISD::CTPOP/ISD::VP_CTPOP.

We can query the materialization cost via RISCVMatInt::getIntMatCost
and if the cost is larger than 2, we should construct the constant
via two instructions.

This fixes #86207.


Patch is 952.85 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/86926.diff

21 Files Affected:

  • (modified) llvm/lib/Target/RISCV/RISCVISelLowering.cpp (+195-6)
  • (modified) llvm/lib/Target/RISCV/RISCVISelLowering.h (+3)
  • (modified) llvm/test/CodeGen/RISCV/ctlz-cttz-ctpop.ll (+1090-518)
  • (modified) llvm/test/CodeGen/RISCV/ctz_zero_return_test.ll (+63-44)
  • (modified) llvm/test/CodeGen/RISCV/pr56457.ll (+10-14)
  • (modified) llvm/test/CodeGen/RISCV/rv64-legal-i32/rv64xtheadbb.ll (+35-22)
  • (modified) llvm/test/CodeGen/RISCV/rv64-legal-i32/rv64zbb.ll (+52-43)
  • (modified) llvm/test/CodeGen/RISCV/rv64xtheadbb.ll (+110-77)
  • (modified) llvm/test/CodeGen/RISCV/rv64zbb.ll (+211-172)
  • (modified) llvm/test/CodeGen/RISCV/rvv/ctlz-sdnode.ll (+324-340)
  • (modified) llvm/test/CodeGen/RISCV/rvv/ctpop-sdnode.ll (+180-192)
  • (modified) llvm/test/CodeGen/RISCV/rvv/ctpop-vp.ll (+1167-779)
  • (modified) llvm/test/CodeGen/RISCV/rvv/cttz-sdnode.ll (+396-440)
  • (modified) llvm/test/CodeGen/RISCV/rvv/cttz-vp.ll (+1246-888)
  • (modified) llvm/test/CodeGen/RISCV/rvv/fixed-vectors-ctlz-vp.ll (+2515-1871)
  • (modified) llvm/test/CodeGen/RISCV/rvv/fixed-vectors-ctlz.ll (+116-164)
  • (modified) llvm/test/CodeGen/RISCV/rvv/fixed-vectors-ctpop-vp.ll (+1220-839)
  • (modified) llvm/test/CodeGen/RISCV/rvv/fixed-vectors-ctpop.ll (+70-94)
  • (modified) llvm/test/CodeGen/RISCV/rvv/fixed-vectors-cttz-vp.ll (+2605-1961)
  • (modified) llvm/test/CodeGen/RISCV/rvv/fixed-vectors-cttz.ll (+172-220)
  • (modified) llvm/test/CodeGen/RISCV/sextw-removal.ll (+29-29)
diff --git a/llvm/lib/Target/RISCV/RISCVISelLowering.cpp b/llvm/lib/Target/RISCV/RISCVISelLowering.cpp
index e6814c5f71a09b..031030990d4405 100644
--- a/llvm/lib/Target/RISCV/RISCVISelLowering.cpp
+++ b/llvm/lib/Target/RISCV/RISCVISelLowering.cpp
@@ -391,7 +391,9 @@ RISCVTargetLowering::RISCVTargetLowering(const TargetMachine &TM,
         setOperationAction({ISD::CTTZ, ISD::CTTZ_ZERO_UNDEF}, MVT::i32, Custom);
     }
   } else if (!Subtarget.hasVendorXCVbitmanip()) {
-    setOperationAction({ISD::CTTZ, ISD::CTPOP}, XLenVT, Expand);
+    setOperationAction(ISD::CTTZ, XLenVT, Expand);
+    setOperationAction(ISD::CTPOP, XLenVT,
+                       Subtarget.is64Bit() ? Custom : Expand);
     if (RV64LegalI32 && Subtarget.is64Bit())
       setOperationAction({ISD::CTTZ, ISD::CTPOP}, MVT::i32, Expand);
   }
@@ -901,11 +903,13 @@ RISCVTargetLowering::RISCVTargetLowering(const TargetMachine &TM,
                            VT, Custom);
       } else {
         setOperationAction({ISD::BITREVERSE, ISD::VP_BITREVERSE}, VT, Expand);
-        setOperationAction({ISD::CTLZ, ISD::CTTZ, ISD::CTPOP}, VT, Expand);
+        setOperationAction({ISD::CTLZ, ISD::CTTZ}, VT, Expand);
         setOperationAction({ISD::VP_CTLZ, ISD::VP_CTLZ_ZERO_UNDEF, ISD::VP_CTTZ,
-                            ISD::VP_CTTZ_ZERO_UNDEF, ISD::VP_CTPOP},
+                            ISD::VP_CTTZ_ZERO_UNDEF},
                            VT, Expand);
 
+        setOperationAction({ISD::CTPOP, ISD::VP_CTPOP}, VT, Custom);
+
         // Lower CTLZ_ZERO_UNDEF and CTTZ_ZERO_UNDEF if element of VT in the
         // range of f32.
         EVT FloatVT = MVT::getVectorVT(MVT::f32, VT.getVectorElementCount());
@@ -1238,6 +1242,7 @@ RISCVTargetLowering::RISCVTargetLowering(const TargetMachine &TM,
                               ISD::CTTZ, ISD::CTTZ_ZERO_UNDEF, ISD::CTPOP},
                              VT, Custom);
         } else {
+          setOperationAction({ISD::CTPOP, ISD::VP_CTPOP}, VT, Custom);
           // Lower CTLZ_ZERO_UNDEF and CTTZ_ZERO_UNDEF if element of VT in the
           // range of f32.
           EVT FloatVT = MVT::getVectorVT(MVT::f32, VT.getVectorElementCount());
@@ -6746,8 +6751,18 @@ SDValue RISCVTargetLowering::LowerOperation(SDValue Op,
   case ISD::UDIV:
   case ISD::UREM:
   case ISD::BSWAP:
-  case ISD::CTPOP:
     return lowerToScalableOp(Op, DAG);
+  case ISD::CTPOP: {
+    if (Op.getValueType().isScalarInteger())
+      return lowerScalarCTPOP(Op, DAG);
+    if (Subtarget.hasStdExtZvbb())
+      return lowerToScalableOp(Op, DAG);
+    return lowerVectorCTPOP(Op, DAG);
+  }
+  case ISD::VP_CTPOP:
+    if (Subtarget.hasStdExtZvbb())
+      return lowerVPOp(Op, DAG);
+    return lowerVectorCTPOP(Op, DAG);
   case ISD::SHL:
   case ISD::SRA:
   case ISD::SRL:
@@ -6972,8 +6987,6 @@ SDValue RISCVTargetLowering::LowerOperation(SDValue Op,
     if (Subtarget.hasStdExtZvbb())
       return lowerVPOp(Op, DAG);
     return lowerCTLZ_CTTZ_ZERO_UNDEF(Op, DAG);
-  case ISD::VP_CTPOP:
-    return lowerVPOp(Op, DAG);
   case ISD::EXPERIMENTAL_VP_STRIDED_LOAD:
     return lowerVPStridedLoad(Op, DAG);
   case ISD::EXPERIMENTAL_VP_STRIDED_STORE:
@@ -10755,6 +10768,182 @@ SDValue RISCVTargetLowering::lowerABS(SDValue Op, SelectionDAG &DAG) const {
   return Max;
 }
 
+SDValue RISCVTargetLowering::lowerScalarCTPOP(SDValue Op,
+                                              SelectionDAG &DAG) const {
+  MVT VT = Op.getSimpleValueType();
+  SDLoc DL(Op);
+  MVT ShVT = getShiftAmountTy(VT, DAG.getDataLayout()).getSimpleVT();
+  unsigned Len = VT.getScalarSizeInBits();
+  assert(VT.isInteger() && "lowerScalarCTPOP not implemented for this type.");
+
+  SDValue V = Op.getOperand(0);
+
+  // This is same algorithm of TargetLowering::expandCTPOP from
+  // http://graphics.stanford.edu/~seander/bithacks.html#CountBitsSetParallel
+  // 0x0F0F0F0F...
+  const APInt &Constant0F = APInt::getSplat(Len, APInt(8, 0x0F));
+  SDValue Mask0F = DAG.getConstant(Constant0F, DL, VT, false, true);
+  // 0x33333333... = (0x0F0F0F0F... ^ (0x0F0F0F0F... << 2))
+  const APInt &Constant33 = APInt::getSplat(Len, APInt(8, 0x33));
+  SDValue Mask33 =
+      RISCVMatInt::getIntMatCost(Constant33, VT.getScalarSizeInBits(),
+                                 Subtarget) > 2
+          ? DAG.getNode(ISD::XOR, DL, VT, Mask0F,
+                        DAG.getNode(ISD::SHL, DL, VT, Mask0F,
+                                    DAG.getShiftAmountConstant(2, VT, DL)))
+          : DAG.getConstant(Constant33, DL, VT);
+  // 0x55555555... = (0x33333333... ^ (0x33333333... << 1))
+  const APInt &Constant55 = APInt::getSplat(Len, APInt(8, 0x55));
+  SDValue Mask55 =
+      RISCVMatInt::getIntMatCost(Constant55, VT.getScalarSizeInBits(),
+                                 Subtarget) > 2
+          ? DAG.getNode(ISD::XOR, DL, VT, Mask33,
+                        DAG.getNode(ISD::SHL, DL, VT, Mask33,
+                                    DAG.getShiftAmountConstant(1, VT, DL)))
+          : DAG.getConstant(Constant55, DL, VT);
+
+  // v = v - ((v >> 1) & 0x55555555...)
+  V = DAG.getNode(ISD::SUB, DL, VT, V,
+                  DAG.getNode(ISD::AND, DL, VT,
+                              DAG.getNode(ISD::SRL, DL, VT, V,
+                                          DAG.getConstant(1, DL, ShVT)),
+                              Mask55));
+  // v = (v & 0x33333333...) + ((v >> 2) & 0x33333333...)
+  V = DAG.getNode(ISD::ADD, DL, VT, DAG.getNode(ISD::AND, DL, VT, V, Mask33),
+                  DAG.getNode(ISD::AND, DL, VT,
+                              DAG.getNode(ISD::SRL, DL, VT, V,
+                                          DAG.getConstant(2, DL, ShVT)),
+                              Mask33));
+  // v = (v + (v >> 4)) & 0x0F0F0F0F...
+  V = DAG.getNode(ISD::AND, DL, VT,
+                  DAG.getNode(ISD::ADD, DL, VT, V,
+                              DAG.getNode(ISD::SRL, DL, VT, V,
+                                          DAG.getConstant(4, DL, ShVT))),
+                  Mask0F);
+
+  // v = (v * 0x01010101...) >> (Len - 8)
+  // 0x01010101... == (0x0F0F0F0F... & (0x0F0F0F0F... >> 3))
+  const APInt &Constant01 = APInt::getSplat(Len, APInt(8, 0x01));
+  SDValue Mask01 =
+      RISCVMatInt::getIntMatCost(Constant01, VT.getScalarSizeInBits(),
+                                 Subtarget) > 2
+          ? DAG.getNode(ISD::AND, DL, VT, Mask0F,
+                        DAG.getNode(ISD::SRL, DL, VT, Mask0F,
+                                    DAG.getShiftAmountConstant(3, VT, DL)))
+          : DAG.getConstant(Constant01, DL, VT);
+  return DAG.getNode(ISD::SRL, DL, VT, DAG.getNode(ISD::MUL, DL, VT, V, Mask01),
+                     DAG.getConstant(Len - 8, DL, ShVT));
+}
+
+SDValue RISCVTargetLowering::lowerVectorCTPOP(SDValue Op,
+                                              SelectionDAG &DAG) const {
+  SDLoc DL(Op);
+  MVT VT = Op.getSimpleValueType();
+  unsigned Len = VT.getScalarSizeInBits();
+  assert(VT.isInteger() && "lowerVectorCTPOP not implemented for this type.");
+
+  SDValue V = Op.getOperand(0);
+  MVT ContainerVT = VT;
+  if (VT.isFixedLengthVector()) {
+    ContainerVT = getContainerForFixedLengthVector(VT);
+    V = convertToScalableVector(ContainerVT, V, DAG, Subtarget);
+  }
+
+  SDValue Mask, VL;
+  if (Op->getOpcode() == ISD::VP_CTPOP) {
+    Mask = Op->getOperand(1);
+    if (VT.isFixedLengthVector())
+      Mask = convertToScalableVector(getMaskTypeFor(ContainerVT), Mask, DAG,
+                                     Subtarget);
+    VL = Op->getOperand(2);
+  } else
+    std::tie(Mask, VL) = getDefaultVLOps(VT, ContainerVT, DL, DAG, Subtarget);
+
+  // This is same algorithm of TargetLowering::expandVPCTPOP from
+  // http://graphics.stanford.edu/~seander/bithacks.html#CountBitsSetParallel
+
+  // 0x0F0F0F0F...
+  const APInt &Constant0F = APInt::getSplat(Len, APInt(8, 0x0F));
+  SDValue Mask0F = DAG.getConstant(Constant0F, DL, ContainerVT);
+  // 0x33333333... = (0x0F0F0F0F... ^ (0x0F0F0F0F... << 2))
+  const APInt &Constant33 = APInt::getSplat(Len, APInt(8, 0x33));
+  SDValue Mask33 =
+      RISCVMatInt::getIntMatCost(Constant33, ContainerVT.getScalarSizeInBits(),
+                                 Subtarget) > 2
+          ? DAG.getNode(RISCVISD::XOR_VL, DL, ContainerVT, Mask0F,
+                        DAG.getNode(RISCVISD::SHL_VL, DL, ContainerVT, Mask0F,
+                                    DAG.getConstant(2, DL, ContainerVT),
+                                    DAG.getUNDEF(ContainerVT), Mask, VL),
+                        DAG.getUNDEF(ContainerVT), Mask, VL)
+          : DAG.getConstant(Constant33, DL, ContainerVT);
+  // 0x55555555... = (0x33333333... ^ (0x33333333... << 1))
+  const APInt &Constant55 = APInt::getSplat(Len, APInt(8, 0x55));
+  SDValue Mask55 =
+      RISCVMatInt::getIntMatCost(Constant55, ContainerVT.getScalarSizeInBits(),
+                                 Subtarget) > 2
+          ? DAG.getNode(RISCVISD::XOR_VL, DL, ContainerVT, Mask33,
+                        DAG.getNode(RISCVISD::SHL_VL, DL, ContainerVT, Mask33,
+                                    DAG.getConstant(1, DL, ContainerVT),
+                                    DAG.getUNDEF(ContainerVT), Mask, VL),
+                        DAG.getUNDEF(ContainerVT), Mask, VL)
+          : DAG.getConstant(Constant55, DL, ContainerVT);
+
+  SDValue Tmp1, Tmp2, Tmp3, Tmp4, Tmp5;
+
+  // v = v - ((v >> 1) & 0x55555555...)
+  Tmp1 = DAG.getNode(RISCVISD::AND_VL, DL, ContainerVT,
+                     DAG.getNode(RISCVISD::SRL_VL, DL, ContainerVT, V,
+                                 DAG.getConstant(1, DL, ContainerVT),
+                                 DAG.getUNDEF(ContainerVT), Mask, VL),
+                     Mask55, DAG.getUNDEF(ContainerVT), Mask, VL);
+  V = DAG.getNode(RISCVISD::SUB_VL, DL, ContainerVT, V, Tmp1,
+                  DAG.getUNDEF(ContainerVT), Mask, VL);
+
+  // v = (v & 0x33333333...) + ((v >> 2) & 0x33333333...)
+  Tmp2 = DAG.getNode(RISCVISD::AND_VL, DL, ContainerVT, V, Mask33,
+                     DAG.getUNDEF(ContainerVT), Mask, VL);
+  Tmp3 = DAG.getNode(RISCVISD::AND_VL, DL, ContainerVT,
+                     DAG.getNode(RISCVISD::SRL_VL, DL, ContainerVT, V,
+                                 DAG.getConstant(2, DL, ContainerVT),
+                                 DAG.getUNDEF(ContainerVT), Mask, VL),
+                     Mask33, DAG.getUNDEF(ContainerVT), Mask, VL);
+  V = DAG.getNode(RISCVISD::ADD_VL, DL, ContainerVT, Tmp2, Tmp3,
+                  DAG.getUNDEF(ContainerVT), Mask, VL);
+
+  // v = (v + (v >> 4)) & 0x0F0F0F0F...
+  Tmp4 = DAG.getNode(RISCVISD::SRL_VL, DL, ContainerVT, V,
+                     DAG.getConstant(4, DL, ContainerVT),
+                     DAG.getUNDEF(ContainerVT), Mask, VL),
+  Tmp5 = DAG.getNode(RISCVISD::ADD_VL, DL, ContainerVT, V, Tmp4,
+                     DAG.getUNDEF(ContainerVT), Mask, VL);
+  V = DAG.getNode(RISCVISD::AND_VL, DL, ContainerVT, Tmp5, Mask0F,
+                  DAG.getUNDEF(ContainerVT), Mask, VL);
+
+  if (Len > 8) {
+    // v = (v * 0x01010101...) >> (Len - 8)
+    // 0x01010101... == (0x0F0F0F0F... & (0x0F0F0F0F... >> 3))
+    const APInt &Constant01 = APInt::getSplat(Len, APInt(8, 0x01));
+    SDValue Mask01 =
+        RISCVMatInt::getIntMatCost(
+            Constant01, ContainerVT.getScalarSizeInBits(), Subtarget) > 2
+            ? DAG.getNode(RISCVISD::AND_VL, DL, ContainerVT, Mask0F,
+                          DAG.getNode(RISCVISD::SRL_VL, DL, ContainerVT, Mask0F,
+                                      DAG.getConstant(3, DL, ContainerVT),
+                                      DAG.getUNDEF(ContainerVT), Mask, VL),
+                          DAG.getUNDEF(ContainerVT), Mask, VL)
+            : DAG.getConstant(Constant01, DL, ContainerVT);
+    V = DAG.getNode(RISCVISD::SRL_VL, DL, ContainerVT,
+                    DAG.getNode(RISCVISD::MUL_VL, DL, ContainerVT, V, Mask01,
+                                DAG.getUNDEF(ContainerVT), Mask, VL),
+                    DAG.getConstant(Len - 8, DL, ContainerVT),
+                    DAG.getUNDEF(ContainerVT), Mask, VL);
+  }
+
+  if (VT.isFixedLengthVector())
+    V = convertFromScalableVector(VT, V, DAG, Subtarget);
+  return V;
+}
+
 SDValue RISCVTargetLowering::lowerFixedLengthVectorFCOPYSIGNToRVV(
     SDValue Op, SelectionDAG &DAG) const {
   SDLoc DL(Op);
diff --git a/llvm/lib/Target/RISCV/RISCVISelLowering.h b/llvm/lib/Target/RISCV/RISCVISelLowering.h
index c11b1464757c7f..cc8a18d9088106 100644
--- a/llvm/lib/Target/RISCV/RISCVISelLowering.h
+++ b/llvm/lib/Target/RISCV/RISCVISelLowering.h
@@ -959,6 +959,9 @@ class RISCVTargetLowering : public TargetLowering {
   SDValue lowerGET_ROUNDING(SDValue Op, SelectionDAG &DAG) const;
   SDValue lowerSET_ROUNDING(SDValue Op, SelectionDAG &DAG) const;
 
+  SDValue lowerScalarCTPOP(SDValue Op, SelectionDAG &DAG) const;
+  SDValue lowerVectorCTPOP(SDValue Op, SelectionDAG &DAG) const;
+
   SDValue lowerEH_DWARF_CFA(SDValue Op, SelectionDAG &DAG) const;
   SDValue lowerCTLZ_CTTZ_ZERO_UNDEF(SDValue Op, SelectionDAG &DAG) const;
 
diff --git a/llvm/test/CodeGen/RISCV/ctlz-cttz-ctpop.ll b/llvm/test/CodeGen/RISCV/ctlz-cttz-ctpop.ll
index 455e6e54c9b396..1eaf91096336f3 100644
--- a/llvm/test/CodeGen/RISCV/ctlz-cttz-ctpop.ll
+++ b/llvm/test/CodeGen/RISCV/ctlz-cttz-ctpop.ll
@@ -53,28 +53,77 @@ define i8 @test_cttz_i8(i8 %a) nounwind {
 ; RV32_NOZBB-NEXT:    li a0, 8
 ; RV32_NOZBB-NEXT:    ret
 ;
-; RV64NOZBB-LABEL: test_cttz_i8:
-; RV64NOZBB:       # %bb.0:
-; RV64NOZBB-NEXT:    andi a1, a0, 255
-; RV64NOZBB-NEXT:    beqz a1, .LBB0_2
-; RV64NOZBB-NEXT:  # %bb.1: # %cond.false
-; RV64NOZBB-NEXT:    addi a1, a0, -1
-; RV64NOZBB-NEXT:    not a0, a0
-; RV64NOZBB-NEXT:    and a0, a0, a1
-; RV64NOZBB-NEXT:    srli a1, a0, 1
-; RV64NOZBB-NEXT:    andi a1, a1, 85
-; RV64NOZBB-NEXT:    subw a0, a0, a1
-; RV64NOZBB-NEXT:    andi a1, a0, 51
-; RV64NOZBB-NEXT:    srli a0, a0, 2
-; RV64NOZBB-NEXT:    andi a0, a0, 51
-; RV64NOZBB-NEXT:    add a0, a1, a0
-; RV64NOZBB-NEXT:    srli a1, a0, 4
-; RV64NOZBB-NEXT:    add a0, a0, a1
-; RV64NOZBB-NEXT:    andi a0, a0, 15
-; RV64NOZBB-NEXT:    ret
-; RV64NOZBB-NEXT:  .LBB0_2:
-; RV64NOZBB-NEXT:    li a0, 8
-; RV64NOZBB-NEXT:    ret
+; RV64I-LABEL: test_cttz_i8:
+; RV64I:       # %bb.0:
+; RV64I-NEXT:    andi a1, a0, 255
+; RV64I-NEXT:    beqz a1, .LBB0_2
+; RV64I-NEXT:  # %bb.1: # %cond.false
+; RV64I-NEXT:    addi sp, sp, -16
+; RV64I-NEXT:    sd ra, 8(sp) # 8-byte Folded Spill
+; RV64I-NEXT:    lui a1, 61681
+; RV64I-NEXT:    addiw a1, a1, -241
+; RV64I-NEXT:    slli a2, a1, 32
+; RV64I-NEXT:    add a1, a1, a2
+; RV64I-NEXT:    slli a2, a1, 2
+; RV64I-NEXT:    xor a2, a2, a1
+; RV64I-NEXT:    addi a3, a0, -1
+; RV64I-NEXT:    not a0, a0
+; RV64I-NEXT:    and a0, a0, a3
+; RV64I-NEXT:    andi a3, a0, 255
+; RV64I-NEXT:    srli a0, a0, 1
+; RV64I-NEXT:    andi a0, a0, 85
+; RV64I-NEXT:    sub a3, a3, a0
+; RV64I-NEXT:    and a0, a3, a2
+; RV64I-NEXT:    srli a3, a3, 2
+; RV64I-NEXT:    and a2, a3, a2
+; RV64I-NEXT:    add a0, a0, a2
+; RV64I-NEXT:    srli a2, a0, 4
+; RV64I-NEXT:    add a0, a0, a2
+; RV64I-NEXT:    and a0, a0, a1
+; RV64I-NEXT:    srli a2, a1, 3
+; RV64I-NEXT:    and a1, a2, a1
+; RV64I-NEXT:    call __muldi3
+; RV64I-NEXT:    srli a0, a0, 56
+; RV64I-NEXT:    ld ra, 8(sp) # 8-byte Folded Reload
+; RV64I-NEXT:    addi sp, sp, 16
+; RV64I-NEXT:    ret
+; RV64I-NEXT:  .LBB0_2:
+; RV64I-NEXT:    li a0, 8
+; RV64I-NEXT:    ret
+;
+; RV64M-LABEL: test_cttz_i8:
+; RV64M:       # %bb.0:
+; RV64M-NEXT:    andi a1, a0, 255
+; RV64M-NEXT:    beqz a1, .LBB0_2
+; RV64M-NEXT:  # %bb.1: # %cond.false
+; RV64M-NEXT:    lui a1, 61681
+; RV64M-NEXT:    addiw a1, a1, -241
+; RV64M-NEXT:    slli a2, a1, 32
+; RV64M-NEXT:    add a1, a1, a2
+; RV64M-NEXT:    slli a2, a1, 2
+; RV64M-NEXT:    xor a2, a2, a1
+; RV64M-NEXT:    addi a3, a0, -1
+; RV64M-NEXT:    not a0, a0
+; RV64M-NEXT:    and a0, a0, a3
+; RV64M-NEXT:    andi a3, a0, 255
+; RV64M-NEXT:    srli a0, a0, 1
+; RV64M-NEXT:    andi a0, a0, 85
+; RV64M-NEXT:    sub a3, a3, a0
+; RV64M-NEXT:    and a0, a3, a2
+; RV64M-NEXT:    srli a3, a3, 2
+; RV64M-NEXT:    and a2, a3, a2
+; RV64M-NEXT:    add a0, a0, a2
+; RV64M-NEXT:    srli a2, a0, 4
+; RV64M-NEXT:    add a0, a0, a2
+; RV64M-NEXT:    and a0, a0, a1
+; RV64M-NEXT:    srli a2, a1, 3
+; RV64M-NEXT:    and a1, a2, a1
+; RV64M-NEXT:    mul a0, a0, a1
+; RV64M-NEXT:    srli a0, a0, 56
+; RV64M-NEXT:    ret
+; RV64M-NEXT:  .LBB0_2:
+; RV64M-NEXT:    li a0, 8
+; RV64M-NEXT:    ret
 ;
 ; RV32ZBB-LABEL: test_cttz_i8:
 ; RV32ZBB:       # %bb.0:
@@ -154,35 +203,83 @@ define i16 @test_cttz_i16(i16 %a) nounwind {
 ; RV32_NOZBB-NEXT:    li a0, 16
 ; RV32_NOZBB-NEXT:    ret
 ;
-; RV64NOZBB-LABEL: test_cttz_i16:
-; RV64NOZBB:       # %bb.0:
-; RV64NOZBB-NEXT:    slli a1, a0, 48
-; RV64NOZBB-NEXT:    beqz a1, .LBB1_2
-; RV64NOZBB-NEXT:  # %bb.1: # %cond.false
-; RV64NOZBB-NEXT:    addi a1, a0, -1
-; RV64NOZBB-NEXT:    not a0, a0
-; RV64NOZBB-NEXT:    and a0, a0, a1
-; RV64NOZBB-NEXT:    srli a1, a0, 1
-; RV64NOZBB-NEXT:    lui a2, 5
-; RV64NOZBB-NEXT:    addiw a2, a2, 1365
-; RV64NOZBB-NEXT:    and a1, a1, a2
-; RV64NOZBB-NEXT:    sub a0, a0, a1
-; RV64NOZBB-NEXT:    lui a1, 3
-; RV64NOZBB-NEXT:    addiw a1, a1, 819
-; RV64NOZBB-NEXT:    and a2, a0, a1
-; RV64NOZBB-NEXT:    srli a0, a0, 2
-; RV64NOZBB-NEXT:    and a0, a0, a1
-; RV64NOZBB-NEXT:    add a0, a2, a0
-; RV64NOZBB-NEXT:    srli a1, a0, 4
-; RV64NOZBB-NEXT:    add a0, a0, a1
-; RV64NOZBB-NEXT:    andi a1, a0, 15
-; RV64NOZBB-NEXT:    slli a0, a0, 52
-; RV64NOZBB-NEXT:    srli a0, a0, 60
-; RV64NOZBB-NEXT:    add a0, a1, a0
-; RV64NOZBB-NEXT:    ret
-; RV64NOZBB-NEXT:  .LBB1_2:
-; RV64NOZBB-NEXT:    li a0, 16
-; RV64NOZBB-NEXT:    ret
+; RV64I-LABEL: test_cttz_i16:
+; RV64I:       # %bb.0:
+; RV64I-NEXT:    slli a1, a0, 48
+; RV64I-NEXT:    beqz a1, .LBB1_2
+; RV64I-NEXT:  # %bb.1: # %cond.false
+; RV64I-NEXT:    addi sp, sp, -16
+; RV64I-NEXT:    sd ra, 8(sp) # 8-byte Folded Spill
+; RV64I-NEXT:    lui a1, 61681
+; RV64I-NEXT:    addiw a1, a1, -241
+; RV64I-NEXT:    slli a2, a1, 32
+; RV64I-NEXT:    add a1, a1, a2
+; RV64I-NEXT:    slli a2, a1, 2
+; RV64I-NEXT:    xor a2, a2, a1
+; RV64I-NEXT:    addi a3, a0, -1
+; RV64I-NEXT:    not a0, a0
+; RV64I-NEXT:    and a0, a0, a3
+; RV64I-NEXT:    srli a3, a0, 1
+; RV64I-NEXT:    lui a4, 5
+; RV64I-NEXT:    addiw a4, a4, 1365
+; RV64I-NEXT:    and a3, a3, a4
+; RV64I-NEXT:    slli a0, a0, 48
+; RV64I-NEXT:    srli a0, a0, 48
+; RV64I-NEXT:    sub a0, a0, a3
+; RV64I-NEXT:    and a3, a0, a2
+; RV64I-NEXT:    srli a0, a0, 2
+; RV64I-NEXT:    and a0, a0, a2
+; RV64I-NEXT:    add a0, a3, a0
+; RV64I-NEXT:    srli a2, a0, 4
+; RV64I-NEXT:    add a0, a0, a2
+; RV64I-NEXT:    and a0, a0, a1
+; RV64I-NEXT:    srli a2, a1, 3
+; RV64I-NEXT:    and a1, a2, a1
+; RV64I-NEXT:    call __muldi3
+; RV64I-NEXT:    srli a0, a0, 56
+; RV64I-NEXT:    ld ra, 8(sp) # 8-byte Folded Reload
+; RV64I-NEXT:    addi sp, sp, 16
+; RV64I-NEXT:    ret
+; RV64I-NEXT:  .LBB1_2:
+; RV64I-NEXT:    li a0, 16
+; RV64I-NEXT:    ret
+;
+; RV64M-LABEL: test_cttz_i16:
+; RV64M:       # %bb.0:
+; RV64M-NEXT:    slli a1, a0, 48
+; RV64M-NEXT:    beqz a1, .LBB1_2
+; RV64M-NEXT:  # %bb.1: # %cond.false
+; RV64M-NEXT:    lui a1, 61681
+; RV64M-NEXT:    addiw a1, a1, -241
+; RV64M-NEXT:    slli a2, a1, 32
+; RV64M-NEXT:    add a1, a1, a2
+; RV64M-NEXT:    slli a2, a1, 2
+; RV64M-NEXT:    xor a2, a2, a1
+; RV64M-NEXT:    addi a3, a0, -1
+; RV64M-NEXT:    not a0, a0
+; RV64M-NEXT:    and a0, a0, a3
+; RV64M-NEXT:    srli a3, a0, 1
+; RV64M-NEXT:    lui a4, 5
+; RV64M-NEXT:    addiw a4, a4, 1365
+; RV64M-NEXT:    and a3, a3, a4
+; RV64M-NEXT:    slli a0, a0, 48
+; RV64M-NEXT:    srli a0, a0, 48
+; RV64M-NEXT:    sub a0, a0, a3
+; RV64M-NEXT:    and a3, a0, a2
+; RV64M-NEXT:    srli a0, a0, 2
+; RV64M-NEXT:    and a0, a0, a2
+; RV64M-NEXT:    add a0, a3, a0
+; RV64M-NEXT:    srli a2, a0, 4
+; RV64M-NEXT:    add a0, a0, a2
+; RV64M-NEXT:    and a0, a0, a1
+; RV64M-NEXT:    srli a2, a1, 3
+; RV64M-NEXT:    and a1, a2, a1
+; RV64M-NEXT:    mul a0, a0, a1
+; RV64M-NEXT:    srli a0, a0, 56
+; RV64M-NEXT:    ret
+; RV64M-NEXT:  .LBB1_2:
+; RV64M-NEXT:    li a0, 16
+; RV64M-NEXT:    ret
 ;
 ; RV32ZBB-LABEL: test_cttz_i16:
 ; RV32ZBB:       # %bb.0:
@@ -422,16 +519,33 @@ define i64 @test_cttz_i64(i64 %a) nounwind {
 ; RV64I-NEXT:  # %bb.1: # %cond.false
 ; RV64I-NEXT:    addi sp, sp, -16
 ; RV64I-NEXT:    sd ra, 8(sp) # 8-byte Folded Spill
-; RV64I-NEXT:    n...
[truncated]

@wangpc-pp
Copy link
Contributor Author

There is a problem that we will lowering ISD::CTLZ/ISD::CTTZ to ISD::CTPOP if ISD::CTPOP is Legal or Custom, which causes some regressions. I am still fixing it.

@efriedma-quic
Copy link
Collaborator

For the vector case, I think it ends up cheaper if you always start by bitcasting the operand to an i8 vector, and doing an i8 popcount. That makes the constants involved significantly cheaper. (The width you use to do the arithmetic doesn't matter for correctness, but doing it in i8 means you need fewer vsetvli, I think). Then you can bitcast back to the wider type to do the multiply/shift.

@wangpc-pp
Copy link
Contributor Author

For the vector case, I think it ends up cheaper if you always start by bitcasting the operand to an i8 vector, and doing an i8 popcount. That makes the constants involved significantly cheaper. (The width you use to do the arithmetic doesn't matter for correctness, but doing it in i8 means you need fewer vsetvli, I think). Then you can bitcast back to the wider type to do the multiply/shift.

I don't know if this is benefical. Using smaller EEW means that we can simplify the materialization but we may execute more uops I think, which may result in worse performance.

@topperc
Copy link
Collaborator

topperc commented Mar 29, 2024

For the vector case, I think it ends up cheaper if you always start by bitcasting the operand to an i8 vector, and doing an i8 popcount. That makes the constants involved significantly cheaper. (The width you use to do the arithmetic doesn't matter for correctness, but doing it in i8 means you need fewer vsetvli, I think). Then you can bitcast back to the wider type to do the multiply/shift.

I don't know if this is benefical. Using smaller EEW means that we can simplify the materialization but we may execute more uops I think, which may result in worse performance.

Hopefully simple bitwise operations like shifts, ands, and adds don't get split based on element width.

@wangpc-pp
Copy link
Contributor Author

For the vector case, I think it ends up cheaper if you always start by bitcasting the operand to an i8 vector, and doing an i8 popcount. That makes the constants involved significantly cheaper. (The width you use to do the arithmetic doesn't matter for correctness, but doing it in i8 means you need fewer vsetvli, I think). Then you can bitcast back to the wider type to do the multiply/shift.

I don't know if this is benefical. Using smaller EEW means that we can simplify the materialization but we may execute more uops I think, which may result in worse performance.

Hopefully simple bitwise operations like shifts, ands, and adds don't get split based on element width.

Yeah, most instructions will be split to chunks of datapath length in most implementations, but I do know an implementation that doesn't.

@wangpc-pp wangpc-pp force-pushed the main-riscv-popcount-constants branch from dabc3a5 to 2073e7b Compare March 29, 2024 09:19
…stly

For RISCV, it is costly to materialize constants used in lowering
`ISD::CTPOP`/`ISD::VP_CTPOP`.

We can query the materialization cost via `RISCVMatInt::getIntMatCost`
and if the cost is larger than 2, we should construct the constant
via two instructions.

This PR contains only vector part. We don't do the optimization if
LMUL is equal to LMUL_8 as it increases register preasure.
@wangpc-pp wangpc-pp force-pushed the main-riscv-popcount-constants branch from 2073e7b to 2db829e Compare March 29, 2024 09:20
@wangpc-pp
Copy link
Contributor Author

Only vector part is included and we don't optimize if LMUL==M8 as it increases register preasure. Now there is no noticeable regression.

@topperc
Copy link
Collaborator

topperc commented Mar 29, 2024

For the vector case, I think it ends up cheaper if you always start by bitcasting the operand to an i8 vector, and doing an i8 popcount. That makes the constants involved significantly cheaper. (The width you use to do the arithmetic doesn't matter for correctness, but doing it in i8 means you need fewer vsetvli, I think). Then you can bitcast back to the wider type to do the multiply/shift.

I don't know if this is benefical. Using smaller EEW means that we can simplify the materialization but we may execute more uops I think, which may result in worse performance.

Hopefully simple bitwise operations like shifts, ands, and adds don't get split based on element width.

Yeah, most instructions will be split to chunks of datapath length in most implementations, but I do know an implementation that doesn't.

Can you share what implementation that is?

@efriedma-quic
Copy link
Collaborator

Using the smaller EEW reduces both the total number of instructions, and the number of vector uops, on implementations that split on the size of the vector. And on an implementation which splits on the number of elements, we'd need a bunch of other tweaks to generate good code, anyway. (For example, we'd want to rewrite i8 vector popcount to use i64 operations.)

@wangpc-pp
Copy link
Contributor Author

For the vector case, I think it ends up cheaper if you always start by bitcasting the operand to an i8 vector, and doing an i8 popcount. That makes the constants involved significantly cheaper. (The width you use to do the arithmetic doesn't matter for correctness, but doing it in i8 means you need fewer vsetvli, I think). Then you can bitcast back to the wider type to do the multiply/shift.

I don't know if this is benefical. Using smaller EEW means that we can simplify the materialization but we may execute more uops I think, which may result in worse performance.

Hopefully simple bitwise operations like shifts, ands, and adds don't get split based on element width.

Yeah, most instructions will be split to chunks of datapath length in most implementations, but I do know an implementation that doesn't.

Can you share what implementation that is?

Just checked, I was wrong as the implementation doesn't split shift/and/add based on EEW (It's a small core targeting for low-end IoT scenarios and vlen is really small). Sorry for that.
Besides, maybe we can do some optimizations for more possible operations if this simple arithmetics don't get split based on EEW assumption is right?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants