-
Notifications
You must be signed in to change notification settings - Fork 15.2k
[DAGCombine] Support (shl %x, constant) in foldPartialReduceMLAMulOp. #160663
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
Support shifts in foldPartialReduceMLAMulOp by treating (shl %x, %c) as (mul %x, (shl 1, %c)).
@llvm/pr-subscribers-llvm-selectiondag Author: Florian Hahn (fhahn) ChangesSupport shifts in foldPartialReduceMLAMulOp by treating (shl %x, %c) as (mul %x, (shl 1, %c)). Full diff: https://github.com/llvm/llvm-project/pull/160663.diff 2 Files Affected:
diff --git a/llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp b/llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp
index a6ba6e518899f..5794ce06a0fa3 100644
--- a/llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp
+++ b/llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp
@@ -12996,13 +12996,31 @@ SDValue DAGCombiner::foldPartialReduceMLAMulOp(SDNode *N) {
SDValue Op1 = N->getOperand(1);
SDValue Op2 = N->getOperand(2);
- APInt C;
- if (Op1->getOpcode() != ISD::MUL ||
- !ISD::isConstantSplatVector(Op2.getNode(), C) || !C.isOne())
+ unsigned Opc = Op1->getOpcode();
+ if (Opc != ISD::MUL && Opc != ISD::SHL)
return SDValue();
SDValue LHS = Op1->getOperand(0);
SDValue RHS = Op1->getOperand(1);
+
+ // Try to treat (shl %a, %c) as (mul %a, (1 << %c)) for constant %c.
+ if (Opc == ISD::SHL) {
+ APInt C;
+ if (!ISD::isConstantSplatVector(RHS.getNode(), C))
+ return SDValue();
+
+ RHS =
+ DAG.getSplatVector(RHS.getValueType(), DL,
+ DAG.getConstant(APInt(C.getBitWidth(), 1).shl(C), DL,
+ RHS.getValueType().getScalarType()));
+ Opc = ISD::MUL;
+ }
+
+ APInt C;
+ if (Opc != ISD::MUL || !ISD::isConstantSplatVector(Op2.getNode(), C) ||
+ !C.isOne())
+ return SDValue();
+
unsigned LHSOpcode = LHS->getOpcode();
if (!ISD::isExtOpcode(LHSOpcode))
return SDValue();
diff --git a/llvm/test/CodeGen/AArch64/neon-partial-reduce-dot-product.ll b/llvm/test/CodeGen/AArch64/neon-partial-reduce-dot-product.ll
index d60c870003e4d..428750740fc56 100644
--- a/llvm/test/CodeGen/AArch64/neon-partial-reduce-dot-product.ll
+++ b/llvm/test/CodeGen/AArch64/neon-partial-reduce-dot-product.ll
@@ -1257,21 +1257,55 @@ entry:
}
define <4 x i32> @partial_reduce_shl_sext_const_rhs6(<16 x i8> %l, <4 x i32> %part) {
-; CHECK-COMMON-LABEL: partial_reduce_shl_sext_const_rhs6:
+; CHECK-NODOT-LABEL: partial_reduce_shl_sext_const_rhs6:
+; CHECK-NODOT: // %bb.0:
+; CHECK-NODOT-NEXT: sshll v2.8h, v0.8b, #0
+; CHECK-NODOT-NEXT: sshll2 v0.8h, v0.16b, #0
+; CHECK-NODOT-NEXT: sshll v3.4s, v0.4h, #6
+; CHECK-NODOT-NEXT: sshll2 v4.4s, v2.8h, #6
+; CHECK-NODOT-NEXT: sshll v2.4s, v2.4h, #6
+; CHECK-NODOT-NEXT: sshll2 v0.4s, v0.8h, #6
+; CHECK-NODOT-NEXT: add v1.4s, v1.4s, v2.4s
+; CHECK-NODOT-NEXT: add v2.4s, v4.4s, v3.4s
+; CHECK-NODOT-NEXT: add v1.4s, v1.4s, v2.4s
+; CHECK-NODOT-NEXT: add v0.4s, v1.4s, v0.4s
+; CHECK-NODOT-NEXT: ret
+;
+; CHECK-DOT-LABEL: partial_reduce_shl_sext_const_rhs6:
+; CHECK-DOT: // %bb.0:
+; CHECK-DOT-NEXT: movi v2.16b, #64
+; CHECK-DOT-NEXT: sdot v1.4s, v0.16b, v2.16b
+; CHECK-DOT-NEXT: mov v0.16b, v1.16b
+; CHECK-DOT-NEXT: ret
+;
+; CHECK-DOT-I8MM-LABEL: partial_reduce_shl_sext_const_rhs6:
+; CHECK-DOT-I8MM: // %bb.0:
+; CHECK-DOT-I8MM-NEXT: movi v2.16b, #64
+; CHECK-DOT-I8MM-NEXT: sdot v1.4s, v0.16b, v2.16b
+; CHECK-DOT-I8MM-NEXT: mov v0.16b, v1.16b
+; CHECK-DOT-I8MM-NEXT: ret
+ %ext = sext <16 x i8> %l to <16 x i32>
+ %shift = shl nsw <16 x i32> %ext, splat (i32 6)
+ %red = tail call <4 x i32> @llvm.vector.partial.reduce.add.v4i32.v16i32(<4 x i32> %part, <16 x i32> %shift)
+ ret <4 x i32> %red
+}
+
+define <4 x i32> @partial_reduce_shl_sext_const_rhs7(<16 x i8> %l, <4 x i32> %part) {
+; CHECK-COMMON-LABEL: partial_reduce_shl_sext_const_rhs7:
; CHECK-COMMON: // %bb.0:
; CHECK-COMMON-NEXT: sshll v2.8h, v0.8b, #0
; CHECK-COMMON-NEXT: sshll2 v0.8h, v0.16b, #0
-; CHECK-COMMON-NEXT: sshll v3.4s, v0.4h, #6
-; CHECK-COMMON-NEXT: sshll2 v4.4s, v2.8h, #6
-; CHECK-COMMON-NEXT: sshll v2.4s, v2.4h, #6
-; CHECK-COMMON-NEXT: sshll2 v0.4s, v0.8h, #6
+; CHECK-COMMON-NEXT: sshll v3.4s, v0.4h, #7
+; CHECK-COMMON-NEXT: sshll2 v4.4s, v2.8h, #7
+; CHECK-COMMON-NEXT: sshll v2.4s, v2.4h, #7
+; CHECK-COMMON-NEXT: sshll2 v0.4s, v0.8h, #7
; CHECK-COMMON-NEXT: add v1.4s, v1.4s, v2.4s
; CHECK-COMMON-NEXT: add v2.4s, v4.4s, v3.4s
; CHECK-COMMON-NEXT: add v1.4s, v1.4s, v2.4s
; CHECK-COMMON-NEXT: add v0.4s, v1.4s, v0.4s
; CHECK-COMMON-NEXT: ret
%ext = sext <16 x i8> %l to <16 x i32>
- %shift = shl nsw <16 x i32> %ext, splat (i32 6)
+ %shift = shl nsw <16 x i32> %ext, splat (i32 7)
%red = tail call <4 x i32> @llvm.vector.partial.reduce.add.v4i32.v16i32(<4 x i32> %part, <16 x i32> %shift)
ret <4 x i32> %red
}
@@ -1331,19 +1365,33 @@ define <4 x i32> @partial_reduce_shl_sext_non_const_rhs(<16 x i8> %l, <4 x i32>
}
define <4 x i32> @partial_reduce_shl_zext_const_rhs6(<16 x i8> %l, <4 x i32> %part) {
-; CHECK-COMMON-LABEL: partial_reduce_shl_zext_const_rhs6:
-; CHECK-COMMON: // %bb.0:
-; CHECK-COMMON-NEXT: ushll v2.8h, v0.8b, #0
-; CHECK-COMMON-NEXT: ushll2 v0.8h, v0.16b, #0
-; CHECK-COMMON-NEXT: ushll v3.4s, v0.4h, #6
-; CHECK-COMMON-NEXT: ushll2 v4.4s, v2.8h, #6
-; CHECK-COMMON-NEXT: ushll v2.4s, v2.4h, #6
-; CHECK-COMMON-NEXT: ushll2 v0.4s, v0.8h, #6
-; CHECK-COMMON-NEXT: add v1.4s, v1.4s, v2.4s
-; CHECK-COMMON-NEXT: add v2.4s, v4.4s, v3.4s
-; CHECK-COMMON-NEXT: add v1.4s, v1.4s, v2.4s
-; CHECK-COMMON-NEXT: add v0.4s, v1.4s, v0.4s
-; CHECK-COMMON-NEXT: ret
+; CHECK-NODOT-LABEL: partial_reduce_shl_zext_const_rhs6:
+; CHECK-NODOT: // %bb.0:
+; CHECK-NODOT-NEXT: ushll v2.8h, v0.8b, #0
+; CHECK-NODOT-NEXT: ushll2 v0.8h, v0.16b, #0
+; CHECK-NODOT-NEXT: ushll v3.4s, v0.4h, #6
+; CHECK-NODOT-NEXT: ushll2 v4.4s, v2.8h, #6
+; CHECK-NODOT-NEXT: ushll v2.4s, v2.4h, #6
+; CHECK-NODOT-NEXT: ushll2 v0.4s, v0.8h, #6
+; CHECK-NODOT-NEXT: add v1.4s, v1.4s, v2.4s
+; CHECK-NODOT-NEXT: add v2.4s, v4.4s, v3.4s
+; CHECK-NODOT-NEXT: add v1.4s, v1.4s, v2.4s
+; CHECK-NODOT-NEXT: add v0.4s, v1.4s, v0.4s
+; CHECK-NODOT-NEXT: ret
+;
+; CHECK-DOT-LABEL: partial_reduce_shl_zext_const_rhs6:
+; CHECK-DOT: // %bb.0:
+; CHECK-DOT-NEXT: movi v2.16b, #64
+; CHECK-DOT-NEXT: udot v1.4s, v0.16b, v2.16b
+; CHECK-DOT-NEXT: mov v0.16b, v1.16b
+; CHECK-DOT-NEXT: ret
+;
+; CHECK-DOT-I8MM-LABEL: partial_reduce_shl_zext_const_rhs6:
+; CHECK-DOT-I8MM: // %bb.0:
+; CHECK-DOT-I8MM-NEXT: movi v2.16b, #64
+; CHECK-DOT-I8MM-NEXT: udot v1.4s, v0.16b, v2.16b
+; CHECK-DOT-I8MM-NEXT: mov v0.16b, v1.16b
+; CHECK-DOT-I8MM-NEXT: ret
%ext = zext <16 x i8> %l to <16 x i32>
%shift = shl nsw <16 x i32> %ext, splat (i32 6)
%red = tail call <4 x i32> @llvm.vector.partial.reduce.add.v4i32.v16i32(<4 x i32> %part, <16 x i32> %shift)
|
@llvm/pr-subscribers-backend-aarch64 Author: Florian Hahn (fhahn) ChangesSupport shifts in foldPartialReduceMLAMulOp by treating (shl %x, %c) as (mul %x, (shl 1, %c)). Full diff: https://github.com/llvm/llvm-project/pull/160663.diff 2 Files Affected:
diff --git a/llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp b/llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp
index a6ba6e518899f..5794ce06a0fa3 100644
--- a/llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp
+++ b/llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp
@@ -12996,13 +12996,31 @@ SDValue DAGCombiner::foldPartialReduceMLAMulOp(SDNode *N) {
SDValue Op1 = N->getOperand(1);
SDValue Op2 = N->getOperand(2);
- APInt C;
- if (Op1->getOpcode() != ISD::MUL ||
- !ISD::isConstantSplatVector(Op2.getNode(), C) || !C.isOne())
+ unsigned Opc = Op1->getOpcode();
+ if (Opc != ISD::MUL && Opc != ISD::SHL)
return SDValue();
SDValue LHS = Op1->getOperand(0);
SDValue RHS = Op1->getOperand(1);
+
+ // Try to treat (shl %a, %c) as (mul %a, (1 << %c)) for constant %c.
+ if (Opc == ISD::SHL) {
+ APInt C;
+ if (!ISD::isConstantSplatVector(RHS.getNode(), C))
+ return SDValue();
+
+ RHS =
+ DAG.getSplatVector(RHS.getValueType(), DL,
+ DAG.getConstant(APInt(C.getBitWidth(), 1).shl(C), DL,
+ RHS.getValueType().getScalarType()));
+ Opc = ISD::MUL;
+ }
+
+ APInt C;
+ if (Opc != ISD::MUL || !ISD::isConstantSplatVector(Op2.getNode(), C) ||
+ !C.isOne())
+ return SDValue();
+
unsigned LHSOpcode = LHS->getOpcode();
if (!ISD::isExtOpcode(LHSOpcode))
return SDValue();
diff --git a/llvm/test/CodeGen/AArch64/neon-partial-reduce-dot-product.ll b/llvm/test/CodeGen/AArch64/neon-partial-reduce-dot-product.ll
index d60c870003e4d..428750740fc56 100644
--- a/llvm/test/CodeGen/AArch64/neon-partial-reduce-dot-product.ll
+++ b/llvm/test/CodeGen/AArch64/neon-partial-reduce-dot-product.ll
@@ -1257,21 +1257,55 @@ entry:
}
define <4 x i32> @partial_reduce_shl_sext_const_rhs6(<16 x i8> %l, <4 x i32> %part) {
-; CHECK-COMMON-LABEL: partial_reduce_shl_sext_const_rhs6:
+; CHECK-NODOT-LABEL: partial_reduce_shl_sext_const_rhs6:
+; CHECK-NODOT: // %bb.0:
+; CHECK-NODOT-NEXT: sshll v2.8h, v0.8b, #0
+; CHECK-NODOT-NEXT: sshll2 v0.8h, v0.16b, #0
+; CHECK-NODOT-NEXT: sshll v3.4s, v0.4h, #6
+; CHECK-NODOT-NEXT: sshll2 v4.4s, v2.8h, #6
+; CHECK-NODOT-NEXT: sshll v2.4s, v2.4h, #6
+; CHECK-NODOT-NEXT: sshll2 v0.4s, v0.8h, #6
+; CHECK-NODOT-NEXT: add v1.4s, v1.4s, v2.4s
+; CHECK-NODOT-NEXT: add v2.4s, v4.4s, v3.4s
+; CHECK-NODOT-NEXT: add v1.4s, v1.4s, v2.4s
+; CHECK-NODOT-NEXT: add v0.4s, v1.4s, v0.4s
+; CHECK-NODOT-NEXT: ret
+;
+; CHECK-DOT-LABEL: partial_reduce_shl_sext_const_rhs6:
+; CHECK-DOT: // %bb.0:
+; CHECK-DOT-NEXT: movi v2.16b, #64
+; CHECK-DOT-NEXT: sdot v1.4s, v0.16b, v2.16b
+; CHECK-DOT-NEXT: mov v0.16b, v1.16b
+; CHECK-DOT-NEXT: ret
+;
+; CHECK-DOT-I8MM-LABEL: partial_reduce_shl_sext_const_rhs6:
+; CHECK-DOT-I8MM: // %bb.0:
+; CHECK-DOT-I8MM-NEXT: movi v2.16b, #64
+; CHECK-DOT-I8MM-NEXT: sdot v1.4s, v0.16b, v2.16b
+; CHECK-DOT-I8MM-NEXT: mov v0.16b, v1.16b
+; CHECK-DOT-I8MM-NEXT: ret
+ %ext = sext <16 x i8> %l to <16 x i32>
+ %shift = shl nsw <16 x i32> %ext, splat (i32 6)
+ %red = tail call <4 x i32> @llvm.vector.partial.reduce.add.v4i32.v16i32(<4 x i32> %part, <16 x i32> %shift)
+ ret <4 x i32> %red
+}
+
+define <4 x i32> @partial_reduce_shl_sext_const_rhs7(<16 x i8> %l, <4 x i32> %part) {
+; CHECK-COMMON-LABEL: partial_reduce_shl_sext_const_rhs7:
; CHECK-COMMON: // %bb.0:
; CHECK-COMMON-NEXT: sshll v2.8h, v0.8b, #0
; CHECK-COMMON-NEXT: sshll2 v0.8h, v0.16b, #0
-; CHECK-COMMON-NEXT: sshll v3.4s, v0.4h, #6
-; CHECK-COMMON-NEXT: sshll2 v4.4s, v2.8h, #6
-; CHECK-COMMON-NEXT: sshll v2.4s, v2.4h, #6
-; CHECK-COMMON-NEXT: sshll2 v0.4s, v0.8h, #6
+; CHECK-COMMON-NEXT: sshll v3.4s, v0.4h, #7
+; CHECK-COMMON-NEXT: sshll2 v4.4s, v2.8h, #7
+; CHECK-COMMON-NEXT: sshll v2.4s, v2.4h, #7
+; CHECK-COMMON-NEXT: sshll2 v0.4s, v0.8h, #7
; CHECK-COMMON-NEXT: add v1.4s, v1.4s, v2.4s
; CHECK-COMMON-NEXT: add v2.4s, v4.4s, v3.4s
; CHECK-COMMON-NEXT: add v1.4s, v1.4s, v2.4s
; CHECK-COMMON-NEXT: add v0.4s, v1.4s, v0.4s
; CHECK-COMMON-NEXT: ret
%ext = sext <16 x i8> %l to <16 x i32>
- %shift = shl nsw <16 x i32> %ext, splat (i32 6)
+ %shift = shl nsw <16 x i32> %ext, splat (i32 7)
%red = tail call <4 x i32> @llvm.vector.partial.reduce.add.v4i32.v16i32(<4 x i32> %part, <16 x i32> %shift)
ret <4 x i32> %red
}
@@ -1331,19 +1365,33 @@ define <4 x i32> @partial_reduce_shl_sext_non_const_rhs(<16 x i8> %l, <4 x i32>
}
define <4 x i32> @partial_reduce_shl_zext_const_rhs6(<16 x i8> %l, <4 x i32> %part) {
-; CHECK-COMMON-LABEL: partial_reduce_shl_zext_const_rhs6:
-; CHECK-COMMON: // %bb.0:
-; CHECK-COMMON-NEXT: ushll v2.8h, v0.8b, #0
-; CHECK-COMMON-NEXT: ushll2 v0.8h, v0.16b, #0
-; CHECK-COMMON-NEXT: ushll v3.4s, v0.4h, #6
-; CHECK-COMMON-NEXT: ushll2 v4.4s, v2.8h, #6
-; CHECK-COMMON-NEXT: ushll v2.4s, v2.4h, #6
-; CHECK-COMMON-NEXT: ushll2 v0.4s, v0.8h, #6
-; CHECK-COMMON-NEXT: add v1.4s, v1.4s, v2.4s
-; CHECK-COMMON-NEXT: add v2.4s, v4.4s, v3.4s
-; CHECK-COMMON-NEXT: add v1.4s, v1.4s, v2.4s
-; CHECK-COMMON-NEXT: add v0.4s, v1.4s, v0.4s
-; CHECK-COMMON-NEXT: ret
+; CHECK-NODOT-LABEL: partial_reduce_shl_zext_const_rhs6:
+; CHECK-NODOT: // %bb.0:
+; CHECK-NODOT-NEXT: ushll v2.8h, v0.8b, #0
+; CHECK-NODOT-NEXT: ushll2 v0.8h, v0.16b, #0
+; CHECK-NODOT-NEXT: ushll v3.4s, v0.4h, #6
+; CHECK-NODOT-NEXT: ushll2 v4.4s, v2.8h, #6
+; CHECK-NODOT-NEXT: ushll v2.4s, v2.4h, #6
+; CHECK-NODOT-NEXT: ushll2 v0.4s, v0.8h, #6
+; CHECK-NODOT-NEXT: add v1.4s, v1.4s, v2.4s
+; CHECK-NODOT-NEXT: add v2.4s, v4.4s, v3.4s
+; CHECK-NODOT-NEXT: add v1.4s, v1.4s, v2.4s
+; CHECK-NODOT-NEXT: add v0.4s, v1.4s, v0.4s
+; CHECK-NODOT-NEXT: ret
+;
+; CHECK-DOT-LABEL: partial_reduce_shl_zext_const_rhs6:
+; CHECK-DOT: // %bb.0:
+; CHECK-DOT-NEXT: movi v2.16b, #64
+; CHECK-DOT-NEXT: udot v1.4s, v0.16b, v2.16b
+; CHECK-DOT-NEXT: mov v0.16b, v1.16b
+; CHECK-DOT-NEXT: ret
+;
+; CHECK-DOT-I8MM-LABEL: partial_reduce_shl_zext_const_rhs6:
+; CHECK-DOT-I8MM: // %bb.0:
+; CHECK-DOT-I8MM-NEXT: movi v2.16b, #64
+; CHECK-DOT-I8MM-NEXT: udot v1.4s, v0.16b, v2.16b
+; CHECK-DOT-I8MM-NEXT: mov v0.16b, v1.16b
+; CHECK-DOT-I8MM-NEXT: ret
%ext = zext <16 x i8> %l to <16 x i32>
%shift = shl nsw <16 x i32> %ext, splat (i32 6)
%red = tail call <4 x i32> @llvm.vector.partial.reduce.add.v4i32.v16i32(<4 x i32> %part, <16 x i32> %shift)
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Are there any shifts in vplan that we can allow in place of a mul? If so, that would of course be a follow-up PR.
|
||
APInt C; | ||
if (Op1->getOpcode() != ISD::MUL || | ||
!ISD::isConstantSplatVector(Op2.getNode(), C) || !C.isOne()) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Why do we no longer need to allow constant splats in place of the mul or shl?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
IIUC, that just moved down the file a little.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Ah yes, I mis-read where the brackets were.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM
|
||
APInt C; | ||
if (Op1->getOpcode() != ISD::MUL || | ||
!ISD::isConstantSplatVector(Op2.getNode(), C) || !C.isOne()) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
IIUC, that just moved down the file a little.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Are there any shifts in vplan that we can allow in place of a mul? If so, that would of course be a follow-up PR.
Yep, but first we need to support constant operands at all: #161092
…ceMLAMulOp. (#160663) Support shifts in foldPartialReduceMLAMulOp by treating (shl %x, %c) as (mul %x, (shl 1, %c)). PR: llvm/llvm-project#160663
…llvm#160663) Support shifts in foldPartialReduceMLAMulOp by treating (shl %x, %c) as (mul %x, (shl 1, %c)). PR: llvm#160663
Support shifts in foldPartialReduceMLAMulOp by treating (shl %x, %c) as (mul %x, (shl 1, %c)).