Skip to content

Commit

Permalink
[AArch64] Reassociate sub(x, add(m1, m2)) to sub(sub(x, m1), m2)
Browse files Browse the repository at this point in the history
The mid end will reassociate sub(sub(x, m1), m2) to sub(x, add(m1, m2)). This
reassociates it back to allow the creation of more mls instructions.

Differential Revision: https://reviews.llvm.org/D143143
  • Loading branch information
davemgreen committed Feb 10, 2023
1 parent d37a31c commit c52255d
Show file tree
Hide file tree
Showing 3 changed files with 67 additions and 53 deletions.
28 changes: 28 additions & 0 deletions llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
Expand Up @@ -17702,6 +17702,32 @@ static SDValue performAddCombineForShiftedOperands(SDNode *N,
return SDValue();
}

// The mid end will reassociate sub(sub(x, m1), m2) to sub(x, add(m1, m2))
// This reassociates it back to allow the creation of more mls instructions.
static SDValue performSubAddMULCombine(SDNode *N, SelectionDAG &DAG) {
if (N->getOpcode() != ISD::SUB)
return SDValue();
SDValue Add = N->getOperand(1);
if (Add.getOpcode() != ISD::ADD)
return SDValue();

SDValue X = N->getOperand(0);
if (isa<ConstantSDNode>(X))
return SDValue();
SDValue M1 = Add.getOperand(0);
SDValue M2 = Add.getOperand(1);
if (M1.getOpcode() != ISD::MUL && M1.getOpcode() != AArch64ISD::SMULL &&
M1.getOpcode() != AArch64ISD::UMULL)
return SDValue();
if (M2.getOpcode() != ISD::MUL && M2.getOpcode() != AArch64ISD::SMULL &&
M2.getOpcode() != AArch64ISD::UMULL)
return SDValue();

EVT VT = N->getValueType(0);
SDValue Sub = DAG.getNode(ISD::SUB, SDLoc(N), VT, X, M1);
return DAG.getNode(ISD::SUB, SDLoc(N), VT, Sub, M2);
}

static SDValue performAddSubCombine(SDNode *N,
TargetLowering::DAGCombinerInfo &DCI,
SelectionDAG &DAG) {
Expand All @@ -17718,6 +17744,8 @@ static SDValue performAddSubCombine(SDNode *N,
return Val;
if (SDValue Val = performAddCombineForShiftedOperands(N, DAG))
return Val;
if (SDValue Val = performSubAddMULCombine(N, DAG))
return Val;

return performAddSubLongCombine(N, DCI, DAG);
}
Expand Down
40 changes: 18 additions & 22 deletions llvm/test/CodeGen/AArch64/arm64-vmul.ll
Expand Up @@ -457,12 +457,11 @@ define <2 x i64> @smlsl2d(ptr %A, ptr %B, ptr %C) nounwind {
define void @smlsl8h_chain_with_constant(ptr %dst, <8 x i8> %v1, <8 x i8> %v2, <8 x i8> %v3) {
; CHECK-LABEL: smlsl8h_chain_with_constant:
; CHECK: // %bb.0:
; CHECK-NEXT: smull.8h v0, v0, v2
; CHECK-NEXT: mvn.8b v2, v2
; CHECK-NEXT: movi.16b v3, #1
; CHECK-NEXT: smlal.8h v0, v1, v2
; CHECK-NEXT: sub.8h v0, v3, v0
; CHECK-NEXT: str q0, [x0]
; CHECK-NEXT: smlsl.8h v3, v0, v2
; CHECK-NEXT: mvn.8b v0, v2
; CHECK-NEXT: smlsl.8h v3, v1, v0
; CHECK-NEXT: str q3, [x0]
; CHECK-NEXT: ret
%xor = xor <8 x i8> %v3, <i8 -1, i8 -1, i8 -1, i8 -1, i8 -1, i8 -1, i8 -1, i8 -1>
%smull.1 = tail call <8 x i16> @llvm.aarch64.neon.smull.v8i16(<8 x i8> %v1, <8 x i8> %v3)
Expand All @@ -476,13 +475,12 @@ define void @smlsl8h_chain_with_constant(ptr %dst, <8 x i8> %v1, <8 x i8> %v2, <
define void @smlsl2d_chain_with_constant(ptr %dst, <2 x i32> %v1, <2 x i32> %v2, <2 x i32> %v3) {
; CHECK-LABEL: smlsl2d_chain_with_constant:
; CHECK: // %bb.0:
; CHECK-NEXT: smull.2d v0, v0, v2
; CHECK-NEXT: mov w8, #257
; CHECK-NEXT: mvn.8b v2, v2
; CHECK-NEXT: smlal.2d v0, v1, v2
; CHECK-NEXT: dup.2d v1, x8
; CHECK-NEXT: sub.2d v0, v1, v0
; CHECK-NEXT: str q0, [x0]
; CHECK-NEXT: dup.2d v3, x8
; CHECK-NEXT: smlsl.2d v3, v0, v2
; CHECK-NEXT: mvn.8b v0, v2
; CHECK-NEXT: smlsl.2d v3, v1, v0
; CHECK-NEXT: str q3, [x0]
; CHECK-NEXT: ret
%xor = xor <2 x i32> %v3, <i32 -1, i32 -1>
%smull.1 = tail call <2 x i64> @llvm.aarch64.neon.smull.v2i64(<2 x i32> %v1, <2 x i32> %v3)
Expand Down Expand Up @@ -738,12 +736,11 @@ define <2 x i64> @umlsl2d(ptr %A, ptr %B, ptr %C) nounwind {
define void @umlsl8h_chain_with_constant(ptr %dst, <8 x i8> %v1, <8 x i8> %v2, <8 x i8> %v3) {
; CHECK-LABEL: umlsl8h_chain_with_constant:
; CHECK: // %bb.0:
; CHECK-NEXT: umull.8h v0, v0, v2
; CHECK-NEXT: mvn.8b v2, v2
; CHECK-NEXT: movi.16b v3, #1
; CHECK-NEXT: umlal.8h v0, v1, v2
; CHECK-NEXT: sub.8h v0, v3, v0
; CHECK-NEXT: str q0, [x0]
; CHECK-NEXT: umlsl.8h v3, v0, v2
; CHECK-NEXT: mvn.8b v0, v2
; CHECK-NEXT: umlsl.8h v3, v1, v0
; CHECK-NEXT: str q3, [x0]
; CHECK-NEXT: ret
%xor = xor <8 x i8> %v3, <i8 -1, i8 -1, i8 -1, i8 -1, i8 -1, i8 -1, i8 -1, i8 -1>
%umull.1 = tail call <8 x i16> @llvm.aarch64.neon.umull.v8i16(<8 x i8> %v1, <8 x i8> %v3)
Expand All @@ -757,13 +754,12 @@ define void @umlsl8h_chain_with_constant(ptr %dst, <8 x i8> %v1, <8 x i8> %v2, <
define void @umlsl2d_chain_with_constant(ptr %dst, <2 x i32> %v1, <2 x i32> %v2, <2 x i32> %v3) {
; CHECK-LABEL: umlsl2d_chain_with_constant:
; CHECK: // %bb.0:
; CHECK-NEXT: umull.2d v0, v0, v2
; CHECK-NEXT: mov w8, #257
; CHECK-NEXT: mvn.8b v2, v2
; CHECK-NEXT: umlal.2d v0, v1, v2
; CHECK-NEXT: dup.2d v1, x8
; CHECK-NEXT: sub.2d v0, v1, v0
; CHECK-NEXT: str q0, [x0]
; CHECK-NEXT: dup.2d v3, x8
; CHECK-NEXT: umlsl.2d v3, v0, v2
; CHECK-NEXT: mvn.8b v0, v2
; CHECK-NEXT: umlsl.2d v3, v1, v0
; CHECK-NEXT: str q3, [x0]
; CHECK-NEXT: ret
%xor = xor <2 x i32> %v3, <i32 -1, i32 -1>
%umull.1 = tail call <2 x i64> @llvm.aarch64.neon.umull.v2i64(<2 x i32> %v1, <2 x i32> %v3)
Expand Down
52 changes: 21 additions & 31 deletions llvm/test/CodeGen/AArch64/reassocmls.ll
Expand Up @@ -4,9 +4,8 @@
define i64 @smlsl_i64(i64 %a, i32 %b, i32 %c, i32 %d, i32 %e) {
; CHECK-LABEL: smlsl_i64:
; CHECK: // %bb.0:
; CHECK-NEXT: smull x8, w4, w3
; CHECK-NEXT: smaddl x8, w2, w1, x8
; CHECK-NEXT: sub x0, x0, x8
; CHECK-NEXT: smsubl x8, w4, w3, x0
; CHECK-NEXT: smsubl x0, w2, w1, x8
; CHECK-NEXT: ret
%be = sext i32 %b to i64
%ce = sext i32 %c to i64
Expand All @@ -22,9 +21,8 @@ define i64 @smlsl_i64(i64 %a, i32 %b, i32 %c, i32 %d, i32 %e) {
define i64 @umlsl_i64(i64 %a, i32 %b, i32 %c, i32 %d, i32 %e) {
; CHECK-LABEL: umlsl_i64:
; CHECK: // %bb.0:
; CHECK-NEXT: umull x8, w4, w3
; CHECK-NEXT: umaddl x8, w2, w1, x8
; CHECK-NEXT: sub x0, x0, x8
; CHECK-NEXT: umsubl x8, w4, w3, x0
; CHECK-NEXT: umsubl x0, w2, w1, x8
; CHECK-NEXT: ret
%be = zext i32 %b to i64
%ce = zext i32 %c to i64
Expand All @@ -40,9 +38,8 @@ define i64 @umlsl_i64(i64 %a, i32 %b, i32 %c, i32 %d, i32 %e) {
define i64 @mls_i64(i64 %a, i64 %b, i64 %c, i64 %d, i64 %e) {
; CHECK-LABEL: mls_i64:
; CHECK: // %bb.0:
; CHECK-NEXT: mul x8, x2, x1
; CHECK-NEXT: madd x8, x4, x3, x8
; CHECK-NEXT: sub x0, x0, x8
; CHECK-NEXT: msub x8, x4, x3, x0
; CHECK-NEXT: msub x0, x2, x1, x8
; CHECK-NEXT: ret
%m1.neg = mul i64 %c, %b
%m2.neg = mul i64 %e, %d
Expand All @@ -54,9 +51,8 @@ define i64 @mls_i64(i64 %a, i64 %b, i64 %c, i64 %d, i64 %e) {
define i16 @mls_i16(i16 %a, i16 %b, i16 %c, i16 %d, i16 %e) {
; CHECK-LABEL: mls_i16:
; CHECK: // %bb.0:
; CHECK-NEXT: mul w8, w2, w1
; CHECK-NEXT: madd w8, w4, w3, w8
; CHECK-NEXT: sub w0, w0, w8
; CHECK-NEXT: msub w8, w4, w3, w0
; CHECK-NEXT: msub w0, w2, w1, w8
; CHECK-NEXT: ret
%m1.neg = mul i16 %c, %b
%m2.neg = mul i16 %e, %d
Expand Down Expand Up @@ -97,9 +93,8 @@ define i64 @mls_i64_C(i64 %a, i64 %b, i64 %c, i64 %d, i64 %e) {
define <8 x i16> @smlsl_v8i16(<8 x i16> %a, <8 x i8> %b, <8 x i8> %c, <8 x i8> %d, <8 x i8> %e) {
; CHECK-LABEL: smlsl_v8i16:
; CHECK: // %bb.0:
; CHECK-NEXT: smull v3.8h, v4.8b, v3.8b
; CHECK-NEXT: smlal v3.8h, v2.8b, v1.8b
; CHECK-NEXT: sub v0.8h, v0.8h, v3.8h
; CHECK-NEXT: smlsl v0.8h, v4.8b, v3.8b
; CHECK-NEXT: smlsl v0.8h, v2.8b, v1.8b
; CHECK-NEXT: ret
%be = sext <8 x i8> %b to <8 x i16>
%ce = sext <8 x i8> %c to <8 x i16>
Expand All @@ -115,9 +110,8 @@ define <8 x i16> @smlsl_v8i16(<8 x i16> %a, <8 x i8> %b, <8 x i8> %c, <8 x i8> %
define <8 x i16> @umlsl_v8i16(<8 x i16> %a, <8 x i8> %b, <8 x i8> %c, <8 x i8> %d, <8 x i8> %e) {
; CHECK-LABEL: umlsl_v8i16:
; CHECK: // %bb.0:
; CHECK-NEXT: umull v3.8h, v4.8b, v3.8b
; CHECK-NEXT: umlal v3.8h, v2.8b, v1.8b
; CHECK-NEXT: sub v0.8h, v0.8h, v3.8h
; CHECK-NEXT: umlsl v0.8h, v4.8b, v3.8b
; CHECK-NEXT: umlsl v0.8h, v2.8b, v1.8b
; CHECK-NEXT: ret
%be = zext <8 x i8> %b to <8 x i16>
%ce = zext <8 x i8> %c to <8 x i16>
Expand All @@ -133,9 +127,8 @@ define <8 x i16> @umlsl_v8i16(<8 x i16> %a, <8 x i8> %b, <8 x i8> %c, <8 x i8> %
define <8 x i16> @mls_v8i16(<8 x i16> %a, <8 x i16> %b, <8 x i16> %c, <8 x i16> %d, <8 x i16> %e) {
; CHECK-LABEL: mls_v8i16:
; CHECK: // %bb.0:
; CHECK-NEXT: mul v1.8h, v2.8h, v1.8h
; CHECK-NEXT: mla v1.8h, v4.8h, v3.8h
; CHECK-NEXT: sub v0.8h, v0.8h, v1.8h
; CHECK-NEXT: mls v0.8h, v4.8h, v3.8h
; CHECK-NEXT: mls v0.8h, v2.8h, v1.8h
; CHECK-NEXT: ret
%m1.neg = mul <8 x i16> %c, %b
%m2.neg = mul <8 x i16> %e, %d
Expand Down Expand Up @@ -166,9 +159,8 @@ define <vscale x 8 x i16> @smlsl_nxv8i16(<vscale x 8 x i16> %a, <vscale x 8 x i8
; CHECK-NEXT: sxtb z4.h, p0/m, z4.h
; CHECK-NEXT: sxtb z1.h, p0/m, z1.h
; CHECK-NEXT: sxtb z2.h, p0/m, z2.h
; CHECK-NEXT: mul z3.h, z4.h, z3.h
; CHECK-NEXT: mla z3.h, p0/m, z2.h, z1.h
; CHECK-NEXT: sub z0.h, z0.h, z3.h
; CHECK-NEXT: mls z0.h, p0/m, z4.h, z3.h
; CHECK-NEXT: mls z0.h, p0/m, z2.h, z1.h
; CHECK-NEXT: ret
%be = sext <vscale x 8 x i8> %b to <vscale x 8 x i16>
%ce = sext <vscale x 8 x i8> %c to <vscale x 8 x i16>
Expand All @@ -184,14 +176,13 @@ define <vscale x 8 x i16> @smlsl_nxv8i16(<vscale x 8 x i16> %a, <vscale x 8 x i8
define <vscale x 8 x i16> @umlsl_nxv8i16(<vscale x 8 x i16> %a, <vscale x 8 x i8> %b, <vscale x 8 x i8> %c, <vscale x 8 x i8> %d, <vscale x 8 x i8> %e) {
; CHECK-LABEL: umlsl_nxv8i16:
; CHECK: // %bb.0:
; CHECK-NEXT: ptrue p0.h
; CHECK-NEXT: and z3.h, z3.h, #0xff
; CHECK-NEXT: and z4.h, z4.h, #0xff
; CHECK-NEXT: ptrue p0.h
; CHECK-NEXT: and z1.h, z1.h, #0xff
; CHECK-NEXT: and z2.h, z2.h, #0xff
; CHECK-NEXT: mul z3.h, z4.h, z3.h
; CHECK-NEXT: mla z3.h, p0/m, z2.h, z1.h
; CHECK-NEXT: sub z0.h, z0.h, z3.h
; CHECK-NEXT: mls z0.h, p0/m, z4.h, z3.h
; CHECK-NEXT: mls z0.h, p0/m, z2.h, z1.h
; CHECK-NEXT: ret
%be = zext <vscale x 8 x i8> %b to <vscale x 8 x i16>
%ce = zext <vscale x 8 x i8> %c to <vscale x 8 x i16>
Expand All @@ -208,9 +199,8 @@ define <vscale x 8 x i16> @mls_nxv8i16(<vscale x 8 x i16> %a, <vscale x 8 x i16>
; CHECK-LABEL: mls_nxv8i16:
; CHECK: // %bb.0:
; CHECK-NEXT: ptrue p0.h
; CHECK-NEXT: mul z3.h, z4.h, z3.h
; CHECK-NEXT: mla z3.h, p0/m, z2.h, z1.h
; CHECK-NEXT: sub z0.h, z0.h, z3.h
; CHECK-NEXT: mls z0.h, p0/m, z4.h, z3.h
; CHECK-NEXT: mls z0.h, p0/m, z2.h, z1.h
; CHECK-NEXT: ret
%m1.neg = mul <vscale x 8 x i16> %c, %b
%m2.neg = mul <vscale x 8 x i16> %e, %d
Expand Down

0 comments on commit c52255d

Please sign in to comment.