diff --git a/llvm/lib/Target/AArch64/AArch64TargetTransformInfo.cpp b/llvm/lib/Target/AArch64/AArch64TargetTransformInfo.cpp index 1e276749045ed..2df5f091a9db7 100644 --- a/llvm/lib/Target/AArch64/AArch64TargetTransformInfo.cpp +++ b/llvm/lib/Target/AArch64/AArch64TargetTransformInfo.cpp @@ -416,12 +416,76 @@ static Optional processPhiNode(InstCombiner &IC, return IC.replaceInstUsesWith(II, NPN); } +// (from_svbool (binop (to_svbool pred) (svbool_t _) (svbool_t _)))) +// => (binop (pred) (from_svbool _) (from_svbool _)) +// +// The above transformation eliminates a `to_svbool` in the predicate +// operand of bitwise operation `binop` by narrowing the vector width of +// the operation. For example, it would convert a ` +// and` into a ` and`. This is profitable because +// to_svbool must zero the new lanes during widening, whereas +// from_svbool is free. +static Optional tryCombineFromSVBoolBinOp(InstCombiner &IC, + IntrinsicInst &II) { + auto BinOp = dyn_cast(II.getOperand(0)); + if (!BinOp) + return None; + + auto IntrinsicID = BinOp->getIntrinsicID(); + switch (IntrinsicID) { + case Intrinsic::aarch64_sve_and_z: + case Intrinsic::aarch64_sve_bic_z: + case Intrinsic::aarch64_sve_eor_z: + case Intrinsic::aarch64_sve_nand_z: + case Intrinsic::aarch64_sve_nor_z: + case Intrinsic::aarch64_sve_orn_z: + case Intrinsic::aarch64_sve_orr_z: + break; + default: + return None; + } + + auto BinOpPred = BinOp->getOperand(0); + auto BinOpOp1 = BinOp->getOperand(1); + auto BinOpOp2 = BinOp->getOperand(2); + + auto PredIntr = dyn_cast(BinOpPred); + if (!PredIntr || + PredIntr->getIntrinsicID() != Intrinsic::aarch64_sve_convert_to_svbool) + return None; + + auto PredOp = PredIntr->getOperand(0); + auto PredOpTy = cast(PredOp->getType()); + if (PredOpTy != II.getType()) + return None; + + IRBuilder<> Builder(II.getContext()); + Builder.SetInsertPoint(&II); + + SmallVector NarrowedBinOpArgs = {PredOp}; + auto NarrowBinOpOp1 = Builder.CreateIntrinsic( + Intrinsic::aarch64_sve_convert_from_svbool, {PredOpTy}, {BinOpOp1}); + NarrowedBinOpArgs.push_back(NarrowBinOpOp1); + if (BinOpOp1 == BinOpOp2) + NarrowedBinOpArgs.push_back(NarrowBinOpOp1); + else + NarrowedBinOpArgs.push_back(Builder.CreateIntrinsic( + Intrinsic::aarch64_sve_convert_from_svbool, {PredOpTy}, {BinOpOp2})); + + auto NarrowedBinOp = + Builder.CreateIntrinsic(IntrinsicID, {PredOpTy}, NarrowedBinOpArgs); + return IC.replaceInstUsesWith(II, NarrowedBinOp); +} + static Optional instCombineConvertFromSVBool(InstCombiner &IC, IntrinsicInst &II) { // If the reinterpret instruction operand is a PHI Node if (isa(II.getArgOperand(0))) return processPhiNode(IC, II); + if (auto BinOpCombine = tryCombineFromSVBoolBinOp(IC, II)) + return BinOpCombine; + SmallVector CandidatesForRemoval; Value *Cursor = II.getOperand(0), *EarliestReplacement = nullptr; diff --git a/llvm/test/Transforms/InstCombine/AArch64/sve-intrinsic-to-svbool-binops.ll b/llvm/test/Transforms/InstCombine/AArch64/sve-intrinsic-to-svbool-binops.ll new file mode 100644 index 0000000000000..29a5f777728c8 --- /dev/null +++ b/llvm/test/Transforms/InstCombine/AArch64/sve-intrinsic-to-svbool-binops.ll @@ -0,0 +1,141 @@ +; NOTE: Assertions have been autogenerated by utils/update_test_checks.py +; RUN: opt -S -instcombine < %s | FileCheck %s + +target triple = "aarch64-unknown-linux-gnu" + +define @try_combine_svbool_binop_and_0( %a, %b, %c) { +; CHECK-LABEL: @try_combine_svbool_binop_and_0( +; CHECK-NEXT: [[TMP1:%.*]] = call @llvm.aarch64.sve.convert.from.svbool.nxv4i1( [[B:%.*]]) +; CHECK-NEXT: [[TMP2:%.*]] = call @llvm.aarch64.sve.convert.from.svbool.nxv4i1( [[C:%.*]]) +; CHECK-NEXT: [[TMP3:%.*]] = call @llvm.aarch64.sve.and.z.nxv4i1( [[A:%.*]], [[TMP1]], [[TMP2]]) +; CHECK-NEXT: ret [[TMP3]] +; + %t1 = tail call @llvm.aarch64.sve.convert.to.svbool.nxv4i1( %a) + %t2 = tail call @llvm.aarch64.sve.and.z.nxv16i1( %t1, %b, %c) + %t3 = tail call @llvm.aarch64.sve.convert.from.svbool.nxv4i1( %t2) + ret %t3 +} + +define @try_combine_svbool_binop_and_1( %a, %b) { +; CHECK-LABEL: @try_combine_svbool_binop_and_1( +; CHECK-NEXT: [[TMP1:%.*]] = call @llvm.aarch64.sve.convert.from.svbool.nxv8i1( [[B:%.*]]) +; CHECK-NEXT: [[TMP2:%.*]] = call @llvm.aarch64.sve.and.z.nxv8i1( [[A:%.*]], [[TMP1]], [[TMP1]]) +; CHECK-NEXT: ret [[TMP2]] +; + %t1 = tail call @llvm.aarch64.sve.convert.to.svbool.nxv8i1( %a) + %t2 = tail call @llvm.aarch64.sve.and.z.nxv16i1( %t1, %b, %b) + %t3 = tail call @llvm.aarch64.sve.convert.from.svbool.nxv8i1( %t2) + ret %t3 +} + +define @try_combine_svbool_binop_and_2( %a, %b) { +; CHECK-LABEL: @try_combine_svbool_binop_and_2( +; CHECK-NEXT: [[TMP1:%.*]] = call @llvm.aarch64.sve.convert.from.svbool.nxv4i1( [[B:%.*]]) +; CHECK-NEXT: [[TMP2:%.*]] = call @llvm.aarch64.sve.and.z.nxv4i1( [[A:%.*]], [[TMP1]], [[TMP1]]) +; CHECK-NEXT: ret [[TMP2]] +; + %t1 = tail call @llvm.aarch64.sve.convert.to.svbool.nxv4i1( %a) + %t2 = tail call @llvm.aarch64.sve.and.z.nxv16i1( %t1, %b, %b) + %t3 = tail call @llvm.aarch64.sve.convert.from.svbool.nxv4i1( %t2) + ret %t3 +} + +define @try_combine_svbool_binop_and_3( %a, %b) { +; CHECK-LABEL: @try_combine_svbool_binop_and_3( +; CHECK-NEXT: [[TMP1:%.*]] = call @llvm.aarch64.sve.convert.from.svbool.nxv2i1( [[B:%.*]]) +; CHECK-NEXT: [[TMP2:%.*]] = call @llvm.aarch64.sve.and.z.nxv2i1( [[A:%.*]], [[TMP1]], [[TMP1]]) +; CHECK-NEXT: ret [[TMP2]] +; + %t1 = tail call @llvm.aarch64.sve.convert.to.svbool.nxv2i1( %a) + %t2 = tail call @llvm.aarch64.sve.and.z.nxv16i1( %t1, %b, %b) + %t3 = tail call @llvm.aarch64.sve.convert.from.svbool.nxv2i1( %t2) + ret %t3 +} + +define @try_combine_svbool_binop_bic( %a, %b) { +; CHECK-LABEL: @try_combine_svbool_binop_bic( +; CHECK-NEXT: [[TMP1:%.*]] = call @llvm.aarch64.sve.convert.from.svbool.nxv8i1( [[B:%.*]]) +; CHECK-NEXT: [[TMP2:%.*]] = call @llvm.aarch64.sve.bic.z.nxv8i1( [[A:%.*]], [[TMP1]], [[TMP1]]) +; CHECK-NEXT: ret [[TMP2]] +; + %t1 = tail call @llvm.aarch64.sve.convert.to.svbool.nxv8i1( %a) + %t2 = tail call @llvm.aarch64.sve.bic.z.nxv16i1( %t1, %b, %b) + %t3 = tail call @llvm.aarch64.sve.convert.from.svbool.nxv8i1( %t2) + ret %t3 +} + +define @try_combine_svbool_binop_eor( %a, %b) { +; CHECK-LABEL: @try_combine_svbool_binop_eor( +; CHECK-NEXT: [[TMP1:%.*]] = call @llvm.aarch64.sve.convert.from.svbool.nxv8i1( [[B:%.*]]) +; CHECK-NEXT: [[TMP2:%.*]] = call @llvm.aarch64.sve.eor.z.nxv8i1( [[A:%.*]], [[TMP1]], [[TMP1]]) +; CHECK-NEXT: ret [[TMP2]] +; + %t1 = tail call @llvm.aarch64.sve.convert.to.svbool.nxv8i1( %a) + %t2 = tail call @llvm.aarch64.sve.eor.z.nxv16i1( %t1, %b, %b) + %t3 = tail call @llvm.aarch64.sve.convert.from.svbool.nxv8i1( %t2) + ret %t3 +} + +define @try_combine_svbool_binop_nand( %a, %b) { +; CHECK-LABEL: @try_combine_svbool_binop_nand( +; CHECK-NEXT: [[TMP1:%.*]] = call @llvm.aarch64.sve.convert.from.svbool.nxv8i1( [[B:%.*]]) +; CHECK-NEXT: [[TMP2:%.*]] = call @llvm.aarch64.sve.nand.z.nxv8i1( [[A:%.*]], [[TMP1]], [[TMP1]]) +; CHECK-NEXT: ret [[TMP2]] +; + %t1 = tail call @llvm.aarch64.sve.convert.to.svbool.nxv8i1( %a) + %t2 = tail call @llvm.aarch64.sve.nand.z.nxv16i1( %t1, %b, %b) + %t3 = tail call @llvm.aarch64.sve.convert.from.svbool.nxv8i1( %t2) + ret %t3 +} + +define @try_combine_svbool_binop_nor( %a, %b) { +; CHECK-LABEL: @try_combine_svbool_binop_nor( +; CHECK-NEXT: [[TMP1:%.*]] = call @llvm.aarch64.sve.convert.from.svbool.nxv8i1( [[B:%.*]]) +; CHECK-NEXT: [[TMP2:%.*]] = call @llvm.aarch64.sve.nor.z.nxv8i1( [[A:%.*]], [[TMP1]], [[TMP1]]) +; CHECK-NEXT: ret [[TMP2]] +; + %t1 = tail call @llvm.aarch64.sve.convert.to.svbool.nxv8i1( %a) + %t2 = tail call @llvm.aarch64.sve.nor.z.nxv16i1( %t1, %b, %b) + %t3 = tail call @llvm.aarch64.sve.convert.from.svbool.nxv8i1( %t2) + ret %t3 +} + +define @try_combine_svbool_binop_orn( %a, %b) { +; CHECK-LABEL: @try_combine_svbool_binop_orn( +; CHECK-NEXT: [[TMP1:%.*]] = call @llvm.aarch64.sve.convert.from.svbool.nxv8i1( [[B:%.*]]) +; CHECK-NEXT: [[TMP2:%.*]] = call @llvm.aarch64.sve.orn.z.nxv8i1( [[A:%.*]], [[TMP1]], [[TMP1]]) +; CHECK-NEXT: ret [[TMP2]] +; + %t1 = tail call @llvm.aarch64.sve.convert.to.svbool.nxv8i1( %a) + %t2 = tail call @llvm.aarch64.sve.orn.z.nxv16i1( %t1, %b, %b) + %t3 = tail call @llvm.aarch64.sve.convert.from.svbool.nxv8i1( %t2) + ret %t3 +} + +define @try_combine_svbool_binop_orr( %a, %b) { +; CHECK-LABEL: @try_combine_svbool_binop_orr( +; CHECK-NEXT: [[TMP1:%.*]] = call @llvm.aarch64.sve.convert.from.svbool.nxv8i1( [[B:%.*]]) +; CHECK-NEXT: [[TMP2:%.*]] = call @llvm.aarch64.sve.orr.z.nxv8i1( [[A:%.*]], [[TMP1]], [[TMP1]]) +; CHECK-NEXT: ret [[TMP2]] +; + %t1 = tail call @llvm.aarch64.sve.convert.to.svbool.nxv8i1( %a) + %t2 = tail call @llvm.aarch64.sve.orr.z.nxv16i1( %t1, %b, %b) + %t3 = tail call @llvm.aarch64.sve.convert.from.svbool.nxv8i1( %t2) + ret %t3 +} + +declare @llvm.aarch64.sve.convert.to.svbool.nxv8i1() +declare @llvm.aarch64.sve.convert.to.svbool.nxv4i1() +declare @llvm.aarch64.sve.convert.to.svbool.nxv2i1() +declare @llvm.aarch64.sve.convert.from.svbool.nxv8i1() +declare @llvm.aarch64.sve.convert.from.svbool.nxv4i1() +declare @llvm.aarch64.sve.convert.from.svbool.nxv2i1() +declare @llvm.aarch64.sve.and.z.nxv16i1(, , ) +declare @llvm.aarch64.sve.bic.z.nxv16i1(, , ) +declare @llvm.aarch64.sve.eor.z.nxv16i1(, , ) +declare @llvm.aarch64.sve.nand.z.nxv16i1(, , ) +declare @llvm.aarch64.sve.nor.z.nxv16i1(, , ) +declare @llvm.aarch64.sve.orn.z.nxv16i1(, , ) +declare @llvm.aarch64.sve.orr.z.nxv16i1(, , ) + +attributes #0 = { "target-features"="+sve" }