Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
24 changes: 21 additions & 3 deletions llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -12994,13 +12994,31 @@ SDValue DAGCombiner::foldPartialReduceMLAMulOp(SDNode *N) {
SDValue Op1 = N->getOperand(1);
SDValue Op2 = N->getOperand(2);

APInt C;
if (Op1->getOpcode() != ISD::MUL ||
!ISD::isConstantSplatVector(Op2.getNode(), C) || !C.isOne())
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why do we no longer need to allow constant splats in place of the mul or shl?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

IIUC, that just moved down the file a little.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ah yes, I mis-read where the brackets were.

unsigned Opc = Op1->getOpcode();
if (Opc != ISD::MUL && Opc != ISD::SHL)
return SDValue();

SDValue LHS = Op1->getOperand(0);
SDValue RHS = Op1->getOperand(1);

// Try to treat (shl %a, %c) as (mul %a, (1 << %c)) for constant %c.
if (Opc == ISD::SHL) {
APInt C;
if (!ISD::isConstantSplatVector(RHS.getNode(), C))
return SDValue();

RHS =
DAG.getSplatVector(RHS.getValueType(), DL,
DAG.getConstant(APInt(C.getBitWidth(), 1).shl(C), DL,
RHS.getValueType().getScalarType()));
Opc = ISD::MUL;
}

APInt C;
if (Opc != ISD::MUL || !ISD::isConstantSplatVector(Op2.getNode(), C) ||
!C.isOne())
return SDValue();

unsigned LHSOpcode = LHS->getOpcode();
if (!ISD::isExtOpcode(LHSOpcode))
return SDValue();
Expand Down
86 changes: 67 additions & 19 deletions llvm/test/CodeGen/AArch64/neon-partial-reduce-dot-product.ll
Original file line number Diff line number Diff line change
Expand Up @@ -1257,21 +1257,55 @@ entry:
}

define <4 x i32> @partial_reduce_shl_sext_const_rhs6(<16 x i8> %l, <4 x i32> %part) {
; CHECK-COMMON-LABEL: partial_reduce_shl_sext_const_rhs6:
; CHECK-NODOT-LABEL: partial_reduce_shl_sext_const_rhs6:
; CHECK-NODOT: // %bb.0:
; CHECK-NODOT-NEXT: sshll v2.8h, v0.8b, #0
; CHECK-NODOT-NEXT: sshll2 v0.8h, v0.16b, #0
; CHECK-NODOT-NEXT: sshll v3.4s, v0.4h, #6
; CHECK-NODOT-NEXT: sshll2 v4.4s, v2.8h, #6
; CHECK-NODOT-NEXT: sshll v2.4s, v2.4h, #6
; CHECK-NODOT-NEXT: sshll2 v0.4s, v0.8h, #6
; CHECK-NODOT-NEXT: add v1.4s, v1.4s, v2.4s
; CHECK-NODOT-NEXT: add v2.4s, v4.4s, v3.4s
; CHECK-NODOT-NEXT: add v1.4s, v1.4s, v2.4s
; CHECK-NODOT-NEXT: add v0.4s, v1.4s, v0.4s
; CHECK-NODOT-NEXT: ret
;
; CHECK-DOT-LABEL: partial_reduce_shl_sext_const_rhs6:
; CHECK-DOT: // %bb.0:
; CHECK-DOT-NEXT: movi v2.16b, #64
; CHECK-DOT-NEXT: sdot v1.4s, v0.16b, v2.16b
; CHECK-DOT-NEXT: mov v0.16b, v1.16b
; CHECK-DOT-NEXT: ret
;
; CHECK-DOT-I8MM-LABEL: partial_reduce_shl_sext_const_rhs6:
; CHECK-DOT-I8MM: // %bb.0:
; CHECK-DOT-I8MM-NEXT: movi v2.16b, #64
; CHECK-DOT-I8MM-NEXT: sdot v1.4s, v0.16b, v2.16b
; CHECK-DOT-I8MM-NEXT: mov v0.16b, v1.16b
; CHECK-DOT-I8MM-NEXT: ret
%ext = sext <16 x i8> %l to <16 x i32>
%shift = shl nsw <16 x i32> %ext, splat (i32 6)
%red = tail call <4 x i32> @llvm.vector.partial.reduce.add.v4i32.v16i32(<4 x i32> %part, <16 x i32> %shift)
ret <4 x i32> %red
}

define <4 x i32> @partial_reduce_shl_sext_const_rhs7(<16 x i8> %l, <4 x i32> %part) {
; CHECK-COMMON-LABEL: partial_reduce_shl_sext_const_rhs7:
; CHECK-COMMON: // %bb.0:
; CHECK-COMMON-NEXT: sshll v2.8h, v0.8b, #0
; CHECK-COMMON-NEXT: sshll2 v0.8h, v0.16b, #0
; CHECK-COMMON-NEXT: sshll v3.4s, v0.4h, #6
; CHECK-COMMON-NEXT: sshll2 v4.4s, v2.8h, #6
; CHECK-COMMON-NEXT: sshll v2.4s, v2.4h, #6
; CHECK-COMMON-NEXT: sshll2 v0.4s, v0.8h, #6
; CHECK-COMMON-NEXT: sshll v3.4s, v0.4h, #7
; CHECK-COMMON-NEXT: sshll2 v4.4s, v2.8h, #7
; CHECK-COMMON-NEXT: sshll v2.4s, v2.4h, #7
; CHECK-COMMON-NEXT: sshll2 v0.4s, v0.8h, #7
; CHECK-COMMON-NEXT: add v1.4s, v1.4s, v2.4s
; CHECK-COMMON-NEXT: add v2.4s, v4.4s, v3.4s
; CHECK-COMMON-NEXT: add v1.4s, v1.4s, v2.4s
; CHECK-COMMON-NEXT: add v0.4s, v1.4s, v0.4s
; CHECK-COMMON-NEXT: ret
%ext = sext <16 x i8> %l to <16 x i32>
%shift = shl nsw <16 x i32> %ext, splat (i32 6)
%shift = shl nsw <16 x i32> %ext, splat (i32 7)
%red = tail call <4 x i32> @llvm.vector.partial.reduce.add.v4i32.v16i32(<4 x i32> %part, <16 x i32> %shift)
ret <4 x i32> %red
}
Expand Down Expand Up @@ -1331,19 +1365,33 @@ define <4 x i32> @partial_reduce_shl_sext_non_const_rhs(<16 x i8> %l, <4 x i32>
}

define <4 x i32> @partial_reduce_shl_zext_const_rhs6(<16 x i8> %l, <4 x i32> %part) {
; CHECK-COMMON-LABEL: partial_reduce_shl_zext_const_rhs6:
; CHECK-COMMON: // %bb.0:
; CHECK-COMMON-NEXT: ushll v2.8h, v0.8b, #0
; CHECK-COMMON-NEXT: ushll2 v0.8h, v0.16b, #0
; CHECK-COMMON-NEXT: ushll v3.4s, v0.4h, #6
; CHECK-COMMON-NEXT: ushll2 v4.4s, v2.8h, #6
; CHECK-COMMON-NEXT: ushll v2.4s, v2.4h, #6
; CHECK-COMMON-NEXT: ushll2 v0.4s, v0.8h, #6
; CHECK-COMMON-NEXT: add v1.4s, v1.4s, v2.4s
; CHECK-COMMON-NEXT: add v2.4s, v4.4s, v3.4s
; CHECK-COMMON-NEXT: add v1.4s, v1.4s, v2.4s
; CHECK-COMMON-NEXT: add v0.4s, v1.4s, v0.4s
; CHECK-COMMON-NEXT: ret
; CHECK-NODOT-LABEL: partial_reduce_shl_zext_const_rhs6:
; CHECK-NODOT: // %bb.0:
; CHECK-NODOT-NEXT: ushll v2.8h, v0.8b, #0
; CHECK-NODOT-NEXT: ushll2 v0.8h, v0.16b, #0
; CHECK-NODOT-NEXT: ushll v3.4s, v0.4h, #6
; CHECK-NODOT-NEXT: ushll2 v4.4s, v2.8h, #6
; CHECK-NODOT-NEXT: ushll v2.4s, v2.4h, #6
; CHECK-NODOT-NEXT: ushll2 v0.4s, v0.8h, #6
; CHECK-NODOT-NEXT: add v1.4s, v1.4s, v2.4s
; CHECK-NODOT-NEXT: add v2.4s, v4.4s, v3.4s
; CHECK-NODOT-NEXT: add v1.4s, v1.4s, v2.4s
; CHECK-NODOT-NEXT: add v0.4s, v1.4s, v0.4s
; CHECK-NODOT-NEXT: ret
;
; CHECK-DOT-LABEL: partial_reduce_shl_zext_const_rhs6:
; CHECK-DOT: // %bb.0:
; CHECK-DOT-NEXT: movi v2.16b, #64
; CHECK-DOT-NEXT: udot v1.4s, v0.16b, v2.16b
; CHECK-DOT-NEXT: mov v0.16b, v1.16b
; CHECK-DOT-NEXT: ret
;
; CHECK-DOT-I8MM-LABEL: partial_reduce_shl_zext_const_rhs6:
; CHECK-DOT-I8MM: // %bb.0:
; CHECK-DOT-I8MM-NEXT: movi v2.16b, #64
; CHECK-DOT-I8MM-NEXT: udot v1.4s, v0.16b, v2.16b
; CHECK-DOT-I8MM-NEXT: mov v0.16b, v1.16b
; CHECK-DOT-I8MM-NEXT: ret
%ext = zext <16 x i8> %l to <16 x i32>
%shift = shl nsw <16 x i32> %ext, splat (i32 6)
%red = tail call <4 x i32> @llvm.vector.partial.reduce.add.v4i32.v16i32(<4 x i32> %part, <16 x i32> %shift)
Expand Down