Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[AArch64][SVE2] Generate urshr rounding shift rights #78374

Merged
merged 5 commits into from
Jan 31, 2024

Conversation

UsmanNadeem
Copy link
Contributor

@UsmanNadeem UsmanNadeem commented Jan 17, 2024

Matching code is taken from rshrnb except that immediate
shift value has a larger range, and support for signed shift. rshrnb code
now checks for the new AArch64ISD::URSHR_I_PRED node to convert
to narrowing rounding op.

Please let me know if there is a better way/location to do this!

@llvmbot
Copy link
Collaborator

llvmbot commented Jan 17, 2024

@llvm/pr-subscribers-backend-aarch64

Author: Usman Nadeem (UsmanNadeem)

Changes

Matching code is similar to that for rshrnb except that immediate
shift value has a larger range, and support for signed shift. rshrnb
now uses the new AArch64ISD node for uniform rounding.

Change-Id: Idbb811f318d33c7637371cf7bb00285d20e1771d


Full diff: https://github.com/llvm/llvm-project/pull/78374.diff

5 Files Affected:

  • (modified) llvm/lib/Target/AArch64/AArch64ISelLowering.cpp (+58-23)
  • (modified) llvm/lib/Target/AArch64/AArch64ISelLowering.h (+2)
  • (modified) llvm/lib/Target/AArch64/AArch64SVEInstrInfo.td (+6-4)
  • (modified) llvm/test/CodeGen/AArch64/sve2-intrinsics-combine-rshrnb.ll (+7-10)
  • (added) llvm/test/CodeGen/AArch64/sve2-rsh.ll (+203)
diff --git a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
index 91b36161ab46e89..d1731fcaabf8664 100644
--- a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
+++ b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
@@ -2649,6 +2649,8 @@ const char *AArch64TargetLowering::getTargetNodeName(unsigned Opcode) const {
     MAKE_CASE(AArch64ISD::MSRR)
     MAKE_CASE(AArch64ISD::RSHRNB_I)
     MAKE_CASE(AArch64ISD::CTTZ_ELTS)
+    MAKE_CASE(AArch64ISD::SRSHR_I_PRED)
+    MAKE_CASE(AArch64ISD::URSHR_I_PRED)
   }
 #undef MAKE_CASE
   return nullptr;
@@ -2933,6 +2935,7 @@ static SDValue convertToScalableVector(SelectionDAG &DAG, EVT VT, SDValue V);
 static SDValue convertFromScalableVector(SelectionDAG &DAG, EVT VT, SDValue V);
 static SDValue convertFixedMaskToScalableVector(SDValue Mask,
                                                 SelectionDAG &DAG);
+static SDValue getPredicateForVector(SelectionDAG &DAG, SDLoc &DL, EVT VT);
 static SDValue getPredicateForScalableVector(SelectionDAG &DAG, SDLoc &DL,
                                              EVT VT);
 
@@ -13713,6 +13716,42 @@ SDValue AArch64TargetLowering::LowerTRUNCATE(SDValue Op,
   return SDValue();
 }
 
+static SDValue tryLowerToRoundingShiftRightByImm(SDValue Shift,
+                                                 SelectionDAG &DAG) {
+  if (Shift->getOpcode() != ISD::SRL && Shift->getOpcode() != ISD::SRA)
+    return SDValue();
+
+  EVT ResVT = Shift.getValueType();
+  assert(ResVT.isScalableVT());
+
+  auto ShiftOp1 =
+      dyn_cast_or_null<ConstantSDNode>(DAG.getSplatValue(Shift->getOperand(1)));
+  if (!ShiftOp1)
+    return SDValue();
+  unsigned ShiftValue = ShiftOp1->getZExtValue();
+
+  if (ShiftValue < 1 || ShiftValue > ResVT.getScalarSizeInBits())
+    return SDValue();
+
+  SDValue Add = Shift->getOperand(0);
+  if (Add->getOpcode() != ISD::ADD || !Add->hasOneUse())
+    return SDValue();
+  auto AddOp1 =
+      dyn_cast_or_null<ConstantSDNode>(DAG.getSplatValue(Add->getOperand(1)));
+  if (!AddOp1)
+    return SDValue();
+  uint64_t AddValue = AddOp1->getZExtValue();
+  if (AddValue != 1ULL << (ShiftValue - 1))
+    return SDValue();
+
+  SDLoc DL(Shift);
+  unsigned Opc = Shift->getOpcode() == ISD::SRA ? AArch64ISD::SRSHR_I_PRED
+                                                : AArch64ISD::URSHR_I_PRED;
+  return DAG.getNode(Opc, DL, ResVT, getPredicateForVector(DAG, DL, ResVT),
+                     Add->getOperand(0),
+                     DAG.getTargetConstant(ShiftValue, DL, MVT::i32));
+}
+
 SDValue AArch64TargetLowering::LowerVectorSRA_SRL_SHL(SDValue Op,
                                                       SelectionDAG &DAG) const {
   EVT VT = Op.getValueType();
@@ -13738,6 +13777,10 @@ SDValue AArch64TargetLowering::LowerVectorSRA_SRL_SHL(SDValue Op,
                        Op.getOperand(0), Op.getOperand(1));
   case ISD::SRA:
   case ISD::SRL:
+    if (VT.isScalableVector() && Subtarget->hasSVE2orSME())
+      if (SDValue RSH = tryLowerToRoundingShiftRightByImm(Op, DAG))
+        return RSH;
+
     if (VT.isScalableVector() ||
         useSVEForFixedLengthVectorVT(VT, !Subtarget->isNeonAvailable())) {
       unsigned Opc = Op.getOpcode() == ISD::SRA ? AArch64ISD::SRA_PRED
@@ -20025,6 +20068,12 @@ static SDValue performIntrinsicCombine(SDNode *N,
   case Intrinsic::aarch64_sve_uqsub_x:
     return DAG.getNode(ISD::USUBSAT, SDLoc(N), N->getValueType(0),
                        N->getOperand(1), N->getOperand(2));
+  case Intrinsic::aarch64_sve_srshr:
+    return DAG.getNode(AArch64ISD::SRSHR_I_PRED, SDLoc(N), N->getValueType(0),
+                       N->getOperand(1), N->getOperand(2), N->getOperand(3));
+  case Intrinsic::aarch64_sve_urshr:
+    return DAG.getNode(AArch64ISD::URSHR_I_PRED, SDLoc(N), N->getValueType(0),
+                       N->getOperand(1), N->getOperand(2), N->getOperand(3));
   case Intrinsic::aarch64_sve_asrd:
     return DAG.getNode(AArch64ISD::SRAD_MERGE_OP1, SDLoc(N), N->getValueType(0),
                        N->getOperand(1), N->getOperand(2), N->getOperand(3));
@@ -20652,12 +20701,13 @@ static SDValue performUnpackCombine(SDNode *N, SelectionDAG &DAG,
 // a uzp1 or a truncating store.
 static SDValue trySimplifySrlAddToRshrnb(SDValue Srl, SelectionDAG &DAG,
                                          const AArch64Subtarget *Subtarget) {
-  EVT VT = Srl->getValueType(0);
+  if (Srl->getOpcode() != AArch64ISD::URSHR_I_PRED)
+    return SDValue();
 
-  if (!VT.isScalableVector() || !Subtarget->hasSVE2() ||
-      Srl->getOpcode() != ISD::SRL)
+  if (!isAllActivePredicate(DAG, Srl.getOperand(0)))
     return SDValue();
 
+  EVT VT = Srl->getValueType(0);
   EVT ResVT;
   if (VT == MVT::nxv8i16)
     ResVT = MVT::nxv16i8;
@@ -20668,29 +20718,14 @@ static SDValue trySimplifySrlAddToRshrnb(SDValue Srl, SelectionDAG &DAG,
   else
     return SDValue();
 
-  auto SrlOp1 =
-      dyn_cast_or_null<ConstantSDNode>(DAG.getSplatValue(Srl->getOperand(1)));
-  if (!SrlOp1)
-    return SDValue();
-  unsigned ShiftValue = SrlOp1->getZExtValue();
-  if (ShiftValue < 1 || ShiftValue > ResVT.getScalarSizeInBits())
-    return SDValue();
-
-  SDValue Add = Srl->getOperand(0);
-  if (Add->getOpcode() != ISD::ADD || !Add->hasOneUse())
-    return SDValue();
-  auto AddOp1 =
-      dyn_cast_or_null<ConstantSDNode>(DAG.getSplatValue(Add->getOperand(1)));
-  if (!AddOp1)
-    return SDValue();
-  uint64_t AddValue = AddOp1->getZExtValue();
-  if (AddValue != 1ULL << (ShiftValue - 1))
+  unsigned ShiftValue =
+      cast<ConstantSDNode>(Srl->getOperand(2))->getZExtValue();
+  if (ShiftValue > ResVT.getScalarSizeInBits())
     return SDValue();
 
   SDLoc DL(Srl);
-  SDValue Rshrnb = DAG.getNode(
-      AArch64ISD::RSHRNB_I, DL, ResVT,
-      {Add->getOperand(0), DAG.getTargetConstant(ShiftValue, DL, MVT::i32)});
+  SDValue Rshrnb = DAG.getNode(AArch64ISD::RSHRNB_I, DL, ResVT,
+                               {Srl->getOperand(1), Srl->getOperand(2)});
   return DAG.getNode(ISD::BITCAST, DL, VT, Rshrnb);
 }
 
diff --git a/llvm/lib/Target/AArch64/AArch64ISelLowering.h b/llvm/lib/Target/AArch64/AArch64ISelLowering.h
index 6ddbcd41dcb7696..e1ecd3f4e36be03 100644
--- a/llvm/lib/Target/AArch64/AArch64ISelLowering.h
+++ b/llvm/lib/Target/AArch64/AArch64ISelLowering.h
@@ -210,7 +210,9 @@ enum NodeType : unsigned {
   UQSHL_I,
   SQSHLU_I,
   SRSHR_I,
+  SRSHR_I_PRED,
   URSHR_I,
+  URSHR_I_PRED,
 
   // Vector narrowing shift by immediate (bottom)
   RSHRNB_I,
diff --git a/llvm/lib/Target/AArch64/AArch64SVEInstrInfo.td b/llvm/lib/Target/AArch64/AArch64SVEInstrInfo.td
index c4d69232c9e30ea..516ab36464379dd 100644
--- a/llvm/lib/Target/AArch64/AArch64SVEInstrInfo.td
+++ b/llvm/lib/Target/AArch64/AArch64SVEInstrInfo.td
@@ -232,6 +232,8 @@ def SDT_AArch64Arith_Imm : SDTypeProfile<1, 3, [
 ]>;
 
 def AArch64asrd_m1 : SDNode<"AArch64ISD::SRAD_MERGE_OP1", SDT_AArch64Arith_Imm>;
+def AArch64urshri_p : SDNode<"AArch64ISD::URSHR_I_PRED", SDT_AArch64Arith_Imm>;
+def AArch64srshri_p : SDNode<"AArch64ISD::SRSHR_I_PRED", SDT_AArch64Arith_Imm>;
 
 def SDT_AArch64IntExtend : SDTypeProfile<1, 4, [
   SDTCisVec<0>, SDTCisVec<1>, SDTCisVec<2>, SDTCisVT<3, OtherVT>, SDTCisVec<4>,
@@ -3538,8 +3540,8 @@ let Predicates = [HasSVE2orSME] in {
   // SVE2 predicated shifts
   defm SQSHL_ZPmI  : sve_int_bin_pred_shift_imm_left_dup<0b0110, "sqshl",  "SQSHL_ZPZI",  int_aarch64_sve_sqshl>;
   defm UQSHL_ZPmI  : sve_int_bin_pred_shift_imm_left_dup<0b0111, "uqshl",  "UQSHL_ZPZI",  int_aarch64_sve_uqshl>;
-  defm SRSHR_ZPmI  : sve_int_bin_pred_shift_imm_right<   0b1100, "srshr",  "SRSHR_ZPZI",  int_aarch64_sve_srshr>;
-  defm URSHR_ZPmI  : sve_int_bin_pred_shift_imm_right<   0b1101, "urshr",  "URSHR_ZPZI",  int_aarch64_sve_urshr>;
+  defm SRSHR_ZPmI  : sve_int_bin_pred_shift_imm_right<   0b1100, "srshr",  "SRSHR_ZPZI",  AArch64srshri_p>;
+  defm URSHR_ZPmI  : sve_int_bin_pred_shift_imm_right<   0b1101, "urshr",  "URSHR_ZPZI",  AArch64urshri_p>;
   defm SQSHLU_ZPmI : sve_int_bin_pred_shift_imm_left<    0b1111, "sqshlu", "SQSHLU_ZPZI", int_aarch64_sve_sqshlu>;
 
   // SVE2 integer add/subtract long
@@ -3583,8 +3585,8 @@ let Predicates = [HasSVE2orSME] in {
   // SVE2 bitwise shift right and accumulate
   defm SSRA_ZZI  : sve2_int_bin_accum_shift_imm_right<0b00, "ssra",  AArch64ssra>;
   defm USRA_ZZI  : sve2_int_bin_accum_shift_imm_right<0b01, "usra",  AArch64usra>;
-  defm SRSRA_ZZI : sve2_int_bin_accum_shift_imm_right<0b10, "srsra", int_aarch64_sve_srsra, int_aarch64_sve_srshr>;
-  defm URSRA_ZZI : sve2_int_bin_accum_shift_imm_right<0b11, "ursra", int_aarch64_sve_ursra, int_aarch64_sve_urshr>;
+  defm SRSRA_ZZI : sve2_int_bin_accum_shift_imm_right<0b10, "srsra", int_aarch64_sve_srsra, AArch64srshri_p>;
+  defm URSRA_ZZI : sve2_int_bin_accum_shift_imm_right<0b11, "ursra", int_aarch64_sve_ursra, AArch64urshri_p>;
 
   // SVE2 complex integer add
   defm CADD_ZZI   : sve2_int_cadd<0b0, "cadd",   int_aarch64_sve_cadd_x>;
diff --git a/llvm/test/CodeGen/AArch64/sve2-intrinsics-combine-rshrnb.ll b/llvm/test/CodeGen/AArch64/sve2-intrinsics-combine-rshrnb.ll
index 0afd11d098a0009..58ef846a3172381 100644
--- a/llvm/test/CodeGen/AArch64/sve2-intrinsics-combine-rshrnb.ll
+++ b/llvm/test/CodeGen/AArch64/sve2-intrinsics-combine-rshrnb.ll
@@ -184,16 +184,14 @@ define void @wide_add_shift_add_rshrnb_d(ptr %dest, i64 %index, <vscale x 4 x i6
 define void @neg_wide_add_shift_add_rshrnb_d(ptr %dest, i64 %index, <vscale x 4 x i64> %arg1){
 ; CHECK-LABEL: neg_wide_add_shift_add_rshrnb_d:
 ; CHECK:       // %bb.0:
-; CHECK-NEXT:    mov z2.d, #0x800000000000
-; CHECK-NEXT:    ptrue p0.s
-; CHECK-NEXT:    add z0.d, z0.d, z2.d
-; CHECK-NEXT:    add z1.d, z1.d, z2.d
-; CHECK-NEXT:    lsr z1.d, z1.d, #48
-; CHECK-NEXT:    lsr z0.d, z0.d, #48
+; CHECK-NEXT:    ptrue p0.d
+; CHECK-NEXT:    ptrue p1.s
+; CHECK-NEXT:    urshr z1.d, p0/m, z1.d, #48
+; CHECK-NEXT:    urshr z0.d, p0/m, z0.d, #48
 ; CHECK-NEXT:    uzp1 z0.s, z0.s, z1.s
-; CHECK-NEXT:    ld1w { z1.s }, p0/z, [x0, x1, lsl #2]
+; CHECK-NEXT:    ld1w { z1.s }, p1/z, [x0, x1, lsl #2]
 ; CHECK-NEXT:    add z0.s, z1.s, z0.s
-; CHECK-NEXT:    st1w { z0.s }, p0, [x0, x1, lsl #2]
+; CHECK-NEXT:    st1w { z0.s }, p1, [x0, x1, lsl #2]
 ; CHECK-NEXT:    ret
   %1 = add <vscale x 4 x i64> %arg1, shufflevector (<vscale x 4 x i64> insertelement (<vscale x 4 x i64> poison, i64 140737488355328, i64 0), <vscale x 4 x i64> poison, <vscale x 4 x i32> zeroinitializer)
   %2 = lshr <vscale x 4 x i64> %1, shufflevector (<vscale x 4 x i64> insertelement (<vscale x 4 x i64> poison, i64 48, i64 0), <vscale x 4 x i64> poison, <vscale x 4 x i32> zeroinitializer)
@@ -286,8 +284,7 @@ define void @neg_add_lshr_rshrnb_s(ptr %ptr, ptr %dst, i64 %index){
 ; CHECK:       // %bb.0:
 ; CHECK-NEXT:    ptrue p0.d
 ; CHECK-NEXT:    ld1d { z0.d }, p0/z, [x0]
-; CHECK-NEXT:    add z0.d, z0.d, #32 // =0x20
-; CHECK-NEXT:    lsr z0.d, z0.d, #6
+; CHECK-NEXT:    urshr z0.d, p0/m, z0.d, #6
 ; CHECK-NEXT:    st1h { z0.d }, p0, [x1, x2, lsl #1]
 ; CHECK-NEXT:    ret
   %load = load <vscale x 2 x i64>, ptr %ptr, align 2
diff --git a/llvm/test/CodeGen/AArch64/sve2-rsh.ll b/llvm/test/CodeGen/AArch64/sve2-rsh.ll
new file mode 100644
index 000000000000000..2bdfc1931cdc2f3
--- /dev/null
+++ b/llvm/test/CodeGen/AArch64/sve2-rsh.ll
@@ -0,0 +1,203 @@
+; NOTE: Assertions have been autogenerated by utils/update_llc_test_checks.py UTC_ARGS: --version 4
+; RUN: llc -mtriple=aarch64 -mattr=+sve < %s -o - | FileCheck --check-prefixes=CHECK,SVE %s
+; RUN: llc -mtriple=aarch64 -mattr=+sve2 < %s -o - | FileCheck --check-prefixes=CHECK,SVE2 %s
+
+; Wrong add/shift amount. Should be 32 for shift of 6.
+define <vscale x 2 x i64> @neg_urshr_1(<vscale x 2 x i64> %x) {
+; CHECK-LABEL: neg_urshr_1:
+; CHECK:       // %bb.0:
+; CHECK-NEXT:    add z0.d, z0.d, #16 // =0x10
+; CHECK-NEXT:    lsr z0.d, z0.d, #6
+; CHECK-NEXT:    ret
+  %add = add <vscale x 2 x i64> %x, splat (i64 16)
+  %sh = lshr <vscale x 2 x i64> %add, splat (i64 6)
+  ret <vscale x 2 x i64> %sh
+}
+
+; Vector Shift.
+define <vscale x 2 x i64> @neg_urshr_2(<vscale x 2 x i64> %x, <vscale x 2 x i64> %y) {
+; CHECK-LABEL: neg_urshr_2:
+; CHECK:       // %bb.0:
+; CHECK-NEXT:    ptrue p0.d
+; CHECK-NEXT:    add z0.d, z0.d, #32 // =0x20
+; CHECK-NEXT:    lsr z0.d, p0/m, z0.d, z1.d
+; CHECK-NEXT:    ret
+  %add = add <vscale x 2 x i64> %x, splat (i64 32)
+  %sh = lshr <vscale x 2 x i64> %add, %y
+  ret <vscale x 2 x i64> %sh
+}
+
+; Vector Add.
+define <vscale x 2 x i64> @neg_urshr_3(<vscale x 2 x i64> %x, <vscale x 2 x i64> %y) {
+; CHECK-LABEL: neg_urshr_3:
+; CHECK:       // %bb.0:
+; CHECK-NEXT:    add z0.d, z0.d, z1.d
+; CHECK-NEXT:    lsr z0.d, z0.d, #6
+; CHECK-NEXT:    ret
+  %add = add <vscale x 2 x i64> %x, %y
+  %sh = lshr <vscale x 2 x i64> %add, splat (i64 6)
+  ret <vscale x 2 x i64> %sh
+}
+
+; Add has two uses.
+define <vscale x 2 x i64> @neg_urshr_4(<vscale x 2 x i64> %x) {
+; CHECK-LABEL: neg_urshr_4:
+; CHECK:       // %bb.0:
+; CHECK-NEXT:    stp x29, x30, [sp, #-16]! // 16-byte Folded Spill
+; CHECK-NEXT:    addvl sp, sp, #-1
+; CHECK-NEXT:    str z8, [sp] // 16-byte Folded Spill
+; CHECK-NEXT:    .cfi_escape 0x0f, 0x0c, 0x8f, 0x00, 0x11, 0x10, 0x22, 0x11, 0x08, 0x92, 0x2e, 0x00, 0x1e, 0x22 // sp + 16 + 8 * VG
+; CHECK-NEXT:    .cfi_offset w30, -8
+; CHECK-NEXT:    .cfi_offset w29, -16
+; CHECK-NEXT:    .cfi_escape 0x10, 0x48, 0x0a, 0x11, 0x70, 0x22, 0x11, 0x78, 0x92, 0x2e, 0x00, 0x1e, 0x22 // $d8 @ cfa - 16 - 8 * VG
+; CHECK-NEXT:    add z0.d, z0.d, #32 // =0x20
+; CHECK-NEXT:    lsr z8.d, z0.d, #6
+; CHECK-NEXT:    bl use
+; CHECK-NEXT:    mov z0.d, z8.d
+; CHECK-NEXT:    ldr z8, [sp] // 16-byte Folded Reload
+; CHECK-NEXT:    addvl sp, sp, #1
+; CHECK-NEXT:    ldp x29, x30, [sp], #16 // 16-byte Folded Reload
+; CHECK-NEXT:    ret
+  %add = add <vscale x 2 x i64> %x, splat (i64 32)
+  %sh = lshr <vscale x 2 x i64> %add, splat (i64 6)
+  call void @use(<vscale x 2 x i64> %add)
+  ret <vscale x 2 x i64> %sh
+}
+
+define <vscale x 16 x i8> @urshr_i8(<vscale x 16 x i8> %x) {
+; SVE-LABEL: urshr_i8:
+; SVE:       // %bb.0:
+; SVE-NEXT:    add z0.b, z0.b, #32 // =0x20
+; SVE-NEXT:    lsr z0.b, z0.b, #6
+; SVE-NEXT:    ret
+;
+; SVE2-LABEL: urshr_i8:
+; SVE2:       // %bb.0:
+; SVE2-NEXT:    ptrue p0.b
+; SVE2-NEXT:    urshr z0.b, p0/m, z0.b, #6
+; SVE2-NEXT:    ret
+  %add = add <vscale x 16 x i8> %x, splat (i8 32)
+  %sh = lshr <vscale x 16 x i8> %add, splat (i8 6)
+  ret <vscale x 16 x i8> %sh
+}
+
+define <vscale x 8 x i16> @urshr_i16(<vscale x 8 x i16> %x) {
+; SVE-LABEL: urshr_i16:
+; SVE:       // %bb.0:
+; SVE-NEXT:    add z0.h, z0.h, #32 // =0x20
+; SVE-NEXT:    lsr z0.h, z0.h, #6
+; SVE-NEXT:    ret
+;
+; SVE2-LABEL: urshr_i16:
+; SVE2:       // %bb.0:
+; SVE2-NEXT:    ptrue p0.h
+; SVE2-NEXT:    urshr z0.h, p0/m, z0.h, #6
+; SVE2-NEXT:    ret
+  %add = add <vscale x 8 x i16> %x, splat (i16 32)
+  %sh = lshr <vscale x 8 x i16> %add, splat (i16 6)
+  ret <vscale x 8 x i16> %sh
+}
+
+define <vscale x 4 x i32> @urshr_i32(<vscale x 4 x i32> %x) {
+; SVE-LABEL: urshr_i32:
+; SVE:       // %bb.0:
+; SVE-NEXT:    add z0.s, z0.s, #32 // =0x20
+; SVE-NEXT:    lsr z0.s, z0.s, #6
+; SVE-NEXT:    ret
+;
+; SVE2-LABEL: urshr_i32:
+; SVE2:       // %bb.0:
+; SVE2-NEXT:    ptrue p0.s
+; SVE2-NEXT:    urshr z0.s, p0/m, z0.s, #6
+; SVE2-NEXT:    ret
+  %add = add <vscale x 4 x i32> %x, splat (i32 32)
+  %sh = lshr <vscale x 4 x i32> %add, splat (i32 6)
+  ret <vscale x 4 x i32> %sh
+}
+
+define <vscale x 2 x i64> @urshr_i64(<vscale x 2 x i64> %x) {
+; SVE-LABEL: urshr_i64:
+; SVE:       // %bb.0:
+; SVE-NEXT:    add z0.d, z0.d, #32 // =0x20
+; SVE-NEXT:    lsr z0.d, z0.d, #6
+; SVE-NEXT:    ret
+;
+; SVE2-LABEL: urshr_i64:
+; SVE2:       // %bb.0:
+; SVE2-NEXT:    ptrue p0.d
+; SVE2-NEXT:    urshr z0.d, p0/m, z0.d, #6
+; SVE2-NEXT:    ret
+  %add = add <vscale x 2 x i64> %x, splat (i64 32)
+  %sh = lshr <vscale x 2 x i64> %add, splat (i64 6)
+  ret <vscale x 2 x i64> %sh
+}
+
+define <vscale x 16 x i8> @srshr_i8(<vscale x 16 x i8> %x) {
+; SVE-LABEL: srshr_i8:
+; SVE:       // %bb.0:
+; SVE-NEXT:    add z0.b, z0.b, #32 // =0x20
+; SVE-NEXT:    asr z0.b, z0.b, #6
+; SVE-NEXT:    ret
+;
+; SVE2-LABEL: srshr_i8:
+; SVE2:       // %bb.0:
+; SVE2-NEXT:    ptrue p0.b
+; SVE2-NEXT:    srshr z0.b, p0/m, z0.b, #6
+; SVE2-NEXT:    ret
+  %add = add <vscale x 16 x i8> %x, splat (i8 32)
+  %sh = ashr <vscale x 16 x i8> %add, splat (i8 6)
+  ret <vscale x 16 x i8> %sh
+}
+
+define <vscale x 8 x i16> @srshr_i16(<vscale x 8 x i16> %x) {
+; SVE-LABEL: srshr_i16:
+; SVE:       // %bb.0:
+; SVE-NEXT:    add z0.h, z0.h, #32 // =0x20
+; SVE-NEXT:    asr z0.h, z0.h, #6
+; SVE-NEXT:    ret
+;
+; SVE2-LABEL: srshr_i16:
+; SVE2:       // %bb.0:
+; SVE2-NEXT:    ptrue p0.h
+; SVE2-NEXT:    srshr z0.h, p0/m, z0.h, #6
+; SVE2-NEXT:    ret
+  %add = add <vscale x 8 x i16> %x, splat (i16 32)
+  %sh = ashr <vscale x 8 x i16> %add, splat (i16 6)
+  ret <vscale x 8 x i16> %sh
+}
+
+define <vscale x 4 x i32> @srshr_i32(<vscale x 4 x i32> %x) {
+; SVE-LABEL: srshr_i32:
+; SVE:       // %bb.0:
+; SVE-NEXT:    add z0.s, z0.s, #32 // =0x20
+; SVE-NEXT:    asr z0.s, z0.s, #6
+; SVE-NEXT:    ret
+;
+; SVE2-LABEL: srshr_i32:
+; SVE2:       // %bb.0:
+; SVE2-NEXT:    ptrue p0.s
+; SVE2-NEXT:    srshr z0.s, p0/m, z0.s, #6
+; SVE2-NEXT:    ret
+  %add = add <vscale x 4 x i32> %x, splat (i32 32)
+  %sh = ashr <vscale x 4 x i32> %add, splat (i32 6)
+  ret <vscale x 4 x i32> %sh
+}
+
+define <vscale x 2 x i64> @srshr_i64(<vscale x 2 x i64> %x) {
+; SVE-LABEL: srshr_i64:
+; SVE:       // %bb.0:
+; SVE-NEXT:    add z0.d, z0.d, #32 // =0x20
+; SVE-NEXT:    asr z0.d, z0.d, #6
+; SVE-NEXT:    ret
+;
+; SVE2-LABEL: srshr_i64:
+; SVE2:       // %bb.0:
+; SVE2-NEXT:    ptrue p0.d
+; SVE2-NEXT:    srshr z0.d, p0/m, z0.d, #6
+; SVE2-NEXT:    ret
+  %add = add <vscale x 2 x i64> %x, splat (i64 32)
+  %sh = ashr <vscale x 2 x i64> %add, splat (i64 6)
+  ret <vscale x 2 x i64> %sh
+}
+
+declare void @use(<vscale x 2 x i64>)

@UsmanNadeem UsmanNadeem changed the title [AArch64][SVE2] Generate signed/unsigned rounding shift rights [AArch64][SVE2] Generate urshr/srshr rounding shift rights Jan 20, 2024
Copy link
Collaborator

@davemgreen davemgreen left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hello. Some of these instructions have some awkward edge cases. Is this correct if the add overflows?

@UsmanNadeem
Copy link
Contributor Author

UsmanNadeem commented Jan 22, 2024

Hello. Some of these instructions have some awkward edge cases. Is this correct if the add overflows?

Yes, the result will be correct. Meaning that if we operate on the max value for the respective data type, the result will not be effected by an overflow.
e.g. for uint8:
255 r/ 4 == (255 + (1 << 2-1)) >> 2 == urshr z0.b, p0/m, z0.b, #2 the result will be 64.

@davemgreen
Copy link
Collaborator

I agree that is what urshr will produce but not the shr+add. The add would overflow, producing zero, so the shifted value is zero.
I think it needs to be something like trunc(shr(add(ext))) or maybe we check that at least one of the top bits is known to not be 1 (so the add doesn't overflow).

@UsmanNadeem
Copy link
Contributor Author

I agree that is what urshr will produce but not the shr+add. The add would overflow, producing zero, so the shifted value is zero. I think it needs to be something like trunc(shr(add(ext))) or maybe we check that at least one of the top bits is known to not be 1 (so the add doesn't overflow).

Right, I guess in this case the existing code for rshrnb also has this bug. I'll work on it.

Matching code is similar to that for rshrnb except that immediate
shift value has a larger range, and support for signed shift. rshrnb
now uses the new AArch64ISD node for uniform rounding.

Change-Id: Idbb811f318d33c7637371cf7bb00285d20e1771d
Change-Id: I7450629fa43bb3ac1bc40daaa760255eed483c10
@UsmanNadeem
Copy link
Contributor Author

  • Handled overflows.
  • Added the following combine for wide shifts: uzp1(urshr(uunpklo(X),C), urshr(uunpkhi(X), C)) -> urshr(X, C)
  • Removed code for signed RSHR for now.

@UsmanNadeem UsmanNadeem changed the title [AArch64][SVE2] Generate urshr/srshr rounding shift rights [AArch64][SVE2] Generate urshr rounding shift rights Jan 27, 2024
Copy link
Collaborator

@davemgreen davemgreen left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Using nuw sounds like a smart idea, and should handle cases like https://godbolt.org/z/rf1oaMvd4 as the nuw get automatically added. (It does make some of the tools I have for trying to analyze equivalence less useful, but as far as I can tell it seems OK). There might be an extension too where it checks the demanded bits would make the overflow unimportant, but I'm not sure where to make that happen.

Change-Id: Id6dceead02c7473ed5c3635c2b56c7f367315563
@UsmanNadeem
Copy link
Contributor Author

  • canLowerSRLToRoundingShiftForVT returns true if the add has NUW or if the number of bits used in the return value allow us to not care about the overflow (tested by rshrnb).

  • srl(add(X, 1 << (ShiftValue - 1)), ShiftValue) is transformed to urshr or rshrnb if the result it truncated.

  • uzp1(rshrnb(uunpklo(X),C), rshrnb(uunpkhi(X), C)) is converted to urshr(X, C), tested by the wide_trunc tests.

if (SDValue Rshrnb = trySimplifySrlAddToRshrnb(Op0, DAG, Subtarget))
return DAG.getNode(AArch64ISD::UZP1, DL, ResVT, Rshrnb, Op1);

if (SDValue Rshrnb = trySimplifySrlAddToRshrnb(Op1, DAG, Subtarget))
return DAG.getNode(AArch64ISD::UZP1, DL, ResVT, Op0, Rshrnb);

// uzp1(bitcast(x), bitcast(y)) -> uzp1(x, y)
if (isHalvingTruncateAndConcatOfLegalIntScalableType(N) &&
Op0.getOpcode() == ISD::BITCAST && Op1.getOpcode() == ISD::BITCAST) {
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It is not obvious to me why it is valid to remove the BITCAST in all cases. Is it because the instruction is entirely defined by the output type, and so the input types do not matter? We can just remove the bitcasts, and doing so leads to simpler code?

Under big endian a BITCAST will actually swap the order of certain lanes (they are defined in terms of storing in one type and reloading in another, so are lowered to a REV). BE isn't supported for SVE yet for some reason, but we should limit this to LE.

Copy link
Contributor Author

@UsmanNadeem UsmanNadeem Jan 31, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is the transform:

     nxv4i32 = AArch64ISD::UZP1 bitcast(nxv4i32 to nxv2i64), bitcast(nxv4i32 to nxv2i64)
i.e. nxv4i32 = AArch64ISD::UZP1 nxv2i64 ..., nxv2i64 ...
     =>
     nxv4i32 = AArch64ISD::UZP1 nxv4i32 x, nxv4i32 y

Both get lowered to uzp1.s and removing the bitcast makes the code simpler.

An example here: https://godbolt.org/z/b7hsqc1Ev

I limited it to little endian.

llvm/test/CodeGen/AArch64/sve2-rsh.ll Outdated Show resolved Hide resolved
Change-Id: I076f19c947696100ec469c8407b6d235d6444145
Change-Id: Idaed857e4978185a74488651a5cc216c8b5eddd3
Copy link
Collaborator

@davemgreen davemgreen left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks. LGTM

@UsmanNadeem UsmanNadeem merged commit 1d14323 into llvm:main Jan 31, 2024
3 of 4 checks passed
Comment on lines +20251 to +20253
case Intrinsic::aarch64_sve_urshr:
return DAG.getNode(AArch64ISD::URSHR_I_PRED, SDLoc(N), N->getValueType(0),
N->getOperand(1), N->getOperand(2), N->getOperand(3));
Copy link
Collaborator

@paulwalker-arm paulwalker-arm Feb 1, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This doesn't look sound to me. Either the naming of the AArch64ISD node is wrong or there needs to be a PatFrags that contains this node and the intrinsic. I say the because the _PRED nodes have no requirement when it comes to the result of inactive lanes whereas the aarch64_sve_urshr intrinsic has a very specific requirement.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For naming I looked at a few other instructions that have similar behavior as urshr, i.e. inactive elements in the destination vector remain unmodified, and they were also named as _PRED.

I am quite new to the isel backend. Can you please explain what difference having a PatFrag would make compared to the code above?

Copy link
Collaborator

@paulwalker-arm paulwalker-arm Feb 6, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you point to an example of where a _PRED node expects the results of inactive lanes to take a known value? because that really shouldn't be the case (there's a comment at the top of AArch64SVEInstrInfo.td and AArch64ISelLowering.h that details the naming strategy). The intent of the _PRED nodes is to allow predication to be represented at the DAG level rather than waiting until instruction selection. They have no requirement for the results of inactive lanes to free up instruction section to allow the best use of unpredicated and/or reversed instructions.

The naming is important because people will assume the documented rules implementing DAG combines or make changes to instruction selection and thus if they're not followed it's very likely to introduce bugs. If it's important for the ISD node to model the results of the inactive lanes in accordance with the underlying SVE instruction then it should be named as such (e.g. URSHR_I_MERGE_OP1).

This is generally not the case and typically at the ISD level the result of inactive lanes is not important (often because an all active predicate is passed in) and thus the _PRED suffix is used. When this is the case we still want to minimise the number of ISel patterns and so a PatFrags is created to match both the ISD node and the intrinsic to the same instruction (e.g. AArch64mla_m1).

Copy link
Contributor Author

@UsmanNadeem UsmanNadeem Feb 9, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks, I get your point now. I will post a follow-up fix.

carlosgalvezp pushed a commit to carlosgalvezp/llvm-project that referenced this pull request Feb 1, 2024
Add a new node `AArch64ISD::URSHR_I_PRED`.

`srl(add(X, 1 << (ShiftValue - 1)), ShiftValue)` is transformed to
`urshr`, or to `rshrnb` (as before) if the result it truncated.

`uzp1(rshrnb(uunpklo(X),C), rshrnb(uunpkhi(X), C))` is converted to
`urshr(X, C)` (tested by the wide_trunc tests).

Pattern matching code in `canLowerSRLToRoundingShiftForVT` is taken
from prior code in rshrnb. It returns true if the add has NUW or if the
number of bits used in the return value allow us to not care about the
overflow (tested by rshrnb test cases).
UsmanNadeem added a commit to UsmanNadeem/llvm-project that referenced this pull request Feb 9, 2024
Follow-up for llvm#78374

Change-Id: Ib39b60725f508343fd7fc0f9160f0cf8ad8d7f7f
UsmanNadeem added a commit that referenced this pull request Feb 12, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

4 participants