diff --git a/llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp b/llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp index 204e1f0c75e00..558c5a0390228 100644 --- a/llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp +++ b/llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp @@ -12994,13 +12994,31 @@ SDValue DAGCombiner::foldPartialReduceMLAMulOp(SDNode *N) { SDValue Op1 = N->getOperand(1); SDValue Op2 = N->getOperand(2); - APInt C; - if (Op1->getOpcode() != ISD::MUL || - !ISD::isConstantSplatVector(Op2.getNode(), C) || !C.isOne()) + unsigned Opc = Op1->getOpcode(); + if (Opc != ISD::MUL && Opc != ISD::SHL) return SDValue(); SDValue LHS = Op1->getOperand(0); SDValue RHS = Op1->getOperand(1); + + // Try to treat (shl %a, %c) as (mul %a, (1 << %c)) for constant %c. + if (Opc == ISD::SHL) { + APInt C; + if (!ISD::isConstantSplatVector(RHS.getNode(), C)) + return SDValue(); + + RHS = + DAG.getSplatVector(RHS.getValueType(), DL, + DAG.getConstant(APInt(C.getBitWidth(), 1).shl(C), DL, + RHS.getValueType().getScalarType())); + Opc = ISD::MUL; + } + + APInt C; + if (Opc != ISD::MUL || !ISD::isConstantSplatVector(Op2.getNode(), C) || + !C.isOne()) + return SDValue(); + unsigned LHSOpcode = LHS->getOpcode(); if (!ISD::isExtOpcode(LHSOpcode)) return SDValue(); diff --git a/llvm/test/CodeGen/AArch64/neon-partial-reduce-dot-product.ll b/llvm/test/CodeGen/AArch64/neon-partial-reduce-dot-product.ll index d60c870003e4d..428750740fc56 100644 --- a/llvm/test/CodeGen/AArch64/neon-partial-reduce-dot-product.ll +++ b/llvm/test/CodeGen/AArch64/neon-partial-reduce-dot-product.ll @@ -1257,21 +1257,55 @@ entry: } define <4 x i32> @partial_reduce_shl_sext_const_rhs6(<16 x i8> %l, <4 x i32> %part) { -; CHECK-COMMON-LABEL: partial_reduce_shl_sext_const_rhs6: +; CHECK-NODOT-LABEL: partial_reduce_shl_sext_const_rhs6: +; CHECK-NODOT: // %bb.0: +; CHECK-NODOT-NEXT: sshll v2.8h, v0.8b, #0 +; CHECK-NODOT-NEXT: sshll2 v0.8h, v0.16b, #0 +; CHECK-NODOT-NEXT: sshll v3.4s, v0.4h, #6 +; CHECK-NODOT-NEXT: sshll2 v4.4s, v2.8h, #6 +; CHECK-NODOT-NEXT: sshll v2.4s, v2.4h, #6 +; CHECK-NODOT-NEXT: sshll2 v0.4s, v0.8h, #6 +; CHECK-NODOT-NEXT: add v1.4s, v1.4s, v2.4s +; CHECK-NODOT-NEXT: add v2.4s, v4.4s, v3.4s +; CHECK-NODOT-NEXT: add v1.4s, v1.4s, v2.4s +; CHECK-NODOT-NEXT: add v0.4s, v1.4s, v0.4s +; CHECK-NODOT-NEXT: ret +; +; CHECK-DOT-LABEL: partial_reduce_shl_sext_const_rhs6: +; CHECK-DOT: // %bb.0: +; CHECK-DOT-NEXT: movi v2.16b, #64 +; CHECK-DOT-NEXT: sdot v1.4s, v0.16b, v2.16b +; CHECK-DOT-NEXT: mov v0.16b, v1.16b +; CHECK-DOT-NEXT: ret +; +; CHECK-DOT-I8MM-LABEL: partial_reduce_shl_sext_const_rhs6: +; CHECK-DOT-I8MM: // %bb.0: +; CHECK-DOT-I8MM-NEXT: movi v2.16b, #64 +; CHECK-DOT-I8MM-NEXT: sdot v1.4s, v0.16b, v2.16b +; CHECK-DOT-I8MM-NEXT: mov v0.16b, v1.16b +; CHECK-DOT-I8MM-NEXT: ret + %ext = sext <16 x i8> %l to <16 x i32> + %shift = shl nsw <16 x i32> %ext, splat (i32 6) + %red = tail call <4 x i32> @llvm.vector.partial.reduce.add.v4i32.v16i32(<4 x i32> %part, <16 x i32> %shift) + ret <4 x i32> %red +} + +define <4 x i32> @partial_reduce_shl_sext_const_rhs7(<16 x i8> %l, <4 x i32> %part) { +; CHECK-COMMON-LABEL: partial_reduce_shl_sext_const_rhs7: ; CHECK-COMMON: // %bb.0: ; CHECK-COMMON-NEXT: sshll v2.8h, v0.8b, #0 ; CHECK-COMMON-NEXT: sshll2 v0.8h, v0.16b, #0 -; CHECK-COMMON-NEXT: sshll v3.4s, v0.4h, #6 -; CHECK-COMMON-NEXT: sshll2 v4.4s, v2.8h, #6 -; CHECK-COMMON-NEXT: sshll v2.4s, v2.4h, #6 -; CHECK-COMMON-NEXT: sshll2 v0.4s, v0.8h, #6 +; CHECK-COMMON-NEXT: sshll v3.4s, v0.4h, #7 +; CHECK-COMMON-NEXT: sshll2 v4.4s, v2.8h, #7 +; CHECK-COMMON-NEXT: sshll v2.4s, v2.4h, #7 +; CHECK-COMMON-NEXT: sshll2 v0.4s, v0.8h, #7 ; CHECK-COMMON-NEXT: add v1.4s, v1.4s, v2.4s ; CHECK-COMMON-NEXT: add v2.4s, v4.4s, v3.4s ; CHECK-COMMON-NEXT: add v1.4s, v1.4s, v2.4s ; CHECK-COMMON-NEXT: add v0.4s, v1.4s, v0.4s ; CHECK-COMMON-NEXT: ret %ext = sext <16 x i8> %l to <16 x i32> - %shift = shl nsw <16 x i32> %ext, splat (i32 6) + %shift = shl nsw <16 x i32> %ext, splat (i32 7) %red = tail call <4 x i32> @llvm.vector.partial.reduce.add.v4i32.v16i32(<4 x i32> %part, <16 x i32> %shift) ret <4 x i32> %red } @@ -1331,19 +1365,33 @@ define <4 x i32> @partial_reduce_shl_sext_non_const_rhs(<16 x i8> %l, <4 x i32> } define <4 x i32> @partial_reduce_shl_zext_const_rhs6(<16 x i8> %l, <4 x i32> %part) { -; CHECK-COMMON-LABEL: partial_reduce_shl_zext_const_rhs6: -; CHECK-COMMON: // %bb.0: -; CHECK-COMMON-NEXT: ushll v2.8h, v0.8b, #0 -; CHECK-COMMON-NEXT: ushll2 v0.8h, v0.16b, #0 -; CHECK-COMMON-NEXT: ushll v3.4s, v0.4h, #6 -; CHECK-COMMON-NEXT: ushll2 v4.4s, v2.8h, #6 -; CHECK-COMMON-NEXT: ushll v2.4s, v2.4h, #6 -; CHECK-COMMON-NEXT: ushll2 v0.4s, v0.8h, #6 -; CHECK-COMMON-NEXT: add v1.4s, v1.4s, v2.4s -; CHECK-COMMON-NEXT: add v2.4s, v4.4s, v3.4s -; CHECK-COMMON-NEXT: add v1.4s, v1.4s, v2.4s -; CHECK-COMMON-NEXT: add v0.4s, v1.4s, v0.4s -; CHECK-COMMON-NEXT: ret +; CHECK-NODOT-LABEL: partial_reduce_shl_zext_const_rhs6: +; CHECK-NODOT: // %bb.0: +; CHECK-NODOT-NEXT: ushll v2.8h, v0.8b, #0 +; CHECK-NODOT-NEXT: ushll2 v0.8h, v0.16b, #0 +; CHECK-NODOT-NEXT: ushll v3.4s, v0.4h, #6 +; CHECK-NODOT-NEXT: ushll2 v4.4s, v2.8h, #6 +; CHECK-NODOT-NEXT: ushll v2.4s, v2.4h, #6 +; CHECK-NODOT-NEXT: ushll2 v0.4s, v0.8h, #6 +; CHECK-NODOT-NEXT: add v1.4s, v1.4s, v2.4s +; CHECK-NODOT-NEXT: add v2.4s, v4.4s, v3.4s +; CHECK-NODOT-NEXT: add v1.4s, v1.4s, v2.4s +; CHECK-NODOT-NEXT: add v0.4s, v1.4s, v0.4s +; CHECK-NODOT-NEXT: ret +; +; CHECK-DOT-LABEL: partial_reduce_shl_zext_const_rhs6: +; CHECK-DOT: // %bb.0: +; CHECK-DOT-NEXT: movi v2.16b, #64 +; CHECK-DOT-NEXT: udot v1.4s, v0.16b, v2.16b +; CHECK-DOT-NEXT: mov v0.16b, v1.16b +; CHECK-DOT-NEXT: ret +; +; CHECK-DOT-I8MM-LABEL: partial_reduce_shl_zext_const_rhs6: +; CHECK-DOT-I8MM: // %bb.0: +; CHECK-DOT-I8MM-NEXT: movi v2.16b, #64 +; CHECK-DOT-I8MM-NEXT: udot v1.4s, v0.16b, v2.16b +; CHECK-DOT-I8MM-NEXT: mov v0.16b, v1.16b +; CHECK-DOT-I8MM-NEXT: ret %ext = zext <16 x i8> %l to <16 x i32> %shift = shl nsw <16 x i32> %ext, splat (i32 6) %red = tail call <4 x i32> @llvm.vector.partial.reduce.add.v4i32.v16i32(<4 x i32> %part, <16 x i32> %shift)