diff --git a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp index 50fccd5aeefe0..8bf75e4aee5f0 100644 --- a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp +++ b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp @@ -22391,6 +22391,152 @@ static SDValue performDupLane128Combine(SDNode *N, SelectionDAG &DAG) { return DAG.getNode(ISD::BITCAST, DL, VT, NewDuplane128); } +// Try to combine mull with uzp1. +static SDValue tryCombineMULLWithUZP1(SDNode *N, + TargetLowering::DAGCombinerInfo &DCI, + SelectionDAG &DAG) { + if (DCI.isBeforeLegalizeOps()) + return SDValue(); + + SDValue LHS = N->getOperand(0); + SDValue RHS = N->getOperand(1); + + SDValue ExtractHigh; + SDValue ExtractLow; + SDValue TruncHigh; + SDValue TruncLow; + SDLoc DL(N); + + // Check the operands are trunc and extract_high. + if (isEssentiallyExtractHighSubvector(LHS) && + RHS.getOpcode() == ISD::TRUNCATE) { + TruncHigh = RHS; + if (LHS.getOpcode() == ISD::BITCAST) + ExtractHigh = LHS.getOperand(0); + else + ExtractHigh = LHS; + } else if (isEssentiallyExtractHighSubvector(RHS) && + LHS.getOpcode() == ISD::TRUNCATE) { + TruncHigh = LHS; + if (LHS.getOpcode() == ISD::BITCAST) + ExtractHigh = RHS.getOperand(0); + else + ExtractHigh = RHS; + } else + return SDValue(); + + // If the truncate's operand is BUILD_VECTOR with DUP, do not combine the op + // with uzp1. + // You can see the regressions on test/CodeGen/AArch64/aarch64-smull.ll + SDValue TruncHighOp = TruncHigh.getOperand(0); + EVT TruncHighOpVT = TruncHighOp.getValueType(); + if (TruncHighOp.getOpcode() == AArch64ISD::DUP || + DAG.isSplatValue(TruncHighOp, false)) + return SDValue(); + + // Check there is other extract_high with same source vector. + // For example, + // + // t18: v4i16 = extract_subvector t2, Constant:i64<0> + // t12: v4i16 = truncate t11 + // t31: v4i32 = AArch64ISD::SMULL t18, t12 + // t23: v4i16 = extract_subvector t2, Constant:i64<4> + // t16: v4i16 = truncate t15 + // t30: v4i32 = AArch64ISD::SMULL t23, t1 + // + // This dagcombine assumes the two extract_high uses same source vector in + // order to detect the pair of the mull. If they have different source vector, + // this code will not work. + bool HasFoundMULLow = true; + SDValue ExtractHighSrcVec = ExtractHigh.getOperand(0); + if (ExtractHighSrcVec->use_size() != 2) + HasFoundMULLow = false; + + // Find ExtractLow. + for (SDNode *User : ExtractHighSrcVec.getNode()->uses()) { + if (User == ExtractHigh.getNode()) + continue; + + if (User->getOpcode() != ISD::EXTRACT_SUBVECTOR || + !isNullConstant(User->getOperand(1))) { + HasFoundMULLow = false; + break; + } + + ExtractLow.setNode(User); + } + + if (!ExtractLow || !ExtractLow->hasOneUse()) + HasFoundMULLow = false; + + // Check ExtractLow's user. + if (HasFoundMULLow) { + SDNode *ExtractLowUser = *ExtractLow.getNode()->use_begin(); + if (ExtractLowUser->getOpcode() != N->getOpcode()) + HasFoundMULLow = false; + + if (ExtractLowUser->getOperand(0) == ExtractLow) { + if (ExtractLowUser->getOperand(1).getOpcode() == ISD::TRUNCATE) + TruncLow = ExtractLowUser->getOperand(1); + else + HasFoundMULLow = false; + } else { + if (ExtractLowUser->getOperand(0).getOpcode() == ISD::TRUNCATE) + TruncLow = ExtractLowUser->getOperand(0); + else + HasFoundMULLow = false; + } + } + + // If the truncate's operand is BUILD_VECTOR with DUP, do not combine the op + // with uzp1. + // You can see the regressions on test/CodeGen/AArch64/aarch64-smull.ll + EVT TruncHighVT = TruncHigh.getValueType(); + EVT UZP1VT = TruncHighVT.getDoubleNumVectorElementsVT(*DAG.getContext()); + SDValue TruncLowOp = + HasFoundMULLow ? TruncLow.getOperand(0) : DAG.getUNDEF(UZP1VT); + EVT TruncLowOpVT = TruncLowOp.getValueType(); + if (HasFoundMULLow && (TruncLowOp.getOpcode() == AArch64ISD::DUP || + DAG.isSplatValue(TruncLowOp, false))) + return SDValue(); + + // Create uzp1, extract_high and extract_low. + if (TruncHighOpVT != UZP1VT) + TruncHighOp = DAG.getNode(ISD::BITCAST, DL, UZP1VT, TruncHighOp); + if (TruncLowOpVT != UZP1VT) + TruncLowOp = DAG.getNode(ISD::BITCAST, DL, UZP1VT, TruncLowOp); + + SDValue UZP1 = + DAG.getNode(AArch64ISD::UZP1, DL, UZP1VT, TruncLowOp, TruncHighOp); + SDValue HighIdxCst = + DAG.getConstant(TruncHighVT.getVectorNumElements(), DL, MVT::i64); + SDValue NewTruncHigh = + DAG.getNode(ISD::EXTRACT_SUBVECTOR, DL, TruncHighVT, UZP1, HighIdxCst); + DAG.ReplaceAllUsesWith(TruncHigh, NewTruncHigh); + + if (HasFoundMULLow) { + EVT TruncLowVT = TruncLow.getValueType(); + SDValue NewTruncLow = DAG.getNode(ISD::EXTRACT_SUBVECTOR, DL, TruncLowVT, + UZP1, ExtractLow.getOperand(1)); + DAG.ReplaceAllUsesWith(TruncLow, NewTruncLow); + } + + return SDValue(N, 0); +} + +static SDValue performMULLCombine(SDNode *N, + TargetLowering::DAGCombinerInfo &DCI, + SelectionDAG &DAG) { + if (SDValue Val = + tryCombineLongOpWithDup(Intrinsic::not_intrinsic, N, DCI, DAG)) + return Val; + + if (SDValue Val = tryCombineMULLWithUZP1(N, DCI, DAG)) + return Val; + + return SDValue(); +} + SDValue AArch64TargetLowering::PerformDAGCombine(SDNode *N, DAGCombinerInfo &DCI) const { SelectionDAG &DAG = DCI.DAG; @@ -22535,7 +22681,7 @@ SDValue AArch64TargetLowering::PerformDAGCombine(SDNode *N, case AArch64ISD::SMULL: case AArch64ISD::UMULL: case AArch64ISD::PMULL: - return tryCombineLongOpWithDup(Intrinsic::not_intrinsic, N, DCI, DAG); + return performMULLCombine(N, DCI, DAG); case ISD::INTRINSIC_VOID: case ISD::INTRINSIC_W_CHAIN: switch (cast(N->getOperand(1))->getZExtValue()) { diff --git a/llvm/test/CodeGen/AArch64/aarch64-smull.ll b/llvm/test/CodeGen/AArch64/aarch64-smull.ll index e4d733fd7c1c6..cf2bc9c2e5896 100644 --- a/llvm/test/CodeGen/AArch64/aarch64-smull.ll +++ b/llvm/test/CodeGen/AArch64/aarch64-smull.ll @@ -1033,13 +1033,11 @@ define <8 x i32> @umull_and_v8i32(<8 x i16> %src1, <8 x i32> %src2) { ; CHECK-LABEL: umull_and_v8i32: ; CHECK: // %bb.0: // %entry ; CHECK-NEXT: movi v3.2d, #0x0000ff000000ff -; CHECK-NEXT: ext v4.16b, v0.16b, v0.16b, #8 ; CHECK-NEXT: and v2.16b, v2.16b, v3.16b ; CHECK-NEXT: and v1.16b, v1.16b, v3.16b -; CHECK-NEXT: xtn v1.4h, v1.4s -; CHECK-NEXT: xtn v2.4h, v2.4s -; CHECK-NEXT: umull v0.4s, v0.4h, v1.4h -; CHECK-NEXT: umull v1.4s, v4.4h, v2.4h +; CHECK-NEXT: uzp1 v2.8h, v1.8h, v2.8h +; CHECK-NEXT: umull2 v1.4s, v0.8h, v2.8h +; CHECK-NEXT: umull v0.4s, v0.4h, v2.4h ; CHECK-NEXT: ret entry: %in1 = zext <8 x i16> %src1 to <8 x i32> @@ -1084,13 +1082,11 @@ define <4 x i64> @umull_and_v4i64(<4 x i32> %src1, <4 x i64> %src2) { ; CHECK-LABEL: umull_and_v4i64: ; CHECK: // %bb.0: // %entry ; CHECK-NEXT: movi v3.2d, #0x000000000000ff -; CHECK-NEXT: ext v4.16b, v0.16b, v0.16b, #8 ; CHECK-NEXT: and v2.16b, v2.16b, v3.16b ; CHECK-NEXT: and v1.16b, v1.16b, v3.16b -; CHECK-NEXT: xtn v1.2s, v1.2d -; CHECK-NEXT: xtn v2.2s, v2.2d -; CHECK-NEXT: umull v0.2d, v0.2s, v1.2s -; CHECK-NEXT: umull v1.2d, v4.2s, v2.2s +; CHECK-NEXT: uzp1 v2.4s, v1.4s, v2.4s +; CHECK-NEXT: umull2 v1.2d, v0.4s, v2.4s +; CHECK-NEXT: umull v0.2d, v0.2s, v2.2s ; CHECK-NEXT: ret entry: %in1 = zext <4 x i32> %src1 to <4 x i64> @@ -1115,3 +1111,227 @@ entry: %out = mul nsw <4 x i64> %in1, %broadcast.splat ret <4 x i64> %out } + +define void @pmlsl2_v8i16_uzp1(<16 x i8> %0, <8 x i16> %1, ptr %2, ptr %3) { +; CHECK-LABEL: pmlsl2_v8i16_uzp1: +; CHECK: // %bb.0: +; CHECK-NEXT: ldr q2, [x1, #16] +; CHECK-NEXT: uzp1 v2.16b, v0.16b, v2.16b +; CHECK-NEXT: pmull2 v0.8h, v0.16b, v2.16b +; CHECK-NEXT: sub v0.8h, v1.8h, v0.8h +; CHECK-NEXT: str q0, [x0] +; CHECK-NEXT: ret + %5 = getelementptr inbounds i32, ptr %3, i64 4 + %6 = load <8 x i16>, ptr %5, align 4 + %7 = trunc <8 x i16> %6 to <8 x i8> + %8 = shufflevector <16 x i8> %0, <16 x i8> poison, <8 x i32> + %9 = tail call <8 x i16> @llvm.aarch64.neon.pmull.v8i16(<8 x i8> %8, <8 x i8> %7) + %10 = sub <8 x i16> %1, %9 + store <8 x i16> %10, ptr %2, align 16 + ret void +} + +define void @smlsl2_v8i16_uzp1(<16 x i8> %0, <8 x i16> %1, ptr %2, ptr %3) { +; CHECK-LABEL: smlsl2_v8i16_uzp1: +; CHECK: // %bb.0: +; CHECK-NEXT: ldr q2, [x1, #16] +; CHECK-NEXT: uzp1 v2.16b, v0.16b, v2.16b +; CHECK-NEXT: smlsl2 v1.8h, v0.16b, v2.16b +; CHECK-NEXT: str q1, [x0] +; CHECK-NEXT: ret + %5 = getelementptr inbounds i32, ptr %3, i64 4 + %6 = load <8 x i16>, ptr %5, align 4 + %7 = trunc <8 x i16> %6 to <8 x i8> + %8 = shufflevector <16 x i8> %0, <16 x i8> poison, <8 x i32> + %9 = tail call <8 x i16> @llvm.aarch64.neon.smull.v8i16(<8 x i8> %8, <8 x i8> %7) + %10 = sub <8 x i16> %1, %9 + store <8 x i16> %10, ptr %2, align 16 + ret void +} + +define void @umlsl2_v8i16_uzp1(<16 x i8> %0, <8 x i16> %1, ptr %2, ptr %3) { +; CHECK-LABEL: umlsl2_v8i16_uzp1: +; CHECK: // %bb.0: +; CHECK-NEXT: ldr q2, [x1, #16] +; CHECK-NEXT: uzp1 v2.16b, v0.16b, v2.16b +; CHECK-NEXT: umlsl2 v1.8h, v0.16b, v2.16b +; CHECK-NEXT: str q1, [x0] +; CHECK-NEXT: ret + %5 = getelementptr inbounds i32, ptr %3, i64 4 + %6 = load <8 x i16>, ptr %5, align 4 + %7 = trunc <8 x i16> %6 to <8 x i8> + %8 = shufflevector <16 x i8> %0, <16 x i8> poison, <8 x i32> + %9 = tail call <8 x i16> @llvm.aarch64.neon.umull.v8i16(<8 x i8> %8, <8 x i8> %7) + %10 = sub <8 x i16> %1, %9 + store <8 x i16> %10, ptr %2, align 16 + ret void +} + +define void @smlsl2_v4i32_uzp1(<8 x i16> %0, <4 x i32> %1, ptr %2, ptr %3) { +; CHECK-LABEL: smlsl2_v4i32_uzp1: +; CHECK: // %bb.0: +; CHECK-NEXT: ldr q2, [x1, #16] +; CHECK-NEXT: uzp1 v2.8h, v0.8h, v2.8h +; CHECK-NEXT: smlsl2 v1.4s, v0.8h, v2.8h +; CHECK-NEXT: str q1, [x0] +; CHECK-NEXT: ret + %5 = getelementptr inbounds i32, ptr %3, i64 4 + %6 = load <4 x i32>, ptr %5, align 4 + %7 = trunc <4 x i32> %6 to <4 x i16> + %8 = shufflevector <8 x i16> %0, <8 x i16> poison, <4 x i32> + %9 = tail call <4 x i32> @llvm.aarch64.neon.smull.v4i32(<4 x i16> %8, <4 x i16> %7) + %10 = sub <4 x i32> %1, %9 + store <4 x i32> %10, ptr %2, align 16 + ret void +} + +define void @umlsl2_v4i32_uzp1(<8 x i16> %0, <4 x i32> %1, ptr %2, ptr %3) { +; CHECK-LABEL: umlsl2_v4i32_uzp1: +; CHECK: // %bb.0: +; CHECK-NEXT: ldr q2, [x1, #16] +; CHECK-NEXT: uzp1 v2.8h, v0.8h, v2.8h +; CHECK-NEXT: umlsl2 v1.4s, v0.8h, v2.8h +; CHECK-NEXT: str q1, [x0] +; CHECK-NEXT: ret + %5 = getelementptr inbounds i32, ptr %3, i64 4 + %6 = load <4 x i32>, ptr %5, align 4 + %7 = trunc <4 x i32> %6 to <4 x i16> + %8 = shufflevector <8 x i16> %0, <8 x i16> poison, <4 x i32> + %9 = tail call <4 x i32> @llvm.aarch64.neon.umull.v4i32(<4 x i16> %8, <4 x i16> %7) + %10 = sub <4 x i32> %1, %9 + store <4 x i32> %10, ptr %2, align 16 + ret void +} + +define void @pmlsl_pmlsl2_v8i16_uzp1(<16 x i8> %0, <8 x i16> %1, ptr %2, ptr %3, i32 %4) { +; CHECK-LABEL: pmlsl_pmlsl2_v8i16_uzp1: +; CHECK: // %bb.0: // %entry +; CHECK-NEXT: ldp q2, q3, [x1] +; CHECK-NEXT: uzp1 v2.16b, v2.16b, v3.16b +; CHECK-NEXT: pmull v3.8h, v0.8b, v2.8b +; CHECK-NEXT: pmull2 v0.8h, v0.16b, v2.16b +; CHECK-NEXT: add v0.8h, v3.8h, v0.8h +; CHECK-NEXT: sub v0.8h, v1.8h, v0.8h +; CHECK-NEXT: str q0, [x0] +; CHECK-NEXT: ret +entry: + %5 = load <8 x i16>, ptr %3, align 4 + %6 = trunc <8 x i16> %5 to <8 x i8> + %7 = getelementptr inbounds i32, ptr %3, i64 4 + %8 = load <8 x i16>, ptr %7, align 4 + %9 = trunc <8 x i16> %8 to <8 x i8> + %10 = shufflevector <16 x i8> %0, <16 x i8> poison, <8 x i32> + %11 = tail call <8 x i16> @llvm.aarch64.neon.pmull.v8i16(<8 x i8> %10, <8 x i8> %6) + %12 = shufflevector <16 x i8> %0, <16 x i8> poison, <8 x i32> + %13 = tail call <8 x i16> @llvm.aarch64.neon.pmull.v8i16(<8 x i8> %12, <8 x i8> %9) + %14 = add <8 x i16> %11, %13 + %15 = sub <8 x i16> %1, %14 + store <8 x i16> %15, ptr %2, align 16 + ret void +} + +define void @smlsl_smlsl2_v8i16_uzp1(<16 x i8> %0, <8 x i16> %1, ptr %2, ptr %3, i32 %4) { +; CHECK-LABEL: smlsl_smlsl2_v8i16_uzp1: +; CHECK: // %bb.0: // %entry +; CHECK-NEXT: ldp q2, q3, [x1] +; CHECK-NEXT: uzp1 v2.16b, v2.16b, v3.16b +; CHECK-NEXT: smlsl v1.8h, v0.8b, v2.8b +; CHECK-NEXT: smlsl2 v1.8h, v0.16b, v2.16b +; CHECK-NEXT: str q1, [x0] +; CHECK-NEXT: ret +entry: + %5 = load <8 x i16>, ptr %3, align 4 + %6 = trunc <8 x i16> %5 to <8 x i8> + %7 = getelementptr inbounds i32, ptr %3, i64 4 + %8 = load <8 x i16>, ptr %7, align 4 + %9 = trunc <8 x i16> %8 to <8 x i8> + %10 = shufflevector <16 x i8> %0, <16 x i8> poison, <8 x i32> + %11 = tail call <8 x i16> @llvm.aarch64.neon.smull.v8i16(<8 x i8> %10, <8 x i8> %6) + %12 = shufflevector <16 x i8> %0, <16 x i8> poison, <8 x i32> + %13 = tail call <8 x i16> @llvm.aarch64.neon.smull.v8i16(<8 x i8> %12, <8 x i8> %9) + %14 = add <8 x i16> %11, %13 + %15 = sub <8 x i16> %1, %14 + store <8 x i16> %15, ptr %2, align 16 + ret void +} + +define void @umlsl_umlsl2_v8i16_uzp1(<16 x i8> %0, <8 x i16> %1, ptr %2, ptr %3, i32 %4) { +; CHECK-LABEL: umlsl_umlsl2_v8i16_uzp1: +; CHECK: // %bb.0: // %entry +; CHECK-NEXT: ldp q2, q3, [x1] +; CHECK-NEXT: uzp1 v2.16b, v2.16b, v3.16b +; CHECK-NEXT: umlsl v1.8h, v0.8b, v2.8b +; CHECK-NEXT: umlsl2 v1.8h, v0.16b, v2.16b +; CHECK-NEXT: str q1, [x0] +; CHECK-NEXT: ret +entry: + %5 = load <8 x i16>, ptr %3, align 4 + %6 = trunc <8 x i16> %5 to <8 x i8> + %7 = getelementptr inbounds i32, ptr %3, i64 4 + %8 = load <8 x i16>, ptr %7, align 4 + %9 = trunc <8 x i16> %8 to <8 x i8> + %10 = shufflevector <16 x i8> %0, <16 x i8> poison, <8 x i32> + %11 = tail call <8 x i16> @llvm.aarch64.neon.umull.v8i16(<8 x i8> %10, <8 x i8> %6) + %12 = shufflevector <16 x i8> %0, <16 x i8> poison, <8 x i32> + %13 = tail call <8 x i16> @llvm.aarch64.neon.umull.v8i16(<8 x i8> %12, <8 x i8> %9) + %14 = add <8 x i16> %11, %13 + %15 = sub <8 x i16> %1, %14 + store <8 x i16> %15, ptr %2, align 16 + ret void +} + +define void @smlsl_smlsl2_v4i32_uzp1(<8 x i16> %0, <4 x i32> %1, ptr %2, ptr %3, i32 %4) { +; CHECK-LABEL: smlsl_smlsl2_v4i32_uzp1: +; CHECK: // %bb.0: // %entry +; CHECK-NEXT: ldp q2, q3, [x1] +; CHECK-NEXT: uzp1 v2.8h, v2.8h, v3.8h +; CHECK-NEXT: smlsl v1.4s, v0.4h, v2.4h +; CHECK-NEXT: smlsl2 v1.4s, v0.8h, v2.8h +; CHECK-NEXT: str q1, [x0] +; CHECK-NEXT: ret +entry: + %5 = load <4 x i32>, ptr %3, align 4 + %6 = trunc <4 x i32> %5 to <4 x i16> + %7 = getelementptr inbounds i32, ptr %3, i64 4 + %8 = load <4 x i32>, ptr %7, align 4 + %9 = trunc <4 x i32> %8 to <4 x i16> + %10 = shufflevector <8 x i16> %0, <8 x i16> poison, <4 x i32> + %11 = tail call <4 x i32> @llvm.aarch64.neon.smull.v4i32(<4 x i16> %10, <4 x i16> %6) + %12 = shufflevector <8 x i16> %0, <8 x i16> poison, <4 x i32> + %13 = tail call <4 x i32> @llvm.aarch64.neon.smull.v4i32(<4 x i16> %12, <4 x i16> %9) + %14 = add <4 x i32> %11, %13 + %15 = sub <4 x i32> %1, %14 + store <4 x i32> %15, ptr %2, align 16 + ret void +} + +define void @umlsl_umlsl2_v4i32_uzp1(<8 x i16> %0, <4 x i32> %1, ptr %2, ptr %3, i32 %4) { +; CHECK-LABEL: umlsl_umlsl2_v4i32_uzp1: +; CHECK: // %bb.0: // %entry +; CHECK-NEXT: ldp q2, q3, [x1] +; CHECK-NEXT: uzp1 v2.8h, v2.8h, v3.8h +; CHECK-NEXT: umlsl v1.4s, v0.4h, v2.4h +; CHECK-NEXT: umlsl2 v1.4s, v0.8h, v2.8h +; CHECK-NEXT: str q1, [x0] +; CHECK-NEXT: ret +entry: + %5 = load <4 x i32>, ptr %3, align 4 + %6 = trunc <4 x i32> %5 to <4 x i16> + %7 = getelementptr inbounds i32, ptr %3, i64 4 + %8 = load <4 x i32>, ptr %7, align 4 + %9 = trunc <4 x i32> %8 to <4 x i16> + %10 = shufflevector <8 x i16> %0, <8 x i16> poison, <4 x i32> + %11 = tail call <4 x i32> @llvm.aarch64.neon.umull.v4i32(<4 x i16> %10, <4 x i16> %6) + %12 = shufflevector <8 x i16> %0, <8 x i16> poison, <4 x i32> + %13 = tail call <4 x i32> @llvm.aarch64.neon.umull.v4i32(<4 x i16> %12, <4 x i16> %9) + %14 = add <4 x i32> %11, %13 + %15 = sub <4 x i32> %1, %14 + store <4 x i32> %15, ptr %2, align 16 + ret void +} + +declare <8 x i16> @llvm.aarch64.neon.pmull.v8i16(<8 x i8>, <8 x i8>) +declare <8 x i16> @llvm.aarch64.neon.smull.v8i16(<8 x i8>, <8 x i8>) +declare <8 x i16> @llvm.aarch64.neon.umull.v8i16(<8 x i8>, <8 x i8>) +declare <4 x i32> @llvm.aarch64.neon.smull.v4i32(<4 x i16>, <4 x i16>) +declare <4 x i32> @llvm.aarch64.neon.umull.v4i32(<4 x i16>, <4 x i16>)