-
Notifications
You must be signed in to change notification settings - Fork 15.2k
Fold SVE mul and mul_u to neg during isel #160828
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
Replace mul and mul_u ops with a neg operation if their second operand is a splat value -1. Apply the optimization also for mul_u ops if their first operand is a splat value -1 due to their commutativity.
@llvm/pr-subscribers-backend-aarch64 Author: Martin Wehking (MartinWehking) ChangesReplace mul and mul_u ops with a neg operation if their second operand is a splat value -1. Apply the optimization also for mul_u ops if their first operand is a splat value -1 due to their commutativity. Full diff: https://github.com/llvm/llvm-project/pull/160828.diff 3 Files Affected:
diff --git a/llvm/lib/Target/AArch64/AArch64SVEInstrInfo.td b/llvm/lib/Target/AArch64/AArch64SVEInstrInfo.td
index 1e30735b7a56a..27f4e8125f067 100644
--- a/llvm/lib/Target/AArch64/AArch64SVEInstrInfo.td
+++ b/llvm/lib/Target/AArch64/AArch64SVEInstrInfo.td
@@ -730,6 +730,21 @@ let Predicates = [HasSVE_or_SME] in {
defm ABS_ZPmZ : sve_int_un_pred_arit< 0b110, "abs", AArch64abs_mt>;
defm NEG_ZPmZ : sve_int_un_pred_arit< 0b111, "neg", AArch64neg_mt>;
+ // mul x (splat -1) -> neg x
+ let Predicates = [HasSVE_or_SME] in {
+ def : SVE_2_Op_Neg_One_Passthru_Pat<nxv16i8, AArch64mul_m1, nxv16i1, NEG_ZPmZ_B , i32>;
+ def : SVE_2_Op_Neg_One_Passthru_Pat<nxv8i16, AArch64mul_m1, nxv8i1, NEG_ZPmZ_H , i32>;
+ def : SVE_2_Op_Neg_One_Passthru_Pat<nxv4i32, AArch64mul_m1, nxv4i1, NEG_ZPmZ_S , i32>;
+ def : SVE_2_Op_Neg_One_Passthru_Pat<nxv2i64, AArch64mul_m1, nxv2i1, NEG_ZPmZ_D , i64>;
+
+ let AddedComplexity = 5 in {
+ defm : SVE_2_Op_Neg_One_Passthru_Pat_Comm<nxv16i8, AArch64mul_p, nxv16i1, NEG_ZPmZ_B , i32>;
+ defm : SVE_2_Op_Neg_One_Passthru_Pat_Comm<nxv8i16, AArch64mul_p, nxv8i1, NEG_ZPmZ_H , i32>;
+ defm : SVE_2_Op_Neg_One_Passthru_Pat_Comm<nxv4i32, AArch64mul_p, nxv4i1, NEG_ZPmZ_S , i32>;
+ defm : SVE_2_Op_Neg_One_Passthru_Pat_Comm<nxv2i64, AArch64mul_p, nxv2i1, NEG_ZPmZ_D , i64>;
+ }
+ }
+
defm CLS_ZPmZ : sve_int_un_pred_arit_bitwise< 0b000, "cls", AArch64cls_mt>;
defm CLZ_ZPmZ : sve_int_un_pred_arit_bitwise< 0b001, "clz", AArch64clz_mt>;
defm CNT_ZPmZ : sve_int_un_pred_arit_bitwise< 0b010, "cnt", AArch64cnt_mt>;
diff --git a/llvm/lib/Target/AArch64/SVEInstrFormats.td b/llvm/lib/Target/AArch64/SVEInstrFormats.td
index 9a23c35766cac..4fdc80c2d749b 100644
--- a/llvm/lib/Target/AArch64/SVEInstrFormats.td
+++ b/llvm/lib/Target/AArch64/SVEInstrFormats.td
@@ -723,6 +723,19 @@ class SVE2p1_Cvt_VG2_Pat<string name, SDPatternOperator intrinsic, ValueType out
: Pat<(out_vt (intrinsic in_vt:$Zn1, in_vt:$Zn2)),
(!cast<Instruction>(name) (REG_SEQUENCE ZPR2Mul2, in_vt:$Zn1, zsub0, in_vt:$Zn2, zsub1))>;
+class SVE_2_Op_Neg_One_Passthru_Pat<ValueType vt, SDPatternOperator op, ValueType pt,
+ Instruction inst, ValueType immT>
+: Pat<(vt (op pt:$Op1, vt:$Op2, (vt (splat_vector (immT -1))))),
+ (inst $Op2, $Op1, $Op2)>;
+
+// Same as above, but commutative
+multiclass SVE_2_Op_Neg_One_Passthru_Pat_Comm<ValueType vt, SDPatternOperator op, ValueType pt,
+ Instruction inst, ValueType immT> {
+def : Pat<(vt (op pt:$Op1, vt:$Op2, (vt (splat_vector (immT -1))))),
+ (inst $Op2, $Op1, $Op2)>;
+def : Pat<(vt (op pt:$Op1, (vt (splat_vector (immT -1))), vt:$Op2)),
+ (inst $Op2, $Op1, $Op2)>;
+}
//===----------------------------------------------------------------------===//
// SVE pattern match helpers.
//===----------------------------------------------------------------------===//
diff --git a/llvm/test/CodeGen/AArch64/sve-int-mul-neg.ll b/llvm/test/CodeGen/AArch64/sve-int-mul-neg.ll
new file mode 100644
index 0000000000000..8db9eac027506
--- /dev/null
+++ b/llvm/test/CodeGen/AArch64/sve-int-mul-neg.ll
@@ -0,0 +1,213 @@
+; NOTE: Assertions have been autogenerated by utils/update_llc_test_checks.py UTC_ARGS: --version 5
+; RUN: llc -mtriple=aarch64-linux-gnu -verify-machineinstrs -mattr=+sve < %s | FileCheck %s
+
+target triple = "aarch64-unknown-linux-gnu"
+
+; Muls with (-1) as operand should fold to neg.
+define <vscale x 16 x i8> @mul_neg_fold_i8(<vscale x 16 x i1> %pg, <vscale x 16 x i8> %a) {
+; CHECK-LABEL: mul_neg_fold_i8:
+; CHECK: // %bb.0:
+; CHECK-NEXT: neg z0.b, p0/m, z0.b
+; CHECK-NEXT: ret
+ %1 = call <vscale x 16 x i8> @llvm.aarch64.sve.dup.x.nxv16i8(i8 -1)
+ %2 = call <vscale x 16 x i8> @llvm.aarch64.sve.mul.nxv16i8(<vscale x 16 x i1> %pg, <vscale x 16 x i8> %a, <vscale x 16 x i8> %1)
+ ret <vscale x 16 x i8> %2
+}
+
+define <vscale x 8 x i16> @mul_neg_fold_i16(<vscale x 8 x i1> %pg, <vscale x 8 x i16> %a) {
+; CHECK-LABEL: mul_neg_fold_i16:
+; CHECK: // %bb.0:
+; CHECK-NEXT: neg z0.h, p0/m, z0.h
+; CHECK-NEXT: ret
+ %1 = call <vscale x 8 x i16> @llvm.aarch64.sve.dup.x.nxv8i16(i16 -1)
+ %2 = call <vscale x 8 x i16> @llvm.aarch64.sve.mul.nxv8i16(<vscale x 8 x i1> %pg, <vscale x 8 x i16> %a, <vscale x 8 x i16> %1)
+ ret <vscale x 8 x i16> %2
+}
+
+define <vscale x 4 x i32> @mul_neg_fold_i32(<vscale x 4 x i1> %pg, <vscale x 4 x i32> %a) {
+; CHECK-LABEL: mul_neg_fold_i32:
+; CHECK: // %bb.0:
+; CHECK-NEXT: neg z0.s, p0/m, z0.s
+; CHECK-NEXT: ret
+ %1 = call <vscale x 4 x i32> @llvm.aarch64.sve.dup.x.nxv4i32(i32 -1)
+ %2 = call <vscale x 4 x i32> @llvm.aarch64.sve.mul.nxv4i32(<vscale x 4 x i1> %pg, <vscale x 4 x i32> %a, <vscale x 4 x i32> %1)
+ ret <vscale x 4 x i32> %2
+}
+
+define <vscale x 2 x i64> @mul_neg_fold_i64(<vscale x 2 x i1> %pg, <vscale x 2 x i64> %a) {
+; CHECK-LABEL: mul_neg_fold_i64:
+; CHECK: // %bb.0:
+; CHECK-NEXT: neg z0.d, p0/m, z0.d
+; CHECK-NEXT: ret
+ %1 = call <vscale x 2 x i64> @llvm.aarch64.sve.dup.x.nxv2i64(i64 -1)
+ %2 = call <vscale x 2 x i64> @llvm.aarch64.sve.mul.nxv2i64(<vscale x 2 x i1> %pg, <vscale x 2 x i64> %a, <vscale x 2 x i64> %1)
+ ret <vscale x 2 x i64> %2
+}
+
+define <vscale x 8 x i16> @mul_neg_fold_two_dups(<vscale x 8 x i1> %pg, <vscale x 8 x i16> %a) {
+ ; Edge case -- make sure that the case where we're multiplying two dups
+ ; together is sane.
+; CHECK-LABEL: mul_neg_fold_two_dups:
+; CHECK: // %bb.0:
+; CHECK-NEXT: mov z0.h, #-1 // =0xffffffffffffffff
+; CHECK-NEXT: neg z0.h, p0/m, z0.h
+; CHECK-NEXT: ret
+ %1 = call <vscale x 8 x i16> @llvm.aarch64.sve.dup.x.nxv8i16(i16 -1)
+ %2 = call <vscale x 8 x i16> @llvm.aarch64.sve.dup.x.nxv8i16(i16 -1)
+ %3 = call <vscale x 8 x i16> @llvm.aarch64.sve.mul.nxv8i16(<vscale x 8 x i1> %pg, <vscale x 8 x i16> %1, <vscale x 8 x i16> %2)
+ ret <vscale x 8 x i16> %3
+}
+
+define <vscale x 16 x i8> @mul_neg_fold_u_i8(<vscale x 16 x i1> %pg, <vscale x 16 x i8> %a) {
+; CHECK-LABEL: mul_neg_fold_u_i8:
+; CHECK: // %bb.0:
+; CHECK-NEXT: neg z0.b, p0/m, z0.b
+; CHECK-NEXT: ret
+ %1 = call <vscale x 16 x i8> @llvm.aarch64.sve.dup.x.nxv16i8(i8 -1)
+ %2 = call <vscale x 16 x i8> @llvm.aarch64.sve.mul.u.nxv16i8(<vscale x 16 x i1> %pg, <vscale x 16 x i8> %a, <vscale x 16 x i8> %1)
+ ret <vscale x 16 x i8> %2
+}
+
+define <vscale x 8 x i16> @mul_neg_fold_u_i16(<vscale x 8 x i1> %pg, <vscale x 8 x i16> %a) {
+; CHECK-LABEL: mul_neg_fold_u_i16:
+; CHECK: // %bb.0:
+; CHECK-NEXT: neg z0.h, p0/m, z0.h
+; CHECK-NEXT: ret
+ %1 = call <vscale x 8 x i16> @llvm.aarch64.sve.dup.x.nxv8i16(i16 -1)
+ %2 = call <vscale x 8 x i16> @llvm.aarch64.sve.mul.u.nxv8i16(<vscale x 8 x i1> %pg, <vscale x 8 x i16> %a, <vscale x 8 x i16> %1)
+ ret <vscale x 8 x i16> %2
+}
+
+define <vscale x 4 x i32> @mul_neg_fold_u_i32(<vscale x 4 x i1> %pg, <vscale x 4 x i32> %a) {
+; CHECK-LABEL: mul_neg_fold_u_i32:
+; CHECK: // %bb.0:
+; CHECK-NEXT: neg z0.s, p0/m, z0.s
+; CHECK-NEXT: ret
+ %1 = call <vscale x 4 x i32> @llvm.aarch64.sve.dup.x.nxv4i32(i32 -1)
+ %2 = call <vscale x 4 x i32> @llvm.aarch64.sve.mul.u.nxv4i32(<vscale x 4 x i1> %pg, <vscale x 4 x i32> %a, <vscale x 4 x i32> %1)
+ ret <vscale x 4 x i32> %2
+}
+
+define <vscale x 2 x i64> @mul_neg_fold_u_i64(<vscale x 2 x i1> %pg, <vscale x 2 x i64> %a) {
+; CHECK-LABEL: mul_neg_fold_u_i64:
+; CHECK: // %bb.0:
+; CHECK-NEXT: neg z0.d, p0/m, z0.d
+; CHECK-NEXT: ret
+ %1 = call <vscale x 2 x i64> @llvm.aarch64.sve.dup.x.nxv2i64(i64 -1)
+ %2 = call <vscale x 2 x i64> @llvm.aarch64.sve.mul.u.nxv2i64(<vscale x 2 x i1> %pg, <vscale x 2 x i64> %a, <vscale x 2 x i64> %1)
+ ret <vscale x 2 x i64> %2
+}
+
+define <vscale x 8 x i16> @mul_neg_fold_u_two_dups(<vscale x 8 x i1> %pg, <vscale x 8 x i16> %a) {
+; CHECK-LABEL: mul_neg_fold_u_two_dups:
+; CHECK: // %bb.0:
+; CHECK-NEXT: mov z0.h, #-1 // =0xffffffffffffffff
+; CHECK-NEXT: neg z0.h, p0/m, z0.h
+; CHECK-NEXT: ret
+ %1 = call <vscale x 8 x i16> @llvm.aarch64.sve.dup.x.nxv8i16(i16 -1)
+ %2 = call <vscale x 8 x i16> @llvm.aarch64.sve.dup.x.nxv8i16(i16 -1)
+ %3 = call <vscale x 8 x i16> @llvm.aarch64.sve.mul.u.nxv8i16(<vscale x 8 x i1> %pg, <vscale x 8 x i16> %1, <vscale x 8 x i16> %2)
+ ret <vscale x 8 x i16> %3
+}
+
+; Undefined mul is commutative
+define <vscale x 2 x i64> @mul_neg_fold_u_different_argument_order(<vscale x 2 x i1> %pg, <vscale x 2 x i64> %a) {
+; CHECK-LABEL: mul_neg_fold_u_different_argument_order:
+; CHECK: // %bb.0:
+; CHECK-NEXT: neg z0.d, p0/m, z0.d
+; CHECK-NEXT: ret
+ %1 = call <vscale x 2 x i64> @llvm.aarch64.sve.dup.x.nxv2i64(i64 -1)
+ %2 = call <vscale x 2 x i64> @llvm.aarch64.sve.mul.u.nxv2i64(<vscale x 2 x i1> %pg, <vscale x 2 x i64> %1, <vscale x 2 x i64> %a)
+ ret <vscale x 2 x i64> %2
+}
+; Non foldable muls -- we don't expect these to be optimised out.
+define <vscale x 8 x i16> @no_mul_neg_fold_i16(<vscale x 8 x i1> %pg, <vscale x 8 x i16> %a) {
+; CHECK-LABEL: no_mul_neg_fold_i16:
+; CHECK: // %bb.0:
+; CHECK-NEXT: mov z1.h, #-2 // =0xfffffffffffffffe
+; CHECK-NEXT: mul z0.h, p0/m, z0.h, z1.h
+; CHECK-NEXT: ret
+ %1 = call <vscale x 8 x i16> @llvm.aarch64.sve.dup.x.nxv8i16(i16 -2)
+ %2 = call <vscale x 8 x i16> @llvm.aarch64.sve.mul.nxv8i16(<vscale x 8 x i1> %pg, <vscale x 8 x i16> %a, <vscale x 8 x i16> %1)
+ ret <vscale x 8 x i16> %2
+}
+
+define <vscale x 4 x i32> @no_mul_neg_fold_i32(<vscale x 4 x i1> %pg, <vscale x 4 x i32> %a) {
+; CHECK-LABEL: no_mul_neg_fold_i32:
+; CHECK: // %bb.0:
+; CHECK-NEXT: mov z1.s, #-2 // =0xfffffffffffffffe
+; CHECK-NEXT: mul z0.s, p0/m, z0.s, z1.s
+; CHECK-NEXT: ret
+ %1 = call <vscale x 4 x i32> @llvm.aarch64.sve.dup.x.nxv4i32(i32 -2)
+ %2 = call <vscale x 4 x i32> @llvm.aarch64.sve.mul.nxv4i32(<vscale x 4 x i1> %pg, <vscale x 4 x i32> %a, <vscale x 4 x i32> %1)
+ ret <vscale x 4 x i32> %2
+}
+
+define <vscale x 2 x i64> @no_mul_neg_fold_i64(<vscale x 2 x i1> %pg, <vscale x 2 x i64> %a) {
+; CHECK-LABEL: no_mul_neg_fold_i64:
+; CHECK: // %bb.0:
+; CHECK-NEXT: mov z1.d, #-2 // =0xfffffffffffffffe
+; CHECK-NEXT: mul z0.d, p0/m, z0.d, z1.d
+; CHECK-NEXT: ret
+ %1 = call <vscale x 2 x i64> @llvm.aarch64.sve.dup.x.nxv2i64(i64 -2)
+ %2 = call <vscale x 2 x i64> @llvm.aarch64.sve.mul.nxv2i64(<vscale x 2 x i1> %pg, <vscale x 2 x i64> %a, <vscale x 2 x i64> %1)
+ ret <vscale x 2 x i64> %2
+}
+
+; Merge mul is non commutative
+define <vscale x 2 x i64> @no_mul_neg_fold_different_argument_order(<vscale x 2 x i1> %pg, <vscale x 2 x i64> %a) {
+; CHECK-LABEL: no_mul_neg_fold_different_argument_order:
+; CHECK: // %bb.0:
+; CHECK-NEXT: mov z1.d, #-1 // =0xffffffffffffffff
+; CHECK-NEXT: mul z1.d, p0/m, z1.d, z0.d
+; CHECK-NEXT: mov z0.d, z1.d
+; CHECK-NEXT: ret
+ %1 = call <vscale x 2 x i64> @llvm.aarch64.sve.dup.x.nxv2i64(i64 -1)
+ %2 = call <vscale x 2 x i64> @llvm.aarch64.sve.mul.nxv2i64(<vscale x 2 x i1> %pg, <vscale x 2 x i64> %1, <vscale x 2 x i64> %a)
+ ret <vscale x 2 x i64> %2
+}
+
+define <vscale x 8 x i16> @no_mul_neg_fold_u_i16(<vscale x 8 x i1> %pg, <vscale x 8 x i16> %a) {
+; CHECK-LABEL: no_mul_neg_fold_u_i16:
+; CHECK: // %bb.0:
+; CHECK-NEXT: mul z0.h, z0.h, #-2
+; CHECK-NEXT: ret
+ %1 = call <vscale x 8 x i16> @llvm.aarch64.sve.dup.x.nxv8i16(i16 -2)
+ %2 = call <vscale x 8 x i16> @llvm.aarch64.sve.mul.u.nxv8i16(<vscale x 8 x i1> %pg, <vscale x 8 x i16> %a, <vscale x 8 x i16> %1)
+ ret <vscale x 8 x i16> %2
+}
+
+define <vscale x 4 x i32> @no_mul_neg_fold_u_i32(<vscale x 4 x i1> %pg, <vscale x 4 x i32> %a) {
+; CHECK-LABEL: no_mul_neg_fold_u_i32:
+; CHECK: // %bb.0:
+; CHECK-NEXT: mul z0.s, z0.s, #-2
+; CHECK-NEXT: ret
+ %1 = call <vscale x 4 x i32> @llvm.aarch64.sve.dup.x.nxv4i32(i32 -2)
+ %2 = call <vscale x 4 x i32> @llvm.aarch64.sve.mul.u.nxv4i32(<vscale x 4 x i1> %pg, <vscale x 4 x i32> %a, <vscale x 4 x i32> %1)
+ ret <vscale x 4 x i32> %2
+}
+
+define <vscale x 2 x i64> @no_mul_neg_fold_u_i64(<vscale x 2 x i1> %pg, <vscale x 2 x i64> %a) {
+; CHECK-LABEL: no_mul_neg_fold_u_i64:
+; CHECK: // %bb.0:
+; CHECK-NEXT: mul z0.d, z0.d, #-2
+; CHECK-NEXT: ret
+ %1 = call <vscale x 2 x i64> @llvm.aarch64.sve.dup.x.nxv2i64(i64 -2)
+ %2 = call <vscale x 2 x i64> @llvm.aarch64.sve.mul.u.nxv2i64(<vscale x 2 x i1> %pg, <vscale x 2 x i64> %a, <vscale x 2 x i64> %1)
+ ret <vscale x 2 x i64> %2
+}
+
+declare <vscale x 16 x i8> @llvm.aarch64.sve.dup.x.nxv16i8(i8)
+declare <vscale x 8 x i16> @llvm.aarch64.sve.dup.x.nxv8i16(i16)
+declare <vscale x 4 x i32> @llvm.aarch64.sve.dup.x.nxv4i32(i32)
+declare <vscale x 2 x i64> @llvm.aarch64.sve.dup.x.nxv2i64(i64)
+
+declare <vscale x 16 x i8> @llvm.aarch64.sve.mul.u.nxv16i8(<vscale x 16 x i1>, <vscale x 16 x i8>, <vscale x 16 x i8>)
+declare <vscale x 8 x i16> @llvm.aarch64.sve.mul.u.nxv8i16(<vscale x 8 x i1>, <vscale x 8 x i16>, <vscale x 8 x i16>)
+declare <vscale x 4 x i32> @llvm.aarch64.sve.mul.u.nxv4i32(<vscale x 4 x i1>, <vscale x 4 x i32>, <vscale x 4 x i32>)
+declare <vscale x 2 x i64> @llvm.aarch64.sve.mul.u.nxv2i64(<vscale x 2 x i1>, <vscale x 2 x i64>, <vscale x 2 x i64>)
+
+declare <vscale x 16 x i8> @llvm.aarch64.sve.mul.nxv16i8(<vscale x 16 x i1>, <vscale x 16 x i8>, <vscale x 16 x i8>)
+declare <vscale x 8 x i16> @llvm.aarch64.sve.mul.nxv8i16(<vscale x 8 x i1>, <vscale x 8 x i16>, <vscale x 8 x i16>)
+declare <vscale x 4 x i32> @llvm.aarch64.sve.mul.nxv4i32(<vscale x 4 x i1>, <vscale x 4 x i32>, <vscale x 4 x i32>)
+declare <vscale x 2 x i64> @llvm.aarch64.sve.mul.nxv2i64(<vscale x 2 x i1>, <vscale x 2 x i64>, <vscale x 2 x i64>)
|
Target undef variants for variants of mul where the inactive lanes do not matter (mul.u instrinsics). Create also test cases for all different datatypes when checking the commutativity of mul.u. Remove target guards since only optimization patterns are added.
Optimize mul for merge variants where the splat value is the first operand. This optimization is equal to a neg that as a (mov -1) and op2 as operands. Remove a redundant commutativity pattern for undef variants of mul since the splat value should appear only as second operand after instcombine. Adapt the tests to test for the new patterns and remove redundant ones for non-optimization cases.
Do not use helper classes and create the patterns directly Move pattern definitions into a similar spot
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
A few suggestions but otherwise this looks good to me.
Replace some dups with splat_vectors and remove whitespaces. Remove redundant arg in run line.
Replace mul and mul_u ops with a neg operation if their second operand is a splat value -1.
Apply the optimization also for mul_u ops if their first operand is a splat value -1 due to their commutativity.