diff --git a/llvm/lib/Target/AArch64/AArch64SVEInstrInfo.td b/llvm/lib/Target/AArch64/AArch64SVEInstrInfo.td index 7604ffdc9f646..7fe4f7acdbd49 100644 --- a/llvm/lib/Target/AArch64/AArch64SVEInstrInfo.td +++ b/llvm/lib/Target/AArch64/AArch64SVEInstrInfo.td @@ -4238,6 +4238,13 @@ defm UDOT_ZZZ_HtoS : sve2p1_two_way_dot_vv<"udot", 0b1, int_aarch64_sve_udot_x2 defm SDOT_ZZZI_HtoS : sve2p1_two_way_dot_vvi<"sdot", 0b0, int_aarch64_sve_sdot_lane_x2>; defm UDOT_ZZZI_HtoS : sve2p1_two_way_dot_vvi<"udot", 0b1, int_aarch64_sve_udot_lane_x2>; +let Predicates = [HasSVE2p1_or_SME2] in { + def : Pat<(nxv4i32 (partial_reduce_umla nxv4i32:$Acc, nxv8i16:$MulLHS, nxv8i16:$MulRHS)), + (UDOT_ZZZ_HtoS $Acc, $MulLHS, $MulRHS)>; + def : Pat<(nxv4i32 (partial_reduce_smla nxv4i32:$Acc, nxv8i16:$MulLHS, nxv8i16:$MulRHS)), + (SDOT_ZZZ_HtoS $Acc, $MulLHS, $MulRHS)>; +} // End HasSVE2p1_or_SME2 + defm SQCVTN_Z2Z_StoH : sve2p1_multi_vec_extract_narrow<"sqcvtn", 0b00, int_aarch64_sve_sqcvtn_x2>; defm UQCVTN_Z2Z_StoH : sve2p1_multi_vec_extract_narrow<"uqcvtn", 0b01, int_aarch64_sve_uqcvtn_x2>; defm SQCVTUN_Z2Z_StoH : sve2p1_multi_vec_extract_narrow<"sqcvtun", 0b10, int_aarch64_sve_sqcvtun_x2>; diff --git a/llvm/test/CodeGen/AArch64/sve2p1-dots-partial-reduction.ll b/llvm/test/CodeGen/AArch64/sve2p1-dots-partial-reduction.ll new file mode 100644 index 0000000000000..d9ba613931982 --- /dev/null +++ b/llvm/test/CodeGen/AArch64/sve2p1-dots-partial-reduction.ll @@ -0,0 +1,105 @@ +; NOTE: Assertions have been autogenerated by utils/update_llc_test_checks.py UTC_ARGS: --version 5 +; RUN: llc -mtriple=aarch64-linux-gnu -mattr=+sve2p1 < %s | FileCheck %s +; RUN: llc -mtriple=aarch64-linux-gnu -mattr=+sme2 -force-streaming < %s | FileCheck %s + +define @udot_vl128( %acc, %a, %b) { +; CHECK-LABEL: udot_vl128: +; CHECK: // %bb.0: // %entry +; CHECK-NEXT: udot z0.s, z1.h, z2.h +; CHECK-NEXT: ret +entry: + %a.wide = zext %a to + %b.wide = zext %b to + %mult = mul nuw nsw %a.wide, %b.wide + %partial.reduce = tail call @llvm.experimental.vector.partial.reduce.add( %acc, %mult) + ret %partial.reduce +} + +define @sdot_vl128( %acc, %a, %b) { +; CHECK-LABEL: sdot_vl128: +; CHECK: // %bb.0: // %entry +; CHECK-NEXT: sdot z0.s, z1.h, z2.h +; CHECK-NEXT: ret +entry: + %a.wide = sext %a to + %b.wide = sext %b to + %mult = mul nuw nsw %a.wide, %b.wide + %partial.reduce = tail call @llvm.experimental.vector.partial.reduce.add( %acc, %mult) + ret %partial.reduce +} + +define void @udot_vl256(ptr %accptr, ptr %aptr, ptr %bptr) vscale_range(2,2) { +; CHECK-LABEL: udot_vl256: +; CHECK: // %bb.0: // %entry +; CHECK-NEXT: ldr z0, [x0] +; CHECK-NEXT: ldr z1, [x1] +; CHECK-NEXT: ldr z2, [x2] +; CHECK-NEXT: udot z0.s, z1.h, z2.h +; CHECK-NEXT: str z0, [x0] +; CHECK-NEXT: ret +entry: + %acc = load <8 x i32>, ptr %accptr + %a = load <16 x i16>, ptr %aptr + %b = load <16 x i16>, ptr %bptr + %a.wide = zext <16 x i16> %a to <16 x i32> + %b.wide = zext <16 x i16> %b to <16 x i32> + %mult = mul nuw nsw <16 x i32> %a.wide, %b.wide + %partial.reduce = tail call <8 x i32> @llvm.experimental.vector.partial.reduce.add(<8 x i32> %acc, <16 x i32> %mult) + store <8 x i32> %partial.reduce, ptr %accptr + ret void +} + +define void @sdot_vl256(ptr %accptr, ptr %aptr, ptr %bptr) vscale_range(2,2) { +; CHECK-LABEL: sdot_vl256: +; CHECK: // %bb.0: // %entry +; CHECK-NEXT: ldr z0, [x0] +; CHECK-NEXT: ldr z1, [x1] +; CHECK-NEXT: ldr z2, [x2] +; CHECK-NEXT: sdot z0.s, z1.h, z2.h +; CHECK-NEXT: str z0, [x0] +; CHECK-NEXT: ret +entry: + %acc = load <8 x i32>, ptr %accptr + %a = load <16 x i16>, ptr %aptr + %b = load <16 x i16>, ptr %bptr + %a.wide = sext <16 x i16> %a to <16 x i32> + %b.wide = sext <16 x i16> %b to <16 x i32> + %mult = mul nuw nsw <16 x i32> %a.wide, %b.wide + %partial.reduce = tail call <8 x i32> @llvm.experimental.vector.partial.reduce.add(<8 x i32> %acc, <16 x i32> %mult) + store <8 x i32> %partial.reduce, ptr %accptr + ret void +} + +define <4 x i32> @fixed_udot_s_h(<4 x i32> %acc, <8 x i16> %a, <8 x i16> %b) { +; CHECK-LABEL: fixed_udot_s_h: +; CHECK: // %bb.0: // %entry +; CHECK-NEXT: // kill: def $q0 killed $q0 def $z0 +; CHECK-NEXT: // kill: def $q2 killed $q2 def $z2 +; CHECK-NEXT: // kill: def $q1 killed $q1 def $z1 +; CHECK-NEXT: udot z0.s, z1.h, z2.h +; CHECK-NEXT: // kill: def $q0 killed $q0 killed $z0 +; CHECK-NEXT: ret +entry: + %a.wide = zext <8 x i16> %a to <8 x i32> + %b.wide = zext <8 x i16> %b to <8 x i32> + %mult = mul nuw nsw <8 x i32> %a.wide, %b.wide + %partial.reduce = tail call <4 x i32> @llvm.experimental.vector.partial.reduce.add(<4 x i32> %acc, <8 x i32> %mult) + ret <4 x i32> %partial.reduce +} + +define <4 x i32> @fixed_sdot_s_h(<4 x i32> %acc, <8 x i16> %a, <8 x i16> %b) { +; CHECK-LABEL: fixed_sdot_s_h: +; CHECK: // %bb.0: // %entry +; CHECK-NEXT: // kill: def $q0 killed $q0 def $z0 +; CHECK-NEXT: // kill: def $q2 killed $q2 def $z2 +; CHECK-NEXT: // kill: def $q1 killed $q1 def $z1 +; CHECK-NEXT: sdot z0.s, z1.h, z2.h +; CHECK-NEXT: // kill: def $q0 killed $q0 killed $z0 +; CHECK-NEXT: ret +entry: + %a.wide = sext <8 x i16> %a to <8 x i32> + %b.wide = sext <8 x i16> %b to <8 x i32> + %mult = mul nuw nsw <8 x i32> %a.wide, %b.wide + %partial.reduce = tail call <4 x i32> @llvm.experimental.vector.partial.reduce.add(<4 x i32> %acc, <8 x i32> %mult) + ret <4 x i32> %partial.reduce +}