Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
78 changes: 48 additions & 30 deletions llvm/lib/Target/AArch64/AArch64InstrFormats.td
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,20 @@
// Describe AArch64 instructions format here
//

// Helper class to convert vector element types to integers.
class ChangeElementTypeToInteger<ValueType InVT> {
ValueType VT = !cond(
!eq(InVT, v2f32): v2i32,
!eq(InVT, v4f32): v4i32,
// TODO: Other types.
true : untyped);
}

class VTPair<ValueType A, ValueType B> {
ValueType VT0 = A;
ValueType VT1 = B;
}

// Format specifies the encoding used by the instruction. This is part of the
// ad-hoc solution used to emit machine instruction encodings by our machine
// code emitter.
Expand Down Expand Up @@ -8952,36 +8966,6 @@ multiclass SIMDThreeSameVectorBFDot<bit U, string asm> {
v4f32, v8bf16>;
}

class BaseSIMDThreeSameVectorBF16DotI<bit Q, bit U, string asm,
string dst_kind, string lhs_kind,
string rhs_kind,
RegisterOperand RegType,
ValueType AccumType,
ValueType InputType>
: BaseSIMDIndexedTied<Q, U, 0b0, 0b01, 0b1111,
RegType, RegType, V128, VectorIndexS,
asm, "", dst_kind, lhs_kind, rhs_kind,
[(set (AccumType RegType:$dst),
(AccumType (int_aarch64_neon_bfdot
(AccumType RegType:$Rd),
(InputType RegType:$Rn),
(InputType (bitconvert (AccumType
(AArch64duplane32 (v4f32 V128:$Rm),
VectorIndexS:$idx)))))))]> {

bits<2> idx;
let Inst{21} = idx{0}; // L
let Inst{11} = idx{1}; // H
}

multiclass SIMDThreeSameVectorBF16DotI<bit U, string asm> {

def v4bf16 : BaseSIMDThreeSameVectorBF16DotI<0, U, asm, ".2s", ".4h",
".2h", V64, v2f32, v4bf16>;
def v8bf16 : BaseSIMDThreeSameVectorBF16DotI<1, U, asm, ".4s", ".8h",
".2h", V128, v4f32, v8bf16>;
}

let mayRaiseFPException = 1, Uses = [FPCR] in
class SIMDBF16MLAL<bit Q, string asm, SDPatternOperator OpNode>
: BaseSIMDThreeSameVectorTied<Q, 0b1, 0b110, 0b11111, V128, asm, ".4s",
Expand Down Expand Up @@ -9054,6 +9038,40 @@ class BF16ToSinglePrecision<string asm>
}
} // End of let mayStore = 0, mayLoad = 0, hasSideEffects = 0

multiclass BaseSIMDThreeSameVectorBF16DotI<bit Q, bit U, string asm,
string dst_kind, string lhs_kind,
string rhs_kind,
RegisterOperand RegType,
ValueType AccumType,
ValueType InputType> {
let mayLoad = 0, mayStore = 0, hasSideEffects = 0 in {
def NAME : BaseSIMDIndexedTied<Q, U, 0b0, 0b01, 0b1111, RegType, RegType, V128, VectorIndexS,
asm, "", dst_kind, lhs_kind, rhs_kind, []>
{
bits<2> idx;
let Inst{21} = idx{0}; // L
let Inst{11} = idx{1}; // H
}
}

foreach DupTypes = [VTPair<AccumType, v4f32>,
VTPair<ChangeElementTypeToInteger<AccumType>.VT, v4i32>] in {
def : Pat<(AccumType (int_aarch64_neon_bfdot
(AccumType RegType:$Rd), (InputType RegType:$Rn),
(InputType (bitconvert
(DupTypes.VT0 (AArch64duplane32 (DupTypes.VT1
(bitconvert (v8bf16 V128:$Rm))), VectorIndexS:$Idx)))))),
(!cast<Instruction>(NAME) $Rd, $Rn, $Rm, VectorIndexS:$Idx)>;
}
}

multiclass SIMDThreeSameVectorBF16DotI<bit U, string asm> {
defm v4bf16 : BaseSIMDThreeSameVectorBF16DotI<0, U, asm, ".2s", ".4h",
".2h", V64, v2f32, v4bf16>;
defm v8bf16 : BaseSIMDThreeSameVectorBF16DotI<1, U, asm, ".4s", ".8h",
".2h", V128, v4f32, v8bf16>;
}

//----------------------------------------------------------------------------
class BaseSIMDThreeSameVectorIndexB<bit Q, bit U, bits<2> sz, bits<4> opc,
string asm, string dst_kind,
Expand Down
17 changes: 0 additions & 17 deletions llvm/lib/Target/AArch64/AArch64InstrInfo.td
Original file line number Diff line number Diff line change
Expand Up @@ -1735,23 +1735,6 @@ def BFCVTN2 : SIMD_BFCVTN2;

def : Pat<(concat_vectors (v4bf16 V64:$Rd), (any_fpround (v4f32 V128:$Rn))),
(BFCVTN2 (v8bf16 (INSERT_SUBREG (IMPLICIT_DEF), V64:$Rd, dsub)), V128:$Rn)>;

// Vector-scalar BFDOT:
// The second source operand of the 64-bit variant of BF16DOTlane is a 128-bit
// register (the instruction uses a single 32-bit lane from it), so the pattern
// is a bit tricky.
def : Pat<(v2f32 (int_aarch64_neon_bfdot
(v2f32 V64:$Rd), (v4bf16 V64:$Rn),
(v4bf16 (bitconvert
(v2i32 (AArch64duplane32
(v4i32 (bitconvert
(v8bf16 (insert_subvector undef,
(v4bf16 V64:$Rm),
(i64 0))))),
VectorIndexS:$idx)))))),
(BF16DOTlanev4bf16 (v2f32 V64:$Rd), (v4bf16 V64:$Rn),
(SUBREG_TO_REG (i32 0), V64:$Rm, dsub),
VectorIndexS:$idx)>;
}

let Predicates = [HasNEONandIsStreamingSafe, HasBF16] in {
Expand Down
31 changes: 31 additions & 0 deletions llvm/test/CodeGen/AArch64/aarch64-bf16-dotprod-intrinsics.ll
Original file line number Diff line number Diff line change
Expand Up @@ -151,6 +151,37 @@ entry:
ret <4 x float> %vbfmlaltq_v3.i
}

define <4 x float> @test_vbfdotq_laneq_f32_v4i32_shufflevector(<8 x bfloat> %a, <8 x bfloat> %b) {
; CHECK-LABEL: test_vbfdotq_laneq_f32_v4i32_shufflevector:
; CHECK: // %bb.0: // %entry
; CHECK-NEXT: movi v2.2d, #0000000000000000
; CHECK-NEXT: bfdot v2.4s, v0.8h, v1.2h[0]
; CHECK-NEXT: mov v0.16b, v2.16b
; CHECK-NEXT: ret
entry:
%0 = bitcast <8 x bfloat> %b to <4 x i32>
%1 = shufflevector <4 x i32> %0, <4 x i32> poison, <4 x i32> zeroinitializer
%2 = bitcast <4 x i32> %1 to <8 x bfloat>
%vbfdotq = call <4 x float> @llvm.aarch64.neon.bfdot.v4f32.v8bf16(<4 x float> zeroinitializer, <8 x bfloat> %a, <8 x bfloat> %2)
ret <4 x float> %vbfdotq
}

define <2 x float> @test_vbfdotq_laneq_f32_v2i32_shufflevector(<4 x bfloat> %a, <4 x bfloat> %b) {
; CHECK-LABEL: test_vbfdotq_laneq_f32_v2i32_shufflevector:
; CHECK: // %bb.0: // %entry
; CHECK-NEXT: movi d2, #0000000000000000
; CHECK-NEXT: // kill: def $d1 killed $d1 def $q1
; CHECK-NEXT: bfdot v2.2s, v0.4h, v1.2h[0]
; CHECK-NEXT: fmov d0, d2
; CHECK-NEXT: ret
entry:
%0 = bitcast <4 x bfloat> %b to <2 x i32>
%1 = shufflevector <2 x i32> %0, <2 x i32> poison, <2 x i32> zeroinitializer
%2 = bitcast <2 x i32> %1 to <4 x bfloat>
%vbfdotq = call <2 x float> @llvm.aarch64.neon.bfdot.v2f32.v4bf16(<2 x float> zeroinitializer, <4 x bfloat> %a, <4 x bfloat> %2)
ret <2 x float> %vbfdotq
}

declare <2 x float> @llvm.aarch64.neon.bfdot.v2f32.v4bf16(<2 x float>, <4 x bfloat>, <4 x bfloat>)
declare <4 x float> @llvm.aarch64.neon.bfdot.v4f32.v8bf16(<4 x float>, <8 x bfloat>, <8 x bfloat>)
declare <4 x float> @llvm.aarch64.neon.bfmmla(<4 x float>, <8 x bfloat>, <8 x bfloat>)
Expand Down