diff --git a/llvm/lib/Target/AArch64/AArch64InstrFormats.td b/llvm/lib/Target/AArch64/AArch64InstrFormats.td index 4d2e740779961..ed5410d26d924 100644 --- a/llvm/lib/Target/AArch64/AArch64InstrFormats.td +++ b/llvm/lib/Target/AArch64/AArch64InstrFormats.td @@ -10,6 +10,20 @@ // Describe AArch64 instructions format here // +// Helper class to convert vector element types to integers. +class ChangeElementTypeToInteger { + ValueType VT = !cond( + !eq(InVT, v2f32): v2i32, + !eq(InVT, v4f32): v4i32, + // TODO: Other types. + true : untyped); +} + +class VTPair { + ValueType VT0 = A; + ValueType VT1 = B; +} + // Format specifies the encoding used by the instruction. This is part of the // ad-hoc solution used to emit machine instruction encodings by our machine // code emitter. @@ -8952,36 +8966,6 @@ multiclass SIMDThreeSameVectorBFDot { v4f32, v8bf16>; } -class BaseSIMDThreeSameVectorBF16DotI - : BaseSIMDIndexedTied { - - bits<2> idx; - let Inst{21} = idx{0}; // L - let Inst{11} = idx{1}; // H -} - -multiclass SIMDThreeSameVectorBF16DotI { - - def v4bf16 : BaseSIMDThreeSameVectorBF16DotI<0, U, asm, ".2s", ".4h", - ".2h", V64, v2f32, v4bf16>; - def v8bf16 : BaseSIMDThreeSameVectorBF16DotI<1, U, asm, ".4s", ".8h", - ".2h", V128, v4f32, v8bf16>; -} - let mayRaiseFPException = 1, Uses = [FPCR] in class SIMDBF16MLAL : BaseSIMDThreeSameVectorTied } } // End of let mayStore = 0, mayLoad = 0, hasSideEffects = 0 +multiclass BaseSIMDThreeSameVectorBF16DotI { + let mayLoad = 0, mayStore = 0, hasSideEffects = 0 in { + def NAME : BaseSIMDIndexedTied + { + bits<2> idx; + let Inst{21} = idx{0}; // L + let Inst{11} = idx{1}; // H + } + } + + foreach DupTypes = [VTPair, + VTPair.VT, v4i32>] in { + def : Pat<(AccumType (int_aarch64_neon_bfdot + (AccumType RegType:$Rd), (InputType RegType:$Rn), + (InputType (bitconvert + (DupTypes.VT0 (AArch64duplane32 (DupTypes.VT1 + (bitconvert (v8bf16 V128:$Rm))), VectorIndexS:$Idx)))))), + (!cast(NAME) $Rd, $Rn, $Rm, VectorIndexS:$Idx)>; + } +} + +multiclass SIMDThreeSameVectorBF16DotI { + defm v4bf16 : BaseSIMDThreeSameVectorBF16DotI<0, U, asm, ".2s", ".4h", + ".2h", V64, v2f32, v4bf16>; + defm v8bf16 : BaseSIMDThreeSameVectorBF16DotI<1, U, asm, ".4s", ".8h", + ".2h", V128, v4f32, v8bf16>; +} + //---------------------------------------------------------------------------- class BaseSIMDThreeSameVectorIndexB sz, bits<4> opc, string asm, string dst_kind, diff --git a/llvm/lib/Target/AArch64/AArch64InstrInfo.td b/llvm/lib/Target/AArch64/AArch64InstrInfo.td index d1d352dcc5f1f..835fd0575b5e2 100644 --- a/llvm/lib/Target/AArch64/AArch64InstrInfo.td +++ b/llvm/lib/Target/AArch64/AArch64InstrInfo.td @@ -1735,23 +1735,6 @@ def BFCVTN2 : SIMD_BFCVTN2; def : Pat<(concat_vectors (v4bf16 V64:$Rd), (any_fpround (v4f32 V128:$Rn))), (BFCVTN2 (v8bf16 (INSERT_SUBREG (IMPLICIT_DEF), V64:$Rd, dsub)), V128:$Rn)>; - -// Vector-scalar BFDOT: -// The second source operand of the 64-bit variant of BF16DOTlane is a 128-bit -// register (the instruction uses a single 32-bit lane from it), so the pattern -// is a bit tricky. -def : Pat<(v2f32 (int_aarch64_neon_bfdot - (v2f32 V64:$Rd), (v4bf16 V64:$Rn), - (v4bf16 (bitconvert - (v2i32 (AArch64duplane32 - (v4i32 (bitconvert - (v8bf16 (insert_subvector undef, - (v4bf16 V64:$Rm), - (i64 0))))), - VectorIndexS:$idx)))))), - (BF16DOTlanev4bf16 (v2f32 V64:$Rd), (v4bf16 V64:$Rn), - (SUBREG_TO_REG (i32 0), V64:$Rm, dsub), - VectorIndexS:$idx)>; } let Predicates = [HasNEONandIsStreamingSafe, HasBF16] in { diff --git a/llvm/test/CodeGen/AArch64/aarch64-bf16-dotprod-intrinsics.ll b/llvm/test/CodeGen/AArch64/aarch64-bf16-dotprod-intrinsics.ll index 52b542790e82d..ca3cd6bbae549 100644 --- a/llvm/test/CodeGen/AArch64/aarch64-bf16-dotprod-intrinsics.ll +++ b/llvm/test/CodeGen/AArch64/aarch64-bf16-dotprod-intrinsics.ll @@ -151,6 +151,37 @@ entry: ret <4 x float> %vbfmlaltq_v3.i } +define <4 x float> @test_vbfdotq_laneq_f32_v4i32_shufflevector(<8 x bfloat> %a, <8 x bfloat> %b) { +; CHECK-LABEL: test_vbfdotq_laneq_f32_v4i32_shufflevector: +; CHECK: // %bb.0: // %entry +; CHECK-NEXT: movi v2.2d, #0000000000000000 +; CHECK-NEXT: bfdot v2.4s, v0.8h, v1.2h[0] +; CHECK-NEXT: mov v0.16b, v2.16b +; CHECK-NEXT: ret +entry: + %0 = bitcast <8 x bfloat> %b to <4 x i32> + %1 = shufflevector <4 x i32> %0, <4 x i32> poison, <4 x i32> zeroinitializer + %2 = bitcast <4 x i32> %1 to <8 x bfloat> + %vbfdotq = call <4 x float> @llvm.aarch64.neon.bfdot.v4f32.v8bf16(<4 x float> zeroinitializer, <8 x bfloat> %a, <8 x bfloat> %2) + ret <4 x float> %vbfdotq +} + +define <2 x float> @test_vbfdotq_laneq_f32_v2i32_shufflevector(<4 x bfloat> %a, <4 x bfloat> %b) { +; CHECK-LABEL: test_vbfdotq_laneq_f32_v2i32_shufflevector: +; CHECK: // %bb.0: // %entry +; CHECK-NEXT: movi d2, #0000000000000000 +; CHECK-NEXT: // kill: def $d1 killed $d1 def $q1 +; CHECK-NEXT: bfdot v2.2s, v0.4h, v1.2h[0] +; CHECK-NEXT: fmov d0, d2 +; CHECK-NEXT: ret +entry: + %0 = bitcast <4 x bfloat> %b to <2 x i32> + %1 = shufflevector <2 x i32> %0, <2 x i32> poison, <2 x i32> zeroinitializer + %2 = bitcast <2 x i32> %1 to <4 x bfloat> + %vbfdotq = call <2 x float> @llvm.aarch64.neon.bfdot.v2f32.v4bf16(<2 x float> zeroinitializer, <4 x bfloat> %a, <4 x bfloat> %2) + ret <2 x float> %vbfdotq +} + declare <2 x float> @llvm.aarch64.neon.bfdot.v2f32.v4bf16(<2 x float>, <4 x bfloat>, <4 x bfloat>) declare <4 x float> @llvm.aarch64.neon.bfdot.v4f32.v8bf16(<4 x float>, <8 x bfloat>, <8 x bfloat>) declare <4 x float> @llvm.aarch64.neon.bfmmla(<4 x float>, <8 x bfloat>, <8 x bfloat>)