diff --git a/llvm/include/llvm/CodeGenTypes/LowLevelType.h b/llvm/include/llvm/CodeGenTypes/LowLevelType.h index d8e0848aff84d..0aec3d2537d9c 100644 --- a/llvm/include/llvm/CodeGenTypes/LowLevelType.h +++ b/llvm/include/llvm/CodeGenTypes/LowLevelType.h @@ -30,6 +30,7 @@ #include "llvm/CodeGenTypes/MachineValueType.h" #include "llvm/Support/Compiler.h" #include "llvm/Support/Debug.h" +#include "llvm/Support/ErrorHandling.h" #include namespace llvm { @@ -41,7 +42,13 @@ class LLT { public: /// Get a low-level scalar or aggregate "bag of bits". static constexpr LLT scalar(unsigned SizeInBits) { - return LLT{/*isPointer=*/false, /*isVector=*/false, /*isScalar=*/true, + return LLT{/*isPointer=*/false, /*isVector=*/false, /*isScalar=*/true, /*isBfloat=*/false, + ElementCount::getFixed(0), SizeInBits, + /*AddressSpace=*/0}; + } + + static constexpr LLT scalar_bfloat(unsigned SizeInBits) { + return LLT{/*isPointer=*/false, /*isVector=*/false, /*isScalar=*/true, /*isBfloat=*/true, ElementCount::getFixed(0), SizeInBits, /*AddressSpace=*/0}; } @@ -49,7 +56,7 @@ class LLT { /// Get a low-level token; just a scalar with zero bits (or no size). static constexpr LLT token() { return LLT{/*isPointer=*/false, /*isVector=*/false, - /*isScalar=*/true, ElementCount::getFixed(0), + /*isScalar=*/true, /*isBfloat=*/false, ElementCount::getFixed(0), /*SizeInBits=*/0, /*AddressSpace=*/0}; } @@ -57,14 +64,14 @@ class LLT { /// Get a low-level pointer in the given address space. static constexpr LLT pointer(unsigned AddressSpace, unsigned SizeInBits) { assert(SizeInBits > 0 && "invalid pointer size"); - return LLT{/*isPointer=*/true, /*isVector=*/false, /*isScalar=*/false, + return LLT{/*isPointer=*/true, /*isVector=*/false, /*isScalar=*/false, /*isBfloat=*/false, ElementCount::getFixed(0), SizeInBits, AddressSpace}; } /// Get a low-level vector of some number of elements and element width. static constexpr LLT vector(ElementCount EC, unsigned ScalarSizeInBits) { assert(!EC.isScalar() && "invalid number of vector elements"); - return LLT{/*isPointer=*/false, /*isVector=*/true, /*isScalar=*/false, + return LLT{/*isPointer=*/false, /*isVector=*/true, /*isScalar=*/false, /*isBfloat=*/false, EC, ScalarSizeInBits, /*AddressSpace=*/0}; } @@ -75,11 +82,17 @@ class LLT { return LLT{ScalarTy.isPointer(), /*isVector=*/true, /*isScalar=*/false, + /*isBfloat=*/false, EC, ScalarTy.getSizeInBits().getFixedValue(), ScalarTy.isPointer() ? ScalarTy.getAddressSpace() : 0}; } + // Get a 16-bit brain float value. + static constexpr LLT bfloat16() { + return scalar_bfloat(16); + } + /// Get a 16-bit IEEE half value. /// TODO: Add IEEE semantics to type - This currently returns a simple `scalar(16)`. static constexpr LLT float16() { @@ -132,14 +145,14 @@ class LLT { return scalarOrVector(EC, LLT::scalar(static_cast(ScalarSize))); } - explicit constexpr LLT(bool isPointer, bool isVector, bool isScalar, + explicit constexpr LLT(bool isPointer, bool isVector, bool isScalar, bool isBfloat, ElementCount EC, uint64_t SizeInBits, unsigned AddressSpace) : LLT() { - init(isPointer, isVector, isScalar, EC, SizeInBits, AddressSpace); + init(isPointer, isVector, isScalar, isBfloat, EC, SizeInBits, AddressSpace); } explicit constexpr LLT() - : IsScalar(false), IsPointer(false), IsVector(false), RawData(0) {} + : IsScalar(false), IsPointer(false), IsVector(false), IsBfloat(false), RawData(0) {} LLVM_ABI explicit LLT(MVT VT); @@ -154,6 +167,7 @@ class LLT { constexpr bool isPointerOrPointerVector() const { return IsPointer && isValid(); } + constexpr bool isBfloat() const { return IsBfloat; } /// Returns the number of elements in a vector LLT. Must only be called on /// vector types. @@ -304,32 +318,35 @@ class LLT { /// isScalar : 1 /// isPointer : 1 /// isVector : 1 - /// with 61 bits remaining for Kind-specific data, packed in bitfields + /// isBfloat : 1 + /// with 60 bits remaining for Kind-specific data, packed in bitfields /// as described below. As there isn't a simple portable way to pack bits /// into bitfields, here the different fields in the packed structure is /// described in static const *Field variables. Each of these variables /// is a 2-element array, with the first element describing the bitfield size /// and the second element describing the bitfield offset. /// - /// +--------+---------+--------+----------+----------------------+ - /// |isScalar|isPointer|isVector| RawData |Notes | - /// +--------+---------+--------+----------+----------------------+ - /// | 0 | 0 | 0 | 0 |Invalid | - /// +--------+---------+--------+----------+----------------------+ - /// | 0 | 0 | 1 | 0 |Tombstone Key | - /// +--------+---------+--------+----------+----------------------+ - /// | 0 | 1 | 0 | 0 |Empty Key | - /// +--------+---------+--------+----------+----------------------+ - /// | 1 | 0 | 0 | 0 |Token | - /// +--------+---------+--------+----------+----------------------+ - /// | 1 | 0 | 0 | non-zero |Scalar | - /// +--------+---------+--------+----------+----------------------+ - /// | 0 | 1 | 0 | non-zero |Pointer | - /// +--------+---------+--------+----------+----------------------+ - /// | 0 | 0 | 1 | non-zero |Vector of non-pointer | - /// +--------+---------+--------+----------+----------------------+ - /// | 0 | 1 | 1 | non-zero |Vector of pointer | - /// +--------+---------+--------+----------+----------------------+ + /// +--------+---------+--------+----------+----------+----------------------+ + /// |isScalar|isPointer|isVector| isBfloat | RawData |Notes | + /// +--------+---------+--------+----------+----------+----------------------+ + /// | 0 | 0 | 0 | 0 | 0 |Invalid | + /// +--------+---------+--------+----------+----------+----------------------+ + /// | 0 | 0 | 1 | 0 | 0 |Tombstone Key | + /// +--------+---------+--------+----------+----------+----------------------+ + /// | 0 | 1 | 0 | 0 | 0 |Empty Key | + /// +--------+---------+--------+----------+----------+----------------------+ + /// | 1 | 0 | 0 | 0 | 0 |Token | + /// +--------+---------+--------+----------+----------+----------------------+ + /// | 1 | 0 | 0 | 0 | non-zero |Scalar | + /// +--------+---------+--------+----------+----------+----------------------+ + /// | 1 | 0 | 0 | 1 | non-zero |Scalar (Bfloat 16) | + /// +--------+---------+--------+----------+----------+----------------------+ + /// | 0 | 1 | 0 | 0 | non-zero |Pointer | + /// +--------+---------+--------+----------+----------+----------------------+ + /// | 0 | 0 | 1 | 0 | non-zero |Vector of non-pointer | + /// +--------+---------+--------+----------+----------+----------------------+ + /// | 0 | 1 | 1 0 | non-zero |Vector of pointer | + /// +--------+---------+--------+----------+----------+----------------------+ /// /// Everything else is reserved. typedef int BitFieldInfo[2]; @@ -340,12 +357,12 @@ class LLT { /// valid encodings, SizeInBits/SizeOfElement must be larger than 0. /// * Non-pointer scalar (isPointer == 0 && isVector == 0): /// SizeInBits: 32; - static const constexpr BitFieldInfo ScalarSizeFieldInfo{32, 29}; + static const constexpr BitFieldInfo ScalarSizeFieldInfo{32, 28}; /// * Pointer (isPointer == 1 && isVector == 0): /// SizeInBits: 16; /// AddressSpace: 24; - static const constexpr BitFieldInfo PointerSizeFieldInfo{16, 45}; - static const constexpr BitFieldInfo PointerAddressSpaceFieldInfo{24, 21}; + static const constexpr BitFieldInfo PointerSizeFieldInfo{16, 44}; + static const constexpr BitFieldInfo PointerAddressSpaceFieldInfo{24, 20}; /// * Vector-of-non-pointer (isPointer == 0 && isVector == 1): /// NumElements: 16; /// SizeOfElement: 32; @@ -361,7 +378,8 @@ class LLT { uint64_t IsScalar : 1; uint64_t IsPointer : 1; uint64_t IsVector : 1; - uint64_t RawData : 61; + uint64_t IsBfloat : 1; + uint64_t RawData : 60; static constexpr uint64_t getMask(const BitFieldInfo FieldInfo) { const int FieldSizeInBits = FieldInfo[0]; @@ -381,7 +399,7 @@ class LLT { return getMask(FieldInfo) & (RawData >> FieldInfo[1]); } - constexpr void init(bool IsPointer, bool IsVector, bool IsScalar, + constexpr void init(bool IsPointer, bool IsVector, bool IsScalar, bool IsBfloat, ElementCount EC, uint64_t SizeInBits, unsigned AddressSpace) { assert(SizeInBits <= std::numeric_limits::max() && @@ -389,6 +407,7 @@ class LLT { this->IsPointer = IsPointer; this->IsVector = IsVector; this->IsScalar = IsScalar; + this->IsBfloat = IsBfloat; if (IsPointer) { RawData = maskAndShift(SizeInBits, PointerSizeFieldInfo) | maskAndShift(AddressSpace, PointerAddressSpaceFieldInfo); @@ -403,7 +422,7 @@ class LLT { public: constexpr uint64_t getUniqueRAWLLTData() const { - return ((uint64_t)RawData) << 3 | ((uint64_t)IsScalar) << 2 | + return ((uint64_t)RawData) << 4 | ((uint64_t)IsBfloat) << 3 | ((uint64_t)IsScalar) << 2 | ((uint64_t)IsPointer) << 1 | ((uint64_t)IsVector); } }; diff --git a/llvm/lib/CodeGen/LowLevelTypeUtils.cpp b/llvm/lib/CodeGen/LowLevelTypeUtils.cpp index 936c9fbb2fff0..03a97eacb049e 100644 --- a/llvm/lib/CodeGen/LowLevelTypeUtils.cpp +++ b/llvm/lib/CodeGen/LowLevelTypeUtils.cpp @@ -36,6 +36,9 @@ LLT llvm::getLLTForType(Type &Ty, const DataLayout &DL) { // concerned. auto SizeInBits = DL.getTypeSizeInBits(&Ty); assert(SizeInBits != 0 && "invalid zero-sized type"); + if (Ty.isBFloatTy()) { + return LLT::scalar_bfloat(SizeInBits); + } return LLT::scalar(SizeInBits); } diff --git a/llvm/lib/CodeGenTypes/LowLevelType.cpp b/llvm/lib/CodeGenTypes/LowLevelType.cpp index 4785f2652b00e..8828135fcbb27 100644 --- a/llvm/lib/CodeGenTypes/LowLevelType.cpp +++ b/llvm/lib/CodeGenTypes/LowLevelType.cpp @@ -19,18 +19,21 @@ using namespace llvm; LLT::LLT(MVT VT) { if (VT.isVector()) { bool asVector = VT.getVectorMinNumElements() > 1 || VT.isScalableVector(); - init(/*IsPointer=*/false, asVector, /*IsScalar=*/!asVector, + init(/*IsPointer=*/false, asVector, /*IsScalar=*/!asVector, /*isBfloat=*/false, VT.getVectorElementCount(), VT.getVectorElementType().getSizeInBits(), /*AddressSpace=*/0); } else if (VT.isValid() && !VT.isScalableTargetExtVT()) { // Aggregates are no different from real scalars as far as GlobalISel is // concerned. - init(/*IsPointer=*/false, /*IsVector=*/false, /*IsScalar=*/true, + MVT ElemVT = VT.getVectorElementType(); + bool isElemBfloat = (ElemVT == MVT::bf16); + init(/*IsPointer=*/false, /*IsVector=*/false, /*IsScalar=*/true, /*isBfloat=*/false, ElementCount::getFixed(0), VT.getSizeInBits(), /*AddressSpace=*/0); } else { IsScalar = false; IsPointer = false; IsVector = false; + IsBfloat = false; RawData = 0; } } diff --git a/llvm/lib/Target/SPIRV/SPIRVModuleAnalysis.cpp b/llvm/lib/Target/SPIRV/SPIRVModuleAnalysis.cpp index 8039cf0c432fa..dc00d97e2d8d2 100644 --- a/llvm/lib/Target/SPIRV/SPIRVModuleAnalysis.cpp +++ b/llvm/lib/Target/SPIRV/SPIRVModuleAnalysis.cpp @@ -1262,6 +1262,17 @@ void addInstrRequirements(const MachineInstr &MI, break; } case SPIRV::OpTypeFloat: { + const MachineBasicBlock *MBB = MI.getParent(); + const MachineFunction *MF = MBB->getParent(); + const MachineRegisterInfo &MRI = MF->getRegInfo(); + const MachineOperand &MO = MI.getOperand(1); + if (MO.isReg()) { + LLT Ty = MRI.getType(MO.getReg()); + if(Ty.isBfloat()) { + assert(1 && "hola, ur wrong"); + } + } + unsigned BitWidth = MI.getOperand(1).getImm(); if (BitWidth == 64) Reqs.addCapability(SPIRV::Capability::Float64); diff --git a/llvm/test/CodeGen/SPIRV/extensions/SPV_KHR_bfloat16/bfloat16.ll b/llvm/test/CodeGen/SPIRV/extensions/SPV_KHR_bfloat16/bfloat16.ll new file mode 100644 index 0000000000000..336b2b013bc60 --- /dev/null +++ b/llvm/test/CodeGen/SPIRV/extensions/SPV_KHR_bfloat16/bfloat16.ll @@ -0,0 +1,17 @@ +; RUN: llc -verify-machineinstrs -O0 -mtriple=spirv32-unknown-unknown --spirv-ext=+SPV_KHR_bfloat16 %s -o - | FileCheck %s +; RUN: %if spirv-tools %{ llc -O0 -mtriple=spirv64-unknown-unknown --spirv-ext=+SPV_KHR_bfloat16 %s -o - -filetype=obj | spirv-val %} +; XFAIL: * +; RUN: llc -verify-machineinstrs -O0 -mtriple=spirv32-unknown-unknown %s -o - | FileCheck %s --check-prefix=CHECK-ERROR + +; CHECK-ERROR: BFloat16TypeKHR requires the following SPIR-V extension: SPV_KHR_subgroup_rotate + +; CHECK-DAG: OpCapability BFloat16TypeKHR +; CHECK-DAG: OpExtension "SPV_KHR_bfloat16" +; CHECK: %[[#BFLOAT:]] = OpTypeFloat 16 0 +; CHECK: %[[#]] = OpTypeVector %[[#BFLOAT]] 2 + +define spir_kernel void @test() { +entry: + %addr1 = alloca bfloat + ret void +}