From c52255d26a23df6ecf09f60ca3e3615467f16bbe Mon Sep 17 00:00:00 2001 From: David Green Date: Fri, 10 Feb 2023 18:09:11 +0000 Subject: [PATCH] [AArch64] Reassociate sub(x, add(m1, m2)) to sub(sub(x, m1), m2) The mid end will reassociate sub(sub(x, m1), m2) to sub(x, add(m1, m2)). This reassociates it back to allow the creation of more mls instructions. Differential Revision: https://reviews.llvm.org/D143143 --- .../Target/AArch64/AArch64ISelLowering.cpp | 28 ++++++++++ llvm/test/CodeGen/AArch64/arm64-vmul.ll | 40 +++++++------- llvm/test/CodeGen/AArch64/reassocmls.ll | 52 ++++++++----------- 3 files changed, 67 insertions(+), 53 deletions(-) diff --git a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp index 6dd4c9b9044d5..1b51bf389f353 100644 --- a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp +++ b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp @@ -17702,6 +17702,32 @@ static SDValue performAddCombineForShiftedOperands(SDNode *N, return SDValue(); } +// The mid end will reassociate sub(sub(x, m1), m2) to sub(x, add(m1, m2)) +// This reassociates it back to allow the creation of more mls instructions. +static SDValue performSubAddMULCombine(SDNode *N, SelectionDAG &DAG) { + if (N->getOpcode() != ISD::SUB) + return SDValue(); + SDValue Add = N->getOperand(1); + if (Add.getOpcode() != ISD::ADD) + return SDValue(); + + SDValue X = N->getOperand(0); + if (isa(X)) + return SDValue(); + SDValue M1 = Add.getOperand(0); + SDValue M2 = Add.getOperand(1); + if (M1.getOpcode() != ISD::MUL && M1.getOpcode() != AArch64ISD::SMULL && + M1.getOpcode() != AArch64ISD::UMULL) + return SDValue(); + if (M2.getOpcode() != ISD::MUL && M2.getOpcode() != AArch64ISD::SMULL && + M2.getOpcode() != AArch64ISD::UMULL) + return SDValue(); + + EVT VT = N->getValueType(0); + SDValue Sub = DAG.getNode(ISD::SUB, SDLoc(N), VT, X, M1); + return DAG.getNode(ISD::SUB, SDLoc(N), VT, Sub, M2); +} + static SDValue performAddSubCombine(SDNode *N, TargetLowering::DAGCombinerInfo &DCI, SelectionDAG &DAG) { @@ -17718,6 +17744,8 @@ static SDValue performAddSubCombine(SDNode *N, return Val; if (SDValue Val = performAddCombineForShiftedOperands(N, DAG)) return Val; + if (SDValue Val = performSubAddMULCombine(N, DAG)) + return Val; return performAddSubLongCombine(N, DCI, DAG); } diff --git a/llvm/test/CodeGen/AArch64/arm64-vmul.ll b/llvm/test/CodeGen/AArch64/arm64-vmul.ll index 7f743f605f255..3a9f0319b06e0 100644 --- a/llvm/test/CodeGen/AArch64/arm64-vmul.ll +++ b/llvm/test/CodeGen/AArch64/arm64-vmul.ll @@ -457,12 +457,11 @@ define <2 x i64> @smlsl2d(ptr %A, ptr %B, ptr %C) nounwind { define void @smlsl8h_chain_with_constant(ptr %dst, <8 x i8> %v1, <8 x i8> %v2, <8 x i8> %v3) { ; CHECK-LABEL: smlsl8h_chain_with_constant: ; CHECK: // %bb.0: -; CHECK-NEXT: smull.8h v0, v0, v2 -; CHECK-NEXT: mvn.8b v2, v2 ; CHECK-NEXT: movi.16b v3, #1 -; CHECK-NEXT: smlal.8h v0, v1, v2 -; CHECK-NEXT: sub.8h v0, v3, v0 -; CHECK-NEXT: str q0, [x0] +; CHECK-NEXT: smlsl.8h v3, v0, v2 +; CHECK-NEXT: mvn.8b v0, v2 +; CHECK-NEXT: smlsl.8h v3, v1, v0 +; CHECK-NEXT: str q3, [x0] ; CHECK-NEXT: ret %xor = xor <8 x i8> %v3, %smull.1 = tail call <8 x i16> @llvm.aarch64.neon.smull.v8i16(<8 x i8> %v1, <8 x i8> %v3) @@ -476,13 +475,12 @@ define void @smlsl8h_chain_with_constant(ptr %dst, <8 x i8> %v1, <8 x i8> %v2, < define void @smlsl2d_chain_with_constant(ptr %dst, <2 x i32> %v1, <2 x i32> %v2, <2 x i32> %v3) { ; CHECK-LABEL: smlsl2d_chain_with_constant: ; CHECK: // %bb.0: -; CHECK-NEXT: smull.2d v0, v0, v2 ; CHECK-NEXT: mov w8, #257 -; CHECK-NEXT: mvn.8b v2, v2 -; CHECK-NEXT: smlal.2d v0, v1, v2 -; CHECK-NEXT: dup.2d v1, x8 -; CHECK-NEXT: sub.2d v0, v1, v0 -; CHECK-NEXT: str q0, [x0] +; CHECK-NEXT: dup.2d v3, x8 +; CHECK-NEXT: smlsl.2d v3, v0, v2 +; CHECK-NEXT: mvn.8b v0, v2 +; CHECK-NEXT: smlsl.2d v3, v1, v0 +; CHECK-NEXT: str q3, [x0] ; CHECK-NEXT: ret %xor = xor <2 x i32> %v3, %smull.1 = tail call <2 x i64> @llvm.aarch64.neon.smull.v2i64(<2 x i32> %v1, <2 x i32> %v3) @@ -738,12 +736,11 @@ define <2 x i64> @umlsl2d(ptr %A, ptr %B, ptr %C) nounwind { define void @umlsl8h_chain_with_constant(ptr %dst, <8 x i8> %v1, <8 x i8> %v2, <8 x i8> %v3) { ; CHECK-LABEL: umlsl8h_chain_with_constant: ; CHECK: // %bb.0: -; CHECK-NEXT: umull.8h v0, v0, v2 -; CHECK-NEXT: mvn.8b v2, v2 ; CHECK-NEXT: movi.16b v3, #1 -; CHECK-NEXT: umlal.8h v0, v1, v2 -; CHECK-NEXT: sub.8h v0, v3, v0 -; CHECK-NEXT: str q0, [x0] +; CHECK-NEXT: umlsl.8h v3, v0, v2 +; CHECK-NEXT: mvn.8b v0, v2 +; CHECK-NEXT: umlsl.8h v3, v1, v0 +; CHECK-NEXT: str q3, [x0] ; CHECK-NEXT: ret %xor = xor <8 x i8> %v3, %umull.1 = tail call <8 x i16> @llvm.aarch64.neon.umull.v8i16(<8 x i8> %v1, <8 x i8> %v3) @@ -757,13 +754,12 @@ define void @umlsl8h_chain_with_constant(ptr %dst, <8 x i8> %v1, <8 x i8> %v2, < define void @umlsl2d_chain_with_constant(ptr %dst, <2 x i32> %v1, <2 x i32> %v2, <2 x i32> %v3) { ; CHECK-LABEL: umlsl2d_chain_with_constant: ; CHECK: // %bb.0: -; CHECK-NEXT: umull.2d v0, v0, v2 ; CHECK-NEXT: mov w8, #257 -; CHECK-NEXT: mvn.8b v2, v2 -; CHECK-NEXT: umlal.2d v0, v1, v2 -; CHECK-NEXT: dup.2d v1, x8 -; CHECK-NEXT: sub.2d v0, v1, v0 -; CHECK-NEXT: str q0, [x0] +; CHECK-NEXT: dup.2d v3, x8 +; CHECK-NEXT: umlsl.2d v3, v0, v2 +; CHECK-NEXT: mvn.8b v0, v2 +; CHECK-NEXT: umlsl.2d v3, v1, v0 +; CHECK-NEXT: str q3, [x0] ; CHECK-NEXT: ret %xor = xor <2 x i32> %v3, %umull.1 = tail call <2 x i64> @llvm.aarch64.neon.umull.v2i64(<2 x i32> %v1, <2 x i32> %v3) diff --git a/llvm/test/CodeGen/AArch64/reassocmls.ll b/llvm/test/CodeGen/AArch64/reassocmls.ll index cf201caac4abd..3e0a67d610c12 100644 --- a/llvm/test/CodeGen/AArch64/reassocmls.ll +++ b/llvm/test/CodeGen/AArch64/reassocmls.ll @@ -4,9 +4,8 @@ define i64 @smlsl_i64(i64 %a, i32 %b, i32 %c, i32 %d, i32 %e) { ; CHECK-LABEL: smlsl_i64: ; CHECK: // %bb.0: -; CHECK-NEXT: smull x8, w4, w3 -; CHECK-NEXT: smaddl x8, w2, w1, x8 -; CHECK-NEXT: sub x0, x0, x8 +; CHECK-NEXT: smsubl x8, w4, w3, x0 +; CHECK-NEXT: smsubl x0, w2, w1, x8 ; CHECK-NEXT: ret %be = sext i32 %b to i64 %ce = sext i32 %c to i64 @@ -22,9 +21,8 @@ define i64 @smlsl_i64(i64 %a, i32 %b, i32 %c, i32 %d, i32 %e) { define i64 @umlsl_i64(i64 %a, i32 %b, i32 %c, i32 %d, i32 %e) { ; CHECK-LABEL: umlsl_i64: ; CHECK: // %bb.0: -; CHECK-NEXT: umull x8, w4, w3 -; CHECK-NEXT: umaddl x8, w2, w1, x8 -; CHECK-NEXT: sub x0, x0, x8 +; CHECK-NEXT: umsubl x8, w4, w3, x0 +; CHECK-NEXT: umsubl x0, w2, w1, x8 ; CHECK-NEXT: ret %be = zext i32 %b to i64 %ce = zext i32 %c to i64 @@ -40,9 +38,8 @@ define i64 @umlsl_i64(i64 %a, i32 %b, i32 %c, i32 %d, i32 %e) { define i64 @mls_i64(i64 %a, i64 %b, i64 %c, i64 %d, i64 %e) { ; CHECK-LABEL: mls_i64: ; CHECK: // %bb.0: -; CHECK-NEXT: mul x8, x2, x1 -; CHECK-NEXT: madd x8, x4, x3, x8 -; CHECK-NEXT: sub x0, x0, x8 +; CHECK-NEXT: msub x8, x4, x3, x0 +; CHECK-NEXT: msub x0, x2, x1, x8 ; CHECK-NEXT: ret %m1.neg = mul i64 %c, %b %m2.neg = mul i64 %e, %d @@ -54,9 +51,8 @@ define i64 @mls_i64(i64 %a, i64 %b, i64 %c, i64 %d, i64 %e) { define i16 @mls_i16(i16 %a, i16 %b, i16 %c, i16 %d, i16 %e) { ; CHECK-LABEL: mls_i16: ; CHECK: // %bb.0: -; CHECK-NEXT: mul w8, w2, w1 -; CHECK-NEXT: madd w8, w4, w3, w8 -; CHECK-NEXT: sub w0, w0, w8 +; CHECK-NEXT: msub w8, w4, w3, w0 +; CHECK-NEXT: msub w0, w2, w1, w8 ; CHECK-NEXT: ret %m1.neg = mul i16 %c, %b %m2.neg = mul i16 %e, %d @@ -97,9 +93,8 @@ define i64 @mls_i64_C(i64 %a, i64 %b, i64 %c, i64 %d, i64 %e) { define <8 x i16> @smlsl_v8i16(<8 x i16> %a, <8 x i8> %b, <8 x i8> %c, <8 x i8> %d, <8 x i8> %e) { ; CHECK-LABEL: smlsl_v8i16: ; CHECK: // %bb.0: -; CHECK-NEXT: smull v3.8h, v4.8b, v3.8b -; CHECK-NEXT: smlal v3.8h, v2.8b, v1.8b -; CHECK-NEXT: sub v0.8h, v0.8h, v3.8h +; CHECK-NEXT: smlsl v0.8h, v4.8b, v3.8b +; CHECK-NEXT: smlsl v0.8h, v2.8b, v1.8b ; CHECK-NEXT: ret %be = sext <8 x i8> %b to <8 x i16> %ce = sext <8 x i8> %c to <8 x i16> @@ -115,9 +110,8 @@ define <8 x i16> @smlsl_v8i16(<8 x i16> %a, <8 x i8> %b, <8 x i8> %c, <8 x i8> % define <8 x i16> @umlsl_v8i16(<8 x i16> %a, <8 x i8> %b, <8 x i8> %c, <8 x i8> %d, <8 x i8> %e) { ; CHECK-LABEL: umlsl_v8i16: ; CHECK: // %bb.0: -; CHECK-NEXT: umull v3.8h, v4.8b, v3.8b -; CHECK-NEXT: umlal v3.8h, v2.8b, v1.8b -; CHECK-NEXT: sub v0.8h, v0.8h, v3.8h +; CHECK-NEXT: umlsl v0.8h, v4.8b, v3.8b +; CHECK-NEXT: umlsl v0.8h, v2.8b, v1.8b ; CHECK-NEXT: ret %be = zext <8 x i8> %b to <8 x i16> %ce = zext <8 x i8> %c to <8 x i16> @@ -133,9 +127,8 @@ define <8 x i16> @umlsl_v8i16(<8 x i16> %a, <8 x i8> %b, <8 x i8> %c, <8 x i8> % define <8 x i16> @mls_v8i16(<8 x i16> %a, <8 x i16> %b, <8 x i16> %c, <8 x i16> %d, <8 x i16> %e) { ; CHECK-LABEL: mls_v8i16: ; CHECK: // %bb.0: -; CHECK-NEXT: mul v1.8h, v2.8h, v1.8h -; CHECK-NEXT: mla v1.8h, v4.8h, v3.8h -; CHECK-NEXT: sub v0.8h, v0.8h, v1.8h +; CHECK-NEXT: mls v0.8h, v4.8h, v3.8h +; CHECK-NEXT: mls v0.8h, v2.8h, v1.8h ; CHECK-NEXT: ret %m1.neg = mul <8 x i16> %c, %b %m2.neg = mul <8 x i16> %e, %d @@ -166,9 +159,8 @@ define @smlsl_nxv8i16( %a, %b to %ce = sext %c to @@ -184,14 +176,13 @@ define @smlsl_nxv8i16( %a, @umlsl_nxv8i16( %a, %b, %c, %d, %e) { ; CHECK-LABEL: umlsl_nxv8i16: ; CHECK: // %bb.0: +; CHECK-NEXT: ptrue p0.h ; CHECK-NEXT: and z3.h, z3.h, #0xff ; CHECK-NEXT: and z4.h, z4.h, #0xff -; CHECK-NEXT: ptrue p0.h ; CHECK-NEXT: and z1.h, z1.h, #0xff ; CHECK-NEXT: and z2.h, z2.h, #0xff -; CHECK-NEXT: mul z3.h, z4.h, z3.h -; CHECK-NEXT: mla z3.h, p0/m, z2.h, z1.h -; CHECK-NEXT: sub z0.h, z0.h, z3.h +; CHECK-NEXT: mls z0.h, p0/m, z4.h, z3.h +; CHECK-NEXT: mls z0.h, p0/m, z2.h, z1.h ; CHECK-NEXT: ret %be = zext %b to %ce = zext %c to @@ -208,9 +199,8 @@ define @mls_nxv8i16( %a, ; CHECK-LABEL: mls_nxv8i16: ; CHECK: // %bb.0: ; CHECK-NEXT: ptrue p0.h -; CHECK-NEXT: mul z3.h, z4.h, z3.h -; CHECK-NEXT: mla z3.h, p0/m, z2.h, z1.h -; CHECK-NEXT: sub z0.h, z0.h, z3.h +; CHECK-NEXT: mls z0.h, p0/m, z4.h, z3.h +; CHECK-NEXT: mls z0.h, p0/m, z2.h, z1.h ; CHECK-NEXT: ret %m1.neg = mul %c, %b %m2.neg = mul %e, %d