Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[AMDGPU] Use bf16 instead of i16 for bfloat #80908

Merged
merged 1 commit into from
Feb 16, 2024
Merged

Conversation

shiltian
Copy link
Contributor

@shiltian shiltian commented Feb 6, 2024

Currently we generally use i16 to represent bf16 in those tablegen files. This patch is trying to use bf16 directly.

Fix #79369.

Copy link

github-actions bot commented Feb 6, 2024

⚠️ C/C++ code formatter, clang-format found issues in your code. ⚠️

You can test this locally with the following command:
git-clang-format --diff 97434cb318d170a14914126f33181b759197ba41 5b66bb22a91690078a955cea6c02b6b746b6502b -- llvm/lib/Target/AMDGPU/AsmParser/AMDGPUAsmParser.cpp llvm/lib/Target/AMDGPU/MCTargetDesc/AMDGPUInstPrinter.cpp llvm/lib/Target/AMDGPU/MCTargetDesc/AMDGPUInstPrinter.h llvm/lib/Target/AMDGPU/MCTargetDesc/AMDGPUMCCodeEmitter.cpp llvm/lib/Target/AMDGPU/SIDefines.h llvm/lib/Target/AMDGPU/SIInstrInfo.cpp llvm/lib/Target/AMDGPU/Utils/AMDGPUBaseInfo.cpp llvm/lib/Target/AMDGPU/Utils/AMDGPUBaseInfo.h
View the diff from clang-format here.
diff --git a/llvm/lib/Target/AMDGPU/AsmParser/AMDGPUAsmParser.cpp b/llvm/lib/Target/AMDGPU/AsmParser/AMDGPUAsmParser.cpp
index 883b30562e..e45379f4de 100644
--- a/llvm/lib/Target/AMDGPU/AsmParser/AMDGPUAsmParser.cpp
+++ b/llvm/lib/Target/AMDGPU/AsmParser/AMDGPUAsmParser.cpp
@@ -617,7 +617,9 @@ public:
 
   bool isVSrc_f64() const { return isVCSrcF64() || isLiteralImm(MVT::f64); }
 
-  bool isVSrcT_bf16() const { return isVCSrcTBF16() || isLiteralImm(MVT::bf16); }
+  bool isVSrcT_bf16() const {
+    return isVCSrcTBF16() || isLiteralImm(MVT::bf16);
+  }
 
   bool isVSrcT_f16() const { return isVCSrcTF16() || isLiteralImm(MVT::f16); }
 
@@ -2361,7 +2363,7 @@ void AMDGPUOperand::addLiteralImmOperand(MCInst &Inst, int64_t Val, bool ApplyMo
   case AMDGPU::OPERAND_REG_INLINE_AC_BF16:
     if (isSafeTruncation(Val, 16) &&
         AMDGPU::isInlinableLiteralBF16(static_cast<int16_t>(Val),
-                                     AsmParser->hasInv2PiInlineImm())) {
+                                       AsmParser->hasInv2PiInlineImm())) {
       Inst.addOperand(MCOperand::createImm(Val));
       setImmKindConst();
       return;

@shiltian shiltian requested a review from t-tye February 6, 2024 22:11
llvm/include/llvm/IR/IntrinsicsAMDGPU.td Show resolved Hide resolved
llvm/lib/CodeGen/GlobalISel/IRTranslator.cpp Outdated Show resolved Hide resolved
llvm/include/llvm/IR/IntrinsicsAMDGPU.td Outdated Show resolved Hide resolved
llvm/test/MC/AMDGPU/bf16_imm.s Show resolved Hide resolved
@shiltian shiltian force-pushed the PR79369 branch 2 times, most recently from 672fd3c to d14668f Compare February 8, 2024 17:12
@shiltian shiltian changed the title [RFC][WIP][AMDGPU] Use bf16 instead of i16 for bfloat [RFC][AMDGPU] Use bf16 instead of i16 for bfloat Feb 8, 2024
@shiltian shiltian marked this pull request as ready for review February 8, 2024 17:13
@llvmbot llvmbot added clang Clang issues not falling into any other category backend:AMDGPU clang:codegen mc Machine (object) code llvm:globalisel llvm:ir labels Feb 8, 2024
@llvmbot
Copy link
Collaborator

llvmbot commented Feb 8, 2024

@llvm/pr-subscribers-clang-codegen
@llvm/pr-subscribers-clang
@llvm/pr-subscribers-mc

@llvm/pr-subscribers-llvm-ir

Author: Shilei Tian (shiltian)

Changes

Currently it looks like we generally use i16 to represent bf16 in those tablegen
files. I'm not sure of the reason behind it. My wild guess is the type bf16 was
not available when we enabled the support. This patch is trying to use bf16
directly in those tablegen files, aiming at fixing #79369.


Patch is 32.53 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/80908.diff

14 Files Affected:

  • (modified) clang/lib/CodeGen/CGBuiltin.cpp (-4)
  • (modified) llvm/include/llvm/IR/IntrinsicsAMDGPU.td (+4-4)
  • (modified) llvm/lib/CodeGen/GlobalISel/IRTranslator.cpp (+3-2)
  • (modified) llvm/lib/Target/AMDGPU/AsmParser/AMDGPUAsmParser.cpp (+66)
  • (modified) llvm/lib/Target/AMDGPU/MCTargetDesc/AMDGPUInstPrinter.cpp (+10)
  • (modified) llvm/lib/Target/AMDGPU/MCTargetDesc/AMDGPUMCCodeEmitter.cpp (+7)
  • (modified) llvm/lib/Target/AMDGPU/SIDefines.h (+7)
  • (modified) llvm/lib/Target/AMDGPU/SIInstrInfo.cpp (+7)
  • (modified) llvm/lib/Target/AMDGPU/SIInstrInfo.td (+32-26)
  • (modified) llvm/lib/Target/AMDGPU/SIRegisterInfo.td (+21-1)
  • (modified) llvm/lib/Target/AMDGPU/Utils/AMDGPUBaseInfo.h (+7)
  • (modified) llvm/lib/Target/AMDGPU/VOP3Instructions.td (+1-1)
  • (modified) llvm/test/CodeGen/AMDGPU/llvm.amdgcn.fdot2.bf16.bf16.ll (+18-18)
  • (added) llvm/test/MC/AMDGPU/bf16_imm.s (+8)
diff --git a/clang/lib/CodeGen/CGBuiltin.cpp b/clang/lib/CodeGen/CGBuiltin.cpp
index a7a410dab1a018..daf651917f2a96 100644
--- a/clang/lib/CodeGen/CGBuiltin.cpp
+++ b/clang/lib/CodeGen/CGBuiltin.cpp
@@ -5908,8 +5908,6 @@ RValue CodeGenFunction::EmitBuiltinExpr(const GlobalDecl GD, unsigned BuiltinID,
           }
         }
 
-        assert(ArgValue->getType()->canLosslesslyBitCastTo(PTy) &&
-               "Must be able to losslessly bit cast to param");
         // Cast vector type (e.g., v256i32) to x86_amx, this only happen
         // in amx intrinsics.
         if (PTy->isX86_AMXTy())
@@ -5939,8 +5937,6 @@ RValue CodeGenFunction::EmitBuiltinExpr(const GlobalDecl GD, unsigned BuiltinID,
         }
       }
 
-      assert(V->getType()->canLosslesslyBitCastTo(RetTy) &&
-             "Must be able to losslessly bit cast result type");
       // Cast x86_amx to vector type (e.g., v256i32), this only happen
       // in amx intrinsics.
       if (V->getType()->isX86_AMXTy())
diff --git a/llvm/include/llvm/IR/IntrinsicsAMDGPU.td b/llvm/include/llvm/IR/IntrinsicsAMDGPU.td
index 202fa4e8f4ea81..6795fb7aa0edb8 100644
--- a/llvm/include/llvm/IR/IntrinsicsAMDGPU.td
+++ b/llvm/include/llvm/IR/IntrinsicsAMDGPU.td
@@ -2819,11 +2819,11 @@ def int_amdgcn_fdot2_f16_f16 :
 def int_amdgcn_fdot2_bf16_bf16 :
   ClangBuiltin<"__builtin_amdgcn_fdot2_bf16_bf16">,
   DefaultAttrsIntrinsic<
-    [llvm_i16_ty],   // %r
+    [llvm_bfloat_ty],   // %r
     [
-      llvm_v2i16_ty, // %a
-      llvm_v2i16_ty, // %b
-      llvm_i16_ty    // %c
+      llvm_v2bf16_ty, // %a
+      llvm_v2bf16_ty, // %b
+      llvm_bfloat_ty    // %c
     ],
     [IntrNoMem, IntrSpeculatable]
   >;
diff --git a/llvm/lib/CodeGen/GlobalISel/IRTranslator.cpp b/llvm/lib/CodeGen/GlobalISel/IRTranslator.cpp
index c1d8e890a66edb..828229f3e569e3 100644
--- a/llvm/lib/CodeGen/GlobalISel/IRTranslator.cpp
+++ b/llvm/lib/CodeGen/GlobalISel/IRTranslator.cpp
@@ -1562,8 +1562,9 @@ bool IRTranslator::translateBitCast(const User &U,
 
 bool IRTranslator::translateCast(unsigned Opcode, const User &U,
                                  MachineIRBuilder &MIRBuilder) {
-  if (U.getType()->getScalarType()->isBFloatTy() ||
-      U.getOperand(0)->getType()->getScalarType()->isBFloatTy())
+  if (Opcode != TargetOpcode::G_BITCAST &&
+      (U.getType()->getScalarType()->isBFloatTy() ||
+       U.getOperand(0)->getType()->getScalarType()->isBFloatTy()))
     return false;
   Register Op = getOrCreateVReg(*U.getOperand(0));
   Register Res = getOrCreateVReg(U);
diff --git a/llvm/lib/Target/AMDGPU/AsmParser/AMDGPUAsmParser.cpp b/llvm/lib/Target/AMDGPU/AsmParser/AMDGPUAsmParser.cpp
index a94da992b33859..d6d96c251f7e30 100644
--- a/llvm/lib/Target/AMDGPU/AsmParser/AMDGPUAsmParser.cpp
+++ b/llvm/lib/Target/AMDGPU/AsmParser/AMDGPUAsmParser.cpp
@@ -475,6 +475,8 @@ class AMDGPUOperand : public MCParsedAsmOperand {
 
   bool isSSrcF64() const { return isSCSrc_b64() || isLiteralImm(MVT::f64); }
 
+  bool isSSrc_bf16() const { return isSCSrcB16() || isLiteralImm(MVT::bf16); }
+
   bool isSSrc_f16() const { return isSCSrcB16() || isLiteralImm(MVT::f16); }
 
   bool isSSrcV2F16() const {
@@ -541,22 +543,40 @@ class AMDGPUOperand : public MCParsedAsmOperand {
     return isRegOrInlineNoMods(AMDGPU::VS_64RegClassID, MVT::f64);
   }
 
+  bool isVCSrcTBF16() const {
+    return isRegOrInlineNoMods(AMDGPU::VS_16RegClassID, MVT::bf16);
+  }
+
   bool isVCSrcTF16() const {
     return isRegOrInlineNoMods(AMDGPU::VS_16RegClassID, MVT::f16);
   }
 
+  bool isVCSrcTBF16_Lo128() const {
+    return isRegOrInlineNoMods(AMDGPU::VS_16_Lo128RegClassID, MVT::bf16);
+  }
+
   bool isVCSrcTF16_Lo128() const {
     return isRegOrInlineNoMods(AMDGPU::VS_16_Lo128RegClassID, MVT::f16);
   }
 
+  bool isVCSrcFake16BF16_Lo128() const {
+    return isRegOrInlineNoMods(AMDGPU::VS_32_Lo128RegClassID, MVT::bf16);
+  }
+
   bool isVCSrcFake16F16_Lo128() const {
     return isRegOrInlineNoMods(AMDGPU::VS_32_Lo128RegClassID, MVT::f16);
   }
 
+  bool isVCSrc_bf16() const {
+    return isRegOrInlineNoMods(AMDGPU::VS_32RegClassID, MVT::bf16);
+  }
+
   bool isVCSrc_f16() const {
     return isRegOrInlineNoMods(AMDGPU::VS_32RegClassID, MVT::f16);
   }
 
+  bool isVCSrc_v2bf16() const { return isVCSrc_bf16(); }
+
   bool isVCSrc_v2f16() const { return isVCSrc_f16(); }
 
   bool isVSrc_b32() const {
@@ -597,18 +617,34 @@ class AMDGPUOperand : public MCParsedAsmOperand {
 
   bool isVSrc_f64() const { return isVCSrcF64() || isLiteralImm(MVT::f64); }
 
+  bool isVSrcT_bf16() const { return isVCSrcTBF16() || isLiteralImm(MVT::bf16); }
+
   bool isVSrcT_f16() const { return isVCSrcTF16() || isLiteralImm(MVT::f16); }
 
+  bool isVSrcT_bf16_Lo128() const {
+    return isVCSrcTBF16_Lo128() || isLiteralImm(MVT::bf16);
+  }
+
   bool isVSrcT_f16_Lo128() const {
     return isVCSrcTF16_Lo128() || isLiteralImm(MVT::f16);
   }
 
+  bool isVSrcFake16_bf16_Lo128() const {
+    return isVCSrcFake16BF16_Lo128() || isLiteralImm(MVT::bf16);
+  }
+
   bool isVSrcFake16_f16_Lo128() const {
     return isVCSrcFake16F16_Lo128() || isLiteralImm(MVT::f16);
   }
 
+  bool isVSrc_bf16() const { return isVCSrc_bf16() || isLiteralImm(MVT::bf16); }
+
   bool isVSrc_f16() const { return isVCSrc_f16() || isLiteralImm(MVT::f16); }
 
+  bool isVSrc_v2bf16() const {
+    return isVSrc_bf16() || isLiteralImm(MVT::v2bf16);
+  }
+
   bool isVSrc_v2f16() const { return isVSrc_f16() || isLiteralImm(MVT::v2f16); }
 
   bool isVISrcB32() const {
@@ -635,6 +671,10 @@ class AMDGPUOperand : public MCParsedAsmOperand {
     return isVISrcF16() || isVISrcB32();
   }
 
+  bool isVISrc_64_bf16() const {
+    return isRegOrInlineNoMods(AMDGPU::VReg_64RegClassID, MVT::bf16);
+  }
+
   bool isVISrc_64_f16() const {
     return isRegOrInlineNoMods(AMDGPU::VReg_64RegClassID, MVT::f16);
   }
@@ -803,6 +843,10 @@ class AMDGPUOperand : public MCParsedAsmOperand {
     return isAISrc_128F16() || isAISrc_128_b32();
   }
 
+  bool isVISrc_128_bf16() const {
+    return isRegOrInlineNoMods(AMDGPU::VReg_128RegClassID, MVT::bf16);
+  }
+
   bool isVISrc_128_f16() const {
     return isRegOrInlineNoMods(AMDGPU::VReg_128RegClassID, MVT::f16);
   }
@@ -1890,6 +1934,14 @@ static const fltSemantics *getOpFltSemantics(uint8_t OperandType) {
   case AMDGPU::OPERAND_REG_IMM_V2FP16:
   case AMDGPU::OPERAND_KIMM16:
     return &APFloat::IEEEhalf();
+  case AMDGPU::OPERAND_REG_IMM_BF16:
+  case AMDGPU::OPERAND_REG_IMM_BF16_DEFERRED:
+  case AMDGPU::OPERAND_REG_INLINE_C_BF16:
+  case AMDGPU::OPERAND_REG_INLINE_C_V2BF16:
+  case AMDGPU::OPERAND_REG_INLINE_AC_BF16:
+  case AMDGPU::OPERAND_REG_INLINE_AC_V2BF16:
+  case AMDGPU::OPERAND_REG_IMM_V2BF16:
+    return &APFloat::BFloat();
   default:
     llvm_unreachable("unsupported fp type");
   }
@@ -2186,17 +2238,24 @@ void AMDGPUOperand::addLiteralImmOperand(MCInst &Inst, int64_t Val, bool ApplyMo
     case AMDGPU::OPERAND_REG_INLINE_AC_INT32:
     case AMDGPU::OPERAND_REG_INLINE_AC_FP32:
     case AMDGPU::OPERAND_REG_IMM_INT16:
+    case AMDGPU::OPERAND_REG_IMM_BF16:
     case AMDGPU::OPERAND_REG_IMM_FP16:
+    case AMDGPU::OPERAND_REG_IMM_BF16_DEFERRED:
     case AMDGPU::OPERAND_REG_IMM_FP16_DEFERRED:
     case AMDGPU::OPERAND_REG_INLINE_C_INT16:
+    case AMDGPU::OPERAND_REG_INLINE_C_BF16:
     case AMDGPU::OPERAND_REG_INLINE_C_FP16:
     case AMDGPU::OPERAND_REG_INLINE_C_V2INT16:
+    case AMDGPU::OPERAND_REG_INLINE_C_V2BF16:
     case AMDGPU::OPERAND_REG_INLINE_C_V2FP16:
     case AMDGPU::OPERAND_REG_INLINE_AC_INT16:
+    case AMDGPU::OPERAND_REG_INLINE_AC_BF16:
     case AMDGPU::OPERAND_REG_INLINE_AC_FP16:
     case AMDGPU::OPERAND_REG_INLINE_AC_V2INT16:
+    case AMDGPU::OPERAND_REG_INLINE_AC_V2BF16:
     case AMDGPU::OPERAND_REG_INLINE_AC_V2FP16:
     case AMDGPU::OPERAND_REG_IMM_V2INT16:
+    case AMDGPU::OPERAND_REG_IMM_V2BF16:
     case AMDGPU::OPERAND_REG_IMM_V2FP16:
     case AMDGPU::OPERAND_REG_INLINE_C_V2FP32:
     case AMDGPU::OPERAND_REG_IMM_V2FP32:
@@ -2240,6 +2299,7 @@ void AMDGPUOperand::addLiteralImmOperand(MCInst &Inst, int64_t Val, bool ApplyMo
   case AMDGPU::OPERAND_REG_INLINE_AC_INT32:
   case AMDGPU::OPERAND_REG_INLINE_AC_FP32:
   case AMDGPU::OPERAND_REG_IMM_V2INT16:
+  case AMDGPU::OPERAND_REG_IMM_V2BF16:
   case AMDGPU::OPERAND_REG_IMM_V2FP16:
   case AMDGPU::OPERAND_REG_IMM_V2FP32:
   case AMDGPU::OPERAND_REG_INLINE_C_V2FP32:
@@ -2277,11 +2337,15 @@ void AMDGPUOperand::addLiteralImmOperand(MCInst &Inst, int64_t Val, bool ApplyMo
     return;
 
   case AMDGPU::OPERAND_REG_IMM_INT16:
+  case AMDGPU::OPERAND_REG_IMM_BF16:
   case AMDGPU::OPERAND_REG_IMM_FP16:
+  case AMDGPU::OPERAND_REG_IMM_BF16_DEFERRED:
   case AMDGPU::OPERAND_REG_IMM_FP16_DEFERRED:
   case AMDGPU::OPERAND_REG_INLINE_C_INT16:
+  case AMDGPU::OPERAND_REG_INLINE_C_BF16:
   case AMDGPU::OPERAND_REG_INLINE_C_FP16:
   case AMDGPU::OPERAND_REG_INLINE_AC_INT16:
+  case AMDGPU::OPERAND_REG_INLINE_AC_BF16:
   case AMDGPU::OPERAND_REG_INLINE_AC_FP16:
     if (isSafeTruncation(Val, 16) &&
         AMDGPU::isInlinableLiteral16(static_cast<int16_t>(Val),
@@ -2296,8 +2360,10 @@ void AMDGPUOperand::addLiteralImmOperand(MCInst &Inst, int64_t Val, bool ApplyMo
     return;
 
   case AMDGPU::OPERAND_REG_INLINE_C_V2INT16:
+  case AMDGPU::OPERAND_REG_INLINE_C_V2BF16:
   case AMDGPU::OPERAND_REG_INLINE_C_V2FP16:
   case AMDGPU::OPERAND_REG_INLINE_AC_V2INT16:
+  case AMDGPU::OPERAND_REG_INLINE_AC_V2BF16:
   case AMDGPU::OPERAND_REG_INLINE_AC_V2FP16: {
     assert(isSafeTruncation(Val, 16));
     assert(AMDGPU::isInlinableLiteral16(static_cast<int16_t>(Val),
diff --git a/llvm/lib/Target/AMDGPU/MCTargetDesc/AMDGPUInstPrinter.cpp b/llvm/lib/Target/AMDGPU/MCTargetDesc/AMDGPUInstPrinter.cpp
index abfa4a3531e8e1..96a0168f37e405 100644
--- a/llvm/lib/Target/AMDGPU/MCTargetDesc/AMDGPUInstPrinter.cpp
+++ b/llvm/lib/Target/AMDGPU/MCTargetDesc/AMDGPUInstPrinter.cpp
@@ -521,8 +521,11 @@ void AMDGPUInstPrinter::printImmediateV216(uint32_t Imm, uint8_t OpType,
     if (printImmediateFloat32(Imm, STI, O))
       return;
     break;
+  case AMDGPU::OPERAND_REG_IMM_V2BF16:
   case AMDGPU::OPERAND_REG_IMM_V2FP16:
+  case AMDGPU::OPERAND_REG_INLINE_C_V2BF16:
   case AMDGPU::OPERAND_REG_INLINE_C_V2FP16:
+  case AMDGPU::OPERAND_REG_INLINE_AC_V2BF16:
   case AMDGPU::OPERAND_REG_INLINE_AC_V2FP16:
     if (isUInt<16>(Imm) &&
         printImmediateFloat16(static_cast<uint16_t>(Imm), STI, O))
@@ -792,17 +795,24 @@ void AMDGPUInstPrinter::printRegularOperand(const MCInst *MI, unsigned OpNo,
     case AMDGPU::OPERAND_REG_IMM_INT16:
       printImmediateInt16(Op.getImm(), STI, O);
       break;
+    case AMDGPU::OPERAND_REG_INLINE_C_BF16:
     case AMDGPU::OPERAND_REG_INLINE_C_FP16:
+    case AMDGPU::OPERAND_REG_INLINE_AC_BF16:
     case AMDGPU::OPERAND_REG_INLINE_AC_FP16:
+    case AMDGPU::OPERAND_REG_IMM_BF16:
     case AMDGPU::OPERAND_REG_IMM_FP16:
+    case AMDGPU::OPERAND_REG_IMM_BF16_DEFERRED:
     case AMDGPU::OPERAND_REG_IMM_FP16_DEFERRED:
       printImmediate16(Op.getImm(), STI, O);
       break;
     case AMDGPU::OPERAND_REG_IMM_V2INT16:
+    case AMDGPU::OPERAND_REG_IMM_V2BF16:
     case AMDGPU::OPERAND_REG_IMM_V2FP16:
     case AMDGPU::OPERAND_REG_INLINE_C_V2INT16:
     case AMDGPU::OPERAND_REG_INLINE_AC_V2INT16:
+    case AMDGPU::OPERAND_REG_INLINE_C_V2BF16:
     case AMDGPU::OPERAND_REG_INLINE_C_V2FP16:
+    case AMDGPU::OPERAND_REG_INLINE_AC_V2BF16:
     case AMDGPU::OPERAND_REG_INLINE_AC_V2FP16:
       printImmediateV216(Op.getImm(), OpTy, STI, O);
       break;
diff --git a/llvm/lib/Target/AMDGPU/MCTargetDesc/AMDGPUMCCodeEmitter.cpp b/llvm/lib/Target/AMDGPU/MCTargetDesc/AMDGPUMCCodeEmitter.cpp
index 11f5e456e8d348..9ec174ba56c242 100644
--- a/llvm/lib/Target/AMDGPU/MCTargetDesc/AMDGPUMCCodeEmitter.cpp
+++ b/llvm/lib/Target/AMDGPU/MCTargetDesc/AMDGPUMCCodeEmitter.cpp
@@ -276,9 +276,13 @@ AMDGPUMCCodeEmitter::getLitEncoding(const MCOperand &MO,
   case AMDGPU::OPERAND_REG_INLINE_C_INT16:
   case AMDGPU::OPERAND_REG_INLINE_AC_INT16:
     return getLit16IntEncoding(static_cast<uint16_t>(Imm), STI);
+  case AMDGPU::OPERAND_REG_IMM_BF16:
   case AMDGPU::OPERAND_REG_IMM_FP16:
+  case AMDGPU::OPERAND_REG_IMM_BF16_DEFERRED:
   case AMDGPU::OPERAND_REG_IMM_FP16_DEFERRED:
+  case AMDGPU::OPERAND_REG_INLINE_C_BF16:
   case AMDGPU::OPERAND_REG_INLINE_C_FP16:
+  case AMDGPU::OPERAND_REG_INLINE_AC_BF16:
   case AMDGPU::OPERAND_REG_INLINE_AC_FP16:
     // FIXME Is this correct? What do inline immediates do on SI for f16 src
     // which does not have f16 support?
@@ -288,8 +292,11 @@ AMDGPUMCCodeEmitter::getLitEncoding(const MCOperand &MO,
   case AMDGPU::OPERAND_REG_INLINE_AC_V2INT16:
     return AMDGPU::getInlineEncodingV2I16(static_cast<uint32_t>(Imm))
         .value_or(255);
+  case AMDGPU::OPERAND_REG_IMM_V2BF16:
   case AMDGPU::OPERAND_REG_IMM_V2FP16:
+  case AMDGPU::OPERAND_REG_INLINE_C_V2BF16:
   case AMDGPU::OPERAND_REG_INLINE_C_V2FP16:
+  case AMDGPU::OPERAND_REG_INLINE_AC_V2BF16:
   case AMDGPU::OPERAND_REG_INLINE_AC_V2FP16:
     return AMDGPU::getInlineEncodingV2F16(static_cast<uint32_t>(Imm))
         .value_or(255);
diff --git a/llvm/lib/Target/AMDGPU/SIDefines.h b/llvm/lib/Target/AMDGPU/SIDefines.h
index 19596d53b45328..66b997eb180613 100644
--- a/llvm/lib/Target/AMDGPU/SIDefines.h
+++ b/llvm/lib/Target/AMDGPU/SIDefines.h
@@ -196,9 +196,12 @@ enum OperandType : unsigned {
   OPERAND_REG_IMM_INT16,
   OPERAND_REG_IMM_FP32,
   OPERAND_REG_IMM_FP64,
+  OPERAND_REG_IMM_BF16,
   OPERAND_REG_IMM_FP16,
+  OPERAND_REG_IMM_BF16_DEFERRED,
   OPERAND_REG_IMM_FP16_DEFERRED,
   OPERAND_REG_IMM_FP32_DEFERRED,
+  OPERAND_REG_IMM_V2BF16,
   OPERAND_REG_IMM_V2FP16,
   OPERAND_REG_IMM_V2INT16,
   OPERAND_REG_IMM_V2INT32,
@@ -208,10 +211,12 @@ enum OperandType : unsigned {
   OPERAND_REG_INLINE_C_INT16,
   OPERAND_REG_INLINE_C_INT32,
   OPERAND_REG_INLINE_C_INT64,
+  OPERAND_REG_INLINE_C_BF16,
   OPERAND_REG_INLINE_C_FP16,
   OPERAND_REG_INLINE_C_FP32,
   OPERAND_REG_INLINE_C_FP64,
   OPERAND_REG_INLINE_C_V2INT16,
+  OPERAND_REG_INLINE_C_V2BF16,
   OPERAND_REG_INLINE_C_V2FP16,
   OPERAND_REG_INLINE_C_V2INT32,
   OPERAND_REG_INLINE_C_V2FP32,
@@ -226,10 +231,12 @@ enum OperandType : unsigned {
   /// Operands with an AccVGPR register or inline constant
   OPERAND_REG_INLINE_AC_INT16,
   OPERAND_REG_INLINE_AC_INT32,
+  OPERAND_REG_INLINE_AC_BF16,
   OPERAND_REG_INLINE_AC_FP16,
   OPERAND_REG_INLINE_AC_FP32,
   OPERAND_REG_INLINE_AC_FP64,
   OPERAND_REG_INLINE_AC_V2INT16,
+  OPERAND_REG_INLINE_AC_V2BF16,
   OPERAND_REG_INLINE_AC_V2FP16,
   OPERAND_REG_INLINE_AC_V2INT32,
   OPERAND_REG_INLINE_AC_V2FP32,
diff --git a/llvm/lib/Target/AMDGPU/SIInstrInfo.cpp b/llvm/lib/Target/AMDGPU/SIInstrInfo.cpp
index c7628bd354309c..fcb2a6f1f3d75d 100644
--- a/llvm/lib/Target/AMDGPU/SIInstrInfo.cpp
+++ b/llvm/lib/Target/AMDGPU/SIInstrInfo.cpp
@@ -4181,13 +4181,20 @@ bool SIInstrInfo::isInlineConstant(const MachineOperand &MO,
   case AMDGPU::OPERAND_REG_INLINE_C_V2INT16:
   case AMDGPU::OPERAND_REG_INLINE_AC_V2INT16:
     return AMDGPU::isInlinableLiteralV2I16(Imm);
+  case AMDGPU::OPERAND_REG_IMM_V2BF16:
   case AMDGPU::OPERAND_REG_IMM_V2FP16:
+  case AMDGPU::OPERAND_REG_INLINE_C_V2BF16:
   case AMDGPU::OPERAND_REG_INLINE_C_V2FP16:
+  case AMDGPU::OPERAND_REG_INLINE_AC_V2BF16:
   case AMDGPU::OPERAND_REG_INLINE_AC_V2FP16:
     return AMDGPU::isInlinableLiteralV2F16(Imm);
+  case AMDGPU::OPERAND_REG_IMM_BF16:
   case AMDGPU::OPERAND_REG_IMM_FP16:
+  case AMDGPU::OPERAND_REG_IMM_BF16_DEFERRED:
   case AMDGPU::OPERAND_REG_IMM_FP16_DEFERRED:
+  case AMDGPU::OPERAND_REG_INLINE_C_BF16:
   case AMDGPU::OPERAND_REG_INLINE_C_FP16:
+  case AMDGPU::OPERAND_REG_INLINE_AC_BF16:
   case AMDGPU::OPERAND_REG_INLINE_AC_FP16: {
     if (isInt<16>(Imm) || isUInt<16>(Imm)) {
       // A few special case instructions have 16-bit operands on subtargets
diff --git a/llvm/lib/Target/AMDGPU/SIInstrInfo.td b/llvm/lib/Target/AMDGPU/SIInstrInfo.td
index 22599773d562cb..b0daec4a350eb3 100644
--- a/llvm/lib/Target/AMDGPU/SIInstrInfo.td
+++ b/llvm/lib/Target/AMDGPU/SIInstrInfo.td
@@ -1497,20 +1497,17 @@ class getVOPSrc0ForVT<ValueType VT, bit IsTrue16, bit IsFake16 = 1> {
   RegisterOperand ret =
     !if(VT.isFP,
       !if(!eq(VT.Size, 64),
-         VSrc_f64,
-         !if(!or(!eq(VT.Value, f16.Value), !eq(VT.Value, bf16.Value)),
-            !if(IsTrue16,
-              !if(IsFake16, VSrcFake16_f16_Lo128, VSrcT_f16_Lo128),
-              VSrc_f16
-            ),
-            !if(!or(!eq(VT.Value, v2f16.Value), !eq(VT.Value, v2bf16.Value)),
-               VSrc_v2f16,
-               !if(!or(!eq(VT.Value, v4f16.Value), !eq(VT.Value, v4bf16.Value)),
-                 AVSrc_64,
-                 VSrc_f32
+          VSrc_f64,
+          !if(!eq(VT.Value, f16.Value),
+              !if(IsTrue16, !if(IsFake16, VSrcFake16_f16_Lo128, VSrcT_f16_Lo128), VSrc_f16),
+              !if(!eq(VT.Value, bf16.Value),
+                 !if(IsTrue16, !if(IsFake16, VSrcFake16_bf16_Lo128, VSrcT_bf16_Lo128), VSrc_bf16),
+                 !if(!or(!eq(VT.Value, v2f16.Value), !eq(VT.Value, v2bf16.Value)),
+                     !if(!eq(VT.Value, v2f16.Value), VSrc_v2f16, VSrc_v2bf16),
+                     !if(!or(!eq(VT.Value, v4f16.Value), !eq(VT.Value, v4bf16.Value)), AVSrc_64, VSrc_f32)
+                  )
                )
-            )
-         )
+           )
        ),
        !if(!eq(VT.Size, 64),
           VSrc_b64,
@@ -1569,16 +1566,20 @@ class getVOP3SrcForVT<ValueType VT, bit IsTrue16 = 0> {
         !if(!eq(VT.Value, i1.Value),
            SSrc_i1,
            !if(VT.isFP,
-              !if(!or(!eq(VT.Value, f16.Value), !eq(VT.Value, bf16.Value)),
-                 !if(IsTrue16, VSrcT_f16, VSrc_f16),
-                 !if(!or(!eq(VT.Value, v2f16.Value), !eq(VT.Value, v2bf16.Value)),
-                    VSrc_v2f16,
-                    !if(!or(!eq(VT.Value, v4f16.Value), !eq(VT.Value, v4bf16.Value)),
-                      AVSrc_64,
-                      VSrc_f32
-                    )
-                 )
-              ),
+               !if(!eq(VT.Value, f16.Value),
+                   !if(IsTrue16, VSrcT_f16, VSrc_f16),
+                   !if(!eq(VT.Value, bf16.Value),
+                       !if(IsTrue16, VSrcT_bf16, VSrc_bf16),
+                       !if(!eq(VT.Value, v2f16.Value),
+                           VSrc_v2f16,
+                           !if(!eq(VT.Value, v2bf16.Value),
+                               VSrc_v2bf16,
+                               !if(!or(!eq(VT.Value, v4f16.Value), !eq(VT.Value, v4bf16.Value)),
+                                   AVSrc_64, VSrc_f32)
+                           )
+                       )
+                   )
+               ),
               !if(!eq(VT.Value, i16.Value),
                  !if(IsTrue16, VSrcT_b16, VSrc_b16),
                  !if(!eq(VT.Value, v2i16.Value),
@@ -1597,8 +1598,13 @@ class getVOP3DPPSrcForVT<ValueType VT> {
   RegisterOperand ret =
       !if (!eq(VT.Value, i1.Value), SSrc_i1,
            !if (VT.isFP,
-                !if (!or(!eq(VT.Value, f16.Value), !eq(VT.Value, bf16.Value)), VCSrc_f16,
-                     !if (!or(!eq(VT.Value, v2f16.Value), !eq(VT.Value, v2bf16.Value)), VCSrc_v2f16, VCSrc_f32)),
+                !if(!eq(VT.Value, f16.Value), VCSrc_f16,
+                    !if(!eq(VT.Value, bf16.Value), VCSrc_bf16,
+                        !if(!eq(VT.Value, v2f16.Value), VCSrc_v2f16,
+                            !if(!eq(VT.Value, v2bf16.Value), VCSrc_v2bf16, VCSrc_f32)
+                        )
+                    )
+                ),
                 !if (!eq(VT.Value, i16.Value), VCSrc_b16,
                      !if (!eq(VT.Value, v2i16.Value), VCSrc_v2b16,
                           VCSrc_b32))));
@@ -2528,7 +2534,7 @@ def VOP_V2I16_F32_F32 : VOPProfile <[v2i16, f32, f32, untyped]>;
 def VOP_V2I16_I32_I32 : VOPProfile <[v2i16, i32, i32, untyped]>;
 
 def VOP_F16_V2F16_V2F16_F16 : VOPProfile <[f16, v2f16, v2f16, f16]>;
-def VOP_I16_V2I16_V2I16_I16 : VOPProfile <[i16, v2i16, v2i16, i16]>;
+def VOP_BF16_V2BF16_V2BF16_BF16: VOPProfile <[bf16, v2bf16, v2bf16, bf16]>;
 def VOP_F32_V2I16_V2I16_F32 : VOPProfile <[f32, v2i16, v2i16, f32]>;
 
 def VOP_F32_V2F16_V2F16_V2F16 : VOPProfile <[f32, v2f16, v2f16, v2f16]>;
diff --git a/llvm/lib/Target/AMDGPU/SIRegisterInfo.td b/llvm/lib/Target/AMDGPU/SIRegisterInfo.td
index aabb6c29062114..f24e65304d2052 100644
--- a/llvm/lib/Target/AMDGPU/SIRegisterInfo.td
+++ b/llvm/lib/Target/AMDGPU/SIRegisterInfo.td
@@ -1...
[truncated]

@llvmbot
Copy link
Collaborator

llvmbot commented Feb 8, 2024

@llvm/pr-subscribers-llvm-globalisel

Author: Shilei Tian (shiltian)

Changes

Currently it looks like we generally use i16 to represent bf16 in those tablegen
files. I'm not sure of the reason behind it. My wild guess is the type bf16 was
not available when we enabled the support. This patch is trying to use bf16
directly in those tablegen files, aiming at fixing #79369.


Patch is 32.53 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/80908.diff

14 Files Affected:

  • (modified) clang/lib/CodeGen/CGBuiltin.cpp (-4)
  • (modified) llvm/include/llvm/IR/IntrinsicsAMDGPU.td (+4-4)
  • (modified) llvm/lib/CodeGen/GlobalISel/IRTranslator.cpp (+3-2)
  • (modified) llvm/lib/Target/AMDGPU/AsmParser/AMDGPUAsmParser.cpp (+66)
  • (modified) llvm/lib/Target/AMDGPU/MCTargetDesc/AMDGPUInstPrinter.cpp (+10)
  • (modified) llvm/lib/Target/AMDGPU/MCTargetDesc/AMDGPUMCCodeEmitter.cpp (+7)
  • (modified) llvm/lib/Target/AMDGPU/SIDefines.h (+7)
  • (modified) llvm/lib/Target/AMDGPU/SIInstrInfo.cpp (+7)
  • (modified) llvm/lib/Target/AMDGPU/SIInstrInfo.td (+32-26)
  • (modified) llvm/lib/Target/AMDGPU/SIRegisterInfo.td (+21-1)
  • (modified) llvm/lib/Target/AMDGPU/Utils/AMDGPUBaseInfo.h (+7)
  • (modified) llvm/lib/Target/AMDGPU/VOP3Instructions.td (+1-1)
  • (modified) llvm/test/CodeGen/AMDGPU/llvm.amdgcn.fdot2.bf16.bf16.ll (+18-18)
  • (added) llvm/test/MC/AMDGPU/bf16_imm.s (+8)
diff --git a/clang/lib/CodeGen/CGBuiltin.cpp b/clang/lib/CodeGen/CGBuiltin.cpp
index a7a410dab1a018..daf651917f2a96 100644
--- a/clang/lib/CodeGen/CGBuiltin.cpp
+++ b/clang/lib/CodeGen/CGBuiltin.cpp
@@ -5908,8 +5908,6 @@ RValue CodeGenFunction::EmitBuiltinExpr(const GlobalDecl GD, unsigned BuiltinID,
           }
         }
 
-        assert(ArgValue->getType()->canLosslesslyBitCastTo(PTy) &&
-               "Must be able to losslessly bit cast to param");
         // Cast vector type (e.g., v256i32) to x86_amx, this only happen
         // in amx intrinsics.
         if (PTy->isX86_AMXTy())
@@ -5939,8 +5937,6 @@ RValue CodeGenFunction::EmitBuiltinExpr(const GlobalDecl GD, unsigned BuiltinID,
         }
       }
 
-      assert(V->getType()->canLosslesslyBitCastTo(RetTy) &&
-             "Must be able to losslessly bit cast result type");
       // Cast x86_amx to vector type (e.g., v256i32), this only happen
       // in amx intrinsics.
       if (V->getType()->isX86_AMXTy())
diff --git a/llvm/include/llvm/IR/IntrinsicsAMDGPU.td b/llvm/include/llvm/IR/IntrinsicsAMDGPU.td
index 202fa4e8f4ea81..6795fb7aa0edb8 100644
--- a/llvm/include/llvm/IR/IntrinsicsAMDGPU.td
+++ b/llvm/include/llvm/IR/IntrinsicsAMDGPU.td
@@ -2819,11 +2819,11 @@ def int_amdgcn_fdot2_f16_f16 :
 def int_amdgcn_fdot2_bf16_bf16 :
   ClangBuiltin<"__builtin_amdgcn_fdot2_bf16_bf16">,
   DefaultAttrsIntrinsic<
-    [llvm_i16_ty],   // %r
+    [llvm_bfloat_ty],   // %r
     [
-      llvm_v2i16_ty, // %a
-      llvm_v2i16_ty, // %b
-      llvm_i16_ty    // %c
+      llvm_v2bf16_ty, // %a
+      llvm_v2bf16_ty, // %b
+      llvm_bfloat_ty    // %c
     ],
     [IntrNoMem, IntrSpeculatable]
   >;
diff --git a/llvm/lib/CodeGen/GlobalISel/IRTranslator.cpp b/llvm/lib/CodeGen/GlobalISel/IRTranslator.cpp
index c1d8e890a66edb..828229f3e569e3 100644
--- a/llvm/lib/CodeGen/GlobalISel/IRTranslator.cpp
+++ b/llvm/lib/CodeGen/GlobalISel/IRTranslator.cpp
@@ -1562,8 +1562,9 @@ bool IRTranslator::translateBitCast(const User &U,
 
 bool IRTranslator::translateCast(unsigned Opcode, const User &U,
                                  MachineIRBuilder &MIRBuilder) {
-  if (U.getType()->getScalarType()->isBFloatTy() ||
-      U.getOperand(0)->getType()->getScalarType()->isBFloatTy())
+  if (Opcode != TargetOpcode::G_BITCAST &&
+      (U.getType()->getScalarType()->isBFloatTy() ||
+       U.getOperand(0)->getType()->getScalarType()->isBFloatTy()))
     return false;
   Register Op = getOrCreateVReg(*U.getOperand(0));
   Register Res = getOrCreateVReg(U);
diff --git a/llvm/lib/Target/AMDGPU/AsmParser/AMDGPUAsmParser.cpp b/llvm/lib/Target/AMDGPU/AsmParser/AMDGPUAsmParser.cpp
index a94da992b33859..d6d96c251f7e30 100644
--- a/llvm/lib/Target/AMDGPU/AsmParser/AMDGPUAsmParser.cpp
+++ b/llvm/lib/Target/AMDGPU/AsmParser/AMDGPUAsmParser.cpp
@@ -475,6 +475,8 @@ class AMDGPUOperand : public MCParsedAsmOperand {
 
   bool isSSrcF64() const { return isSCSrc_b64() || isLiteralImm(MVT::f64); }
 
+  bool isSSrc_bf16() const { return isSCSrcB16() || isLiteralImm(MVT::bf16); }
+
   bool isSSrc_f16() const { return isSCSrcB16() || isLiteralImm(MVT::f16); }
 
   bool isSSrcV2F16() const {
@@ -541,22 +543,40 @@ class AMDGPUOperand : public MCParsedAsmOperand {
     return isRegOrInlineNoMods(AMDGPU::VS_64RegClassID, MVT::f64);
   }
 
+  bool isVCSrcTBF16() const {
+    return isRegOrInlineNoMods(AMDGPU::VS_16RegClassID, MVT::bf16);
+  }
+
   bool isVCSrcTF16() const {
     return isRegOrInlineNoMods(AMDGPU::VS_16RegClassID, MVT::f16);
   }
 
+  bool isVCSrcTBF16_Lo128() const {
+    return isRegOrInlineNoMods(AMDGPU::VS_16_Lo128RegClassID, MVT::bf16);
+  }
+
   bool isVCSrcTF16_Lo128() const {
     return isRegOrInlineNoMods(AMDGPU::VS_16_Lo128RegClassID, MVT::f16);
   }
 
+  bool isVCSrcFake16BF16_Lo128() const {
+    return isRegOrInlineNoMods(AMDGPU::VS_32_Lo128RegClassID, MVT::bf16);
+  }
+
   bool isVCSrcFake16F16_Lo128() const {
     return isRegOrInlineNoMods(AMDGPU::VS_32_Lo128RegClassID, MVT::f16);
   }
 
+  bool isVCSrc_bf16() const {
+    return isRegOrInlineNoMods(AMDGPU::VS_32RegClassID, MVT::bf16);
+  }
+
   bool isVCSrc_f16() const {
     return isRegOrInlineNoMods(AMDGPU::VS_32RegClassID, MVT::f16);
   }
 
+  bool isVCSrc_v2bf16() const { return isVCSrc_bf16(); }
+
   bool isVCSrc_v2f16() const { return isVCSrc_f16(); }
 
   bool isVSrc_b32() const {
@@ -597,18 +617,34 @@ class AMDGPUOperand : public MCParsedAsmOperand {
 
   bool isVSrc_f64() const { return isVCSrcF64() || isLiteralImm(MVT::f64); }
 
+  bool isVSrcT_bf16() const { return isVCSrcTBF16() || isLiteralImm(MVT::bf16); }
+
   bool isVSrcT_f16() const { return isVCSrcTF16() || isLiteralImm(MVT::f16); }
 
+  bool isVSrcT_bf16_Lo128() const {
+    return isVCSrcTBF16_Lo128() || isLiteralImm(MVT::bf16);
+  }
+
   bool isVSrcT_f16_Lo128() const {
     return isVCSrcTF16_Lo128() || isLiteralImm(MVT::f16);
   }
 
+  bool isVSrcFake16_bf16_Lo128() const {
+    return isVCSrcFake16BF16_Lo128() || isLiteralImm(MVT::bf16);
+  }
+
   bool isVSrcFake16_f16_Lo128() const {
     return isVCSrcFake16F16_Lo128() || isLiteralImm(MVT::f16);
   }
 
+  bool isVSrc_bf16() const { return isVCSrc_bf16() || isLiteralImm(MVT::bf16); }
+
   bool isVSrc_f16() const { return isVCSrc_f16() || isLiteralImm(MVT::f16); }
 
+  bool isVSrc_v2bf16() const {
+    return isVSrc_bf16() || isLiteralImm(MVT::v2bf16);
+  }
+
   bool isVSrc_v2f16() const { return isVSrc_f16() || isLiteralImm(MVT::v2f16); }
 
   bool isVISrcB32() const {
@@ -635,6 +671,10 @@ class AMDGPUOperand : public MCParsedAsmOperand {
     return isVISrcF16() || isVISrcB32();
   }
 
+  bool isVISrc_64_bf16() const {
+    return isRegOrInlineNoMods(AMDGPU::VReg_64RegClassID, MVT::bf16);
+  }
+
   bool isVISrc_64_f16() const {
     return isRegOrInlineNoMods(AMDGPU::VReg_64RegClassID, MVT::f16);
   }
@@ -803,6 +843,10 @@ class AMDGPUOperand : public MCParsedAsmOperand {
     return isAISrc_128F16() || isAISrc_128_b32();
   }
 
+  bool isVISrc_128_bf16() const {
+    return isRegOrInlineNoMods(AMDGPU::VReg_128RegClassID, MVT::bf16);
+  }
+
   bool isVISrc_128_f16() const {
     return isRegOrInlineNoMods(AMDGPU::VReg_128RegClassID, MVT::f16);
   }
@@ -1890,6 +1934,14 @@ static const fltSemantics *getOpFltSemantics(uint8_t OperandType) {
   case AMDGPU::OPERAND_REG_IMM_V2FP16:
   case AMDGPU::OPERAND_KIMM16:
     return &APFloat::IEEEhalf();
+  case AMDGPU::OPERAND_REG_IMM_BF16:
+  case AMDGPU::OPERAND_REG_IMM_BF16_DEFERRED:
+  case AMDGPU::OPERAND_REG_INLINE_C_BF16:
+  case AMDGPU::OPERAND_REG_INLINE_C_V2BF16:
+  case AMDGPU::OPERAND_REG_INLINE_AC_BF16:
+  case AMDGPU::OPERAND_REG_INLINE_AC_V2BF16:
+  case AMDGPU::OPERAND_REG_IMM_V2BF16:
+    return &APFloat::BFloat();
   default:
     llvm_unreachable("unsupported fp type");
   }
@@ -2186,17 +2238,24 @@ void AMDGPUOperand::addLiteralImmOperand(MCInst &Inst, int64_t Val, bool ApplyMo
     case AMDGPU::OPERAND_REG_INLINE_AC_INT32:
     case AMDGPU::OPERAND_REG_INLINE_AC_FP32:
     case AMDGPU::OPERAND_REG_IMM_INT16:
+    case AMDGPU::OPERAND_REG_IMM_BF16:
     case AMDGPU::OPERAND_REG_IMM_FP16:
+    case AMDGPU::OPERAND_REG_IMM_BF16_DEFERRED:
     case AMDGPU::OPERAND_REG_IMM_FP16_DEFERRED:
     case AMDGPU::OPERAND_REG_INLINE_C_INT16:
+    case AMDGPU::OPERAND_REG_INLINE_C_BF16:
     case AMDGPU::OPERAND_REG_INLINE_C_FP16:
     case AMDGPU::OPERAND_REG_INLINE_C_V2INT16:
+    case AMDGPU::OPERAND_REG_INLINE_C_V2BF16:
     case AMDGPU::OPERAND_REG_INLINE_C_V2FP16:
     case AMDGPU::OPERAND_REG_INLINE_AC_INT16:
+    case AMDGPU::OPERAND_REG_INLINE_AC_BF16:
     case AMDGPU::OPERAND_REG_INLINE_AC_FP16:
     case AMDGPU::OPERAND_REG_INLINE_AC_V2INT16:
+    case AMDGPU::OPERAND_REG_INLINE_AC_V2BF16:
     case AMDGPU::OPERAND_REG_INLINE_AC_V2FP16:
     case AMDGPU::OPERAND_REG_IMM_V2INT16:
+    case AMDGPU::OPERAND_REG_IMM_V2BF16:
     case AMDGPU::OPERAND_REG_IMM_V2FP16:
     case AMDGPU::OPERAND_REG_INLINE_C_V2FP32:
     case AMDGPU::OPERAND_REG_IMM_V2FP32:
@@ -2240,6 +2299,7 @@ void AMDGPUOperand::addLiteralImmOperand(MCInst &Inst, int64_t Val, bool ApplyMo
   case AMDGPU::OPERAND_REG_INLINE_AC_INT32:
   case AMDGPU::OPERAND_REG_INLINE_AC_FP32:
   case AMDGPU::OPERAND_REG_IMM_V2INT16:
+  case AMDGPU::OPERAND_REG_IMM_V2BF16:
   case AMDGPU::OPERAND_REG_IMM_V2FP16:
   case AMDGPU::OPERAND_REG_IMM_V2FP32:
   case AMDGPU::OPERAND_REG_INLINE_C_V2FP32:
@@ -2277,11 +2337,15 @@ void AMDGPUOperand::addLiteralImmOperand(MCInst &Inst, int64_t Val, bool ApplyMo
     return;
 
   case AMDGPU::OPERAND_REG_IMM_INT16:
+  case AMDGPU::OPERAND_REG_IMM_BF16:
   case AMDGPU::OPERAND_REG_IMM_FP16:
+  case AMDGPU::OPERAND_REG_IMM_BF16_DEFERRED:
   case AMDGPU::OPERAND_REG_IMM_FP16_DEFERRED:
   case AMDGPU::OPERAND_REG_INLINE_C_INT16:
+  case AMDGPU::OPERAND_REG_INLINE_C_BF16:
   case AMDGPU::OPERAND_REG_INLINE_C_FP16:
   case AMDGPU::OPERAND_REG_INLINE_AC_INT16:
+  case AMDGPU::OPERAND_REG_INLINE_AC_BF16:
   case AMDGPU::OPERAND_REG_INLINE_AC_FP16:
     if (isSafeTruncation(Val, 16) &&
         AMDGPU::isInlinableLiteral16(static_cast<int16_t>(Val),
@@ -2296,8 +2360,10 @@ void AMDGPUOperand::addLiteralImmOperand(MCInst &Inst, int64_t Val, bool ApplyMo
     return;
 
   case AMDGPU::OPERAND_REG_INLINE_C_V2INT16:
+  case AMDGPU::OPERAND_REG_INLINE_C_V2BF16:
   case AMDGPU::OPERAND_REG_INLINE_C_V2FP16:
   case AMDGPU::OPERAND_REG_INLINE_AC_V2INT16:
+  case AMDGPU::OPERAND_REG_INLINE_AC_V2BF16:
   case AMDGPU::OPERAND_REG_INLINE_AC_V2FP16: {
     assert(isSafeTruncation(Val, 16));
     assert(AMDGPU::isInlinableLiteral16(static_cast<int16_t>(Val),
diff --git a/llvm/lib/Target/AMDGPU/MCTargetDesc/AMDGPUInstPrinter.cpp b/llvm/lib/Target/AMDGPU/MCTargetDesc/AMDGPUInstPrinter.cpp
index abfa4a3531e8e1..96a0168f37e405 100644
--- a/llvm/lib/Target/AMDGPU/MCTargetDesc/AMDGPUInstPrinter.cpp
+++ b/llvm/lib/Target/AMDGPU/MCTargetDesc/AMDGPUInstPrinter.cpp
@@ -521,8 +521,11 @@ void AMDGPUInstPrinter::printImmediateV216(uint32_t Imm, uint8_t OpType,
     if (printImmediateFloat32(Imm, STI, O))
       return;
     break;
+  case AMDGPU::OPERAND_REG_IMM_V2BF16:
   case AMDGPU::OPERAND_REG_IMM_V2FP16:
+  case AMDGPU::OPERAND_REG_INLINE_C_V2BF16:
   case AMDGPU::OPERAND_REG_INLINE_C_V2FP16:
+  case AMDGPU::OPERAND_REG_INLINE_AC_V2BF16:
   case AMDGPU::OPERAND_REG_INLINE_AC_V2FP16:
     if (isUInt<16>(Imm) &&
         printImmediateFloat16(static_cast<uint16_t>(Imm), STI, O))
@@ -792,17 +795,24 @@ void AMDGPUInstPrinter::printRegularOperand(const MCInst *MI, unsigned OpNo,
     case AMDGPU::OPERAND_REG_IMM_INT16:
       printImmediateInt16(Op.getImm(), STI, O);
       break;
+    case AMDGPU::OPERAND_REG_INLINE_C_BF16:
     case AMDGPU::OPERAND_REG_INLINE_C_FP16:
+    case AMDGPU::OPERAND_REG_INLINE_AC_BF16:
     case AMDGPU::OPERAND_REG_INLINE_AC_FP16:
+    case AMDGPU::OPERAND_REG_IMM_BF16:
     case AMDGPU::OPERAND_REG_IMM_FP16:
+    case AMDGPU::OPERAND_REG_IMM_BF16_DEFERRED:
     case AMDGPU::OPERAND_REG_IMM_FP16_DEFERRED:
       printImmediate16(Op.getImm(), STI, O);
       break;
     case AMDGPU::OPERAND_REG_IMM_V2INT16:
+    case AMDGPU::OPERAND_REG_IMM_V2BF16:
     case AMDGPU::OPERAND_REG_IMM_V2FP16:
     case AMDGPU::OPERAND_REG_INLINE_C_V2INT16:
     case AMDGPU::OPERAND_REG_INLINE_AC_V2INT16:
+    case AMDGPU::OPERAND_REG_INLINE_C_V2BF16:
     case AMDGPU::OPERAND_REG_INLINE_C_V2FP16:
+    case AMDGPU::OPERAND_REG_INLINE_AC_V2BF16:
     case AMDGPU::OPERAND_REG_INLINE_AC_V2FP16:
       printImmediateV216(Op.getImm(), OpTy, STI, O);
       break;
diff --git a/llvm/lib/Target/AMDGPU/MCTargetDesc/AMDGPUMCCodeEmitter.cpp b/llvm/lib/Target/AMDGPU/MCTargetDesc/AMDGPUMCCodeEmitter.cpp
index 11f5e456e8d348..9ec174ba56c242 100644
--- a/llvm/lib/Target/AMDGPU/MCTargetDesc/AMDGPUMCCodeEmitter.cpp
+++ b/llvm/lib/Target/AMDGPU/MCTargetDesc/AMDGPUMCCodeEmitter.cpp
@@ -276,9 +276,13 @@ AMDGPUMCCodeEmitter::getLitEncoding(const MCOperand &MO,
   case AMDGPU::OPERAND_REG_INLINE_C_INT16:
   case AMDGPU::OPERAND_REG_INLINE_AC_INT16:
     return getLit16IntEncoding(static_cast<uint16_t>(Imm), STI);
+  case AMDGPU::OPERAND_REG_IMM_BF16:
   case AMDGPU::OPERAND_REG_IMM_FP16:
+  case AMDGPU::OPERAND_REG_IMM_BF16_DEFERRED:
   case AMDGPU::OPERAND_REG_IMM_FP16_DEFERRED:
+  case AMDGPU::OPERAND_REG_INLINE_C_BF16:
   case AMDGPU::OPERAND_REG_INLINE_C_FP16:
+  case AMDGPU::OPERAND_REG_INLINE_AC_BF16:
   case AMDGPU::OPERAND_REG_INLINE_AC_FP16:
     // FIXME Is this correct? What do inline immediates do on SI for f16 src
     // which does not have f16 support?
@@ -288,8 +292,11 @@ AMDGPUMCCodeEmitter::getLitEncoding(const MCOperand &MO,
   case AMDGPU::OPERAND_REG_INLINE_AC_V2INT16:
     return AMDGPU::getInlineEncodingV2I16(static_cast<uint32_t>(Imm))
         .value_or(255);
+  case AMDGPU::OPERAND_REG_IMM_V2BF16:
   case AMDGPU::OPERAND_REG_IMM_V2FP16:
+  case AMDGPU::OPERAND_REG_INLINE_C_V2BF16:
   case AMDGPU::OPERAND_REG_INLINE_C_V2FP16:
+  case AMDGPU::OPERAND_REG_INLINE_AC_V2BF16:
   case AMDGPU::OPERAND_REG_INLINE_AC_V2FP16:
     return AMDGPU::getInlineEncodingV2F16(static_cast<uint32_t>(Imm))
         .value_or(255);
diff --git a/llvm/lib/Target/AMDGPU/SIDefines.h b/llvm/lib/Target/AMDGPU/SIDefines.h
index 19596d53b45328..66b997eb180613 100644
--- a/llvm/lib/Target/AMDGPU/SIDefines.h
+++ b/llvm/lib/Target/AMDGPU/SIDefines.h
@@ -196,9 +196,12 @@ enum OperandType : unsigned {
   OPERAND_REG_IMM_INT16,
   OPERAND_REG_IMM_FP32,
   OPERAND_REG_IMM_FP64,
+  OPERAND_REG_IMM_BF16,
   OPERAND_REG_IMM_FP16,
+  OPERAND_REG_IMM_BF16_DEFERRED,
   OPERAND_REG_IMM_FP16_DEFERRED,
   OPERAND_REG_IMM_FP32_DEFERRED,
+  OPERAND_REG_IMM_V2BF16,
   OPERAND_REG_IMM_V2FP16,
   OPERAND_REG_IMM_V2INT16,
   OPERAND_REG_IMM_V2INT32,
@@ -208,10 +211,12 @@ enum OperandType : unsigned {
   OPERAND_REG_INLINE_C_INT16,
   OPERAND_REG_INLINE_C_INT32,
   OPERAND_REG_INLINE_C_INT64,
+  OPERAND_REG_INLINE_C_BF16,
   OPERAND_REG_INLINE_C_FP16,
   OPERAND_REG_INLINE_C_FP32,
   OPERAND_REG_INLINE_C_FP64,
   OPERAND_REG_INLINE_C_V2INT16,
+  OPERAND_REG_INLINE_C_V2BF16,
   OPERAND_REG_INLINE_C_V2FP16,
   OPERAND_REG_INLINE_C_V2INT32,
   OPERAND_REG_INLINE_C_V2FP32,
@@ -226,10 +231,12 @@ enum OperandType : unsigned {
   /// Operands with an AccVGPR register or inline constant
   OPERAND_REG_INLINE_AC_INT16,
   OPERAND_REG_INLINE_AC_INT32,
+  OPERAND_REG_INLINE_AC_BF16,
   OPERAND_REG_INLINE_AC_FP16,
   OPERAND_REG_INLINE_AC_FP32,
   OPERAND_REG_INLINE_AC_FP64,
   OPERAND_REG_INLINE_AC_V2INT16,
+  OPERAND_REG_INLINE_AC_V2BF16,
   OPERAND_REG_INLINE_AC_V2FP16,
   OPERAND_REG_INLINE_AC_V2INT32,
   OPERAND_REG_INLINE_AC_V2FP32,
diff --git a/llvm/lib/Target/AMDGPU/SIInstrInfo.cpp b/llvm/lib/Target/AMDGPU/SIInstrInfo.cpp
index c7628bd354309c..fcb2a6f1f3d75d 100644
--- a/llvm/lib/Target/AMDGPU/SIInstrInfo.cpp
+++ b/llvm/lib/Target/AMDGPU/SIInstrInfo.cpp
@@ -4181,13 +4181,20 @@ bool SIInstrInfo::isInlineConstant(const MachineOperand &MO,
   case AMDGPU::OPERAND_REG_INLINE_C_V2INT16:
   case AMDGPU::OPERAND_REG_INLINE_AC_V2INT16:
     return AMDGPU::isInlinableLiteralV2I16(Imm);
+  case AMDGPU::OPERAND_REG_IMM_V2BF16:
   case AMDGPU::OPERAND_REG_IMM_V2FP16:
+  case AMDGPU::OPERAND_REG_INLINE_C_V2BF16:
   case AMDGPU::OPERAND_REG_INLINE_C_V2FP16:
+  case AMDGPU::OPERAND_REG_INLINE_AC_V2BF16:
   case AMDGPU::OPERAND_REG_INLINE_AC_V2FP16:
     return AMDGPU::isInlinableLiteralV2F16(Imm);
+  case AMDGPU::OPERAND_REG_IMM_BF16:
   case AMDGPU::OPERAND_REG_IMM_FP16:
+  case AMDGPU::OPERAND_REG_IMM_BF16_DEFERRED:
   case AMDGPU::OPERAND_REG_IMM_FP16_DEFERRED:
+  case AMDGPU::OPERAND_REG_INLINE_C_BF16:
   case AMDGPU::OPERAND_REG_INLINE_C_FP16:
+  case AMDGPU::OPERAND_REG_INLINE_AC_BF16:
   case AMDGPU::OPERAND_REG_INLINE_AC_FP16: {
     if (isInt<16>(Imm) || isUInt<16>(Imm)) {
       // A few special case instructions have 16-bit operands on subtargets
diff --git a/llvm/lib/Target/AMDGPU/SIInstrInfo.td b/llvm/lib/Target/AMDGPU/SIInstrInfo.td
index 22599773d562cb..b0daec4a350eb3 100644
--- a/llvm/lib/Target/AMDGPU/SIInstrInfo.td
+++ b/llvm/lib/Target/AMDGPU/SIInstrInfo.td
@@ -1497,20 +1497,17 @@ class getVOPSrc0ForVT<ValueType VT, bit IsTrue16, bit IsFake16 = 1> {
   RegisterOperand ret =
     !if(VT.isFP,
       !if(!eq(VT.Size, 64),
-         VSrc_f64,
-         !if(!or(!eq(VT.Value, f16.Value), !eq(VT.Value, bf16.Value)),
-            !if(IsTrue16,
-              !if(IsFake16, VSrcFake16_f16_Lo128, VSrcT_f16_Lo128),
-              VSrc_f16
-            ),
-            !if(!or(!eq(VT.Value, v2f16.Value), !eq(VT.Value, v2bf16.Value)),
-               VSrc_v2f16,
-               !if(!or(!eq(VT.Value, v4f16.Value), !eq(VT.Value, v4bf16.Value)),
-                 AVSrc_64,
-                 VSrc_f32
+          VSrc_f64,
+          !if(!eq(VT.Value, f16.Value),
+              !if(IsTrue16, !if(IsFake16, VSrcFake16_f16_Lo128, VSrcT_f16_Lo128), VSrc_f16),
+              !if(!eq(VT.Value, bf16.Value),
+                 !if(IsTrue16, !if(IsFake16, VSrcFake16_bf16_Lo128, VSrcT_bf16_Lo128), VSrc_bf16),
+                 !if(!or(!eq(VT.Value, v2f16.Value), !eq(VT.Value, v2bf16.Value)),
+                     !if(!eq(VT.Value, v2f16.Value), VSrc_v2f16, VSrc_v2bf16),
+                     !if(!or(!eq(VT.Value, v4f16.Value), !eq(VT.Value, v4bf16.Value)), AVSrc_64, VSrc_f32)
+                  )
                )
-            )
-         )
+           )
        ),
        !if(!eq(VT.Size, 64),
           VSrc_b64,
@@ -1569,16 +1566,20 @@ class getVOP3SrcForVT<ValueType VT, bit IsTrue16 = 0> {
         !if(!eq(VT.Value, i1.Value),
            SSrc_i1,
            !if(VT.isFP,
-              !if(!or(!eq(VT.Value, f16.Value), !eq(VT.Value, bf16.Value)),
-                 !if(IsTrue16, VSrcT_f16, VSrc_f16),
-                 !if(!or(!eq(VT.Value, v2f16.Value), !eq(VT.Value, v2bf16.Value)),
-                    VSrc_v2f16,
-                    !if(!or(!eq(VT.Value, v4f16.Value), !eq(VT.Value, v4bf16.Value)),
-                      AVSrc_64,
-                      VSrc_f32
-                    )
-                 )
-              ),
+               !if(!eq(VT.Value, f16.Value),
+                   !if(IsTrue16, VSrcT_f16, VSrc_f16),
+                   !if(!eq(VT.Value, bf16.Value),
+                       !if(IsTrue16, VSrcT_bf16, VSrc_bf16),
+                       !if(!eq(VT.Value, v2f16.Value),
+                           VSrc_v2f16,
+                           !if(!eq(VT.Value, v2bf16.Value),
+                               VSrc_v2bf16,
+                               !if(!or(!eq(VT.Value, v4f16.Value), !eq(VT.Value, v4bf16.Value)),
+                                   AVSrc_64, VSrc_f32)
+                           )
+                       )
+                   )
+               ),
               !if(!eq(VT.Value, i16.Value),
                  !if(IsTrue16, VSrcT_b16, VSrc_b16),
                  !if(!eq(VT.Value, v2i16.Value),
@@ -1597,8 +1598,13 @@ class getVOP3DPPSrcForVT<ValueType VT> {
   RegisterOperand ret =
       !if (!eq(VT.Value, i1.Value), SSrc_i1,
            !if (VT.isFP,
-                !if (!or(!eq(VT.Value, f16.Value), !eq(VT.Value, bf16.Value)), VCSrc_f16,
-                     !if (!or(!eq(VT.Value, v2f16.Value), !eq(VT.Value, v2bf16.Value)), VCSrc_v2f16, VCSrc_f32)),
+                !if(!eq(VT.Value, f16.Value), VCSrc_f16,
+                    !if(!eq(VT.Value, bf16.Value), VCSrc_bf16,
+                        !if(!eq(VT.Value, v2f16.Value), VCSrc_v2f16,
+                            !if(!eq(VT.Value, v2bf16.Value), VCSrc_v2bf16, VCSrc_f32)
+                        )
+                    )
+                ),
                 !if (!eq(VT.Value, i16.Value), VCSrc_b16,
                      !if (!eq(VT.Value, v2i16.Value), VCSrc_v2b16,
                           VCSrc_b32))));
@@ -2528,7 +2534,7 @@ def VOP_V2I16_F32_F32 : VOPProfile <[v2i16, f32, f32, untyped]>;
 def VOP_V2I16_I32_I32 : VOPProfile <[v2i16, i32, i32, untyped]>;
 
 def VOP_F16_V2F16_V2F16_F16 : VOPProfile <[f16, v2f16, v2f16, f16]>;
-def VOP_I16_V2I16_V2I16_I16 : VOPProfile <[i16, v2i16, v2i16, i16]>;
+def VOP_BF16_V2BF16_V2BF16_BF16: VOPProfile <[bf16, v2bf16, v2bf16, bf16]>;
 def VOP_F32_V2I16_V2I16_F32 : VOPProfile <[f32, v2i16, v2i16, f32]>;
 
 def VOP_F32_V2F16_V2F16_V2F16 : VOPProfile <[f32, v2f16, v2f16, v2f16]>;
diff --git a/llvm/lib/Target/AMDGPU/SIRegisterInfo.td b/llvm/lib/Target/AMDGPU/SIRegisterInfo.td
index aabb6c29062114..f24e65304d2052 100644
--- a/llvm/lib/Target/AMDGPU/SIRegisterInfo.td
+++ b/llvm/lib/Target/AMDGPU/SIRegisterInfo.td
@@ -1...
[truncated]

@shiltian
Copy link
Contributor Author

The patch is in a good shape now. I have made two other prime patches (#81674 and #81669). I'll rebase this one once they are landed.

This patch only changes one bf16 instruction with the necessary infrastructure for others. I'll update all of them once this patch is landed.

However, I don't think isInlinableLiteral16 works correctly because the encoding of the floating point inline literals are different for fp16 and bf16, but apparently for now it can only recognize fp16 encoding. This patch at least makes the asm printer work properly. #81345 is trying to fix it correctly, but that is unrelated to this patch.

@shiltian shiltian force-pushed the PR79369 branch 4 times, most recently from 47b96d2 to 7a517ee Compare February 14, 2024 03:12
@shiltian
Copy link
Contributor Author

shiltian commented Feb 16, 2024

I'll create a ticket about the decoder after this patch is landed.

@shiltian shiltian changed the title [RFC][AMDGPU] Use bf16 instead of i16 for bfloat [AMDGPU] Use bf16 instead of i16 for bfloat Feb 16, 2024
return IntImm;

// clang-format off
switch (Val) {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can this call getInlineEncodingV2BF16?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In theory, yes, but for now we can't because getInlineEncodingV2BF16 can't handle some cases (that I didn't dig yet). It looks like in the conversion between uint16_t and uint32_t that makes some test cases fail. IMO we need to unify them (not only for 16-bit) in one place instead of having almost the same logic at least in three places.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah, I really don't like having 4 different copies of this list of hex values (0x3f00, 0xbf00...).

@@ -2652,6 +2652,23 @@ bool isInlinableLiteral32(int32_t Literal, bool HasInv2Pi) {
(Val == 0x3e22f983 && HasInv2Pi);
}

bool isInlinableLiteralBF16(int16_t Literal, bool HasInv2Pi) {
if (!HasInv2Pi)
return false;
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It does not change the behavior, but generally it shall only matter when you compare value to 0x3E22.

llvm/test/MC/AMDGPU/bf16_imm.s Show resolved Hide resolved
Currently it looks like we generally use `i16` to represent `bf16` in those tablegen
files. I'm not sure of the reason behind it. My wild guess is the type `bf16` was
not available when we enabled the support. This patch is trying to use `bf16`
directly in those tablegen files, aiming at fixing llvm#79369. Of course for llvm#79369
a workaround can be to treat all `INT16` variants as `BFloat` in `getOpFltSemantics`,
but it doesn't look good IMHO.

Since I'm fairly new to AMDGPU backend, I'd appreciate it if you can point out
where I don't understand correctly.
Copy link
Collaborator

@rampitec rampitec left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks. There are definitely at least 2 outstanding problems, but it seems there are no regressions comparing to what we have now. LGTM.

@shiltian shiltian merged commit 46734aa into llvm:main Feb 16, 2024
2 of 4 checks passed
@shiltian shiltian deleted the PR79369 branch February 16, 2024 20:58
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
backend:AMDGPU clang:codegen clang Clang issues not falling into any other category llvm:globalisel llvm:ir mc Machine (object) code
Projects
None yet
Development

Successfully merging this pull request may close these issues.

[AMDGPU] Incorrect parsing of bf16 literals
5 participants