Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

AMDGPU: Make v8bf16/v16bf16 legal types #76678

Merged
merged 1 commit into from
Jan 8, 2024

Conversation

arsenm
Copy link
Contributor

@arsenm arsenm commented Jan 1, 2024

Depends #76217

@llvmbot llvmbot added the llvm:SelectionDAG SelectionDAGISel as well label Jan 1, 2024
@llvmbot
Copy link
Collaborator

llvmbot commented Jan 1, 2024

@llvm/pr-subscribers-llvm-selectiondag

@llvm/pr-subscribers-backend-amdgpu

Author: Matt Arsenault (arsenm)

Changes

Patch is 1.38 MiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/76678.diff

27 Files Affected:

  • (modified) llvm/lib/CodeGen/SelectionDAG/LegalizeDAG.cpp (+21-5)
  • (modified) llvm/lib/Target/AMDGPU/AMDGPUCallingConv.td (+13-13)
  • (modified) llvm/lib/Target/AMDGPU/AMDGPUISelDAGToDAG.cpp (+1)
  • (modified) llvm/lib/Target/AMDGPU/AMDGPUISelLowering.cpp (+44-14)
  • (modified) llvm/lib/Target/AMDGPU/SIISelLowering.cpp (+135-32)
  • (modified) llvm/lib/Target/AMDGPU/SIISelLowering.h (+1)
  • (modified) llvm/lib/Target/AMDGPU/SIInstructions.td (+115-6)
  • (modified) llvm/lib/Target/AMDGPU/SIRegisterInfo.td (+31-31)
  • (modified) llvm/lib/Target/AMDGPU/VOP3PInstructions.td (+11-11)
  • (modified) llvm/test/CodeGen/AMDGPU/bf16.ll (+7250-9808)
  • (modified) llvm/test/CodeGen/AMDGPU/fcopysign.f32.ll (+5-7)
  • (modified) llvm/test/CodeGen/AMDGPU/fmed3-cast-combine.ll (+14-2)
  • (modified) llvm/test/CodeGen/AMDGPU/fneg-modifier-casting.ll (+5-5)
  • (modified) llvm/test/CodeGen/AMDGPU/function-args-inreg.ll (+2-2)
  • (modified) llvm/test/CodeGen/AMDGPU/function-args.ll (+70-88)
  • (modified) llvm/test/CodeGen/AMDGPU/function-returns.ll (+19-46)
  • (modified) llvm/test/CodeGen/AMDGPU/gfx-callable-argument-types.ll (+48-80)
  • (modified) llvm/test/CodeGen/AMDGPU/isel-amdgpu-cs-chain-preserve-cc.ll (+6-10)
  • (modified) llvm/test/CodeGen/AMDGPU/llvm.exp.ll (+13-12)
  • (modified) llvm/test/CodeGen/AMDGPU/llvm.exp10.ll (+13-12)
  • (modified) llvm/test/CodeGen/AMDGPU/llvm.exp2.ll (+43-15)
  • (modified) llvm/test/CodeGen/AMDGPU/llvm.is.fpclass.bf16.ll (+157-346)
  • (modified) llvm/test/CodeGen/AMDGPU/llvm.log.ll (+7-4)
  • (modified) llvm/test/CodeGen/AMDGPU/llvm.log10.ll (+7-4)
  • (modified) llvm/test/CodeGen/AMDGPU/llvm.log2.ll (+46-16)
  • (modified) llvm/test/CodeGen/AMDGPU/local-atomics-fp.ll (+2-2)
  • (modified) llvm/test/CodeGen/AMDGPU/vector_shuffle.packed.ll (+362-629)
diff --git a/llvm/lib/CodeGen/SelectionDAG/LegalizeDAG.cpp b/llvm/lib/CodeGen/SelectionDAG/LegalizeDAG.cpp
index 4e317062cec497..296ed3a3c3dc11 100644
--- a/llvm/lib/CodeGen/SelectionDAG/LegalizeDAG.cpp
+++ b/llvm/lib/CodeGen/SelectionDAG/LegalizeDAG.cpp
@@ -3199,7 +3199,16 @@ bool SelectionDAGLegalize::ExpandNode(SDNode *Node) {
       return true;
     }
     break;
-  case ISD::FP_ROUND:
+  case ISD::FP_ROUND: {
+    EVT VT = Node->getValueType(0);
+    if (VT.getScalarType() == MVT::bf16) {
+      Results.push_back(
+          DAG.getNode(ISD::FP_TO_BF16, SDLoc(Node), VT, Node->getOperand(0)));
+      break;
+    }
+
+    LLVM_FALLTHROUGH;
+  }
   case ISD::BITCAST:
     if ((Tmp1 = EmitStackConvert(Node->getOperand(0), Node->getValueType(0),
                                  Node->getValueType(0), dl)))
@@ -3226,12 +3235,19 @@ bool SelectionDAGLegalize::ExpandNode(SDNode *Node) {
       return true;
     }
     break;
-  case ISD::FP_EXTEND:
-    if ((Tmp1 = EmitStackConvert(Node->getOperand(0),
-                                 Node->getOperand(0).getValueType(),
-                                 Node->getValueType(0), dl)))
+  case ISD::FP_EXTEND: {
+    SDValue Op = Node->getOperand(0);
+    EVT SrcVT = Op.getValueType();
+    EVT DstVT = Node->getValueType(0);
+    if (SrcVT.getScalarType() == MVT::bf16) {
+      Results.push_back(DAG.getNode(ISD::BF16_TO_FP, SDLoc(Node), DstVT, Op));
+      break;
+    }
+
+    if ((Tmp1 = EmitStackConvert(Op, SrcVT, DstVT, dl)))
       Results.push_back(Tmp1);
     break;
+  }
   case ISD::BF16_TO_FP: {
     // Always expand bf16 to f32 casts, they lower to ext + shift.
     //
diff --git a/llvm/lib/Target/AMDGPU/AMDGPUCallingConv.td b/llvm/lib/Target/AMDGPU/AMDGPUCallingConv.td
index 9036b26a6f6bcb..c5207228dc913f 100644
--- a/llvm/lib/Target/AMDGPU/AMDGPUCallingConv.td
+++ b/llvm/lib/Target/AMDGPU/AMDGPUCallingConv.td
@@ -22,28 +22,28 @@ def CC_SI_Gfx : CallingConv<[
   // 32 is reserved for the stack pointer
   // 33 is reserved for the frame pointer
   // 34 is reserved for the base pointer
-  CCIfInReg<CCIfType<[f32, i32, f16, i16, v2i16, v2f16] , CCAssignToReg<[
+  CCIfInReg<CCIfType<[f32, i32, f16, i16, v2i16, v2f16, bf16, v2bf16] , CCAssignToReg<[
     SGPR4, SGPR5, SGPR6, SGPR7,
     SGPR8, SGPR9, SGPR10, SGPR11, SGPR12, SGPR13, SGPR14, SGPR15,
     SGPR16, SGPR17, SGPR18, SGPR19, SGPR20, SGPR21, SGPR22, SGPR23,
     SGPR24, SGPR25, SGPR26, SGPR27, SGPR28, SGPR29
   ]>>>,
 
-  CCIfNotInReg<CCIfType<[f32, i32, f16, i16, v2i16, v2f16] , CCAssignToReg<[
+  CCIfNotInReg<CCIfType<[f32, i32, f16, i16, v2i16, v2f16, bf16, v2bf16] , CCAssignToReg<[
     VGPR0, VGPR1, VGPR2, VGPR3, VGPR4, VGPR5, VGPR6, VGPR7,
     VGPR8, VGPR9, VGPR10, VGPR11, VGPR12, VGPR13, VGPR14, VGPR15,
     VGPR16, VGPR17, VGPR18, VGPR19, VGPR20, VGPR21, VGPR22, VGPR23,
     VGPR24, VGPR25, VGPR26, VGPR27, VGPR28, VGPR29, VGPR30, VGPR31
   ]>>>,
 
-  CCIfType<[i32, f32, v2i16, v2f16, i16, f16, i1], CCAssignToStack<4, 4>>
+  CCIfType<[i32, f32, v2i16, v2f16, i16, f16, i1, bf16, v2bf16], CCAssignToStack<4, 4>>
 ]>;
 
 def RetCC_SI_Gfx : CallingConv<[
   CCIfType<[i1], CCPromoteToType<i32>>,
   CCIfType<[i1, i16], CCIfExtend<CCPromoteToType<i32>>>,
 
-  CCIfNotInReg<CCIfType<[f32, i32, f16, i16, v2i16, v2f16] , CCAssignToReg<[
+  CCIfNotInReg<CCIfType<[f32, i32, f16, i16, v2i16, v2f16, bf16, v2bf16] , CCAssignToReg<[
     VGPR0, VGPR1, VGPR2, VGPR3, VGPR4, VGPR5, VGPR6, VGPR7,
     VGPR8, VGPR9, VGPR10, VGPR11, VGPR12, VGPR13, VGPR14, VGPR15,
     VGPR16, VGPR17, VGPR18, VGPR19, VGPR20, VGPR21, VGPR22, VGPR23,
@@ -66,7 +66,7 @@ def RetCC_SI_Gfx : CallingConv<[
 
 def CC_SI_SHADER : CallingConv<[
 
-  CCIfInReg<CCIfType<[f32, i32, f16, i16, v2i16, v2f16] , CCAssignToReg<[
+  CCIfInReg<CCIfType<[f32, i32, f16, i16, v2i16, v2f16, bf16, v2bf16] , CCAssignToReg<[
     SGPR0, SGPR1, SGPR2, SGPR3, SGPR4, SGPR5, SGPR6, SGPR7,
     SGPR8, SGPR9, SGPR10, SGPR11, SGPR12, SGPR13, SGPR14, SGPR15,
     SGPR16, SGPR17, SGPR18, SGPR19, SGPR20, SGPR21, SGPR22, SGPR23,
@@ -76,7 +76,7 @@ def CC_SI_SHADER : CallingConv<[
   ]>>>,
 
   // 32*4 + 4 is the minimum for a fetch shader consumer with 32 inputs.
-  CCIfNotInReg<CCIfType<[f32, i32, f16, i16, v2i16, v2f16] , CCAssignToReg<[
+  CCIfNotInReg<CCIfType<[f32, i32, f16, i16, v2i16, v2f16, bf16, v2bf16] , CCAssignToReg<[
     VGPR0, VGPR1, VGPR2, VGPR3, VGPR4, VGPR5, VGPR6, VGPR7,
     VGPR8, VGPR9, VGPR10, VGPR11, VGPR12, VGPR13, VGPR14, VGPR15,
     VGPR16, VGPR17, VGPR18, VGPR19, VGPR20, VGPR21, VGPR22, VGPR23,
@@ -109,7 +109,7 @@ def RetCC_SI_Shader : CallingConv<[
   ]>>,
 
   // 32*4 + 4 is the minimum for a fetch shader with 32 outputs.
-  CCIfType<[f32, f16, v2f16] , CCAssignToReg<[
+  CCIfType<[f32, f16, v2f16, bf16, v2bf16] , CCAssignToReg<[
     VGPR0, VGPR1, VGPR2, VGPR3, VGPR4, VGPR5, VGPR6, VGPR7,
     VGPR8, VGPR9, VGPR10, VGPR11, VGPR12, VGPR13, VGPR14, VGPR15,
     VGPR16, VGPR17, VGPR18, VGPR19, VGPR20, VGPR21, VGPR22, VGPR23,
@@ -188,23 +188,23 @@ def CC_AMDGPU_Func : CallingConv<[
   CCIfType<[i1], CCPromoteToType<i32>>,
   CCIfType<[i8, i16], CCIfExtend<CCPromoteToType<i32>>>,
 
-  CCIfInReg<CCIfType<[f32, i32, f16, i16, v2i16, v2f16] , CCAssignToReg<
+  CCIfInReg<CCIfType<[f32, i32, f16, i16, v2i16, v2f16, bf16, v2bf16] , CCAssignToReg<
     !foreach(i, !range(0, 30), !cast<Register>("SGPR"#i))  // SGPR0-29
   >>>,
 
-  CCIfType<[i32, f32, i16, f16, v2i16, v2f16, i1], CCAssignToReg<[
+  CCIfType<[i32, f32, i16, f16, v2i16, v2f16, i1, bf16, v2bf16], CCAssignToReg<[
     VGPR0, VGPR1, VGPR2, VGPR3, VGPR4, VGPR5, VGPR6, VGPR7,
     VGPR8, VGPR9, VGPR10, VGPR11, VGPR12, VGPR13, VGPR14, VGPR15,
     VGPR16, VGPR17, VGPR18, VGPR19, VGPR20, VGPR21, VGPR22, VGPR23,
     VGPR24, VGPR25, VGPR26, VGPR27, VGPR28, VGPR29, VGPR30, VGPR31]>>,
-  CCIfType<[i32, f32, v2i16, v2f16, i16, f16, i1], CCAssignToStack<4, 4>>
+  CCIfType<[i32, f32, v2i16, v2f16, i16, f16, i1, bf16, v2bf16], CCAssignToStack<4, 4>>
 ]>;
 
 // Calling convention for leaf functions
 def RetCC_AMDGPU_Func : CallingConv<[
   CCIfType<[i1], CCPromoteToType<i32>>,
   CCIfType<[i1, i16], CCIfExtend<CCPromoteToType<i32>>>,
-  CCIfType<[i32, f32, i16, f16, v2i16, v2f16], CCAssignToReg<[
+  CCIfType<[i32, f32, i16, f16, v2i16, v2f16, bf16, v2bf16], CCAssignToReg<[
     VGPR0, VGPR1, VGPR2, VGPR3, VGPR4, VGPR5, VGPR6, VGPR7,
     VGPR8, VGPR9, VGPR10, VGPR11, VGPR12, VGPR13, VGPR14, VGPR15,
     VGPR16, VGPR17, VGPR18, VGPR19, VGPR20, VGPR21, VGPR22, VGPR23,
@@ -223,11 +223,11 @@ def CC_AMDGPU : CallingConv<[
 ]>;
 
 def CC_AMDGPU_CS_CHAIN : CallingConv<[
-  CCIfInReg<CCIfType<[f32, i32, f16, i16, v2i16, v2f16] , CCAssignToReg<
+  CCIfInReg<CCIfType<[f32, i32, f16, i16, v2i16, v2f16, bf16, v2bf16] , CCAssignToReg<
     !foreach(i, !range(105), !cast<Register>("SGPR"#i))
   >>>,
 
-  CCIfNotInReg<CCIfType<[f32, i32, f16, i16, v2i16, v2f16] , CCAssignToReg<
+  CCIfNotInReg<CCIfType<[f32, i32, f16, i16, v2i16, v2f16, bf16, v2bf16] , CCAssignToReg<
     !foreach(i, !range(8, 255), !cast<Register>("VGPR"#i))
   >>>
 ]>;
diff --git a/llvm/lib/Target/AMDGPU/AMDGPUISelDAGToDAG.cpp b/llvm/lib/Target/AMDGPU/AMDGPUISelDAGToDAG.cpp
index b0eac567ec9f18..40a49cbe3f518f 100644
--- a/llvm/lib/Target/AMDGPU/AMDGPUISelDAGToDAG.cpp
+++ b/llvm/lib/Target/AMDGPU/AMDGPUISelDAGToDAG.cpp
@@ -303,6 +303,7 @@ void AMDGPUDAGToDAGISel::PreprocessISelDAG() {
 
     switch (N->getOpcode()) {
     case ISD::BUILD_VECTOR:
+      // TODO: Match load d16 from shl (extload:i16), 16
       MadeChange |= matchLoadD16FromBuildVector(N);
       break;
     default:
diff --git a/llvm/lib/Target/AMDGPU/AMDGPUISelLowering.cpp b/llvm/lib/Target/AMDGPU/AMDGPUISelLowering.cpp
index 541a5b62450ddf..630910bdfd29c2 100644
--- a/llvm/lib/Target/AMDGPU/AMDGPUISelLowering.cpp
+++ b/llvm/lib/Target/AMDGPU/AMDGPUISelLowering.cpp
@@ -387,17 +387,20 @@ AMDGPUTargetLowering::AMDGPUTargetLowering(const TargetMachine &TM,
                       MVT::v9i32,  MVT::v9f32,  MVT::v10i32, MVT::v10f32,
                       MVT::v11i32, MVT::v11f32, MVT::v12i32, MVT::v12f32},
                      Custom);
+
+  // FIXME: Why is v8f16/v8bf16 missing?
   setOperationAction(
       ISD::EXTRACT_SUBVECTOR,
-      {MVT::v2f16,  MVT::v2i16,  MVT::v4f16,  MVT::v4i16,  MVT::v2f32,
-       MVT::v2i32,  MVT::v3f32,  MVT::v3i32,  MVT::v4f32,  MVT::v4i32,
-       MVT::v5f32,  MVT::v5i32,  MVT::v6f32,  MVT::v6i32,  MVT::v7f32,
-       MVT::v7i32,  MVT::v8f32,  MVT::v8i32,  MVT::v9f32,  MVT::v9i32,
-       MVT::v10i32, MVT::v10f32, MVT::v11i32, MVT::v11f32, MVT::v12i32,
-       MVT::v12f32, MVT::v16f16, MVT::v16i16, MVT::v16f32, MVT::v16i32,
-       MVT::v32f32, MVT::v32i32, MVT::v2f64,  MVT::v2i64,  MVT::v3f64,
-       MVT::v3i64,  MVT::v4f64,  MVT::v4i64,  MVT::v8f64,  MVT::v8i64,
-       MVT::v16f64, MVT::v16i64, MVT::v32i16, MVT::v32f16},
+      {MVT::v2f16,  MVT::v2bf16, MVT::v2i16,  MVT::v4f16,  MVT::v4bf16,
+       MVT::v4i16,  MVT::v2f32,  MVT::v2i32,  MVT::v3f32,  MVT::v3i32,
+       MVT::v4f32,  MVT::v4i32,  MVT::v5f32,  MVT::v5i32,  MVT::v6f32,
+       MVT::v6i32,  MVT::v7f32,  MVT::v7i32,  MVT::v8f32,  MVT::v8i32,
+       MVT::v9f32,  MVT::v9i32,  MVT::v10i32, MVT::v10f32, MVT::v11i32,
+       MVT::v11f32, MVT::v12i32, MVT::v12f32, MVT::v16f16, MVT::v16bf16,
+       MVT::v16i16, MVT::v16f32, MVT::v16i32, MVT::v32f32, MVT::v32i32,
+       MVT::v2f64,  MVT::v2i64,  MVT::v3f64,  MVT::v3i64,  MVT::v4f64,
+       MVT::v4i64,  MVT::v8f64,  MVT::v8i64,  MVT::v16f64, MVT::v16i64,
+       MVT::v32i16, MVT::v32f16, MVT::v32bf16},
       Custom);
 
   setOperationAction(ISD::FP16_TO_FP, MVT::f64, Expand);
@@ -3282,7 +3285,15 @@ SDValue AMDGPUTargetLowering::LowerUINT_TO_FP(SDValue Op,
     return DAG.getNode(ISD::UINT_TO_FP, DL, DestVT, Ext);
   }
 
-  assert(SrcVT == MVT::i64 && "operation should be legal");
+  if (DestVT == MVT::bf16) {
+    SDLoc SL(Op);
+    SDValue ToF32 = DAG.getNode(ISD::UINT_TO_FP, SL, MVT::f32, Src);
+    SDValue FPRoundFlag = DAG.getIntPtrConstant(0, SL, /*isTarget=*/true);
+    return DAG.getNode(ISD::FP_ROUND, SL, MVT::bf16, ToF32, FPRoundFlag);
+  }
+
+  if (SrcVT != MVT::i64)
+    return Op;
 
   if (Subtarget->has16BitInsts() && DestVT == MVT::f16) {
     SDLoc DL(Op);
@@ -3320,7 +3331,15 @@ SDValue AMDGPUTargetLowering::LowerSINT_TO_FP(SDValue Op,
     return DAG.getNode(ISD::SINT_TO_FP, DL, DestVT, Ext);
   }
 
-  assert(SrcVT == MVT::i64 && "operation should be legal");
+  if (DestVT == MVT::bf16) {
+    SDLoc SL(Op);
+    SDValue ToF32 = DAG.getNode(ISD::SINT_TO_FP, SL, MVT::f32, Src);
+    SDValue FPRoundFlag = DAG.getIntPtrConstant(0, SL, /*isTarget=*/true);
+    return DAG.getNode(ISD::FP_ROUND, SL, MVT::bf16, ToF32, FPRoundFlag);
+  }
+
+  if (SrcVT != MVT::i64)
+    return Op;
 
   // TODO: Factor out code common with LowerUINT_TO_FP.
 
@@ -3518,7 +3537,7 @@ SDValue AMDGPUTargetLowering::LowerFP_TO_FP16(SDValue Op, SelectionDAG &DAG) con
   return DAG.getZExtOrTrunc(V, DL, Op.getValueType());
 }
 
-SDValue AMDGPUTargetLowering::LowerFP_TO_INT(SDValue Op,
+SDValue AMDGPUTargetLowering::LowerFP_TO_INT(const SDValue Op,
                                              SelectionDAG &DAG) const {
   SDValue Src = Op.getOperand(0);
   unsigned OpOpcode = Op.getOpcode();
@@ -3529,6 +3548,12 @@ SDValue AMDGPUTargetLowering::LowerFP_TO_INT(SDValue Op,
   if (SrcVT == MVT::f16 && DestVT == MVT::i16)
     return Op;
 
+  if (SrcVT == MVT::bf16) {
+    SDLoc DL(Op);
+    SDValue PromotedSrc = DAG.getNode(ISD::FP_EXTEND, DL, MVT::f32, Src);
+    return DAG.getNode(Op.getOpcode(), DL, DestVT, PromotedSrc);
+  }
+
   // Promote i16 to i32
   if (DestVT == MVT::i16 && (SrcVT == MVT::f32 || SrcVT == MVT::f64)) {
     SDLoc DL(Op);
@@ -3537,6 +3562,9 @@ SDValue AMDGPUTargetLowering::LowerFP_TO_INT(SDValue Op,
     return DAG.getNode(ISD::TRUNCATE, DL, MVT::i16, FpToInt32);
   }
 
+  if (DestVT != MVT::i64)
+    return Op;
+
   if (SrcVT == MVT::f16 ||
       (SrcVT == MVT::f32 && Src.getOpcode() == ISD::FP16_TO_FP)) {
     SDLoc DL(Op);
@@ -3547,7 +3575,7 @@ SDValue AMDGPUTargetLowering::LowerFP_TO_INT(SDValue Op,
     return DAG.getNode(Ext, DL, MVT::i64, FpToInt32);
   }
 
-  if (DestVT == MVT::i64 && (SrcVT == MVT::f32 || SrcVT == MVT::f64))
+  if (SrcVT == MVT::f32 || SrcVT == MVT::f64)
     return LowerFP_TO_INT64(Op, DAG, OpOpcode == ISD::FP_TO_SINT);
 
   return SDValue();
@@ -4948,7 +4976,9 @@ SDValue AMDGPUTargetLowering::PerformDAGCombine(SDNode *N,
     //   vnt1 = build_vector (t1 (bitcast t0:x)), (t1 (bitcast t0:y))
     if (DestVT.isVector()) {
       SDValue Src = N->getOperand(0);
-      if (Src.getOpcode() == ISD::BUILD_VECTOR) {
+      if (Src.getOpcode() == ISD::BUILD_VECTOR &&
+          (DCI.getDAGCombineLevel() < AfterLegalizeDAG ||
+           isOperationLegal(ISD::BUILD_VECTOR, DestVT))) {
         EVT SrcVT = Src.getValueType();
         unsigned NElts = DestVT.getVectorNumElements();
 
diff --git a/llvm/lib/Target/AMDGPU/SIISelLowering.cpp b/llvm/lib/Target/AMDGPU/SIISelLowering.cpp
index f3547db9e9bd94..c4af1fc23afe18 100644
--- a/llvm/lib/Target/AMDGPU/SIISelLowering.cpp
+++ b/llvm/lib/Target/AMDGPU/SIISelLowering.cpp
@@ -151,20 +151,26 @@ SITargetLowering::SITargetLowering(const TargetMachine &TM,
     if (Subtarget->useRealTrue16Insts()) {
       addRegisterClass(MVT::i16, &AMDGPU::VGPR_16RegClass);
       addRegisterClass(MVT::f16, &AMDGPU::VGPR_16RegClass);
+      addRegisterClass(MVT::bf16, &AMDGPU::VGPR_16RegClass);
     } else {
       addRegisterClass(MVT::i16, &AMDGPU::SReg_32RegClass);
       addRegisterClass(MVT::f16, &AMDGPU::SReg_32RegClass);
+      addRegisterClass(MVT::bf16, &AMDGPU::SReg_32RegClass);
     }
 
     // Unless there are also VOP3P operations, not operations are really legal.
     addRegisterClass(MVT::v2i16, &AMDGPU::SReg_32RegClass);
     addRegisterClass(MVT::v2f16, &AMDGPU::SReg_32RegClass);
+    addRegisterClass(MVT::v2bf16, &AMDGPU::SReg_32RegClass);
     addRegisterClass(MVT::v4i16, &AMDGPU::SReg_64RegClass);
     addRegisterClass(MVT::v4f16, &AMDGPU::SReg_64RegClass);
+    addRegisterClass(MVT::v4bf16, &AMDGPU::SReg_64RegClass);
     addRegisterClass(MVT::v8i16, &AMDGPU::SGPR_128RegClass);
     addRegisterClass(MVT::v8f16, &AMDGPU::SGPR_128RegClass);
+    addRegisterClass(MVT::v8bf16, &AMDGPU::SGPR_128RegClass);
     addRegisterClass(MVT::v16i16, &AMDGPU::SGPR_256RegClass);
     addRegisterClass(MVT::v16f16, &AMDGPU::SGPR_256RegClass);
+    addRegisterClass(MVT::v16bf16, &AMDGPU::SGPR_256RegClass);
     addRegisterClass(MVT::v32i16, &AMDGPU::SGPR_512RegClass);
     addRegisterClass(MVT::v32f16, &AMDGPU::SGPR_512RegClass);
   }
@@ -196,6 +202,41 @@ SITargetLowering::SITargetLowering(const TargetMachine &TM,
                       MVT::i1,     MVT::v32i32},
                      Custom);
 
+  if (isTypeLegal(MVT::bf16)) {
+    for (unsigned Opc :
+         {ISD::FADD,     ISD::FSUB,       ISD::FMUL,    ISD::FDIV,
+          ISD::FREM,     ISD::FMA,        ISD::FMINNUM, ISD::FMAXNUM,
+          ISD::FMINIMUM, ISD::FMAXIMUM,   ISD::FSQRT,   ISD::FCBRT,
+          ISD::FSIN,     ISD::FCOS,       ISD::FPOW,    ISD::FPOWI,
+          ISD::FLDEXP,   ISD::FFREXP,     ISD::FLOG,    ISD::FLOG2,
+          ISD::FLOG10,   ISD::FEXP,       ISD::FEXP2,   ISD::FEXP10,
+          ISD::FCEIL,    ISD::FTRUNC,     ISD::FRINT,   ISD::FNEARBYINT,
+          ISD::FROUND,   ISD::FROUNDEVEN, ISD::FFLOOR,  ISD::FCANONICALIZE,
+          ISD::SETCC}) {
+      // FIXME: The promoted to type shouldn't need to be explicit
+      setOperationAction(Opc, MVT::bf16, Promote);
+      AddPromotedToType(Opc, MVT::bf16, MVT::f32);
+    }
+
+    setOperationAction(ISD::FP_ROUND, MVT::bf16, Expand);
+
+    setOperationAction(ISD::SELECT, MVT::bf16, Promote);
+    AddPromotedToType(ISD::SELECT, MVT::bf16, MVT::i16);
+
+    // TODO: Could make these legal
+    setOperationAction(ISD::FABS, MVT::bf16, Expand);
+    setOperationAction(ISD::FNEG, MVT::bf16, Expand);
+    setOperationAction(ISD::FCOPYSIGN, MVT::bf16, Expand);
+
+    // We only need to custom lower because we can't specify an action for bf16
+    // sources.
+    setOperationAction(ISD::FP_TO_SINT, MVT::i32, Custom);
+    setOperationAction(ISD::FP_TO_UINT, MVT::i32, Custom);
+
+    setOperationAction(ISD::BUILD_VECTOR, MVT::v2bf16, Promote);
+    AddPromotedToType(ISD::BUILD_VECTOR, MVT::v2bf16, MVT::v2i16);
+  }
+
   setTruncStoreAction(MVT::v2i32, MVT::v2i16, Expand);
   setTruncStoreAction(MVT::v3i32, MVT::v3i16, Expand);
   setTruncStoreAction(MVT::v4i32, MVT::v4i16, Expand);
@@ -271,13 +312,14 @@ SITargetLowering::SITargetLowering(const TargetMachine &TM,
   // We only support LOAD/STORE and vector manipulation ops for vectors
   // with > 4 elements.
   for (MVT VT :
-       {MVT::v8i32,  MVT::v8f32,  MVT::v9i32,  MVT::v9f32,  MVT::v10i32,
-        MVT::v10f32, MVT::v11i32, MVT::v11f32, MVT::v12i32, MVT::v12f32,
-        MVT::v16i32, MVT::v16f32, MVT::v2i64,  MVT::v2f64,  MVT::v4i16,
-        MVT::v4f16,  MVT::v3i64,  MVT::v3f64,  MVT::v6i32,  MVT::v6f32,
-        MVT::v4i64,  MVT::v4f64,  MVT::v8i64,  MVT::v8f64,  MVT::v8i16,
-        MVT::v8f16,  MVT::v16i16, MVT::v16f16, MVT::v16i64, MVT::v16f64,
-        MVT::v32i32, MVT::v32f32, MVT::v32i16, MVT::v32f16}) {
+       {MVT::v8i32,   MVT::v8f32,  MVT::v9i32,  MVT::v9f32,  MVT::v10i32,
+        MVT::v10f32,  MVT::v11i32, MVT::v11f32, MVT::v12i32, MVT::v12f32,
+        MVT::v16i32,  MVT::v16f32, MVT::v2i64,  MVT::v2f64,  MVT::v4i16,
+        MVT::v4f16,   MVT::v4bf16, MVT::v3i64,  MVT::v3f64,  MVT::v6i32,
+        MVT::v6f32,   MVT::v4i64,  MVT::v4f64,  MVT::v8i64,  MVT::v8f64,
+        MVT::v8i16,   MVT::v8f16,  MVT::v8bf16, MVT::v16i16, MVT::v16f16,
+        MVT::v16bf16, MVT::v16i64, MVT::v16f64, MVT::v32i32, MVT::v32f32,
+        MVT::v32i16,  MVT::v32f16, MVT::v32bf16}) {
     for (unsigned Op = 0; Op < ISD::BUILTIN_OP_END; ++Op) {
       switch (Op) {
       case ISD::LOAD:
@@ -383,13 +425,14 @@ SITargetLowering::SITargetLowering(const TargetMachine &TM,
                      {MVT::v8i32, MVT::v8f32, MVT::v16i32, MVT::v16f32},
                      Expand);
 
-  setOperationAction(ISD::BUILD_VECTOR, {MVT::v4f16, MVT::v4i16}, Custom);
+  setOperationAction(ISD::BUILD_VECTOR, {MVT::v4f16, MVT::v4i16, MVT::v4bf16},
+                     Custom);
 
   // Avoid stack access for these.
   // TODO: Generalize to more vector types.
   setOperationAction({ISD::EXTRACT_VECTOR_ELT, ISD::INSERT_VECTOR_ELT},
-                     {MVT::v2i16, MVT::v2f16, MVT::v2i8, MVT::v4i8, MVT::v8i8,
-                      MVT::v4i16, MVT::v4f16},
+                     {MVT::v2i16, MVT::v2f16, MVT::v2bf16, MVT::v2i8, MVT::v4i8,
+                      MVT::v8i8, MVT::v4i16, MVT::v4f16, MVT::v4bf16},
                      Custom);
 
   // Deal with vec3 vector operations when widened to vec4.
@@ -498,6 +541,11 @@ SITargetLowering::SITargetLowering(const TargetMachine &TM,
   setOperationAction(ISD::BF16_TO_FP, {MVT::i16, MVT::f32, MVT::f64}, Expand);
   setOperationAction(ISD::FP_TO_BF16, {MVT::i16, MVT::f32, MVT::f64}, Expand);
 
+  // Custom lower these because we can't specify a rule based on an illegal
+  // source bf16.
+  setOperationAction({ISD::FP_EXTEND, ISD::STRICT_FP_EXTEND}, MVT::f32, Custom);
+  setOperationAction({ISD::FP_EXTEND, ISD::STRICT_FP_EXTEND}, MVT::f64, Custom);
+
   if (Subtarget->has16BitInsts()) {
     setOperationAction({ISD::Constant, ISD::SMIN, ISD::SMAX, ISD::UMIN,
                         ISD::UMAX, ISD::UADDSAT, ISD::USUBSAT},
@@ -524,9 +572,14 @@ SITargetLowering::SITargetLowering(const TargetMachine &TM,
     AddPromotedToType(ISD::FP_TO_FP16, MVT::i16, MVT::i32);
 
     setOperationAction({ISD::FP_TO_SINT, ISD::FP_TO_UINT}, MVT::i16, Custom);
+    setOperationAction({ISD::SINT_TO_FP, ISD::UINT_TO_FP}, MVT::i16, Custom);
+    setOperationAction({ISD::SINT_TO_FP, ISD::UINT_TO_FP}, MVT::i16, Custom);
+
+    setOperationAction({ISD::SINT_TO_FP, ISD::UINT_TO_FP}, MVT::i32, Custom);
 
     // F16 - Constant Actions.
     setOperationAction(ISD::ConstantFP, MVT::f16, Legal);
+    setOperationAction(ISD::ConstantFP, MVT::bf16, Legal);
 
     // F16 - Load/Store Actions.
     setOperationAction(ISD::LOAD, MVT::f16, Promote);
@@ -534,16 +587,23 @@ SITargetLowering::SITargetLowering(const TargetMachine &TM,
     setOperationAction(ISD::STORE, MVT::f16, Promote);
     AddPromotedToType(ISD::STORE, MVT::f16, MVT::i16);
 
+    // BF16 - Load/Store Actions.
+    setOperationAction(ISD::L...
[truncated]

Copy link

github-actions bot commented Jan 1, 2024

✅ With the latest revision this PR passed the C/C++ code formatter.

@arsenm arsenm merged commit bdbaf6e into llvm:main Jan 8, 2024
3 of 4 checks passed
@arsenm arsenm deleted the bf16/legal-v8bf16-v16bf16 branch January 8, 2024 11:59
arsenm added a commit that referenced this pull request Jan 9, 2024
justinfargnoli pushed a commit to justinfargnoli/llvm-project that referenced this pull request Jan 28, 2024
justinfargnoli pushed a commit to justinfargnoli/llvm-project that referenced this pull request Jan 28, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
backend:AMDGPU llvm:SelectionDAG SelectionDAGISel as well
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

2 participants