Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[AArch64][SVE2] Use rshrnb for masked stores #70026

Merged
merged 2 commits into from
Oct 26, 2023
Merged

Conversation

MDevereau
Copy link
Contributor

This patch is a follow up on https://reviews.llvm.org/D155299. This patch combines add+lsr to rshrnb when 'B' in:

C = A + B
D = C >> Shift

is equal to (1 << (Shift-1), and the bits in the top half of each vector element are zeroed or ignored, such as in a truncating masked store.

@llvmbot
Copy link
Collaborator

llvmbot commented Oct 24, 2023

@llvm/pr-subscribers-backend-aarch64

Author: Matthew Devereau (MDevereau)

Changes

This patch is a follow up on https://reviews.llvm.org/D155299. This patch combines add+lsr to rshrnb when 'B' in:

C = A + B
D = C >> Shift

is equal to (1 << (Shift-1), and the bits in the top half of each vector element are zeroed or ignored, such as in a truncating masked store.


Full diff: https://github.com/llvm/llvm-project/pull/70026.diff

2 Files Affected:

  • (modified) llvm/lib/Target/AArch64/AArch64ISelLowering.cpp (+13)
  • (modified) llvm/test/CodeGen/AArch64/sve2-intrinsics-combine-rshrnb.ll (+19)
diff --git a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
index a16a102e472e709..09ab4ddacddf138 100644
--- a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
+++ b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
@@ -21017,6 +21017,19 @@ static SDValue performMSTORECombine(SDNode *N,
     }
   }
 
+  if (MST->isTruncatingStore()){
+    if (SDValue Rshrnb = trySimplifySrlAddToRshrnb(Value, DAG, Subtarget)){
+      EVT ValueVT = Value->getValueType(0);
+      EVT MemVT = MST->getMemoryVT();
+      if ((ValueVT == MVT::nxv8i16 && MemVT == MVT::nxv8i8) ||
+          (ValueVT == MVT::nxv4i32 && MemVT == MVT::nxv4i16) ||
+          (ValueVT == MVT::nxv2i64 && MemVT == MVT::nxv2i32)){
+        return DAG.getMaskedStore(MST->getChain(), DL, Rshrnb, MST->getBasePtr(), MST->getOffset(), MST->getMask(), 
+                            MST->getMemoryVT(), MST->getMemOperand(), MST->getAddressingMode(), true);
+      }
+    }
+  }
+
   return SDValue();
 }
 
diff --git a/llvm/test/CodeGen/AArch64/sve2-intrinsics-combine-rshrnb.ll b/llvm/test/CodeGen/AArch64/sve2-intrinsics-combine-rshrnb.ll
index a913177623df9ec..0afd11d098a0009 100644
--- a/llvm/test/CodeGen/AArch64/sve2-intrinsics-combine-rshrnb.ll
+++ b/llvm/test/CodeGen/AArch64/sve2-intrinsics-combine-rshrnb.ll
@@ -298,3 +298,22 @@ define void @neg_add_lshr_rshrnb_s(ptr %ptr, ptr %dst, i64 %index){
   store <vscale x 2 x i16> %3, ptr %4, align 1
   ret void
 }
+
+define void @masked_store_rshrnb(ptr %ptr, ptr %dst, i64 %index, <vscale x 8 x i1> %mask) {                             ; preds = %vector.body, %vector.ph
+; CHECK-LABEL: masked_store_rshrnb:
+; CHECK:       // %bb.0:
+; CHECK-NEXT:    ld1h { z0.h }, p0/z, [x0]
+; CHECK-NEXT:    rshrnb z0.b, z0.h, #6
+; CHECK-NEXT:    st1b { z0.h }, p0, [x1, x2]
+; CHECK-NEXT:    ret
+  %wide.masked.load = tail call <vscale x 8 x i16> @llvm.masked.load.nxv8i16.p0(ptr %ptr, i32 2, <vscale x 8 x i1> %mask, <vscale x 8 x i16> poison)
+  %1 = add <vscale x 8 x i16> %wide.masked.load, trunc (<vscale x 8 x i32> shufflevector (<vscale x 8 x i32> insertelement (<vscale x 8 x i32> poison, i32 32, i64 0), <vscale x 8 x i32> poison, <vscale x 8 x i32> zeroinitializer) to <vscale x 8 x i16>)
+  %2 = lshr <vscale x 8 x i16> %1, trunc (<vscale x 8 x i32> shufflevector (<vscale x 8 x i32> insertelement (<vscale x 8 x i32> poison, i32 6, i64 0), <vscale x 8 x i32> poison, <vscale x 8 x i32> zeroinitializer) to <vscale x 8 x i16>)
+  %3 = trunc <vscale x 8 x i16> %2 to <vscale x 8 x i8>
+  %4 = getelementptr inbounds i8, ptr %dst, i64 %index
+  tail call void @llvm.masked.store.nxv8i8.p0(<vscale x 8 x i8> %3, ptr %4, i32 1, <vscale x 8 x i1> %mask)
+  ret void
+}
+
+declare void @llvm.masked.store.nxv8i8.p0(<vscale x 8 x i8>, ptr, i32, <vscale x 8 x i1>)
+declare <vscale x 8 x i16> @llvm.masked.load.nxv8i16.p0(ptr, i32, <vscale x 8 x i1>, <vscale x 8 x i16>)

@github-actions
Copy link

github-actions bot commented Oct 24, 2023

✅ With the latest revision this PR passed the C/C++ code formatter.

This patch is a follow up on https://reviews.llvm.org/D155299.
This patch combines add+lsr to rshrnb when 'B' in:

  C = A + B
  D = C >> Shift

is equal to (1 << (Shift-1), and the bits in the top half
of each vector element are zeroed or ignored, such as in a
truncating masked store.
Copy link
Contributor

@david-arm david-arm left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for this @MDevereau - looks like another nice improvement! I just have one comment about potentially refactoring the code ...

if (SDValue Rshrnb = trySimplifySrlAddToRshrnb(Value, DAG, Subtarget)) {
EVT ValueVT = Value->getValueType(0);
EVT MemVT = MST->getMemoryVT();
if ((ValueVT == MVT::nxv8i16 && MemVT == MVT::nxv8i8) ||
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think you have similar checks in performSTORECombine. It would be nice to combine them somehow into a helper function and reuse the logic, perhaps something like

bool isHalvingTruncateOfLegalScalableType(MVT SrcVT, MVT DstVT) {
  return (SrcVT == MVT::nxv8i16 && DstVT == MVT::nxv8i8) ||
      (SrcVT == MVT::nxv4i32 && DstVT == MVT::nxv4i16) ||
      (SrcVT == MVT::nxv2i64 && DstVT == MVT::nxv2i32);
}

Also, I think I'd prefer we do this check before calling trySimplifySrlAddToRshrnb to prevent us doing unnecessary extra work and potentially creating nodes that we throw away. I realise we also do this in performSTORECombine, but it would be great to re-order the checks in that function too!

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done!

Copy link
Contributor

@david-arm david-arm left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM! Thanks for the changes @MDevereau. :)

@MDevereau MDevereau merged commit 18775a4 into llvm:main Oct 26, 2023
3 checks passed
@MDevereau MDevereau deleted the rshrnb branch October 26, 2023 07:42
zahiraam pushed a commit to zahiraam/llvm-project that referenced this pull request Oct 26, 2023
This patch is a follow up on https://reviews.llvm.org/D155299. This
patch combines add+lsr to rshrnb when 'B' in:

  C = A + B
  D = C >> Shift

is equal to (1 << (Shift-1), and the bits in the top half of each vector
element are zeroed or ignored, such as in a truncating masked store.
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

3 participants