diff --git a/llvm/lib/Target/AArch64/AArch64TargetTransformInfo.cpp b/llvm/lib/Target/AArch64/AArch64TargetTransformInfo.cpp index cdd2750521d2c..27343fd1f9439 100644 --- a/llvm/lib/Target/AArch64/AArch64TargetTransformInfo.cpp +++ b/llvm/lib/Target/AArch64/AArch64TargetTransformInfo.cpp @@ -1630,6 +1630,36 @@ static std::optional instCombineSVETBL(InstCombiner &IC, return IC.replaceInstUsesWith(II, VectorSplat); } +static std::optional instCombineSVEUzp1(InstCombiner &IC, + IntrinsicInst &II) { + Value *A, *B; + Type *RetTy = II.getType(); + constexpr Intrinsic::ID FromSVB = Intrinsic::aarch64_sve_convert_from_svbool; + constexpr Intrinsic::ID ToSVB = Intrinsic::aarch64_sve_convert_to_svbool; + + // uzp1(to_svbool(A), to_svbool(B)) --> + // uzp1(from_svbool(to_svbool(A)), from_svbool(to_svbool(B))) --> + if ((match(II.getArgOperand(0), + m_Intrinsic(m_Intrinsic(m_Value(A)))) && + match(II.getArgOperand(1), + m_Intrinsic(m_Intrinsic(m_Value(B))))) || + (match(II.getArgOperand(0), m_Intrinsic(m_Value(A))) && + match(II.getArgOperand(1), m_Intrinsic(m_Value(B))))) { + auto *TyA = cast(A->getType()); + if (TyA == B->getType() && + RetTy == ScalableVectorType::getDoubleElementsVectorType(TyA)) { + auto *SubVec = IC.Builder.CreateInsertVector( + RetTy, PoisonValue::get(RetTy), A, IC.Builder.getInt64(0)); + auto *ConcatVec = IC.Builder.CreateInsertVector( + RetTy, SubVec, B, IC.Builder.getInt64(TyA->getMinNumElements())); + ConcatVec->takeName(&II); + return IC.replaceInstUsesWith(II, ConcatVec); + } + } + + return std::nullopt; +} + static std::optional instCombineSVEZip(InstCombiner &IC, IntrinsicInst &II) { // zip1(uzp1(A, B), uzp2(A, B)) --> A @@ -2012,6 +2042,8 @@ AArch64TTIImpl::instCombineIntrinsic(InstCombiner &IC, case Intrinsic::aarch64_sve_sunpkhi: case Intrinsic::aarch64_sve_sunpklo: return instCombineSVEUnpack(IC, II); + case Intrinsic::aarch64_sve_uzp1: + return instCombineSVEUzp1(IC, II); case Intrinsic::aarch64_sve_zip1: case Intrinsic::aarch64_sve_zip2: return instCombineSVEZip(IC, II); diff --git a/llvm/test/Transforms/InstCombine/AArch64/sve-intrinsic-opts-uzp1.ll b/llvm/test/Transforms/InstCombine/AArch64/sve-intrinsic-opts-uzp1.ll new file mode 100644 index 0000000000000..0233e0ae57029 --- /dev/null +++ b/llvm/test/Transforms/InstCombine/AArch64/sve-intrinsic-opts-uzp1.ll @@ -0,0 +1,140 @@ +; NOTE: Assertions have been autogenerated by utils/update_test_checks.py UTC_ARGS: --version 4 +; RUN: opt -S -passes=instcombine -mtriple=aarch64 < %s | FileCheck %s + +; Transform the SVE idiom used to concatenate two vectors into target agnostic IR. + +define @reinterpt_uzp1_1( %cmp0, %cmp1) { +; CHECK-LABEL: define @reinterpt_uzp1_1( +; CHECK-SAME: [[CMP0:%.*]], [[CMP1:%.*]]) { +; CHECK-NEXT: [[TMP1:%.*]] = call @llvm.vector.insert.nxv8i1.nxv4i1( poison, [[CMP0]], i64 0) +; CHECK-NEXT: [[UZ1:%.*]] = call @llvm.vector.insert.nxv8i1.nxv4i1( [[TMP1]], [[CMP1]], i64 4) +; CHECK-NEXT: ret [[UZ1]] +; + %1 = tail call @llvm.aarch64.sve.convert.to.svbool.nxv4i1( %cmp0) + %2 = tail call @llvm.aarch64.sve.convert.from.svbool.nxv8i1( %1) + %3 = tail call @llvm.aarch64.sve.convert.to.svbool.nxv4i1( %cmp1) + %4 = tail call @llvm.aarch64.sve.convert.from.svbool.nxv8i1( %3) + %uz1 = tail call @llvm.aarch64.sve.uzp1.nxv8i1( %2, %4) + ret %uz1 +} + +define @reinterpt_uzp1_2( %cmp0, %cmp1) { +; CHECK-LABEL: define @reinterpt_uzp1_2( +; CHECK-SAME: [[CMP0:%.*]], [[CMP1:%.*]]) { +; CHECK-NEXT: [[TMP1:%.*]] = tail call @llvm.aarch64.sve.convert.to.svbool.nxv2i1( [[CMP0]]) +; CHECK-NEXT: [[TMP2:%.*]] = tail call @llvm.aarch64.sve.convert.from.svbool.nxv4i1( [[TMP1]]) +; CHECK-NEXT: [[TMP3:%.*]] = tail call @llvm.aarch64.sve.convert.to.svbool.nxv2i1( [[CMP1]]) +; CHECK-NEXT: [[TMP4:%.*]] = tail call @llvm.aarch64.sve.convert.from.svbool.nxv4i1( [[TMP3]]) +; CHECK-NEXT: [[TMP5:%.*]] = call @llvm.vector.insert.nxv8i1.nxv4i1( poison, [[TMP2]], i64 0) +; CHECK-NEXT: [[UZ1:%.*]] = call @llvm.vector.insert.nxv8i1.nxv4i1( [[TMP5]], [[TMP4]], i64 4) +; CHECK-NEXT: ret [[UZ1]] +; + %1 = tail call @llvm.aarch64.sve.convert.to.svbool.nxv2i1( %cmp0) + %2 = tail call @llvm.aarch64.sve.convert.from.svbool.nxv4i1( %1) + %3 = tail call @llvm.aarch64.sve.convert.to.svbool.nxv2i1( %cmp1) + %4 = tail call @llvm.aarch64.sve.convert.from.svbool.nxv4i1( %3) + %5 = tail call @llvm.aarch64.sve.convert.to.svbool.nxv4i1( %2) + %6 = tail call @llvm.aarch64.sve.convert.from.svbool.nxv8i1( %5) + %7 = tail call @llvm.aarch64.sve.convert.to.svbool.nxv4i1( %4) + %8 = tail call @llvm.aarch64.sve.convert.from.svbool.nxv8i1( %7) + %uz1 = tail call @llvm.aarch64.sve.uzp1.nxv8i1( %6, %8) + ret %uz1 +} + +define @reinterpt_uzp1_3( %cmp0, %cmp1, %cmp2, %cmp3) { +; CHECK-LABEL: define @reinterpt_uzp1_3( +; CHECK-SAME: [[CMP0:%.*]], [[CMP1:%.*]], [[CMP2:%.*]], [[CMP3:%.*]]) { +; CHECK-NEXT: [[TMP1:%.*]] = call @llvm.vector.insert.nxv8i1.nxv4i1( poison, [[CMP0]], i64 0) +; CHECK-NEXT: [[UZ1_1:%.*]] = call @llvm.vector.insert.nxv8i1.nxv4i1( [[TMP1]], [[CMP1]], i64 4) +; CHECK-NEXT: [[TMP2:%.*]] = call @llvm.vector.insert.nxv8i1.nxv4i1( poison, [[CMP2]], i64 0) +; CHECK-NEXT: [[UZ1_2:%.*]] = call @llvm.vector.insert.nxv8i1.nxv4i1( [[TMP2]], [[CMP3]], i64 4) +; CHECK-NEXT: [[TMP3:%.*]] = call @llvm.vector.insert.nxv16i1.nxv8i1( poison, [[UZ1_1]], i64 0) +; CHECK-NEXT: [[UZ3:%.*]] = call @llvm.vector.insert.nxv16i1.nxv8i1( [[TMP3]], [[UZ1_2]], i64 8) +; CHECK-NEXT: ret [[UZ3]] +; + %1 = tail call @llvm.aarch64.sve.convert.to.svbool.nxv4i1( %cmp0) + %2 = tail call @llvm.aarch64.sve.convert.from.svbool.nxv8i1( %1) + %3 = tail call @llvm.aarch64.sve.convert.to.svbool.nxv4i1( %cmp1) + %4 = tail call @llvm.aarch64.sve.convert.from.svbool.nxv8i1( %3) + %uz1_1 = tail call @llvm.aarch64.sve.uzp1.nxv8i1( %2, %4) + %5 = tail call @llvm.aarch64.sve.convert.to.svbool.nxv4i1( %cmp2) + %6 = tail call @llvm.aarch64.sve.convert.from.svbool.nxv8i1( %5) + %7 = tail call @llvm.aarch64.sve.convert.to.svbool.nxv4i1( %cmp3) + %8 = tail call @llvm.aarch64.sve.convert.from.svbool.nxv8i1( %7) + %uz1_2 = tail call @llvm.aarch64.sve.uzp1.nxv8i1( %6, %8) + %9 = tail call @llvm.aarch64.sve.convert.to.svbool.nxv8i1( %uz1_1) + %10 = tail call @llvm.aarch64.sve.convert.to.svbool.nxv8i1( %uz1_2) + %uz3 = tail call @llvm.aarch64.sve.uzp1.nxv16i1( %9, %10) + ret %uz3 +} + +define @neg1( %cmp0, %cmp1) { +; CHECK-LABEL: define @neg1( +; CHECK-SAME: [[CMP0:%.*]], [[CMP1:%.*]]) { +; CHECK-NEXT: [[UZ1:%.*]] = tail call @llvm.aarch64.sve.uzp1.nxv4i1( [[CMP0]], [[CMP1]]) +; CHECK-NEXT: ret [[UZ1]] +; + %uz1 = tail call @llvm.aarch64.sve.uzp1.nxv4i1( %cmp0, %cmp1) + ret %uz1 +} + +define @neg2( %cmp0, %cmp1) { +; CHECK-LABEL: define @neg2( +; CHECK-SAME: [[CMP0:%.*]], [[CMP1:%.*]]) { +; CHECK-NEXT: [[TMP1:%.*]] = tail call @llvm.aarch64.sve.convert.to.svbool.nxv2i1( [[CMP0]]) +; CHECK-NEXT: [[TMP2:%.*]] = tail call @llvm.aarch64.sve.convert.from.svbool.nxv8i1( [[TMP1]]) +; CHECK-NEXT: [[TMP3:%.*]] = tail call @llvm.aarch64.sve.convert.to.svbool.nxv4i1( [[CMP1]]) +; CHECK-NEXT: [[TMP4:%.*]] = tail call @llvm.aarch64.sve.convert.from.svbool.nxv8i1( [[TMP3]]) +; CHECK-NEXT: [[UZ1:%.*]] = tail call @llvm.aarch64.sve.uzp1.nxv8i1( [[TMP2]], [[TMP4]]) +; CHECK-NEXT: ret [[UZ1]] +; + %1 = tail call @llvm.aarch64.sve.convert.to.svbool.nxv2i1( %cmp0) + %2 = tail call @llvm.aarch64.sve.convert.from.svbool.nxv8i1( %1) + %3 = tail call @llvm.aarch64.sve.convert.to.svbool.nxv4i1( %cmp1) + %4 = tail call @llvm.aarch64.sve.convert.from.svbool.nxv8i1( %3) + %uz1 = tail call @llvm.aarch64.sve.uzp1.nxv8i1( %2, %4) + ret %uz1 +} + +define @neg3( %cmp0, %cmp1) { +; CHECK-LABEL: define @neg3( +; CHECK-SAME: [[CMP0:%.*]], [[CMP1:%.*]]) { +; CHECK-NEXT: [[TMP1:%.*]] = tail call @llvm.aarch64.sve.convert.to.svbool.nxv4i1( [[CMP1]]) +; CHECK-NEXT: [[TMP2:%.*]] = tail call @llvm.aarch64.sve.convert.from.svbool.nxv8i1( [[TMP1]]) +; CHECK-NEXT: [[UZ1:%.*]] = tail call @llvm.aarch64.sve.uzp1.nxv8i1( [[CMP0]], [[TMP2]]) +; CHECK-NEXT: ret [[UZ1]] +; + %1 = tail call @llvm.aarch64.sve.convert.to.svbool.nxv4i1( %cmp1) + %2 = tail call @llvm.aarch64.sve.convert.from.svbool.nxv8i1( %1) + %uz1 = tail call @llvm.aarch64.sve.uzp1.nxv8i1( %cmp0, %2) + ret %uz1 +} + +define @neg4( %cmp0, %cmp1) { +; CHECK-LABEL: define @neg4( +; CHECK-SAME: [[CMP0:%.*]], [[CMP1:%.*]]) { +; CHECK-NEXT: [[TMP1:%.*]] = tail call @llvm.aarch64.sve.convert.to.svbool.nxv2i1( [[CMP0]]) +; CHECK-NEXT: [[TMP2:%.*]] = tail call @llvm.aarch64.sve.convert.from.svbool.nxv8i1( [[TMP1]]) +; CHECK-NEXT: [[TMP3:%.*]] = tail call @llvm.aarch64.sve.convert.to.svbool.nxv2i1( [[CMP1]]) +; CHECK-NEXT: [[TMP4:%.*]] = tail call @llvm.aarch64.sve.convert.from.svbool.nxv8i1( [[TMP3]]) +; CHECK-NEXT: [[UZ1:%.*]] = tail call @llvm.aarch64.sve.uzp1.nxv8i1( [[TMP2]], [[TMP4]]) +; CHECK-NEXT: ret [[UZ1]] +; + %1 = tail call @llvm.aarch64.sve.convert.to.svbool.nxv2i1( %cmp0) + %2 = tail call @llvm.aarch64.sve.convert.from.svbool.nxv8i1( %1) + %3 = tail call @llvm.aarch64.sve.convert.to.svbool.nxv2i1( %cmp1) + %4 = tail call @llvm.aarch64.sve.convert.from.svbool.nxv8i1( %3) + %uz1 = tail call @llvm.aarch64.sve.uzp1.nxv8i1( %2, %4) + ret %uz1 +} + +declare @llvm.aarch64.sve.uzp1.nxv4i1(, ) +declare @llvm.aarch64.sve.uzp1.nxv8i1(, ) +declare @llvm.aarch64.sve.uzp1.nxv16i1(, ) + +declare @llvm.aarch64.sve.convert.to.svbool.nxv8i1() +declare @llvm.aarch64.sve.convert.to.svbool.nxv4i1() +declare @llvm.aarch64.sve.convert.to.svbool.nxv2i1() +declare @llvm.aarch64.sve.convert.from.svbool.nxv8i1() +declare @llvm.aarch64.sve.convert.from.svbool.nxv4i1() +