Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[SPIR-V] Add support for the SPIR-V extension SPV_INTEL_bfloat16_conversion #83443

Conversation

VyacheslavLevytskyy
Copy link
Contributor

This PR is to add support for the SPIR-V extension SPV_INTEL_bfloat16_conversion (https://github.com/KhronosGroup/SPIRV-Registry/blob/main/extensions/INTEL/SPV_INTEL_bfloat16_conversion.asciidoc) and OpenCL extension cl_intel_bfloat16_conversions (https://registry.khronos.org/OpenCL/extensions/intel/cl_intel_bfloat16_conversions.html).

@llvmbot
Copy link
Collaborator

llvmbot commented Feb 29, 2024

@llvm/pr-subscribers-backend-spir-v

Author: Vyacheslav Levytskyy (VyacheslavLevytskyy)

Changes

This PR is to add support for the SPIR-V extension SPV_INTEL_bfloat16_conversion (https://github.com/KhronosGroup/SPIRV-Registry/blob/main/extensions/INTEL/SPV_INTEL_bfloat16_conversion.asciidoc) and OpenCL extension cl_intel_bfloat16_conversions (https://registry.khronos.org/OpenCL/extensions/intel/cl_intel_bfloat16_conversions.html).


Patch is 22.55 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/83443.diff

13 Files Affected:

  • (modified) llvm/lib/Target/SPIRV/SPIRVBuiltins.cpp (+49-8)
  • (modified) llvm/lib/Target/SPIRV/SPIRVBuiltins.td (+23-1)
  • (modified) llvm/lib/Target/SPIRV/SPIRVGlobalRegistry.cpp (+9)
  • (modified) llvm/lib/Target/SPIRV/SPIRVGlobalRegistry.h (+4)
  • (modified) llvm/lib/Target/SPIRV/SPIRVInstrInfo.td (+4)
  • (modified) llvm/lib/Target/SPIRV/SPIRVModuleAnalysis.cpp (+7)
  • (modified) llvm/lib/Target/SPIRV/SPIRVSubtarget.cpp (+4)
  • (modified) llvm/lib/Target/SPIRV/SPIRVSymbolicOperands.td (+2)
  • (added) llvm/test/CodeGen/SPIRV/extensions/SPV_INTEL_bfloat16_conversion/bfloat16-conv-negative1.ll (+12)
  • (added) llvm/test/CodeGen/SPIRV/extensions/SPV_INTEL_bfloat16_conversion/bfloat16-conv-negative2.ll (+12)
  • (added) llvm/test/CodeGen/SPIRV/extensions/SPV_INTEL_bfloat16_conversion/bfloat16-conv-negative3.ll (+12)
  • (added) llvm/test/CodeGen/SPIRV/extensions/SPV_INTEL_bfloat16_conversion/bfloat16-conv-negative4.ll (+13)
  • (added) llvm/test/CodeGen/SPIRV/extensions/SPV_INTEL_bfloat16_conversion/bfloat16-conv.ll (+96)
diff --git a/llvm/lib/Target/SPIRV/SPIRVBuiltins.cpp b/llvm/lib/Target/SPIRV/SPIRVBuiltins.cpp
index c1bb27322443ff..5652ab5bcd9462 100644
--- a/llvm/lib/Target/SPIRV/SPIRVBuiltins.cpp
+++ b/llvm/lib/Target/SPIRV/SPIRVBuiltins.cpp
@@ -134,6 +134,7 @@ struct ConvertBuiltin {
   bool IsDestinationSigned;
   bool IsSaturated;
   bool IsRounded;
+  bool IsBfloat16;
   FPRoundingMode::FPRoundingMode RoundingMode;
 };
 
@@ -1986,6 +1987,8 @@ static bool generateConvertInst(const StringRef DemangledCall,
                     SPIRV::Decoration::FPRoundingMode,
                     {(unsigned)Builtin->RoundingMode});
 
+  std::string NeedExtMsg;              // no errors if empty
+  bool IsRightComponentsNumber = true; // check if input/output accepts vectors
   unsigned Opcode = SPIRV::OpNop;
   if (GR->isScalarOrVectorOfType(Call->Arguments[0], SPIRV::OpTypeInt)) {
     // Int -> ...
@@ -2000,23 +2003,61 @@ static bool generateConvertInst(const StringRef DemangledCall,
     } else if (GR->isScalarOrVectorOfType(Call->ReturnRegister,
                                           SPIRV::OpTypeFloat)) {
       // Int -> Float
-      bool IsSourceSigned =
-          DemangledCall[DemangledCall.find_first_of('(') + 1] != 'u';
-      Opcode = IsSourceSigned ? SPIRV::OpConvertSToF : SPIRV::OpConvertUToF;
+      if (Builtin->IsBfloat16) {
+        const auto *ST = static_cast<const SPIRVSubtarget *>(
+            &MIRBuilder.getMF().getSubtarget());
+        if (!ST->canUseExtension(
+                SPIRV::Extension::SPV_INTEL_bfloat16_conversion))
+          NeedExtMsg = "SPV_INTEL_bfloat16_conversion";
+        IsRightComponentsNumber =
+            GR->getScalarOrVectorComponentCount(Call->Arguments[0]) ==
+            GR->getScalarOrVectorComponentCount(Call->ReturnRegister);
+        Opcode = SPIRV::OpConvertBF16ToFINTEL;
+      } else {
+        bool IsSourceSigned =
+            DemangledCall[DemangledCall.find_first_of('(') + 1] != 'u';
+        Opcode = IsSourceSigned ? SPIRV::OpConvertSToF : SPIRV::OpConvertUToF;
+      }
     }
   } else if (GR->isScalarOrVectorOfType(Call->Arguments[0],
                                         SPIRV::OpTypeFloat)) {
     // Float -> ...
-    if (GR->isScalarOrVectorOfType(Call->ReturnRegister, SPIRV::OpTypeInt))
+    if (GR->isScalarOrVectorOfType(Call->ReturnRegister, SPIRV::OpTypeInt)) {
       // Float -> Int
-      Opcode = Builtin->IsDestinationSigned ? SPIRV::OpConvertFToS
-                                            : SPIRV::OpConvertFToU;
-    else if (GR->isScalarOrVectorOfType(Call->ReturnRegister,
-                                        SPIRV::OpTypeFloat))
+      if (Builtin->IsBfloat16) {
+        const auto *ST = static_cast<const SPIRVSubtarget *>(
+            &MIRBuilder.getMF().getSubtarget());
+        if (!ST->canUseExtension(
+                SPIRV::Extension::SPV_INTEL_bfloat16_conversion))
+          NeedExtMsg = "SPV_INTEL_bfloat16_conversion";
+        IsRightComponentsNumber =
+            GR->getScalarOrVectorComponentCount(Call->Arguments[0]) ==
+            GR->getScalarOrVectorComponentCount(Call->ReturnRegister);
+        Opcode = SPIRV::OpConvertFToBF16INTEL;
+      } else {
+        Opcode = Builtin->IsDestinationSigned ? SPIRV::OpConvertFToS
+                                              : SPIRV::OpConvertFToU;
+      }
+    } else if (GR->isScalarOrVectorOfType(Call->ReturnRegister,
+                                          SPIRV::OpTypeFloat)) {
       // Float -> Float
       Opcode = SPIRV::OpFConvert;
+    }
   }
 
+  if (!NeedExtMsg.empty()) {
+    std::string DiagMsg = std::string(Builtin->Name) +
+                          ": the builtin requires the following SPIR-V "
+                          "extension: " +
+                          NeedExtMsg;
+    report_fatal_error(DiagMsg.c_str(), false);
+  }
+  if (!IsRightComponentsNumber) {
+    std::string DiagMsg =
+        std::string(Builtin->Name) +
+        ": result and argument must have the same number of components";
+    report_fatal_error(DiagMsg.c_str(), false);
+  }
   assert(Opcode != SPIRV::OpNop &&
          "Conversion between the types not implemented!");
 
diff --git a/llvm/lib/Target/SPIRV/SPIRVBuiltins.td b/llvm/lib/Target/SPIRV/SPIRVBuiltins.td
index 28a63b93b43b6e..eb26f70b1861f2 100644
--- a/llvm/lib/Target/SPIRV/SPIRVBuiltins.td
+++ b/llvm/lib/Target/SPIRV/SPIRVBuiltins.td
@@ -1177,6 +1177,8 @@ class ConvertBuiltin<string name, InstructionSet set> {
   bit IsDestinationSigned = !eq(!find(name, "convert_u"), -1);
   bit IsSaturated = !not(!eq(!find(name, "_sat"), -1));
   bit IsRounded = !not(!eq(!find(name, "_rt"), -1));
+  bit IsBfloat16 = !or(!not(!eq(!find(name, "BF16"), -1)),
+                       !not(!eq(!find(name, "bfloat16"), -1)));
   FPRoundingMode RoundingMode = !cond(!not(!eq(!find(name, "_rte"), -1)) : RTE,
                                   !not(!eq(!find(name, "_rtz"), -1)) : RTZ,
                                   !not(!eq(!find(name, "_rtp"), -1)) : RTP,
@@ -1187,7 +1189,8 @@ class ConvertBuiltin<string name, InstructionSet set> {
 // Table gathering all the convert builtins.
 def ConvertBuiltins : GenericTable {
   let FilterClass = "ConvertBuiltin";
-  let Fields = ["Name", "Set", "IsDestinationSigned", "IsSaturated", "IsRounded", "RoundingMode"];
+  let Fields = ["Name", "Set", "IsDestinationSigned", "IsSaturated",
+                "IsRounded", "IsBfloat16", "RoundingMode"];
   string TypeOf_Set = "InstructionSet";
   string TypeOf_RoundingMode = "FPRoundingMode";
 }
@@ -1229,6 +1232,25 @@ defm : DemangledConvertBuiltin<"convert_long", OpenCL_std>;
 defm : DemangledConvertBuiltin<"convert_ulong", OpenCL_std>;
 defm : DemangledConvertBuiltin<"convert_float", OpenCL_std>;
 
+// cl_intel_bfloat16_conversions / SPV_INTEL_bfloat16_conversion
+// Multiclass used to define at the same time both a demangled builtin records
+// and a corresponding convert builtin records.
+multiclass DemangledBF16ConvertBuiltin<string name1, string name2> {
+  // Create records for scalar and vector conversions.
+  foreach i = ["", "2", "3", "4", "8", "16"] in {
+    def : DemangledBuiltin<!strconcat("intel_convert_", name1, i, name2, i), OpenCL_std, Convert, 1, 1>;
+    def : ConvertBuiltin<!strconcat("intel_convert_", name1, i, name2, i), OpenCL_std>;
+  }
+}
+
+defm : DemangledBF16ConvertBuiltin<"bfloat16", "_as_ushort">;
+defm : DemangledBF16ConvertBuiltin<"as_bfloat16", "_float">;
+
+foreach conv = ["FToBF16INTEL", "BF16ToFINTEL"] in {
+  def : DemangledBuiltin<!strconcat("__spirv_Convert", conv), OpenCL_std, Convert, 1, 1>;
+  def : ConvertBuiltin<!strconcat("__spirv_Convert", conv), OpenCL_std>;
+}
+
 //===----------------------------------------------------------------------===//
 // Class defining a vector data load/store builtin record used for lowering
 // into OpExtInst instruction.
diff --git a/llvm/lib/Target/SPIRV/SPIRVGlobalRegistry.cpp b/llvm/lib/Target/SPIRV/SPIRVGlobalRegistry.cpp
index a1cb630f1aa477..21cba98ca8b6b7 100644
--- a/llvm/lib/Target/SPIRV/SPIRVGlobalRegistry.cpp
+++ b/llvm/lib/Target/SPIRV/SPIRVGlobalRegistry.cpp
@@ -819,6 +819,15 @@ bool SPIRVGlobalRegistry::isScalarOrVectorOfType(Register VReg,
   return false;
 }
 
+unsigned
+SPIRVGlobalRegistry::getScalarOrVectorComponentCount(Register VReg) const {
+  if (SPIRVType *Type = getSPIRVTypeForVReg(VReg))
+    return Type->getOpcode() == SPIRV::OpTypeVector
+               ? static_cast<unsigned>(Type->getOperand(2).getImm())
+               : 1;
+  return 0;
+}
+
 unsigned
 SPIRVGlobalRegistry::getScalarOrVectorBitWidth(const SPIRVType *Type) const {
   assert(Type && "Invalid Type pointer");
diff --git a/llvm/lib/Target/SPIRV/SPIRVGlobalRegistry.h b/llvm/lib/Target/SPIRV/SPIRVGlobalRegistry.h
index 792a00786f0aaf..965d5b848fcb87 100644
--- a/llvm/lib/Target/SPIRV/SPIRVGlobalRegistry.h
+++ b/llvm/lib/Target/SPIRV/SPIRVGlobalRegistry.h
@@ -197,6 +197,10 @@ class SPIRVGlobalRegistry {
   // opcode (e.g. OpTypeBool, or OpTypeVector %x 4, where %x is OpTypeBool).
   bool isScalarOrVectorOfType(Register VReg, unsigned TypeOpcode) const;
 
+  // Return number of elements in a vector if the given VReg is associated with
+  // a vector type. Return 1 for a scalar type, and 0 for a missing type.
+  unsigned getScalarOrVectorComponentCount(Register VReg) const;
+
   // For vectors or scalars of ints/floats, return the scalar type's bitwidth.
   unsigned getScalarOrVectorBitWidth(const SPIRVType *Type) const;
 
diff --git a/llvm/lib/Target/SPIRV/SPIRVInstrInfo.td b/llvm/lib/Target/SPIRV/SPIRVInstrInfo.td
index fe8c909236cde3..99c57dac4141d8 100644
--- a/llvm/lib/Target/SPIRV/SPIRVInstrInfo.td
+++ b/llvm/lib/Target/SPIRV/SPIRVInstrInfo.td
@@ -443,6 +443,10 @@ def OpBitcast : UnOp<"OpBitcast", 124>;
 def OpPtrCastToCrossWorkgroupINTEL : UnOp<"OpPtrCastToCrossWorkgroupINTEL", 5934>;
 def OpCrossWorkgroupCastToPtrINTEL : UnOp<"OpCrossWorkgroupCastToPtrINTEL", 5938>;
 
+// SPV_INTEL_bfloat16_conversion
+def OpConvertFToBF16INTEL : UnOp<"OpConvertFToBF16INTEL", 6116>;
+def OpConvertBF16ToFINTEL : UnOp<"OpConvertBF16ToFINTEL", 6117>;
+
 // 3.42.12 Composite Instructions
 
 def OpVectorExtractDynamic: Op<77, (outs ID:$res), (ins TYPE:$type, vID:$vec, ID:$idx),
diff --git a/llvm/lib/Target/SPIRV/SPIRVModuleAnalysis.cpp b/llvm/lib/Target/SPIRV/SPIRVModuleAnalysis.cpp
index ac3d6b362d350b..b7be7ffd3f0c61 100644
--- a/llvm/lib/Target/SPIRV/SPIRVModuleAnalysis.cpp
+++ b/llvm/lib/Target/SPIRV/SPIRVModuleAnalysis.cpp
@@ -1110,6 +1110,13 @@ void addInstrRequirements(const MachineInstr &MI,
   case SPIRV::OpAtomicFMaxEXT:
     AddAtomicFloatRequirements(MI, Reqs, ST);
     break;
+  case SPIRV::OpConvertBF16ToFINTEL:
+  case SPIRV::OpConvertFToBF16INTEL:
+    if (ST.canUseExtension(SPIRV::Extension::SPV_INTEL_bfloat16_conversion)) {
+      Reqs.addExtension(SPIRV::Extension::SPV_INTEL_bfloat16_conversion);
+      Reqs.addCapability(SPIRV::Capability::BFloat16ConversionINTEL);
+    }
+    break;
   case SPIRV::OpVariableLengthArrayINTEL:
   case SPIRV::OpSaveMemoryINTEL:
   case SPIRV::OpRestoreMemoryINTEL:
diff --git a/llvm/lib/Target/SPIRV/SPIRVSubtarget.cpp b/llvm/lib/Target/SPIRV/SPIRVSubtarget.cpp
index 0e8952dc6a9c9f..b866def589853f 100644
--- a/llvm/lib/Target/SPIRV/SPIRVSubtarget.cpp
+++ b/llvm/lib/Target/SPIRV/SPIRVSubtarget.cpp
@@ -81,6 +81,10 @@ cl::list<SPIRV::Extension::Extension> Extensions(
             "Allows to use the LinkOnceODR linkage type that is to let "
             "a function or global variable to be merged with other functions "
             "or global variables of the same name when linkage occurs."),
+        clEnumValN(SPIRV::Extension::SPV_INTEL_bfloat16_conversion,
+                   "SPV_INTEL_bfloat16_conversion",
+                   "Adds instructions to convert between single-precision "
+                   "32-bit floating-point values and 16-bit bfloat16 values."),
         clEnumValN(SPIRV::Extension::SPV_KHR_subgroup_rotate,
                    "SPV_KHR_subgroup_rotate",
                    "Adds a new instruction that enables rotating values across "
diff --git a/llvm/lib/Target/SPIRV/SPIRVSymbolicOperands.td b/llvm/lib/Target/SPIRV/SPIRVSymbolicOperands.td
index 211c22340eb82c..8dbbd9049844c8 100644
--- a/llvm/lib/Target/SPIRV/SPIRVSymbolicOperands.td
+++ b/llvm/lib/Target/SPIRV/SPIRVSymbolicOperands.td
@@ -297,6 +297,7 @@ defm SPV_INTEL_fpga_argument_interfaces : ExtensionOperand<102>;
 defm SPV_INTEL_optnone : ExtensionOperand<103>;
 defm SPV_INTEL_function_pointers : ExtensionOperand<104>;
 defm SPV_INTEL_variable_length_array : ExtensionOperand<105>;
+defm SPV_INTEL_bfloat16_conversion : ExtensionOperand<106>;
 
 //===----------------------------------------------------------------------===//
 // Multiclass used to define Capabilities enum values and at the same time
@@ -466,6 +467,7 @@ defm AtomicFloat64MinMaxEXT : CapabilityOperand<5613, 0, 0, [SPV_EXT_shader_atom
 defm VariableLengthArrayINTEL : CapabilityOperand<5817, 0, 0, [SPV_INTEL_variable_length_array], []>;
 defm GroupUniformArithmeticKHR : CapabilityOperand<6400, 0, 0, [SPV_KHR_uniform_group_instructions], []>;
 defm USMStorageClassesINTEL : CapabilityOperand<5935, 0, 0, [SPV_INTEL_usm_storage_classes], [Kernel]>;
+defm BFloat16ConversionINTEL : CapabilityOperand<6115, 0, 0, [SPV_INTEL_bfloat16_conversion], []>;
 
 //===----------------------------------------------------------------------===//
 // Multiclass used to define SourceLanguage enum values and at the same time
diff --git a/llvm/test/CodeGen/SPIRV/extensions/SPV_INTEL_bfloat16_conversion/bfloat16-conv-negative1.ll b/llvm/test/CodeGen/SPIRV/extensions/SPV_INTEL_bfloat16_conversion/bfloat16-conv-negative1.ll
new file mode 100644
index 00000000000000..2f3c859db346df
--- /dev/null
+++ b/llvm/test/CodeGen/SPIRV/extensions/SPV_INTEL_bfloat16_conversion/bfloat16-conv-negative1.ll
@@ -0,0 +1,12 @@
+; RUN: not llc -O0 -mtriple=spirv32-unknown-unknown --spirv-extensions=SPV_INTEL_bfloat16_conversion %s -o %t.spvt 2>&1 | FileCheck %s --check-prefix=CHECK-ERROR
+; CHECK-ERROR: result and argument must have the same number of components
+
+target datalayout = "e-i64:64-v16:16-v24:32-v32:32-v48:64-v96:128-v192:256-v256:256-v512:512-v1024:1024-n8:16:32:64"
+target triple = "spir64-unknown-unknown"
+
+define spir_func void @test(<8 x float> %in) {
+  %res = tail call spir_func i16 @_Z27__spirv_ConvertFToBF16INTELDv8_f(<8 x float> %in)
+  ret void
+}
+
+declare spir_func i16 @_Z27__spirv_ConvertFToBF16INTELDv8_f(<8 x float>)
diff --git a/llvm/test/CodeGen/SPIRV/extensions/SPV_INTEL_bfloat16_conversion/bfloat16-conv-negative2.ll b/llvm/test/CodeGen/SPIRV/extensions/SPV_INTEL_bfloat16_conversion/bfloat16-conv-negative2.ll
new file mode 100644
index 00000000000000..c02d50cfab21d8
--- /dev/null
+++ b/llvm/test/CodeGen/SPIRV/extensions/SPV_INTEL_bfloat16_conversion/bfloat16-conv-negative2.ll
@@ -0,0 +1,12 @@
+; RUN: not llc -O0 -mtriple=spirv32-unknown-unknown --spirv-extensions=SPV_INTEL_bfloat16_conversion %s -o %t.spvt 2>&1 | FileCheck %s --check-prefix=CHECK-ERROR
+; CHECK-ERROR: result and argument must have the same number of components
+
+target datalayout = "e-i64:64-v16:16-v24:32-v32:32-v48:64-v96:128-v192:256-v256:256-v512:512-v1024:1024-n8:16:32:64"
+target triple = "spir64-unknown-unknown"
+
+define spir_func void @test(<8 x float> %in) {
+  %res = tail call spir_func <4 x i16> @_Z27__spirv_ConvertFToBF16INTELDv8_f(<8 x float> %in)
+  ret void
+}
+
+declare spir_func <4 x i16> @_Z27__spirv_ConvertFToBF16INTELDv8_f(<8 x float>)
\ No newline at end of file
diff --git a/llvm/test/CodeGen/SPIRV/extensions/SPV_INTEL_bfloat16_conversion/bfloat16-conv-negative3.ll b/llvm/test/CodeGen/SPIRV/extensions/SPV_INTEL_bfloat16_conversion/bfloat16-conv-negative3.ll
new file mode 100644
index 00000000000000..20a8042ad9c297
--- /dev/null
+++ b/llvm/test/CodeGen/SPIRV/extensions/SPV_INTEL_bfloat16_conversion/bfloat16-conv-negative3.ll
@@ -0,0 +1,12 @@
+; RUN: not llc -O0 -mtriple=spirv32-unknown-unknown --spirv-extensions=SPV_INTEL_bfloat16_conversion %s -o %t.spvt 2>&1 | FileCheck %s --check-prefix=CHECK-ERROR
+; CHECK-ERROR: result and argument must have the same number of components
+
+target datalayout = "e-i64:64-v16:16-v24:32-v32:32-v48:64-v96:128-v192:256-v256:256-v512:512-v1024:1024-n8:16:32:64"
+target triple = "spir64-unknown-unknown"
+
+define spir_func void @test(<8 x i16> %in) {
+  %res = tail call spir_func <4 x float> @_Z27__spirv_ConvertBF16ToFINTELDv8_s(<8 x i16> %in)
+  ret void
+}
+
+declare spir_func <4 x float> @_Z27__spirv_ConvertBF16ToFINTELDv8_s(<8 x i16>)
diff --git a/llvm/test/CodeGen/SPIRV/extensions/SPV_INTEL_bfloat16_conversion/bfloat16-conv-negative4.ll b/llvm/test/CodeGen/SPIRV/extensions/SPV_INTEL_bfloat16_conversion/bfloat16-conv-negative4.ll
new file mode 100644
index 00000000000000..87d26472a4eeb6
--- /dev/null
+++ b/llvm/test/CodeGen/SPIRV/extensions/SPV_INTEL_bfloat16_conversion/bfloat16-conv-negative4.ll
@@ -0,0 +1,13 @@
+; RUN: not llc -O0 -mtriple=spirv32-unknown-unknown --spirv-extensions=SPV_INTEL_bfloat16_conversion %s -o %t.spvt 2>&1 | FileCheck %s --check-prefix=CHECK-ERROR
+; CHECK-ERROR: result and argument must have the same number of components
+
+target datalayout = "e-i64:64-v16:16-v24:32-v32:32-v48:64-v96:128-v192:256-v256:256-v512:512-v1024:1024-n8:16:32:64"
+target triple = "spir64-unknown-unknown"
+
+define spir_func void @test(<8 x i16> %in) {
+  %res = tail call spir_func float @_Z27__spirv_ConvertBF16ToFINTELDv8_s(<8 x i16> %in)
+  ret void
+}
+
+declare spir_func float @_Z27__spirv_ConvertBF16ToFINTELDv8_s(<8 x i16>)
+
diff --git a/llvm/test/CodeGen/SPIRV/extensions/SPV_INTEL_bfloat16_conversion/bfloat16-conv.ll b/llvm/test/CodeGen/SPIRV/extensions/SPV_INTEL_bfloat16_conversion/bfloat16-conv.ll
new file mode 100644
index 00000000000000..2bd59b22322ffd
--- /dev/null
+++ b/llvm/test/CodeGen/SPIRV/extensions/SPV_INTEL_bfloat16_conversion/bfloat16-conv.ll
@@ -0,0 +1,96 @@
+; RUN: llc -O0 -mtriple=spirv32-unknown-unknown --spirv-extensions=SPV_INTEL_bfloat16_conversion %s -o - | FileCheck %s
+; RUN: %if spirv-tools %{ llc -O0 -mtriple=spirv64-unknown-unknown --spirv-extensions=SPV_INTEL_bfloat16_conversion %s -o - -filetype=obj | spirv-val %}
+
+; RUN: not llc -O0 -mtriple=spirv32-unknown-unknown %s -o %t.spvt 2>&1 | FileCheck %s --check-prefix=CHECK-ERROR
+; CHECK-ERROR: the builtin requires the following SPIR-V extension: SPV_INTEL_bfloat16_conversion
+
+; CHECK: OpCapability BFloat16ConversionINTEL
+; CHECK: OpExtension "SPV_INTEL_bfloat16_conversion"
+
+; CHECK-DAG: %[[VoidTy:.*]] = OpTypeVoid
+; CHECK-DAG: %[[Int16Ty:.*]] = OpTypeInt 16 0
+; CHECK-DAG: %[[FP32Ty:.*]] = OpTypeFloat 32
+; CHECK-DAG: %[[VecFloat2:.*]] = OpTypeVector %[[FP32Ty]] 2
+; CHECK-DAG: %[[VecInt162:.*]] = OpTypeVector %[[Int16Ty]] 2
+; CHECK-DAG: %[[VecFloat3:.*]] = OpTypeVector %[[FP32Ty]] 3
+; CHECK-DAG: %[[VecInt163:.*]] = OpTypeVector %[[Int16Ty]] 3
+; CHECK-DAG: %[[VecFloat4:.*]] = OpTypeVector %[[FP32Ty]] 4
+; CHECK-DAG: %[[VecInt164:.*]] = OpTypeVector %[[Int16Ty]] 4
+; CHECK-DAG: %[[VecFloat8:.*]] = OpTypeVector %[[FP32Ty]] 8
+; CHECK-DAG: %[[VecInt168:.*]] = OpTypeVector %[[Int16Ty]] 8
+; CHECK-DAG: %[[VecFloat16:.*]] = OpTypeVector %[[FP32Ty]] 16
+; CHECK-DAG: %[[VecInt1616:.*]] = OpTypeVector %[[Int16Ty]] 16
+; CHECK-DAG: %[[IntConstId:.*]] = OpConstant %[[Int16Ty]] 67
+; CHECK-DAG: %[[FloatConstId:.*]] = OpConstant %[[FP32Ty]] 1.5
+
+; CHECK: OpFunction %[[VoidTy]]
+; CHECK: %[[FP32ValId:.*]] = OpFunctionParameter %[[FP32Ty]]
+; CHECK: %[[FP32v8ValId:.*]] = OpFunctionParameter %[[VecFloat8]]
+
+; CHECK: %[[Int16ValId:.*]] = OpConvertFToBF16INTEL %[[Int16Ty]] %[[FP32ValId]]
+; CHECK: OpConvertBF16ToFINTEL %[[FP32Ty]] %[[Int16ValId]]
+; CHECK: %[[Int16v8ValId:.*]] = OpConvertFToBF16INTEL %[[VecInt168]] %[[FP32v8ValId]]
+; CHECK: OpConvertBF16ToFINTEL %[[VecFloat8]] %[[Int16v8ValId]]
+; CHECK: OpConvertFToBF16INTEL %[[Int16Ty]] %[[FloatConstId]]
+; CHECK: OpConvertBF16ToFINTEL %[[FP32Ty]] %[[IntConstId]]
+
+; CHECK: OpConvertFToBF16INTEL %[[Int16Ty]]
+; CHECK: OpConvertFToBF16INTEL %[[VecInt162]]
+; CHECK: OpConvertFToBF16INTEL %[[VecInt163]]
+; CHECK: OpConvertFToBF16INTEL %[[VecInt164]]
+; CHECK: OpConvertFToBF16INTEL %[[VecInt168]]
+; CHECK: OpConvertFToBF16INTEL %[[VecInt1616]]
+; CHECK: OpConvertBF16ToFINTEL %[[FP32Ty]]
+; CHECK: OpConvertBF16ToFINTEL %[[VecFloat2]]
+; CHECK: OpConvertBF16ToFINTEL %[[VecFloat3]]
+; CHECK: OpConvertBF16ToFINTEL %[[VecFloat4]]
+; CHECK: OpConvertBF16ToFINTEL %[[VecFloat8]]
+; CHECK: OpConvertBF16ToFINTEL %[[VecFloat16]]
+
+target datalayout = "e-i64:64-v16:16-v24:32-v32:32-v48:64-v96:128-v192:256-v256:256-v512:512-v1024:1024-n8:16:32:64"
+target triple = "spir64-unknown-unknown"
+
+define spir_func void @test(float %a, <8 x float> %in) {
+  %res1 = tail call spir_func zeroext i16 @_Z27__spirv_ConvertFToBF16INTELf(float %a)
+  %res2 = tail call spir_func float @_Z27__spirv_ConvertBF16ToFINTELs(i16 zeroext %res1)
+  %res3 = tail call spir_func <8 x i16> @_Z27__spirv_ConvertFToBF16INTELDv8_f(<8 x float> %in)
+  %res4 = tail call spir_func <8 x float> @_Z27__spirv_ConvertBF16ToFINTELDv8_s(<8 x i16> %res3)
+  %res5 = tail call spir_func zeroext i16 @_Z27__spirv_ConvertFToBF16INTELf(float 1.500000e+00)
+  %res6 = tail call spir_func float @_Z27__spirv_ConvertBF16ToFINTELs(i16 67)
+  ret void
+}
+
+declare spir_func zeroext i16 @_Z27__spirv_ConvertFToBF16INTELf(float)
+declare spir_f...
[truncated]

Copy link
Member

@michalpaszkowski michalpaszkowski left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thank you! LGTM!

@VyacheslavLevytskyy VyacheslavLevytskyy merged commit 8f30b62 into llvm:main Mar 4, 2024
3 of 5 checks passed
@MaskRay
Copy link
Member

MaskRay commented Mar 5, 2024

This commit is associated with a users.noreply.github.com email address. Per
https://discourse.llvm.org/t/hidden-emails-on-github-should-we-do-something-about-it/74223/32 , proper email addresses are preferred.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

4 participants