diff --git a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp index 6dd4c9b9044d5d..1b51bf389f3532 100644 --- a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp +++ b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp @@ -17702,6 +17702,32 @@ static SDValue performAddCombineForShiftedOperands(SDNode *N, return SDValue(); } +// The mid end will reassociate sub(sub(x, m1), m2) to sub(x, add(m1, m2)) +// This reassociates it back to allow the creation of more mls instructions. +static SDValue performSubAddMULCombine(SDNode *N, SelectionDAG &DAG) { + if (N->getOpcode() != ISD::SUB) + return SDValue(); + SDValue Add = N->getOperand(1); + if (Add.getOpcode() != ISD::ADD) + return SDValue(); + + SDValue X = N->getOperand(0); + if (isa(X)) + return SDValue(); + SDValue M1 = Add.getOperand(0); + SDValue M2 = Add.getOperand(1); + if (M1.getOpcode() != ISD::MUL && M1.getOpcode() != AArch64ISD::SMULL && + M1.getOpcode() != AArch64ISD::UMULL) + return SDValue(); + if (M2.getOpcode() != ISD::MUL && M2.getOpcode() != AArch64ISD::SMULL && + M2.getOpcode() != AArch64ISD::UMULL) + return SDValue(); + + EVT VT = N->getValueType(0); + SDValue Sub = DAG.getNode(ISD::SUB, SDLoc(N), VT, X, M1); + return DAG.getNode(ISD::SUB, SDLoc(N), VT, Sub, M2); +} + static SDValue performAddSubCombine(SDNode *N, TargetLowering::DAGCombinerInfo &DCI, SelectionDAG &DAG) { @@ -17718,6 +17744,8 @@ static SDValue performAddSubCombine(SDNode *N, return Val; if (SDValue Val = performAddCombineForShiftedOperands(N, DAG)) return Val; + if (SDValue Val = performSubAddMULCombine(N, DAG)) + return Val; return performAddSubLongCombine(N, DCI, DAG); } diff --git a/llvm/test/CodeGen/AArch64/arm64-vmul.ll b/llvm/test/CodeGen/AArch64/arm64-vmul.ll index 7f743f605f255d..3a9f0319b06e0f 100644 --- a/llvm/test/CodeGen/AArch64/arm64-vmul.ll +++ b/llvm/test/CodeGen/AArch64/arm64-vmul.ll @@ -457,12 +457,11 @@ define <2 x i64> @smlsl2d(ptr %A, ptr %B, ptr %C) nounwind { define void @smlsl8h_chain_with_constant(ptr %dst, <8 x i8> %v1, <8 x i8> %v2, <8 x i8> %v3) { ; CHECK-LABEL: smlsl8h_chain_with_constant: ; CHECK: // %bb.0: -; CHECK-NEXT: smull.8h v0, v0, v2 -; CHECK-NEXT: mvn.8b v2, v2 ; CHECK-NEXT: movi.16b v3, #1 -; CHECK-NEXT: smlal.8h v0, v1, v2 -; CHECK-NEXT: sub.8h v0, v3, v0 -; CHECK-NEXT: str q0, [x0] +; CHECK-NEXT: smlsl.8h v3, v0, v2 +; CHECK-NEXT: mvn.8b v0, v2 +; CHECK-NEXT: smlsl.8h v3, v1, v0 +; CHECK-NEXT: str q3, [x0] ; CHECK-NEXT: ret %xor = xor <8 x i8> %v3, %smull.1 = tail call <8 x i16> @llvm.aarch64.neon.smull.v8i16(<8 x i8> %v1, <8 x i8> %v3) @@ -476,13 +475,12 @@ define void @smlsl8h_chain_with_constant(ptr %dst, <8 x i8> %v1, <8 x i8> %v2, < define void @smlsl2d_chain_with_constant(ptr %dst, <2 x i32> %v1, <2 x i32> %v2, <2 x i32> %v3) { ; CHECK-LABEL: smlsl2d_chain_with_constant: ; CHECK: // %bb.0: -; CHECK-NEXT: smull.2d v0, v0, v2 ; CHECK-NEXT: mov w8, #257 -; CHECK-NEXT: mvn.8b v2, v2 -; CHECK-NEXT: smlal.2d v0, v1, v2 -; CHECK-NEXT: dup.2d v1, x8 -; CHECK-NEXT: sub.2d v0, v1, v0 -; CHECK-NEXT: str q0, [x0] +; CHECK-NEXT: dup.2d v3, x8 +; CHECK-NEXT: smlsl.2d v3, v0, v2 +; CHECK-NEXT: mvn.8b v0, v2 +; CHECK-NEXT: smlsl.2d v3, v1, v0 +; CHECK-NEXT: str q3, [x0] ; CHECK-NEXT: ret %xor = xor <2 x i32> %v3, %smull.1 = tail call <2 x i64> @llvm.aarch64.neon.smull.v2i64(<2 x i32> %v1, <2 x i32> %v3) @@ -738,12 +736,11 @@ define <2 x i64> @umlsl2d(ptr %A, ptr %B, ptr %C) nounwind { define void @umlsl8h_chain_with_constant(ptr %dst, <8 x i8> %v1, <8 x i8> %v2, <8 x i8> %v3) { ; CHECK-LABEL: umlsl8h_chain_with_constant: ; CHECK: // %bb.0: -; CHECK-NEXT: umull.8h v0, v0, v2 -; CHECK-NEXT: mvn.8b v2, v2 ; CHECK-NEXT: movi.16b v3, #1 -; CHECK-NEXT: umlal.8h v0, v1, v2 -; CHECK-NEXT: sub.8h v0, v3, v0 -; CHECK-NEXT: str q0, [x0] +; CHECK-NEXT: umlsl.8h v3, v0, v2 +; CHECK-NEXT: mvn.8b v0, v2 +; CHECK-NEXT: umlsl.8h v3, v1, v0 +; CHECK-NEXT: str q3, [x0] ; CHECK-NEXT: ret %xor = xor <8 x i8> %v3, %umull.1 = tail call <8 x i16> @llvm.aarch64.neon.umull.v8i16(<8 x i8> %v1, <8 x i8> %v3) @@ -757,13 +754,12 @@ define void @umlsl8h_chain_with_constant(ptr %dst, <8 x i8> %v1, <8 x i8> %v2, < define void @umlsl2d_chain_with_constant(ptr %dst, <2 x i32> %v1, <2 x i32> %v2, <2 x i32> %v3) { ; CHECK-LABEL: umlsl2d_chain_with_constant: ; CHECK: // %bb.0: -; CHECK-NEXT: umull.2d v0, v0, v2 ; CHECK-NEXT: mov w8, #257 -; CHECK-NEXT: mvn.8b v2, v2 -; CHECK-NEXT: umlal.2d v0, v1, v2 -; CHECK-NEXT: dup.2d v1, x8 -; CHECK-NEXT: sub.2d v0, v1, v0 -; CHECK-NEXT: str q0, [x0] +; CHECK-NEXT: dup.2d v3, x8 +; CHECK-NEXT: umlsl.2d v3, v0, v2 +; CHECK-NEXT: mvn.8b v0, v2 +; CHECK-NEXT: umlsl.2d v3, v1, v0 +; CHECK-NEXT: str q3, [x0] ; CHECK-NEXT: ret %xor = xor <2 x i32> %v3, %umull.1 = tail call <2 x i64> @llvm.aarch64.neon.umull.v2i64(<2 x i32> %v1, <2 x i32> %v3) diff --git a/llvm/test/CodeGen/AArch64/reassocmls.ll b/llvm/test/CodeGen/AArch64/reassocmls.ll index cf201caac4abda..3e0a67d610c123 100644 --- a/llvm/test/CodeGen/AArch64/reassocmls.ll +++ b/llvm/test/CodeGen/AArch64/reassocmls.ll @@ -4,9 +4,8 @@ define i64 @smlsl_i64(i64 %a, i32 %b, i32 %c, i32 %d, i32 %e) { ; CHECK-LABEL: smlsl_i64: ; CHECK: // %bb.0: -; CHECK-NEXT: smull x8, w4, w3 -; CHECK-NEXT: smaddl x8, w2, w1, x8 -; CHECK-NEXT: sub x0, x0, x8 +; CHECK-NEXT: smsubl x8, w4, w3, x0 +; CHECK-NEXT: smsubl x0, w2, w1, x8 ; CHECK-NEXT: ret %be = sext i32 %b to i64 %ce = sext i32 %c to i64 @@ -22,9 +21,8 @@ define i64 @smlsl_i64(i64 %a, i32 %b, i32 %c, i32 %d, i32 %e) { define i64 @umlsl_i64(i64 %a, i32 %b, i32 %c, i32 %d, i32 %e) { ; CHECK-LABEL: umlsl_i64: ; CHECK: // %bb.0: -; CHECK-NEXT: umull x8, w4, w3 -; CHECK-NEXT: umaddl x8, w2, w1, x8 -; CHECK-NEXT: sub x0, x0, x8 +; CHECK-NEXT: umsubl x8, w4, w3, x0 +; CHECK-NEXT: umsubl x0, w2, w1, x8 ; CHECK-NEXT: ret %be = zext i32 %b to i64 %ce = zext i32 %c to i64 @@ -40,9 +38,8 @@ define i64 @umlsl_i64(i64 %a, i32 %b, i32 %c, i32 %d, i32 %e) { define i64 @mls_i64(i64 %a, i64 %b, i64 %c, i64 %d, i64 %e) { ; CHECK-LABEL: mls_i64: ; CHECK: // %bb.0: -; CHECK-NEXT: mul x8, x2, x1 -; CHECK-NEXT: madd x8, x4, x3, x8 -; CHECK-NEXT: sub x0, x0, x8 +; CHECK-NEXT: msub x8, x4, x3, x0 +; CHECK-NEXT: msub x0, x2, x1, x8 ; CHECK-NEXT: ret %m1.neg = mul i64 %c, %b %m2.neg = mul i64 %e, %d @@ -54,9 +51,8 @@ define i64 @mls_i64(i64 %a, i64 %b, i64 %c, i64 %d, i64 %e) { define i16 @mls_i16(i16 %a, i16 %b, i16 %c, i16 %d, i16 %e) { ; CHECK-LABEL: mls_i16: ; CHECK: // %bb.0: -; CHECK-NEXT: mul w8, w2, w1 -; CHECK-NEXT: madd w8, w4, w3, w8 -; CHECK-NEXT: sub w0, w0, w8 +; CHECK-NEXT: msub w8, w4, w3, w0 +; CHECK-NEXT: msub w0, w2, w1, w8 ; CHECK-NEXT: ret %m1.neg = mul i16 %c, %b %m2.neg = mul i16 %e, %d @@ -97,9 +93,8 @@ define i64 @mls_i64_C(i64 %a, i64 %b, i64 %c, i64 %d, i64 %e) { define <8 x i16> @smlsl_v8i16(<8 x i16> %a, <8 x i8> %b, <8 x i8> %c, <8 x i8> %d, <8 x i8> %e) { ; CHECK-LABEL: smlsl_v8i16: ; CHECK: // %bb.0: -; CHECK-NEXT: smull v3.8h, v4.8b, v3.8b -; CHECK-NEXT: smlal v3.8h, v2.8b, v1.8b -; CHECK-NEXT: sub v0.8h, v0.8h, v3.8h +; CHECK-NEXT: smlsl v0.8h, v4.8b, v3.8b +; CHECK-NEXT: smlsl v0.8h, v2.8b, v1.8b ; CHECK-NEXT: ret %be = sext <8 x i8> %b to <8 x i16> %ce = sext <8 x i8> %c to <8 x i16> @@ -115,9 +110,8 @@ define <8 x i16> @smlsl_v8i16(<8 x i16> %a, <8 x i8> %b, <8 x i8> %c, <8 x i8> % define <8 x i16> @umlsl_v8i16(<8 x i16> %a, <8 x i8> %b, <8 x i8> %c, <8 x i8> %d, <8 x i8> %e) { ; CHECK-LABEL: umlsl_v8i16: ; CHECK: // %bb.0: -; CHECK-NEXT: umull v3.8h, v4.8b, v3.8b -; CHECK-NEXT: umlal v3.8h, v2.8b, v1.8b -; CHECK-NEXT: sub v0.8h, v0.8h, v3.8h +; CHECK-NEXT: umlsl v0.8h, v4.8b, v3.8b +; CHECK-NEXT: umlsl v0.8h, v2.8b, v1.8b ; CHECK-NEXT: ret %be = zext <8 x i8> %b to <8 x i16> %ce = zext <8 x i8> %c to <8 x i16> @@ -133,9 +127,8 @@ define <8 x i16> @umlsl_v8i16(<8 x i16> %a, <8 x i8> %b, <8 x i8> %c, <8 x i8> % define <8 x i16> @mls_v8i16(<8 x i16> %a, <8 x i16> %b, <8 x i16> %c, <8 x i16> %d, <8 x i16> %e) { ; CHECK-LABEL: mls_v8i16: ; CHECK: // %bb.0: -; CHECK-NEXT: mul v1.8h, v2.8h, v1.8h -; CHECK-NEXT: mla v1.8h, v4.8h, v3.8h -; CHECK-NEXT: sub v0.8h, v0.8h, v1.8h +; CHECK-NEXT: mls v0.8h, v4.8h, v3.8h +; CHECK-NEXT: mls v0.8h, v2.8h, v1.8h ; CHECK-NEXT: ret %m1.neg = mul <8 x i16> %c, %b %m2.neg = mul <8 x i16> %e, %d @@ -166,9 +159,8 @@ define @smlsl_nxv8i16( %a, %b to %ce = sext %c to @@ -184,14 +176,13 @@ define @smlsl_nxv8i16( %a, @umlsl_nxv8i16( %a, %b, %c, %d, %e) { ; CHECK-LABEL: umlsl_nxv8i16: ; CHECK: // %bb.0: +; CHECK-NEXT: ptrue p0.h ; CHECK-NEXT: and z3.h, z3.h, #0xff ; CHECK-NEXT: and z4.h, z4.h, #0xff -; CHECK-NEXT: ptrue p0.h ; CHECK-NEXT: and z1.h, z1.h, #0xff ; CHECK-NEXT: and z2.h, z2.h, #0xff -; CHECK-NEXT: mul z3.h, z4.h, z3.h -; CHECK-NEXT: mla z3.h, p0/m, z2.h, z1.h -; CHECK-NEXT: sub z0.h, z0.h, z3.h +; CHECK-NEXT: mls z0.h, p0/m, z4.h, z3.h +; CHECK-NEXT: mls z0.h, p0/m, z2.h, z1.h ; CHECK-NEXT: ret %be = zext %b to %ce = zext %c to @@ -208,9 +199,8 @@ define @mls_nxv8i16( %a, ; CHECK-LABEL: mls_nxv8i16: ; CHECK: // %bb.0: ; CHECK-NEXT: ptrue p0.h -; CHECK-NEXT: mul z3.h, z4.h, z3.h -; CHECK-NEXT: mla z3.h, p0/m, z2.h, z1.h -; CHECK-NEXT: sub z0.h, z0.h, z3.h +; CHECK-NEXT: mls z0.h, p0/m, z4.h, z3.h +; CHECK-NEXT: mls z0.h, p0/m, z2.h, z1.h ; CHECK-NEXT: ret %m1.neg = mul %c, %b %m2.neg = mul %e, %d