Skip to content

Commit

Permalink
[AArch64][SVE] Implement missing lowering for extract_subvector for p…
Browse files Browse the repository at this point in the history
…redicates.

Reviewed By: efriedma

Differential Revision: https://reviews.llvm.org/D118057
  • Loading branch information
sdesmalen-arm committed Jan 27, 2022
1 parent fdd3e2c commit c9da81d
Show file tree
Hide file tree
Showing 2 changed files with 102 additions and 2 deletions.
27 changes: 25 additions & 2 deletions llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
Expand Up @@ -1248,6 +1248,7 @@ AArch64TargetLowering::AArch64TargetLowering(const TargetMachine &TM,
setOperationAction(ISD::SELECT_CC, VT, Expand);
setOperationAction(ISD::EXTRACT_VECTOR_ELT, VT, Custom);
setOperationAction(ISD::INSERT_VECTOR_ELT, VT, Custom);
setOperationAction(ISD::INSERT_SUBVECTOR, VT, Custom);

// There are no legal MVT::nxv16f## based types.
if (VT != MVT::nxv16i1) {
Expand Down Expand Up @@ -11038,6 +11039,28 @@ SDValue AArch64TargetLowering::LowerINSERT_SUBVECTOR(SDValue Op,
if (!isTypeLegal(VT))
return SDValue();

// Break down insert_subvector into simpler parts.
if (VT.getVectorElementType() == MVT::i1) {
unsigned NumElts = VT.getVectorMinNumElements();
EVT HalfVT = VT.getHalfNumVectorElementsVT(*DAG.getContext());

SDValue Lo, Hi;
Lo = DAG.getNode(ISD::EXTRACT_SUBVECTOR, DL, HalfVT, Vec0,
DAG.getVectorIdxConstant(0, DL));
Hi = DAG.getNode(ISD::EXTRACT_SUBVECTOR, DL, HalfVT, Vec0,
DAG.getVectorIdxConstant(NumElts / 2, DL));
if (Idx < (NumElts / 2)) {
SDValue NewLo = DAG.getNode(ISD::INSERT_SUBVECTOR, DL, HalfVT, Lo, Vec1,
DAG.getVectorIdxConstant(Idx, DL));
return DAG.getNode(AArch64ISD::UZP1, DL, VT, NewLo, Hi);
} else {
SDValue NewHi =
DAG.getNode(ISD::INSERT_SUBVECTOR, DL, HalfVT, Hi, Vec1,
DAG.getVectorIdxConstant(Idx - (NumElts / 2), DL));
return DAG.getNode(AArch64ISD::UZP1, DL, VT, Lo, NewHi);
}
}

// Ensure the subvector is half the size of the main vector.
if (VT.getVectorElementCount() != (InVT.getVectorElementCount() * 2))
return SDValue();
Expand Down Expand Up @@ -12961,7 +12984,7 @@ bool AArch64TargetLowering::isExtractSubvectorCheap(EVT ResVT, EVT SrcVT,
if (!isOperationLegalOrCustom(ISD::EXTRACT_SUBVECTOR, ResVT))
return false;

return (Index == 0 || Index == ResVT.getVectorNumElements());
return (Index == 0 || Index == ResVT.getVectorMinNumElements());
}

/// Turn vector tests of the signbit in the form of:
Expand Down Expand Up @@ -14321,6 +14344,7 @@ static SDValue performConcatVectorsCombine(SDNode *N,
static SDValue
performInsertSubvectorCombine(SDNode *N, TargetLowering::DAGCombinerInfo &DCI,
SelectionDAG &DAG) {
SDLoc DL(N);
SDValue Vec = N->getOperand(0);
SDValue SubVec = N->getOperand(1);
uint64_t IdxVal = N->getConstantOperandVal(2);
Expand All @@ -14346,7 +14370,6 @@ performInsertSubvectorCombine(SDNode *N, TargetLowering::DAGCombinerInfo &DCI,
// Fold insert_subvector -> concat_vectors
// insert_subvector(Vec,Sub,lo) -> concat_vectors(Sub,extract(Vec,hi))
// insert_subvector(Vec,Sub,hi) -> concat_vectors(extract(Vec,lo),Sub)
SDLoc DL(N);
SDValue Lo, Hi;
if (IdxVal == 0) {
Lo = SubVec;
Expand Down
77 changes: 77 additions & 0 deletions llvm/test/CodeGen/AArch64/sve-insert-vector.ll
Expand Up @@ -501,6 +501,80 @@ define <vscale x 8 x bfloat> @insert_nxv8bf16_v8bf16(<vscale x 8 x bfloat> %sv0,
ret <vscale x 8 x bfloat> %v0
}

; Test predicate inserts of half size.
define <vscale x 16 x i1> @insert_nxv16i1_nxv8i1_0(<vscale x 16 x i1> %vec, <vscale x 8 x i1> %sv) {
; CHECK-LABEL: insert_nxv16i1_nxv8i1_0:
; CHECK: // %bb.0:
; CHECK-NEXT: punpkhi p0.h, p0.b
; CHECK-NEXT: uzp1 p0.b, p1.b, p0.b
; CHECK-NEXT: ret
%v0 = call <vscale x 16 x i1> @llvm.experimental.vector.insert.nx16i1.nxv8i1(<vscale x 16 x i1> %vec, <vscale x 8 x i1> %sv, i64 0)
ret <vscale x 16 x i1> %v0
}

define <vscale x 16 x i1> @insert_nxv16i1_nxv8i1_8(<vscale x 16 x i1> %vec, <vscale x 8 x i1> %sv) {
; CHECK-LABEL: insert_nxv16i1_nxv8i1_8:
; CHECK: // %bb.0:
; CHECK-NEXT: punpklo p0.h, p0.b
; CHECK-NEXT: uzp1 p0.b, p0.b, p1.b
; CHECK-NEXT: ret
%v0 = call <vscale x 16 x i1> @llvm.experimental.vector.insert.nx16i1.nxv8i1(<vscale x 16 x i1> %vec, <vscale x 8 x i1> %sv, i64 8)
ret <vscale x 16 x i1> %v0
}

; Test predicate inserts of less than half the size.
define <vscale x 16 x i1> @insert_nxv16i1_nxv4i1_0(<vscale x 16 x i1> %vec, <vscale x 4 x i1> %sv) {
; CHECK-LABEL: insert_nxv16i1_nxv4i1_0:
; CHECK: // %bb.0:
; CHECK-NEXT: punpklo p2.h, p0.b
; CHECK-NEXT: punpkhi p0.h, p0.b
; CHECK-NEXT: punpkhi p2.h, p2.b
; CHECK-NEXT: uzp1 p1.h, p1.h, p2.h
; CHECK-NEXT: uzp1 p0.b, p1.b, p0.b
; CHECK-NEXT: ret
%v0 = call <vscale x 16 x i1> @llvm.experimental.vector.insert.nx16i1.nxv4i1(<vscale x 16 x i1> %vec, <vscale x 4 x i1> %sv, i64 0)
ret <vscale x 16 x i1> %v0
}

define <vscale x 16 x i1> @insert_nxv16i1_nxv4i1_12(<vscale x 16 x i1> %vec, <vscale x 4 x i1> %sv) {
; CHECK-LABEL: insert_nxv16i1_nxv4i1_12:
; CHECK: // %bb.0:
; CHECK-NEXT: punpkhi p2.h, p0.b
; CHECK-NEXT: punpklo p0.h, p0.b
; CHECK-NEXT: punpklo p2.h, p2.b
; CHECK-NEXT: uzp1 p1.h, p2.h, p1.h
; CHECK-NEXT: uzp1 p0.b, p0.b, p1.b
; CHECK-NEXT: ret
%v0 = call <vscale x 16 x i1> @llvm.experimental.vector.insert.nx16i1.nxv4i1(<vscale x 16 x i1> %vec, <vscale x 4 x i1> %sv, i64 12)
ret <vscale x 16 x i1> %v0
}

; Test predicate insert into undef/zero
define <vscale x 16 x i1> @insert_nxv16i1_nxv4i1_into_zero(<vscale x 4 x i1> %sv) {
; CHECK-LABEL: insert_nxv16i1_nxv4i1_into_zero:
; CHECK: // %bb.0:
; CHECK-NEXT: pfalse p1.b
; CHECK-NEXT: punpklo p2.h, p1.b
; CHECK-NEXT: punpkhi p1.h, p1.b
; CHECK-NEXT: punpkhi p2.h, p2.b
; CHECK-NEXT: uzp1 p0.h, p0.h, p2.h
; CHECK-NEXT: uzp1 p0.b, p0.b, p1.b
; CHECK-NEXT: ret
%v0 = call <vscale x 16 x i1> @llvm.experimental.vector.insert.nx16i1.nxv4i1(<vscale x 16 x i1> zeroinitializer, <vscale x 4 x i1> %sv, i64 0)
ret <vscale x 16 x i1> %v0
}

define <vscale x 16 x i1> @insert_nxv16i1_nxv4i1_into_poison(<vscale x 4 x i1> %sv) {
; CHECK-LABEL: insert_nxv16i1_nxv4i1_into_poison:
; CHECK: // %bb.0:
; CHECK-NEXT: uzp1 p0.h, p0.h, p0.h
; CHECK-NEXT: uzp1 p0.b, p0.b, p0.b
; CHECK-NEXT: ret
%v0 = call <vscale x 16 x i1> @llvm.experimental.vector.insert.nx16i1.nxv4i1(<vscale x 16 x i1> poison, <vscale x 4 x i1> %sv, i64 0)
ret <vscale x 16 x i1> %v0
}


declare <vscale x 3 x i32> @llvm.experimental.vector.insert.nxv3i32.nxv2i32(<vscale x 3 x i32>, <vscale x 2 x i32>, i64)
declare <vscale x 3 x float> @llvm.experimental.vector.insert.nxv3f32.nxv2f32(<vscale x 3 x float>, <vscale x 2 x float>, i64)
declare <vscale x 6 x i32> @llvm.experimental.vector.insert.nxv6i32.nxv2i32(<vscale x 6 x i32>, <vscale x 2 x i32>, i64)
Expand All @@ -511,3 +585,6 @@ declare <vscale x 8 x bfloat> @llvm.experimental.vector.insert.nxv8bf16.v8bf16(<
declare <vscale x 4 x bfloat> @llvm.experimental.vector.insert.nxv4bf16.nxv4bf16(<vscale x 4 x bfloat>, <vscale x 4 x bfloat>, i64)
declare <vscale x 4 x bfloat> @llvm.experimental.vector.insert.nxv4bf16.v4bf16(<vscale x 4 x bfloat>, <4 x bfloat>, i64)
declare <vscale x 2 x bfloat> @llvm.experimental.vector.insert.nxv2bf16.nxv2bf16(<vscale x 2 x bfloat>, <vscale x 2 x bfloat>, i64)

declare <vscale x 16 x i1> @llvm.experimental.vector.insert.nx16i1.nxv4i1(<vscale x 16 x i1>, <vscale x 4 x i1>, i64)
declare <vscale x 16 x i1> @llvm.experimental.vector.insert.nx16i1.nxv8i1(<vscale x 16 x i1>, <vscale x 8 x i1>, i64)

0 comments on commit c9da81d

Please sign in to comment.