diff --git a/llvm/lib/Target/SPIRV/CMakeLists.txt b/llvm/lib/Target/SPIRV/CMakeLists.txt index 31dca146d18fb..ab7e850172c2f 100644 --- a/llvm/lib/Target/SPIRV/CMakeLists.txt +++ b/llvm/lib/Target/SPIRV/CMakeLists.txt @@ -15,13 +15,17 @@ add_public_tablegen_target(SPIRVCommonTableGen) add_llvm_target(SPIRVCodeGen SPIRVAsmPrinter.cpp SPIRVCallLowering.cpp + SPIRVGlobalRegistry.cpp SPIRVInstrInfo.cpp + SPIRVInstructionSelector.cpp SPIRVISelLowering.cpp + SPIRVLegalizerInfo.cpp SPIRVMCInstLower.cpp SPIRVRegisterBankInfo.cpp SPIRVRegisterInfo.cpp SPIRVSubtarget.cpp SPIRVTargetMachine.cpp + SPIRVUtils.cpp LINK_COMPONENTS Analysis diff --git a/llvm/lib/Target/SPIRV/MCTargetDesc/CMakeLists.txt b/llvm/lib/Target/SPIRV/MCTargetDesc/CMakeLists.txt index fb56ffdd376b3..10cb1d039c63a 100644 --- a/llvm/lib/Target/SPIRV/MCTargetDesc/CMakeLists.txt +++ b/llvm/lib/Target/SPIRV/MCTargetDesc/CMakeLists.txt @@ -1,4 +1,5 @@ add_llvm_component_library(LLVMSPIRVDesc + SPIRVBaseInfo.cpp SPIRVMCAsmInfo.cpp SPIRVMCTargetDesc.cpp SPIRVTargetStreamer.cpp diff --git a/llvm/lib/Target/SPIRV/MCTargetDesc/SPIRVBaseInfo.cpp b/llvm/lib/Target/SPIRV/MCTargetDesc/SPIRVBaseInfo.cpp new file mode 100644 index 0000000000000..b8bd536c2524e --- /dev/null +++ b/llvm/lib/Target/SPIRV/MCTargetDesc/SPIRVBaseInfo.cpp @@ -0,0 +1,1094 @@ +//===-- SPIRVBaseInfo.cpp - Top level definitions for SPIRV ----*- C++ -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This file contains small standalone helper functions and enum definitions for +// the SPIRV target useful for the compiler back-end and the MC libraries. +// As such, it deliberately does not include references to LLVM core +// code gen types, passes, etc.. +// +//===----------------------------------------------------------------------===// + +#include "SPIRVBaseInfo.h" +#include "llvm/Support/ErrorHandling.h" + +namespace llvm { +namespace SPIRV { + +#define CASE(CLASS, ATTR) \ + case CLASS::ATTR: \ + return #ATTR; +#define CASE_SUF(CLASS, SF, ATTR) \ + case CLASS::SF##_##ATTR: \ + return #ATTR; + +// Implement getEnumName(Enum e) helper functions. +// TODO: re-implement all the functions using TableGen. +StringRef getCapabilityName(Capability e) { + switch (e) { + CASE(Capability, Matrix) + CASE(Capability, Shader) + CASE(Capability, Geometry) + CASE(Capability, Tessellation) + CASE(Capability, Addresses) + CASE(Capability, Linkage) + CASE(Capability, Kernel) + CASE(Capability, Vector16) + CASE(Capability, Float16Buffer) + CASE(Capability, Float16) + CASE(Capability, Float64) + CASE(Capability, Int64) + CASE(Capability, Int64Atomics) + CASE(Capability, ImageBasic) + CASE(Capability, ImageReadWrite) + CASE(Capability, ImageMipmap) + CASE(Capability, Pipes) + CASE(Capability, Groups) + CASE(Capability, DeviceEnqueue) + CASE(Capability, LiteralSampler) + CASE(Capability, AtomicStorage) + CASE(Capability, Int16) + CASE(Capability, TessellationPointSize) + CASE(Capability, GeometryPointSize) + CASE(Capability, ImageGatherExtended) + CASE(Capability, StorageImageMultisample) + CASE(Capability, UniformBufferArrayDynamicIndexing) + CASE(Capability, SampledImageArrayDymnamicIndexing) + CASE(Capability, ClipDistance) + CASE(Capability, CullDistance) + CASE(Capability, ImageCubeArray) + CASE(Capability, SampleRateShading) + CASE(Capability, ImageRect) + CASE(Capability, SampledRect) + CASE(Capability, GenericPointer) + CASE(Capability, Int8) + CASE(Capability, InputAttachment) + CASE(Capability, SparseResidency) + CASE(Capability, MinLod) + CASE(Capability, Sampled1D) + CASE(Capability, Image1D) + CASE(Capability, SampledCubeArray) + CASE(Capability, SampledBuffer) + CASE(Capability, ImageBuffer) + CASE(Capability, ImageMSArray) + CASE(Capability, StorageImageExtendedFormats) + CASE(Capability, ImageQuery) + CASE(Capability, DerivativeControl) + CASE(Capability, InterpolationFunction) + CASE(Capability, TransformFeedback) + CASE(Capability, GeometryStreams) + CASE(Capability, StorageImageReadWithoutFormat) + CASE(Capability, StorageImageWriteWithoutFormat) + CASE(Capability, MultiViewport) + CASE(Capability, SubgroupDispatch) + CASE(Capability, NamedBarrier) + CASE(Capability, PipeStorage) + CASE(Capability, GroupNonUniform) + CASE(Capability, GroupNonUniformVote) + CASE(Capability, GroupNonUniformArithmetic) + CASE(Capability, GroupNonUniformBallot) + CASE(Capability, GroupNonUniformShuffle) + CASE(Capability, GroupNonUniformShuffleRelative) + CASE(Capability, GroupNonUniformClustered) + CASE(Capability, GroupNonUniformQuad) + CASE(Capability, SubgroupBallotKHR) + CASE(Capability, DrawParameters) + CASE(Capability, SubgroupVoteKHR) + CASE(Capability, StorageBuffer16BitAccess) + CASE(Capability, StorageUniform16) + CASE(Capability, StoragePushConstant16) + CASE(Capability, StorageInputOutput16) + CASE(Capability, DeviceGroup) + CASE(Capability, MultiView) + CASE(Capability, VariablePointersStorageBuffer) + CASE(Capability, VariablePointers) + CASE(Capability, AtomicStorageOps) + CASE(Capability, SampleMaskPostDepthCoverage) + CASE(Capability, StorageBuffer8BitAccess) + CASE(Capability, UniformAndStorageBuffer8BitAccess) + CASE(Capability, StoragePushConstant8) + CASE(Capability, DenormPreserve) + CASE(Capability, DenormFlushToZero) + CASE(Capability, SignedZeroInfNanPreserve) + CASE(Capability, RoundingModeRTE) + CASE(Capability, RoundingModeRTZ) + CASE(Capability, Float16ImageAMD) + CASE(Capability, ImageGatherBiasLodAMD) + CASE(Capability, FragmentMaskAMD) + CASE(Capability, StencilExportEXT) + CASE(Capability, ImageReadWriteLodAMD) + CASE(Capability, SampleMaskOverrideCoverageNV) + CASE(Capability, GeometryShaderPassthroughNV) + CASE(Capability, ShaderViewportIndexLayerEXT) + CASE(Capability, ShaderViewportMaskNV) + CASE(Capability, ShaderStereoViewNV) + CASE(Capability, PerViewAttributesNV) + CASE(Capability, FragmentFullyCoveredEXT) + CASE(Capability, MeshShadingNV) + CASE(Capability, ShaderNonUniformEXT) + CASE(Capability, RuntimeDescriptorArrayEXT) + CASE(Capability, InputAttachmentArrayDynamicIndexingEXT) + CASE(Capability, UniformTexelBufferArrayDynamicIndexingEXT) + CASE(Capability, StorageTexelBufferArrayDynamicIndexingEXT) + CASE(Capability, UniformBufferArrayNonUniformIndexingEXT) + CASE(Capability, SampledImageArrayNonUniformIndexingEXT) + CASE(Capability, StorageBufferArrayNonUniformIndexingEXT) + CASE(Capability, StorageImageArrayNonUniformIndexingEXT) + CASE(Capability, InputAttachmentArrayNonUniformIndexingEXT) + CASE(Capability, UniformTexelBufferArrayNonUniformIndexingEXT) + CASE(Capability, StorageTexelBufferArrayNonUniformIndexingEXT) + CASE(Capability, RayTracingNV) + CASE(Capability, SubgroupShuffleINTEL) + CASE(Capability, SubgroupBufferBlockIOINTEL) + CASE(Capability, SubgroupImageBlockIOINTEL) + CASE(Capability, SubgroupImageMediaBlockIOINTEL) + CASE(Capability, SubgroupAvcMotionEstimationINTEL) + CASE(Capability, SubgroupAvcMotionEstimationIntraINTEL) + CASE(Capability, SubgroupAvcMotionEstimationChromaINTEL) + CASE(Capability, GroupNonUniformPartitionedNV) + CASE(Capability, VulkanMemoryModelKHR) + CASE(Capability, VulkanMemoryModelDeviceScopeKHR) + CASE(Capability, ImageFootprintNV) + CASE(Capability, FragmentBarycentricNV) + CASE(Capability, ComputeDerivativeGroupQuadsNV) + CASE(Capability, ComputeDerivativeGroupLinearNV) + CASE(Capability, FragmentDensityEXT) + CASE(Capability, PhysicalStorageBufferAddressesEXT) + CASE(Capability, CooperativeMatrixNV) + default: + break; + } + llvm_unreachable("Unexpected operand"); +} + +StringRef getSourceLanguageName(SourceLanguage e) { + switch (e) { + CASE(SourceLanguage, Unknown) + CASE(SourceLanguage, ESSL) + CASE(SourceLanguage, GLSL) + CASE(SourceLanguage, OpenCL_C) + CASE(SourceLanguage, OpenCL_CPP) + CASE(SourceLanguage, HLSL) + default: + break; + } + llvm_unreachable("Unexpected operand"); +} + +StringRef getExecutionModelName(ExecutionModel e) { + switch (e) { + CASE(ExecutionModel, Vertex) + CASE(ExecutionModel, TessellationControl) + CASE(ExecutionModel, TessellationEvaluation) + CASE(ExecutionModel, Geometry) + CASE(ExecutionModel, Fragment) + CASE(ExecutionModel, GLCompute) + CASE(ExecutionModel, Kernel) + CASE(ExecutionModel, TaskNV) + CASE(ExecutionModel, MeshNV) + CASE(ExecutionModel, RayGenerationNV) + CASE(ExecutionModel, IntersectionNV) + CASE(ExecutionModel, AnyHitNV) + CASE(ExecutionModel, ClosestHitNV) + CASE(ExecutionModel, MissNV) + CASE(ExecutionModel, CallableNV) + default: + break; + } + llvm_unreachable("Unexpected operand"); +} + +StringRef getAddressingModelName(AddressingModel e) { + switch (e) { + CASE(AddressingModel, Logical) + CASE(AddressingModel, Physical32) + CASE(AddressingModel, Physical64) + CASE(AddressingModel, PhysicalStorageBuffer64EXT) + default: + break; + } + llvm_unreachable("Unexpected operand"); +} + +StringRef getMemoryModelName(MemoryModel e) { + switch (e) { + CASE(MemoryModel, Simple) + CASE(MemoryModel, GLSL450) + CASE(MemoryModel, OpenCL) + CASE(MemoryModel, VulkanKHR) + default: + break; + } + llvm_unreachable("Unexpected operand"); +} + +StringRef getExecutionModeName(ExecutionMode e) { + switch (e) { + CASE(ExecutionMode, Invocations) + CASE(ExecutionMode, SpacingEqual) + CASE(ExecutionMode, SpacingFractionalEven) + CASE(ExecutionMode, SpacingFractionalOdd) + CASE(ExecutionMode, VertexOrderCw) + CASE(ExecutionMode, VertexOrderCcw) + CASE(ExecutionMode, PixelCenterInteger) + CASE(ExecutionMode, OriginUpperLeft) + CASE(ExecutionMode, OriginLowerLeft) + CASE(ExecutionMode, EarlyFragmentTests) + CASE(ExecutionMode, PointMode) + CASE(ExecutionMode, Xfb) + CASE(ExecutionMode, DepthReplacing) + CASE(ExecutionMode, DepthGreater) + CASE(ExecutionMode, DepthLess) + CASE(ExecutionMode, DepthUnchanged) + CASE(ExecutionMode, LocalSize) + CASE(ExecutionMode, LocalSizeHint) + CASE(ExecutionMode, InputPoints) + CASE(ExecutionMode, InputLinesAdjacency) + CASE(ExecutionMode, Triangles) + CASE(ExecutionMode, InputTrianglesAdjacency) + CASE(ExecutionMode, Quads) + CASE(ExecutionMode, Isolines) + CASE(ExecutionMode, OutputVertices) + CASE(ExecutionMode, OutputPoints) + CASE(ExecutionMode, OutputLineStrip) + CASE(ExecutionMode, OutputTriangleStrip) + CASE(ExecutionMode, VecTypeHint) + CASE(ExecutionMode, ContractionOff) + CASE(ExecutionMode, Initializer) + CASE(ExecutionMode, Finalizer) + CASE(ExecutionMode, SubgroupSize) + CASE(ExecutionMode, SubgroupsPerWorkgroup) + CASE(ExecutionMode, SubgroupsPerWorkgroupId) + CASE(ExecutionMode, LocalSizeId) + CASE(ExecutionMode, LocalSizeHintId) + CASE(ExecutionMode, PostDepthCoverage) + CASE(ExecutionMode, DenormPreserve) + CASE(ExecutionMode, DenormFlushToZero) + CASE(ExecutionMode, SignedZeroInfNanPreserve) + CASE(ExecutionMode, RoundingModeRTE) + CASE(ExecutionMode, RoundingModeRTZ) + CASE(ExecutionMode, StencilRefReplacingEXT) + CASE(ExecutionMode, OutputLinesNV) + CASE(ExecutionMode, DerivativeGroupQuadsNV) + CASE(ExecutionMode, DerivativeGroupLinearNV) + CASE(ExecutionMode, OutputTrianglesNV) + default: + break; + } + llvm_unreachable("Unexpected operand"); +} + +StringRef getStorageClassName(StorageClass e) { + switch (e) { + CASE(StorageClass, UniformConstant) + CASE(StorageClass, Input) + CASE(StorageClass, Uniform) + CASE(StorageClass, Output) + CASE(StorageClass, Workgroup) + CASE(StorageClass, CrossWorkgroup) + CASE(StorageClass, Private) + CASE(StorageClass, Function) + CASE(StorageClass, Generic) + CASE(StorageClass, PushConstant) + CASE(StorageClass, AtomicCounter) + CASE(StorageClass, Image) + CASE(StorageClass, StorageBuffer) + CASE(StorageClass, CallableDataNV) + CASE(StorageClass, IncomingCallableDataNV) + CASE(StorageClass, RayPayloadNV) + CASE(StorageClass, HitAttributeNV) + CASE(StorageClass, IncomingRayPayloadNV) + CASE(StorageClass, ShaderRecordBufferNV) + CASE(StorageClass, PhysicalStorageBufferEXT) + default: + break; + } + llvm_unreachable("Unexpected operand"); +} + +StringRef getDimName(Dim dim) { + switch (dim) { + CASE_SUF(Dim, DIM, 1D) + CASE_SUF(Dim, DIM, 2D) + CASE_SUF(Dim, DIM, 3D) + CASE_SUF(Dim, DIM, Cube) + CASE_SUF(Dim, DIM, Rect) + CASE_SUF(Dim, DIM, Buffer) + CASE_SUF(Dim, DIM, SubpassData) + default: + break; + } + llvm_unreachable("Unexpected operand"); +} + +StringRef getSamplerAddressingModeName(SamplerAddressingMode e) { + switch (e) { + CASE(SamplerAddressingMode, None) + CASE(SamplerAddressingMode, ClampToEdge) + CASE(SamplerAddressingMode, Clamp) + CASE(SamplerAddressingMode, Repeat) + CASE(SamplerAddressingMode, RepeatMirrored) + default: + break; + } + llvm_unreachable("Unexpected operand"); +} + +StringRef getSamplerFilterModeName(SamplerFilterMode e) { + switch (e) { + CASE(SamplerFilterMode, Nearest) + CASE(SamplerFilterMode, Linear) + default: + break; + } + llvm_unreachable("Unexpected operand"); +} + +StringRef getImageFormatName(ImageFormat e) { + switch (e) { + CASE(ImageFormat, Unknown) + CASE(ImageFormat, Rgba32f) + CASE(ImageFormat, Rgba16f) + CASE(ImageFormat, R32f) + CASE(ImageFormat, Rgba8) + CASE(ImageFormat, Rgba8Snorm) + CASE(ImageFormat, Rg32f) + CASE(ImageFormat, Rg16f) + CASE(ImageFormat, R11fG11fB10f) + CASE(ImageFormat, R16f) + CASE(ImageFormat, Rgba16) + CASE(ImageFormat, Rgb10A2) + CASE(ImageFormat, Rg16) + CASE(ImageFormat, Rg8) + CASE(ImageFormat, R16) + CASE(ImageFormat, R8) + CASE(ImageFormat, Rgba16Snorm) + CASE(ImageFormat, Rg16Snorm) + CASE(ImageFormat, Rg8Snorm) + CASE(ImageFormat, R16Snorm) + CASE(ImageFormat, R8Snorm) + CASE(ImageFormat, Rgba32i) + CASE(ImageFormat, Rgba16i) + CASE(ImageFormat, Rgba8i) + CASE(ImageFormat, R32i) + CASE(ImageFormat, Rg32i) + CASE(ImageFormat, Rg16i) + CASE(ImageFormat, Rg8i) + CASE(ImageFormat, R16i) + CASE(ImageFormat, R8i) + CASE(ImageFormat, Rgba32ui) + CASE(ImageFormat, Rgba16ui) + CASE(ImageFormat, Rgba8ui) + CASE(ImageFormat, R32ui) + CASE(ImageFormat, Rgb10a2ui) + CASE(ImageFormat, Rg32ui) + CASE(ImageFormat, Rg16ui) + CASE(ImageFormat, Rg8ui) + CASE(ImageFormat, R16ui) + CASE(ImageFormat, R8ui) + default: + break; + } + llvm_unreachable("Unexpected operand"); +} + +StringRef getImageChannelOrderName(ImageChannelOrder e) { + switch (e) { + CASE(ImageChannelOrder, R) + CASE(ImageChannelOrder, A) + CASE(ImageChannelOrder, RG) + CASE(ImageChannelOrder, RA) + CASE(ImageChannelOrder, RGB) + CASE(ImageChannelOrder, RGBA) + CASE(ImageChannelOrder, BGRA) + CASE(ImageChannelOrder, ARGB) + CASE(ImageChannelOrder, Intensity) + CASE(ImageChannelOrder, Luminance) + CASE(ImageChannelOrder, Rx) + CASE(ImageChannelOrder, RGx) + CASE(ImageChannelOrder, RGBx) + CASE(ImageChannelOrder, Depth) + CASE(ImageChannelOrder, DepthStencil) + CASE(ImageChannelOrder, sRGB) + CASE(ImageChannelOrder, sRGBx) + CASE(ImageChannelOrder, sRGBA) + CASE(ImageChannelOrder, sBGRA) + CASE(ImageChannelOrder, ABGR) + default: + break; + } + llvm_unreachable("Unexpected operand"); +} + +StringRef getImageChannelDataTypeName(ImageChannelDataType e) { + switch (e) { + CASE(ImageChannelDataType, SnormInt8) + CASE(ImageChannelDataType, SnormInt16) + CASE(ImageChannelDataType, UnormInt8) + CASE(ImageChannelDataType, UnormInt16) + CASE(ImageChannelDataType, UnormShort565) + CASE(ImageChannelDataType, UnormShort555) + CASE(ImageChannelDataType, UnormInt101010) + CASE(ImageChannelDataType, SignedInt8) + CASE(ImageChannelDataType, SignedInt16) + CASE(ImageChannelDataType, SignedInt32) + CASE(ImageChannelDataType, UnsignedInt8) + CASE(ImageChannelDataType, UnsignedInt16) + CASE(ImageChannelDataType, UnsigendInt32) + CASE(ImageChannelDataType, HalfFloat) + CASE(ImageChannelDataType, Float) + CASE(ImageChannelDataType, UnormInt24) + CASE(ImageChannelDataType, UnormInt101010_2) + default: + break; + } + llvm_unreachable("Unexpected operand"); +} + +std::string getImageOperandName(uint32_t e) { + std::string nameString = ""; + std::string sep = ""; + if (e == static_cast(ImageOperand::None)) + return "None"; + if (e == static_cast(ImageOperand::Bias)) + return "Bias"; + if (e & static_cast(ImageOperand::Bias)) { + nameString += sep + "Bias"; + sep = "|"; + } + if (e == static_cast(ImageOperand::Lod)) + return "Lod"; + if (e & static_cast(ImageOperand::Lod)) { + nameString += sep + "Lod"; + sep = "|"; + } + if (e == static_cast(ImageOperand::Grad)) + return "Grad"; + if (e & static_cast(ImageOperand::Grad)) { + nameString += sep + "Grad"; + sep = "|"; + } + if (e == static_cast(ImageOperand::ConstOffset)) + return "ConstOffset"; + if (e & static_cast(ImageOperand::ConstOffset)) { + nameString += sep + "ConstOffset"; + sep = "|"; + } + if (e == static_cast(ImageOperand::Offset)) + return "Offset"; + if (e & static_cast(ImageOperand::Offset)) { + nameString += sep + "Offset"; + sep = "|"; + } + if (e == static_cast(ImageOperand::ConstOffsets)) + return "ConstOffsets"; + if (e & static_cast(ImageOperand::ConstOffsets)) { + nameString += sep + "ConstOffsets"; + sep = "|"; + } + if (e == static_cast(ImageOperand::Sample)) + return "Sample"; + if (e & static_cast(ImageOperand::Sample)) { + nameString += sep + "Sample"; + sep = "|"; + } + if (e == static_cast(ImageOperand::MinLod)) + return "MinLod"; + if (e & static_cast(ImageOperand::MinLod)) { + nameString += sep + "MinLod"; + sep = "|"; + } + if (e == static_cast(ImageOperand::MakeTexelAvailableKHR)) + return "MakeTexelAvailableKHR"; + if (e & static_cast(ImageOperand::MakeTexelAvailableKHR)) { + nameString += sep + "MakeTexelAvailableKHR"; + sep = "|"; + } + if (e == static_cast(ImageOperand::MakeTexelVisibleKHR)) + return "MakeTexelVisibleKHR"; + if (e & static_cast(ImageOperand::MakeTexelVisibleKHR)) { + nameString += sep + "MakeTexelVisibleKHR"; + sep = "|"; + } + if (e == static_cast(ImageOperand::NonPrivateTexelKHR)) + return "NonPrivateTexelKHR"; + if (e & static_cast(ImageOperand::NonPrivateTexelKHR)) { + nameString += sep + "NonPrivateTexelKHR"; + sep = "|"; + } + if (e == static_cast(ImageOperand::VolatileTexelKHR)) + return "VolatileTexelKHR"; + if (e & static_cast(ImageOperand::VolatileTexelKHR)) { + nameString += sep + "VolatileTexelKHR"; + sep = "|"; + } + if (e == static_cast(ImageOperand::SignExtend)) + return "SignExtend"; + if (e & static_cast(ImageOperand::SignExtend)) { + nameString += sep + "SignExtend"; + sep = "|"; + } + if (e == static_cast(ImageOperand::ZeroExtend)) + return "ZeroExtend"; + if (e & static_cast(ImageOperand::ZeroExtend)) { + nameString += sep + "ZeroExtend"; + sep = "|"; + }; + return nameString; +} + +std::string getFPFastMathModeName(uint32_t e) { + std::string nameString = ""; + std::string sep = ""; + if (e == static_cast(FPFastMathMode::None)) + return "None"; + if (e == static_cast(FPFastMathMode::NotNaN)) + return "NotNaN"; + if (e & static_cast(FPFastMathMode::NotNaN)) { + nameString += sep + "NotNaN"; + sep = "|"; + } + if (e == static_cast(FPFastMathMode::NotInf)) + return "NotInf"; + if (e & static_cast(FPFastMathMode::NotInf)) { + nameString += sep + "NotInf"; + sep = "|"; + } + if (e == static_cast(FPFastMathMode::NSZ)) + return "NSZ"; + if (e & static_cast(FPFastMathMode::NSZ)) { + nameString += sep + "NSZ"; + sep = "|"; + } + if (e == static_cast(FPFastMathMode::AllowRecip)) + return "AllowRecip"; + if (e & static_cast(FPFastMathMode::AllowRecip)) { + nameString += sep + "AllowRecip"; + sep = "|"; + } + if (e == static_cast(FPFastMathMode::Fast)) + return "Fast"; + if (e & static_cast(FPFastMathMode::Fast)) { + nameString += sep + "Fast"; + sep = "|"; + }; + return nameString; +} + +StringRef getFPRoundingModeName(FPRoundingMode e) { + switch (e) { + CASE(FPRoundingMode, RTE) + CASE(FPRoundingMode, RTZ) + CASE(FPRoundingMode, RTP) + CASE(FPRoundingMode, RTN) + default: + break; + } + llvm_unreachable("Unexpected operand"); +} + +StringRef getLinkageTypeName(LinkageType e) { + switch (e) { + CASE(LinkageType, Export) + CASE(LinkageType, Import) + default: + break; + } + llvm_unreachable("Unexpected operand"); +} + +StringRef getAccessQualifierName(AccessQualifier e) { + switch (e) { + CASE(AccessQualifier, ReadOnly) + CASE(AccessQualifier, WriteOnly) + CASE(AccessQualifier, ReadWrite) + default: + break; + } + llvm_unreachable("Unexpected operand"); +} + +StringRef getFunctionParameterAttributeName(FunctionParameterAttribute e) { + switch (e) { + CASE(FunctionParameterAttribute, Zext) + CASE(FunctionParameterAttribute, Sext) + CASE(FunctionParameterAttribute, ByVal) + CASE(FunctionParameterAttribute, Sret) + CASE(FunctionParameterAttribute, NoAlias) + CASE(FunctionParameterAttribute, NoCapture) + CASE(FunctionParameterAttribute, NoWrite) + CASE(FunctionParameterAttribute, NoReadWrite) + default: + break; + } + llvm_unreachable("Unexpected operand"); +} + +StringRef getDecorationName(Decoration e) { + switch (e) { + CASE(Decoration, RelaxedPrecision) + CASE(Decoration, SpecId) + CASE(Decoration, Block) + CASE(Decoration, BufferBlock) + CASE(Decoration, RowMajor) + CASE(Decoration, ColMajor) + CASE(Decoration, ArrayStride) + CASE(Decoration, MatrixStride) + CASE(Decoration, GLSLShared) + CASE(Decoration, GLSLPacked) + CASE(Decoration, CPacked) + CASE(Decoration, BuiltIn) + CASE(Decoration, NoPerspective) + CASE(Decoration, Flat) + CASE(Decoration, Patch) + CASE(Decoration, Centroid) + CASE(Decoration, Sample) + CASE(Decoration, Invariant) + CASE(Decoration, Restrict) + CASE(Decoration, Aliased) + CASE(Decoration, Volatile) + CASE(Decoration, Constant) + CASE(Decoration, Coherent) + CASE(Decoration, NonWritable) + CASE(Decoration, NonReadable) + CASE(Decoration, Uniform) + CASE(Decoration, UniformId) + CASE(Decoration, SaturatedConversion) + CASE(Decoration, Stream) + CASE(Decoration, Location) + CASE(Decoration, Component) + CASE(Decoration, Index) + CASE(Decoration, Binding) + CASE(Decoration, DescriptorSet) + CASE(Decoration, Offset) + CASE(Decoration, XfbBuffer) + CASE(Decoration, XfbStride) + CASE(Decoration, FuncParamAttr) + CASE(Decoration, FPRoundingMode) + CASE(Decoration, FPFastMathMode) + CASE(Decoration, LinkageAttributes) + CASE(Decoration, NoContraction) + CASE(Decoration, InputAttachmentIndex) + CASE(Decoration, Alignment) + CASE(Decoration, MaxByteOffset) + CASE(Decoration, AlignmentId) + CASE(Decoration, MaxByteOffsetId) + CASE(Decoration, NoSignedWrap) + CASE(Decoration, NoUnsignedWrap) + CASE(Decoration, ExplicitInterpAMD) + CASE(Decoration, OverrideCoverageNV) + CASE(Decoration, PassthroughNV) + CASE(Decoration, ViewportRelativeNV) + CASE(Decoration, SecondaryViewportRelativeNV) + CASE(Decoration, PerPrimitiveNV) + CASE(Decoration, PerViewNV) + CASE(Decoration, PerVertexNV) + CASE(Decoration, NonUniformEXT) + CASE(Decoration, CountBuffer) + CASE(Decoration, UserSemantic) + CASE(Decoration, RestrictPointerEXT) + CASE(Decoration, AliasedPointerEXT) + default: + break; + } + llvm_unreachable("Unexpected operand"); +} + +StringRef getBuiltInName(BuiltIn e) { + switch (e) { + CASE(BuiltIn, Position) + CASE(BuiltIn, PointSize) + CASE(BuiltIn, ClipDistance) + CASE(BuiltIn, CullDistance) + CASE(BuiltIn, VertexId) + CASE(BuiltIn, InstanceId) + CASE(BuiltIn, PrimitiveId) + CASE(BuiltIn, InvocationId) + CASE(BuiltIn, Layer) + CASE(BuiltIn, ViewportIndex) + CASE(BuiltIn, TessLevelOuter) + CASE(BuiltIn, TessLevelInner) + CASE(BuiltIn, TessCoord) + CASE(BuiltIn, PatchVertices) + CASE(BuiltIn, FragCoord) + CASE(BuiltIn, PointCoord) + CASE(BuiltIn, FrontFacing) + CASE(BuiltIn, SampleId) + CASE(BuiltIn, SamplePosition) + CASE(BuiltIn, SampleMask) + CASE(BuiltIn, FragDepth) + CASE(BuiltIn, HelperInvocation) + CASE(BuiltIn, NumWorkgroups) + CASE(BuiltIn, WorkgroupSize) + CASE(BuiltIn, WorkgroupId) + CASE(BuiltIn, LocalInvocationId) + CASE(BuiltIn, GlobalInvocationId) + CASE(BuiltIn, LocalInvocationIndex) + CASE(BuiltIn, WorkDim) + CASE(BuiltIn, GlobalSize) + CASE(BuiltIn, EnqueuedWorkgroupSize) + CASE(BuiltIn, GlobalOffset) + CASE(BuiltIn, GlobalLinearId) + CASE(BuiltIn, SubgroupSize) + CASE(BuiltIn, SubgroupMaxSize) + CASE(BuiltIn, NumSubgroups) + CASE(BuiltIn, NumEnqueuedSubgroups) + CASE(BuiltIn, SubgroupId) + CASE(BuiltIn, SubgroupLocalInvocationId) + CASE(BuiltIn, VertexIndex) + CASE(BuiltIn, InstanceIndex) + CASE(BuiltIn, SubgroupEqMask) + CASE(BuiltIn, SubgroupGeMask) + CASE(BuiltIn, SubgroupGtMask) + CASE(BuiltIn, SubgroupLeMask) + CASE(BuiltIn, SubgroupLtMask) + CASE(BuiltIn, BaseVertex) + CASE(BuiltIn, BaseInstance) + CASE(BuiltIn, DrawIndex) + CASE(BuiltIn, DeviceIndex) + CASE(BuiltIn, ViewIndex) + CASE(BuiltIn, BaryCoordNoPerspAMD) + CASE(BuiltIn, BaryCoordNoPerspCentroidAMD) + CASE(BuiltIn, BaryCoordNoPerspSampleAMD) + CASE(BuiltIn, BaryCoordSmoothAMD) + CASE(BuiltIn, BaryCoordSmoothCentroid) + CASE(BuiltIn, BaryCoordSmoothSample) + CASE(BuiltIn, BaryCoordPullModel) + CASE(BuiltIn, FragStencilRefEXT) + CASE(BuiltIn, ViewportMaskNV) + CASE(BuiltIn, SecondaryPositionNV) + CASE(BuiltIn, SecondaryViewportMaskNV) + CASE(BuiltIn, PositionPerViewNV) + CASE(BuiltIn, ViewportMaskPerViewNV) + CASE(BuiltIn, FullyCoveredEXT) + CASE(BuiltIn, TaskCountNV) + CASE(BuiltIn, PrimitiveCountNV) + CASE(BuiltIn, PrimitiveIndicesNV) + CASE(BuiltIn, ClipDistancePerViewNV) + CASE(BuiltIn, CullDistancePerViewNV) + CASE(BuiltIn, LayerPerViewNV) + CASE(BuiltIn, MeshViewCountNV) + CASE(BuiltIn, MeshViewIndices) + CASE(BuiltIn, BaryCoordNV) + CASE(BuiltIn, BaryCoordNoPerspNV) + CASE(BuiltIn, FragSizeEXT) + CASE(BuiltIn, FragInvocationCountEXT) + CASE(BuiltIn, LaunchIdNV) + CASE(BuiltIn, LaunchSizeNV) + CASE(BuiltIn, WorldRayOriginNV) + CASE(BuiltIn, WorldRayDirectionNV) + CASE(BuiltIn, ObjectRayOriginNV) + CASE(BuiltIn, ObjectRayDirectionNV) + CASE(BuiltIn, RayTminNV) + CASE(BuiltIn, RayTmaxNV) + CASE(BuiltIn, InstanceCustomIndexNV) + CASE(BuiltIn, ObjectToWorldNV) + CASE(BuiltIn, WorldToObjectNV) + CASE(BuiltIn, HitTNV) + CASE(BuiltIn, HitKindNV) + CASE(BuiltIn, IncomingRayFlagsNV) + default: + break; + } + llvm_unreachable("Unexpected operand"); +} + +std::string getSelectionControlName(uint32_t e) { + std::string nameString = ""; + std::string sep = ""; + if (e == static_cast(SelectionControl::None)) + return "None"; + if (e == static_cast(SelectionControl::Flatten)) + return "Flatten"; + if (e & static_cast(SelectionControl::Flatten)) { + nameString += sep + "Flatten"; + sep = "|"; + } + if (e == static_cast(SelectionControl::DontFlatten)) + return "DontFlatten"; + if (e & static_cast(SelectionControl::DontFlatten)) { + nameString += sep + "DontFlatten"; + sep = "|"; + }; + return nameString; +} + +std::string getLoopControlName(uint32_t e) { + std::string nameString = ""; + std::string sep = ""; + if (e == static_cast(LoopControl::None)) + return "None"; + if (e == static_cast(LoopControl::Unroll)) + return "Unroll"; + if (e & static_cast(LoopControl::Unroll)) { + nameString += sep + "Unroll"; + sep = "|"; + } + if (e == static_cast(LoopControl::DontUnroll)) + return "DontUnroll"; + if (e & static_cast(LoopControl::DontUnroll)) { + nameString += sep + "DontUnroll"; + sep = "|"; + } + if (e == static_cast(LoopControl::DependencyInfinite)) + return "DependencyInfinite"; + if (e & static_cast(LoopControl::DependencyInfinite)) { + nameString += sep + "DependencyInfinite"; + sep = "|"; + } + if (e == static_cast(LoopControl::DependencyLength)) + return "DependencyLength"; + if (e & static_cast(LoopControl::DependencyLength)) { + nameString += sep + "DependencyLength"; + sep = "|"; + } + if (e == static_cast(LoopControl::MinIterations)) + return "MinIterations"; + if (e & static_cast(LoopControl::MinIterations)) { + nameString += sep + "MinIterations"; + sep = "|"; + } + if (e == static_cast(LoopControl::MaxIterations)) + return "MaxIterations"; + if (e & static_cast(LoopControl::MaxIterations)) { + nameString += sep + "MaxIterations"; + sep = "|"; + } + if (e == static_cast(LoopControl::IterationMultiple)) + return "IterationMultiple"; + if (e & static_cast(LoopControl::IterationMultiple)) { + nameString += sep + "IterationMultiple"; + sep = "|"; + } + if (e == static_cast(LoopControl::PeelCount)) + return "PeelCount"; + if (e & static_cast(LoopControl::PeelCount)) { + nameString += sep + "PeelCount"; + sep = "|"; + } + if (e == static_cast(LoopControl::PartialCount)) + return "PartialCount"; + if (e & static_cast(LoopControl::PartialCount)) { + nameString += sep + "PartialCount"; + sep = "|"; + }; + return nameString; +} + +std::string getFunctionControlName(uint32_t e) { + std::string nameString = ""; + std::string sep = ""; + if (e == static_cast(FunctionControl::None)) + return "None"; + if (e == static_cast(FunctionControl::Inline)) + return "Inline"; + if (e & static_cast(FunctionControl::Inline)) { + nameString += sep + "Inline"; + sep = "|"; + } + if (e == static_cast(FunctionControl::DontInline)) + return "DontInline"; + if (e & static_cast(FunctionControl::DontInline)) { + nameString += sep + "DontInline"; + sep = "|"; + } + if (e == static_cast(FunctionControl::Pure)) + return "Pure"; + if (e & static_cast(FunctionControl::Pure)) { + nameString += sep + "Pure"; + sep = "|"; + } + if (e == static_cast(FunctionControl::Const)) + return "Const"; + if (e & static_cast(FunctionControl::Const)) { + nameString += sep + "Const"; + sep = "|"; + }; + return nameString; +} + +std::string getMemorySemanticsName(uint32_t e) { + std::string nameString = ""; + std::string sep = ""; + if (e == static_cast(MemorySemantics::None)) + return "None"; + if (e == static_cast(MemorySemantics::Acquire)) + return "Acquire"; + if (e & static_cast(MemorySemantics::Acquire)) { + nameString += sep + "Acquire"; + sep = "|"; + } + if (e == static_cast(MemorySemantics::Release)) + return "Release"; + if (e & static_cast(MemorySemantics::Release)) { + nameString += sep + "Release"; + sep = "|"; + } + if (e == static_cast(MemorySemantics::AcquireRelease)) + return "AcquireRelease"; + if (e & static_cast(MemorySemantics::AcquireRelease)) { + nameString += sep + "AcquireRelease"; + sep = "|"; + } + if (e == static_cast(MemorySemantics::SequentiallyConsistent)) + return "SequentiallyConsistent"; + if (e & static_cast(MemorySemantics::SequentiallyConsistent)) { + nameString += sep + "SequentiallyConsistent"; + sep = "|"; + } + if (e == static_cast(MemorySemantics::UniformMemory)) + return "UniformMemory"; + if (e & static_cast(MemorySemantics::UniformMemory)) { + nameString += sep + "UniformMemory"; + sep = "|"; + } + if (e == static_cast(MemorySemantics::SubgroupMemory)) + return "SubgroupMemory"; + if (e & static_cast(MemorySemantics::SubgroupMemory)) { + nameString += sep + "SubgroupMemory"; + sep = "|"; + } + if (e == static_cast(MemorySemantics::WorkgroupMemory)) + return "WorkgroupMemory"; + if (e & static_cast(MemorySemantics::WorkgroupMemory)) { + nameString += sep + "WorkgroupMemory"; + sep = "|"; + } + if (e == static_cast(MemorySemantics::CrossWorkgroupMemory)) + return "CrossWorkgroupMemory"; + if (e & static_cast(MemorySemantics::CrossWorkgroupMemory)) { + nameString += sep + "CrossWorkgroupMemory"; + sep = "|"; + } + if (e == static_cast(MemorySemantics::AtomicCounterMemory)) + return "AtomicCounterMemory"; + if (e & static_cast(MemorySemantics::AtomicCounterMemory)) { + nameString += sep + "AtomicCounterMemory"; + sep = "|"; + } + if (e == static_cast(MemorySemantics::ImageMemory)) + return "ImageMemory"; + if (e & static_cast(MemorySemantics::ImageMemory)) { + nameString += sep + "ImageMemory"; + sep = "|"; + } + if (e == static_cast(MemorySemantics::OutputMemoryKHR)) + return "OutputMemoryKHR"; + if (e & static_cast(MemorySemantics::OutputMemoryKHR)) { + nameString += sep + "OutputMemoryKHR"; + sep = "|"; + } + if (e == static_cast(MemorySemantics::MakeAvailableKHR)) + return "MakeAvailableKHR"; + if (e & static_cast(MemorySemantics::MakeAvailableKHR)) { + nameString += sep + "MakeAvailableKHR"; + sep = "|"; + } + if (e == static_cast(MemorySemantics::MakeVisibleKHR)) + return "MakeVisibleKHR"; + if (e & static_cast(MemorySemantics::MakeVisibleKHR)) { + nameString += sep + "MakeVisibleKHR"; + sep = "|"; + }; + return nameString; +} + +std::string getMemoryOperandName(uint32_t e) { + std::string nameString = ""; + std::string sep = ""; + if (e == static_cast(MemoryOperand::None)) + return "None"; + if (e == static_cast(MemoryOperand::Volatile)) + return "Volatile"; + if (e & static_cast(MemoryOperand::Volatile)) { + nameString += sep + "Volatile"; + sep = "|"; + } + if (e == static_cast(MemoryOperand::Aligned)) + return "Aligned"; + if (e & static_cast(MemoryOperand::Aligned)) { + nameString += sep + "Aligned"; + sep = "|"; + } + if (e == static_cast(MemoryOperand::Nontemporal)) + return "Nontemporal"; + if (e & static_cast(MemoryOperand::Nontemporal)) { + nameString += sep + "Nontemporal"; + sep = "|"; + } + if (e == static_cast(MemoryOperand::MakePointerAvailableKHR)) + return "MakePointerAvailableKHR"; + if (e & static_cast(MemoryOperand::MakePointerAvailableKHR)) { + nameString += sep + "MakePointerAvailableKHR"; + sep = "|"; + } + if (e == static_cast(MemoryOperand::MakePointerVisibleKHR)) + return "MakePointerVisibleKHR"; + if (e & static_cast(MemoryOperand::MakePointerVisibleKHR)) { + nameString += sep + "MakePointerVisibleKHR"; + sep = "|"; + } + if (e == static_cast(MemoryOperand::NonPrivatePointerKHR)) + return "NonPrivatePointerKHR"; + if (e & static_cast(MemoryOperand::NonPrivatePointerKHR)) { + nameString += sep + "NonPrivatePointerKHR"; + sep = "|"; + }; + return nameString; +} + +StringRef getScopeName(Scope e) { + switch (e) { + CASE(Scope, CrossDevice) + CASE(Scope, Device) + CASE(Scope, Workgroup) + CASE(Scope, Subgroup) + CASE(Scope, Invocation) + CASE(Scope, QueueFamilyKHR) + default: + break; + } + llvm_unreachable("Unexpected operand"); +} + +StringRef getGroupOperationName(GroupOperation e) { + switch (e) { + CASE(GroupOperation, Reduce) + CASE(GroupOperation, InclusiveScan) + CASE(GroupOperation, ExclusiveScan) + CASE(GroupOperation, ClusteredReduce) + CASE(GroupOperation, PartitionedReduceNV) + CASE(GroupOperation, PartitionedInclusiveScanNV) + CASE(GroupOperation, PartitionedExclusiveScanNV) + default: + break; + } + llvm_unreachable("Unexpected operand"); +} + +StringRef getKernelEnqueueFlagsName(KernelEnqueueFlags e) { + switch (e) { + CASE(KernelEnqueueFlags, NoWait) + CASE(KernelEnqueueFlags, WaitKernel) + CASE(KernelEnqueueFlags, WaitWorkGroup) + default: + break; + } + llvm_unreachable("Unexpected operand"); +} + +StringRef getKernelProfilingInfoName(KernelProfilingInfo e) { + switch (e) { + CASE(KernelProfilingInfo, None) + CASE(KernelProfilingInfo, CmdExecTime) + default: + break; + } + llvm_unreachable("Unexpected operand"); +} +} // namespace SPIRV +} // namespace llvm diff --git a/llvm/lib/Target/SPIRV/MCTargetDesc/SPIRVBaseInfo.h b/llvm/lib/Target/SPIRV/MCTargetDesc/SPIRVBaseInfo.h new file mode 100644 index 0000000000000..2aa9f076c78e2 --- /dev/null +++ b/llvm/lib/Target/SPIRV/MCTargetDesc/SPIRVBaseInfo.h @@ -0,0 +1,739 @@ +//===-- SPIRVBaseInfo.h - Top level definitions for SPIRV ------*- C++ -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This file contains small standalone helper functions and enum definitions for +// the SPIRV target useful for the compiler back-end and the MC libraries. +// As such, it deliberately does not include references to LLVM core +// code gen types, passes, etc.. +// +//===----------------------------------------------------------------------===// + +#ifndef LLVM_LIB_TARGET_SPIRV_MCTARGETDESC_SPIRVBASEINFO_H +#define LLVM_LIB_TARGET_SPIRV_MCTARGETDESC_SPIRVBASEINFO_H + +#include "llvm/ADT/StringRef.h" +#include + +namespace llvm { +namespace SPIRV { +enum class Capability : uint32_t { + Matrix = 0, + Shader = 1, + Geometry = 2, + Tessellation = 3, + Addresses = 4, + Linkage = 5, + Kernel = 6, + Vector16 = 7, + Float16Buffer = 8, + Float16 = 9, + Float64 = 10, + Int64 = 11, + Int64Atomics = 12, + ImageBasic = 13, + ImageReadWrite = 14, + ImageMipmap = 15, + Pipes = 17, + Groups = 18, + DeviceEnqueue = 19, + LiteralSampler = 20, + AtomicStorage = 21, + Int16 = 22, + TessellationPointSize = 23, + GeometryPointSize = 24, + ImageGatherExtended = 25, + StorageImageMultisample = 27, + UniformBufferArrayDynamicIndexing = 28, + SampledImageArrayDymnamicIndexing = 29, + ClipDistance = 32, + CullDistance = 33, + ImageCubeArray = 34, + SampleRateShading = 35, + ImageRect = 36, + SampledRect = 37, + GenericPointer = 38, + Int8 = 39, + InputAttachment = 40, + SparseResidency = 41, + MinLod = 42, + Sampled1D = 43, + Image1D = 44, + SampledCubeArray = 45, + SampledBuffer = 46, + ImageBuffer = 47, + ImageMSArray = 48, + StorageImageExtendedFormats = 49, + ImageQuery = 50, + DerivativeControl = 51, + InterpolationFunction = 52, + TransformFeedback = 53, + GeometryStreams = 54, + StorageImageReadWithoutFormat = 55, + StorageImageWriteWithoutFormat = 56, + MultiViewport = 57, + SubgroupDispatch = 58, + NamedBarrier = 59, + PipeStorage = 60, + GroupNonUniform = 61, + GroupNonUniformVote = 62, + GroupNonUniformArithmetic = 63, + GroupNonUniformBallot = 64, + GroupNonUniformShuffle = 65, + GroupNonUniformShuffleRelative = 66, + GroupNonUniformClustered = 67, + GroupNonUniformQuad = 68, + SubgroupBallotKHR = 4423, + DrawParameters = 4427, + SubgroupVoteKHR = 4431, + StorageBuffer16BitAccess = 4433, + StorageUniform16 = 4434, + StoragePushConstant16 = 4435, + StorageInputOutput16 = 4436, + DeviceGroup = 4437, + MultiView = 4439, + VariablePointersStorageBuffer = 4441, + VariablePointers = 4442, + AtomicStorageOps = 4445, + SampleMaskPostDepthCoverage = 4447, + StorageBuffer8BitAccess = 4448, + UniformAndStorageBuffer8BitAccess = 4449, + StoragePushConstant8 = 4450, + DenormPreserve = 4464, + DenormFlushToZero = 4465, + SignedZeroInfNanPreserve = 4466, + RoundingModeRTE = 4467, + RoundingModeRTZ = 4468, + Float16ImageAMD = 5008, + ImageGatherBiasLodAMD = 5009, + FragmentMaskAMD = 5010, + StencilExportEXT = 5013, + ImageReadWriteLodAMD = 5015, + SampleMaskOverrideCoverageNV = 5249, + GeometryShaderPassthroughNV = 5251, + ShaderViewportIndexLayerEXT = 5254, + ShaderViewportMaskNV = 5255, + ShaderStereoViewNV = 5259, + PerViewAttributesNV = 5260, + FragmentFullyCoveredEXT = 5265, + MeshShadingNV = 5266, + ShaderNonUniformEXT = 5301, + RuntimeDescriptorArrayEXT = 5302, + InputAttachmentArrayDynamicIndexingEXT = 5303, + UniformTexelBufferArrayDynamicIndexingEXT = 5304, + StorageTexelBufferArrayDynamicIndexingEXT = 5305, + UniformBufferArrayNonUniformIndexingEXT = 5306, + SampledImageArrayNonUniformIndexingEXT = 5307, + StorageBufferArrayNonUniformIndexingEXT = 5308, + StorageImageArrayNonUniformIndexingEXT = 5309, + InputAttachmentArrayNonUniformIndexingEXT = 5310, + UniformTexelBufferArrayNonUniformIndexingEXT = 5311, + StorageTexelBufferArrayNonUniformIndexingEXT = 5312, + RayTracingNV = 5340, + SubgroupShuffleINTEL = 5568, + SubgroupBufferBlockIOINTEL = 5569, + SubgroupImageBlockIOINTEL = 5570, + SubgroupImageMediaBlockIOINTEL = 5579, + SubgroupAvcMotionEstimationINTEL = 5696, + SubgroupAvcMotionEstimationIntraINTEL = 5697, + SubgroupAvcMotionEstimationChromaINTEL = 5698, + GroupNonUniformPartitionedNV = 5297, + VulkanMemoryModelKHR = 5345, + VulkanMemoryModelDeviceScopeKHR = 5346, + ImageFootprintNV = 5282, + FragmentBarycentricNV = 5284, + ComputeDerivativeGroupQuadsNV = 5288, + ComputeDerivativeGroupLinearNV = 5350, + FragmentDensityEXT = 5291, + PhysicalStorageBufferAddressesEXT = 5347, + CooperativeMatrixNV = 5357, +}; +StringRef getCapabilityName(Capability e); + +enum class SourceLanguage : uint32_t { + Unknown = 0, + ESSL = 1, + GLSL = 2, + OpenCL_C = 3, + OpenCL_CPP = 4, + HLSL = 5, +}; +StringRef getSourceLanguageName(SourceLanguage e); + +enum class AddressingModel : uint32_t { + Logical = 0, + Physical32 = 1, + Physical64 = 2, + PhysicalStorageBuffer64EXT = 5348, +}; +StringRef getAddressingModelName(AddressingModel e); + +enum class ExecutionModel : uint32_t { + Vertex = 0, + TessellationControl = 1, + TessellationEvaluation = 2, + Geometry = 3, + Fragment = 4, + GLCompute = 5, + Kernel = 6, + TaskNV = 5267, + MeshNV = 5268, + RayGenerationNV = 5313, + IntersectionNV = 5314, + AnyHitNV = 5315, + ClosestHitNV = 5316, + MissNV = 5317, + CallableNV = 5318, +}; +StringRef getExecutionModelName(ExecutionModel e); + +enum class MemoryModel : uint32_t { + Simple = 0, + GLSL450 = 1, + OpenCL = 2, + VulkanKHR = 3, +}; +StringRef getMemoryModelName(MemoryModel e); + +enum class ExecutionMode : uint32_t { + Invocations = 0, + SpacingEqual = 1, + SpacingFractionalEven = 2, + SpacingFractionalOdd = 3, + VertexOrderCw = 4, + VertexOrderCcw = 5, + PixelCenterInteger = 6, + OriginUpperLeft = 7, + OriginLowerLeft = 8, + EarlyFragmentTests = 9, + PointMode = 10, + Xfb = 11, + DepthReplacing = 12, + DepthGreater = 14, + DepthLess = 15, + DepthUnchanged = 16, + LocalSize = 17, + LocalSizeHint = 18, + InputPoints = 19, + InputLines = 20, + InputLinesAdjacency = 21, + Triangles = 22, + InputTrianglesAdjacency = 23, + Quads = 24, + Isolines = 25, + OutputVertices = 26, + OutputPoints = 27, + OutputLineStrip = 28, + OutputTriangleStrip = 29, + VecTypeHint = 30, + ContractionOff = 31, + Initializer = 33, + Finalizer = 34, + SubgroupSize = 35, + SubgroupsPerWorkgroup = 36, + SubgroupsPerWorkgroupId = 37, + LocalSizeId = 38, + LocalSizeHintId = 39, + PostDepthCoverage = 4446, + DenormPreserve = 4459, + DenormFlushToZero = 4460, + SignedZeroInfNanPreserve = 4461, + RoundingModeRTE = 4462, + RoundingModeRTZ = 4463, + StencilRefReplacingEXT = 5027, + OutputLinesNV = 5269, + DerivativeGroupQuadsNV = 5289, + DerivativeGroupLinearNV = 5290, + OutputTrianglesNV = 5298, +}; +StringRef getExecutionModeName(ExecutionMode e); + +enum class StorageClass : uint32_t { + UniformConstant = 0, + Input = 1, + Uniform = 2, + Output = 3, + Workgroup = 4, + CrossWorkgroup = 5, + Private = 6, + Function = 7, + Generic = 8, + PushConstant = 9, + AtomicCounter = 10, + Image = 11, + StorageBuffer = 12, + CallableDataNV = 5328, + IncomingCallableDataNV = 5329, + RayPayloadNV = 5338, + HitAttributeNV = 5339, + IncomingRayPayloadNV = 5342, + ShaderRecordBufferNV = 5343, + PhysicalStorageBufferEXT = 5349, +}; +StringRef getStorageClassName(StorageClass e); + +enum class Dim : uint32_t { + DIM_1D = 0, + DIM_2D = 1, + DIM_3D = 2, + DIM_Cube = 3, + DIM_Rect = 4, + DIM_Buffer = 5, + DIM_SubpassData = 6, +}; +StringRef getDimName(Dim e); + +enum class SamplerAddressingMode : uint32_t { + None = 0, + ClampToEdge = 1, + Clamp = 2, + Repeat = 3, + RepeatMirrored = 4, +}; +StringRef getSamplerAddressingModeName(SamplerAddressingMode e); + +enum class SamplerFilterMode : uint32_t { + Nearest = 0, + Linear = 1, +}; +StringRef getSamplerFilterModeName(SamplerFilterMode e); + +enum class ImageFormat : uint32_t { + Unknown = 0, + Rgba32f = 1, + Rgba16f = 2, + R32f = 3, + Rgba8 = 4, + Rgba8Snorm = 5, + Rg32f = 6, + Rg16f = 7, + R11fG11fB10f = 8, + R16f = 9, + Rgba16 = 10, + Rgb10A2 = 11, + Rg16 = 12, + Rg8 = 13, + R16 = 14, + R8 = 15, + Rgba16Snorm = 16, + Rg16Snorm = 17, + Rg8Snorm = 18, + R16Snorm = 19, + R8Snorm = 20, + Rgba32i = 21, + Rgba16i = 22, + Rgba8i = 23, + R32i = 24, + Rg32i = 25, + Rg16i = 26, + Rg8i = 27, + R16i = 28, + R8i = 29, + Rgba32ui = 30, + Rgba16ui = 31, + Rgba8ui = 32, + R32ui = 33, + Rgb10a2ui = 34, + Rg32ui = 35, + Rg16ui = 36, + Rg8ui = 37, + R16ui = 38, + R8ui = 39, +}; +StringRef getImageFormatName(ImageFormat e); + +enum class ImageChannelOrder : uint32_t { + R = 0, + A = 1, + RG = 2, + RA = 3, + RGB = 4, + RGBA = 5, + BGRA = 6, + ARGB = 7, + Intensity = 8, + Luminance = 9, + Rx = 10, + RGx = 11, + RGBx = 12, + Depth = 13, + DepthStencil = 14, + sRGB = 15, + sRGBx = 16, + sRGBA = 17, + sBGRA = 18, + ABGR = 19, +}; +StringRef getImageChannelOrderName(ImageChannelOrder e); + +enum class ImageChannelDataType : uint32_t { + SnormInt8 = 0, + SnormInt16 = 1, + UnormInt8 = 2, + UnormInt16 = 3, + UnormShort565 = 4, + UnormShort555 = 5, + UnormInt101010 = 6, + SignedInt8 = 7, + SignedInt16 = 8, + SignedInt32 = 9, + UnsignedInt8 = 10, + UnsignedInt16 = 11, + UnsigendInt32 = 12, + HalfFloat = 13, + Float = 14, + UnormInt24 = 15, + UnormInt101010_2 = 16, +}; +StringRef getImageChannelDataTypeName(ImageChannelDataType e); + +enum class ImageOperand : uint32_t { + None = 0x0, + Bias = 0x1, + Lod = 0x2, + Grad = 0x4, + ConstOffset = 0x8, + Offset = 0x10, + ConstOffsets = 0x20, + Sample = 0x40, + MinLod = 0x80, + MakeTexelAvailableKHR = 0x100, + MakeTexelVisibleKHR = 0x200, + NonPrivateTexelKHR = 0x400, + VolatileTexelKHR = 0x800, + SignExtend = 0x1000, + ZeroExtend = 0x2000, +}; +std::string getImageOperandName(uint32_t e); + +enum class FPFastMathMode : uint32_t { + None = 0x0, + NotNaN = 0x1, + NotInf = 0x2, + NSZ = 0x4, + AllowRecip = 0x8, + Fast = 0x10, +}; +std::string getFPFastMathModeName(uint32_t e); + +enum class FPRoundingMode : uint32_t { + RTE = 0, + RTZ = 1, + RTP = 2, + RTN = 3, +}; +StringRef getFPRoundingModeName(FPRoundingMode e); + +enum class LinkageType : uint32_t { + Export = 0, + Import = 1, +}; +StringRef getLinkageTypeName(LinkageType e); + +enum class AccessQualifier : uint32_t { + ReadOnly = 0, + WriteOnly = 1, + ReadWrite = 2, +}; +StringRef getAccessQualifierName(AccessQualifier e); + +enum class FunctionParameterAttribute : uint32_t { + Zext = 0, + Sext = 1, + ByVal = 2, + Sret = 3, + NoAlias = 4, + NoCapture = 5, + NoWrite = 6, + NoReadWrite = 7, +}; +StringRef getFunctionParameterAttributeName(FunctionParameterAttribute e); + +enum class Decoration : uint32_t { + RelaxedPrecision = 0, + SpecId = 1, + Block = 2, + BufferBlock = 3, + RowMajor = 4, + ColMajor = 5, + ArrayStride = 6, + MatrixStride = 7, + GLSLShared = 8, + GLSLPacked = 9, + CPacked = 10, + BuiltIn = 11, + NoPerspective = 13, + Flat = 14, + Patch = 15, + Centroid = 16, + Sample = 17, + Invariant = 18, + Restrict = 19, + Aliased = 20, + Volatile = 21, + Constant = 22, + Coherent = 23, + NonWritable = 24, + NonReadable = 25, + Uniform = 26, + UniformId = 27, + SaturatedConversion = 28, + Stream = 29, + Location = 30, + Component = 31, + Index = 32, + Binding = 33, + DescriptorSet = 34, + Offset = 35, + XfbBuffer = 36, + XfbStride = 37, + FuncParamAttr = 38, + FPRoundingMode = 39, + FPFastMathMode = 40, + LinkageAttributes = 41, + NoContraction = 42, + InputAttachmentIndex = 43, + Alignment = 44, + MaxByteOffset = 45, + AlignmentId = 46, + MaxByteOffsetId = 47, + NoSignedWrap = 4469, + NoUnsignedWrap = 4470, + ExplicitInterpAMD = 4999, + OverrideCoverageNV = 5248, + PassthroughNV = 5250, + ViewportRelativeNV = 5252, + SecondaryViewportRelativeNV = 5256, + PerPrimitiveNV = 5271, + PerViewNV = 5272, + PerVertexNV = 5273, + NonUniformEXT = 5300, + CountBuffer = 5634, + UserSemantic = 5635, + RestrictPointerEXT = 5355, + AliasedPointerEXT = 5356, +}; +StringRef getDecorationName(Decoration e); + +enum class BuiltIn : uint32_t { + Position = 0, + PointSize = 1, + ClipDistance = 3, + CullDistance = 4, + VertexId = 5, + InstanceId = 6, + PrimitiveId = 7, + InvocationId = 8, + Layer = 9, + ViewportIndex = 10, + TessLevelOuter = 11, + TessLevelInner = 12, + TessCoord = 13, + PatchVertices = 14, + FragCoord = 15, + PointCoord = 16, + FrontFacing = 17, + SampleId = 18, + SamplePosition = 19, + SampleMask = 20, + FragDepth = 22, + HelperInvocation = 23, + NumWorkgroups = 24, + WorkgroupSize = 25, + WorkgroupId = 26, + LocalInvocationId = 27, + GlobalInvocationId = 28, + LocalInvocationIndex = 29, + WorkDim = 30, + GlobalSize = 31, + EnqueuedWorkgroupSize = 32, + GlobalOffset = 33, + GlobalLinearId = 34, + SubgroupSize = 36, + SubgroupMaxSize = 37, + NumSubgroups = 38, + NumEnqueuedSubgroups = 39, + SubgroupId = 40, + SubgroupLocalInvocationId = 41, + VertexIndex = 42, + InstanceIndex = 43, + SubgroupEqMask = 4416, + SubgroupGeMask = 4417, + SubgroupGtMask = 4418, + SubgroupLeMask = 4419, + SubgroupLtMask = 4420, + BaseVertex = 4424, + BaseInstance = 4425, + DrawIndex = 4426, + DeviceIndex = 4438, + ViewIndex = 4440, + BaryCoordNoPerspAMD = 4492, + BaryCoordNoPerspCentroidAMD = 4493, + BaryCoordNoPerspSampleAMD = 4494, + BaryCoordSmoothAMD = 4495, + BaryCoordSmoothCentroid = 4496, + BaryCoordSmoothSample = 4497, + BaryCoordPullModel = 4498, + FragStencilRefEXT = 5014, + ViewportMaskNV = 5253, + SecondaryPositionNV = 5257, + SecondaryViewportMaskNV = 5258, + PositionPerViewNV = 5261, + ViewportMaskPerViewNV = 5262, + FullyCoveredEXT = 5264, + TaskCountNV = 5274, + PrimitiveCountNV = 5275, + PrimitiveIndicesNV = 5276, + ClipDistancePerViewNV = 5277, + CullDistancePerViewNV = 5278, + LayerPerViewNV = 5279, + MeshViewCountNV = 5280, + MeshViewIndices = 5281, + BaryCoordNV = 5286, + BaryCoordNoPerspNV = 5287, + FragSizeEXT = 5292, + FragInvocationCountEXT = 5293, + LaunchIdNV = 5319, + LaunchSizeNV = 5320, + WorldRayOriginNV = 5321, + WorldRayDirectionNV = 5322, + ObjectRayOriginNV = 5323, + ObjectRayDirectionNV = 5324, + RayTminNV = 5325, + RayTmaxNV = 5326, + InstanceCustomIndexNV = 5327, + ObjectToWorldNV = 5330, + WorldToObjectNV = 5331, + HitTNV = 5332, + HitKindNV = 5333, + IncomingRayFlagsNV = 5351, +}; +StringRef getBuiltInName(BuiltIn e); + +enum class SelectionControl : uint32_t { + None = 0x0, + Flatten = 0x1, + DontFlatten = 0x2, +}; +std::string getSelectionControlName(uint32_t e); + +enum class LoopControl : uint32_t { + None = 0x0, + Unroll = 0x1, + DontUnroll = 0x2, + DependencyInfinite = 0x4, + DependencyLength = 0x8, + MinIterations = 0x10, + MaxIterations = 0x20, + IterationMultiple = 0x40, + PeelCount = 0x80, + PartialCount = 0x100, +}; +std::string getLoopControlName(uint32_t e); + +enum class FunctionControl : uint32_t { + None = 0x0, + Inline = 0x1, + DontInline = 0x2, + Pure = 0x4, + Const = 0x8, +}; +std::string getFunctionControlName(uint32_t e); + +enum class MemorySemantics : uint32_t { + None = 0x0, + Acquire = 0x2, + Release = 0x4, + AcquireRelease = 0x8, + SequentiallyConsistent = 0x10, + UniformMemory = 0x40, + SubgroupMemory = 0x80, + WorkgroupMemory = 0x100, + CrossWorkgroupMemory = 0x200, + AtomicCounterMemory = 0x400, + ImageMemory = 0x800, + OutputMemoryKHR = 0x1000, + MakeAvailableKHR = 0x2000, + MakeVisibleKHR = 0x4000, +}; +std::string getMemorySemanticsName(uint32_t e); + +enum class MemoryOperand : uint32_t { + None = 0x0, + Volatile = 0x1, + Aligned = 0x2, + Nontemporal = 0x4, + MakePointerAvailableKHR = 0x8, + MakePointerVisibleKHR = 0x10, + NonPrivatePointerKHR = 0x20, +}; +std::string getMemoryOperandName(uint32_t e); + +enum class Scope : uint32_t { + CrossDevice = 0, + Device = 1, + Workgroup = 2, + Subgroup = 3, + Invocation = 4, + QueueFamilyKHR = 5, +}; +StringRef getScopeName(Scope e); + +enum class GroupOperation : uint32_t { + Reduce = 0, + InclusiveScan = 1, + ExclusiveScan = 2, + ClusteredReduce = 3, + PartitionedReduceNV = 6, + PartitionedInclusiveScanNV = 7, + PartitionedExclusiveScanNV = 8, +}; +StringRef getGroupOperationName(GroupOperation e); + +enum class KernelEnqueueFlags : uint32_t { + NoWait = 0, + WaitKernel = 1, + WaitWorkGroup = 2, +}; +StringRef getKernelEnqueueFlagsName(KernelEnqueueFlags e); + +enum class KernelProfilingInfo : uint32_t { + None = 0x0, + CmdExecTime = 0x1, +}; +StringRef getKernelProfilingInfoName(KernelProfilingInfo e); +} // namespace SPIRV +} // namespace llvm + +// Return a string representation of the operands from startIndex onwards. +// Templated to allow both MachineInstr and MCInst to use the same logic. +template +std::string getSPIRVStringOperand(const InstType &MI, unsigned StartIndex) { + std::string s; // Iteratively append to this string. + + const unsigned NumOps = MI.getNumOperands(); + bool IsFinished = false; + for (unsigned i = StartIndex; i < NumOps && !IsFinished; ++i) { + const auto &Op = MI.getOperand(i); + if (!Op.isImm()) // Stop if we hit a register operand. + break; + assert((Op.getImm() >> 32) == 0 && "Imm operand should be i32 word"); + const uint32_t Imm = Op.getImm(); // Each i32 word is up to 4 characters. + for (unsigned ShiftAmount = 0; ShiftAmount < 32; ShiftAmount += 8) { + char c = (Imm >> ShiftAmount) & 0xff; + if (c == 0) { // Stop if we hit a null-terminator character. + IsFinished = true; + break; + } else { + s += c; // Otherwise, append the character to the result string. + } + } + } + return s; +} + +#endif // LLVM_LIB_TARGET_SPIRV_MCTARGETDESC_SPIRVBASEINFO_H diff --git a/llvm/lib/Target/SPIRV/SPIRV.h b/llvm/lib/Target/SPIRV/SPIRV.h index 30b6995d4f702..31e2c39e7b6f6 100644 --- a/llvm/lib/Target/SPIRV/SPIRV.h +++ b/llvm/lib/Target/SPIRV/SPIRV.h @@ -16,6 +16,13 @@ namespace llvm { class SPIRVTargetMachine; class SPIRVSubtarget; +class InstructionSelector; +class RegisterBankInfo; + +InstructionSelector * +createSPIRVInstructionSelector(const SPIRVTargetMachine &TM, + const SPIRVSubtarget &Subtarget, + const RegisterBankInfo &RBI); } // namespace llvm #endif // LLVM_LIB_TARGET_SPIRV_SPIRV_H diff --git a/llvm/lib/Target/SPIRV/SPIRVCallLowering.cpp b/llvm/lib/Target/SPIRV/SPIRVCallLowering.cpp index 3f57ca0f02c75..df07a126eeead 100644 --- a/llvm/lib/Target/SPIRV/SPIRVCallLowering.cpp +++ b/llvm/lib/Target/SPIRV/SPIRVCallLowering.cpp @@ -12,17 +12,21 @@ //===----------------------------------------------------------------------===// #include "SPIRVCallLowering.h" +#include "MCTargetDesc/SPIRVBaseInfo.h" #include "SPIRV.h" +#include "SPIRVGlobalRegistry.h" #include "SPIRVISelLowering.h" #include "SPIRVRegisterInfo.h" #include "SPIRVSubtarget.h" +#include "SPIRVUtils.h" #include "llvm/CodeGen/FunctionLoweringInfo.h" -#include "llvm/CodeGen/GlobalISel/MachineIRBuilder.h" using namespace llvm; -SPIRVCallLowering::SPIRVCallLowering(const SPIRVTargetLowering &TLI) - : CallLowering(&TLI) {} +SPIRVCallLowering::SPIRVCallLowering(const SPIRVTargetLowering &TLI, + const SPIRVSubtarget &ST, + SPIRVGlobalRegistry *GR) + : CallLowering(&TLI), ST(ST), GR(GR) {} bool SPIRVCallLowering::lowerReturn(MachineIRBuilder &MIRBuilder, const Value *Val, ArrayRef VRegs, @@ -32,19 +36,39 @@ bool SPIRVCallLowering::lowerReturn(MachineIRBuilder &MIRBuilder, // TODO: handle the case of multiple registers. if (VRegs.size() > 1) return false; - if (Val) { - MIRBuilder.buildInstr(SPIRV::OpReturnValue).addUse(VRegs[0]); - return true; - } + if (Val) + return MIRBuilder.buildInstr(SPIRV::OpReturnValue) + .addUse(VRegs[0]) + .constrainAllUses(MIRBuilder.getTII(), *ST.getRegisterInfo(), + *ST.getRegBankInfo()); MIRBuilder.buildInstr(SPIRV::OpReturn); return true; } +// Based on the LLVM function attributes, get a SPIR-V FunctionControl. +static uint32_t getFunctionControl(const Function &F) { + uint32_t FuncControl = static_cast(SPIRV::FunctionControl::None); + if (F.hasFnAttribute(Attribute::AttrKind::AlwaysInline)) { + FuncControl |= static_cast(SPIRV::FunctionControl::Inline); + } + if (F.hasFnAttribute(Attribute::AttrKind::ReadNone)) { + FuncControl |= static_cast(SPIRV::FunctionControl::Pure); + } + if (F.hasFnAttribute(Attribute::AttrKind::ReadOnly)) { + FuncControl |= static_cast(SPIRV::FunctionControl::Const); + } + if (F.hasFnAttribute(Attribute::AttrKind::NoInline)) { + FuncControl |= static_cast(SPIRV::FunctionControl::DontInline); + } + return FuncControl; +} + bool SPIRVCallLowering::lowerFormalArguments(MachineIRBuilder &MIRBuilder, const Function &F, ArrayRef> VRegs, FunctionLoweringInfo &FLI) const { - auto MRI = MIRBuilder.getMRI(); + assert(GR && "Must initialize the SPIRV type registry before lowering args."); + // Assign types and names to all args, and store their types for later. SmallVector ArgTypeVRegs; if (VRegs.size() > 0) { @@ -54,21 +78,55 @@ bool SPIRVCallLowering::lowerFormalArguments(MachineIRBuilder &MIRBuilder, // TODO: handle the case of multiple registers. if (VRegs[i].size() > 1) return false; - ArgTypeVRegs.push_back( - MRI->createGenericVirtualRegister(LLT::scalar(32))); + auto *SpirvTy = + GR->assignTypeToVReg(Arg.getType(), VRegs[i][0], MIRBuilder); + ArgTypeVRegs.push_back(GR->getSPIRVTypeID(SpirvTy)); + + if (Arg.hasName()) + buildOpName(VRegs[i][0], Arg.getName(), MIRBuilder); + if (Arg.getType()->isPointerTy()) { + auto DerefBytes = static_cast(Arg.getDereferenceableBytes()); + if (DerefBytes != 0) + buildOpDecorate(VRegs[i][0], MIRBuilder, + SPIRV::Decoration::MaxByteOffset, {DerefBytes}); + } + if (Arg.hasAttribute(Attribute::Alignment)) { + buildOpDecorate(VRegs[i][0], MIRBuilder, SPIRV::Decoration::Alignment, + {static_cast(Arg.getParamAlignment())}); + } + if (Arg.hasAttribute(Attribute::ReadOnly)) { + auto Attr = + static_cast(SPIRV::FunctionParameterAttribute::NoWrite); + buildOpDecorate(VRegs[i][0], MIRBuilder, + SPIRV::Decoration::FuncParamAttr, {Attr}); + } + if (Arg.hasAttribute(Attribute::ZExt)) { + auto Attr = + static_cast(SPIRV::FunctionParameterAttribute::Zext); + buildOpDecorate(VRegs[i][0], MIRBuilder, + SPIRV::Decoration::FuncParamAttr, {Attr}); + } ++i; } } // Generate a SPIR-V type for the function. + auto MRI = MIRBuilder.getMRI(); Register FuncVReg = MRI->createGenericVirtualRegister(LLT::scalar(32)); MRI->setRegClass(FuncVReg, &SPIRV::IDRegClass); + auto *FTy = F.getFunctionType(); + auto FuncTy = GR->assignTypeToVReg(FTy, FuncVReg, MIRBuilder); + + // Build the OpTypeFunction declaring it. + Register ReturnTypeID = FuncTy->getOperand(1).getReg(); + uint32_t FuncControl = getFunctionControl(F); + MIRBuilder.buildInstr(SPIRV::OpFunction) .addDef(FuncVReg) - .addUse(MRI->createGenericVirtualRegister(LLT::scalar(32))) - .addImm(0) - .addUse(MRI->createGenericVirtualRegister(LLT::scalar(32))); + .addUse(ReturnTypeID) + .addImm(FuncControl) + .addUse(GR->getSPIRVTypeID(FuncTy)); // Add OpFunctionParameters. const unsigned NumArgs = ArgTypeVRegs.size(); @@ -79,6 +137,24 @@ bool SPIRVCallLowering::lowerFormalArguments(MachineIRBuilder &MIRBuilder, .addDef(VRegs[i][0]) .addUse(ArgTypeVRegs[i]); } + // Name the function. + if (F.hasName()) + buildOpName(FuncVReg, F.getName(), MIRBuilder); + + // Handle entry points and function linkage. + if (F.getCallingConv() == CallingConv::SPIR_KERNEL) { + auto MIB = MIRBuilder.buildInstr(SPIRV::OpEntryPoint) + .addImm(static_cast(SPIRV::ExecutionModel::Kernel)) + .addUse(FuncVReg); + addStringImm(F.getName(), MIB); + } else if (F.getLinkage() == GlobalValue::LinkageTypes::ExternalLinkage || + F.getLinkage() == GlobalValue::LinkOnceODRLinkage) { + auto LnkTy = F.isDeclaration() ? SPIRV::LinkageType::Import + : SPIRV::LinkageType::Export; + buildOpDecorate(FuncVReg, MIRBuilder, SPIRV::Decoration::LinkageAttributes, + {static_cast(LnkTy)}, F.getGlobalIdentifier()); + } + return true; } @@ -91,15 +167,49 @@ bool SPIRVCallLowering::lowerCall(MachineIRBuilder &MIRBuilder, Register ResVReg = Info.OrigRet.Regs.empty() ? Register(0) : Info.OrigRet.Regs[0]; + // Emit a regular OpFunctionCall. If it's an externally declared function, + // be sure to emit its type and function declaration here. It will be + // hoisted globally later. + if (Info.Callee.isGlobal()) { + auto *CF = dyn_cast_or_null(Info.Callee.getGlobal()); + // TODO: support constexpr casts and indirect calls. + if (CF == nullptr) + return false; + if (CF->isDeclaration()) { + // Emit the type info and forward function declaration to the first MBB + // to ensure VReg definition dependencies are valid across all MBBs. + MachineBasicBlock::iterator OldII = MIRBuilder.getInsertPt(); + MachineBasicBlock &OldBB = MIRBuilder.getMBB(); + MachineBasicBlock &FirstBB = *MIRBuilder.getMF().getBlockNumbered(0); + MIRBuilder.setInsertPt(FirstBB, FirstBB.instr_end()); + + SmallVector, 8> VRegArgs; + SmallVector, 8> ToInsert; + for (const Argument &Arg : CF->args()) { + if (MIRBuilder.getDataLayout().getTypeStoreSize(Arg.getType()).isZero()) + continue; // Don't handle zero sized types. + ToInsert.push_back({MIRBuilder.getMRI()->createGenericVirtualRegister( + LLT::scalar(32))}); + VRegArgs.push_back(ToInsert.back()); + } + // TODO: Reuse FunctionLoweringInfo. + FunctionLoweringInfo FuncInfo; + lowerFormalArguments(MIRBuilder, *CF, VRegArgs, FuncInfo); + MIRBuilder.setInsertPt(OldBB, OldII); + } + } + // Make sure there's a valid return reg, even for functions returning void. if (!ResVReg.isValid()) { ResVReg = MIRBuilder.getMRI()->createVirtualRegister(&SPIRV::IDRegClass); } + SPIRVType *RetType = + GR->assignTypeToVReg(Info.OrigRet.Ty, ResVReg, MIRBuilder); + // Emit the OpFunctionCall and its args. auto MIB = MIRBuilder.buildInstr(SPIRV::OpFunctionCall) .addDef(ResVReg) - .addUse(MIRBuilder.getMRI()->createVirtualRegister( - &SPIRV::IDRegClass)) + .addUse(GR->getSPIRVTypeID(RetType)) .add(Info.Callee); for (const auto &Arg : Info.OrigArgs) { @@ -108,5 +218,6 @@ bool SPIRVCallLowering::lowerCall(MachineIRBuilder &MIRBuilder, return false; MIB.addUse(Arg.Regs[0]); } - return true; + return MIB.constrainAllUses(MIRBuilder.getTII(), *ST.getRegisterInfo(), + *ST.getRegBankInfo()); } diff --git a/llvm/lib/Target/SPIRV/SPIRVCallLowering.h b/llvm/lib/Target/SPIRV/SPIRVCallLowering.h index 702198e3225a3..c179bb35154b8 100644 --- a/llvm/lib/Target/SPIRV/SPIRVCallLowering.h +++ b/llvm/lib/Target/SPIRV/SPIRVCallLowering.h @@ -17,12 +17,19 @@ namespace llvm { +class SPIRVGlobalRegistry; +class SPIRVSubtarget; class SPIRVTargetLowering; class SPIRVCallLowering : public CallLowering { private: + const SPIRVSubtarget &ST; + // Used to create and assign function, argument, and return type information. + SPIRVGlobalRegistry *GR; + public: - SPIRVCallLowering(const SPIRVTargetLowering &TLI); + SPIRVCallLowering(const SPIRVTargetLowering &TLI, const SPIRVSubtarget &ST, + SPIRVGlobalRegistry *GR); // Built OpReturn or OpReturnValue. bool lowerReturn(MachineIRBuilder &MIRBuiler, const Value *Val, diff --git a/llvm/lib/Target/SPIRV/SPIRVGlobalRegistry.cpp b/llvm/lib/Target/SPIRV/SPIRVGlobalRegistry.cpp new file mode 100644 index 0000000000000..07633ce265a3d --- /dev/null +++ b/llvm/lib/Target/SPIRV/SPIRVGlobalRegistry.cpp @@ -0,0 +1,453 @@ +//===-- SPIRVGlobalRegistry.cpp - SPIR-V Global Registry --------*- C++ -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This file contains the implementation of the SPIRVGlobalRegistry class, +// which is used to maintain rich type information required for SPIR-V even +// after lowering from LLVM IR to GMIR. It can convert an llvm::Type into +// an OpTypeXXX instruction, and map it to a virtual register. Also it builds +// and supports consistency of constants and global variables. +// +//===----------------------------------------------------------------------===// + +#include "SPIRVGlobalRegistry.h" +#include "SPIRV.h" +#include "SPIRVSubtarget.h" +#include "SPIRVTargetMachine.h" +#include "SPIRVUtils.h" + +using namespace llvm; +SPIRVGlobalRegistry::SPIRVGlobalRegistry(unsigned PointerSize) + : PointerSize(PointerSize) {} + +SPIRVType *SPIRVGlobalRegistry::assignTypeToVReg( + const Type *Type, Register VReg, MachineIRBuilder &MIRBuilder, + SPIRV::AccessQualifier AccessQual, bool EmitIR) { + + SPIRVType *SpirvType = + getOrCreateSPIRVType(Type, MIRBuilder, AccessQual, EmitIR); + assignSPIRVTypeToVReg(SpirvType, VReg, MIRBuilder); + return SpirvType; +} + +void SPIRVGlobalRegistry::assignSPIRVTypeToVReg(SPIRVType *SpirvType, + Register VReg, + MachineIRBuilder &MIRBuilder) { + VRegToTypeMap[&MIRBuilder.getMF()][VReg] = SpirvType; +} + +static Register createTypeVReg(MachineIRBuilder &MIRBuilder) { + auto &MRI = MIRBuilder.getMF().getRegInfo(); + auto Res = MRI.createGenericVirtualRegister(LLT::scalar(32)); + MRI.setRegClass(Res, &SPIRV::TYPERegClass); + return Res; +} + +static Register createTypeVReg(MachineRegisterInfo &MRI) { + auto Res = MRI.createGenericVirtualRegister(LLT::scalar(32)); + MRI.setRegClass(Res, &SPIRV::TYPERegClass); + return Res; +} + +SPIRVType *SPIRVGlobalRegistry::getOpTypeBool(MachineIRBuilder &MIRBuilder) { + return MIRBuilder.buildInstr(SPIRV::OpTypeBool) + .addDef(createTypeVReg(MIRBuilder)); +} + +SPIRVType *SPIRVGlobalRegistry::getOpTypeInt(uint32_t Width, + MachineIRBuilder &MIRBuilder, + bool IsSigned) { + auto MIB = MIRBuilder.buildInstr(SPIRV::OpTypeInt) + .addDef(createTypeVReg(MIRBuilder)) + .addImm(Width) + .addImm(IsSigned ? 1 : 0); + return MIB; +} + +SPIRVType *SPIRVGlobalRegistry::getOpTypeFloat(uint32_t Width, + MachineIRBuilder &MIRBuilder) { + auto MIB = MIRBuilder.buildInstr(SPIRV::OpTypeFloat) + .addDef(createTypeVReg(MIRBuilder)) + .addImm(Width); + return MIB; +} + +SPIRVType *SPIRVGlobalRegistry::getOpTypeVoid(MachineIRBuilder &MIRBuilder) { + return MIRBuilder.buildInstr(SPIRV::OpTypeVoid) + .addDef(createTypeVReg(MIRBuilder)); +} + +SPIRVType *SPIRVGlobalRegistry::getOpTypeVector(uint32_t NumElems, + SPIRVType *ElemType, + MachineIRBuilder &MIRBuilder) { + auto EleOpc = ElemType->getOpcode(); + assert((EleOpc == SPIRV::OpTypeInt || EleOpc == SPIRV::OpTypeFloat || + EleOpc == SPIRV::OpTypeBool) && + "Invalid vector element type"); + + auto MIB = MIRBuilder.buildInstr(SPIRV::OpTypeVector) + .addDef(createTypeVReg(MIRBuilder)) + .addUse(getSPIRVTypeID(ElemType)) + .addImm(NumElems); + return MIB; +} + +Register SPIRVGlobalRegistry::buildConstantInt(uint64_t Val, + MachineIRBuilder &MIRBuilder, + SPIRVType *SpvType, + bool EmitIR) { + auto &MF = MIRBuilder.getMF(); + Register Res; + const IntegerType *LLVMIntTy; + if (SpvType) + LLVMIntTy = cast(getTypeForSPIRVType(SpvType)); + else + LLVMIntTy = IntegerType::getInt32Ty(MF.getFunction().getContext()); + // Find a constant in DT or build a new one. + const auto ConstInt = + ConstantInt::get(const_cast(LLVMIntTy), Val); + unsigned BitWidth = SpvType ? getScalarOrVectorBitWidth(SpvType) : 32; + Res = MF.getRegInfo().createGenericVirtualRegister(LLT::scalar(BitWidth)); + assignTypeToVReg(LLVMIntTy, Res, MIRBuilder); + if (EmitIR) + MIRBuilder.buildConstant(Res, *ConstInt); + else + MIRBuilder.buildInstr(SPIRV::OpConstantI) + .addDef(Res) + .addImm(ConstInt->getSExtValue()); + return Res; +} + +Register SPIRVGlobalRegistry::buildConstantFP(APFloat Val, + MachineIRBuilder &MIRBuilder, + SPIRVType *SpvType) { + auto &MF = MIRBuilder.getMF(); + Register Res; + const Type *LLVMFPTy; + if (SpvType) { + LLVMFPTy = getTypeForSPIRVType(SpvType); + assert(LLVMFPTy->isFloatingPointTy()); + } else { + LLVMFPTy = IntegerType::getFloatTy(MF.getFunction().getContext()); + } + // Find a constant in DT or build a new one. + const auto ConstFP = ConstantFP::get(LLVMFPTy->getContext(), Val); + unsigned BitWidth = SpvType ? getScalarOrVectorBitWidth(SpvType) : 32; + Res = MF.getRegInfo().createGenericVirtualRegister(LLT::scalar(BitWidth)); + assignTypeToVReg(LLVMFPTy, Res, MIRBuilder); + MIRBuilder.buildFConstant(Res, *ConstFP); + return Res; +} + +Register SPIRVGlobalRegistry::buildGlobalVariable( + Register ResVReg, SPIRVType *BaseType, StringRef Name, + const GlobalValue *GV, SPIRV::StorageClass Storage, + const MachineInstr *Init, bool IsConst, bool HasLinkageTy, + SPIRV::LinkageType LinkageType, MachineIRBuilder &MIRBuilder, + bool IsInstSelector) { + const GlobalVariable *GVar = nullptr; + if (GV) + GVar = cast(GV); + else { + // If GV is not passed explicitly, use the name to find or construct + // the global variable. + Module *M = MIRBuilder.getMF().getFunction().getParent(); + GVar = M->getGlobalVariable(Name); + if (GVar == nullptr) { + const Type *Ty = getTypeForSPIRVType(BaseType); // TODO: check type. + GVar = new GlobalVariable(*M, const_cast(Ty), false, + GlobalValue::ExternalLinkage, nullptr, + Twine(Name)); + } + GV = GVar; + } + Register Reg; + auto MIB = MIRBuilder.buildInstr(SPIRV::OpVariable) + .addDef(ResVReg) + .addUse(getSPIRVTypeID(BaseType)) + .addImm(static_cast(Storage)); + + if (Init != 0) { + MIB.addUse(Init->getOperand(0).getReg()); + } + + // ISel may introduce a new register on this step, so we need to add it to + // DT and correct its type avoiding fails on the next stage. + if (IsInstSelector) { + const auto &Subtarget = CurMF->getSubtarget(); + constrainSelectedInstRegOperands(*MIB, *Subtarget.getInstrInfo(), + *Subtarget.getRegisterInfo(), + *Subtarget.getRegBankInfo()); + } + Reg = MIB->getOperand(0).getReg(); + + // Set to Reg the same type as ResVReg has. + auto MRI = MIRBuilder.getMRI(); + assert(MRI->getType(ResVReg).isPointer() && "Pointer type is expected"); + if (Reg != ResVReg) { + LLT RegLLTy = LLT::pointer(MRI->getType(ResVReg).getAddressSpace(), 32); + MRI->setType(Reg, RegLLTy); + assignSPIRVTypeToVReg(BaseType, Reg, MIRBuilder); + } + + // If it's a global variable with name, output OpName for it. + if (GVar && GVar->hasName()) + buildOpName(Reg, GVar->getName(), MIRBuilder); + + // Output decorations for the GV. + // TODO: maybe move to GenerateDecorations pass. + if (IsConst) + buildOpDecorate(Reg, MIRBuilder, SPIRV::Decoration::Constant, {}); + + if (GVar && GVar->getAlign().valueOrOne().value() != 1) + buildOpDecorate( + Reg, MIRBuilder, SPIRV::Decoration::Alignment, + {static_cast(GVar->getAlign().valueOrOne().value())}); + + if (HasLinkageTy) + buildOpDecorate(Reg, MIRBuilder, SPIRV::Decoration::LinkageAttributes, + {static_cast(LinkageType)}, Name); + return Reg; +} + +SPIRVType *SPIRVGlobalRegistry::getOpTypeArray(uint32_t NumElems, + SPIRVType *ElemType, + MachineIRBuilder &MIRBuilder, + bool EmitIR) { + assert((ElemType->getOpcode() != SPIRV::OpTypeVoid) && + "Invalid array element type"); + Register NumElementsVReg = + buildConstantInt(NumElems, MIRBuilder, nullptr, EmitIR); + auto MIB = MIRBuilder.buildInstr(SPIRV::OpTypeArray) + .addDef(createTypeVReg(MIRBuilder)) + .addUse(getSPIRVTypeID(ElemType)) + .addUse(NumElementsVReg); + return MIB; +} + +SPIRVType *SPIRVGlobalRegistry::getOpTypePointer(SPIRV::StorageClass SC, + SPIRVType *ElemType, + MachineIRBuilder &MIRBuilder) { + auto MIB = MIRBuilder.buildInstr(SPIRV::OpTypePointer) + .addDef(createTypeVReg(MIRBuilder)) + .addImm(static_cast(SC)) + .addUse(getSPIRVTypeID(ElemType)); + return MIB; +} + +SPIRVType *SPIRVGlobalRegistry::getOpTypeFunction( + SPIRVType *RetType, const SmallVectorImpl &ArgTypes, + MachineIRBuilder &MIRBuilder) { + auto MIB = MIRBuilder.buildInstr(SPIRV::OpTypeFunction) + .addDef(createTypeVReg(MIRBuilder)) + .addUse(getSPIRVTypeID(RetType)); + for (const SPIRVType *ArgType : ArgTypes) + MIB.addUse(getSPIRVTypeID(ArgType)); + return MIB; +} + +SPIRVType *SPIRVGlobalRegistry::createSPIRVType(const Type *Ty, + MachineIRBuilder &MIRBuilder, + SPIRV::AccessQualifier AccQual, + bool EmitIR) { + if (auto IType = dyn_cast(Ty)) { + const unsigned Width = IType->getBitWidth(); + return Width == 1 ? getOpTypeBool(MIRBuilder) + : getOpTypeInt(Width, MIRBuilder, false); + } + if (Ty->isFloatingPointTy()) + return getOpTypeFloat(Ty->getPrimitiveSizeInBits(), MIRBuilder); + if (Ty->isVoidTy()) + return getOpTypeVoid(MIRBuilder); + if (Ty->isVectorTy()) { + auto El = getOrCreateSPIRVType(cast(Ty)->getElementType(), + MIRBuilder); + return getOpTypeVector(cast(Ty)->getNumElements(), El, + MIRBuilder); + } + if (Ty->isArrayTy()) { + auto *El = getOrCreateSPIRVType(Ty->getArrayElementType(), MIRBuilder); + return getOpTypeArray(Ty->getArrayNumElements(), El, MIRBuilder, EmitIR); + } + assert(!isa(Ty) && "Unsupported StructType"); + if (auto FType = dyn_cast(Ty)) { + SPIRVType *RetTy = getOrCreateSPIRVType(FType->getReturnType(), MIRBuilder); + SmallVector ParamTypes; + for (const auto &t : FType->params()) { + ParamTypes.push_back(getOrCreateSPIRVType(t, MIRBuilder)); + } + return getOpTypeFunction(RetTy, ParamTypes, MIRBuilder); + } + if (auto PType = dyn_cast(Ty)) { + Type *ElemType = PType->getPointerElementType(); + + // Some OpenCL and SPIRV builtins like image2d_t are passed in as pointers, + // but should be treated as custom types like OpTypeImage. + assert(!isa(ElemType) && "Unsupported StructType pointer"); + + // Otherwise, treat it as a regular pointer type. + auto SC = addressSpaceToStorageClass(PType->getAddressSpace()); + SPIRVType *SpvElementType = getOrCreateSPIRVType( + ElemType, MIRBuilder, SPIRV::AccessQualifier::ReadWrite, EmitIR); + return getOpTypePointer(SC, SpvElementType, MIRBuilder); + } + llvm_unreachable("Unable to convert LLVM type to SPIRVType"); +} + +SPIRVType *SPIRVGlobalRegistry::getSPIRVTypeForVReg(Register VReg) const { + auto t = VRegToTypeMap.find(CurMF); + if (t != VRegToTypeMap.end()) { + auto tt = t->second.find(VReg); + if (tt != t->second.end()) + return tt->second; + } + return nullptr; +} + +SPIRVType *SPIRVGlobalRegistry::getOrCreateSPIRVType( + const Type *Type, MachineIRBuilder &MIRBuilder, + SPIRV::AccessQualifier AccessQual, bool EmitIR) { + Register Reg; + SPIRVType *SpirvType = createSPIRVType(Type, MIRBuilder, AccessQual, EmitIR); + VRegToTypeMap[&MIRBuilder.getMF()][getSPIRVTypeID(SpirvType)] = SpirvType; + SPIRVToLLVMType[SpirvType] = Type; + return SpirvType; +} + +bool SPIRVGlobalRegistry::isScalarOfType(Register VReg, + unsigned TypeOpcode) const { + SPIRVType *Type = getSPIRVTypeForVReg(VReg); + assert(Type && "isScalarOfType VReg has no type assigned"); + return Type->getOpcode() == TypeOpcode; +} + +bool SPIRVGlobalRegistry::isScalarOrVectorOfType(Register VReg, + unsigned TypeOpcode) const { + SPIRVType *Type = getSPIRVTypeForVReg(VReg); + assert(Type && "isScalarOrVectorOfType VReg has no type assigned"); + if (Type->getOpcode() == TypeOpcode) + return true; + if (Type->getOpcode() == SPIRV::OpTypeVector) { + Register ScalarTypeVReg = Type->getOperand(1).getReg(); + SPIRVType *ScalarType = getSPIRVTypeForVReg(ScalarTypeVReg); + return ScalarType->getOpcode() == TypeOpcode; + } + return false; +} + +unsigned +SPIRVGlobalRegistry::getScalarOrVectorBitWidth(const SPIRVType *Type) const { + assert(Type && "Invalid Type pointer"); + if (Type->getOpcode() == SPIRV::OpTypeVector) { + auto EleTypeReg = Type->getOperand(1).getReg(); + Type = getSPIRVTypeForVReg(EleTypeReg); + } + if (Type->getOpcode() == SPIRV::OpTypeInt || + Type->getOpcode() == SPIRV::OpTypeFloat) + return Type->getOperand(1).getImm(); + if (Type->getOpcode() == SPIRV::OpTypeBool) + return 1; + llvm_unreachable("Attempting to get bit width of non-integer/float type."); +} + +bool SPIRVGlobalRegistry::isScalarOrVectorSigned(const SPIRVType *Type) const { + assert(Type && "Invalid Type pointer"); + if (Type->getOpcode() == SPIRV::OpTypeVector) { + auto EleTypeReg = Type->getOperand(1).getReg(); + Type = getSPIRVTypeForVReg(EleTypeReg); + } + if (Type->getOpcode() == SPIRV::OpTypeInt) + return Type->getOperand(2).getImm() != 0; + llvm_unreachable("Attempting to get sign of non-integer type."); +} + +SPIRV::StorageClass +SPIRVGlobalRegistry::getPointerStorageClass(Register VReg) const { + SPIRVType *Type = getSPIRVTypeForVReg(VReg); + assert(Type && Type->getOpcode() == SPIRV::OpTypePointer && + Type->getOperand(1).isImm() && "Pointer type is expected"); + return static_cast(Type->getOperand(1).getImm()); +} + +SPIRVType * +SPIRVGlobalRegistry::getOrCreateSPIRVIntegerType(unsigned BitWidth, + MachineIRBuilder &MIRBuilder) { + return getOrCreateSPIRVType( + IntegerType::get(MIRBuilder.getMF().getFunction().getContext(), BitWidth), + MIRBuilder); +} + +SPIRVType *SPIRVGlobalRegistry::restOfCreateSPIRVType(Type *LLVMTy, + MachineInstrBuilder MIB) { + SPIRVType *SpirvType = MIB; + VRegToTypeMap[CurMF][getSPIRVTypeID(SpirvType)] = SpirvType; + SPIRVToLLVMType[SpirvType] = LLVMTy; + return SpirvType; +} + +SPIRVType *SPIRVGlobalRegistry::getOrCreateSPIRVIntegerType( + unsigned BitWidth, MachineInstr &I, const SPIRVInstrInfo &TII) { + Type *LLVMTy = IntegerType::get(CurMF->getFunction().getContext(), BitWidth); + MachineBasicBlock &BB = *I.getParent(); + auto MIB = BuildMI(BB, I, I.getDebugLoc(), TII.get(SPIRV::OpTypeInt)) + .addDef(createTypeVReg(CurMF->getRegInfo())) + .addImm(BitWidth) + .addImm(0); + return restOfCreateSPIRVType(LLVMTy, MIB); +} + +SPIRVType * +SPIRVGlobalRegistry::getOrCreateSPIRVBoolType(MachineIRBuilder &MIRBuilder) { + return getOrCreateSPIRVType( + IntegerType::get(MIRBuilder.getMF().getFunction().getContext(), 1), + MIRBuilder); +} + +SPIRVType *SPIRVGlobalRegistry::getOrCreateSPIRVVectorType( + SPIRVType *BaseType, unsigned NumElements, MachineIRBuilder &MIRBuilder) { + return getOrCreateSPIRVType( + FixedVectorType::get(const_cast(getTypeForSPIRVType(BaseType)), + NumElements), + MIRBuilder); +} + +SPIRVType *SPIRVGlobalRegistry::getOrCreateSPIRVVectorType( + SPIRVType *BaseType, unsigned NumElements, MachineInstr &I, + const SPIRVInstrInfo &TII) { + Type *LLVMTy = FixedVectorType::get( + const_cast(getTypeForSPIRVType(BaseType)), NumElements); + MachineBasicBlock &BB = *I.getParent(); + auto MIB = BuildMI(BB, I, I.getDebugLoc(), TII.get(SPIRV::OpTypeVector)) + .addDef(createTypeVReg(CurMF->getRegInfo())) + .addUse(getSPIRVTypeID(BaseType)) + .addImm(NumElements); + return restOfCreateSPIRVType(LLVMTy, MIB); +} + +SPIRVType * +SPIRVGlobalRegistry::getOrCreateSPIRVPointerType(SPIRVType *BaseType, + MachineIRBuilder &MIRBuilder, + SPIRV::StorageClass SClass) { + return getOrCreateSPIRVType( + PointerType::get(const_cast(getTypeForSPIRVType(BaseType)), + storageClassToAddressSpace(SClass)), + MIRBuilder); +} + +SPIRVType *SPIRVGlobalRegistry::getOrCreateSPIRVPointerType( + SPIRVType *BaseType, MachineInstr &I, const SPIRVInstrInfo &TII, + SPIRV::StorageClass SC) { + Type *LLVMTy = + PointerType::get(const_cast(getTypeForSPIRVType(BaseType)), + storageClassToAddressSpace(SC)); + MachineBasicBlock &BB = *I.getParent(); + auto MIB = BuildMI(BB, I, I.getDebugLoc(), TII.get(SPIRV::OpTypePointer)) + .addDef(createTypeVReg(CurMF->getRegInfo())) + .addImm(static_cast(SC)) + .addUse(getSPIRVTypeID(BaseType)); + return restOfCreateSPIRVType(LLVMTy, MIB); +} diff --git a/llvm/lib/Target/SPIRV/SPIRVGlobalRegistry.h b/llvm/lib/Target/SPIRV/SPIRVGlobalRegistry.h new file mode 100644 index 0000000000000..b6727a6dd73a1 --- /dev/null +++ b/llvm/lib/Target/SPIRV/SPIRVGlobalRegistry.h @@ -0,0 +1,174 @@ +//===-- SPIRVGlobalRegistry.h - SPIR-V Global Registry ----------*- C++ -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// SPIRVGlobalRegistry is used to maintain rich type information required for +// SPIR-V even after lowering from LLVM IR to GMIR. It can convert an llvm::Type +// into an OpTypeXXX instruction, and map it to a virtual register. Also it +// builds and supports consistency of constants and global variables. +// +//===----------------------------------------------------------------------===// + +#ifndef LLVM_LIB_TARGET_SPIRV_SPIRVTYPEMANAGER_H +#define LLVM_LIB_TARGET_SPIRV_SPIRVTYPEMANAGER_H + +#include "MCTargetDesc/SPIRVBaseInfo.h" +#include "SPIRVInstrInfo.h" +#include "llvm/CodeGen/GlobalISel/MachineIRBuilder.h" + +namespace llvm { +using SPIRVType = const MachineInstr; + +class SPIRVGlobalRegistry { + // Registers holding values which have types associated with them. + // Initialized upon VReg definition in IRTranslator. + // Do not confuse this with DuplicatesTracker as DT maps Type* to + // where Reg = OpType... + // while VRegToTypeMap tracks SPIR-V type assigned to other regs (i.e. not + // type-declaring ones) + DenseMap> VRegToTypeMap; + + DenseMap SPIRVToLLVMType; + + // Number of bits pointers and size_t integers require. + const unsigned PointerSize; + + // Add a new OpTypeXXX instruction without checking for duplicates. + SPIRVType * + createSPIRVType(const Type *Type, MachineIRBuilder &MIRBuilder, + SPIRV::AccessQualifier AQ = SPIRV::AccessQualifier::ReadWrite, + bool EmitIR = true); + +public: + SPIRVGlobalRegistry(unsigned PointerSize); + + MachineFunction *CurMF; + + // Get or create a SPIR-V type corresponding the given LLVM IR type, + // and map it to the given VReg by creating an ASSIGN_TYPE instruction. + SPIRVType *assignTypeToVReg( + const Type *Type, Register VReg, MachineIRBuilder &MIRBuilder, + SPIRV::AccessQualifier AQ = SPIRV::AccessQualifier::ReadWrite, + bool EmitIR = true); + + // In cases where the SPIR-V type is already known, this function can be + // used to map it to the given VReg via an ASSIGN_TYPE instruction. + void assignSPIRVTypeToVReg(SPIRVType *Type, Register VReg, + MachineIRBuilder &MIRBuilder); + + // Either generate a new OpTypeXXX instruction or return an existing one + // corresponding to the given LLVM IR type. + // EmitIR controls if we emit GMIR or SPV constants (e.g. for array sizes) + // because this method may be called from InstructionSelector and we don't + // want to emit extra IR instructions there. + SPIRVType *getOrCreateSPIRVType( + const Type *Type, MachineIRBuilder &MIRBuilder, + SPIRV::AccessQualifier AQ = SPIRV::AccessQualifier::ReadWrite, + bool EmitIR = true); + + const Type *getTypeForSPIRVType(const SPIRVType *Ty) const { + auto Res = SPIRVToLLVMType.find(Ty); + assert(Res != SPIRVToLLVMType.end()); + return Res->second; + } + + // Return the SPIR-V type instruction corresponding to the given VReg, or + // nullptr if no such type instruction exists. + SPIRVType *getSPIRVTypeForVReg(Register VReg) const; + + // Whether the given VReg has a SPIR-V type mapped to it yet. + bool hasSPIRVTypeForVReg(Register VReg) const { + return getSPIRVTypeForVReg(VReg) != nullptr; + } + + // Return the VReg holding the result of the given OpTypeXXX instruction. + Register getSPIRVTypeID(const SPIRVType *SpirvType) const { + assert(SpirvType && "Attempting to get type id for nullptr type."); + return SpirvType->defs().begin()->getReg(); + } + + void setCurrentFunc(MachineFunction &MF) { CurMF = &MF; } + + // Whether the given VReg has an OpTypeXXX instruction mapped to it with the + // given opcode (e.g. OpTypeFloat). + bool isScalarOfType(Register VReg, unsigned TypeOpcode) const; + + // Return true if the given VReg's assigned SPIR-V type is either a scalar + // matching the given opcode, or a vector with an element type matching that + // opcode (e.g. OpTypeBool, or OpTypeVector %x 4, where %x is OpTypeBool). + bool isScalarOrVectorOfType(Register VReg, unsigned TypeOpcode) const; + + // For vectors or scalars of ints/floats, return the scalar type's bitwidth. + unsigned getScalarOrVectorBitWidth(const SPIRVType *Type) const; + + // For integer vectors or scalars, return whether the integers are signed. + bool isScalarOrVectorSigned(const SPIRVType *Type) const; + + // Gets the storage class of the pointer type assigned to this vreg. + SPIRV::StorageClass getPointerStorageClass(Register VReg) const; + + // Return the number of bits SPIR-V pointers and size_t variables require. + unsigned getPointerSize() const { return PointerSize; } + +private: + SPIRVType *getOpTypeBool(MachineIRBuilder &MIRBuilder); + + SPIRVType *getOpTypeInt(uint32_t Width, MachineIRBuilder &MIRBuilder, + bool IsSigned = false); + + SPIRVType *getOpTypeFloat(uint32_t Width, MachineIRBuilder &MIRBuilder); + + SPIRVType *getOpTypeVoid(MachineIRBuilder &MIRBuilder); + + SPIRVType *getOpTypeVector(uint32_t NumElems, SPIRVType *ElemType, + MachineIRBuilder &MIRBuilder); + + SPIRVType *getOpTypeArray(uint32_t NumElems, SPIRVType *ElemType, + MachineIRBuilder &MIRBuilder, bool EmitIR = true); + + SPIRVType *getOpTypePointer(SPIRV::StorageClass SC, SPIRVType *ElemType, + MachineIRBuilder &MIRBuilder); + + SPIRVType *getOpTypeFunction(SPIRVType *RetType, + const SmallVectorImpl &ArgTypes, + MachineIRBuilder &MIRBuilder); + SPIRVType *restOfCreateSPIRVType(Type *LLVMTy, MachineInstrBuilder MIB); + +public: + Register buildConstantInt(uint64_t Val, MachineIRBuilder &MIRBuilder, + SPIRVType *SpvType = nullptr, bool EmitIR = true); + Register buildConstantFP(APFloat Val, MachineIRBuilder &MIRBuilder, + SPIRVType *SpvType = nullptr); + Register + buildGlobalVariable(Register Reg, SPIRVType *BaseType, StringRef Name, + const GlobalValue *GV, SPIRV::StorageClass Storage, + const MachineInstr *Init, bool IsConst, bool HasLinkageTy, + SPIRV::LinkageType LinkageType, + MachineIRBuilder &MIRBuilder, bool IsInstSelector); + + // Convenient helpers for getting types with check for duplicates. + SPIRVType *getOrCreateSPIRVIntegerType(unsigned BitWidth, + MachineIRBuilder &MIRBuilder); + SPIRVType *getOrCreateSPIRVIntegerType(unsigned BitWidth, MachineInstr &I, + const SPIRVInstrInfo &TII); + SPIRVType *getOrCreateSPIRVBoolType(MachineIRBuilder &MIRBuilder); + SPIRVType *getOrCreateSPIRVVectorType(SPIRVType *BaseType, + unsigned NumElements, + MachineIRBuilder &MIRBuilder); + SPIRVType *getOrCreateSPIRVVectorType(SPIRVType *BaseType, + unsigned NumElements, MachineInstr &I, + const SPIRVInstrInfo &TII); + + SPIRVType *getOrCreateSPIRVPointerType( + SPIRVType *BaseType, MachineIRBuilder &MIRBuilder, + SPIRV::StorageClass SClass = SPIRV::StorageClass::Function); + SPIRVType *getOrCreateSPIRVPointerType( + SPIRVType *BaseType, MachineInstr &I, const SPIRVInstrInfo &TII, + SPIRV::StorageClass SClass = SPIRV::StorageClass::Function); +}; +} // end namespace llvm +#endif // LLLVM_LIB_TARGET_SPIRV_SPIRVTYPEMANAGER_H diff --git a/llvm/lib/Target/SPIRV/SPIRVInstructionSelector.cpp b/llvm/lib/Target/SPIRV/SPIRVInstructionSelector.cpp new file mode 100644 index 0000000000000..0334d87a93c20 --- /dev/null +++ b/llvm/lib/Target/SPIRV/SPIRVInstructionSelector.cpp @@ -0,0 +1,1089 @@ +//===- SPIRVInstructionSelector.cpp ------------------------------*- C++ -*-==// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This file implements the targeting of the InstructionSelector class for +// SPIRV. +// TODO: This should be generated by TableGen. +// +//===----------------------------------------------------------------------===// + +#include "SPIRV.h" +#include "SPIRVGlobalRegistry.h" +#include "SPIRVInstrInfo.h" +#include "SPIRVRegisterBankInfo.h" +#include "SPIRVRegisterInfo.h" +#include "SPIRVTargetMachine.h" +#include "SPIRVUtils.h" +#include "llvm/ADT/APFloat.h" +#include "llvm/CodeGen/GlobalISel/InstructionSelector.h" +#include "llvm/CodeGen/GlobalISel/InstructionSelectorImpl.h" +#include "llvm/CodeGen/MachineInstrBuilder.h" +#include "llvm/CodeGen/MachineRegisterInfo.h" +#include "llvm/Support/Debug.h" + +#define DEBUG_TYPE "spirv-isel" + +using namespace llvm; + +namespace { + +#define GET_GLOBALISEL_PREDICATE_BITSET +#include "SPIRVGenGlobalISel.inc" +#undef GET_GLOBALISEL_PREDICATE_BITSET + +class SPIRVInstructionSelector : public InstructionSelector { + const SPIRVSubtarget &STI; + const SPIRVInstrInfo &TII; + const SPIRVRegisterInfo &TRI; + const RegisterBankInfo &RBI; + SPIRVGlobalRegistry &GR; + MachineRegisterInfo *MRI; + +public: + SPIRVInstructionSelector(const SPIRVTargetMachine &TM, + const SPIRVSubtarget &ST, + const RegisterBankInfo &RBI); + void setupMF(MachineFunction &MF, GISelKnownBits *KB, + CodeGenCoverage &CoverageInfo, ProfileSummaryInfo *PSI, + BlockFrequencyInfo *BFI) override; + // Common selection code. Instruction-specific selection occurs in spvSelect. + bool select(MachineInstr &I) override; + static const char *getName() { return DEBUG_TYPE; } + +#define GET_GLOBALISEL_PREDICATES_DECL +#include "SPIRVGenGlobalISel.inc" +#undef GET_GLOBALISEL_PREDICATES_DECL + +#define GET_GLOBALISEL_TEMPORARIES_DECL +#include "SPIRVGenGlobalISel.inc" +#undef GET_GLOBALISEL_TEMPORARIES_DECL + +private: + // tblgen-erated 'select' implementation, used as the initial selector for + // the patterns that don't require complex C++. + bool selectImpl(MachineInstr &I, CodeGenCoverage &CoverageInfo) const; + + // All instruction-specific selection that didn't happen in "select()". + // Is basically a large Switch/Case delegating to all other select method. + bool spvSelect(Register ResVReg, const SPIRVType *ResType, + MachineInstr &I) const; + + bool selectGlobalValue(Register ResVReg, MachineInstr &I, + const MachineInstr *Init = nullptr) const; + + bool selectUnOpWithSrc(Register ResVReg, const SPIRVType *ResType, + MachineInstr &I, Register SrcReg, + unsigned Opcode) const; + bool selectUnOp(Register ResVReg, const SPIRVType *ResType, MachineInstr &I, + unsigned Opcode) const; + + bool selectLoad(Register ResVReg, const SPIRVType *ResType, + MachineInstr &I) const; + bool selectStore(MachineInstr &I) const; + + bool selectMemOperation(Register ResVReg, MachineInstr &I) const; + + bool selectAtomicRMW(Register ResVReg, const SPIRVType *ResType, + MachineInstr &I, unsigned NewOpcode) const; + + bool selectAtomicCmpXchg(Register ResVReg, const SPIRVType *ResType, + MachineInstr &I) const; + + bool selectFence(MachineInstr &I) const; + + bool selectAddrSpaceCast(Register ResVReg, const SPIRVType *ResType, + MachineInstr &I) const; + + bool selectBitreverse(Register ResVReg, const SPIRVType *ResType, + MachineInstr &I) const; + + bool selectConstVector(Register ResVReg, const SPIRVType *ResType, + MachineInstr &I) const; + + bool selectCmp(Register ResVReg, const SPIRVType *ResType, + unsigned comparisonOpcode, MachineInstr &I) const; + + bool selectICmp(Register ResVReg, const SPIRVType *ResType, + MachineInstr &I) const; + bool selectFCmp(Register ResVReg, const SPIRVType *ResType, + MachineInstr &I) const; + + void renderImm32(MachineInstrBuilder &MIB, const MachineInstr &I, + int OpIdx) const; + void renderFImm32(MachineInstrBuilder &MIB, const MachineInstr &I, + int OpIdx) const; + + bool selectConst(Register ResVReg, const SPIRVType *ResType, const APInt &Imm, + MachineInstr &I) const; + + bool selectSelect(Register ResVReg, const SPIRVType *ResType, MachineInstr &I, + bool IsSigned) const; + bool selectIToF(Register ResVReg, const SPIRVType *ResType, MachineInstr &I, + bool IsSigned, unsigned Opcode) const; + bool selectExt(Register ResVReg, const SPIRVType *ResType, MachineInstr &I, + bool IsSigned) const; + + bool selectTrunc(Register ResVReg, const SPIRVType *ResType, + MachineInstr &I) const; + + bool selectIntToBool(Register IntReg, Register ResVReg, + const SPIRVType *intTy, const SPIRVType *boolTy, + MachineInstr &I) const; + + bool selectOpUndef(Register ResVReg, const SPIRVType *ResType, + MachineInstr &I) const; + bool selectIntrinsic(Register ResVReg, const SPIRVType *ResType, + MachineInstr &I) const; + + bool selectFrameIndex(Register ResVReg, const SPIRVType *ResType, + MachineInstr &I) const; + + bool selectBranch(MachineInstr &I) const; + bool selectBranchCond(MachineInstr &I) const; + + bool selectPhi(Register ResVReg, const SPIRVType *ResType, + MachineInstr &I) const; + + Register buildI32Constant(uint32_t Val, MachineInstr &I, + const SPIRVType *ResType = nullptr) const; + + Register buildZerosVal(const SPIRVType *ResType, MachineInstr &I) const; + Register buildOnesVal(bool AllOnes, const SPIRVType *ResType, + MachineInstr &I) const; +}; + +} // end anonymous namespace + +#define GET_GLOBALISEL_IMPL +#include "SPIRVGenGlobalISel.inc" +#undef GET_GLOBALISEL_IMPL + +SPIRVInstructionSelector::SPIRVInstructionSelector(const SPIRVTargetMachine &TM, + const SPIRVSubtarget &ST, + const RegisterBankInfo &RBI) + : InstructionSelector(), STI(ST), TII(*ST.getInstrInfo()), + TRI(*ST.getRegisterInfo()), RBI(RBI), GR(*ST.getSPIRVGlobalRegistry()), +#define GET_GLOBALISEL_PREDICATES_INIT +#include "SPIRVGenGlobalISel.inc" +#undef GET_GLOBALISEL_PREDICATES_INIT +#define GET_GLOBALISEL_TEMPORARIES_INIT +#include "SPIRVGenGlobalISel.inc" +#undef GET_GLOBALISEL_TEMPORARIES_INIT +{ +} + +void SPIRVInstructionSelector::setupMF(MachineFunction &MF, GISelKnownBits *KB, + CodeGenCoverage &CoverageInfo, + ProfileSummaryInfo *PSI, + BlockFrequencyInfo *BFI) { + MRI = &MF.getRegInfo(); + GR.setCurrentFunc(MF); + InstructionSelector::setupMF(MF, KB, CoverageInfo, PSI, BFI); +} + +// Defined in SPIRVLegalizerInfo.cpp. +extern bool isTypeFoldingSupported(unsigned Opcode); + +bool SPIRVInstructionSelector::select(MachineInstr &I) { + assert(I.getParent() && "Instruction should be in a basic block!"); + assert(I.getParent()->getParent() && "Instruction should be in a function!"); + + Register Opcode = I.getOpcode(); + // If it's not a GMIR instruction, we've selected it already. + if (!isPreISelGenericOpcode(Opcode)) { + if (Opcode == SPIRV::ASSIGN_TYPE) { // These pseudos aren't needed any more. + auto *Def = MRI->getVRegDef(I.getOperand(1).getReg()); + if (isTypeFoldingSupported(Def->getOpcode())) { + auto Res = selectImpl(I, *CoverageInfo); + assert(Res || Def->getOpcode() == TargetOpcode::G_CONSTANT); + if (Res) + return Res; + } + MRI->replaceRegWith(I.getOperand(1).getReg(), I.getOperand(0).getReg()); + I.removeFromParent(); + } else if (I.getNumDefs() == 1) { + // Make all vregs 32 bits (for SPIR-V IDs). + MRI->setType(I.getOperand(0).getReg(), LLT::scalar(32)); + } + return true; + } + + if (I.getNumOperands() != I.getNumExplicitOperands()) { + LLVM_DEBUG(errs() << "Generic instr has unexpected implicit operands\n"); + return false; + } + + // Common code for getting return reg+type, and removing selected instr + // from parent occurs here. Instr-specific selection happens in spvSelect(). + bool HasDefs = I.getNumDefs() > 0; + Register ResVReg = HasDefs ? I.getOperand(0).getReg() : Register(0); + SPIRVType *ResType = HasDefs ? GR.getSPIRVTypeForVReg(ResVReg) : nullptr; + assert(!HasDefs || ResType || I.getOpcode() == TargetOpcode::G_GLOBAL_VALUE); + if (spvSelect(ResVReg, ResType, I)) { + if (HasDefs) // Make all vregs 32 bits (for SPIR-V IDs). + MRI->setType(ResVReg, LLT::scalar(32)); + I.removeFromParent(); + return true; + } + return false; +} + +bool SPIRVInstructionSelector::spvSelect(Register ResVReg, + const SPIRVType *ResType, + MachineInstr &I) const { + assert(!isTypeFoldingSupported(I.getOpcode()) || + I.getOpcode() == TargetOpcode::G_CONSTANT); + const unsigned Opcode = I.getOpcode(); + switch (Opcode) { + case TargetOpcode::G_CONSTANT: + return selectConst(ResVReg, ResType, I.getOperand(1).getCImm()->getValue(), + I); + case TargetOpcode::G_GLOBAL_VALUE: + return selectGlobalValue(ResVReg, I); + case TargetOpcode::G_IMPLICIT_DEF: + return selectOpUndef(ResVReg, ResType, I); + + case TargetOpcode::G_INTRINSIC_W_SIDE_EFFECTS: + return selectIntrinsic(ResVReg, ResType, I); + case TargetOpcode::G_BITREVERSE: + return selectBitreverse(ResVReg, ResType, I); + + case TargetOpcode::G_BUILD_VECTOR: + return selectConstVector(ResVReg, ResType, I); + + case TargetOpcode::G_SHUFFLE_VECTOR: { + MachineBasicBlock &BB = *I.getParent(); + auto MIB = BuildMI(BB, I, I.getDebugLoc(), TII.get(SPIRV::OpVectorShuffle)) + .addDef(ResVReg) + .addUse(GR.getSPIRVTypeID(ResType)) + .addUse(I.getOperand(1).getReg()) + .addUse(I.getOperand(2).getReg()); + for (auto V : I.getOperand(3).getShuffleMask()) + MIB.addImm(V); + return MIB.constrainAllUses(TII, TRI, RBI); + } + case TargetOpcode::G_MEMMOVE: + case TargetOpcode::G_MEMCPY: + return selectMemOperation(ResVReg, I); + + case TargetOpcode::G_ICMP: + return selectICmp(ResVReg, ResType, I); + case TargetOpcode::G_FCMP: + return selectFCmp(ResVReg, ResType, I); + + case TargetOpcode::G_FRAME_INDEX: + return selectFrameIndex(ResVReg, ResType, I); + + case TargetOpcode::G_LOAD: + return selectLoad(ResVReg, ResType, I); + case TargetOpcode::G_STORE: + return selectStore(I); + + case TargetOpcode::G_BR: + return selectBranch(I); + case TargetOpcode::G_BRCOND: + return selectBranchCond(I); + + case TargetOpcode::G_PHI: + return selectPhi(ResVReg, ResType, I); + + case TargetOpcode::G_FPTOSI: + return selectUnOp(ResVReg, ResType, I, SPIRV::OpConvertFToS); + case TargetOpcode::G_FPTOUI: + return selectUnOp(ResVReg, ResType, I, SPIRV::OpConvertFToU); + + case TargetOpcode::G_SITOFP: + return selectIToF(ResVReg, ResType, I, true, SPIRV::OpConvertSToF); + case TargetOpcode::G_UITOFP: + return selectIToF(ResVReg, ResType, I, false, SPIRV::OpConvertUToF); + + case TargetOpcode::G_CTPOP: + return selectUnOp(ResVReg, ResType, I, SPIRV::OpBitCount); + + case TargetOpcode::G_SEXT: + return selectExt(ResVReg, ResType, I, true); + case TargetOpcode::G_ANYEXT: + case TargetOpcode::G_ZEXT: + return selectExt(ResVReg, ResType, I, false); + case TargetOpcode::G_TRUNC: + return selectTrunc(ResVReg, ResType, I); + case TargetOpcode::G_FPTRUNC: + case TargetOpcode::G_FPEXT: + return selectUnOp(ResVReg, ResType, I, SPIRV::OpFConvert); + + case TargetOpcode::G_PTRTOINT: + return selectUnOp(ResVReg, ResType, I, SPIRV::OpConvertPtrToU); + case TargetOpcode::G_INTTOPTR: + return selectUnOp(ResVReg, ResType, I, SPIRV::OpConvertUToPtr); + case TargetOpcode::G_BITCAST: + return selectUnOp(ResVReg, ResType, I, SPIRV::OpBitcast); + case TargetOpcode::G_ADDRSPACE_CAST: + return selectAddrSpaceCast(ResVReg, ResType, I); + + case TargetOpcode::G_ATOMICRMW_OR: + return selectAtomicRMW(ResVReg, ResType, I, SPIRV::OpAtomicOr); + case TargetOpcode::G_ATOMICRMW_ADD: + return selectAtomicRMW(ResVReg, ResType, I, SPIRV::OpAtomicIAdd); + case TargetOpcode::G_ATOMICRMW_AND: + return selectAtomicRMW(ResVReg, ResType, I, SPIRV::OpAtomicAnd); + case TargetOpcode::G_ATOMICRMW_MAX: + return selectAtomicRMW(ResVReg, ResType, I, SPIRV::OpAtomicSMax); + case TargetOpcode::G_ATOMICRMW_MIN: + return selectAtomicRMW(ResVReg, ResType, I, SPIRV::OpAtomicSMin); + case TargetOpcode::G_ATOMICRMW_SUB: + return selectAtomicRMW(ResVReg, ResType, I, SPIRV::OpAtomicISub); + case TargetOpcode::G_ATOMICRMW_XOR: + return selectAtomicRMW(ResVReg, ResType, I, SPIRV::OpAtomicXor); + case TargetOpcode::G_ATOMICRMW_UMAX: + return selectAtomicRMW(ResVReg, ResType, I, SPIRV::OpAtomicUMax); + case TargetOpcode::G_ATOMICRMW_UMIN: + return selectAtomicRMW(ResVReg, ResType, I, SPIRV::OpAtomicUMin); + case TargetOpcode::G_ATOMICRMW_XCHG: + return selectAtomicRMW(ResVReg, ResType, I, SPIRV::OpAtomicExchange); + case TargetOpcode::G_ATOMIC_CMPXCHG: + return selectAtomicCmpXchg(ResVReg, ResType, I); + + case TargetOpcode::G_FENCE: + return selectFence(I); + + default: + return false; + } +} + +bool SPIRVInstructionSelector::selectUnOpWithSrc(Register ResVReg, + const SPIRVType *ResType, + MachineInstr &I, + Register SrcReg, + unsigned Opcode) const { + return BuildMI(*I.getParent(), I, I.getDebugLoc(), TII.get(Opcode)) + .addDef(ResVReg) + .addUse(GR.getSPIRVTypeID(ResType)) + .addUse(SrcReg) + .constrainAllUses(TII, TRI, RBI); +} + +bool SPIRVInstructionSelector::selectUnOp(Register ResVReg, + const SPIRVType *ResType, + MachineInstr &I, + unsigned Opcode) const { + return selectUnOpWithSrc(ResVReg, ResType, I, I.getOperand(1).getReg(), + Opcode); +} + +static SPIRV::MemorySemantics getMemSemantics(AtomicOrdering Ord) { + switch (Ord) { + case AtomicOrdering::Acquire: + return SPIRV::MemorySemantics::Acquire; + case AtomicOrdering::Release: + return SPIRV::MemorySemantics::Release; + case AtomicOrdering::AcquireRelease: + return SPIRV::MemorySemantics::AcquireRelease; + case AtomicOrdering::SequentiallyConsistent: + return SPIRV::MemorySemantics::SequentiallyConsistent; + case AtomicOrdering::Unordered: + case AtomicOrdering::Monotonic: + case AtomicOrdering::NotAtomic: + default: + return SPIRV::MemorySemantics::None; + } +} + +static SPIRV::Scope getScope(SyncScope::ID Ord) { + switch (Ord) { + case SyncScope::SingleThread: + return SPIRV::Scope::Invocation; + case SyncScope::System: + return SPIRV::Scope::Device; + default: + llvm_unreachable("Unsupported synchronization Scope ID."); + } +} + +static void addMemoryOperands(MachineMemOperand *MemOp, + MachineInstrBuilder &MIB) { + uint32_t SpvMemOp = static_cast(SPIRV::MemoryOperand::None); + if (MemOp->isVolatile()) + SpvMemOp |= static_cast(SPIRV::MemoryOperand::Volatile); + if (MemOp->isNonTemporal()) + SpvMemOp |= static_cast(SPIRV::MemoryOperand::Nontemporal); + if (MemOp->getAlign().value()) + SpvMemOp |= static_cast(SPIRV::MemoryOperand::Aligned); + + if (SpvMemOp != static_cast(SPIRV::MemoryOperand::None)) { + MIB.addImm(SpvMemOp); + if (SpvMemOp & static_cast(SPIRV::MemoryOperand::Aligned)) + MIB.addImm(MemOp->getAlign().value()); + } +} + +static void addMemoryOperands(uint64_t Flags, MachineInstrBuilder &MIB) { + uint32_t SpvMemOp = static_cast(SPIRV::MemoryOperand::None); + if (Flags & MachineMemOperand::Flags::MOVolatile) + SpvMemOp |= static_cast(SPIRV::MemoryOperand::Volatile); + if (Flags & MachineMemOperand::Flags::MONonTemporal) + SpvMemOp |= static_cast(SPIRV::MemoryOperand::Nontemporal); + + if (SpvMemOp != static_cast(SPIRV::MemoryOperand::None)) + MIB.addImm(SpvMemOp); +} + +bool SPIRVInstructionSelector::selectLoad(Register ResVReg, + const SPIRVType *ResType, + MachineInstr &I) const { + unsigned OpOffset = + I.getOpcode() == TargetOpcode::G_INTRINSIC_W_SIDE_EFFECTS ? 1 : 0; + Register Ptr = I.getOperand(1 + OpOffset).getReg(); + auto MIB = BuildMI(*I.getParent(), I, I.getDebugLoc(), TII.get(SPIRV::OpLoad)) + .addDef(ResVReg) + .addUse(GR.getSPIRVTypeID(ResType)) + .addUse(Ptr); + if (!I.getNumMemOperands()) { + assert(I.getOpcode() == TargetOpcode::G_INTRINSIC_W_SIDE_EFFECTS); + addMemoryOperands(I.getOperand(2 + OpOffset).getImm(), MIB); + } else { + addMemoryOperands(*I.memoperands_begin(), MIB); + } + return MIB.constrainAllUses(TII, TRI, RBI); +} + +bool SPIRVInstructionSelector::selectStore(MachineInstr &I) const { + unsigned OpOffset = + I.getOpcode() == TargetOpcode::G_INTRINSIC_W_SIDE_EFFECTS ? 1 : 0; + Register StoreVal = I.getOperand(0 + OpOffset).getReg(); + Register Ptr = I.getOperand(1 + OpOffset).getReg(); + MachineBasicBlock &BB = *I.getParent(); + auto MIB = BuildMI(BB, I, I.getDebugLoc(), TII.get(SPIRV::OpStore)) + .addUse(Ptr) + .addUse(StoreVal); + if (!I.getNumMemOperands()) { + assert(I.getOpcode() == TargetOpcode::G_INTRINSIC_W_SIDE_EFFECTS); + addMemoryOperands(I.getOperand(2 + OpOffset).getImm(), MIB); + } else { + addMemoryOperands(*I.memoperands_begin(), MIB); + } + return MIB.constrainAllUses(TII, TRI, RBI); +} + +bool SPIRVInstructionSelector::selectMemOperation(Register ResVReg, + MachineInstr &I) const { + MachineBasicBlock &BB = *I.getParent(); + auto MIB = BuildMI(BB, I, I.getDebugLoc(), TII.get(SPIRV::OpCopyMemorySized)) + .addDef(I.getOperand(0).getReg()) + .addUse(I.getOperand(1).getReg()) + .addUse(I.getOperand(2).getReg()); + if (I.getNumMemOperands()) + addMemoryOperands(*I.memoperands_begin(), MIB); + bool Result = MIB.constrainAllUses(TII, TRI, RBI); + if (ResVReg.isValid() && ResVReg != MIB->getOperand(0).getReg()) { + BuildMI(BB, I, I.getDebugLoc(), TII.get(TargetOpcode::COPY), ResVReg) + .addUse(MIB->getOperand(0).getReg()); + } + return Result; +} + +bool SPIRVInstructionSelector::selectAtomicRMW(Register ResVReg, + const SPIRVType *ResType, + MachineInstr &I, + unsigned NewOpcode) const { + assert(I.hasOneMemOperand()); + const MachineMemOperand *MemOp = *I.memoperands_begin(); + uint32_t Scope = static_cast(getScope(MemOp->getSyncScopeID())); + Register ScopeReg = buildI32Constant(Scope, I); + + Register Ptr = I.getOperand(1).getReg(); + // TODO: Changed as it's implemented in the translator. See test/atomicrmw.ll + // auto ScSem = + // getMemSemanticsForStorageClass(GR.getPointerStorageClass(Ptr)); + AtomicOrdering AO = MemOp->getSuccessOrdering(); + uint32_t MemSem = static_cast(getMemSemantics(AO)); + Register MemSemReg = buildI32Constant(MemSem /*| ScSem*/, I); + + return BuildMI(*I.getParent(), I, I.getDebugLoc(), TII.get(NewOpcode)) + .addDef(ResVReg) + .addUse(GR.getSPIRVTypeID(ResType)) + .addUse(Ptr) + .addUse(ScopeReg) + .addUse(MemSemReg) + .addUse(I.getOperand(2).getReg()) + .constrainAllUses(TII, TRI, RBI); +} + +bool SPIRVInstructionSelector::selectFence(MachineInstr &I) const { + AtomicOrdering AO = AtomicOrdering(I.getOperand(0).getImm()); + uint32_t MemSem = static_cast(getMemSemantics(AO)); + Register MemSemReg = buildI32Constant(MemSem, I); + SyncScope::ID Ord = SyncScope::ID(I.getOperand(1).getImm()); + uint32_t Scope = static_cast(getScope(Ord)); + Register ScopeReg = buildI32Constant(Scope, I); + MachineBasicBlock &BB = *I.getParent(); + return BuildMI(BB, I, I.getDebugLoc(), TII.get(SPIRV::OpMemoryBarrier)) + .addUse(ScopeReg) + .addUse(MemSemReg) + .constrainAllUses(TII, TRI, RBI); +} + +bool SPIRVInstructionSelector::selectAtomicCmpXchg(Register ResVReg, + const SPIRVType *ResType, + MachineInstr &I) const { + assert(I.hasOneMemOperand()); + const MachineMemOperand *MemOp = *I.memoperands_begin(); + uint32_t Scope = static_cast(getScope(MemOp->getSyncScopeID())); + Register ScopeReg = buildI32Constant(Scope, I); + + Register Ptr = I.getOperand(2).getReg(); + Register Cmp = I.getOperand(3).getReg(); + Register Val = I.getOperand(4).getReg(); + + SPIRVType *SpvValTy = GR.getSPIRVTypeForVReg(Val); + SPIRV::StorageClass SC = GR.getPointerStorageClass(Ptr); + uint32_t ScSem = static_cast(getMemSemanticsForStorageClass(SC)); + AtomicOrdering AO = MemOp->getSuccessOrdering(); + uint32_t MemSemEq = static_cast(getMemSemantics(AO)) | ScSem; + Register MemSemEqReg = buildI32Constant(MemSemEq, I); + AtomicOrdering FO = MemOp->getFailureOrdering(); + uint32_t MemSemNeq = static_cast(getMemSemantics(FO)) | ScSem; + Register MemSemNeqReg = + MemSemEq == MemSemNeq ? MemSemEqReg : buildI32Constant(MemSemNeq, I); + const DebugLoc &DL = I.getDebugLoc(); + return BuildMI(*I.getParent(), I, DL, TII.get(SPIRV::OpAtomicCompareExchange)) + .addDef(ResVReg) + .addUse(GR.getSPIRVTypeID(SpvValTy)) + .addUse(Ptr) + .addUse(ScopeReg) + .addUse(MemSemEqReg) + .addUse(MemSemNeqReg) + .addUse(Val) + .addUse(Cmp) + .constrainAllUses(TII, TRI, RBI); +} + +static bool isGenericCastablePtr(SPIRV::StorageClass SC) { + switch (SC) { + case SPIRV::StorageClass::Workgroup: + case SPIRV::StorageClass::CrossWorkgroup: + case SPIRV::StorageClass::Function: + return true; + default: + return false; + } +} + +// In SPIR-V address space casting can only happen to and from the Generic +// storage class. We can also only case Workgroup, CrossWorkgroup, or Function +// pointers to and from Generic pointers. As such, we can convert e.g. from +// Workgroup to Function by going via a Generic pointer as an intermediary. All +// other combinations can only be done by a bitcast, and are probably not safe. +bool SPIRVInstructionSelector::selectAddrSpaceCast(Register ResVReg, + const SPIRVType *ResType, + MachineInstr &I) const { + Register SrcPtr = I.getOperand(1).getReg(); + SPIRVType *SrcPtrTy = GR.getSPIRVTypeForVReg(SrcPtr); + SPIRV::StorageClass SrcSC = GR.getPointerStorageClass(SrcPtr); + SPIRV::StorageClass DstSC = GR.getPointerStorageClass(ResVReg); + + // Casting from an eligable pointer to Generic. + if (DstSC == SPIRV::StorageClass::Generic && isGenericCastablePtr(SrcSC)) + return selectUnOp(ResVReg, ResType, I, SPIRV::OpPtrCastToGeneric); + // Casting from Generic to an eligable pointer. + if (SrcSC == SPIRV::StorageClass::Generic && isGenericCastablePtr(DstSC)) + return selectUnOp(ResVReg, ResType, I, SPIRV::OpGenericCastToPtr); + // Casting between 2 eligable pointers using Generic as an intermediary. + if (isGenericCastablePtr(SrcSC) && isGenericCastablePtr(DstSC)) { + Register Tmp = MRI->createVirtualRegister(&SPIRV::IDRegClass); + SPIRVType *GenericPtrTy = GR.getOrCreateSPIRVPointerType( + SrcPtrTy, I, TII, SPIRV::StorageClass::Generic); + MachineBasicBlock &BB = *I.getParent(); + const DebugLoc &DL = I.getDebugLoc(); + bool Success = BuildMI(BB, I, DL, TII.get(SPIRV::OpPtrCastToGeneric)) + .addDef(Tmp) + .addUse(GR.getSPIRVTypeID(GenericPtrTy)) + .addUse(SrcPtr) + .constrainAllUses(TII, TRI, RBI); + return Success && BuildMI(BB, I, DL, TII.get(SPIRV::OpGenericCastToPtr)) + .addDef(ResVReg) + .addUse(GR.getSPIRVTypeID(ResType)) + .addUse(Tmp) + .constrainAllUses(TII, TRI, RBI); + } + // TODO Should this case just be disallowed completely? + // We're casting 2 other arbitrary address spaces, so have to bitcast. + return selectUnOp(ResVReg, ResType, I, SPIRV::OpBitcast); +} + +static unsigned getFCmpOpcode(unsigned PredNum) { + auto Pred = static_cast(PredNum); + switch (Pred) { + case CmpInst::FCMP_OEQ: + return SPIRV::OpFOrdEqual; + case CmpInst::FCMP_OGE: + return SPIRV::OpFOrdGreaterThanEqual; + case CmpInst::FCMP_OGT: + return SPIRV::OpFOrdGreaterThan; + case CmpInst::FCMP_OLE: + return SPIRV::OpFOrdLessThanEqual; + case CmpInst::FCMP_OLT: + return SPIRV::OpFOrdLessThan; + case CmpInst::FCMP_ONE: + return SPIRV::OpFOrdNotEqual; + case CmpInst::FCMP_ORD: + return SPIRV::OpOrdered; + case CmpInst::FCMP_UEQ: + return SPIRV::OpFUnordEqual; + case CmpInst::FCMP_UGE: + return SPIRV::OpFUnordGreaterThanEqual; + case CmpInst::FCMP_UGT: + return SPIRV::OpFUnordGreaterThan; + case CmpInst::FCMP_ULE: + return SPIRV::OpFUnordLessThanEqual; + case CmpInst::FCMP_ULT: + return SPIRV::OpFUnordLessThan; + case CmpInst::FCMP_UNE: + return SPIRV::OpFUnordNotEqual; + case CmpInst::FCMP_UNO: + return SPIRV::OpUnordered; + default: + llvm_unreachable("Unknown predicate type for FCmp"); + } +} + +static unsigned getICmpOpcode(unsigned PredNum) { + auto Pred = static_cast(PredNum); + switch (Pred) { + case CmpInst::ICMP_EQ: + return SPIRV::OpIEqual; + case CmpInst::ICMP_NE: + return SPIRV::OpINotEqual; + case CmpInst::ICMP_SGE: + return SPIRV::OpSGreaterThanEqual; + case CmpInst::ICMP_SGT: + return SPIRV::OpSGreaterThan; + case CmpInst::ICMP_SLE: + return SPIRV::OpSLessThanEqual; + case CmpInst::ICMP_SLT: + return SPIRV::OpSLessThan; + case CmpInst::ICMP_UGE: + return SPIRV::OpUGreaterThanEqual; + case CmpInst::ICMP_UGT: + return SPIRV::OpUGreaterThan; + case CmpInst::ICMP_ULE: + return SPIRV::OpULessThanEqual; + case CmpInst::ICMP_ULT: + return SPIRV::OpULessThan; + default: + llvm_unreachable("Unknown predicate type for ICmp"); + } +} + +static unsigned getPtrCmpOpcode(unsigned Pred) { + switch (static_cast(Pred)) { + case CmpInst::ICMP_EQ: + return SPIRV::OpPtrEqual; + case CmpInst::ICMP_NE: + return SPIRV::OpPtrNotEqual; + default: + llvm_unreachable("Unknown predicate type for pointer comparison"); + } +} + +// Return the logical operation, or abort if none exists. +static unsigned getBoolCmpOpcode(unsigned PredNum) { + auto Pred = static_cast(PredNum); + switch (Pred) { + case CmpInst::ICMP_EQ: + return SPIRV::OpLogicalEqual; + case CmpInst::ICMP_NE: + return SPIRV::OpLogicalNotEqual; + default: + llvm_unreachable("Unknown predicate type for Bool comparison"); + } +} + +bool SPIRVInstructionSelector::selectBitreverse(Register ResVReg, + const SPIRVType *ResType, + MachineInstr &I) const { + MachineBasicBlock &BB = *I.getParent(); + return BuildMI(BB, I, I.getDebugLoc(), TII.get(SPIRV::OpBitReverse)) + .addDef(ResVReg) + .addUse(GR.getSPIRVTypeID(ResType)) + .addUse(I.getOperand(1).getReg()) + .constrainAllUses(TII, TRI, RBI); +} + +bool SPIRVInstructionSelector::selectConstVector(Register ResVReg, + const SPIRVType *ResType, + MachineInstr &I) const { + // TODO: only const case is supported for now. + assert(std::all_of( + I.operands_begin(), I.operands_end(), [this](const MachineOperand &MO) { + if (MO.isDef()) + return true; + if (!MO.isReg()) + return false; + SPIRVType *ConstTy = this->MRI->getVRegDef(MO.getReg()); + assert(ConstTy && ConstTy->getOpcode() == SPIRV::ASSIGN_TYPE && + ConstTy->getOperand(1).isReg()); + Register ConstReg = ConstTy->getOperand(1).getReg(); + const MachineInstr *Const = this->MRI->getVRegDef(ConstReg); + assert(Const); + return (Const->getOpcode() == TargetOpcode::G_CONSTANT || + Const->getOpcode() == TargetOpcode::G_FCONSTANT); + })); + + auto MIB = BuildMI(*I.getParent(), I, I.getDebugLoc(), + TII.get(SPIRV::OpConstantComposite)) + .addDef(ResVReg) + .addUse(GR.getSPIRVTypeID(ResType)); + for (unsigned i = I.getNumExplicitDefs(); i < I.getNumExplicitOperands(); ++i) + MIB.addUse(I.getOperand(i).getReg()); + return MIB.constrainAllUses(TII, TRI, RBI); +} + +bool SPIRVInstructionSelector::selectCmp(Register ResVReg, + const SPIRVType *ResType, + unsigned CmpOpc, + MachineInstr &I) const { + Register Cmp0 = I.getOperand(2).getReg(); + Register Cmp1 = I.getOperand(3).getReg(); + assert(GR.getSPIRVTypeForVReg(Cmp0)->getOpcode() == + GR.getSPIRVTypeForVReg(Cmp1)->getOpcode() && + "CMP operands should have the same type"); + return BuildMI(*I.getParent(), I, I.getDebugLoc(), TII.get(CmpOpc)) + .addDef(ResVReg) + .addUse(GR.getSPIRVTypeID(ResType)) + .addUse(Cmp0) + .addUse(Cmp1) + .constrainAllUses(TII, TRI, RBI); +} + +bool SPIRVInstructionSelector::selectICmp(Register ResVReg, + const SPIRVType *ResType, + MachineInstr &I) const { + auto Pred = I.getOperand(1).getPredicate(); + unsigned CmpOpc; + + Register CmpOperand = I.getOperand(2).getReg(); + if (GR.isScalarOfType(CmpOperand, SPIRV::OpTypePointer)) + CmpOpc = getPtrCmpOpcode(Pred); + else if (GR.isScalarOrVectorOfType(CmpOperand, SPIRV::OpTypeBool)) + CmpOpc = getBoolCmpOpcode(Pred); + else + CmpOpc = getICmpOpcode(Pred); + return selectCmp(ResVReg, ResType, CmpOpc, I); +} + +void SPIRVInstructionSelector::renderFImm32(MachineInstrBuilder &MIB, + const MachineInstr &I, + int OpIdx) const { + assert(I.getOpcode() == TargetOpcode::G_FCONSTANT && OpIdx == -1 && + "Expected G_FCONSTANT"); + const ConstantFP *FPImm = I.getOperand(1).getFPImm(); + addNumImm(FPImm->getValueAPF().bitcastToAPInt(), MIB); +} + +void SPIRVInstructionSelector::renderImm32(MachineInstrBuilder &MIB, + const MachineInstr &I, + int OpIdx) const { + assert(I.getOpcode() == TargetOpcode::G_CONSTANT && OpIdx == -1 && + "Expected G_CONSTANT"); + addNumImm(I.getOperand(1).getCImm()->getValue(), MIB); +} + +Register +SPIRVInstructionSelector::buildI32Constant(uint32_t Val, MachineInstr &I, + const SPIRVType *ResType) const { + const SPIRVType *SpvI32Ty = + ResType ? ResType : GR.getOrCreateSPIRVIntegerType(32, I, TII); + Register NewReg; + NewReg = MRI->createGenericVirtualRegister(LLT::scalar(32)); + MachineInstr *MI; + MachineBasicBlock &BB = *I.getParent(); + if (Val == 0) { + MI = BuildMI(BB, I, I.getDebugLoc(), TII.get(SPIRV::OpConstantNull)) + .addDef(NewReg) + .addUse(GR.getSPIRVTypeID(SpvI32Ty)); + } else { + MI = BuildMI(BB, I, I.getDebugLoc(), TII.get(SPIRV::OpConstantI)) + .addDef(NewReg) + .addUse(GR.getSPIRVTypeID(SpvI32Ty)) + .addImm(APInt(32, Val).getZExtValue()); + } + constrainSelectedInstRegOperands(*MI, TII, TRI, RBI); + return NewReg; +} + +bool SPIRVInstructionSelector::selectFCmp(Register ResVReg, + const SPIRVType *ResType, + MachineInstr &I) const { + unsigned CmpOp = getFCmpOpcode(I.getOperand(1).getPredicate()); + return selectCmp(ResVReg, ResType, CmpOp, I); +} + +Register SPIRVInstructionSelector::buildZerosVal(const SPIRVType *ResType, + MachineInstr &I) const { + return buildI32Constant(0, I, ResType); +} + +Register SPIRVInstructionSelector::buildOnesVal(bool AllOnes, + const SPIRVType *ResType, + MachineInstr &I) const { + unsigned BitWidth = GR.getScalarOrVectorBitWidth(ResType); + APInt One = AllOnes ? APInt::getAllOnesValue(BitWidth) + : APInt::getOneBitSet(BitWidth, 0); + Register OneReg = buildI32Constant(One.getZExtValue(), I, ResType); + if (ResType->getOpcode() == SPIRV::OpTypeVector) { + const unsigned NumEles = ResType->getOperand(2).getImm(); + Register OneVec = MRI->createVirtualRegister(&SPIRV::IDRegClass); + unsigned Opcode = SPIRV::OpConstantComposite; + auto MIB = BuildMI(*I.getParent(), I, I.getDebugLoc(), TII.get(Opcode)) + .addDef(OneVec) + .addUse(GR.getSPIRVTypeID(ResType)); + for (unsigned i = 0; i < NumEles; ++i) + MIB.addUse(OneReg); + constrainSelectedInstRegOperands(*MIB, TII, TRI, RBI); + return OneVec; + } + return OneReg; +} + +bool SPIRVInstructionSelector::selectSelect(Register ResVReg, + const SPIRVType *ResType, + MachineInstr &I, + bool IsSigned) const { + // To extend a bool, we need to use OpSelect between constants. + Register ZeroReg = buildZerosVal(ResType, I); + Register OneReg = buildOnesVal(IsSigned, ResType, I); + bool IsScalarBool = + GR.isScalarOfType(I.getOperand(1).getReg(), SPIRV::OpTypeBool); + unsigned Opcode = + IsScalarBool ? SPIRV::OpSelectSISCond : SPIRV::OpSelectSIVCond; + return BuildMI(*I.getParent(), I, I.getDebugLoc(), TII.get(Opcode)) + .addDef(ResVReg) + .addUse(GR.getSPIRVTypeID(ResType)) + .addUse(I.getOperand(1).getReg()) + .addUse(OneReg) + .addUse(ZeroReg) + .constrainAllUses(TII, TRI, RBI); +} + +bool SPIRVInstructionSelector::selectIToF(Register ResVReg, + const SPIRVType *ResType, + MachineInstr &I, bool IsSigned, + unsigned Opcode) const { + Register SrcReg = I.getOperand(1).getReg(); + // We can convert bool value directly to float type without OpConvert*ToF, + // however the translator generates OpSelect+OpConvert*ToF, so we do the same. + if (GR.isScalarOrVectorOfType(I.getOperand(1).getReg(), SPIRV::OpTypeBool)) { + unsigned BitWidth = GR.getScalarOrVectorBitWidth(ResType); + SPIRVType *TmpType = GR.getOrCreateSPIRVIntegerType(BitWidth, I, TII); + if (ResType->getOpcode() == SPIRV::OpTypeVector) { + const unsigned NumElts = ResType->getOperand(2).getImm(); + TmpType = GR.getOrCreateSPIRVVectorType(TmpType, NumElts, I, TII); + } + SrcReg = MRI->createVirtualRegister(&SPIRV::IDRegClass); + selectSelect(SrcReg, TmpType, I, false); + } + return selectUnOpWithSrc(ResVReg, ResType, I, SrcReg, Opcode); +} + +bool SPIRVInstructionSelector::selectExt(Register ResVReg, + const SPIRVType *ResType, + MachineInstr &I, bool IsSigned) const { + if (GR.isScalarOrVectorOfType(I.getOperand(1).getReg(), SPIRV::OpTypeBool)) + return selectSelect(ResVReg, ResType, I, IsSigned); + unsigned Opcode = IsSigned ? SPIRV::OpSConvert : SPIRV::OpUConvert; + return selectUnOp(ResVReg, ResType, I, Opcode); +} + +bool SPIRVInstructionSelector::selectIntToBool(Register IntReg, + Register ResVReg, + const SPIRVType *IntTy, + const SPIRVType *BoolTy, + MachineInstr &I) const { + // To truncate to a bool, we use OpBitwiseAnd 1 and OpINotEqual to zero. + Register BitIntReg = MRI->createVirtualRegister(&SPIRV::IDRegClass); + bool IsVectorTy = IntTy->getOpcode() == SPIRV::OpTypeVector; + unsigned Opcode = IsVectorTy ? SPIRV::OpBitwiseAndV : SPIRV::OpBitwiseAndS; + Register Zero = buildZerosVal(IntTy, I); + Register One = buildOnesVal(false, IntTy, I); + MachineBasicBlock &BB = *I.getParent(); + BuildMI(BB, I, I.getDebugLoc(), TII.get(Opcode)) + .addDef(BitIntReg) + .addUse(GR.getSPIRVTypeID(IntTy)) + .addUse(IntReg) + .addUse(One) + .constrainAllUses(TII, TRI, RBI); + return BuildMI(BB, I, I.getDebugLoc(), TII.get(SPIRV::OpINotEqual)) + .addDef(ResVReg) + .addUse(GR.getSPIRVTypeID(BoolTy)) + .addUse(BitIntReg) + .addUse(Zero) + .constrainAllUses(TII, TRI, RBI); +} + +bool SPIRVInstructionSelector::selectTrunc(Register ResVReg, + const SPIRVType *ResType, + MachineInstr &I) const { + if (GR.isScalarOrVectorOfType(ResVReg, SPIRV::OpTypeBool)) { + Register IntReg = I.getOperand(1).getReg(); + const SPIRVType *ArgType = GR.getSPIRVTypeForVReg(IntReg); + return selectIntToBool(IntReg, ResVReg, ArgType, ResType, I); + } + bool IsSigned = GR.isScalarOrVectorSigned(ResType); + unsigned Opcode = IsSigned ? SPIRV::OpSConvert : SPIRV::OpUConvert; + return selectUnOp(ResVReg, ResType, I, Opcode); +} + +bool SPIRVInstructionSelector::selectConst(Register ResVReg, + const SPIRVType *ResType, + const APInt &Imm, + MachineInstr &I) const { + assert(ResType->getOpcode() != SPIRV::OpTypePointer || Imm.isNullValue()); + MachineBasicBlock &BB = *I.getParent(); + if (ResType->getOpcode() == SPIRV::OpTypePointer && Imm.isNullValue()) { + return BuildMI(BB, I, I.getDebugLoc(), TII.get(SPIRV::OpConstantNull)) + .addDef(ResVReg) + .addUse(GR.getSPIRVTypeID(ResType)) + .constrainAllUses(TII, TRI, RBI); + } + auto MIB = BuildMI(BB, I, I.getDebugLoc(), TII.get(SPIRV::OpConstantI)) + .addDef(ResVReg) + .addUse(GR.getSPIRVTypeID(ResType)); + // <=32-bit integers should be caught by the sdag pattern. + assert(Imm.getBitWidth() > 32); + addNumImm(Imm, MIB); + return MIB.constrainAllUses(TII, TRI, RBI); +} + +bool SPIRVInstructionSelector::selectOpUndef(Register ResVReg, + const SPIRVType *ResType, + MachineInstr &I) const { + return BuildMI(*I.getParent(), I, I.getDebugLoc(), TII.get(SPIRV::OpUndef)) + .addDef(ResVReg) + .addUse(GR.getSPIRVTypeID(ResType)) + .constrainAllUses(TII, TRI, RBI); +} + +bool SPIRVInstructionSelector::selectIntrinsic(Register ResVReg, + const SPIRVType *ResType, + MachineInstr &I) const { + llvm_unreachable("Intrinsic selection not implemented"); +} + +bool SPIRVInstructionSelector::selectFrameIndex(Register ResVReg, + const SPIRVType *ResType, + MachineInstr &I) const { + return BuildMI(*I.getParent(), I, I.getDebugLoc(), TII.get(SPIRV::OpVariable)) + .addDef(ResVReg) + .addUse(GR.getSPIRVTypeID(ResType)) + .addImm(static_cast(SPIRV::StorageClass::Function)) + .constrainAllUses(TII, TRI, RBI); +} + +bool SPIRVInstructionSelector::selectBranch(MachineInstr &I) const { + // InstructionSelector walks backwards through the instructions. We can use + // both a G_BR and a G_BRCOND to create an OpBranchConditional. We hit G_BR + // first, so can generate an OpBranchConditional here. If there is no + // G_BRCOND, we just use OpBranch for a regular unconditional branch. + const MachineInstr *PrevI = I.getPrevNode(); + MachineBasicBlock &MBB = *I.getParent(); + if (PrevI != nullptr && PrevI->getOpcode() == TargetOpcode::G_BRCOND) { + return BuildMI(MBB, I, I.getDebugLoc(), TII.get(SPIRV::OpBranchConditional)) + .addUse(PrevI->getOperand(0).getReg()) + .addMBB(PrevI->getOperand(1).getMBB()) + .addMBB(I.getOperand(0).getMBB()) + .constrainAllUses(TII, TRI, RBI); + } + return BuildMI(MBB, I, I.getDebugLoc(), TII.get(SPIRV::OpBranch)) + .addMBB(I.getOperand(0).getMBB()) + .constrainAllUses(TII, TRI, RBI); +} + +bool SPIRVInstructionSelector::selectBranchCond(MachineInstr &I) const { + // InstructionSelector walks backwards through the instructions. For an + // explicit conditional branch with no fallthrough, we use both a G_BR and a + // G_BRCOND to create an OpBranchConditional. We should hit G_BR first, and + // generate the OpBranchConditional in selectBranch above. + // + // If an OpBranchConditional has been generated, we simply return, as the work + // is alread done. If there is no OpBranchConditional, LLVM must be relying on + // implicit fallthrough to the next basic block, so we need to create an + // OpBranchConditional with an explicit "false" argument pointing to the next + // basic block that LLVM would fall through to. + const MachineInstr *NextI = I.getNextNode(); + // Check if this has already been successfully selected. + if (NextI != nullptr && NextI->getOpcode() == SPIRV::OpBranchConditional) + return true; + // Must be relying on implicit block fallthrough, so generate an + // OpBranchConditional with the "next" basic block as the "false" target. + MachineBasicBlock &MBB = *I.getParent(); + unsigned NextMBBNum = MBB.getNextNode()->getNumber(); + MachineBasicBlock *NextMBB = I.getMF()->getBlockNumbered(NextMBBNum); + return BuildMI(MBB, I, I.getDebugLoc(), TII.get(SPIRV::OpBranchConditional)) + .addUse(I.getOperand(0).getReg()) + .addMBB(I.getOperand(1).getMBB()) + .addMBB(NextMBB) + .constrainAllUses(TII, TRI, RBI); +} + +bool SPIRVInstructionSelector::selectPhi(Register ResVReg, + const SPIRVType *ResType, + MachineInstr &I) const { + auto MIB = BuildMI(*I.getParent(), I, I.getDebugLoc(), TII.get(SPIRV::OpPhi)) + .addDef(ResVReg) + .addUse(GR.getSPIRVTypeID(ResType)); + const unsigned NumOps = I.getNumOperands(); + for (unsigned i = 1; i < NumOps; i += 2) { + MIB.addUse(I.getOperand(i + 0).getReg()); + MIB.addMBB(I.getOperand(i + 1).getMBB()); + } + return MIB.constrainAllUses(TII, TRI, RBI); +} + +bool SPIRVInstructionSelector::selectGlobalValue( + Register ResVReg, MachineInstr &I, const MachineInstr *Init) const { + // FIXME: don't use MachineIRBuilder here, replace it with BuildMI. + MachineIRBuilder MIRBuilder(I); + const GlobalValue *GV = I.getOperand(1).getGlobal(); + SPIRVType *ResType = GR.getOrCreateSPIRVType( + GV->getType(), MIRBuilder, SPIRV::AccessQualifier::ReadWrite, false); + + std::string GlobalIdent = GV->getGlobalIdentifier(); + // TODO: suport @llvm.global.annotations. + auto GlobalVar = cast(GV); + + bool HasInit = GlobalVar->hasInitializer() && + !isa(GlobalVar->getInitializer()); + // Skip empty declaration for GVs with initilaizers till we get the decl with + // passed initializer. + if (HasInit && !Init) + return true; + + unsigned AddrSpace = GV->getAddressSpace(); + SPIRV::StorageClass Storage = addressSpaceToStorageClass(AddrSpace); + bool HasLnkTy = GV->getLinkage() != GlobalValue::InternalLinkage && + Storage != SPIRV::StorageClass::Function; + SPIRV::LinkageType LnkType = + (GV->isDeclaration() || GV->hasAvailableExternallyLinkage()) + ? SPIRV::LinkageType::Import + : SPIRV::LinkageType::Export; + + Register Reg = GR.buildGlobalVariable(ResVReg, ResType, GlobalIdent, GV, + Storage, Init, GlobalVar->isConstant(), + HasLnkTy, LnkType, MIRBuilder, true); + return Reg.isValid(); +} + +namespace llvm { +InstructionSelector * +createSPIRVInstructionSelector(const SPIRVTargetMachine &TM, + const SPIRVSubtarget &Subtarget, + const RegisterBankInfo &RBI) { + return new SPIRVInstructionSelector(TM, Subtarget, RBI); +} +} // namespace llvm diff --git a/llvm/lib/Target/SPIRV/SPIRVLegalizerInfo.cpp b/llvm/lib/Target/SPIRV/SPIRVLegalizerInfo.cpp new file mode 100644 index 0000000000000..fff277e9ad251 --- /dev/null +++ b/llvm/lib/Target/SPIRV/SPIRVLegalizerInfo.cpp @@ -0,0 +1,301 @@ +//===- SPIRVLegalizerInfo.cpp --- SPIR-V Legalization Rules ------*- C++ -*-==// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This file implements the targeting of the Machinelegalizer class for SPIR-V. +// +//===----------------------------------------------------------------------===// + +#include "SPIRVLegalizerInfo.h" +#include "SPIRV.h" +#include "SPIRVGlobalRegistry.h" +#include "SPIRVSubtarget.h" +#include "llvm/CodeGen/GlobalISel/LegalizerHelper.h" +#include "llvm/CodeGen/GlobalISel/MachineIRBuilder.h" +#include "llvm/CodeGen/MachineInstr.h" +#include "llvm/CodeGen/MachineRegisterInfo.h" +#include "llvm/CodeGen/TargetOpcodes.h" + +using namespace llvm; +using namespace llvm::LegalizeActions; +using namespace llvm::LegalityPredicates; + +static const std::set TypeFoldingSupportingOpcs = { + TargetOpcode::G_ADD, + TargetOpcode::G_FADD, + TargetOpcode::G_SUB, + TargetOpcode::G_FSUB, + TargetOpcode::G_MUL, + TargetOpcode::G_FMUL, + TargetOpcode::G_SDIV, + TargetOpcode::G_UDIV, + TargetOpcode::G_FDIV, + TargetOpcode::G_SREM, + TargetOpcode::G_UREM, + TargetOpcode::G_FREM, + TargetOpcode::G_FNEG, + TargetOpcode::G_CONSTANT, + TargetOpcode::G_FCONSTANT, + TargetOpcode::G_AND, + TargetOpcode::G_OR, + TargetOpcode::G_XOR, + TargetOpcode::G_SHL, + TargetOpcode::G_ASHR, + TargetOpcode::G_LSHR, + TargetOpcode::G_SELECT, + TargetOpcode::G_EXTRACT_VECTOR_ELT, +}; + +bool isTypeFoldingSupported(unsigned Opcode) { + return TypeFoldingSupportingOpcs.count(Opcode) > 0; +} + +SPIRVLegalizerInfo::SPIRVLegalizerInfo(const SPIRVSubtarget &ST) { + using namespace TargetOpcode; + + this->ST = &ST; + GR = ST.getSPIRVGlobalRegistry(); + + const LLT s1 = LLT::scalar(1); + const LLT s8 = LLT::scalar(8); + const LLT s16 = LLT::scalar(16); + const LLT s32 = LLT::scalar(32); + const LLT s64 = LLT::scalar(64); + + const LLT v16s64 = LLT::fixed_vector(16, 64); + const LLT v16s32 = LLT::fixed_vector(16, 32); + const LLT v16s16 = LLT::fixed_vector(16, 16); + const LLT v16s8 = LLT::fixed_vector(16, 8); + const LLT v16s1 = LLT::fixed_vector(16, 1); + + const LLT v8s64 = LLT::fixed_vector(8, 64); + const LLT v8s32 = LLT::fixed_vector(8, 32); + const LLT v8s16 = LLT::fixed_vector(8, 16); + const LLT v8s8 = LLT::fixed_vector(8, 8); + const LLT v8s1 = LLT::fixed_vector(8, 1); + + const LLT v4s64 = LLT::fixed_vector(4, 64); + const LLT v4s32 = LLT::fixed_vector(4, 32); + const LLT v4s16 = LLT::fixed_vector(4, 16); + const LLT v4s8 = LLT::fixed_vector(4, 8); + const LLT v4s1 = LLT::fixed_vector(4, 1); + + const LLT v3s64 = LLT::fixed_vector(3, 64); + const LLT v3s32 = LLT::fixed_vector(3, 32); + const LLT v3s16 = LLT::fixed_vector(3, 16); + const LLT v3s8 = LLT::fixed_vector(3, 8); + const LLT v3s1 = LLT::fixed_vector(3, 1); + + const LLT v2s64 = LLT::fixed_vector(2, 64); + const LLT v2s32 = LLT::fixed_vector(2, 32); + const LLT v2s16 = LLT::fixed_vector(2, 16); + const LLT v2s8 = LLT::fixed_vector(2, 8); + const LLT v2s1 = LLT::fixed_vector(2, 1); + + const unsigned PSize = ST.getPointerSize(); + const LLT p0 = LLT::pointer(0, PSize); // Function + const LLT p1 = LLT::pointer(1, PSize); // CrossWorkgroup + const LLT p2 = LLT::pointer(2, PSize); // UniformConstant + const LLT p3 = LLT::pointer(3, PSize); // Workgroup + const LLT p4 = LLT::pointer(4, PSize); // Generic + const LLT p5 = LLT::pointer(5, PSize); // Input + + // TODO: remove copy-pasting here by using concatenation in some way. + auto allPtrsScalarsAndVectors = { + p0, p1, p2, p3, p4, p5, s1, s8, s16, + s32, s64, v2s1, v2s8, v2s16, v2s32, v2s64, v3s1, v3s8, + v3s16, v3s32, v3s64, v4s1, v4s8, v4s16, v4s32, v4s64, v8s1, + v8s8, v8s16, v8s32, v8s64, v16s1, v16s8, v16s16, v16s32, v16s64}; + + auto allScalarsAndVectors = { + s1, s8, s16, s32, s64, v2s1, v2s8, v2s16, v2s32, v2s64, + v3s1, v3s8, v3s16, v3s32, v3s64, v4s1, v4s8, v4s16, v4s32, v4s64, + v8s1, v8s8, v8s16, v8s32, v8s64, v16s1, v16s8, v16s16, v16s32, v16s64}; + + auto allIntScalarsAndVectors = {s8, s16, s32, s64, v2s8, v2s16, + v2s32, v2s64, v3s8, v3s16, v3s32, v3s64, + v4s8, v4s16, v4s32, v4s64, v8s8, v8s16, + v8s32, v8s64, v16s8, v16s16, v16s32, v16s64}; + + auto allBoolScalarsAndVectors = {s1, v2s1, v3s1, v4s1, v8s1, v16s1}; + + auto allIntScalars = {s8, s16, s32, s64}; + + auto allFloatScalarsAndVectors = { + s16, s32, s64, v2s16, v2s32, v2s64, v3s16, v3s32, v3s64, + v4s16, v4s32, v4s64, v8s16, v8s32, v8s64, v16s16, v16s32, v16s64}; + + auto allFloatAndIntScalars = allIntScalars; + + auto allPtrs = {p0, p1, p2, p3, p4, p5}; + auto allWritablePtrs = {p0, p1, p3, p4}; + + for (auto Opc : TypeFoldingSupportingOpcs) + getActionDefinitionsBuilder(Opc).custom(); + + getActionDefinitionsBuilder(G_GLOBAL_VALUE).alwaysLegal(); + + // TODO: add proper rules for vectors legalization. + getActionDefinitionsBuilder({G_BUILD_VECTOR, G_SHUFFLE_VECTOR}).alwaysLegal(); + + getActionDefinitionsBuilder({G_MEMCPY, G_MEMMOVE}) + .legalIf(all(typeInSet(0, allWritablePtrs), typeInSet(1, allPtrs))); + + getActionDefinitionsBuilder(G_ADDRSPACE_CAST) + .legalForCartesianProduct(allPtrs, allPtrs); + + getActionDefinitionsBuilder({G_LOAD, G_STORE}).legalIf(typeInSet(1, allPtrs)); + + getActionDefinitionsBuilder(G_BITREVERSE).legalFor(allFloatScalarsAndVectors); + + getActionDefinitionsBuilder(G_FMA).legalFor(allFloatScalarsAndVectors); + + getActionDefinitionsBuilder({G_FPTOSI, G_FPTOUI}) + .legalForCartesianProduct(allIntScalarsAndVectors, + allFloatScalarsAndVectors); + + getActionDefinitionsBuilder({G_SITOFP, G_UITOFP}) + .legalForCartesianProduct(allFloatScalarsAndVectors, + allScalarsAndVectors); + + getActionDefinitionsBuilder({G_SMIN, G_SMAX, G_UMIN, G_UMAX, G_ABS}) + .legalFor(allIntScalarsAndVectors); + + getActionDefinitionsBuilder(G_CTPOP).legalForCartesianProduct( + allIntScalarsAndVectors, allIntScalarsAndVectors); + + getActionDefinitionsBuilder(G_PHI).legalFor(allPtrsScalarsAndVectors); + + getActionDefinitionsBuilder(G_BITCAST).legalIf(all( + typeInSet(0, allPtrsScalarsAndVectors), + typeInSet(1, allPtrsScalarsAndVectors), + LegalityPredicate(([=](const LegalityQuery &Query) { + return Query.Types[0].getSizeInBits() == Query.Types[1].getSizeInBits(); + })))); + + getActionDefinitionsBuilder(G_IMPLICIT_DEF).alwaysLegal(); + + getActionDefinitionsBuilder(G_INTTOPTR) + .legalForCartesianProduct(allPtrs, allIntScalars); + getActionDefinitionsBuilder(G_PTRTOINT) + .legalForCartesianProduct(allIntScalars, allPtrs); + getActionDefinitionsBuilder(G_PTR_ADD).legalForCartesianProduct( + allPtrs, allIntScalars); + + // ST.canDirectlyComparePointers() for pointer args is supported in + // legalizeCustom(). + getActionDefinitionsBuilder(G_ICMP).customIf( + all(typeInSet(0, allBoolScalarsAndVectors), + typeInSet(1, allPtrsScalarsAndVectors))); + + getActionDefinitionsBuilder(G_FCMP).legalIf( + all(typeInSet(0, allBoolScalarsAndVectors), + typeInSet(1, allFloatScalarsAndVectors))); + + getActionDefinitionsBuilder({G_ATOMICRMW_OR, G_ATOMICRMW_ADD, G_ATOMICRMW_AND, + G_ATOMICRMW_MAX, G_ATOMICRMW_MIN, + G_ATOMICRMW_SUB, G_ATOMICRMW_XOR, + G_ATOMICRMW_UMAX, G_ATOMICRMW_UMIN}) + .legalForCartesianProduct(allIntScalars, allWritablePtrs); + + getActionDefinitionsBuilder(G_ATOMICRMW_XCHG) + .legalForCartesianProduct(allFloatAndIntScalars, allWritablePtrs); + + getActionDefinitionsBuilder(G_ATOMIC_CMPXCHG_WITH_SUCCESS).lower(); + // TODO: add proper legalization rules. + getActionDefinitionsBuilder(G_ATOMIC_CMPXCHG).alwaysLegal(); + + getActionDefinitionsBuilder({G_UADDO, G_USUBO, G_SMULO, G_UMULO}) + .alwaysLegal(); + + // Extensions. + getActionDefinitionsBuilder({G_TRUNC, G_ZEXT, G_SEXT, G_ANYEXT}) + .legalForCartesianProduct(allScalarsAndVectors); + + // FP conversions. + getActionDefinitionsBuilder({G_FPTRUNC, G_FPEXT}) + .legalForCartesianProduct(allFloatScalarsAndVectors); + + // Pointer-handling. + getActionDefinitionsBuilder(G_FRAME_INDEX).legalFor({p0}); + + // Control-flow. + getActionDefinitionsBuilder(G_BRCOND).legalFor({s1}); + + getActionDefinitionsBuilder({G_FPOW, + G_FEXP, + G_FEXP2, + G_FLOG, + G_FLOG2, + G_FABS, + G_FMINNUM, + G_FMAXNUM, + G_FCEIL, + G_FCOS, + G_FSIN, + G_FSQRT, + G_FFLOOR, + G_FRINT, + G_FNEARBYINT, + G_INTRINSIC_ROUND, + G_INTRINSIC_TRUNC, + G_FMINIMUM, + G_FMAXIMUM, + G_INTRINSIC_ROUNDEVEN}) + .legalFor(allFloatScalarsAndVectors); + + getActionDefinitionsBuilder(G_FCOPYSIGN) + .legalForCartesianProduct(allFloatScalarsAndVectors, + allFloatScalarsAndVectors); + + getActionDefinitionsBuilder(G_FPOWI).legalForCartesianProduct( + allFloatScalarsAndVectors, allIntScalarsAndVectors); + + getLegacyLegalizerInfo().computeTables(); + verify(*ST.getInstrInfo()); +} + +static Register convertPtrToInt(Register Reg, LLT ConvTy, SPIRVType *SpirvType, + LegalizerHelper &Helper, + MachineRegisterInfo &MRI, + SPIRVGlobalRegistry *GR) { + Register ConvReg = MRI.createGenericVirtualRegister(ConvTy); + GR->assignSPIRVTypeToVReg(SpirvType, ConvReg, Helper.MIRBuilder); + Helper.MIRBuilder.buildInstr(TargetOpcode::G_PTRTOINT) + .addDef(ConvReg) + .addUse(Reg); + return ConvReg; +} + +bool SPIRVLegalizerInfo::legalizeCustom(LegalizerHelper &Helper, + MachineInstr &MI) const { + auto Opc = MI.getOpcode(); + MachineRegisterInfo &MRI = MI.getMF()->getRegInfo(); + if (!isTypeFoldingSupported(Opc)) { + assert(Opc == TargetOpcode::G_ICMP); + assert(GR->getSPIRVTypeForVReg(MI.getOperand(0).getReg())); + auto &Op0 = MI.getOperand(2); + auto &Op1 = MI.getOperand(3); + Register Reg0 = Op0.getReg(); + Register Reg1 = Op1.getReg(); + CmpInst::Predicate Cond = + static_cast(MI.getOperand(1).getPredicate()); + if ((!ST->canDirectlyComparePointers() || + (Cond != CmpInst::ICMP_EQ && Cond != CmpInst::ICMP_NE)) && + MRI.getType(Reg0).isPointer() && MRI.getType(Reg1).isPointer()) { + LLT ConvT = LLT::scalar(ST->getPointerSize()); + Type *LLVMTy = IntegerType::get(MI.getMF()->getFunction().getContext(), + ST->getPointerSize()); + SPIRVType *SpirvTy = GR->getOrCreateSPIRVType(LLVMTy, Helper.MIRBuilder); + Op0.setReg(convertPtrToInt(Reg0, ConvT, SpirvTy, Helper, MRI, GR)); + Op1.setReg(convertPtrToInt(Reg1, ConvT, SpirvTy, Helper, MRI, GR)); + } + return true; + } + // TODO: implement legalization for other opcodes. + return true; +} diff --git a/llvm/lib/Target/SPIRV/SPIRVLegalizerInfo.h b/llvm/lib/Target/SPIRV/SPIRVLegalizerInfo.h new file mode 100644 index 0000000000000..2541ff29edb0f --- /dev/null +++ b/llvm/lib/Target/SPIRV/SPIRVLegalizerInfo.h @@ -0,0 +1,36 @@ +//===- SPIRVLegalizerInfo.h --- SPIR-V Legalization Rules --------*- C++ -*-==// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This file declares the targeting of the MachineLegalizer class for SPIR-V. +// +//===----------------------------------------------------------------------===// + +#ifndef LLVM_LIB_TARGET_SPIRV_SPIRVMACHINELEGALIZER_H +#define LLVM_LIB_TARGET_SPIRV_SPIRVMACHINELEGALIZER_H + +#include "SPIRVGlobalRegistry.h" +#include "llvm/CodeGen/GlobalISel/LegalizerInfo.h" + +bool isTypeFoldingSupported(unsigned Opcode); + +namespace llvm { + +class LLVMContext; +class SPIRVSubtarget; + +// This class provides the information for legalizing SPIR-V instructions. +class SPIRVLegalizerInfo : public LegalizerInfo { + const SPIRVSubtarget *ST; + SPIRVGlobalRegistry *GR; + +public: + bool legalizeCustom(LegalizerHelper &Helper, MachineInstr &MI) const override; + SPIRVLegalizerInfo(const SPIRVSubtarget &ST); +}; +} // namespace llvm +#endif // LLVM_LIB_TARGET_SPIRV_SPIRVMACHINELEGALIZER_H diff --git a/llvm/lib/Target/SPIRV/SPIRVSubtarget.cpp b/llvm/lib/Target/SPIRV/SPIRVSubtarget.cpp index 152af7b9c9c86..cdf3a160f3738 100644 --- a/llvm/lib/Target/SPIRV/SPIRVSubtarget.cpp +++ b/llvm/lib/Target/SPIRV/SPIRVSubtarget.cpp @@ -12,6 +12,8 @@ #include "SPIRVSubtarget.h" #include "SPIRV.h" +#include "SPIRVGlobalRegistry.h" +#include "SPIRVLegalizerInfo.h" #include "SPIRVRegisterBankInfo.h" #include "SPIRVTargetMachine.h" #include "llvm/MC/TargetRegistry.h" @@ -43,8 +45,13 @@ SPIRVSubtarget::SPIRVSubtarget(const Triple &TT, const std::string &CPU, : SPIRVGenSubtargetInfo(TT, CPU, /*TuneCPU=*/CPU, FS), PointerSize(computePointerSize(TT)), SPIRVVersion(0), InstrInfo(), FrameLowering(initSubtargetDependencies(CPU, FS)), TLInfo(TM, *this) { - CallLoweringInfo = std::make_unique(TLInfo); + GR = std::make_unique(PointerSize); + CallLoweringInfo = + std::make_unique(TLInfo, *this, GR.get()); + Legalizer = std::make_unique(*this); RegBankInfo = std::make_unique(); + InstSelector.reset( + createSPIRVInstructionSelector(TM, *this, *RegBankInfo.get())); } SPIRVSubtarget &SPIRVSubtarget::initSubtargetDependencies(StringRef CPU, diff --git a/llvm/lib/Target/SPIRV/SPIRVSubtarget.h b/llvm/lib/Target/SPIRV/SPIRVSubtarget.h index 208fb11e3f95c..a6332cfefa8e3 100644 --- a/llvm/lib/Target/SPIRV/SPIRVSubtarget.h +++ b/llvm/lib/Target/SPIRV/SPIRVSubtarget.h @@ -30,7 +30,7 @@ namespace llvm { class StringRef; - +class SPIRVGlobalRegistry; class SPIRVTargetMachine; class SPIRVSubtarget : public SPIRVGenSubtargetInfo { @@ -38,6 +38,8 @@ class SPIRVSubtarget : public SPIRVGenSubtargetInfo { const unsigned PointerSize; uint32_t SPIRVVersion; + std::unique_ptr GR; + SPIRVInstrInfo InstrInfo; SPIRVFrameLowering FrameLowering; SPIRVTargetLowering TLInfo; @@ -45,6 +47,8 @@ class SPIRVSubtarget : public SPIRVGenSubtargetInfo { // GlobalISel related APIs. std::unique_ptr CallLoweringInfo; std::unique_ptr RegBankInfo; + std::unique_ptr Legalizer; + std::unique_ptr InstSelector; public: // This constructor initializes the data members to match that @@ -59,6 +63,7 @@ class SPIRVSubtarget : public SPIRVGenSubtargetInfo { unsigned getPointerSize() const { return PointerSize; } bool canDirectlyComparePointers() const; uint32_t getSPIRVVersion() const { return SPIRVVersion; }; + SPIRVGlobalRegistry *getSPIRVGlobalRegistry() const { return GR.get(); } const CallLowering *getCallLowering() const override { return CallLoweringInfo.get(); @@ -66,6 +71,12 @@ class SPIRVSubtarget : public SPIRVGenSubtargetInfo { const RegisterBankInfo *getRegBankInfo() const override { return RegBankInfo.get(); } + const LegalizerInfo *getLegalizerInfo() const override { + return Legalizer.get(); + } + InstructionSelector *getInstructionSelector() const override { + return InstSelector.get(); + } const SPIRVInstrInfo *getInstrInfo() const override { return &InstrInfo; } const SPIRVFrameLowering *getFrameLowering() const override { return &FrameLowering; diff --git a/llvm/lib/Target/SPIRV/SPIRVTargetMachine.cpp b/llvm/lib/Target/SPIRV/SPIRVTargetMachine.cpp index a6ce001ff04f8..a9fd5e2d581c1 100644 --- a/llvm/lib/Target/SPIRV/SPIRVTargetMachine.cpp +++ b/llvm/lib/Target/SPIRV/SPIRVTargetMachine.cpp @@ -12,6 +12,9 @@ #include "SPIRVTargetMachine.h" #include "SPIRV.h" +#include "SPIRVCallLowering.h" +#include "SPIRVGlobalRegistry.h" +#include "SPIRVLegalizerInfo.h" #include "SPIRVTargetObjectFile.h" #include "SPIRVTargetTransformInfo.h" #include "TargetInfo/SPIRVTargetInfo.h" @@ -34,6 +37,9 @@ extern "C" LLVM_EXTERNAL_VISIBILITY void LLVMInitializeSPIRVTarget() { // Register the target. RegisterTargetMachine X(getTheSPIRV32Target()); RegisterTargetMachine Y(getTheSPIRV64Target()); + + PassRegistry &PR = *PassRegistry::getPassRegistry(); + initializeGlobalISel(PR); } static std::string computeDataLayout(const Triple &TT) { @@ -155,7 +161,19 @@ bool SPIRVPassConfig::addRegBankSelect() { return false; } +namespace { +// A custom subclass of InstructionSelect, which is mostly the same except from +// not requiring RegBankSelect to occur previously. +class SPIRVInstructionSelect : public InstructionSelect { + // We don't use register banks, so unset the requirement for them + MachineFunctionProperties getRequiredProperties() const override { + return InstructionSelect::getRequiredProperties().reset( + MachineFunctionProperties::Property::RegBankSelected); + } +}; +} // namespace + bool SPIRVPassConfig::addGlobalInstructionSelect() { - addPass(new InstructionSelect(getOptLevel())); + addPass(new SPIRVInstructionSelect()); return false; } diff --git a/llvm/lib/Target/SPIRV/SPIRVUtils.cpp b/llvm/lib/Target/SPIRV/SPIRVUtils.cpp new file mode 100644 index 0000000000000..aa1933fc4f02e --- /dev/null +++ b/llvm/lib/Target/SPIRV/SPIRVUtils.cpp @@ -0,0 +1,182 @@ +//===--- SPIRVUtils.cpp ---- SPIR-V Utility Functions -----------*- C++ -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This file contains miscellaneous utility functions. +// +//===----------------------------------------------------------------------===// + +#include "SPIRVUtils.h" +#include "MCTargetDesc/SPIRVBaseInfo.h" +#include "SPIRV.h" +#include "SPIRVInstrInfo.h" +#include "llvm/ADT/StringRef.h" +#include "llvm/CodeGen/GlobalISel/MachineIRBuilder.h" +#include "llvm/CodeGen/MachineInstr.h" +#include "llvm/CodeGen/MachineInstrBuilder.h" + +using namespace llvm; + +// The following functions are used to add these string literals as a series of +// 32-bit integer operands with the correct format, and unpack them if necessary +// when making string comparisons in compiler passes. +// SPIR-V requires null-terminated UTF-8 strings padded to 32-bit alignment. +static uint32_t convertCharsToWord(const StringRef &Str, unsigned i) { + uint32_t Word = 0u; // Build up this 32-bit word from 4 8-bit chars. + for (unsigned WordIndex = 0; WordIndex < 4; ++WordIndex) { + unsigned StrIndex = i + WordIndex; + uint8_t CharToAdd = 0; // Initilize char as padding/null. + if (StrIndex < Str.size()) { // If it's within the string, get a real char. + CharToAdd = Str[StrIndex]; + } + Word |= (CharToAdd << (WordIndex * 8)); + } + return Word; +} + +// Get length including padding and null terminator. +static size_t getPaddedLen(const StringRef &Str) { + const size_t Len = Str.size() + 1; + return (Len % 4 == 0) ? Len : Len + (4 - (Len % 4)); +} + +void addStringImm(const StringRef &Str, MachineInstrBuilder &MIB) { + const size_t PaddedLen = getPaddedLen(Str); + for (unsigned i = 0; i < PaddedLen; i += 4) { + // Add an operand for the 32-bits of chars or padding. + MIB.addImm(convertCharsToWord(Str, i)); + } +} + +void addStringImm(const StringRef &Str, IRBuilder<> &B, + std::vector &Args) { + const size_t PaddedLen = getPaddedLen(Str); + for (unsigned i = 0; i < PaddedLen; i += 4) { + // Add a vector element for the 32-bits of chars or padding. + Args.push_back(B.getInt32(convertCharsToWord(Str, i))); + } +} + +std::string getStringImm(const MachineInstr &MI, unsigned StartIndex) { + return getSPIRVStringOperand(MI, StartIndex); +} + +void addNumImm(const APInt &Imm, MachineInstrBuilder &MIB) { + const auto Bitwidth = Imm.getBitWidth(); + switch (Bitwidth) { + case 1: + break; // Already handled. + case 8: + case 16: + case 32: + MIB.addImm(Imm.getZExtValue()); + break; + case 64: { + uint64_t FullImm = Imm.getZExtValue(); + uint32_t LowBits = FullImm & 0xffffffff; + uint32_t HighBits = (FullImm >> 32) & 0xffffffff; + MIB.addImm(LowBits).addImm(HighBits); + break; + } + default: + report_fatal_error("Unsupported constant bitwidth"); + } +} + +void buildOpName(Register Target, const StringRef &Name, + MachineIRBuilder &MIRBuilder) { + if (!Name.empty()) { + auto MIB = MIRBuilder.buildInstr(SPIRV::OpName).addUse(Target); + addStringImm(Name, MIB); + } +} + +static void finishBuildOpDecorate(MachineInstrBuilder &MIB, + const std::vector &DecArgs, + StringRef StrImm) { + if (!StrImm.empty()) + addStringImm(StrImm, MIB); + for (const auto &DecArg : DecArgs) + MIB.addImm(DecArg); +} + +void buildOpDecorate(Register Reg, MachineIRBuilder &MIRBuilder, + llvm::SPIRV::Decoration Dec, + const std::vector &DecArgs, StringRef StrImm) { + auto MIB = MIRBuilder.buildInstr(SPIRV::OpDecorate) + .addUse(Reg) + .addImm(static_cast(Dec)); + finishBuildOpDecorate(MIB, DecArgs, StrImm); +} + +void buildOpDecorate(Register Reg, MachineInstr &I, const SPIRVInstrInfo &TII, + llvm::SPIRV::Decoration Dec, + const std::vector &DecArgs, StringRef StrImm) { + MachineBasicBlock &MBB = *I.getParent(); + auto MIB = BuildMI(MBB, I, I.getDebugLoc(), TII.get(SPIRV::OpDecorate)) + .addUse(Reg) + .addImm(static_cast(Dec)); + finishBuildOpDecorate(MIB, DecArgs, StrImm); +} + +// TODO: maybe the following two functions should be handled in the subtarget +// to allow for different OpenCL vs Vulkan handling. +unsigned storageClassToAddressSpace(SPIRV::StorageClass SC) { + switch (SC) { + case SPIRV::StorageClass::Function: + return 0; + case SPIRV::StorageClass::CrossWorkgroup: + return 1; + case SPIRV::StorageClass::UniformConstant: + return 2; + case SPIRV::StorageClass::Workgroup: + return 3; + case SPIRV::StorageClass::Generic: + return 4; + case SPIRV::StorageClass::Input: + return 7; + default: + llvm_unreachable("Unable to get address space id"); + } +} + +SPIRV::StorageClass addressSpaceToStorageClass(unsigned AddrSpace) { + switch (AddrSpace) { + case 0: + return SPIRV::StorageClass::Function; + case 1: + return SPIRV::StorageClass::CrossWorkgroup; + case 2: + return SPIRV::StorageClass::UniformConstant; + case 3: + return SPIRV::StorageClass::Workgroup; + case 4: + return SPIRV::StorageClass::Generic; + case 7: + return SPIRV::StorageClass::Input; + default: + llvm_unreachable("Unknown address space"); + } +} + +SPIRV::MemorySemantics getMemSemanticsForStorageClass(SPIRV::StorageClass SC) { + switch (SC) { + case SPIRV::StorageClass::StorageBuffer: + case SPIRV::StorageClass::Uniform: + return SPIRV::MemorySemantics::UniformMemory; + case SPIRV::StorageClass::Workgroup: + return SPIRV::MemorySemantics::WorkgroupMemory; + case SPIRV::StorageClass::CrossWorkgroup: + return SPIRV::MemorySemantics::CrossWorkgroupMemory; + case SPIRV::StorageClass::AtomicCounter: + return SPIRV::MemorySemantics::AtomicCounterMemory; + case SPIRV::StorageClass::Image: + return SPIRV::MemorySemantics::ImageMemory; + default: + return SPIRV::MemorySemantics::None; + } +} diff --git a/llvm/lib/Target/SPIRV/SPIRVUtils.h b/llvm/lib/Target/SPIRV/SPIRVUtils.h new file mode 100644 index 0000000000000..3cb1db7795751 --- /dev/null +++ b/llvm/lib/Target/SPIRV/SPIRVUtils.h @@ -0,0 +1,69 @@ +//===--- SPIRVUtils.h ---- SPIR-V Utility Functions -------------*- C++ -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This file contains miscellaneous utility functions. +// +//===----------------------------------------------------------------------===// + +#ifndef LLVM_LIB_TARGET_SPIRV_SPIRVUTILS_H +#define LLVM_LIB_TARGET_SPIRV_SPIRVUTILS_H + +#include "MCTargetDesc/SPIRVBaseInfo.h" +#include "llvm/IR/IRBuilder.h" +#include + +namespace llvm { +class MCInst; +class MachineFunction; +class MachineInstr; +class MachineInstrBuilder; +class MachineIRBuilder; +class MachineRegisterInfo; +class Register; +class StringRef; +class SPIRVInstrInfo; +} // namespace llvm + +// Add the given string as a series of integer operand, inserting null +// terminators and padding to make sure the operands all have 32-bit +// little-endian words. +void addStringImm(const llvm::StringRef &Str, llvm::MachineInstrBuilder &MIB); +void addStringImm(const llvm::StringRef &Str, llvm::IRBuilder<> &B, + std::vector &Args); + +// Read the series of integer operands back as a null-terminated string using +// the reverse of the logic in addStringImm. +std::string getStringImm(const llvm::MachineInstr &MI, unsigned StartIndex); + +// Add the given numerical immediate to MIB. +void addNumImm(const llvm::APInt &Imm, llvm::MachineInstrBuilder &MIB); + +// Add an OpName instruction for the given target register. +void buildOpName(llvm::Register Target, const llvm::StringRef &Name, + llvm::MachineIRBuilder &MIRBuilder); + +// Add an OpDecorate instruction for the given Reg. +void buildOpDecorate(llvm::Register Reg, llvm::MachineIRBuilder &MIRBuilder, + llvm::SPIRV::Decoration Dec, + const std::vector &DecArgs, + llvm::StringRef StrImm = ""); +void buildOpDecorate(llvm::Register Reg, llvm::MachineInstr &I, + const llvm::SPIRVInstrInfo &TII, + llvm::SPIRV::Decoration Dec, + const std::vector &DecArgs, + llvm::StringRef StrImm = ""); + +// Convert a SPIR-V storage class to the corresponding LLVM IR address space. +unsigned storageClassToAddressSpace(llvm::SPIRV::StorageClass SC); + +// Convert an LLVM IR address space to a SPIR-V storage class. +llvm::SPIRV::StorageClass addressSpaceToStorageClass(unsigned AddrSpace); + +llvm::SPIRV::MemorySemantics +getMemSemanticsForStorageClass(llvm::SPIRV::StorageClass SC); +#endif // LLVM_LIB_TARGET_SPIRV_SPIRVUTILS_H