-
Notifications
You must be signed in to change notification settings - Fork 15.4k
[AArch64] Generalize bfdotq_lane patterns to work for f32/i32 duplanes #171146
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
|
@llvm/pr-subscribers-backend-aarch64 Author: Benjamin Maxwell (MacDue) ChangesThis also removes an overly specific pattern that is redundant with this change. Fixes #170883 Full diff: https://github.com/llvm/llvm-project/pull/171146.diff 3 Files Affected:
diff --git a/llvm/lib/Target/AArch64/AArch64InstrFormats.td b/llvm/lib/Target/AArch64/AArch64InstrFormats.td
index 4d2e740779961..821dfbd8e9191 100644
--- a/llvm/lib/Target/AArch64/AArch64InstrFormats.td
+++ b/llvm/lib/Target/AArch64/AArch64InstrFormats.td
@@ -10,6 +10,20 @@
// Describe AArch64 instructions format here
//
+// Helper class to convert vector element types to integers.
+class ChangeElementTypeToInteger<ValueType InVT> {
+ ValueType VT = !cond(
+ !eq(InVT, v2f32): v2i32,
+ !eq(InVT, v4f32): v4i32,
+ // TODO: Other types.
+ true : untyped);
+}
+
+class VTPair<ValueType A, ValueType B> {
+ ValueType VT0 = A;
+ ValueType VT1 = B;
+}
+
// Format specifies the encoding used by the instruction. This is part of the
// ad-hoc solution used to emit machine instruction encodings by our machine
// code emitter.
@@ -8952,36 +8966,6 @@ multiclass SIMDThreeSameVectorBFDot<bit U, string asm> {
v4f32, v8bf16>;
}
-class BaseSIMDThreeSameVectorBF16DotI<bit Q, bit U, string asm,
- string dst_kind, string lhs_kind,
- string rhs_kind,
- RegisterOperand RegType,
- ValueType AccumType,
- ValueType InputType>
- : BaseSIMDIndexedTied<Q, U, 0b0, 0b01, 0b1111,
- RegType, RegType, V128, VectorIndexS,
- asm, "", dst_kind, lhs_kind, rhs_kind,
- [(set (AccumType RegType:$dst),
- (AccumType (int_aarch64_neon_bfdot
- (AccumType RegType:$Rd),
- (InputType RegType:$Rn),
- (InputType (bitconvert (AccumType
- (AArch64duplane32 (v4f32 V128:$Rm),
- VectorIndexS:$idx)))))))]> {
-
- bits<2> idx;
- let Inst{21} = idx{0}; // L
- let Inst{11} = idx{1}; // H
-}
-
-multiclass SIMDThreeSameVectorBF16DotI<bit U, string asm> {
-
- def v4bf16 : BaseSIMDThreeSameVectorBF16DotI<0, U, asm, ".2s", ".4h",
- ".2h", V64, v2f32, v4bf16>;
- def v8bf16 : BaseSIMDThreeSameVectorBF16DotI<1, U, asm, ".4s", ".8h",
- ".2h", V128, v4f32, v8bf16>;
-}
-
let mayRaiseFPException = 1, Uses = [FPCR] in
class SIMDBF16MLAL<bit Q, string asm, SDPatternOperator OpNode>
: BaseSIMDThreeSameVectorTied<Q, 0b1, 0b110, 0b11111, V128, asm, ".4s",
@@ -9054,6 +9038,39 @@ class BF16ToSinglePrecision<string asm>
}
} // End of let mayStore = 0, mayLoad = 0, hasSideEffects = 0
+multiclass BaseSIMDThreeSameVectorBF16DotI<bit Q, bit U, string asm,
+ string dst_kind, string lhs_kind,
+ string rhs_kind,
+ RegisterOperand RegType,
+ ValueType AccumType,
+ ValueType InputType> {
+ let mayLoad = 0, mayStore = 0, hasSideEffects = 0 in {
+ def NAME : BaseSIMDIndexedTied<Q, U, 0b0, 0b01, 0b1111, RegType, RegType, V128, VectorIndexS,
+ asm, "", dst_kind, lhs_kind, rhs_kind, []>
+ {
+ bits<2> idx;
+ let Inst{21} = idx{0}; // L
+ let Inst{11} = idx{1}; // H
+ }
+ }
+
+ foreach DupTypes = [VTPair<AccumType, v4f32>,
+ VTPair<ChangeElementTypeToInteger<AccumType>.VT, v4i32>] in {
+ def : Pat<(AccumType (int_aarch64_neon_bfdot
+ (AccumType RegType:$Rd), (InputType RegType:$Rn),
+ (InputType (bitconvert
+ (DupTypes.VT0 (AArch64duplane32 (DupTypes.VT1 V128:$Rm), VectorIndexS:$Idx)))))),
+ (!cast<Instruction>(NAME) $Rd, $Rn, $Rm, VectorIndexS:$Idx)>;
+ }
+}
+
+multiclass SIMDThreeSameVectorBF16DotI<bit U, string asm> {
+ defm v4bf16 : BaseSIMDThreeSameVectorBF16DotI<0, U, asm, ".2s", ".4h",
+ ".2h", V64, v2f32, v4bf16>;
+ defm v8bf16 : BaseSIMDThreeSameVectorBF16DotI<1, U, asm, ".4s", ".8h",
+ ".2h", V128, v4f32, v8bf16>;
+}
+
//----------------------------------------------------------------------------
class BaseSIMDThreeSameVectorIndexB<bit Q, bit U, bits<2> sz, bits<4> opc,
string asm, string dst_kind,
diff --git a/llvm/lib/Target/AArch64/AArch64InstrInfo.td b/llvm/lib/Target/AArch64/AArch64InstrInfo.td
index 64017d7cafca3..d2e34219d40aa 100644
--- a/llvm/lib/Target/AArch64/AArch64InstrInfo.td
+++ b/llvm/lib/Target/AArch64/AArch64InstrInfo.td
@@ -1729,23 +1729,6 @@ def BFCVTN2 : SIMD_BFCVTN2;
def : Pat<(concat_vectors (v4bf16 V64:$Rd), (any_fpround (v4f32 V128:$Rn))),
(BFCVTN2 (v8bf16 (INSERT_SUBREG (IMPLICIT_DEF), V64:$Rd, dsub)), V128:$Rn)>;
-
-// Vector-scalar BFDOT:
-// The second source operand of the 64-bit variant of BF16DOTlane is a 128-bit
-// register (the instruction uses a single 32-bit lane from it), so the pattern
-// is a bit tricky.
-def : Pat<(v2f32 (int_aarch64_neon_bfdot
- (v2f32 V64:$Rd), (v4bf16 V64:$Rn),
- (v4bf16 (bitconvert
- (v2i32 (AArch64duplane32
- (v4i32 (bitconvert
- (v8bf16 (insert_subvector undef,
- (v4bf16 V64:$Rm),
- (i64 0))))),
- VectorIndexS:$idx)))))),
- (BF16DOTlanev4bf16 (v2f32 V64:$Rd), (v4bf16 V64:$Rn),
- (SUBREG_TO_REG (i32 0), V64:$Rm, dsub),
- VectorIndexS:$idx)>;
}
let Predicates = [HasNEONandIsStreamingSafe, HasBF16] in {
diff --git a/llvm/test/CodeGen/AArch64/aarch64-bf16-dotprod-intrinsics.ll b/llvm/test/CodeGen/AArch64/aarch64-bf16-dotprod-intrinsics.ll
index 52b542790e82d..ca3cd6bbae549 100644
--- a/llvm/test/CodeGen/AArch64/aarch64-bf16-dotprod-intrinsics.ll
+++ b/llvm/test/CodeGen/AArch64/aarch64-bf16-dotprod-intrinsics.ll
@@ -151,6 +151,37 @@ entry:
ret <4 x float> %vbfmlaltq_v3.i
}
+define <4 x float> @test_vbfdotq_laneq_f32_v4i32_shufflevector(<8 x bfloat> %a, <8 x bfloat> %b) {
+; CHECK-LABEL: test_vbfdotq_laneq_f32_v4i32_shufflevector:
+; CHECK: // %bb.0: // %entry
+; CHECK-NEXT: movi v2.2d, #0000000000000000
+; CHECK-NEXT: bfdot v2.4s, v0.8h, v1.2h[0]
+; CHECK-NEXT: mov v0.16b, v2.16b
+; CHECK-NEXT: ret
+entry:
+ %0 = bitcast <8 x bfloat> %b to <4 x i32>
+ %1 = shufflevector <4 x i32> %0, <4 x i32> poison, <4 x i32> zeroinitializer
+ %2 = bitcast <4 x i32> %1 to <8 x bfloat>
+ %vbfdotq = call <4 x float> @llvm.aarch64.neon.bfdot.v4f32.v8bf16(<4 x float> zeroinitializer, <8 x bfloat> %a, <8 x bfloat> %2)
+ ret <4 x float> %vbfdotq
+}
+
+define <2 x float> @test_vbfdotq_laneq_f32_v2i32_shufflevector(<4 x bfloat> %a, <4 x bfloat> %b) {
+; CHECK-LABEL: test_vbfdotq_laneq_f32_v2i32_shufflevector:
+; CHECK: // %bb.0: // %entry
+; CHECK-NEXT: movi d2, #0000000000000000
+; CHECK-NEXT: // kill: def $d1 killed $d1 def $q1
+; CHECK-NEXT: bfdot v2.2s, v0.4h, v1.2h[0]
+; CHECK-NEXT: fmov d0, d2
+; CHECK-NEXT: ret
+entry:
+ %0 = bitcast <4 x bfloat> %b to <2 x i32>
+ %1 = shufflevector <2 x i32> %0, <2 x i32> poison, <2 x i32> zeroinitializer
+ %2 = bitcast <2 x i32> %1 to <4 x bfloat>
+ %vbfdotq = call <2 x float> @llvm.aarch64.neon.bfdot.v2f32.v4bf16(<2 x float> zeroinitializer, <4 x bfloat> %a, <4 x bfloat> %2)
+ ret <2 x float> %vbfdotq
+}
+
declare <2 x float> @llvm.aarch64.neon.bfdot.v2f32.v4bf16(<2 x float>, <4 x bfloat>, <4 x bfloat>)
declare <4 x float> @llvm.aarch64.neon.bfdot.v4f32.v8bf16(<4 x float>, <8 x bfloat>, <8 x bfloat>)
declare <4 x float> @llvm.aarch64.neon.bfmmla(<4 x float>, <8 x bfloat>, <8 x bfloat>)
|
| def : Pat<(AccumType (int_aarch64_neon_bfdot | ||
| (AccumType RegType:$Rd), (InputType RegType:$Rn), | ||
| (InputType (bitconvert | ||
| (DupTypes.VT0 (AArch64duplane32 (DupTypes.VT1 V128:$Rm), VectorIndexS:$Idx)))))), | ||
| (!cast<Instruction>(NAME) $Rd, $Rn, $Rm, VectorIndexS:$Idx)>; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This doesn't look correct for big-endian targets, where mixed-element-size bitconverts are not no-ops.
The removed pattern matched matching bitconverts either side of the AArch64duplane32, which is ok because they're effectively back to back and so collapse to a no-op.
I appreciate some of this is existing code, but if we're increasing the applicability of the pattern then I'd rather not make things worse. I'm wondering if there's a post legalisation combine you can add to replace bitconvert->duplane->bitconvert with nvcast->duplane->nvcast and then the above can work for both cases when matching a single nvcast?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I see. I decided just to match both bitconverts here (which does not regress any tests). Theoretically, it would regress some LE code that did not have a bitconvert as the input to the AArch64duplane32, but I think that's pretty unlikely (given what clang emits for these intrinsics).
I did think about adding the combine but it's a little annoying as some bitconverts have multiple users, and I'm not sure it's worth the extra complexity.
This also removes an overly specific pattern that is redundant with this change. Fixes llvm#170883
This also removes an overly specific pattern that is redundant with this change.
Fixes #170883