diff --git a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp index 8457f6178fdc2..e36396c7bdf2b 100644 --- a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp +++ b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp @@ -1921,6 +1921,9 @@ AArch64TargetLowering::AArch64TargetLowering(const TargetMachine &TM, if (Subtarget->hasSVE2p1() || Subtarget->hasSME2()) { setPartialReduceMLAAction(ISD::PARTIAL_REDUCE_FMLA, MVT::nxv4f32, MVT::nxv8f16, Legal); + // We can use SVE2p1 fdot to emulate the fixed-length variant. + setPartialReduceMLAAction(ISD::PARTIAL_REDUCE_FMLA, MVT::v4f32, + MVT::v8f16, Custom); } } diff --git a/llvm/test/CodeGen/AArch64/sve2p1-fixed-length-fdot.ll b/llvm/test/CodeGen/AArch64/sve2p1-fixed-length-fdot.ll index 89216ce2cb72b..864c66caf5f6c 100644 --- a/llvm/test/CodeGen/AArch64/sve2p1-fixed-length-fdot.ll +++ b/llvm/test/CodeGen/AArch64/sve2p1-fixed-length-fdot.ll @@ -4,6 +4,43 @@ target triple = "aarch64-linux-gnu" +define void @fdot_v4f32(ptr %accptr, ptr %aptr, ptr %bptr) { +; SVE2-LABEL: fdot_v4f32: +; SVE2: // %bb.0: // %entry +; SVE2-NEXT: ldr q0, [x1] +; SVE2-NEXT: ldr q1, [x2] +; SVE2-NEXT: fcvtl v2.4s, v0.4h +; SVE2-NEXT: fcvtl v3.4s, v1.4h +; SVE2-NEXT: fcvtl2 v0.4s, v0.8h +; SVE2-NEXT: fcvtl2 v1.4s, v1.8h +; SVE2-NEXT: fmul v2.4s, v2.4s, v3.4s +; SVE2-NEXT: ldr q3, [x0] +; SVE2-NEXT: fmul v0.4s, v0.4s, v1.4s +; SVE2-NEXT: fadd v1.4s, v3.4s, v2.4s +; SVE2-NEXT: fadd v0.4s, v1.4s, v0.4s +; SVE2-NEXT: str q0, [x0] +; SVE2-NEXT: ret +; +; SVE2P1-LABEL: fdot_v4f32: +; SVE2P1: // %bb.0: // %entry +; SVE2P1-NEXT: ldr q0, [x0] +; SVE2P1-NEXT: ldr q1, [x1] +; SVE2P1-NEXT: ldr q2, [x2] +; SVE2P1-NEXT: fdot z0.s, z1.h, z2.h +; SVE2P1-NEXT: str q0, [x0] +; SVE2P1-NEXT: ret +entry: + %acc = load <4 x float>, ptr %accptr + %a = load <8 x half>, ptr %aptr + %b = load <8 x half>, ptr %bptr + %a.wide = fpext <8 x half> %a to <8 x float> + %b.wide = fpext <8 x half> %b to <8 x float> + %mult = fmul <8 x float> %a.wide, %b.wide + %partial.reduce = call <4 x float> @llvm.vector.partial.reduce.fadd(<4 x float> %acc, <8 x float> %mult) + store <4 x float> %partial.reduce, ptr %accptr + ret void +} + define void @fdot_wide_v8f32(ptr %accptr, ptr %aptr, ptr %bptr) vscale_range(2,0) { ; SVE2-LABEL: fdot_wide_v8f32: ; SVE2: // %bb.0: // %entry @@ -177,17 +214,26 @@ entry: } define <4 x float> @fixed_fdot_wide(<4 x float> %acc, <8 x half> %a, <8 x half> %b) { -; CHECK-LABEL: fixed_fdot_wide: -; CHECK: // %bb.0: // %entry -; CHECK-NEXT: fcvtl v3.4s, v1.4h -; CHECK-NEXT: fcvtl v4.4s, v2.4h -; CHECK-NEXT: fcvtl2 v1.4s, v1.8h -; CHECK-NEXT: fcvtl2 v2.4s, v2.8h -; CHECK-NEXT: fmul v3.4s, v3.4s, v4.4s -; CHECK-NEXT: fmul v1.4s, v1.4s, v2.4s -; CHECK-NEXT: fadd v0.4s, v0.4s, v3.4s -; CHECK-NEXT: fadd v0.4s, v0.4s, v1.4s -; CHECK-NEXT: ret +; SVE2-LABEL: fixed_fdot_wide: +; SVE2: // %bb.0: // %entry +; SVE2-NEXT: fcvtl v3.4s, v1.4h +; SVE2-NEXT: fcvtl v4.4s, v2.4h +; SVE2-NEXT: fcvtl2 v1.4s, v1.8h +; SVE2-NEXT: fcvtl2 v2.4s, v2.8h +; SVE2-NEXT: fmul v3.4s, v3.4s, v4.4s +; SVE2-NEXT: fmul v1.4s, v1.4s, v2.4s +; SVE2-NEXT: fadd v0.4s, v0.4s, v3.4s +; SVE2-NEXT: fadd v0.4s, v0.4s, v1.4s +; SVE2-NEXT: ret +; +; SVE2P1-LABEL: fixed_fdot_wide: +; SVE2P1: // %bb.0: // %entry +; SVE2P1-NEXT: // kill: def $q0 killed $q0 def $z0 +; SVE2P1-NEXT: // kill: def $q2 killed $q2 def $z2 +; SVE2P1-NEXT: // kill: def $q1 killed $q1 def $z1 +; SVE2P1-NEXT: fdot z0.s, z1.h, z2.h +; SVE2P1-NEXT: // kill: def $q0 killed $q0 killed $z0 +; SVE2P1-NEXT: ret entry: %a.wide = fpext <8 x half> %a to <8 x float> %b.wide = fpext <8 x half> %b to <8 x float>