Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

AMDGPU: Make v32bf16 a legal type #76679

Merged
merged 2 commits into from
Jan 9, 2024
Merged

Conversation

arsenm
Copy link
Contributor

@arsenm arsenm commented Jan 1, 2024

Depends #76678

@llvmbot llvmbot added the llvm:SelectionDAG SelectionDAGISel as well label Jan 1, 2024
@llvmbot
Copy link
Collaborator

llvmbot commented Jan 1, 2024

@llvm/pr-subscribers-llvm-selectiondag

@llvm/pr-subscribers-backend-amdgpu

Author: Matt Arsenault (arsenm)

Changes

Depends #76678


Patch is 1.38 MiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/76679.diff

27 Files Affected:

  • (modified) llvm/lib/CodeGen/SelectionDAG/LegalizeDAG.cpp (+21-5)
  • (modified) llvm/lib/Target/AMDGPU/AMDGPUCallingConv.td (+13-13)
  • (modified) llvm/lib/Target/AMDGPU/AMDGPUISelDAGToDAG.cpp (+1)
  • (modified) llvm/lib/Target/AMDGPU/AMDGPUISelLowering.cpp (+44-14)
  • (modified) llvm/lib/Target/AMDGPU/SIISelLowering.cpp (+140-32)
  • (modified) llvm/lib/Target/AMDGPU/SIISelLowering.h (+1)
  • (modified) llvm/lib/Target/AMDGPU/SIInstructions.td (+147-6)
  • (modified) llvm/lib/Target/AMDGPU/SIRegisterInfo.td (+33-33)
  • (modified) llvm/lib/Target/AMDGPU/VOP3PInstructions.td (+11-11)
  • (modified) llvm/test/CodeGen/AMDGPU/bf16.ll (+7124-9724)
  • (modified) llvm/test/CodeGen/AMDGPU/fcopysign.f32.ll (+5-7)
  • (modified) llvm/test/CodeGen/AMDGPU/fmed3-cast-combine.ll (+14-2)
  • (modified) llvm/test/CodeGen/AMDGPU/fneg-modifier-casting.ll (+5-5)
  • (modified) llvm/test/CodeGen/AMDGPU/function-args-inreg.ll (+2-2)
  • (modified) llvm/test/CodeGen/AMDGPU/function-args.ll (+70-88)
  • (modified) llvm/test/CodeGen/AMDGPU/function-returns.ll (+19-46)
  • (modified) llvm/test/CodeGen/AMDGPU/gfx-callable-argument-types.ll (+48-80)
  • (modified) llvm/test/CodeGen/AMDGPU/isel-amdgpu-cs-chain-preserve-cc.ll (+6-10)
  • (modified) llvm/test/CodeGen/AMDGPU/llvm.exp.ll (+13-12)
  • (modified) llvm/test/CodeGen/AMDGPU/llvm.exp10.ll (+13-12)
  • (modified) llvm/test/CodeGen/AMDGPU/llvm.exp2.ll (+43-15)
  • (modified) llvm/test/CodeGen/AMDGPU/llvm.is.fpclass.bf16.ll (+157-346)
  • (modified) llvm/test/CodeGen/AMDGPU/llvm.log.ll (+7-4)
  • (modified) llvm/test/CodeGen/AMDGPU/llvm.log10.ll (+7-4)
  • (modified) llvm/test/CodeGen/AMDGPU/llvm.log2.ll (+46-16)
  • (modified) llvm/test/CodeGen/AMDGPU/local-atomics-fp.ll (+2-2)
  • (modified) llvm/test/CodeGen/AMDGPU/vector_shuffle.packed.ll (+362-629)
diff --git a/llvm/lib/CodeGen/SelectionDAG/LegalizeDAG.cpp b/llvm/lib/CodeGen/SelectionDAG/LegalizeDAG.cpp
index 4e317062cec497..296ed3a3c3dc11 100644
--- a/llvm/lib/CodeGen/SelectionDAG/LegalizeDAG.cpp
+++ b/llvm/lib/CodeGen/SelectionDAG/LegalizeDAG.cpp
@@ -3199,7 +3199,16 @@ bool SelectionDAGLegalize::ExpandNode(SDNode *Node) {
       return true;
     }
     break;
-  case ISD::FP_ROUND:
+  case ISD::FP_ROUND: {
+    EVT VT = Node->getValueType(0);
+    if (VT.getScalarType() == MVT::bf16) {
+      Results.push_back(
+          DAG.getNode(ISD::FP_TO_BF16, SDLoc(Node), VT, Node->getOperand(0)));
+      break;
+    }
+
+    LLVM_FALLTHROUGH;
+  }
   case ISD::BITCAST:
     if ((Tmp1 = EmitStackConvert(Node->getOperand(0), Node->getValueType(0),
                                  Node->getValueType(0), dl)))
@@ -3226,12 +3235,19 @@ bool SelectionDAGLegalize::ExpandNode(SDNode *Node) {
       return true;
     }
     break;
-  case ISD::FP_EXTEND:
-    if ((Tmp1 = EmitStackConvert(Node->getOperand(0),
-                                 Node->getOperand(0).getValueType(),
-                                 Node->getValueType(0), dl)))
+  case ISD::FP_EXTEND: {
+    SDValue Op = Node->getOperand(0);
+    EVT SrcVT = Op.getValueType();
+    EVT DstVT = Node->getValueType(0);
+    if (SrcVT.getScalarType() == MVT::bf16) {
+      Results.push_back(DAG.getNode(ISD::BF16_TO_FP, SDLoc(Node), DstVT, Op));
+      break;
+    }
+
+    if ((Tmp1 = EmitStackConvert(Op, SrcVT, DstVT, dl)))
       Results.push_back(Tmp1);
     break;
+  }
   case ISD::BF16_TO_FP: {
     // Always expand bf16 to f32 casts, they lower to ext + shift.
     //
diff --git a/llvm/lib/Target/AMDGPU/AMDGPUCallingConv.td b/llvm/lib/Target/AMDGPU/AMDGPUCallingConv.td
index 9036b26a6f6bcb..c5207228dc913f 100644
--- a/llvm/lib/Target/AMDGPU/AMDGPUCallingConv.td
+++ b/llvm/lib/Target/AMDGPU/AMDGPUCallingConv.td
@@ -22,28 +22,28 @@ def CC_SI_Gfx : CallingConv<[
   // 32 is reserved for the stack pointer
   // 33 is reserved for the frame pointer
   // 34 is reserved for the base pointer
-  CCIfInReg<CCIfType<[f32, i32, f16, i16, v2i16, v2f16] , CCAssignToReg<[
+  CCIfInReg<CCIfType<[f32, i32, f16, i16, v2i16, v2f16, bf16, v2bf16] , CCAssignToReg<[
     SGPR4, SGPR5, SGPR6, SGPR7,
     SGPR8, SGPR9, SGPR10, SGPR11, SGPR12, SGPR13, SGPR14, SGPR15,
     SGPR16, SGPR17, SGPR18, SGPR19, SGPR20, SGPR21, SGPR22, SGPR23,
     SGPR24, SGPR25, SGPR26, SGPR27, SGPR28, SGPR29
   ]>>>,
 
-  CCIfNotInReg<CCIfType<[f32, i32, f16, i16, v2i16, v2f16] , CCAssignToReg<[
+  CCIfNotInReg<CCIfType<[f32, i32, f16, i16, v2i16, v2f16, bf16, v2bf16] , CCAssignToReg<[
     VGPR0, VGPR1, VGPR2, VGPR3, VGPR4, VGPR5, VGPR6, VGPR7,
     VGPR8, VGPR9, VGPR10, VGPR11, VGPR12, VGPR13, VGPR14, VGPR15,
     VGPR16, VGPR17, VGPR18, VGPR19, VGPR20, VGPR21, VGPR22, VGPR23,
     VGPR24, VGPR25, VGPR26, VGPR27, VGPR28, VGPR29, VGPR30, VGPR31
   ]>>>,
 
-  CCIfType<[i32, f32, v2i16, v2f16, i16, f16, i1], CCAssignToStack<4, 4>>
+  CCIfType<[i32, f32, v2i16, v2f16, i16, f16, i1, bf16, v2bf16], CCAssignToStack<4, 4>>
 ]>;
 
 def RetCC_SI_Gfx : CallingConv<[
   CCIfType<[i1], CCPromoteToType<i32>>,
   CCIfType<[i1, i16], CCIfExtend<CCPromoteToType<i32>>>,
 
-  CCIfNotInReg<CCIfType<[f32, i32, f16, i16, v2i16, v2f16] , CCAssignToReg<[
+  CCIfNotInReg<CCIfType<[f32, i32, f16, i16, v2i16, v2f16, bf16, v2bf16] , CCAssignToReg<[
     VGPR0, VGPR1, VGPR2, VGPR3, VGPR4, VGPR5, VGPR6, VGPR7,
     VGPR8, VGPR9, VGPR10, VGPR11, VGPR12, VGPR13, VGPR14, VGPR15,
     VGPR16, VGPR17, VGPR18, VGPR19, VGPR20, VGPR21, VGPR22, VGPR23,
@@ -66,7 +66,7 @@ def RetCC_SI_Gfx : CallingConv<[
 
 def CC_SI_SHADER : CallingConv<[
 
-  CCIfInReg<CCIfType<[f32, i32, f16, i16, v2i16, v2f16] , CCAssignToReg<[
+  CCIfInReg<CCIfType<[f32, i32, f16, i16, v2i16, v2f16, bf16, v2bf16] , CCAssignToReg<[
     SGPR0, SGPR1, SGPR2, SGPR3, SGPR4, SGPR5, SGPR6, SGPR7,
     SGPR8, SGPR9, SGPR10, SGPR11, SGPR12, SGPR13, SGPR14, SGPR15,
     SGPR16, SGPR17, SGPR18, SGPR19, SGPR20, SGPR21, SGPR22, SGPR23,
@@ -76,7 +76,7 @@ def CC_SI_SHADER : CallingConv<[
   ]>>>,
 
   // 32*4 + 4 is the minimum for a fetch shader consumer with 32 inputs.
-  CCIfNotInReg<CCIfType<[f32, i32, f16, i16, v2i16, v2f16] , CCAssignToReg<[
+  CCIfNotInReg<CCIfType<[f32, i32, f16, i16, v2i16, v2f16, bf16, v2bf16] , CCAssignToReg<[
     VGPR0, VGPR1, VGPR2, VGPR3, VGPR4, VGPR5, VGPR6, VGPR7,
     VGPR8, VGPR9, VGPR10, VGPR11, VGPR12, VGPR13, VGPR14, VGPR15,
     VGPR16, VGPR17, VGPR18, VGPR19, VGPR20, VGPR21, VGPR22, VGPR23,
@@ -109,7 +109,7 @@ def RetCC_SI_Shader : CallingConv<[
   ]>>,
 
   // 32*4 + 4 is the minimum for a fetch shader with 32 outputs.
-  CCIfType<[f32, f16, v2f16] , CCAssignToReg<[
+  CCIfType<[f32, f16, v2f16, bf16, v2bf16] , CCAssignToReg<[
     VGPR0, VGPR1, VGPR2, VGPR3, VGPR4, VGPR5, VGPR6, VGPR7,
     VGPR8, VGPR9, VGPR10, VGPR11, VGPR12, VGPR13, VGPR14, VGPR15,
     VGPR16, VGPR17, VGPR18, VGPR19, VGPR20, VGPR21, VGPR22, VGPR23,
@@ -188,23 +188,23 @@ def CC_AMDGPU_Func : CallingConv<[
   CCIfType<[i1], CCPromoteToType<i32>>,
   CCIfType<[i8, i16], CCIfExtend<CCPromoteToType<i32>>>,
 
-  CCIfInReg<CCIfType<[f32, i32, f16, i16, v2i16, v2f16] , CCAssignToReg<
+  CCIfInReg<CCIfType<[f32, i32, f16, i16, v2i16, v2f16, bf16, v2bf16] , CCAssignToReg<
     !foreach(i, !range(0, 30), !cast<Register>("SGPR"#i))  // SGPR0-29
   >>>,
 
-  CCIfType<[i32, f32, i16, f16, v2i16, v2f16, i1], CCAssignToReg<[
+  CCIfType<[i32, f32, i16, f16, v2i16, v2f16, i1, bf16, v2bf16], CCAssignToReg<[
     VGPR0, VGPR1, VGPR2, VGPR3, VGPR4, VGPR5, VGPR6, VGPR7,
     VGPR8, VGPR9, VGPR10, VGPR11, VGPR12, VGPR13, VGPR14, VGPR15,
     VGPR16, VGPR17, VGPR18, VGPR19, VGPR20, VGPR21, VGPR22, VGPR23,
     VGPR24, VGPR25, VGPR26, VGPR27, VGPR28, VGPR29, VGPR30, VGPR31]>>,
-  CCIfType<[i32, f32, v2i16, v2f16, i16, f16, i1], CCAssignToStack<4, 4>>
+  CCIfType<[i32, f32, v2i16, v2f16, i16, f16, i1, bf16, v2bf16], CCAssignToStack<4, 4>>
 ]>;
 
 // Calling convention for leaf functions
 def RetCC_AMDGPU_Func : CallingConv<[
   CCIfType<[i1], CCPromoteToType<i32>>,
   CCIfType<[i1, i16], CCIfExtend<CCPromoteToType<i32>>>,
-  CCIfType<[i32, f32, i16, f16, v2i16, v2f16], CCAssignToReg<[
+  CCIfType<[i32, f32, i16, f16, v2i16, v2f16, bf16, v2bf16], CCAssignToReg<[
     VGPR0, VGPR1, VGPR2, VGPR3, VGPR4, VGPR5, VGPR6, VGPR7,
     VGPR8, VGPR9, VGPR10, VGPR11, VGPR12, VGPR13, VGPR14, VGPR15,
     VGPR16, VGPR17, VGPR18, VGPR19, VGPR20, VGPR21, VGPR22, VGPR23,
@@ -223,11 +223,11 @@ def CC_AMDGPU : CallingConv<[
 ]>;
 
 def CC_AMDGPU_CS_CHAIN : CallingConv<[
-  CCIfInReg<CCIfType<[f32, i32, f16, i16, v2i16, v2f16] , CCAssignToReg<
+  CCIfInReg<CCIfType<[f32, i32, f16, i16, v2i16, v2f16, bf16, v2bf16] , CCAssignToReg<
     !foreach(i, !range(105), !cast<Register>("SGPR"#i))
   >>>,
 
-  CCIfNotInReg<CCIfType<[f32, i32, f16, i16, v2i16, v2f16] , CCAssignToReg<
+  CCIfNotInReg<CCIfType<[f32, i32, f16, i16, v2i16, v2f16, bf16, v2bf16] , CCAssignToReg<
     !foreach(i, !range(8, 255), !cast<Register>("VGPR"#i))
   >>>
 ]>;
diff --git a/llvm/lib/Target/AMDGPU/AMDGPUISelDAGToDAG.cpp b/llvm/lib/Target/AMDGPU/AMDGPUISelDAGToDAG.cpp
index b0eac567ec9f18..40a49cbe3f518f 100644
--- a/llvm/lib/Target/AMDGPU/AMDGPUISelDAGToDAG.cpp
+++ b/llvm/lib/Target/AMDGPU/AMDGPUISelDAGToDAG.cpp
@@ -303,6 +303,7 @@ void AMDGPUDAGToDAGISel::PreprocessISelDAG() {
 
     switch (N->getOpcode()) {
     case ISD::BUILD_VECTOR:
+      // TODO: Match load d16 from shl (extload:i16), 16
       MadeChange |= matchLoadD16FromBuildVector(N);
       break;
     default:
diff --git a/llvm/lib/Target/AMDGPU/AMDGPUISelLowering.cpp b/llvm/lib/Target/AMDGPU/AMDGPUISelLowering.cpp
index 541a5b62450ddf..630910bdfd29c2 100644
--- a/llvm/lib/Target/AMDGPU/AMDGPUISelLowering.cpp
+++ b/llvm/lib/Target/AMDGPU/AMDGPUISelLowering.cpp
@@ -387,17 +387,20 @@ AMDGPUTargetLowering::AMDGPUTargetLowering(const TargetMachine &TM,
                       MVT::v9i32,  MVT::v9f32,  MVT::v10i32, MVT::v10f32,
                       MVT::v11i32, MVT::v11f32, MVT::v12i32, MVT::v12f32},
                      Custom);
+
+  // FIXME: Why is v8f16/v8bf16 missing?
   setOperationAction(
       ISD::EXTRACT_SUBVECTOR,
-      {MVT::v2f16,  MVT::v2i16,  MVT::v4f16,  MVT::v4i16,  MVT::v2f32,
-       MVT::v2i32,  MVT::v3f32,  MVT::v3i32,  MVT::v4f32,  MVT::v4i32,
-       MVT::v5f32,  MVT::v5i32,  MVT::v6f32,  MVT::v6i32,  MVT::v7f32,
-       MVT::v7i32,  MVT::v8f32,  MVT::v8i32,  MVT::v9f32,  MVT::v9i32,
-       MVT::v10i32, MVT::v10f32, MVT::v11i32, MVT::v11f32, MVT::v12i32,
-       MVT::v12f32, MVT::v16f16, MVT::v16i16, MVT::v16f32, MVT::v16i32,
-       MVT::v32f32, MVT::v32i32, MVT::v2f64,  MVT::v2i64,  MVT::v3f64,
-       MVT::v3i64,  MVT::v4f64,  MVT::v4i64,  MVT::v8f64,  MVT::v8i64,
-       MVT::v16f64, MVT::v16i64, MVT::v32i16, MVT::v32f16},
+      {MVT::v2f16,  MVT::v2bf16, MVT::v2i16,  MVT::v4f16,  MVT::v4bf16,
+       MVT::v4i16,  MVT::v2f32,  MVT::v2i32,  MVT::v3f32,  MVT::v3i32,
+       MVT::v4f32,  MVT::v4i32,  MVT::v5f32,  MVT::v5i32,  MVT::v6f32,
+       MVT::v6i32,  MVT::v7f32,  MVT::v7i32,  MVT::v8f32,  MVT::v8i32,
+       MVT::v9f32,  MVT::v9i32,  MVT::v10i32, MVT::v10f32, MVT::v11i32,
+       MVT::v11f32, MVT::v12i32, MVT::v12f32, MVT::v16f16, MVT::v16bf16,
+       MVT::v16i16, MVT::v16f32, MVT::v16i32, MVT::v32f32, MVT::v32i32,
+       MVT::v2f64,  MVT::v2i64,  MVT::v3f64,  MVT::v3i64,  MVT::v4f64,
+       MVT::v4i64,  MVT::v8f64,  MVT::v8i64,  MVT::v16f64, MVT::v16i64,
+       MVT::v32i16, MVT::v32f16, MVT::v32bf16},
       Custom);
 
   setOperationAction(ISD::FP16_TO_FP, MVT::f64, Expand);
@@ -3282,7 +3285,15 @@ SDValue AMDGPUTargetLowering::LowerUINT_TO_FP(SDValue Op,
     return DAG.getNode(ISD::UINT_TO_FP, DL, DestVT, Ext);
   }
 
-  assert(SrcVT == MVT::i64 && "operation should be legal");
+  if (DestVT == MVT::bf16) {
+    SDLoc SL(Op);
+    SDValue ToF32 = DAG.getNode(ISD::UINT_TO_FP, SL, MVT::f32, Src);
+    SDValue FPRoundFlag = DAG.getIntPtrConstant(0, SL, /*isTarget=*/true);
+    return DAG.getNode(ISD::FP_ROUND, SL, MVT::bf16, ToF32, FPRoundFlag);
+  }
+
+  if (SrcVT != MVT::i64)
+    return Op;
 
   if (Subtarget->has16BitInsts() && DestVT == MVT::f16) {
     SDLoc DL(Op);
@@ -3320,7 +3331,15 @@ SDValue AMDGPUTargetLowering::LowerSINT_TO_FP(SDValue Op,
     return DAG.getNode(ISD::SINT_TO_FP, DL, DestVT, Ext);
   }
 
-  assert(SrcVT == MVT::i64 && "operation should be legal");
+  if (DestVT == MVT::bf16) {
+    SDLoc SL(Op);
+    SDValue ToF32 = DAG.getNode(ISD::SINT_TO_FP, SL, MVT::f32, Src);
+    SDValue FPRoundFlag = DAG.getIntPtrConstant(0, SL, /*isTarget=*/true);
+    return DAG.getNode(ISD::FP_ROUND, SL, MVT::bf16, ToF32, FPRoundFlag);
+  }
+
+  if (SrcVT != MVT::i64)
+    return Op;
 
   // TODO: Factor out code common with LowerUINT_TO_FP.
 
@@ -3518,7 +3537,7 @@ SDValue AMDGPUTargetLowering::LowerFP_TO_FP16(SDValue Op, SelectionDAG &DAG) con
   return DAG.getZExtOrTrunc(V, DL, Op.getValueType());
 }
 
-SDValue AMDGPUTargetLowering::LowerFP_TO_INT(SDValue Op,
+SDValue AMDGPUTargetLowering::LowerFP_TO_INT(const SDValue Op,
                                              SelectionDAG &DAG) const {
   SDValue Src = Op.getOperand(0);
   unsigned OpOpcode = Op.getOpcode();
@@ -3529,6 +3548,12 @@ SDValue AMDGPUTargetLowering::LowerFP_TO_INT(SDValue Op,
   if (SrcVT == MVT::f16 && DestVT == MVT::i16)
     return Op;
 
+  if (SrcVT == MVT::bf16) {
+    SDLoc DL(Op);
+    SDValue PromotedSrc = DAG.getNode(ISD::FP_EXTEND, DL, MVT::f32, Src);
+    return DAG.getNode(Op.getOpcode(), DL, DestVT, PromotedSrc);
+  }
+
   // Promote i16 to i32
   if (DestVT == MVT::i16 && (SrcVT == MVT::f32 || SrcVT == MVT::f64)) {
     SDLoc DL(Op);
@@ -3537,6 +3562,9 @@ SDValue AMDGPUTargetLowering::LowerFP_TO_INT(SDValue Op,
     return DAG.getNode(ISD::TRUNCATE, DL, MVT::i16, FpToInt32);
   }
 
+  if (DestVT != MVT::i64)
+    return Op;
+
   if (SrcVT == MVT::f16 ||
       (SrcVT == MVT::f32 && Src.getOpcode() == ISD::FP16_TO_FP)) {
     SDLoc DL(Op);
@@ -3547,7 +3575,7 @@ SDValue AMDGPUTargetLowering::LowerFP_TO_INT(SDValue Op,
     return DAG.getNode(Ext, DL, MVT::i64, FpToInt32);
   }
 
-  if (DestVT == MVT::i64 && (SrcVT == MVT::f32 || SrcVT == MVT::f64))
+  if (SrcVT == MVT::f32 || SrcVT == MVT::f64)
     return LowerFP_TO_INT64(Op, DAG, OpOpcode == ISD::FP_TO_SINT);
 
   return SDValue();
@@ -4948,7 +4976,9 @@ SDValue AMDGPUTargetLowering::PerformDAGCombine(SDNode *N,
     //   vnt1 = build_vector (t1 (bitcast t0:x)), (t1 (bitcast t0:y))
     if (DestVT.isVector()) {
       SDValue Src = N->getOperand(0);
-      if (Src.getOpcode() == ISD::BUILD_VECTOR) {
+      if (Src.getOpcode() == ISD::BUILD_VECTOR &&
+          (DCI.getDAGCombineLevel() < AfterLegalizeDAG ||
+           isOperationLegal(ISD::BUILD_VECTOR, DestVT))) {
         EVT SrcVT = Src.getValueType();
         unsigned NElts = DestVT.getVectorNumElements();
 
diff --git a/llvm/lib/Target/AMDGPU/SIISelLowering.cpp b/llvm/lib/Target/AMDGPU/SIISelLowering.cpp
index f3547db9e9bd94..a408f939b7a49e 100644
--- a/llvm/lib/Target/AMDGPU/SIISelLowering.cpp
+++ b/llvm/lib/Target/AMDGPU/SIISelLowering.cpp
@@ -151,22 +151,29 @@ SITargetLowering::SITargetLowering(const TargetMachine &TM,
     if (Subtarget->useRealTrue16Insts()) {
       addRegisterClass(MVT::i16, &AMDGPU::VGPR_16RegClass);
       addRegisterClass(MVT::f16, &AMDGPU::VGPR_16RegClass);
+      addRegisterClass(MVT::bf16, &AMDGPU::VGPR_16RegClass);
     } else {
       addRegisterClass(MVT::i16, &AMDGPU::SReg_32RegClass);
       addRegisterClass(MVT::f16, &AMDGPU::SReg_32RegClass);
+      addRegisterClass(MVT::bf16, &AMDGPU::SReg_32RegClass);
     }
 
     // Unless there are also VOP3P operations, not operations are really legal.
     addRegisterClass(MVT::v2i16, &AMDGPU::SReg_32RegClass);
     addRegisterClass(MVT::v2f16, &AMDGPU::SReg_32RegClass);
+    addRegisterClass(MVT::v2bf16, &AMDGPU::SReg_32RegClass);
     addRegisterClass(MVT::v4i16, &AMDGPU::SReg_64RegClass);
     addRegisterClass(MVT::v4f16, &AMDGPU::SReg_64RegClass);
+    addRegisterClass(MVT::v4bf16, &AMDGPU::SReg_64RegClass);
     addRegisterClass(MVT::v8i16, &AMDGPU::SGPR_128RegClass);
     addRegisterClass(MVT::v8f16, &AMDGPU::SGPR_128RegClass);
+    addRegisterClass(MVT::v8bf16, &AMDGPU::SGPR_128RegClass);
     addRegisterClass(MVT::v16i16, &AMDGPU::SGPR_256RegClass);
     addRegisterClass(MVT::v16f16, &AMDGPU::SGPR_256RegClass);
+    addRegisterClass(MVT::v16bf16, &AMDGPU::SGPR_256RegClass);
     addRegisterClass(MVT::v32i16, &AMDGPU::SGPR_512RegClass);
     addRegisterClass(MVT::v32f16, &AMDGPU::SGPR_512RegClass);
+    addRegisterClass(MVT::v32bf16, &AMDGPU::SGPR_512RegClass);
   }
 
   addRegisterClass(MVT::v32i32, &AMDGPU::VReg_1024RegClass);
@@ -196,6 +203,41 @@ SITargetLowering::SITargetLowering(const TargetMachine &TM,
                       MVT::i1,     MVT::v32i32},
                      Custom);
 
+  if (isTypeLegal(MVT::bf16)) {
+    for (unsigned Opc :
+         {ISD::FADD,     ISD::FSUB,       ISD::FMUL,    ISD::FDIV,
+          ISD::FREM,     ISD::FMA,        ISD::FMINNUM, ISD::FMAXNUM,
+          ISD::FMINIMUM, ISD::FMAXIMUM,   ISD::FSQRT,   ISD::FCBRT,
+          ISD::FSIN,     ISD::FCOS,       ISD::FPOW,    ISD::FPOWI,
+          ISD::FLDEXP,   ISD::FFREXP,     ISD::FLOG,    ISD::FLOG2,
+          ISD::FLOG10,   ISD::FEXP,       ISD::FEXP2,   ISD::FEXP10,
+          ISD::FCEIL,    ISD::FTRUNC,     ISD::FRINT,   ISD::FNEARBYINT,
+          ISD::FROUND,   ISD::FROUNDEVEN, ISD::FFLOOR,  ISD::FCANONICALIZE,
+          ISD::SETCC}) {
+      // FIXME: The promoted to type shouldn't need to be explicit
+      setOperationAction(Opc, MVT::bf16, Promote);
+      AddPromotedToType(Opc, MVT::bf16, MVT::f32);
+    }
+
+    setOperationAction(ISD::FP_ROUND, MVT::bf16, Expand);
+
+    setOperationAction(ISD::SELECT, MVT::bf16, Promote);
+    AddPromotedToType(ISD::SELECT, MVT::bf16, MVT::i16);
+
+    // TODO: Could make these legal
+    setOperationAction(ISD::FABS, MVT::bf16, Expand);
+    setOperationAction(ISD::FNEG, MVT::bf16, Expand);
+    setOperationAction(ISD::FCOPYSIGN, MVT::bf16, Expand);
+
+    // We only need to custom lower because we can't specify an action for bf16
+    // sources.
+    setOperationAction(ISD::FP_TO_SINT, MVT::i32, Custom);
+    setOperationAction(ISD::FP_TO_UINT, MVT::i32, Custom);
+
+    setOperationAction(ISD::BUILD_VECTOR, MVT::v2bf16, Promote);
+    AddPromotedToType(ISD::BUILD_VECTOR, MVT::v2bf16, MVT::v2i16);
+  }
+
   setTruncStoreAction(MVT::v2i32, MVT::v2i16, Expand);
   setTruncStoreAction(MVT::v3i32, MVT::v3i16, Expand);
   setTruncStoreAction(MVT::v4i32, MVT::v4i16, Expand);
@@ -271,13 +313,14 @@ SITargetLowering::SITargetLowering(const TargetMachine &TM,
   // We only support LOAD/STORE and vector manipulation ops for vectors
   // with > 4 elements.
   for (MVT VT :
-       {MVT::v8i32,  MVT::v8f32,  MVT::v9i32,  MVT::v9f32,  MVT::v10i32,
-        MVT::v10f32, MVT::v11i32, MVT::v11f32, MVT::v12i32, MVT::v12f32,
-        MVT::v16i32, MVT::v16f32, MVT::v2i64,  MVT::v2f64,  MVT::v4i16,
-        MVT::v4f16,  MVT::v3i64,  MVT::v3f64,  MVT::v6i32,  MVT::v6f32,
-        MVT::v4i64,  MVT::v4f64,  MVT::v8i64,  MVT::v8f64,  MVT::v8i16,
-        MVT::v8f16,  MVT::v16i16, MVT::v16f16, MVT::v16i64, MVT::v16f64,
-        MVT::v32i32, MVT::v32f32, MVT::v32i16, MVT::v32f16}) {
+       {MVT::v8i32,   MVT::v8f32,  MVT::v9i32,  MVT::v9f32,  MVT::v10i32,
+        MVT::v10f32,  MVT::v11i32, MVT::v11f32, MVT::v12i32, MVT::v12f32,
+        MVT::v16i32,  MVT::v16f32, MVT::v2i64,  MVT::v2f64,  MVT::v4i16,
+        MVT::v4f16,   MVT::v4bf16, MVT::v3i64,  MVT::v3f64,  MVT::v6i32,
+        MVT::v6f32,   MVT::v4i64,  MVT::v4f64,  MVT::v8i64,  MVT::v8f64,
+        MVT::v8i16,   MVT::v8f16,  MVT::v8bf16, MVT::v16i16, MVT::v16f16,
+        MVT::v16bf16, MVT::v16i64, MVT::v16f64, MVT::v32i32, MVT::v32f32,
+        MVT::v32i16,  MVT::v32f16, MVT::v32bf16}) {
     for (unsigned Op = 0; Op < ISD::BUILTIN_OP_END; ++Op) {
       switch (Op) {
       case ISD::LOAD:
@@ -383,13 +426,14 @@ SITargetLowering::SITargetLowering(const TargetMachine &TM,
                      {MVT::v8i32, MVT::v8f32, MVT::v16i32, MVT::v16f32},
                      Expand);
 
-  setOperationAction(ISD::BUILD_VECTOR, {MVT::v4f16, MVT::v4i16}, Custom);
+  setOperationAction(ISD::BUILD_VECTOR, {MVT::v4f16, MVT::v4i16, MVT::v4bf16},
+                     Custom);
 
   // Avoid stack access for these.
   // TODO: Generalize to more vector types.
   setOperationAction({ISD::EXTRACT_VECTOR_ELT, ISD::INSERT_VECTOR_ELT},
-                     {MVT::v2i16, MVT::v2f16, MVT::v2i8, MVT::v4i8, MVT::v8i8,
-                      MVT::v4i16, MVT::v4f16},
+                     {MVT::v2i16, MVT::v2f16, MVT::v2bf16, MVT::v2i8, MVT::v4i8,
+                      MVT::v8i8, MVT::v4i16, MVT::v4f16, MVT::v4bf16},
                      Custom);
 
   // Deal with vec3 vector operations when widened to vec4.
@@ -498,6 +542,11 @@ SITargetLowering::SITargetLowering(const TargetMachine &TM,
   setOperationAction(ISD::BF16_TO_FP, {MVT::i16, MVT::f32, MVT::f64}, Expand);
   setOperationAction(ISD::FP_TO_BF16, {MVT::i16, MVT::f32, MVT::f64}, Expand);
 
+  // Custom lower these because we can't specify a rule based on an illegal
+  // source bf16.
+  setOperationAction({ISD::FP_EXTEND, ISD::STRICT_FP_EXTEND}, MVT::f32, Custom);
+  setOperationAction({ISD::FP_EXTEND, ISD::STRICT_FP_EXTEND}, MVT::f64, Custom);
+
   if (Subtarget->has16BitInsts()) {
     setOperationAction({ISD::Constant, ISD::SMIN, ISD::SMAX, ISD::UMIN,
                         ISD::UMAX, ISD::UADDSAT, ISD::USUBSAT},
@@ -524,9 +573,14 @@ SITargetLowering::SITargetLowering(const TargetMachine &TM,
     AddPromotedToType(ISD::FP_TO_FP16, MVT::i16, MVT::i32);
 
     setOperationAction({ISD::FP_TO_SINT, ISD::FP_TO_UINT}, MVT::i16, Custom);
+    setOperationAction({ISD::SINT_TO_FP, ISD::UINT_TO_FP}, MVT::i16, Custom);
+    setOperationAction({ISD::SINT_TO_FP, ISD::UINT_TO_FP}, MVT::i16, Custom);
+
+    setOperationAction({ISD::SINT_TO_FP, ISD::UINT_TO_FP}, MVT::i32, Custom);
 
     // F16 - Constant Actions.
     setOperationAction(ISD::ConstantFP, MVT::f16, Legal);
+    setOperationAction(ISD::ConstantFP, MVT::bf16, Legal);
 
     // F16 - Load/Store Actions.
     setOperationAction(ISD::LOAD, MVT::f16, Promote);
@@ -534,16 +588,23 @@ SITargetLowering::SITargetLowering(const TargetMachine &TM,
     setOperationAction(ISD::STORE, MVT::f16, Promo...
[truncated]

Copy link

github-actions bot commented Jan 1, 2024

✅ With the latest revision this PR passed the C/C++ code formatter.

@arsenm arsenm merged commit f9fec40 into llvm:main Jan 9, 2024
4 checks passed
@arsenm arsenm deleted the bf16/legal-v32bf16 branch January 9, 2024 10:02
@jayfoad
Copy link
Contributor

jayfoad commented Jan 9, 2024

This is causing:

FAIL: LLVM :: CodeGen/AMDGPU/bf16.ll (1 of 3422)
******************** TEST 'LLVM :: CodeGen/AMDGPU/bf16.ll' FAILED ********************
Exit Code: 1

Command Output (stderr):
--
RUN: at line 2: /home/jayfoad2/llvm-release/bin/llc < /home/jayfoad2/git/llvm-project/llvm/test/CodeGen/AMDGPU/bf16.ll -mtriple=amdgcn | /home/jayfoad2/llvm-release/bin/FileCheck /home/jayfoad2/git/llvm-project/llvm/test/CodeGen/AMDGPU/bf16.ll -check-prefixes=GCN
+ /home/jayfoad2/llvm-release/bin/llc -mtriple=amdgcn
+ /home/jayfoad2/llvm-release/bin/FileCheck /home/jayfoad2/git/llvm-project/llvm/test/CodeGen/AMDGPU/bf16.ll -check-prefixes=GCN
RUN: at line 3: /home/jayfoad2/llvm-release/bin/llc < /home/jayfoad2/git/llvm-project/llvm/test/CodeGen/AMDGPU/bf16.ll -mtriple=amdgcn -mcpu=hawaii | /home/jayfoad2/llvm-release/bin/FileCheck /home/jayfoad2/git/llvm-project/llvm/test/CodeGen/AMDGPU/bf16.ll -check-prefixes=GFX7
+ /home/jayfoad2/llvm-release/bin/FileCheck /home/jayfoad2/git/llvm-project/llvm/test/CodeGen/AMDGPU/bf16.ll -check-prefixes=GFX7
+ /home/jayfoad2/llvm-release/bin/llc -mtriple=amdgcn -mcpu=hawaii
RUN: at line 4: /home/jayfoad2/llvm-release/bin/llc < /home/jayfoad2/git/llvm-project/llvm/test/CodeGen/AMDGPU/bf16.ll -mtriple=amdgcn -mcpu=tonga | /home/jayfoad2/llvm-release/bin/FileCheck /home/jayfoad2/git/llvm-project/llvm/test/CodeGen/AMDGPU/bf16.ll -check-prefixes=GFX8
+ /home/jayfoad2/llvm-release/bin/llc -mtriple=amdgcn -mcpu=tonga
+ /home/jayfoad2/llvm-release/bin/FileCheck /home/jayfoad2/git/llvm-project/llvm/test/CodeGen/AMDGPU/bf16.ll -check-prefixes=GFX8
/home/jayfoad2/git/llvm-project/llvm/test/CodeGen/AMDGPU/bf16.ll:9836:14: error: GFX8-NEXT: expected string not found in input
; GFX8-NEXT: v_lshlrev_b32_e32 v31, 16, v16
             ^
<stdin>:3227:41: note: scanning from here
 s_waitcnt vmcnt(0) expcnt(0) lgkmcnt(0)
                                        ^
<stdin>:3228:2: note: possible intended match here
 v_lshlrev_b32_e32 v31, 16, v30
 ^
/home/jayfoad2/git/llvm-project/llvm/test/CodeGen/AMDGPU/bf16.ll:12151:14: error: GFX8-NEXT: expected string not found in input
; GFX8-NEXT: v_lshlrev_b32_e32 v31, 16, v16
             ^
<stdin>:3916:41: note: scanning from here
 s_waitcnt vmcnt(0) expcnt(0) lgkmcnt(0)
                                        ^
<stdin>:3917:2: note: possible intended match here
 v_lshlrev_b32_e32 v31, 16, v30
 ^
/home/jayfoad2/git/llvm-project/llvm/test/CodeGen/AMDGPU/bf16.ll:14689:14: error: GFX8-NEXT: expected string not found in input
; GFX8-NEXT: v_lshlrev_b32_e32 v31, 16, v16
             ^
<stdin>:4674:41: note: scanning from here
 s_waitcnt vmcnt(0) expcnt(0) lgkmcnt(0)
                                        ^
<stdin>:4675:2: note: possible intended match here
 v_lshlrev_b32_e32 v31, 16, v30
 ^
/home/jayfoad2/git/llvm-project/llvm/test/CodeGen/AMDGPU/bf16.ll:16816:14: error: GFX8-NEXT: expected string not found in input
; GFX8-NEXT: v_lshlrev_b32_e32 v31, 16, v16
             ^
<stdin>:5143:41: note: scanning from here
 s_waitcnt vmcnt(0) expcnt(0) lgkmcnt(0)
                                        ^
<stdin>:5144:2: note: possible intended match here
 v_lshlrev_b32_e32 v31, 16, v30
 ^
/home/jayfoad2/git/llvm-project/llvm/test/CodeGen/AMDGPU/bf16.ll:28231:14: error: GFX8-NEXT: is not on the line after the previous match
; GFX8-NEXT: v_writelane_b32 v31, s35, 3
             ^
<stdin>:9660:2: note: 'next' match was here
 v_writelane_b32 v31, s35, 3
 ^
<stdin>:9658:29: note: previous match ended here
 v_writelane_b32 v31, s34, 2
                            ^
<stdin>:9659:1: note: non-matching line after previous match is here
 v_and_b32_e32 v0, 1, v0
^

Input file: <stdin>
Check file: /home/jayfoad2/git/llvm-project/llvm/test/CodeGen/AMDGPU/bf16.ll

-dump-input=help explains the following input dump.

Input was:
<<<<<<
              .
              .
              .
           3222:  .globl v_fadd_v32bf16 ; -- Begin function v_fadd_v32bf16 
           3223:  .p2align 2 
           3224:  .type v_fadd_v32bf16,@function 
           3225: v_fadd_v32bf16: ; @v_fadd_v32bf16 
           3226: ; %bb.0: 
           3227:  s_waitcnt vmcnt(0) expcnt(0) lgkmcnt(0) 
next:9836'0                                              X error: no match found
           3228:  v_lshlrev_b32_e32 v31, 16, v30 
next:9836'0      ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
next:9836'1       ?                               possible intended match
           3229:  v_lshlrev_b32_e32 v32, 16, v14 
next:9836'0      ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
           3230:  v_and_b32_e32 v30, 0xffff0000, v30 
next:9836'0      ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
           3231:  v_and_b32_e32 v14, 0xffff0000, v14 
next:9836'0      ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
           3232:  v_add_f32_e32 v31, v32, v31 
next:9836'0      ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
           3233:  v_add_f32_e32 v30, v14, v30 
next:9836'0      ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
              .
              .
              .
           3911:  .globl v_fmul_v32bf16 ; -- Begin function v_fmul_v32bf16 
           3912:  .p2align 2 
           3913:  .type v_fmul_v32bf16,@function 
           3914: v_fmul_v32bf16: ; @v_fmul_v32bf16 
           3915: ; %bb.0: 
           3916:  s_waitcnt vmcnt(0) expcnt(0) lgkmcnt(0) 
next:12151'0                                             X error: no match found
           3917:  v_lshlrev_b32_e32 v31, 16, v30 
next:12151'0     ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
next:12151'1      ?                               possible intended match
           3918:  v_lshlrev_b32_e32 v32, 16, v14 
next:12151'0     ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
           3919:  v_and_b32_e32 v30, 0xffff0000, v30 
next:12151'0     ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
           3920:  v_and_b32_e32 v14, 0xffff0000, v14 
next:12151'0     ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
           3921:  v_mul_f32_e32 v31, v32, v31 
next:12151'0     ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
           3922:  v_mul_f32_e32 v30, v14, v30 
next:12151'0     ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
              .
              .
              .
           4669:  .globl v_minnum_v32bf16 ; -- Begin function v_minnum_v32bf16 
           4670:  .p2align 2 
           4671:  .type v_minnum_v32bf16,@function 
           4672: v_minnum_v32bf16: ; @v_minnum_v32bf16 
           4673: ; %bb.0: 
           4674:  s_waitcnt vmcnt(0) expcnt(0) lgkmcnt(0) 
next:14689'0                                             X error: no match found
           4675:  v_lshlrev_b32_e32 v31, 16, v30 
next:14689'0     ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
next:14689'1      ?                               possible intended match
           4676:  v_lshlrev_b32_e32 v32, 16, v14 
next:14689'0     ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
           4677:  v_and_b32_e32 v30, 0xffff0000, v30 
next:14689'0     ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
           4678:  v_and_b32_e32 v14, 0xffff0000, v14 
next:14689'0     ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
           4679:  v_min_f32_e32 v31, v32, v31 
next:14689'0     ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
           4680:  v_min_f32_e32 v30, v14, v30 
next:14689'0     ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
              .
              .
              .
           5138:  .globl v_maxnum_v32bf16 ; -- Begin function v_maxnum_v32bf16 
           5139:  .p2align 2 
           5140:  .type v_maxnum_v32bf16,@function 
           5141: v_maxnum_v32bf16: ; @v_maxnum_v32bf16 
           5142: ; %bb.0: 
           5143:  s_waitcnt vmcnt(0) expcnt(0) lgkmcnt(0) 
next:16816'0                                             X error: no match found
           5144:  v_lshlrev_b32_e32 v31, 16, v30 
next:16816'0     ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
next:16816'1      ?                               possible intended match
           5145:  v_lshlrev_b32_e32 v32, 16, v14 
next:16816'0     ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
           5146:  v_and_b32_e32 v30, 0xffff0000, v30 
next:16816'0     ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
           5147:  v_and_b32_e32 v14, 0xffff0000, v14 
next:16816'0     ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
           5148:  v_max_f32_e32 v31, v32, v31 
next:16816'0     ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
           5149:  v_max_f32_e32 v30, v14, v30 
next:16816'0     ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
              .
              .
              .
           9655:  s_mov_b64 exec, s[4:5] 
           9656:  v_writelane_b32 v31, s30, 0 
           9657:  v_writelane_b32 v31, s31, 1 
           9658:  v_writelane_b32 v31, s34, 2 
           9659:  v_and_b32_e32 v0, 1, v0 
           9660:  v_writelane_b32 v31, s35, 3 
next:28231        !~~~~~~~~~~~~~~~~~~~~~~~~~~  error: match on wrong line
           9661:  v_cmp_eq_u32_e32 vcc, 1, v0 
           9662:  v_and_b32_e32 v0, 1, v1 
           9663:  v_writelane_b32 v31, s36, 4 
           9664:  v_cmp_eq_u32_e64 s[4:5], 1, v0 
           9665:  v_and_b32_e32 v0, 1, v2 
              .
              .
              .
>>>>>>

--

********************

justinfargnoli pushed a commit to justinfargnoli/llvm-project that referenced this pull request Jan 28, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
backend:AMDGPU llvm:SelectionDAG SelectionDAGISel as well
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

3 participants