Skip to content

Commit

Permalink
[AArch64][SVE2] Use rshrnb for masked stores
Browse files Browse the repository at this point in the history
This patch is a follow up on https://reviews.llvm.org/D155299.
This patch combines add+lsr to rshrnb when 'B' in:

  C = A + B
  D = C >> Shift

is equal to (1 << (Shift-1), and the bits in the top half
of each vector element are zeroed or ignored, such as in a
truncating masked store.
  • Loading branch information
MDevereau committed Oct 24, 2023
1 parent febf5c9 commit 512193b
Show file tree
Hide file tree
Showing 2 changed files with 34 additions and 0 deletions.
15 changes: 15 additions & 0 deletions llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -21017,6 +21017,21 @@ static SDValue performMSTORECombine(SDNode *N,
}
}

if (MST->isTruncatingStore()) {
if (SDValue Rshrnb = trySimplifySrlAddToRshrnb(Value, DAG, Subtarget)) {
EVT ValueVT = Value->getValueType(0);
EVT MemVT = MST->getMemoryVT();
if ((ValueVT == MVT::nxv8i16 && MemVT == MVT::nxv8i8) ||
(ValueVT == MVT::nxv4i32 && MemVT == MVT::nxv4i16) ||
(ValueVT == MVT::nxv2i64 && MemVT == MVT::nxv2i32)) {
return DAG.getMaskedStore(
MST->getChain(), DL, Rshrnb, MST->getBasePtr(), MST->getOffset(),
MST->getMask(), MST->getMemoryVT(), MST->getMemOperand(),
MST->getAddressingMode(), true);
}
}
}

return SDValue();
}

Expand Down
19 changes: 19 additions & 0 deletions llvm/test/CodeGen/AArch64/sve2-intrinsics-combine-rshrnb.ll
Original file line number Diff line number Diff line change
Expand Up @@ -298,3 +298,22 @@ define void @neg_add_lshr_rshrnb_s(ptr %ptr, ptr %dst, i64 %index){
store <vscale x 2 x i16> %3, ptr %4, align 1
ret void
}

define void @masked_store_rshrnb(ptr %ptr, ptr %dst, i64 %index, <vscale x 8 x i1> %mask) { ; preds = %vector.body, %vector.ph
; CHECK-LABEL: masked_store_rshrnb:
; CHECK: // %bb.0:
; CHECK-NEXT: ld1h { z0.h }, p0/z, [x0]
; CHECK-NEXT: rshrnb z0.b, z0.h, #6
; CHECK-NEXT: st1b { z0.h }, p0, [x1, x2]
; CHECK-NEXT: ret
%wide.masked.load = tail call <vscale x 8 x i16> @llvm.masked.load.nxv8i16.p0(ptr %ptr, i32 2, <vscale x 8 x i1> %mask, <vscale x 8 x i16> poison)
%1 = add <vscale x 8 x i16> %wide.masked.load, trunc (<vscale x 8 x i32> shufflevector (<vscale x 8 x i32> insertelement (<vscale x 8 x i32> poison, i32 32, i64 0), <vscale x 8 x i32> poison, <vscale x 8 x i32> zeroinitializer) to <vscale x 8 x i16>)
%2 = lshr <vscale x 8 x i16> %1, trunc (<vscale x 8 x i32> shufflevector (<vscale x 8 x i32> insertelement (<vscale x 8 x i32> poison, i32 6, i64 0), <vscale x 8 x i32> poison, <vscale x 8 x i32> zeroinitializer) to <vscale x 8 x i16>)
%3 = trunc <vscale x 8 x i16> %2 to <vscale x 8 x i8>
%4 = getelementptr inbounds i8, ptr %dst, i64 %index
tail call void @llvm.masked.store.nxv8i8.p0(<vscale x 8 x i8> %3, ptr %4, i32 1, <vscale x 8 x i1> %mask)
ret void
}

declare void @llvm.masked.store.nxv8i8.p0(<vscale x 8 x i8>, ptr, i32, <vscale x 8 x i1>)
declare <vscale x 8 x i16> @llvm.masked.load.nxv8i16.p0(ptr, i32, <vscale x 8 x i1>, <vscale x 8 x i16>)

0 comments on commit 512193b

Please sign in to comment.