Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[X86][BF16] Add subvec_zero_lowering patterns #76507

Merged
merged 2 commits into from
Dec 31, 2023
Merged

Conversation

phoebewang
Copy link
Contributor

No description provided.

@llvmbot
Copy link
Collaborator

llvmbot commented Dec 28, 2023

@llvm/pr-subscribers-backend-x86

Author: Phoebe Wang (phoebewang)

Changes

Full diff: https://github.com/llvm/llvm-project/pull/76507.diff

3 Files Affected:

  • (modified) llvm/lib/Target/X86/X86ISelLowering.cpp (+1-1)
  • (modified) llvm/lib/Target/X86/X86InstrVecCompiler.td (+9)
  • (modified) llvm/test/CodeGen/X86/bfloat.ll (+14)
diff --git a/llvm/lib/Target/X86/X86ISelLowering.cpp b/llvm/lib/Target/X86/X86ISelLowering.cpp
index 35e54ebd5129f7..57ada84bbeffad 100644
--- a/llvm/lib/Target/X86/X86ISelLowering.cpp
+++ b/llvm/lib/Target/X86/X86ISelLowering.cpp
@@ -3743,7 +3743,7 @@ static SDValue getZeroVector(MVT VT, const X86Subtarget &Subtarget,
   SDValue Vec;
   if (!Subtarget.hasSSE2() && VT.is128BitVector()) {
     Vec = DAG.getConstantFP(+0.0, dl, MVT::v4f32);
-  } else if (VT.isFloatingPoint()) {
+  } else if (VT.isFloatingPoint() && VT.getVectorElementType() != MVT::bf16) {
     Vec = DAG.getConstantFP(+0.0, dl, VT);
   } else if (VT.getVectorElementType() == MVT::i1) {
     assert((Subtarget.hasBWI() || VT.getVectorNumElements() <= 16) &&
diff --git a/llvm/lib/Target/X86/X86InstrVecCompiler.td b/llvm/lib/Target/X86/X86InstrVecCompiler.td
index 70bd77bba03ab3..bbd19cf8d5b25e 100644
--- a/llvm/lib/Target/X86/X86InstrVecCompiler.td
+++ b/llvm/lib/Target/X86/X86InstrVecCompiler.td
@@ -130,6 +130,9 @@ let Predicates = [HasAVX, NoVLX] in {
   defm : subvec_zero_lowering<"DQA", VR128, v32i8, v16i8, sub_xmm>;
 }
 
+let Predicates = [HasAVXNECONVERT, NoVLX] in
+  defm : subvec_zero_lowering<"DQA", VR128, v16bf16, v8bf16, sub_xmm>;
+
 let Predicates = [HasVLX] in {
   defm : subvec_zero_lowering<"APDZ128", VR128X, v4f64, v2f64, sub_xmm>;
   defm : subvec_zero_lowering<"APSZ128", VR128X, v8f32, v4f32, sub_xmm>;
@@ -175,6 +178,12 @@ let Predicates = [HasFP16, HasVLX] in {
   defm : subvec_zero_lowering<"APSZ256", VR256X, v32f16, v16f16, sub_ymm>;
 }
 
+let Predicates = [HasBF16, HasVLX] in {
+  defm : subvec_zero_lowering<"APSZ128", VR128X, v16bf16, v8bf16, sub_xmm>;
+  defm : subvec_zero_lowering<"APSZ128", VR128X, v32bf16, v8bf16, sub_xmm>;
+  defm : subvec_zero_lowering<"APSZ256", VR256X, v32bf16, v16bf16, sub_ymm>;
+}
+
 class maskzeroupper<ValueType vt, RegisterClass RC> :
   PatLeaf<(vt RC:$src), [{
     return isMaskZeroExtended(N);
diff --git a/llvm/test/CodeGen/X86/bfloat.ll b/llvm/test/CodeGen/X86/bfloat.ll
index 674a0eacb0ca98..9c65310f79d7ec 100644
--- a/llvm/test/CodeGen/X86/bfloat.ll
+++ b/llvm/test/CodeGen/X86/bfloat.ll
@@ -2529,3 +2529,17 @@ define <8 x bfloat> @extract_v32bf16_v8bf16(<32 x bfloat> %x) {
   %a = shufflevector <32 x bfloat> %x, <32 x bfloat> undef, <8 x i32> <i32 8, i32 9, i32 10, i32 11, i32 12, i32 13, i32 14, i32 15>
   ret <8 x bfloat> %a
 }
+
+define <16 x bfloat> @concat_zero_v8bf16(<8 x bfloat> %x, <8 x bfloat> %y) {
+; SSE2-LABEL: concat_zero_v8bf16:
+; SSE2:       # %bb.0:
+; SSE2-NEXT:    xorps %xmm1, %xmm1
+; SSE2-NEXT:    retq
+;
+; AVX-LABEL: concat_zero_v8bf16:
+; AVX:       # %bb.0:
+; AVX-NEXT:    vmovaps %xmm0, %xmm0
+; AVX-NEXT:    retq
+  %a = shufflevector <8 x bfloat> %x, <8 x bfloat> zeroinitializer, <16 x i32> <i32 0, i32 1, i32 2, i32 3, i32 4, i32 5, i32 6, i32 7, i32 8, i32 9, i32 10, i32 11, i32 12, i32 13, i32 14, i32 15>
+  ret <16 x bfloat> %a
+}

@@ -3743,7 +3743,7 @@ static SDValue getZeroVector(MVT VT, const X86Subtarget &Subtarget,
SDValue Vec;
if (!Subtarget.hasSSE2() && VT.is128BitVector()) {
Vec = DAG.getConstantFP(+0.0, dl, MVT::v4f32);
} else if (VT.isFloatingPoint()) {
} else if (VT.isFloatingPoint() && VT.getVectorElementType() != MVT::bf16) {
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why is this necessary? Are we missing bf16 handling somewhere?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Because the scalar bf16 is an illegal type even vXbf16 are legal. We worked around it in BUILDVECTOR and should try to avoid build zero vector here.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could we test for isTypeLegal(VT.getVectorElementType()) instead?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done.

Copy link
Collaborator

@RKSimon RKSimon left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM with one (optional) minor

@@ -3743,7 +3743,7 @@ static SDValue getZeroVector(MVT VT, const X86Subtarget &Subtarget,
SDValue Vec;
if (!Subtarget.hasSSE2() && VT.is128BitVector()) {
Vec = DAG.getConstantFP(+0.0, dl, MVT::v4f32);
} else if (VT.isFloatingPoint()) {
} else if (VT.isFloatingPoint() && VT.getVectorElementType() != MVT::bf16) {
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could we test for isTypeLegal(VT.getVectorElementType()) instead?

@phoebewang phoebewang merged commit a384cd5 into llvm:main Dec 31, 2023
3 of 4 checks passed
@phoebewang phoebewang deleted the bf16 branch December 31, 2023 03:14
qiaojbao pushed a commit to GPUOpen-Drivers/llvm-project that referenced this pull request Jan 26, 2024
…ffd43ea2e

Local branch amd-gfx 5aaffd4 Merged main:a1f1371fdc7d9af9edf32339dcfebada96d937a5 into amd-gfx:45668192a2fc
Remote branch main a384cd5 [X86][BF16] Add subvec_zero_lowering patterns (llvm#76507)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

3 participants