Skip to content

Commit

Permalink
[AArch64] Add support for efficient bitcast in vector truncate store.
Browse files Browse the repository at this point in the history
Following the changes in D145301, we now also support the efficient bitcast
when storing the bool vector. Previously, this was expanded.

Differential Revision: https://reviews.llvm.org/D148316
  • Loading branch information
lawben authored and davemgreen committed Apr 28, 2023
1 parent 84539a2 commit cd68e17
Show file tree
Hide file tree
Showing 6 changed files with 413 additions and 37 deletions.
55 changes: 47 additions & 8 deletions llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
Expand Up @@ -19775,20 +19775,25 @@ static EVT tryGetOriginalBoolVectorType(SDValue Op, int Depth = 0) {
static SDValue vectorToScalarBitmask(SDNode *N, SelectionDAG &DAG) {
SDLoc DL(N);
SDValue ComparisonResult(N, 0);
EVT BoolVecVT = ComparisonResult.getValueType();
assert(BoolVecVT.isVector() && "Must be a vector type");
EVT VecVT = ComparisonResult.getValueType();
assert(VecVT.isVector() && "Must be a vector type");

unsigned NumElts = BoolVecVT.getVectorNumElements();
unsigned NumElts = VecVT.getVectorNumElements();
if (NumElts != 2 && NumElts != 4 && NumElts != 8 && NumElts != 16)
return SDValue();

if (VecVT.getVectorElementType() != MVT::i1 &&
!DAG.getTargetLoweringInfo().isTypeLegal(VecVT))
return SDValue();

// If we can find the original types to work on instead of a vector of i1,
// we can avoid extend/extract conversion instructions.
EVT VecVT = tryGetOriginalBoolVectorType(ComparisonResult);
if (!VecVT.isSimple()) {
unsigned BitsPerElement = std::max(64 / NumElts, 8u); // min. 64-bit vector
VecVT =
BoolVecVT.changeVectorElementType(MVT::getIntegerVT(BitsPerElement));
if (VecVT.getVectorElementType() == MVT::i1) {
VecVT = tryGetOriginalBoolVectorType(ComparisonResult);
if (!VecVT.isSimple()) {
unsigned BitsPerElement = std::max(64 / NumElts, 8u); // >= 64-bit vector
VecVT = MVT::getVectorVT(MVT::getIntegerVT(BitsPerElement), NumElts);
}
}
VecVT = VecVT.changeVectorElementTypeToInteger();

Expand Down Expand Up @@ -19849,6 +19854,37 @@ static SDValue vectorToScalarBitmask(SDNode *N, SelectionDAG &DAG) {
return DAG.getNode(ISD::VECREDUCE_ADD, DL, ResultVT, RepresentativeBits);
}

static SDValue combineBoolVectorAndTruncateStore(SelectionDAG &DAG,
StoreSDNode *Store) {
if (!Store->isTruncatingStore())
return SDValue();

SDLoc DL(Store);
SDValue VecOp = Store->getValue();
EVT VT = VecOp.getValueType();
EVT MemVT = Store->getMemoryVT();

if (!MemVT.isVector() || !VT.isVector() ||
MemVT.getVectorElementType() != MVT::i1)
return SDValue();

// If we are storing a vector that we are currently building, let
// `scalarizeVectorStore()` handle this more efficiently.
if (VecOp.getOpcode() == ISD::BUILD_VECTOR)
return SDValue();

VecOp = DAG.getNode(ISD::TRUNCATE, DL, MemVT, VecOp);
SDValue VectorBits = vectorToScalarBitmask(VecOp.getNode(), DAG);
if (!VectorBits)
return SDValue();

EVT StoreVT =
EVT::getIntegerVT(*DAG.getContext(), MemVT.getStoreSizeInBits());
SDValue ExtendedBits = DAG.getZExtOrTrunc(VectorBits, DL, StoreVT);
return DAG.getStore(Store->getChain(), DL, ExtendedBits, Store->getBasePtr(),
Store->getMemOperand());
}

static SDValue performSTORECombine(SDNode *N,
TargetLowering::DAGCombinerInfo &DCI,
SelectionDAG &DAG,
Expand Down Expand Up @@ -19887,6 +19923,9 @@ static SDValue performSTORECombine(SDNode *N,
if (SDValue Store = foldTruncStoreOfExt(DAG, N))
return Store;

if (SDValue Store = combineBoolVectorAndTruncateStore(DAG, ST))
return Store;

return SDValue();
}

Expand Down
15 changes: 13 additions & 2 deletions llvm/test/CodeGen/AArch64/setcc-type-mismatch.ll
@@ -1,9 +1,20 @@
; NOTE: Assertions have been autogenerated by utils/update_llc_test_checks.py
; RUN: llc -mtriple=aarch64-linux-gnu %s -o - | FileCheck %s

define void @test_mismatched_setcc(<4 x i22> %l, <4 x i22> %r, ptr %addr) {
; CHECK-LABEL: test_mismatched_setcc:
; CHECK: cmeq [[CMP128:v[0-9]+]].4s, {{v[0-9]+}}.4s, {{v[0-9]+}}.4s
; CHECK: xtn {{v[0-9]+}}.4h, [[CMP128]].4s
; CHECK: // %bb.0:
; CHECK-NEXT: movi v2.4s, #63, msl #16
; CHECK-NEXT: adrp x8, .LCPI0_0
; CHECK-NEXT: ldr q3, [x8, :lo12:.LCPI0_0]
; CHECK-NEXT: and v1.16b, v1.16b, v2.16b
; CHECK-NEXT: and v0.16b, v0.16b, v2.16b
; CHECK-NEXT: cmeq v0.4s, v0.4s, v1.4s
; CHECK-NEXT: and v0.16b, v0.16b, v3.16b
; CHECK-NEXT: addv s0, v0.4s
; CHECK-NEXT: fmov w8, s0
; CHECK-NEXT: strb w8, [x0]
; CHECK-NEXT: ret

%tst = icmp eq <4 x i22> %l, %r
store <4 x i1> %tst, ptr %addr
Expand Down
60 changes: 55 additions & 5 deletions llvm/test/CodeGen/AArch64/vec-combine-compare-to-bitmask.ll
Expand Up @@ -418,18 +418,59 @@ define i4 @convert_to_bitmask_float(<4 x float> %vec) {
ret i4 %bitmask
}

; TODO(lawben): Change this in follow-up patch to #D145301, as truncating stores fix this.
; Larger vector types don't map directly.
define i8 @no_convert_large_vector(<8 x i32> %vec) {
; Larger vector types don't map directly, but the can be split/truncated and then converted.
; After the comparison against 0, this is truncated to <8 x i16>, which is valid again.
define i8 @convert_large_vector(<8 x i32> %vec) {
; CHECK-LABEL: lCPI15_0:
; CHECK-NEXT: .short 1
; CHECK-NEXT: .short 2
; CHECK-NEXT: .short 4
; CHECK-NEXT: .short 8
; CHECK-NEXT: .short 16
; CHECK-NEXT: .short 32
; CHECK-NEXT: .short 64
; CHECK-NEXT: .short 128

; CHECK-LABEL: convert_large_vector:
; CHECK: cmeq.4s v1, v1, #0
; CHECK-NOT: addv
; CHECK: Lloh30:
; CHECK-NEXT: adrp x8, lCPI15_0@PAGE
; CHECK-NEXT: cmeq.4s v1, v1, #0
; CHECK-NEXT: cmeq.4s v0, v0, #0
; CHECK-NEXT: uzp1.8h v0, v0, v1
; CHECK-NEXT: Lloh31:
; CHECK-NEXT: ldr q1, [x8, lCPI15_0@PAGEOFF]
; CHECK-NEXT: bic.16b v0, v1, v0
; CHECK-NEXT: addv.8h h0, v0
; CHECK-NEXT: fmov w8, s0
; CHECK-NEXT: and w0, w8, #0xff
; CHECK-NEXT: add sp, sp, #16
; CHECK-NEXT: ret

%cmp_result = icmp ne <8 x i32> %vec, zeroinitializer
%bitmask = bitcast <8 x i1> %cmp_result to i8
ret i8 %bitmask
}

define i4 @convert_legalized_illegal_element_size(<4 x i22> %vec) {
; CHECK-LABEL: convert_legalized_illegal_element_size
; CHECK: ; %bb.0:
; CHECK-NEXT: movi.4s v1, #63, msl #16
; CHECK-NEXT: Lloh32:
; CHECK-NEXT: adrp x8, lCPI16_0@PAGE
; CHECK-NEXT: cmtst.4s v0, v0, v1
; CHECK-NEXT: Lloh33:
; CHECK-NEXT: ldr d1, [x8, lCPI16_0@PAGEOFF]
; CHECK-NEXT: xtn.4h v0, v0
; CHECK-NEXT: and.8b v0, v0, v1
; CHECK-NEXT: addv.4h h0, v0
; CHECK-NEXT: fmov w0, s0
; CHECK-NEXT: ret

%cmp_result = icmp ne <4 x i22> %vec, zeroinitializer
%bitmask = bitcast <4 x i1> %cmp_result to i4
ret i4 %bitmask
}

; This may still be converted as a v8i8 after the vector concat (but not as v4iX).
define i8 @no_direct_convert_for_bad_concat(<4 x i32> %vec) {
; CHECK-LABEL: no_direct_convert_for_bad_concat:
Expand All @@ -450,3 +491,12 @@ define <8 x i1> @no_convert_without_direct_bitcast(<8 x i16> %vec) {
%cmp_result = icmp ne <8 x i16> %vec, zeroinitializer
ret <8 x i1> %cmp_result
}

define i6 @no_combine_illegal_num_elements(<6 x i32> %vec) {
; CHECK-LABEL: no_combine_illegal_num_elements
; CHECK-NOT: addv

%cmp_result = icmp ne <6 x i32> %vec, zeroinitializer
%bitmask = bitcast <6 x i1> %cmp_result to i6
ret i6 %bitmask
}

0 comments on commit cd68e17

Please sign in to comment.