Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Subject: [PATCH] [AArch64ISelLowering] Optimize rounding shift and saturation truncation #74325

Open
wants to merge 1 commit into
base: main
Choose a base branch
from

Conversation

JohnLee1243
Copy link
Contributor

This patch does 2 kinds of instruction simplification.

  1. Rounding shift operation like SHIFT(Add(OpA, 1<<(imm-1)), imm) can be simplified as srshrn(OpA, imm).
  2. Rounding shift saturation truncation operations like Trunc(min(max(Shift(Add(OpA, 1<<(imm-1),imm)),0),maxValue)) can be simplified as uqrshrn(OpA, imm)or sqrshrun(OpA, imm).
  3. Add a pattern for RSHRN. This patch does these optimization in backend after legalization.

@llvmbot
Copy link
Collaborator

llvmbot commented Dec 4, 2023

@llvm/pr-subscribers-backend-aarch64

Author: None (JohnLee1243)

Changes

This patch does 2 kinds of instruction simplification.

  1. Rounding shift operation like SHIFT(Add(OpA, 1<<(imm-1)), imm) can be simplified as srshrn(OpA, imm).
  2. Rounding shift saturation truncation operations like Trunc(min(max(Shift(Add(OpA, 1<<(imm-1),imm)),0),maxValue)) can be simplified as uqrshrn(OpA, imm)or sqrshrun(OpA, imm).
  3. Add a pattern for RSHRN. This patch does these optimization in backend after legalization.

Patch is 41.77 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/74325.diff

5 Files Affected:

  • (modified) llvm/lib/Target/AArch64/AArch64ISelDAGToDAG.cpp (+11)
  • (modified) llvm/lib/Target/AArch64/AArch64ISelLowering.cpp (+595-1)
  • (modified) llvm/lib/Target/AArch64/AArch64InstrInfo.td (+1-1)
  • (added) llvm/test/CodeGen/AArch64/isel_rounding_opt.ll (+327)
  • (modified) llvm/test/CodeGen/AArch64/neon-rshrn.ll (+11-19)
diff --git a/llvm/lib/Target/AArch64/AArch64ISelDAGToDAG.cpp b/llvm/lib/Target/AArch64/AArch64ISelDAGToDAG.cpp
index 2f49e9a6b37cc..b559a470ed3f2 100644
--- a/llvm/lib/Target/AArch64/AArch64ISelDAGToDAG.cpp
+++ b/llvm/lib/Target/AArch64/AArch64ISelDAGToDAG.cpp
@@ -177,6 +177,17 @@ class AArch64DAGToDAGISel : public SelectionDAGISel {
   }
 
   bool SelectRoundingVLShr(SDValue N, SDValue &Res1, SDValue &Res2) {
+    if (N.getOpcode() == AArch64ISD::URSHR_I) {
+      EVT VT = N.getValueType();
+      unsigned ShtAmt = N->getConstantOperandVal(1);
+      if (ShtAmt >= VT.getScalarSizeInBits() / 2)
+        return false;
+
+      Res1 = N.getOperand(0);
+      Res2 = CurDAG->getTargetConstant(ShtAmt, SDLoc(N), MVT::i32);
+      return true;
+    }
+
     if (N.getOpcode() != AArch64ISD::VLSHR)
       return false;
     SDValue Op = N->getOperand(0);
diff --git a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
index b6a16217dfae3..17da58f892baa 100644
--- a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
+++ b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
@@ -111,6 +111,28 @@ STATISTIC(NumTailCalls, "Number of tail calls");
 STATISTIC(NumShiftInserts, "Number of vector shift inserts");
 STATISTIC(NumOptimizedImms, "Number of times immediates were optimized");
 
+static cl::opt<bool> EnableAArch64RoundingOpt(
+    "aarch64-optimize-rounding", cl::Hidden,
+    cl::desc("Enable AArch64 rounding optimization"),
+    cl::init(true));
+
+static cl::opt<bool> EnableAArch64RoundingSatNarrowOpt(
+    "aarch64-optimize-rounding-saturation", cl::Hidden,
+    cl::desc("Enable AArch64 rounding and saturation narrow optimization"),
+    cl::init(true));
+
+static cl::opt<bool> EnableAArch64ExtractVecElementCombine(
+    "aarch64-extract-vector-element-trunc-combine", cl::Hidden,
+    cl::desc("Allow AArch64 extract vector element combination with "
+             "truncation"),
+    cl::init(true));
+
+static cl::opt<int> RoundingSearchMaxDepth(
+    "aarch64-rounding-search-max-depth", cl::Hidden,
+    cl::desc("Maximum depth to bfs search rounding value in rounding "
+             "optimization"),
+    cl::init(4));
+
 // FIXME: The necessary dtprel relocations don't seem to be supported
 // well in the GNU bfd and gold linkers at the moment. Therefore, by
 // default, for now, fall back to GeneralDynamic code generation.
@@ -995,6 +1017,8 @@ AArch64TargetLowering::AArch64TargetLowering(const TargetMachine &TM,
 
   setTargetDAGCombine(ISD::INTRINSIC_WO_CHAIN);
 
+  setTargetDAGCombine(ISD::TRUNCATE);
+
   setTargetDAGCombine({ISD::ANY_EXTEND, ISD::ZERO_EXTEND, ISD::SIGN_EXTEND,
                        ISD::VECTOR_SPLICE, ISD::SIGN_EXTEND_INREG,
                        ISD::CONCAT_VECTORS, ISD::EXTRACT_SUBVECTOR,
@@ -17681,6 +17705,482 @@ static SDValue performANDCombine(SDNode *N,
   return SDValue();
 }
 
+// BFS search the operand which is equal to rounding value
+static bool searchRoundingValueBFS(
+    uint64_t RoundingValue, const SDValue &OperandIn,
+    SmallVectorImpl<std::pair<SDValue, SDNodeFlags>> &AddOperands, int Level) {
+  SmallVector<SDValue, 4> WorkList;
+  WorkList.emplace_back(OperandIn);
+  while (Level > 0 && !WorkList.empty()) {
+    auto Operand = WorkList.front();
+    SmallVector<SDValue>::iterator k = WorkList.begin();
+    WorkList.erase(k);
+    Level--;
+    SDValue Operand0 = Operand.getOperand(0);
+    SDValue Operand1 = Operand.getOperand(1);
+    BuildVectorSDNode *AddOp0 = dyn_cast<BuildVectorSDNode>(Operand0);
+    BuildVectorSDNode *AddOp1 = dyn_cast<BuildVectorSDNode>(Operand1);
+    auto foundRounding = [&](BuildVectorSDNode *AddOp, SDValue &OtherOperand) {
+      APInt SplatBitsAdd, SplatUndefAdd;
+      unsigned SplatBitSizeAdd = 0;
+      bool HasAnyUndefsAnd = false;
+      if (AddOp &&
+          AddOp->isConstantSplat(SplatBitsAdd, SplatUndefAdd, SplatBitSizeAdd,
+                                 HasAnyUndefsAnd) &&
+          (SplatBitsAdd == RoundingValue)) {
+        AddOperands.emplace_back(
+            std::make_pair(OtherOperand, OtherOperand.getNode()->getFlags()));
+        while (!WorkList.empty()) {
+          SDValue TempVal = WorkList.front();
+          SmallVector<SDValue>::iterator k = WorkList.begin();
+          WorkList.erase(k);
+          AddOperands.emplace_back(
+              std::make_pair(TempVal, TempVal.getNode()->getFlags()));
+        }
+        return true;
+      }
+      return false;
+    };
+    if (foundRounding(AddOp0, Operand1))
+      return true;
+    if (foundRounding(AddOp1, Operand0))
+      return true;
+    if (Operand0.getOpcode() == ISD::ADD)
+      WorkList.emplace_back(Operand0);
+    else
+      AddOperands.emplace_back(
+          std::make_pair(Operand0, Operand.getNode()->getFlags()));
+    if (Operand1.getOpcode() == ISD::ADD)
+      WorkList.emplace_back(Operand1);
+    else
+      AddOperands.emplace_back(
+          std::make_pair(Operand1, Operand.getNode()->getFlags()));
+  }
+
+  return false;
+}
+
+// Try to match pattern "OpB = SHIFT(Add(OpA, 1<<(imm-1)), imm) ", where
+// shift must be an immediate number
+static SDValue matchShiftRounding(const SDValue &ShiftOp0,
+                                  const SDValue &ShiftOp1, SelectionDAG &DAG,
+                                  int64_t &ShiftAmount) {
+  ShiftAmount = 0;
+  // For illegal type, do nothing. Wait until type is legalized.
+  EVT VT0 = ShiftOp0.getValueType();
+  if (!VT0.isVector() || !DAG.getTargetLoweringInfo().isTypeLegal(VT0))
+    return SDValue();
+
+  BuildVectorSDNode *ShiftOperand1 = dyn_cast<BuildVectorSDNode>(ShiftOp1);
+
+  // Shift value must be an immediate, either a constant splat vector form or
+  // scalar form.
+  int64_t TempAmount;
+  EVT VT = ShiftOp1.getValueType();
+  if (ShiftOperand1 && VT.isVector() &&
+      isVShiftRImm(ShiftOp1, VT, false, TempAmount))
+    ShiftAmount = TempAmount;
+
+  if (ConstantSDNode *C = dyn_cast<ConstantSDNode>(ShiftOp1))
+    ShiftAmount = C->getSExtValue();
+
+  // For shift value = 1, match XHADD in first priority which accomplishes (a +
+  // b +1)>>1
+  if (ShiftOp0.getOpcode() == ISD::ADD && ShiftAmount > 1 &&
+      ShiftAmount <= 64) {
+    uint64_t RoundingValue = 1 << static_cast<uint64_t>(ShiftAmount - 1);
+    SmallVector<std::pair<SDValue, SDNodeFlags>, 4> AddOperands;
+
+    // In case expression has pattern "(a + roundingValue+b+c+d ) >> shift ",
+    // in which rounding value is not the direct input operand of shift,
+    // rounding value should be searched from root to leaves. In this case, the
+    // expresion will be matched as "RoundingShift(a + b + c + d, shift)"
+    if (searchRoundingValueBFS(RoundingValue, ShiftOp0, AddOperands,
+                               RoundingSearchMaxDepth)) {
+      SDLoc DL(ShiftOp0);
+      EVT VTAdd = AddOperands[0].first.getValueType();
+
+      for (size_t i = 1; i < AddOperands.size(); i++) {
+        AddOperands[i].first =
+            DAG.getNode(ISD::ADD, DL, VTAdd, AddOperands[i].first,
+                        AddOperands[i - 1].first, AddOperands[i].second);
+      }
+      return AddOperands[AddOperands.size() - 1].first;
+    }
+  }
+
+  return SDValue();
+}
+
+// Attempt to form XRSHR(OpA, imm) from "SRX(Add(OpA, 1<<(imm-1)), imm) "
+// Depending on logic and arithmetic shift, 'XRSHR' can be 'SRSHR' or 'URSHR'
+// and 'SRA' or 'SRL' for SRX.
+
+static SDValue matchXRSHR(SDNode *N, SelectionDAG &DAG,
+                          unsigned ShiftRoundingOpc) {
+  EVT VT = N->getValueType(0);
+  int64_t ShiftAmount;
+  SDValue AddOperand =
+      matchShiftRounding(N->getOperand(0), N->getOperand(1), DAG, ShiftAmount);
+  if (!AddOperand)
+    return SDValue();
+  SDLoc DL(N);
+  SDValue ResultRounding = DAG.getNode(
+      ShiftRoundingOpc, DL, VT, AddOperand,
+      DAG.getConstant(static_cast<uint64_t>(ShiftAmount), DL, MVT::i32));
+  return ResultRounding;
+}
+
+static SDValue performSRLCombine(SDNode *N,
+                                 TargetLowering::DAGCombinerInfo &DCI) {
+  SelectionDAG &DAG = DCI.DAG;
+  EVT VT = N->getValueType(0);
+
+  // Attempt to form URSHR(OpA, imm) from "SRL(Add(OpA, 1<<(imm-1), imm)"
+  SDValue ResultRounding;
+  if (EnableAArch64RoundingOpt && VT.isVector() &&
+      static_cast<const AArch64Subtarget &>(DAG.getSubtarget()).hasNEON())
+    ResultRounding = matchXRSHR(N, DAG, AArch64ISD::URSHR_I);
+
+  if (ResultRounding)
+    return ResultRounding;
+
+  if (VT != MVT::i32 && VT != MVT::i64)
+    return SDValue();
+
+  // Canonicalize (srl (bswap i32 x), 16) to (rotr (bswap i32 x), 16), if the
+  // high 16-bits of x are zero. Similarly, canonicalize (srl (bswap i64 x), 32)
+  // to (rotr (bswap i64 x), 32), if the high 32-bits of x are zero.
+  SDValue N0 = N->getOperand(0);
+  if (N0.getOpcode() == ISD::BSWAP) {
+    SDLoc DL(N);
+    SDValue N1 = N->getOperand(1);
+    SDValue N00 = N0.getOperand(0);
+    if (ConstantSDNode *C = dyn_cast<ConstantSDNode>(N1)) {
+      uint64_t ShiftAmt = C->getZExtValue();
+      if (VT == MVT::i32 && ShiftAmt == 16 &&
+          DAG.MaskedValueIsZero(N00, APInt::getHighBitsSet(32, 16)))
+        return DAG.getNode(ISD::ROTR, DL, VT, N0, N1);
+      if (VT == MVT::i64 && ShiftAmt == 32 &&
+          DAG.MaskedValueIsZero(N00, APInt::getHighBitsSet(64, 32)))
+        return DAG.getNode(ISD::ROTR, DL, VT, N0, N1);
+    }
+  }
+  return SDValue();
+}
+
+static SDValue performSRACombine(SDNode *N,
+                                 TargetLowering::DAGCombinerInfo &DCI) {
+  SelectionDAG &DAG = DCI.DAG;
+  EVT VT = N->getValueType(0);
+
+  // Attempt to form URSHR(OpA, shift) from "SRL(Add(OpA, 1<<(imm-1), imm) "
+  SDValue ResultRounding;
+  if (EnableAArch64RoundingOpt && VT.isVector() &&
+      static_cast<const AArch64Subtarget &>(DAG.getSubtarget()).hasNEON())
+    ResultRounding = matchXRSHR(N, DAG, AArch64ISD::SRSHR_I);
+
+  return ResultRounding;
+}
+
+// Try to match pattern "Trunc(min(max(Shift(Add(OpA, 1<<(imm-1),
+// imm)),0),maxValue)) " Here min(max(x,0),maxValue) is an unsigned
+// saturation operation, shift is signed right scalar or vector shift.
+static SDValue matchTruncSatRounding(const SDValue &UminOp, SelectionDAG &DAG,
+                                     uint64_t MaxScalarValue,
+                                     int64_t &ShiftAmount) {
+  if (UminOp.getOpcode() != ISD::UMIN)
+    return SDValue();
+
+  APInt SplatBitsRound, SplatUndefRound;
+  unsigned SplatBitSizeRound = 0;
+  bool HasAnyUndefsRound = false;
+  uint64_t UminOp1VecVal = 0;
+
+  BuildVectorSDNode *UminOp1Vec =
+      dyn_cast<BuildVectorSDNode>(UminOp.getOperand(1));
+
+  if (UminOp1Vec &&
+      UminOp1Vec->isConstantSplat(SplatBitsRound, SplatUndefRound,
+                                  SplatBitSizeRound, HasAnyUndefsRound))
+    UminOp1VecVal = SplatBitsRound.getZExtValue();
+
+  if (UminOp1VecVal != MaxScalarValue)
+    return SDValue();
+
+  SDValue SmaxOp = UminOp.getOperand(0);
+  if (SmaxOp.getOpcode() != ISD::SMAX ||
+      !isNullOrNullSplat(SmaxOp.getOperand(1)))
+    return SDValue();
+
+  SDValue RoundingOp = SmaxOp.getOperand(0);
+  unsigned int RoundingOpCode = RoundingOp.getOpcode();
+  if (RoundingOpCode == AArch64ISD::VASHR ||
+      RoundingOpCode == AArch64ISD::VLSHR || RoundingOpCode == ISD::SRA ||
+      RoundingOpCode == ISD::SRL) {
+    SDValue AddOperand = matchShiftRounding(
+        RoundingOp.getOperand(0), RoundingOp.getOperand(1), DAG, ShiftAmount);
+    // Rounding+Truncation instruction doesn't support shift amount > input data
+    // width/2
+    int64_t InSize = UminOp.getValueType().getScalarSizeInBits();
+    if (ShiftAmount > InSize / 2)
+      return SDValue();
+    return AddOperand;
+  }
+
+  return SDValue();
+}
+
+// A helper function to build sqrshrun or uqrshrn instruction
+SDValue generateQrshrnInstruction(unsigned ShiftRoundingOpc, SelectionDAG &DAG,
+                                  SDValue &AddOperand, int64_t ShiftAmount) {
+  SDLoc DL(AddOperand);
+  EVT InVT = AddOperand.getValueType();
+  EVT HalvedVT = InVT.changeVectorElementType(
+      InVT.getVectorElementType().getHalfSizedIntegerVT(*DAG.getContext()));
+
+  SDValue ResultRounding = DAG.getNode(
+      ISD::INTRINSIC_WO_CHAIN, DL, HalvedVT,
+      DAG.getConstant(ShiftRoundingOpc, DL, MVT::i64), AddOperand,
+      DAG.getConstant(static_cast<uint64_t>(ShiftAmount), DL, MVT::i32));
+  return ResultRounding;
+}
+
+// Attempt to match Trunc(Concat(Trunc(min(max(Shift(Add(OpA, 1<<(imm1-1),
+// imm1)),0),maxValue)),Trunc(min(max(Shift(Add(OpB, 1<<(imm2-1),
+// imm2)),0),maxValue))))
+// The pattern in tree is shown below
+//  OpA     1<<(imm1-1)                         OpB     1<<(imm2-1)
+//    \       /                                    \       /
+//     \     /                                      \     /
+//      Add      imm1                       imm2      Add
+//        \      /                              \      /
+//         \    /                                \    /
+//         Shift     0                    0      Shift
+//            \     /                      \      /
+//             \   /                        \    /
+//              max     maxVal      maxVal   max
+//                \     /             \      /
+//                 \   /               \    /
+//                  min                 min
+//                    \                 /
+//                     \               /
+//                    Trunc         Trunc
+//                        \         /
+//                         \       /
+//                          Concat
+//                             |
+//                             |
+//                           Trunc
+// The pattern will be matched as uqxtn(concat(qrshrn(OpA,imm1),
+// qrshrn(OpB,imm2))) where uqxtn is saturation truncation, qrshrn is sqrshrun
+// or uqrshrn.
+static SDValue matchQrshrn2(SDNode *N, SelectionDAG &DAG) {
+  TypeSize OutSizeTyped = N->getValueSizeInBits(0);
+  if (OutSizeTyped.isScalable())
+    return SDValue();
+  uint64_t OutSize = OutSizeTyped;
+  SDValue Operand0 = N->getOperand(0);
+  if (Operand0.getOpcode() != ISD::CONCAT_VECTORS || OutSize != 64)
+    return SDValue();
+
+  EVT VT = N->getValueType(0);
+  uint64_t OutScalarSize = VT.getScalarSizeInBits();
+  uint64_t MaxScalarValue = (1 << OutScalarSize) - 1;
+  SDLoc DL(N);
+  SDValue TruncOp0 = Operand0.getOperand(0);
+  SDValue TruncOp1 = Operand0.getOperand(1);
+  if (TruncOp0.getOpcode() != ISD::TRUNCATE ||
+      TruncOp1.getOpcode() != ISD::TRUNCATE)
+    return SDValue();
+
+  int64_t ShiftAmount0 = -1, ShiftAmount1 = -1;
+  SDValue AddOperand0 = matchTruncSatRounding(TruncOp0.getOperand(0), DAG,
+                                              MaxScalarValue, ShiftAmount0);
+  SDValue AddOperand1 = matchTruncSatRounding(TruncOp1.getOperand(0), DAG,
+                                              MaxScalarValue, ShiftAmount1);
+  if (!AddOperand0 || !AddOperand1)
+    return SDValue();
+
+  auto getShiftVal = [&](SDValue &TruncOp) {
+    if (SDValue Operand0 = TruncOp.getOperand(0))
+      if (SDValue Operand00 = Operand0.getOperand(0))
+        return Operand00.getOperand(0);
+
+    return SDValue();
+  };
+
+  SDValue Shift0Val = getShiftVal(TruncOp0);
+  SDValue Shift1Val = getShiftVal(TruncOp1);
+  if (!Shift0Val || !Shift1Val)
+    return SDValue();
+
+  unsigned Shift0Opc = Shift0Val.getOpcode();
+  unsigned Shift1Opc = Shift1Val.getOpcode();
+  unsigned TruncOpc0 = (Shift0Opc == AArch64ISD::VLSHR || Shift0Opc == ISD::SRL)
+                           ? Intrinsic::aarch64_neon_uqrshrn
+                           : Intrinsic::aarch64_neon_sqrshrun;
+  unsigned TruncOpc1 = (Shift1Opc == AArch64ISD::VLSHR || Shift1Opc == ISD::SRL)
+                           ? Intrinsic::aarch64_neon_uqrshrn
+                           : Intrinsic::aarch64_neon_sqrshrun;
+  SDValue ResultTrunc0 =
+      generateQrshrnInstruction(TruncOpc0, DAG, AddOperand0, ShiftAmount0);
+  SDValue ResultTrunc1 =
+      generateQrshrnInstruction(TruncOpc1, DAG, AddOperand1, ShiftAmount1);
+
+  EVT ConcatVT = ResultTrunc1.getValueType().getDoubleNumVectorElementsVT(
+      *DAG.getContext());
+  SDValue ConcatOp = DAG.getNode(ISD::CONCAT_VECTORS, DL, ConcatVT,
+                                 ResultTrunc0, ResultTrunc1);
+  // Notice MaxScalarValue is finial truncated type max value, so twice
+  // saturation is needed. uqxtn is saturation truncation.
+  SDValue ResultOp = DAG.getNode(
+      ISD::INTRINSIC_WO_CHAIN, DL, VT,
+      DAG.getConstant(Intrinsic::aarch64_neon_uqxtn, DL, MVT::i64), ConcatOp);
+
+  return ResultOp;
+}
+
+// To match SQRSHRUN from trunc(umin(smax((a + b + (1<<(imm-1))) >>imm,
+// max), 0))
+static SDValue matchQrshrn(SDNode *N, SelectionDAG &DAG) {
+  int64_t ShiftAmount = -1;
+  SDValue AddOperand;
+  TypeSize OutSizeTyped = N->getValueSizeInBits(0);
+  if (OutSizeTyped.isScalable())
+    return SDValue();
+  uint64_t OutSize = OutSizeTyped;
+  EVT VT = N->getValueType(0);
+
+  uint64_t OutScalarSize = VT.getScalarSizeInBits();
+  uint64_t MaxScalarValue = (1 << OutScalarSize) - 1;
+  SDLoc DL(N);
+
+  if (OutSize <= 64 && OutSize >= 32)
+    AddOperand = matchTruncSatRounding(N->getOperand(0), DAG, MaxScalarValue,
+                                       ShiftAmount);
+
+  if (!AddOperand)
+    return SDValue();
+
+  uint64_t UminOpScalarSize =
+      N->getOperand(0).getValueType().getScalarSizeInBits();
+  unsigned ShiftOpc = N->getOperand(0).getOperand(0).getOperand(0).getOpcode();
+
+  unsigned TruncOpc = (ShiftOpc == AArch64ISD::VLSHR || ShiftOpc == ISD::SRL)
+                          ? Intrinsic::aarch64_neon_uqrshrn
+                          : Intrinsic::aarch64_neon_sqrshrun;
+
+  SDValue ResultTrunc =
+      generateQrshrnInstruction(TruncOpc, DAG, AddOperand, ShiftAmount);
+  // For pattern "trunc(trunc(umin(smax((a + b + (1<<(imm-1)))
+  // >>imm,0),max))", where max is the outer truncated type max value, so
+  // another saturation trunction is needed.
+  // Because final truncated type may be illegal, and there is no method to
+  // legalize intrinsic, so add redundant operation `CONCAT_VECTORS` and
+  // `EXTRACT_SUBVECTOR` to automatically legalize the operation.
+  // But generated asm code is not so optimized in this way.
+  // For example, generated asm:
+  // sqrshrun  v1.4h, v0.4s, #6
+  // sqrshrun2 v1.8h, v0.4s, #6
+  // uqxtn   v0.8b, v1.8h
+  // zip1    v0.16b, v0.16b, v0.16b
+  // xtn     v0.8b, v0.8h
+  //
+  // ideal Optimized code:
+  //
+  // sqrshrun  v1.4h, v0.4s, #6
+  // uqxtn   v0.8b, v1.8h
+  // To solve this issue, Function `hasUselessTrunc` is introduced which can
+  // optimize code like above.
+  if (ResultTrunc && OutScalarSize == UminOpScalarSize / 4) {
+    EVT ConcatVT = ResultTrunc.getValueType().getDoubleNumVectorElementsVT(
+        *DAG.getContext());
+    SDValue ConcatOp = DAG.getNode(
+        ISD::CONCAT_VECTORS, DL, ConcatVT, ResultTrunc,
+        DAG.getNode(ISD::UNDEF, SDLoc(), ResultTrunc.getValueType()));
+    EVT HalvedVT = ConcatVT.changeVectorElementType(
+        ConcatVT.getVectorElementType().getHalfSizedIntegerVT(
+            *DAG.getContext()));
+
+    ResultTrunc = DAG.getNode(
+        ISD::INTRINSIC_WO_CHAIN, DL, HalvedVT,
+        DAG.getConstant(Intrinsic::aarch64_neon_uqxtn, DL, MVT::i64), ConcatOp);
+
+    return DAG.getNode(ISD::EXTRACT_SUBVECTOR, DL, VT, ResultTrunc,
+                       DAG.getVectorIdxConstant(0, DL));
+  }
+
+  // Only support truncated type size = 1/2 or 1/4 of input type size
+  if (ResultTrunc)
+    assert(OutScalarSize == UminOpScalarSize / 2 &&
+           "Invalid Truncation Type Size!");
+  return ResultTrunc;
+}
+
+// Try to match pattern  (truncate (BUILD_VECTOR(a[i],a[i+1],..., x,x,..)))
+static bool hasUselessTrunc(SDValue &N, SDValue &UsefullValue,
+                            unsigned FinalEleNum, uint64_t OldExtractIndex,
+                            int &NewExtractIndex) {
+  if (N.getOpcode() != ISD::TRUNCATE)
+    return false;
+
+  SDValue N0 = N.getOperand(0);
+  if (N0.getOpcode() != ISD::BUILD_VECTOR)
+    return false;
+
+  SmallVector<SDValue, 8> ExtractValues;
+  SmallVector<int, 8> Indices;
+  unsigned ElementNum = N.getValueType().getVectorNumElements();
+
+  // for example, if v8i8 is bitcasted to v2i32, then `TransformedSize` is 4
+  unsigned TransformedSize = 0;
+  if (FinalEleNum != 0)
+    TransformedSize = ElementNum / FinalEleNum;
+  for...
[truncated]

Copy link

github-actions bot commented Dec 4, 2023

✅ With the latest revision this PR passed the C/C++ code formatter.

…turation truncation

This patch does 2 kinds of instruction simplification.
1. Rounding shift operation like SHIFT(Add(OpA, 1<<(imm-1)), imm)
can be simplified as srshrn(OpA, imm).
2. Rounding shift saturation truncation operations like
Trunc(min(max(Shift(Add(OpA, 1<<(imm-1),imm)),0),maxValue)) can be
simplified as uqrshrn(OpA, imm)or sqrshrun(OpA, imm).
3. Add a pattern for RSHRN.
This patch does these optimization in backend after legalization.
; CHECK-NEXT: movi v2.2d, #0000000000000000
; CHECK-NEXT: raddhn v0.8b, v0.8h, v2.8h
; CHECK-NEXT: raddhn2 v0.16b, v1.8h, v2.8h
; CHECK-NEXT: urshr v1.8h, v1.8h, #8
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

On the surface this looks worse than before. raddhn has a latency of 2, throughput of 4 on neoverse-v1, whereas urshr has a latency of 4 and throughput of 2. I think the original code would likely be faster. Not sure if there is an easy way of keeping the old version here?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

3 participants