Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[DXIL] Model DXIL Class specification of DXIL Ops in DXIL.td #87803

Open
wants to merge 10 commits into
base: main
Choose a base branch
from

Conversation

bharadwajy
Copy link
Contributor

@bharadwajy bharadwajy commented Apr 5, 2024

Add specification of DXIL Op class.

Each DXIL Op belongs to a class. A DXIL class represents DXIL Ops
with the same function prototype (or signature). This changeset adds
specification of a TableGen class representing DXIL class. It facilitates usage
of the prototype information of the DXIL class that a DXIL Op belongs to instead
of inheriting the return and parameter type information from LLVM Intrinsic.
Using DXIL class avoids the currently implemented definitions of new
narrow DXIL types such as llvm_halforfloat_ty (hence deleted), is more
accurate and precise.

@llvmbot
Copy link
Collaborator

llvmbot commented Apr 5, 2024

@llvm/pr-subscribers-llvm-support

@llvm/pr-subscribers-backend-directx

Author: S. Bharadwaj Yadavalli (bharadwajy)

Changes
  • Add specification of DXIL Op class and Shader Model.

    Each DXIL Op belongs to a class. A DXIL class represents DXIL Ops
    with the same function prototype (or signature). This changeset adds
    specification of DXIL op TableGen class. This facilitates usage of the
    prototype information of the DXIL class that a DXIL Op belongs to instead
    of inheriting the return and parameter type information from LLVM Intrinsic.
    Using DXIL class avoids the currently implemented definitions of new
    narrow DXIL types such as llvm_halforfloat_ty (hence deleted), is more
    accurate and precise.

  • Add specification Shader Model version.

    Each DXIL Op has a set of valid overloads. Validity of overload types depends
    on minimum shader model version. Expressing such constraints in DXIL Ops
    records is needed to ensure valid code generation by DXIL Lowering pass.

    This changeset implements a specification mechanism that associates DXIL Ops
    with the classes they belong to and associates minimum shader mode version
    with valid overload types.

  • Restructure test of lowering llvm.sin.*

    This pattern of tests is expected to facilitate use of same test sources to
    test lowering of various combinations of options.


Patch is 67.21 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/87803.diff

48 Files Affected:

  • (modified) llvm/lib/Target/DirectX/DXIL.td (+177-308)
  • (modified) llvm/lib/Target/DirectX/DXILOpBuilder.cpp (+42-6)
  • (modified) llvm/lib/Target/DirectX/DXILOpBuilder.h (+10-5)
  • (modified) llvm/lib/Target/DirectX/DXILOpLowering.cpp (+20-3)
  • (renamed) llvm/test/CodeGen/DirectX/Inputs/sin/double.ll (-4)
  • (added) llvm/test/CodeGen/DirectX/Inputs/sin/float.ll (+9)
  • (added) llvm/test/CodeGen/DirectX/Inputs/sin/half.ll (+9)
  • (modified) llvm/test/CodeGen/DirectX/abs.ll (+1-1)
  • (modified) llvm/test/CodeGen/DirectX/ceil.ll (+2-1)
  • (modified) llvm/test/CodeGen/DirectX/ceil_error.ll (+1-1)
  • (modified) llvm/test/CodeGen/DirectX/clamp.ll (+1-1)
  • (modified) llvm/test/CodeGen/DirectX/cos.ll (+1-1)
  • (modified) llvm/test/CodeGen/DirectX/cos_error.ll (+1-1)
  • (modified) llvm/test/CodeGen/DirectX/dot2_error.ll (+1-1)
  • (modified) llvm/test/CodeGen/DirectX/dot3_error.ll (+1-1)
  • (modified) llvm/test/CodeGen/DirectX/dot4_error.ll (+1-1)
  • (modified) llvm/test/CodeGen/DirectX/exp.ll (+1-1)
  • (modified) llvm/test/CodeGen/DirectX/exp2_error.ll (+1-1)
  • (modified) llvm/test/CodeGen/DirectX/fabs.ll (+1-2)
  • (modified) llvm/test/CodeGen/DirectX/fdot.ll (+1-1)
  • (modified) llvm/test/CodeGen/DirectX/floor.ll (+1-1)
  • (modified) llvm/test/CodeGen/DirectX/floor_error.ll (+1-1)
  • (modified) llvm/test/CodeGen/DirectX/fmax.ll (+1-1)
  • (modified) llvm/test/CodeGen/DirectX/fmin.ll (+1-1)
  • (modified) llvm/test/CodeGen/DirectX/frac_error.ll (+1-1)
  • (modified) llvm/test/CodeGen/DirectX/idot.ll (+1-1)
  • (modified) llvm/test/CodeGen/DirectX/isinf.ll (+1-1)
  • (modified) llvm/test/CodeGen/DirectX/isinf_error.ll (+1-1)
  • (modified) llvm/test/CodeGen/DirectX/log.ll (+1-1)
  • (modified) llvm/test/CodeGen/DirectX/log10.ll (+1-1)
  • (modified) llvm/test/CodeGen/DirectX/log2.ll (+1-1)
  • (modified) llvm/test/CodeGen/DirectX/log2_error.ll (+1-1)
  • (modified) llvm/test/CodeGen/DirectX/pow.ll (+1-1)
  • (modified) llvm/test/CodeGen/DirectX/reversebits.ll (+1-1)
  • (modified) llvm/test/CodeGen/DirectX/round.ll (+1-1)
  • (modified) llvm/test/CodeGen/DirectX/round_error.ll (+1-1)
  • (modified) llvm/test/CodeGen/DirectX/rsqrt.ll (+1-1)
  • (modified) llvm/test/CodeGen/DirectX/rsqrt_error.ll (+1-1)
  • (modified) llvm/test/CodeGen/DirectX/sin.ll (+20-22)
  • (modified) llvm/test/CodeGen/DirectX/smax.ll (+1-1)
  • (modified) llvm/test/CodeGen/DirectX/smin.ll (+1-1)
  • (modified) llvm/test/CodeGen/DirectX/sqrt.ll (+1-1)
  • (modified) llvm/test/CodeGen/DirectX/sqrt_error.ll (+1-1)
  • (modified) llvm/test/CodeGen/DirectX/trunc.ll (+1-1)
  • (modified) llvm/test/CodeGen/DirectX/trunc_error.ll (+1-1)
  • (modified) llvm/test/CodeGen/DirectX/umax.ll (+1-1)
  • (modified) llvm/test/CodeGen/DirectX/umin.ll (+1-1)
  • (modified) llvm/utils/TableGen/DXILEmitter.cpp (+87-54)
diff --git a/llvm/lib/Target/DirectX/DXIL.td b/llvm/lib/Target/DirectX/DXIL.td
index cd388ed3e3191b..ce8a14cda94a9d 100644
--- a/llvm/lib/Target/DirectX/DXIL.td
+++ b/llvm/lib/Target/DirectX/DXIL.td
@@ -13,331 +13,200 @@
 
 include "llvm/IR/Intrinsics.td"
 
-class DXILOpClass;
+// Abstract class to demarcate minimum Shader model version required
+// to support DXIL Op
+class DXILShaderModel<int major, int minor> {
+  int MajorAndMinor = !add(!mul(major, 10), minor);
+}
 
-// Following is a set of DXIL Operation classes whose names appear to be
-// arbitrary, yet need to be a substring of the function name used during
-// lowering to DXIL Operation calls. These class name strings are specified
-// as the third argument of add_dixil_op in utils/hct/hctdb.py and case converted
-// in utils/hct/hctdb_instrhelp.py of DirectXShaderCompiler repo. The function
-// name has the format "dx.op.<class-name>.<return-type>".
+// Valid minimum Shader model version records
 
-defset list<DXILOpClass> OpClasses = {
-  def acceptHitAndEndSearch : DXILOpClass;
-  def allocateNodeOutputRecords : DXILOpClass;
-  def allocateRayQuery : DXILOpClass;
-  def annotateHandle : DXILOpClass;
-  def annotateNodeHandle : DXILOpClass;
-  def annotateNodeRecordHandle : DXILOpClass;
-  def atomicBinOp : DXILOpClass;
-  def atomicCompareExchange : DXILOpClass;
-  def attributeAtVertex : DXILOpClass;
-  def barrier : DXILOpClass;
-  def barrierByMemoryHandle : DXILOpClass;
-  def barrierByMemoryType : DXILOpClass;
-  def barrierByNodeRecordHandle : DXILOpClass;
-  def binary : DXILOpClass;
-  def binaryWithCarryOrBorrow : DXILOpClass;
-  def binaryWithTwoOuts : DXILOpClass;
-  def bitcastF16toI16 : DXILOpClass;
-  def bitcastF32toI32 : DXILOpClass;
-  def bitcastF64toI64 : DXILOpClass;
-  def bitcastI16toF16 : DXILOpClass;
-  def bitcastI32toF32 : DXILOpClass;
-  def bitcastI64toF64 : DXILOpClass;
-  def bufferLoad : DXILOpClass;
-  def bufferStore : DXILOpClass;
-  def bufferUpdateCounter : DXILOpClass;
-  def calculateLOD : DXILOpClass;
-  def callShader : DXILOpClass;
-  def cbufferLoad : DXILOpClass;
-  def cbufferLoadLegacy : DXILOpClass;
-  def checkAccessFullyMapped : DXILOpClass;
-  def coverage : DXILOpClass;
-  def createHandle : DXILOpClass;
-  def createHandleForLib : DXILOpClass;
-  def createHandleFromBinding : DXILOpClass;
-  def createHandleFromHeap : DXILOpClass;
-  def createNodeInputRecordHandle : DXILOpClass;
-  def createNodeOutputHandle : DXILOpClass;
-  def cutStream : DXILOpClass;
-  def cycleCounterLegacy : DXILOpClass;
-  def discard : DXILOpClass;
-  def dispatchMesh : DXILOpClass;
-  def dispatchRaysDimensions : DXILOpClass;
-  def dispatchRaysIndex : DXILOpClass;
-  def domainLocation : DXILOpClass;
-  def dot2 : DXILOpClass;
-  def dot2AddHalf : DXILOpClass;
-  def dot3 : DXILOpClass;
-  def dot4 : DXILOpClass;
-  def dot4AddPacked : DXILOpClass;
-  def emitIndices : DXILOpClass;
-  def emitStream : DXILOpClass;
-  def emitThenCutStream : DXILOpClass;
-  def evalCentroid : DXILOpClass;
-  def evalSampleIndex : DXILOpClass;
-  def evalSnapped : DXILOpClass;
-  def finishedCrossGroupSharing : DXILOpClass;
-  def flattenedThreadIdInGroup : DXILOpClass;
-  def geometryIndex : DXILOpClass;
-  def getDimensions : DXILOpClass;
-  def getInputRecordCount : DXILOpClass;
-  def getMeshPayload : DXILOpClass;
-  def getNodeRecordPtr : DXILOpClass;
-  def getRemainingRecursionLevels : DXILOpClass;
-  def groupId : DXILOpClass;
-  def gsInstanceID : DXILOpClass;
-  def hitKind : DXILOpClass;
-  def ignoreHit : DXILOpClass;
-  def incrementOutputCount : DXILOpClass;
-  def indexNodeHandle : DXILOpClass;
-  def innerCoverage : DXILOpClass;
-  def instanceID : DXILOpClass;
-  def instanceIndex : DXILOpClass;
-  def isHelperLane : DXILOpClass;
-  def isSpecialFloat : DXILOpClass;
-  def legacyDoubleToFloat : DXILOpClass;
-  def legacyDoubleToSInt32 : DXILOpClass;
-  def legacyDoubleToUInt32 : DXILOpClass;
-  def legacyF16ToF32 : DXILOpClass;
-  def legacyF32ToF16 : DXILOpClass;
-  def loadInput : DXILOpClass;
-  def loadOutputControlPoint : DXILOpClass;
-  def loadPatchConstant : DXILOpClass;
-  def makeDouble : DXILOpClass;
-  def minPrecXRegLoad : DXILOpClass;
-  def minPrecXRegStore : DXILOpClass;
-  def nodeOutputIsValid : DXILOpClass;
-  def objectRayDirection : DXILOpClass;
-  def objectRayOrigin : DXILOpClass;
-  def objectToWorld : DXILOpClass;
-  def outputComplete : DXILOpClass;
-  def outputControlPointID : DXILOpClass;
-  def pack4x8 : DXILOpClass;
-  def primitiveID : DXILOpClass;
-  def primitiveIndex : DXILOpClass;
-  def quadOp : DXILOpClass;
-  def quadReadLaneAt : DXILOpClass;
-  def quadVote : DXILOpClass;
-  def quaternary : DXILOpClass;
-  def rawBufferLoad : DXILOpClass;
-  def rawBufferStore : DXILOpClass;
-  def rayFlags : DXILOpClass;
-  def rayQuery_Abort : DXILOpClass;
-  def rayQuery_CommitNonOpaqueTriangleHit : DXILOpClass;
-  def rayQuery_CommitProceduralPrimitiveHit : DXILOpClass;
-  def rayQuery_Proceed : DXILOpClass;
-  def rayQuery_StateMatrix : DXILOpClass;
-  def rayQuery_StateScalar : DXILOpClass;
-  def rayQuery_StateVector : DXILOpClass;
-  def rayQuery_TraceRayInline : DXILOpClass;
-  def rayTCurrent : DXILOpClass;
-  def rayTMin : DXILOpClass;
-  def renderTargetGetSampleCount : DXILOpClass;
-  def renderTargetGetSamplePosition : DXILOpClass;
-  def reportHit : DXILOpClass;
-  def sample : DXILOpClass;
-  def sampleBias : DXILOpClass;
-  def sampleCmp : DXILOpClass;
-  def sampleCmpBias : DXILOpClass;
-  def sampleCmpGrad : DXILOpClass;
-  def sampleCmpLevel : DXILOpClass;
-  def sampleCmpLevelZero : DXILOpClass;
-  def sampleGrad : DXILOpClass;
-  def sampleIndex : DXILOpClass;
-  def sampleLevel : DXILOpClass;
-  def setMeshOutputCounts : DXILOpClass;
-  def splitDouble : DXILOpClass;
-  def startInstanceLocation : DXILOpClass;
-  def startVertexLocation : DXILOpClass;
-  def storeOutput : DXILOpClass;
-  def storePatchConstant : DXILOpClass;
-  def storePrimitiveOutput : DXILOpClass;
-  def storeVertexOutput : DXILOpClass;
-  def tempRegLoad : DXILOpClass;
-  def tempRegStore : DXILOpClass;
-  def tertiary : DXILOpClass;
-  def texture2DMSGetSamplePosition : DXILOpClass;
-  def textureGather : DXILOpClass;
-  def textureGatherCmp : DXILOpClass;
-  def textureGatherRaw : DXILOpClass;
-  def textureLoad : DXILOpClass;
-  def textureStore : DXILOpClass;
-  def textureStoreSample : DXILOpClass;
-  def threadId : DXILOpClass;
-  def threadIdInGroup : DXILOpClass;
-  def traceRay : DXILOpClass;
-  def unary : DXILOpClass;
-  def unaryBits : DXILOpClass;
-  def unpack4x8 : DXILOpClass;
-  def viewID : DXILOpClass;
-  def waveActiveAllEqual : DXILOpClass;
-  def waveActiveBallot : DXILOpClass;
-  def waveActiveBit : DXILOpClass;
-  def waveActiveOp : DXILOpClass;
-  def waveAllOp : DXILOpClass;
-  def waveAllTrue : DXILOpClass;
-  def waveAnyTrue : DXILOpClass;
-  def waveGetLaneCount : DXILOpClass;
-  def waveGetLaneIndex : DXILOpClass;
-  def waveIsFirstLane : DXILOpClass;
-  def waveMatch : DXILOpClass;
-  def waveMatrix_Accumulate : DXILOpClass;
-  def waveMatrix_Annotate : DXILOpClass;
-  def waveMatrix_Depth : DXILOpClass;
-  def waveMatrix_Fill : DXILOpClass;
-  def waveMatrix_LoadGroupShared : DXILOpClass;
-  def waveMatrix_LoadRawBuf : DXILOpClass;
-  def waveMatrix_Multiply : DXILOpClass;
-  def waveMatrix_ScalarOp : DXILOpClass;
-  def waveMatrix_StoreGroupShared : DXILOpClass;
-  def waveMatrix_StoreRawBuf : DXILOpClass;
-  def waveMultiPrefixBitCount : DXILOpClass;
-  def waveMultiPrefixOp : DXILOpClass;
-  def wavePrefixOp : DXILOpClass;
-  def waveReadLaneAt : DXILOpClass;
-  def waveReadLaneFirst : DXILOpClass;
-  def worldRayDirection : DXILOpClass;
-  def worldRayOrigin : DXILOpClass;
-  def worldToObject : DXILOpClass;
-  def writeSamplerFeedback : DXILOpClass;
-  def writeSamplerFeedbackBias : DXILOpClass;
-  def writeSamplerFeedbackGrad : DXILOpClass;
-  def writeSamplerFeedbackLevel: DXILOpClass;
+// Shader Mode 6.x
+foreach i = 0...9 in {
+  def SM6_#i : DXILShaderModel<6, i>;
+}
+// Shader Mode 7.x - for now 7.0 is defined. Extend as needed
+foreach i = 0 in {
+  def SM7_#i : DXILShaderModel<7, i>;
+}
 
-  // This is a sentinel definition. Hence placed at the end of the list
-  // and not as part of the above alphabetically sorted valid definitions.
-  // Additionally it is capitalized unlike all the others.
-  def UnknownOpClass: DXILOpClass;
+// Abstraction of class mapping valid DXIL Op overloads the minimum
+// version of Shader Model they are supported
+class DXILOpOverload<DXILShaderModel minsm, list<LLVMType> overloads> {
+  DXILShaderModel ShaderModel = minsm;
+  list<LLVMType> OpOverloads = overloads;
 }
 
-// Several of the overloaded DXIL Operations support for data types
-// that are a subset of the overloaded LLVM intrinsics that they map to.
-// For e.g., llvm.sin.* intrinsic operates on any floating-point type and
-// maps for lowering to DXIL Op Sin. However, valid overloads of DXIL Sin
-// operation overloads are half (f16) and float (f32) only.
-//
-// The following abstracts overload types specific to DXIL operations.
+// Abstraction of DXIL Operation class.
+// It encapsulates an associated function signature viz.,
+// returnTy(param1Ty, param2Ty, ...) represented as a list of LLVMTypes.
+// DXIL Ops that belong to a DXILOpClass record the signature of that
+// DXILOpClass
 
-class DXILType : LLVMType<OtherVT> {
-  let isAny = 1;
-  int isI16OrI32 = 0;
-  int isHalfOrFloat = 0;
+class DXILOpClass<list<LLVMType> OpSig> {
+  list<LLVMType> OpSignature = OpSig;
 }
 
-// Concrete records for various overload types supported specifically by
-// DXIL Operations.
-let isI16OrI32 = 1 in
-  def llvm_i16ori32_ty : DXILType;
+// Concrete definitions of DXIL Op Classes
+// Note that these class name strings are specified as the third argument
+// of add_dixil_op in utils/hct/hctdb.py and case converted in
+// utils/hct/hctdb_instrhelp.py of DirectXShaderCompiler repo. The function
+// name has the format "dx.op.<class-name>.<return-type>", in most cases.
 
-let isHalfOrFloat = 1 in
-  def llvm_halforfloat_ty : DXILType;
+// NOTE: The following list is not complete. Classes need to be defined as new DXIL Ops
+// are added.
+defset list<DXILOpClass> OpClasses = {
+  def acceptHitAndEndSearch : DXILOpClass<[llvm_void_ty]>;
+  def allocateRayQuery : DXILOpClass<[llvm_i32_ty, llvm_i32_ty]>;
+  def attributeAtVertex : DXILOpClass<[llvm_any_ty, llvm_i32_ty, llvm_i32_ty, llvm_i8_ty, llvm_i8_ty]>;
+  def barrier : DXILOpClass<[llvm_void_ty, llvm_i32_ty]>;
+  def barrierByMemoryType : DXILOpClass<[llvm_void_ty, llvm_i32_ty, llvm_i32_ty]>;
+  def binary : DXILOpClass<[llvm_any_ty, LLVMMatchType<0>, LLVMMatchType<0>]>;
+  def binaryWithCarryOrBorrow : DXILOpClass<[llvm_i32_ty, llvm_any_ty, LLVMMatchType<0>]>;
+  def dot2 : DXILOpClass<!listsplat(llvm_anyfloat_ty, 5)>;
+  def dot3 : DXILOpClass<!listsplat(llvm_anyfloat_ty, 7)>;
+  def dot4 : DXILOpClass<!listsplat(llvm_anyfloat_ty, 9)>;
+  def flattenedThreadIdInGroup : DXILOpClass<[llvm_i32_ty]>;
+  def groupId : DXILOpClass<[llvm_i32_ty, llvm_i32_ty]>;
+  def isSpecialFloat : DXILOpClass<[llvm_i1_ty, llvm_anyfloat_ty]>;
+  def tertiary : DXILOpClass<[llvm_any_ty, LLVMMatchType<0>, LLVMMatchType<0>, LLVMMatchType<0>]>;
+  def threadId : DXILOpClass<[llvm_i32_ty, llvm_i32_ty]>;
+  def threadIdInGroup : DXILOpClass<[llvm_i32_ty, llvm_i32_ty]>;
+  def unary : DXILOpClass<[llvm_any_ty, LLVMMatchType<0>]>;
+
+  // This is a sentinel definition. Hence placed at the end of the list
+  // and not as part of the above alphabetically sorted valid definitions.
+  // Additionally it is capitalized unlike all the others.
+  def UnknownOpClass: DXILOpClass<[]>;
+}
 
 // Abstraction DXIL Operation to LLVM intrinsic
 class DXILOpMappingBase {
   int OpCode = 0;                      // Opcode of DXIL Operation
   DXILOpClass OpClass = UnknownOpClass;// Class of DXIL Operation.
   Intrinsic LLVMIntrinsic = ?;         // LLVM Intrinsic DXIL Operation maps to
+  list<DXILOpOverload> OpOverloadTypes = ?; // Valid overload type
+                                       // of DXIL Operation
   string Doc = "";                     // A short description of the operation
-  list<LLVMType> OpTypes = ?;          // Valid types of DXIL Operation in the
-                                       // format [returnTy, param1ty, ...]
 }
 
-class DXILOpMapping<int opCode, DXILOpClass opClass,
-                    Intrinsic intrinsic, string doc,
-                    list<LLVMType> opTys = []> : DXILOpMappingBase {
-  int OpCode = opCode;                 // Opcode corresponding to DXIL Operation
-  DXILOpClass OpClass = opClass;       // Class of DXIL Operation.
-  Intrinsic LLVMIntrinsic = intrinsic; // LLVM Intrinsic the DXIL Operation maps
-  string Doc = doc;                    // to a short description of the operation
-  list<LLVMType> OpTypes = !if(!eq(!size(opTys), 0), LLVMIntrinsic.Types, opTys);
+class DXILOpMapping<int opCode,
+                    Intrinsic intrinsic,
+                    list<DXILOpOverload> overloadTypes,
+                    string doc> : DXILOpMappingBase {
+  int OpCode = opCode;
+  Intrinsic LLVMIntrinsic = intrinsic;
+  list<DXILOpOverload> OpOverloadTypes = overloadTypes;
+  string Doc = doc;
+}
+
+// Concrete definitions of DXIL Operation mapping to corresponding LLVM intrinsic
+
+// IsSpecialFloat Class
+let OpClass = isSpecialFloat in {
+  def IsInf : DXILOpMapping<9,  int_dx_isinf, [DXILOpOverload<SM6_0, [llvm_half_ty, llvm_float_ty]>],
+                           "Determines if the specified value is infinite.">;
+}
+
+// Unary Class
+let OpClass = unary in {
+  def Abs : DXILOpMapping<6, int_fabs, [DXILOpOverload<SM6_0, [llvm_half_ty, llvm_float_ty, llvm_double_ty]>],
+                          "Returns the absolute value of the input.">;
+
+  def Cos  : DXILOpMapping<12, int_cos, [DXILOpOverload<SM6_0, [llvm_half_ty, llvm_float_ty]>],
+                          "Returns cosine(theta) for theta in radians.">;
+  def Sin  : DXILOpMapping<13, int_sin, [DXILOpOverload<SM6_3, [llvm_half_ty, llvm_float_ty]>,
+                                         DXILOpOverload<SM6_0, [llvm_float_ty]>],
+                           "Returns sine(theta) for theta in radians.">;
+  def Exp2 : DXILOpMapping<21, int_exp2, [DXILOpOverload<SM6_0, [llvm_half_ty, llvm_float_ty]>],
+                           "Returns the base 2 exponential, or 2**x, of the"
+                           " specified value. exp2(x) = 2**x.">;
+  def Frac : DXILOpMapping<22, int_dx_frac, [DXILOpOverload<SM6_0, [llvm_half_ty, llvm_float_ty]>],
+                            "Returns a fraction from 0 to 1 that represents the"
+                            " decimal part of the input.">;
+  def Log2 : DXILOpMapping<23, int_log2, [DXILOpOverload<SM6_0, [llvm_half_ty, llvm_float_ty]>],
+                           "Returns the base-2 logarithm of the specified value.">;
+  def Sqrt : DXILOpMapping<24, int_sqrt, [DXILOpOverload<SM6_0, [llvm_half_ty, llvm_float_ty]>],
+                           "Returns the square root of the specified floating-point"
+                           "value, per component.">;
+  def RSqrt : DXILOpMapping<25, int_dx_rsqrt, [DXILOpOverload<SM6_0, [llvm_half_ty, llvm_float_ty]>],
+                            "Returns the reciprocal of the square root of the"
+                            " specified value. rsqrt(x) = 1 / sqrt(x).">;
+  def Round : DXILOpMapping<26, int_roundeven, [DXILOpOverload<SM6_0, [llvm_half_ty, llvm_float_ty]>],
+                            "Returns the input rounded to the nearest integer"
+                            "within a floating-point type.">;
+  def Floor : DXILOpMapping<27, int_floor, [DXILOpOverload<SM6_0, [llvm_half_ty, llvm_float_ty]>],
+                            "Returns the largest integer that is less than or equal to the input.">;
+  def Ceil  : DXILOpMapping<28, int_ceil, [DXILOpOverload<SM6_0, [llvm_half_ty, llvm_float_ty]>],
+                            "Returns the smallest integer that is greater than or equal to the input.">;
+  def Trunc : DXILOpMapping<29, int_trunc, [DXILOpOverload<SM6_0, [llvm_half_ty, llvm_float_ty]>],
+                            "Returns the specified value truncated to the integer component.">;
+  def Rbits : DXILOpMapping<30, int_bitreverse, [DXILOpOverload<SM6_0, [llvm_i16_ty, llvm_i32_ty, llvm_i64_ty]>],
+                            "Returns the specified value with its bits reversed.">;
+}
+
+// Binary Class
+let OpClass = binary in {
+// Float overloads
+  def FMax : DXILOpMapping<35, int_maxnum, [DXILOpOverload<SM6_0, [llvm_half_ty, llvm_float_ty, llvm_double_ty]>],
+                           "Float maximum. FMax(a,b) = a > b ? a : b">;
+  def FMin : DXILOpMapping<36, int_minnum, [DXILOpOverload<SM6_0, [llvm_half_ty, llvm_float_ty, llvm_double_ty]>],
+                           "Float minimum. FMin(a,b) = a < b ? a : b">;
+// Int overloads
+  def SMax : DXILOpMapping<37, int_smax, [DXILOpOverload<SM6_0,[llvm_i16_ty, llvm_i32_ty, llvm_i64_ty]>],
+                           "Signed integer maximum. SMax(a,b) = a > b ? a : b">;
+  def SMin : DXILOpMapping<38, int_smin, [DXILOpOverload<SM6_0,[llvm_i16_ty, llvm_i32_ty, llvm_i64_ty]>],
+                           "Signed integer minimum. SMin(a,b) = a < b ? a : b">;
+  def UMax : DXILOpMapping<39, int_umax, [DXILOpOverload<SM6_0,[llvm_i16_ty, llvm_i32_ty, llvm_i64_ty]>],
+                           "Unsigned integer maximum. UMax(a,b) = a > b ? a : b">;
+  def UMin : DXILOpMapping<40, int_umin, [DXILOpOverload<SM6_0,[llvm_i16_ty, llvm_i32_ty, llvm_i64_ty]>],
+                           "Unsigned integer minimum. UMin(a,b) = a < b ? a : b">;
 }
 
-// Concrete definition of DXIL Operation mapping to corresponding LLVM intrinsic
-def Abs : DXILOpMapping<6, unary, int_fabs,
-                         "Returns the absolute value of the input.">;
-def IsInf : DXILOpMapping<9, isSpecialFloat, int_dx_isinf,
-                         "Determines if the specified value is infinite.",
-                         [llvm_i1_ty, llvm_halforfloat_ty]>;
-def Cos  : DXILOpMapping<12, unary, int_cos,
-                         "Returns cosine(theta) for theta in radians.",
-                         [llvm_halforfloat_ty, LLVMMatchType<0>]>;
-def Sin  : DXILOpMapping<13, unary, int_sin,
-                         "Returns sine(theta) for theta in radians.",
-                         [llvm_halforfloat_ty, LLVMMatchType<0>]>;
-def Exp2 : DXILOpMapping<21, unary, int_exp2,
-                         "Returns the base 2 exponential, or 2**x, of the specified value."
-                         "exp2(x) = 2**x.",
-                         [llvm_halforfloat_ty, LLVMMatchType<0>]>;
-def Frac : DXILOpMapping<22, unary, int_dx_frac,
-                         "Returns a fraction from 0 to 1 that represents the "
-                         "decimal part of the input.",
-                         [llvm_halforfloat_ty, LLVMMatchType<0>]>;
-def Log2 : DXILOpMapping<23, unary, int_log2,
-                         "Returns the base-2 logarithm of the specified value.",
-                         [llvm_halforfloat_ty, LLVMMatchType<0>]>;
-def Sqrt : DXILOpMapping<24, unary, int_sqrt,
-                         "Returns the square root of the specified floating-point"
-                         "value, per component.",
-                         [llvm_halforfloat_ty, LLVMMatchType<0>]>;
-def RSqrt : DXILOpMapping<25, unary, int_dx_rsqrt,
-                         "Returns the reciprocal of the square root of the specified value."
-                         "rsqrt(x) = 1 / sqrt(x).",
-                         [llvm_halforfloat_ty, LLVMMatchType<0>]>;
-def Round : DXILOpMapping<26, unary, int_roundeven,
-                         "Returns the input rounded to the nearest integer"
-                         "within a floating-point type.",
-                         [llvm_halforfloat_ty, LLVMMatchType<0>]>;
-def Floor : DXILOpMapping<27, unary, int_floor,
-                         "Returns the largest integer that is less than or equal to the input.",
-                         [llvm_halforfloat_ty, LLVMMatchType<0>]>;
-def Ceil : DXILOpMapping<28, unary, int_ceil,
-                         "Returns the smallest integer that is greater than or equal to the input.",
-                         [llvm_halforfloat_ty, LLVMMatchType<0>]>;
-def Trunc : DXILOpMapping<29, unary, ...
[truncated]

llvm/lib/Target/DirectX/DXIL.td Outdated Show resolved Hide resolved
llvm/lib/Target/DirectX/DXIL.td Outdated Show resolved Hide resolved
llvm/lib/Target/DirectX/DXIL.td Outdated Show resolved Hide resolved
llvm/lib/Target/DirectX/DXILOpLowering.cpp Outdated Show resolved Hide resolved
llvm/test/CodeGen/DirectX/abs.ll Show resolved Hide resolved
llvm/utils/TableGen/DXILEmitter.cpp Outdated Show resolved Hide resolved
llvm/utils/TableGen/DXILEmitter.cpp Outdated Show resolved Hide resolved
llvm/include/llvm/Support/DXILABI.h Outdated Show resolved Hide resolved
llvm/lib/Target/DirectX/DXIL.td Outdated Show resolved Hide resolved
llvm/lib/Target/DirectX/DXIL.td Outdated Show resolved Hide resolved
llvm/lib/Target/DirectX/DXILOpLowering.cpp Outdated Show resolved Hide resolved
llvm/lib/Target/DirectX/DXILOpLowering.cpp Outdated Show resolved Hide resolved
llvm/test/CodeGen/DirectX/sin.ll Outdated Show resolved Hide resolved
llvm/lib/Target/DirectX/DXILOpBuilder.cpp Outdated Show resolved Hide resolved
llvm/lib/Target/DirectX/DXIL.td Outdated Show resolved Hide resolved
llvm/utils/TableGen/DXILEmitter.cpp Outdated Show resolved Hide resolved
llvm/utils/TableGen/DXILEmitter.cpp Show resolved Hide resolved
llvm/utils/TableGen/DXILEmitter.cpp Outdated Show resolved Hide resolved
llvm/utils/TableGen/DXILEmitter.cpp Show resolved Hide resolved
llvm/utils/TableGen/DXILEmitter.cpp Outdated Show resolved Hide resolved
Copy link
Contributor Author

@bharadwajy bharadwajy left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Addressed PR feedback.

llvm/include/llvm/Support/DXILABI.h Outdated Show resolved Hide resolved
llvm/lib/Target/DirectX/DXIL.td Outdated Show resolved Hide resolved
llvm/lib/Target/DirectX/DXIL.td Outdated Show resolved Hide resolved
llvm/lib/Target/DirectX/DXIL.td Outdated Show resolved Hide resolved
llvm/lib/Target/DirectX/DXIL.td Outdated Show resolved Hide resolved
llvm/test/CodeGen/DirectX/sin.ll Outdated Show resolved Hide resolved
llvm_unreachable("Support for specified DXIL Type not yet implemented");
report_fatal_error(
"Support for specified parameter type not yet implemented",
/*gen_crash_diag*/ false);
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This should still be llvm_unreachable - if we get here we made a mistake in DXIL.td or otherwise broke our invariants.

Changed.

report_fatal_error("Specification of multiple differing overload "
"parameter types not yet supported",
false);
/*gen_crash_diag*/ false);
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It isn't new, but this should be an assert, not a report_fatal_error. As described in the TODO above hitting this means we've made a mistake.

Changed to assert.

llvm/utils/TableGen/DXILEmitter.cpp Outdated Show resolved Hide resolved
llvm/utils/TableGen/DXILEmitter.cpp Outdated Show resolved Hide resolved
@bharadwajy bharadwajy force-pushed the dxil_td/precise-overload-specification branch from 7f64891 to aee3f5b Compare April 11, 2024 16:17
@bharadwajy bharadwajy self-assigned this May 28, 2024
Each DXIL OpClass represents DXIL Ops with the same function protitype
(signature). Represent this property in a TableGen class and add
an overload types field with the DXILOpMapping to denote valid
overload types of a DXIL Op record being defined.
for TableGen records of DXIL Opeartions.
This pattern of tests will facilitate use of same test sources to
test lowering of various combinations of options.
class defs when adding DXIL Ops of a new class.
version Shader Model in a single place to be used both
by DXILEmitter and by the lowering pass.
 - Use VersionTriple to deal with Shader Model version.
 - Undo sin test reorganization.
to create this branch

Reverted classification of valid overload specification based on
Shader Model version for sin. Will specify minimum DXIL Version
for DXIL Ops.
@bharadwajy bharadwajy force-pushed the dxil_td/precise-overload-specification branch from aee3f5b to 560d94f Compare May 29, 2024 23:02
@bharadwajy bharadwajy changed the title [DXIL] Model DXIL Class and Shader Model association of DXIL Ops in DXIL.td [DXIL] Model DXIL Class specification of DXIL Ops in DXIL.td May 29, 2024
break;
}
}
if (!knownType) {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I see you added an assert( ) here. Did this changed the behavior? Before it would report an error using report_fatal_error( ) and have text. Is the assert( ) doing the same thing?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There seem to be a few places where error reporting or llvm_unreachable's have been turned into asserts. Is there an overall strategy here?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I see you added an assert( ) here. Did this changed the behavior? Before it would report an error using report_fatal_error( ) and have text. Is the assert( ) doing the same thing?

Yes, the error text would be printed.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There seem to be a few places where error reporting or llvm_unreachable's have been turned into asserts. Is there an overall strategy here?

The overall strategy suggested in earlier feedback on this PR was to use assert to flag violations of assumptions made in code - for example, in this case, the assumption specified in the TODO comment a few lines above.

Copy link
Contributor

@damyanp damyanp left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It'd be good to be clear on how errors are handled/reported. Is assert the right thing to use in all cases here?

// Concrete definitions of DXIL Op Classes
// Refer to the design document
// (https://github.com/llvm/llvm-project/blob/main/llvm/docs/DirectX/DXILOpTableGenDesign.rst)
// for details about the use of DXIL Op Class name.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This comment isn't very helpful, as it just points at a whole design document without any real information about what you would want to read in it. If I search for "DXIL Op Class name" in that document I find nothing, and from a brief perusal of the doc the only thing I can find that sounds like details about the use of the name is this sentence:

This string is an integral part of the DXIL Op function name and is constructed in the format dx.op.<class-name>.<overload-type>

If that's all we want to say maybe it's better to repeat it here? Otherwise, if there's something more complicated going on that needs more in depth explanation than is reasonable in a comment, then pointing to a more specific part of the document would be more appropriate.

int OpCode = opCode;
Intrinsic LLVMIntrinsic = intrinsic;
list<LLVMType> OpOverloadTypes = overloadTypes;
string Doc = doc;
}
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why is this split into DXILOpPropertiesBase and DXILOpProperties? Will there be other derived definitions of some kind?

}

let OpClass = binary in {
// Float overloads
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm not sure these "float overloads" and "int overloads" comments are all that helpful here, we can probably drop them.

" imad(m,a,b) = m * a + b.">;
def UMad : DXILOpProperties<49, int_dx_umad, [llvm_i16_ty, llvm_i32_ty, llvm_i64_ty],
"Unsigned integer arithmetic multiply/add operation."
" umad(m,a, = m * a + b.">;
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

More formatting (These are within the opClass = tertiary scope, they should be indented appropriately.)

// format [returnTy, param1ty, ...]
let OpClass = unary in {
def Abs : DXILOpProperties<6, int_fabs, [llvm_half_ty, llvm_float_ty, llvm_double_ty],
"Returns the absolute value of the input.">;
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The indentation on all of these looks a little funny, probably because of one of the renames that ended up on DXILOpProperties. Apparently clang-format works on tablegen now (There was a talk recently), so that might be worth a try.

@@ -1,5 +1,5 @@
; RUN: opt -S -dxil-intrinsic-expansion < %s | FileCheck %s --check-prefixes=CHECK,EXPCHECK
; RUN: opt -S -dxil-op-lower < %s | FileCheck %s --check-prefixes=CHECK,DOPCHECK
; RUN: opt -S -dxil-op-lower -mtriple=dxil-pc-shadermodel6.3-library %s | FileCheck %s --check-prefixes=CHECK,DOPCHECK
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do we need to add all of these triples to the tests? The current implementation doesn't key off of shader model or dxil versions at all.

%0 = load half, ptr %a.addr, align 2
%1 = call half @llvm.sin.f16(half %0)
ret half %1
}
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Did something actually change in this test or is it just reordered for some reason?

; CHECK: LLVM ERROR: Invalid Overload
; Double is not valid in any Shader Model version
; SM6_0_DOUBLE: LLVM ERROR: Invalid Overload
; SM6_3_DOUBLE: LLVM ERROR: Invalid Overload
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This change seems unnecessary.

// Populate OpOverloads with
for (unsigned I = 0; I < OverloadTypeRecsSize; I++) {
OpOverloads.emplace_back(OverloadTypeRecs[I]);
}
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why copy the vector out of R->getValueAsListOfDefs at all? Isn't this just:

for (Record *R : R->getValueAsListOfDefs("OpOverloadTypes"))
  OpOverloads.push_back(R);

static std::string getOverloadKindStrs(const SmallVector<Record *> OLTys) {
std::string OverloadString = "";
std::string Prefix = "";
for (auto OLTy : OLTys) {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Better to spell this as const Record *OLTy

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
Status: No status
Development

Successfully merging this pull request may close these issues.

None yet

7 participants