Skip to content

Commit

Permalink
[InstCombine] Extend SVEVectorFuseMulAddSub to support newly added "u…
Browse files Browse the repository at this point in the history
…ndef" intrinsics.

D143767 will change the intrinsics used to lower floating-point
svadd_x, svmul_x and svsub_x builtins. This will result in the
combines added as part of D140200 to no longer fire in all cases.
This patch extends the existing combines for contraction to cover
fadd_u, fmul_u and fsub_u intrinsics.

Differential Revision: https://reviews.llvm.org/D144413
  • Loading branch information
paulwalker-arm committed Mar 12, 2023
1 parent 4989779 commit 2f887c9
Show file tree
Hide file tree
Showing 2 changed files with 35 additions and 0 deletions.
8 changes: 8 additions & 0 deletions llvm/lib/Target/AArch64/AArch64TargetTransformInfo.cpp
Expand Up @@ -1619,9 +1619,17 @@ AArch64TTIImpl::instCombineIntrinsic(InstCombiner &IC,
case Intrinsic::aarch64_sve_fadd:
case Intrinsic::aarch64_sve_add:
return instCombineSVEVectorAdd(IC, II);
case Intrinsic::aarch64_sve_fadd_u:
return instCombineSVEVectorFuseMulAddSub<Intrinsic::aarch64_sve_fmul_u,
Intrinsic::aarch64_sve_fmla_u>(
IC, II, true);
case Intrinsic::aarch64_sve_fsub:
case Intrinsic::aarch64_sve_sub:
return instCombineSVEVectorSub(IC, II);
case Intrinsic::aarch64_sve_fsub_u:
return instCombineSVEVectorFuseMulAddSub<Intrinsic::aarch64_sve_fmul_u,
Intrinsic::aarch64_sve_fmls_u>(
IC, II, true);
case Intrinsic::aarch64_sve_tbl:
return instCombineSVETBL(IC, II);
case Intrinsic::aarch64_sve_uunpkhi:
Expand Down
Expand Up @@ -15,6 +15,18 @@ define <vscale x 8 x half> @combine_fmla(<vscale x 16 x i1> %p, <vscale x 8 x ha
ret <vscale x 8 x half> %3
}

define <vscale x 8 x half> @combine_fmla_u(<vscale x 16 x i1> %p, <vscale x 8 x half> %a, <vscale x 8 x half> %b, <vscale x 8 x half> %c) #0 {
; CHECK-LABEL: @combine_fmla_u(
; CHECK-NEXT: [[TMP1:%.*]] = tail call <vscale x 8 x i1> @llvm.aarch64.sve.convert.from.svbool.nxv8i1(<vscale x 16 x i1> [[P:%.*]])
; CHECK-NEXT: [[TMP2:%.*]] = call fast <vscale x 8 x half> @llvm.aarch64.sve.fmla.u.nxv8f16(<vscale x 8 x i1> [[TMP1]], <vscale x 8 x half> [[C:%.*]], <vscale x 8 x half> [[A:%.*]], <vscale x 8 x half> [[B:%.*]])
; CHECK-NEXT: ret <vscale x 8 x half> [[TMP2]]
;
%1 = tail call <vscale x 8 x i1> @llvm.aarch64.sve.convert.from.svbool.nxv8i1(<vscale x 16 x i1> %p)
%2 = tail call fast <vscale x 8 x half> @llvm.aarch64.sve.fmul.u.nxv8f16(<vscale x 8 x i1> %1, <vscale x 8 x half> %a, <vscale x 8 x half> %b)
%3 = tail call fast <vscale x 8 x half> @llvm.aarch64.sve.fadd.u.nxv8f16(<vscale x 8 x i1> %1, <vscale x 8 x half> %c, <vscale x 8 x half> %2)
ret <vscale x 8 x half> %3
}

define <vscale x 16 x i8> @combine_mla_i8(<vscale x 16 x i1> %p, <vscale x 16 x i8> %a, <vscale x 16 x i8> %b, <vscale x 16 x i8> %c) #0 {
; CHECK-LABEL: @combine_mla_i8(
; CHECK-NEXT: [[TMP1:%.*]] = call <vscale x 16 x i8> @llvm.aarch64.sve.mla.nxv16i8(<vscale x 16 x i1> [[P:%.*]], <vscale x 16 x i8> [[C:%.*]], <vscale x 16 x i8> [[A:%.*]], <vscale x 16 x i8> [[B:%.*]])
Expand Down Expand Up @@ -59,6 +71,18 @@ define <vscale x 8 x half> @combine_fmls(<vscale x 16 x i1> %p, <vscale x 8 x ha
ret <vscale x 8 x half> %3
}

define <vscale x 8 x half> @combine_fmls_u(<vscale x 16 x i1> %p, <vscale x 8 x half> %a, <vscale x 8 x half> %b, <vscale x 8 x half> %c) #0 {
; CHECK-LABEL: @combine_fmls_u(
; CHECK-NEXT: [[TMP1:%.*]] = tail call <vscale x 8 x i1> @llvm.aarch64.sve.convert.from.svbool.nxv8i1(<vscale x 16 x i1> [[P:%.*]])
; CHECK-NEXT: [[TMP2:%.*]] = call fast <vscale x 8 x half> @llvm.aarch64.sve.fmls.u.nxv8f16(<vscale x 8 x i1> [[TMP1]], <vscale x 8 x half> [[C:%.*]], <vscale x 8 x half> [[A:%.*]], <vscale x 8 x half> [[B:%.*]])
; CHECK-NEXT: ret <vscale x 8 x half> [[TMP2]]
;
%1 = tail call <vscale x 8 x i1> @llvm.aarch64.sve.convert.from.svbool.nxv8i1(<vscale x 16 x i1> %p)
%2 = tail call fast <vscale x 8 x half> @llvm.aarch64.sve.fmul.u.nxv8f16(<vscale x 8 x i1> %1, <vscale x 8 x half> %a, <vscale x 8 x half> %b)
%3 = tail call fast <vscale x 8 x half> @llvm.aarch64.sve.fsub.u.nxv8f16(<vscale x 8 x i1> %1, <vscale x 8 x half> %c, <vscale x 8 x half> %2)
ret <vscale x 8 x half> %3
}

define <vscale x 16 x i8> @combine_mls_i8(<vscale x 16 x i1> %p, <vscale x 16 x i8> %a, <vscale x 16 x i8> %b, <vscale x 16 x i8> %c) #0 {
; CHECK-LABEL: @combine_mls_i8(
; CHECK-NEXT: [[TMP1:%.*]] = call <vscale x 16 x i8> @llvm.aarch64.sve.mls.nxv16i8(<vscale x 16 x i1> [[P:%.*]], <vscale x 16 x i8> [[C:%.*]], <vscale x 16 x i8> [[A:%.*]], <vscale x 16 x i8> [[B:%.*]])
Expand Down Expand Up @@ -173,6 +197,9 @@ declare <vscale x 2 x i1> @llvm.aarch64.sve.convert.from.svbool.nxv2i1(<vscale x
declare <vscale x 8 x half> @llvm.aarch64.sve.fmul.nxv8f16(<vscale x 8 x i1>, <vscale x 8 x half>, <vscale x 8 x half>)
declare <vscale x 8 x half> @llvm.aarch64.sve.fadd.nxv8f16(<vscale x 8 x i1>, <vscale x 8 x half>, <vscale x 8 x half>)
declare <vscale x 8 x half> @llvm.aarch64.sve.fsub.nxv8f16(<vscale x 8 x i1>, <vscale x 8 x half>, <vscale x 8 x half>)
declare <vscale x 8 x half> @llvm.aarch64.sve.fmul.u.nxv8f16(<vscale x 8 x i1>, <vscale x 8 x half>, <vscale x 8 x half>)
declare <vscale x 8 x half> @llvm.aarch64.sve.fadd.u.nxv8f16(<vscale x 8 x i1>, <vscale x 8 x half>, <vscale x 8 x half>)
declare <vscale x 8 x half> @llvm.aarch64.sve.fsub.u.nxv8f16(<vscale x 8 x i1>, <vscale x 8 x half>, <vscale x 8 x half>)
declare <vscale x 16 x i8> @llvm.aarch64.sve.mul.nxv16i8(<vscale x 16 x i1>, <vscale x 16 x i8>, <vscale x 16 x i8>)
declare <vscale x 16 x i8> @llvm.aarch64.sve.add.nxv16i8(<vscale x 16 x i1>, <vscale x 16 x i8>, <vscale x 16 x i8>)
declare <vscale x 16 x i8> @llvm.aarch64.sve.sub.nxv16i8(<vscale x 16 x i1>, <vscale x 16 x i8>, <vscale x 16 x i8>)
Expand Down

0 comments on commit 2f887c9

Please sign in to comment.