Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[DirectX][DXIL] Align type spec of TableGen DXIL Op and LLVM Intrinsic #86311

Conversation

bharadwajy
Copy link
Contributor

Align specification of return type and parameter type fields of DXIL Op mapping with those of TableGan class Intrinsic.

A void return type of LLVM Intrinsic is represented as [] in its TableGen description record. Currently, a void return type of DXIL Operation is represented as [llvm_void_ty]. In addition, return and parameter types are recorded as a single list with an understanding that element at index 0 is the return type.

These changes leverage and align DXIL Op type specification with the type specification of the LLVM Intrinsic. As a result, return and parameter types are now specified as two separate lists no longer requiring a different representation for void return type. Additionally, type specification would be more succinct yet equally informative for DXIL Op records for which the same LLVM Intrinsics types are also valid.

Added a test to verify lowering of LLVM intrinsic with void return.

Barrier intrinsic has a void return type. Specification of its DXIL Op can inherit the types of this intrinsic. The test verifies the changes.

Move OverloadKind to DXILABI.h.

Fixes #86229

@llvmbot
Copy link
Collaborator

llvmbot commented Mar 22, 2024

@llvm/pr-subscribers-backend-directx

@llvm/pr-subscribers-llvm-support

Author: S. Bharadwaj Yadavalli (bharadwajy)

Changes

Align specification of return type and parameter type fields of DXIL Op mapping with those of TableGan class Intrinsic.

A void return type of LLVM Intrinsic is represented as [] in its TableGen description record. Currently, a void return type of DXIL Operation is represented as [llvm_void_ty]. In addition, return and parameter types are recorded as a single list with an understanding that element at index 0 is the return type.

These changes leverage and align DXIL Op type specification with the type specification of the LLVM Intrinsic. As a result, return and parameter types are now specified as two separate lists no longer requiring a different representation for void return type. Additionally, type specification would be more succinct yet equally informative for DXIL Op records for which the same LLVM Intrinsics types are also valid.

Added a test to verify lowering of LLVM intrinsic with void return.

Barrier intrinsic has a void return type. Specification of its DXIL Op can inherit the types of this intrinsic. The test verifies the changes.

Move OverloadKind to DXILABI.h.

Fixes #86229


Patch is 22.49 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/86311.diff

6 Files Affected:

  • (modified) llvm/include/llvm/IR/IntrinsicsDirectX.td (+1)
  • (modified) llvm/include/llvm/Support/DXILABI.h (+15)
  • (modified) llvm/lib/Target/DirectX/DXIL.td (+19-12)
  • (modified) llvm/lib/Target/DirectX/DXILOpBuilder.cpp (+4-19)
  • (added) llvm/test/CodeGen/DirectX/barrier.ll (+11)
  • (modified) llvm/utils/TableGen/DXILEmitter.cpp (+173-114)
diff --git a/llvm/include/llvm/IR/IntrinsicsDirectX.td b/llvm/include/llvm/IR/IntrinsicsDirectX.td
index 1164b241ba7b0d..292aa4497916f1 100644
--- a/llvm/include/llvm/IR/IntrinsicsDirectX.td
+++ b/llvm/include/llvm/IR/IntrinsicsDirectX.td
@@ -16,6 +16,7 @@ def int_dx_thread_id : Intrinsic<[llvm_i32_ty], [llvm_i32_ty], [IntrNoMem, IntrW
 def int_dx_group_id : Intrinsic<[llvm_i32_ty], [llvm_i32_ty], [IntrNoMem, IntrWillReturn]>;
 def int_dx_thread_id_in_group : Intrinsic<[llvm_i32_ty], [llvm_i32_ty], [IntrNoMem, IntrWillReturn]>;
 def int_dx_flattened_thread_id_in_group : Intrinsic<[llvm_i32_ty], [], [IntrNoMem, IntrWillReturn]>;
+def int_dx_barrier  : Intrinsic<[], [llvm_i32_ty], [IntrNoDuplicate, IntrWillReturn]>;
 
 def int_dx_create_handle : ClangBuiltin<"__builtin_hlsl_create_handle">,
     Intrinsic<[ llvm_ptr_ty ], [llvm_i8_ty], [IntrWillReturn]>;
diff --git a/llvm/include/llvm/Support/DXILABI.h b/llvm/include/llvm/Support/DXILABI.h
index c1d81775b6711e..75cc17bad992d9 100644
--- a/llvm/include/llvm/Support/DXILABI.h
+++ b/llvm/include/llvm/Support/DXILABI.h
@@ -39,6 +39,21 @@ enum class ParameterKind : uint8_t {
   DXIL_HANDLE,
 };
 
+enum OverloadKind : uint16_t {
+  INVALID = 0,
+  VOID = 1,
+  HALF = 1 << 1,
+  FLOAT = 1 << 2,
+  DOUBLE = 1 << 3,
+  I1 = 1 << 4,
+  I8 = 1 << 5,
+  I16 = 1 << 6,
+  I32 = 1 << 7,
+  I64 = 1 << 8,
+  UserDefineType = 1 << 9,
+  ObjectType = 1 << 10,
+};
+
 /// The kind of resource for an SRV or UAV resource. Sometimes referred to as
 /// "Shape" in the DXIL docs.
 enum class ResourceKind : uint32_t {
diff --git a/llvm/lib/Target/DirectX/DXIL.td b/llvm/lib/Target/DirectX/DXIL.td
index f7e69ebae15b6c..aa98b74c1ffe56 100644
--- a/llvm/lib/Target/DirectX/DXIL.td
+++ b/llvm/lib/Target/DirectX/DXIL.td
@@ -240,18 +240,23 @@ class DXILOpMappingBase {
   DXILOpClass OpClass = UnknownOpClass;// Class of DXIL Operation.
   Intrinsic LLVMIntrinsic = ?;         // LLVM Intrinsic DXIL Operation maps to
   string Doc = "";                     // A short description of the operation
-  list<LLVMType> OpTypes = ?;          // Valid types of DXIL Operation in the
-                                       // format [returnTy, param1ty, ...]
+  // The following fields denote the same semantics as those of Intrinsic class
+  // and are initialized with the same values as those of LLVMIntrinsic unless
+  // overridden in the definition of a record.
+  list<LLVMType> OpRetTypes = ?;    // Valid return types of DXIL Operation
+  list<LLVMType> OpParamTypes = ?;     // Valid parameter types of DXIL Operation
 }
 
 class DXILOpMapping<int opCode, DXILOpClass opClass,
                     Intrinsic intrinsic, string doc,
-                    list<LLVMType> opTys = []> : DXILOpMappingBase {
+                    list<LLVMType> retTys = [],
+                    list<LLVMType> paramTys = []> : DXILOpMappingBase {
   int OpCode = opCode;                 // Opcode corresponding to DXIL Operation
   DXILOpClass OpClass = opClass;       // Class of DXIL Operation.
   Intrinsic LLVMIntrinsic = intrinsic; // LLVM Intrinsic the DXIL Operation maps
   string Doc = doc;                    // to a short description of the operation
-  list<LLVMType> OpTypes = !if(!eq(!size(opTys), 0), LLVMIntrinsic.Types, opTys);
+  list<LLVMType> OpRetTypes = !if(!eq(!size(retTys), 0), LLVMIntrinsic.RetTypes, retTys);
+  list<LLVMType> OpParamTypes = !if(!eq(!size(paramTys), 0), LLVMIntrinsic.ParamTypes, paramTys);
 }
 
 // Concrete definition of DXIL Operation mapping to corresponding LLVM intrinsic
@@ -259,32 +264,32 @@ def Abs : DXILOpMapping<6, unary, int_fabs,
                          "Returns the absolute value of the input.">;
 def IsInf : DXILOpMapping<9, isSpecialFloat, int_dx_isinf,
                          "Determines if the specified value is infinite.",
-                         [llvm_i1_ty, llvm_halforfloat_ty]>;
+                         [llvm_i1_ty], [llvm_halforfloat_ty]>;
 def Cos  : DXILOpMapping<12, unary, int_cos,
                          "Returns cosine(theta) for theta in radians.",
-                         [llvm_halforfloat_ty, LLVMMatchType<0>]>;
+                         [llvm_halforfloat_ty], [LLVMMatchType<0>]>;
 def Sin  : DXILOpMapping<13, unary, int_sin,
                          "Returns sine(theta) for theta in radians.",
-                         [llvm_halforfloat_ty, LLVMMatchType<0>]>;
+                         [llvm_halforfloat_ty], [LLVMMatchType<0>]>;
 def Exp2 : DXILOpMapping<21, unary, int_exp2,
                          "Returns the base 2 exponential, or 2**x, of the specified value."
                          "exp2(x) = 2**x.",
-                         [llvm_halforfloat_ty, LLVMMatchType<0>]>;
+                         [llvm_halforfloat_ty], [LLVMMatchType<0>]>;
 def Frac : DXILOpMapping<22, unary, int_dx_frac,
                          "Returns a fraction from 0 to 1 that represents the "
                          "decimal part of the input.",
-                         [llvm_halforfloat_ty, LLVMMatchType<0>]>;
+                         [llvm_halforfloat_ty], [LLVMMatchType<0>]>;
 def RSqrt : DXILOpMapping<25, unary, int_dx_rsqrt,
                          "Returns the reciprocal of the square root of the specified value."
                          "rsqrt(x) = 1 / sqrt(x).",
-                         [llvm_halforfloat_ty, LLVMMatchType<0>]>;
+                         [llvm_halforfloat_ty], [LLVMMatchType<0>]>;
 def Round : DXILOpMapping<26, unary, int_round,
                          "Returns the input rounded to the nearest integer"
                          "within a floating-point type.",
-                         [llvm_halforfloat_ty, LLVMMatchType<0>]>;
+                         [llvm_halforfloat_ty], [LLVMMatchType<0>]>;
 def Floor : DXILOpMapping<27, unary, int_floor,
                          "Returns the largest integer that is less than or equal to the input.",
-                         [llvm_halforfloat_ty, LLVMMatchType<0>]>;
+                         [llvm_halforfloat_ty], [LLVMMatchType<0>]>;
 def FMax : DXILOpMapping<35, binary, int_maxnum,
                          "Float maximum. FMax(a,b) = a > b ? a : b">;
 def FMin : DXILOpMapping<36, binary, int_minnum,
@@ -303,6 +308,8 @@ def IMad : DXILOpMapping<48, tertiary, int_dx_imad,
                          "Signed integer arithmetic multiply/add operation. imad(m,a,b) = m * a + b.">;
 def UMad : DXILOpMapping<49, tertiary, int_dx_umad,
                          "Unsigned integer arithmetic multiply/add operation. umad(m,a,b) = m * a + b.">;
+def Barrier : DXILOpMapping<80, barrier, int_dx_barrier,
+                          "Inserts a memory barrier in the shader">;
 def ThreadId : DXILOpMapping<93, threadId, int_dx_thread_id,
                              "Reads the thread ID">;
 def GroupId  : DXILOpMapping<94, groupId, int_dx_group_id,
diff --git a/llvm/lib/Target/DirectX/DXILOpBuilder.cpp b/llvm/lib/Target/DirectX/DXILOpBuilder.cpp
index a1eacc2d48009c..f56155312aee50 100644
--- a/llvm/lib/Target/DirectX/DXILOpBuilder.cpp
+++ b/llvm/lib/Target/DirectX/DXILOpBuilder.cpp
@@ -21,24 +21,6 @@ using namespace llvm::dxil;
 
 constexpr StringLiteral DXILOpNamePrefix = "dx.op.";
 
-namespace {
-
-enum OverloadKind : uint16_t {
-  VOID = 1,
-  HALF = 1 << 1,
-  FLOAT = 1 << 2,
-  DOUBLE = 1 << 3,
-  I1 = 1 << 4,
-  I8 = 1 << 5,
-  I16 = 1 << 6,
-  I32 = 1 << 7,
-  I64 = 1 << 8,
-  UserDefineType = 1 << 9,
-  ObjectType = 1 << 10,
-};
-
-} // namespace
-
 static const char *getOverloadTypeName(OverloadKind Kind) {
   switch (Kind) {
   case OverloadKind::HALF:
@@ -61,8 +43,11 @@ static const char *getOverloadTypeName(OverloadKind Kind) {
   case OverloadKind::ObjectType:
   case OverloadKind::UserDefineType:
     break;
+  case OverloadKind::INVALID:
+    report_fatal_error("Invalid Overload Type for type name lookup",
+                       /* gen_crash_diag=*/false);
   }
-  llvm_unreachable("invalid overload type for name");
+  llvm_unreachable("Unhandled Overload Type specified for type name lookup");
   return "void";
 }
 
diff --git a/llvm/test/CodeGen/DirectX/barrier.ll b/llvm/test/CodeGen/DirectX/barrier.ll
new file mode 100644
index 00000000000000..8be4aac1f782b5
--- /dev/null
+++ b/llvm/test/CodeGen/DirectX/barrier.ll
@@ -0,0 +1,11 @@
+; RUN: opt -S -dxil-op-lower < %s | FileCheck %s
+
+; Argument of llvm.dx.barrier is expected to be a mask of 
+; DXIL::BarrierMode values. Chose an int value for testing.
+
+define void @test_barrier() #0 {
+entry:
+  ; CHECK: call void @dx.op.barrier.i32(i32 80, i32 9)
+  call void @llvm.dx.barrier(i32 noundef 9)
+  ret void
+}
diff --git a/llvm/utils/TableGen/DXILEmitter.cpp b/llvm/utils/TableGen/DXILEmitter.cpp
index af1efb8aa99f73..5f678a527381ca 100644
--- a/llvm/utils/TableGen/DXILEmitter.cpp
+++ b/llvm/utils/TableGen/DXILEmitter.cpp
@@ -39,8 +39,8 @@ struct DXILOperationDesc {
   int OpCode;         // ID of DXIL operation
   StringRef OpClass;  // name of the opcode class
   StringRef Doc;      // the documentation description of this instruction
-  SmallVector<Record *> OpTypes; // Vector of operand type records -
-                                 // return type is at index 0
+  SmallVector<OverloadKind> OpOverloadTys; // Vector of operand overload types -
+                                           // return type is at index 0
   SmallVector<std::string>
       OpAttributes;     // operation attribute represented as strings
   StringRef Intrinsic;  // The llvm intrinsic map to OpName. Default is "" which
@@ -65,41 +65,167 @@ struct DXILOperationDesc {
 };
 } // end anonymous namespace
 
-/// Return dxil::ParameterKind corresponding to input LLVMType record
+/// Return dxil::ParameterKind corresponding to input Overload Kind
 ///
-/// \param R TableGen def record of class LLVMType
+/// \param OLKind Overload Kind
 /// \return ParameterKind As defined in llvm/Support/DXILABI.h
 
-static ParameterKind getParameterKind(const Record *R) {
+static ParameterKind getParameterKind(const dxil::OverloadKind OLKind) {
+  switch (OLKind) {
+  case OverloadKind::VOID:
+    return ParameterKind::VOID;
+  case OverloadKind::HALF:
+    return ParameterKind::HALF;
+  case OverloadKind::FLOAT:
+    return ParameterKind::FLOAT;
+  case OverloadKind::DOUBLE:
+    return ParameterKind::DOUBLE;
+  case OverloadKind::I1:
+    return ParameterKind::I1;
+  case OverloadKind::I8:
+    return ParameterKind::I8;
+  case OverloadKind::I16:
+    return ParameterKind::I16;
+  case OverloadKind::I32:
+    return ParameterKind::I32;
+  case OverloadKind::I64:
+    return ParameterKind::I64;
+  default:
+    if ((OLKind ==
+         (OverloadKind::HALF | OverloadKind::FLOAT | OverloadKind::DOUBLE)) ||
+        (OLKind == (OverloadKind::HALF | OverloadKind::FLOAT)) ||
+        (OLKind == (OverloadKind::I1 | OverloadKind::I8 | OverloadKind::I16 |
+                    OverloadKind::I32 | OverloadKind::I64)) ||
+        (OLKind == (OverloadKind::I16 | OverloadKind::I32))) {
+      return ParameterKind::OVERLOAD;
+    } else {
+      report_fatal_error("Unsupported Overload Type encountered",
+                         /* gen_crash_diag=*/false);
+    }
+  }
+}
+
+/// Return a string representation of ParameterKind enum
+/// \param Kind Parameter Kind enum value
+/// \return std::string string representation of input Kind
+static std::string getParameterKindStr(ParameterKind Kind) {
+  switch (Kind) {
+  case ParameterKind::INVALID:
+    return "INVALID";
+  case ParameterKind::VOID:
+    return "VOID";
+  case ParameterKind::HALF:
+    return "HALF";
+  case ParameterKind::FLOAT:
+    return "FLOAT";
+  case ParameterKind::DOUBLE:
+    return "DOUBLE";
+  case ParameterKind::I1:
+    return "I1";
+  case ParameterKind::I8:
+    return "I8";
+  case ParameterKind::I16:
+    return "I16";
+  case ParameterKind::I32:
+    return "I32";
+  case ParameterKind::I64:
+    return "I64";
+  case ParameterKind::OVERLOAD:
+    return "OVERLOAD";
+  case ParameterKind::CBUFFER_RET:
+    return "CBUFFER_RET";
+  case ParameterKind::RESOURCE_RET:
+    return "RESOURCE_RET";
+  case ParameterKind::DXIL_HANDLE:
+    return "DXIL_HANDLE";
+  }
+  llvm_unreachable("Unknown llvm::dxil::ParameterKind enum");
+}
+
+static dxil::OverloadKind getOverloadKind(const Record *R) {
   auto VTRec = R->getValueAsDef("VT");
   switch (getValueType(VTRec)) {
   case MVT::isVoid:
-    return ParameterKind::VOID;
+    return OverloadKind::VOID;
   case MVT::f16:
-    return ParameterKind::HALF;
+    return OverloadKind::HALF;
   case MVT::f32:
-    return ParameterKind::FLOAT;
+    return OverloadKind::FLOAT;
   case MVT::f64:
-    return ParameterKind::DOUBLE;
+    return OverloadKind::DOUBLE;
   case MVT::i1:
-    return ParameterKind::I1;
+    return OverloadKind::I1;
   case MVT::i8:
-    return ParameterKind::I8;
+    return OverloadKind::I8;
   case MVT::i16:
-    return ParameterKind::I16;
+    return OverloadKind::I16;
   case MVT::i32:
-    return ParameterKind::I32;
-  case MVT::fAny:
+    return OverloadKind::I32;
+  case MVT::i64:
+    return OverloadKind::I64;
   case MVT::iAny:
-    return ParameterKind::OVERLOAD;
+    return static_cast<dxil::OverloadKind>(
+        OverloadKind::I16 | OverloadKind::I32 | OverloadKind::I64);
+  case MVT::fAny:
+    return static_cast<dxil::OverloadKind>(
+        OverloadKind::HALF | OverloadKind::FLOAT | OverloadKind::DOUBLE);
   case MVT::Other:
     // Handle DXIL-specific overload types
-    if (R->getValueAsInt("isHalfOrFloat") || R->getValueAsInt("isI16OrI32")) {
-      return ParameterKind::OVERLOAD;
+    {
+      if (R->getValueAsInt("isHalfOrFloat")) {
+        return static_cast<dxil::OverloadKind>(OverloadKind::HALF |
+                                               OverloadKind::FLOAT);
+      } else if (R->getValueAsInt("isI16OrI32")) {
+        return static_cast<dxil::OverloadKind>(OverloadKind::I16 |
+                                               OverloadKind::I32);
+      }
     }
     LLVM_FALLTHROUGH;
   default:
-    llvm_unreachable("Support for specified DXIL Type not yet implemented");
+    report_fatal_error(
+        "Support for specified parameter OverloadKind not yet implemented",
+        /* gen_crash_diag=*/false);
+  }
+}
+
+/// Return a string representation of OverloadKind enum
+/// \param OLKind Overload Kind
+/// \return std::string string representation of OverloadKind
+
+static std::string getOverloadKindStr(const dxil::OverloadKind OLKind) {
+  switch (OLKind) {
+  case OverloadKind::VOID:
+    return "OverloadKind::VOID";
+  case OverloadKind::HALF:
+    return "OverloadKind::HALF";
+  case OverloadKind::FLOAT:
+    return "OverloadKind::FLOAT";
+  case OverloadKind::DOUBLE:
+    return "OverloadKind::DOUBLE";
+  case OverloadKind::I1:
+    return "OverloadKind::I1";
+  case OverloadKind::I8:
+    return "OverloadKind::I8";
+  case OverloadKind::I16:
+    return "OverloadKind::I16";
+  case OverloadKind::I32:
+    return "OverloadKind::I32";
+  case OverloadKind::I64:
+    return "OverloadKind::I64";
+  default:
+    if (OLKind == (OverloadKind::I16 | OverloadKind::I32 | OverloadKind::I64)) {
+      return "OverloadKind::I16 | OverloadKind::I32 | OverloadKind::I64";
+    } else if (OLKind == (OverloadKind::HALF | OverloadKind::FLOAT |
+                          OverloadKind::DOUBLE)) {
+      return "OverloadKind::HALF | OverloadKind::FLOAT | OverloadKind::DOUBLE";
+    } else if (OLKind == (OverloadKind::HALF | OverloadKind::FLOAT)) {
+      return "OverloadKind::HALF | OverloadKind::FLOAT";
+    } else if (OLKind == (OverloadKind::I16 | OverloadKind::I32)) {
+      return "OverloadKind::I16 | OverloadKind::I32";
+    } else {
+      report_fatal_error("Unsupported OverloadKind specified",
+                         /* gen_crash_diag=*/false);
+    }
   }
 }
 
@@ -114,9 +240,25 @@ DXILOperationDesc::DXILOperationDesc(const Record *R) {
 
   Doc = R->getValueAsString("Doc");
 
-  auto TypeRecs = R->getValueAsListOfDefs("OpTypes");
+  // Populate OpOverloadTys with return type and parameter types
+  auto RetTypeRecs = R->getValueAsListOfDefs("OpRetTypes");
+  auto ParamTypeRecs = R->getValueAsListOfDefs("OpParamTypes");
+  unsigned RetTypeRecSize = RetTypeRecs.size();
+  unsigned ParamTypeRecSize = ParamTypeRecs.size();
+  // A vector with return type and parameter type records
+  std::vector<Record *> TypeRecs;
+  TypeRecs.reserve(RetTypeRecSize + ParamTypeRecSize);
+  // If return type lust is empty, the return type is void
+  if (RetTypeRecSize == 0) {
+    OpOverloadTys.emplace_back(OverloadKind::VOID);
+  } else {
+    // Append RetTypeRecs to TypeRecs
+    TypeRecs.insert(TypeRecs.end(), RetTypeRecs.begin(), RetTypeRecs.end());
+  }
+  // Append RetTypeRecs to TypeRecs
+  TypeRecs.insert(TypeRecs.end(), ParamTypeRecs.begin(), ParamTypeRecs.end());
+
   unsigned TypeRecsSize = TypeRecs.size();
-  // Populate OpTypes with return type and parameter types
 
   // Parameter indices of overloaded parameters.
   // This vector contains overload parameters in the order used to
@@ -146,13 +288,13 @@ DXILOperationDesc::DXILOperationDesc(const Record *R) {
         if (!knownType) {
           report_fatal_error("Specification of multiple differing overload "
                              "parameter types not yet supported",
-                             false);
+                             /* gen_crash_diag=*/false);
         }
       } else {
         OverloadParamIndices.push_back(i);
       }
     }
-    // Populate OpTypes array according to the type specification
+    // Populate OpOverloadTys array according to the type specification
     if (TR->isAnonymous()) {
       // Check prior overload types exist
       assert(!OverloadParamIndices.empty() &&
@@ -160,10 +302,10 @@ DXILOperationDesc::DXILOperationDesc(const Record *R) {
       // Get the parameter index of anonymous type, TR, references
       auto OLParamIndex = TR->getValueAsInt("Number");
       // Resolve and insert the type to that at OLParamIndex
-      OpTypes.emplace_back(TypeRecs[OLParamIndex]);
+      OpOverloadTys.emplace_back(getOverloadKind(TypeRecs[OLParamIndex]));
     } else {
-      // A non-anonymous type. Just record it in OpTypes
-      OpTypes.emplace_back(TR);
+      // A non-anonymous type. Just record it in OpOverloadTys
+      OpOverloadTys.emplace_back(getOverloadKind(TR));
     }
   }
 
@@ -172,7 +314,7 @@ DXILOperationDesc::DXILOperationDesc(const Record *R) {
   if (!OverloadParamIndices.empty()) {
     if (OverloadParamIndices.size() > 1)
       report_fatal_error("Multiple overload type specification not supported",
-                         false);
+                         /* gen_crash_diag=*/false);
     OverloadParamIndex = OverloadParamIndices[0];
   }
   // Get the operation class
@@ -196,89 +338,6 @@ DXILOperationDesc::DXILOperationDesc(const Record *R) {
   }
 }
 
-/// Return a string representation of ParameterKind enum
-/// \param Kind Parameter Kind enum value
-/// \return std::string string representation of input Kind
-static std::string getParameterKindStr(ParameterKind Kind) {
-  switch (Kind) {
-  case ParameterKind::INVALID:
-    return "INVALID";
-  case ParameterKind::VOID:
-    return "VOID";
-  case ParameterKind::HALF:
-    return "HALF";
-  case ParameterKind::FLOAT:
-    return "FLOAT";
-  case ParameterKind::DOUBLE:
-    return "DOUBLE";
-  case ParameterKind::I1:
-    return "I1";
-  case ParameterKind::I8:
-    return "I8";
-  case ParameterKind::I16:
-    return "I16";
-  case ParameterKind::I32:
-    return "I32";
-  case ParameterKind::I64:
-    return "I64";
-  case ParameterKind::OVERLOAD:
-    return "OVERLOAD";
-  case ParameterKind::CBUFFER_RET:
-    return "CBUFFER_RET";
-  case ParameterKind::RESOURCE_RET:
-    return "RESOURCE_RET";
-  case ParameterKind::DXIL_HANDLE:
-    return "DXIL_HANDLE";
-  }
-  llvm_unreachable("Unknown llvm::dxil::ParameterKind enum");
-}
-
-/// Return a string representation of OverloadKind enum that maps to
-/// input LLVMType record
-/// \param R TableGen def record of class LLVMType
-/// \return std::string string representation of OverloadKind
-
-static std::string getOverloadKindStr(const Record *R) {
-  auto VTRec = R->getValueAsDef("VT");
-  switch (getValueType(VTRec)) {
-  case MVT::isVoid:
-    return "OverloadKind::VOID";
-  case MVT::f16:
-    return "OverloadKind::HALF";
-  case MVT::f32:
-    return "OverloadKind::FLOAT";
-  case MVT::f64:
-    return "OverloadKin...
[truncated]

@llvmbot
Copy link
Collaborator

llvmbot commented Mar 22, 2024

@llvm/pr-subscribers-llvm-ir

Author: S. Bharadwaj Yadavalli (bharadwajy)

Changes

Align specification of return type and parameter type fields of DXIL Op mapping with those of TableGan class Intrinsic.

A void return type of LLVM Intrinsic is represented as [] in its TableGen description record. Currently, a void return type of DXIL Operation is represented as [llvm_void_ty]. In addition, return and parameter types are recorded as a single list with an understanding that element at index 0 is the return type.

These changes leverage and align DXIL Op type specification with the type specification of the LLVM Intrinsic. As a result, return and parameter types are now specified as two separate lists no longer requiring a different representation for void return type. Additionally, type specification would be more succinct yet equally informative for DXIL Op records for which the same LLVM Intrinsics types are also valid.

Added a test to verify lowering of LLVM intrinsic with void return.

Barrier intrinsic has a void return type. Specification of its DXIL Op can inherit the types of this intrinsic. The test verifies the changes.

Move OverloadKind to DXILABI.h.

Fixes #86229


Patch is 22.49 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/86311.diff

6 Files Affected:

  • (modified) llvm/include/llvm/IR/IntrinsicsDirectX.td (+1)
  • (modified) llvm/include/llvm/Support/DXILABI.h (+15)
  • (modified) llvm/lib/Target/DirectX/DXIL.td (+19-12)
  • (modified) llvm/lib/Target/DirectX/DXILOpBuilder.cpp (+4-19)
  • (added) llvm/test/CodeGen/DirectX/barrier.ll (+11)
  • (modified) llvm/utils/TableGen/DXILEmitter.cpp (+173-114)
diff --git a/llvm/include/llvm/IR/IntrinsicsDirectX.td b/llvm/include/llvm/IR/IntrinsicsDirectX.td
index 1164b241ba7b0d..292aa4497916f1 100644
--- a/llvm/include/llvm/IR/IntrinsicsDirectX.td
+++ b/llvm/include/llvm/IR/IntrinsicsDirectX.td
@@ -16,6 +16,7 @@ def int_dx_thread_id : Intrinsic<[llvm_i32_ty], [llvm_i32_ty], [IntrNoMem, IntrW
 def int_dx_group_id : Intrinsic<[llvm_i32_ty], [llvm_i32_ty], [IntrNoMem, IntrWillReturn]>;
 def int_dx_thread_id_in_group : Intrinsic<[llvm_i32_ty], [llvm_i32_ty], [IntrNoMem, IntrWillReturn]>;
 def int_dx_flattened_thread_id_in_group : Intrinsic<[llvm_i32_ty], [], [IntrNoMem, IntrWillReturn]>;
+def int_dx_barrier  : Intrinsic<[], [llvm_i32_ty], [IntrNoDuplicate, IntrWillReturn]>;
 
 def int_dx_create_handle : ClangBuiltin<"__builtin_hlsl_create_handle">,
     Intrinsic<[ llvm_ptr_ty ], [llvm_i8_ty], [IntrWillReturn]>;
diff --git a/llvm/include/llvm/Support/DXILABI.h b/llvm/include/llvm/Support/DXILABI.h
index c1d81775b6711e..75cc17bad992d9 100644
--- a/llvm/include/llvm/Support/DXILABI.h
+++ b/llvm/include/llvm/Support/DXILABI.h
@@ -39,6 +39,21 @@ enum class ParameterKind : uint8_t {
   DXIL_HANDLE,
 };
 
+enum OverloadKind : uint16_t {
+  INVALID = 0,
+  VOID = 1,
+  HALF = 1 << 1,
+  FLOAT = 1 << 2,
+  DOUBLE = 1 << 3,
+  I1 = 1 << 4,
+  I8 = 1 << 5,
+  I16 = 1 << 6,
+  I32 = 1 << 7,
+  I64 = 1 << 8,
+  UserDefineType = 1 << 9,
+  ObjectType = 1 << 10,
+};
+
 /// The kind of resource for an SRV or UAV resource. Sometimes referred to as
 /// "Shape" in the DXIL docs.
 enum class ResourceKind : uint32_t {
diff --git a/llvm/lib/Target/DirectX/DXIL.td b/llvm/lib/Target/DirectX/DXIL.td
index f7e69ebae15b6c..aa98b74c1ffe56 100644
--- a/llvm/lib/Target/DirectX/DXIL.td
+++ b/llvm/lib/Target/DirectX/DXIL.td
@@ -240,18 +240,23 @@ class DXILOpMappingBase {
   DXILOpClass OpClass = UnknownOpClass;// Class of DXIL Operation.
   Intrinsic LLVMIntrinsic = ?;         // LLVM Intrinsic DXIL Operation maps to
   string Doc = "";                     // A short description of the operation
-  list<LLVMType> OpTypes = ?;          // Valid types of DXIL Operation in the
-                                       // format [returnTy, param1ty, ...]
+  // The following fields denote the same semantics as those of Intrinsic class
+  // and are initialized with the same values as those of LLVMIntrinsic unless
+  // overridden in the definition of a record.
+  list<LLVMType> OpRetTypes = ?;    // Valid return types of DXIL Operation
+  list<LLVMType> OpParamTypes = ?;     // Valid parameter types of DXIL Operation
 }
 
 class DXILOpMapping<int opCode, DXILOpClass opClass,
                     Intrinsic intrinsic, string doc,
-                    list<LLVMType> opTys = []> : DXILOpMappingBase {
+                    list<LLVMType> retTys = [],
+                    list<LLVMType> paramTys = []> : DXILOpMappingBase {
   int OpCode = opCode;                 // Opcode corresponding to DXIL Operation
   DXILOpClass OpClass = opClass;       // Class of DXIL Operation.
   Intrinsic LLVMIntrinsic = intrinsic; // LLVM Intrinsic the DXIL Operation maps
   string Doc = doc;                    // to a short description of the operation
-  list<LLVMType> OpTypes = !if(!eq(!size(opTys), 0), LLVMIntrinsic.Types, opTys);
+  list<LLVMType> OpRetTypes = !if(!eq(!size(retTys), 0), LLVMIntrinsic.RetTypes, retTys);
+  list<LLVMType> OpParamTypes = !if(!eq(!size(paramTys), 0), LLVMIntrinsic.ParamTypes, paramTys);
 }
 
 // Concrete definition of DXIL Operation mapping to corresponding LLVM intrinsic
@@ -259,32 +264,32 @@ def Abs : DXILOpMapping<6, unary, int_fabs,
                          "Returns the absolute value of the input.">;
 def IsInf : DXILOpMapping<9, isSpecialFloat, int_dx_isinf,
                          "Determines if the specified value is infinite.",
-                         [llvm_i1_ty, llvm_halforfloat_ty]>;
+                         [llvm_i1_ty], [llvm_halforfloat_ty]>;
 def Cos  : DXILOpMapping<12, unary, int_cos,
                          "Returns cosine(theta) for theta in radians.",
-                         [llvm_halforfloat_ty, LLVMMatchType<0>]>;
+                         [llvm_halforfloat_ty], [LLVMMatchType<0>]>;
 def Sin  : DXILOpMapping<13, unary, int_sin,
                          "Returns sine(theta) for theta in radians.",
-                         [llvm_halforfloat_ty, LLVMMatchType<0>]>;
+                         [llvm_halforfloat_ty], [LLVMMatchType<0>]>;
 def Exp2 : DXILOpMapping<21, unary, int_exp2,
                          "Returns the base 2 exponential, or 2**x, of the specified value."
                          "exp2(x) = 2**x.",
-                         [llvm_halforfloat_ty, LLVMMatchType<0>]>;
+                         [llvm_halforfloat_ty], [LLVMMatchType<0>]>;
 def Frac : DXILOpMapping<22, unary, int_dx_frac,
                          "Returns a fraction from 0 to 1 that represents the "
                          "decimal part of the input.",
-                         [llvm_halforfloat_ty, LLVMMatchType<0>]>;
+                         [llvm_halforfloat_ty], [LLVMMatchType<0>]>;
 def RSqrt : DXILOpMapping<25, unary, int_dx_rsqrt,
                          "Returns the reciprocal of the square root of the specified value."
                          "rsqrt(x) = 1 / sqrt(x).",
-                         [llvm_halforfloat_ty, LLVMMatchType<0>]>;
+                         [llvm_halforfloat_ty], [LLVMMatchType<0>]>;
 def Round : DXILOpMapping<26, unary, int_round,
                          "Returns the input rounded to the nearest integer"
                          "within a floating-point type.",
-                         [llvm_halforfloat_ty, LLVMMatchType<0>]>;
+                         [llvm_halforfloat_ty], [LLVMMatchType<0>]>;
 def Floor : DXILOpMapping<27, unary, int_floor,
                          "Returns the largest integer that is less than or equal to the input.",
-                         [llvm_halforfloat_ty, LLVMMatchType<0>]>;
+                         [llvm_halforfloat_ty], [LLVMMatchType<0>]>;
 def FMax : DXILOpMapping<35, binary, int_maxnum,
                          "Float maximum. FMax(a,b) = a > b ? a : b">;
 def FMin : DXILOpMapping<36, binary, int_minnum,
@@ -303,6 +308,8 @@ def IMad : DXILOpMapping<48, tertiary, int_dx_imad,
                          "Signed integer arithmetic multiply/add operation. imad(m,a,b) = m * a + b.">;
 def UMad : DXILOpMapping<49, tertiary, int_dx_umad,
                          "Unsigned integer arithmetic multiply/add operation. umad(m,a,b) = m * a + b.">;
+def Barrier : DXILOpMapping<80, barrier, int_dx_barrier,
+                          "Inserts a memory barrier in the shader">;
 def ThreadId : DXILOpMapping<93, threadId, int_dx_thread_id,
                              "Reads the thread ID">;
 def GroupId  : DXILOpMapping<94, groupId, int_dx_group_id,
diff --git a/llvm/lib/Target/DirectX/DXILOpBuilder.cpp b/llvm/lib/Target/DirectX/DXILOpBuilder.cpp
index a1eacc2d48009c..f56155312aee50 100644
--- a/llvm/lib/Target/DirectX/DXILOpBuilder.cpp
+++ b/llvm/lib/Target/DirectX/DXILOpBuilder.cpp
@@ -21,24 +21,6 @@ using namespace llvm::dxil;
 
 constexpr StringLiteral DXILOpNamePrefix = "dx.op.";
 
-namespace {
-
-enum OverloadKind : uint16_t {
-  VOID = 1,
-  HALF = 1 << 1,
-  FLOAT = 1 << 2,
-  DOUBLE = 1 << 3,
-  I1 = 1 << 4,
-  I8 = 1 << 5,
-  I16 = 1 << 6,
-  I32 = 1 << 7,
-  I64 = 1 << 8,
-  UserDefineType = 1 << 9,
-  ObjectType = 1 << 10,
-};
-
-} // namespace
-
 static const char *getOverloadTypeName(OverloadKind Kind) {
   switch (Kind) {
   case OverloadKind::HALF:
@@ -61,8 +43,11 @@ static const char *getOverloadTypeName(OverloadKind Kind) {
   case OverloadKind::ObjectType:
   case OverloadKind::UserDefineType:
     break;
+  case OverloadKind::INVALID:
+    report_fatal_error("Invalid Overload Type for type name lookup",
+                       /* gen_crash_diag=*/false);
   }
-  llvm_unreachable("invalid overload type for name");
+  llvm_unreachable("Unhandled Overload Type specified for type name lookup");
   return "void";
 }
 
diff --git a/llvm/test/CodeGen/DirectX/barrier.ll b/llvm/test/CodeGen/DirectX/barrier.ll
new file mode 100644
index 00000000000000..8be4aac1f782b5
--- /dev/null
+++ b/llvm/test/CodeGen/DirectX/barrier.ll
@@ -0,0 +1,11 @@
+; RUN: opt -S -dxil-op-lower < %s | FileCheck %s
+
+; Argument of llvm.dx.barrier is expected to be a mask of 
+; DXIL::BarrierMode values. Chose an int value for testing.
+
+define void @test_barrier() #0 {
+entry:
+  ; CHECK: call void @dx.op.barrier.i32(i32 80, i32 9)
+  call void @llvm.dx.barrier(i32 noundef 9)
+  ret void
+}
diff --git a/llvm/utils/TableGen/DXILEmitter.cpp b/llvm/utils/TableGen/DXILEmitter.cpp
index af1efb8aa99f73..5f678a527381ca 100644
--- a/llvm/utils/TableGen/DXILEmitter.cpp
+++ b/llvm/utils/TableGen/DXILEmitter.cpp
@@ -39,8 +39,8 @@ struct DXILOperationDesc {
   int OpCode;         // ID of DXIL operation
   StringRef OpClass;  // name of the opcode class
   StringRef Doc;      // the documentation description of this instruction
-  SmallVector<Record *> OpTypes; // Vector of operand type records -
-                                 // return type is at index 0
+  SmallVector<OverloadKind> OpOverloadTys; // Vector of operand overload types -
+                                           // return type is at index 0
   SmallVector<std::string>
       OpAttributes;     // operation attribute represented as strings
   StringRef Intrinsic;  // The llvm intrinsic map to OpName. Default is "" which
@@ -65,41 +65,167 @@ struct DXILOperationDesc {
 };
 } // end anonymous namespace
 
-/// Return dxil::ParameterKind corresponding to input LLVMType record
+/// Return dxil::ParameterKind corresponding to input Overload Kind
 ///
-/// \param R TableGen def record of class LLVMType
+/// \param OLKind Overload Kind
 /// \return ParameterKind As defined in llvm/Support/DXILABI.h
 
-static ParameterKind getParameterKind(const Record *R) {
+static ParameterKind getParameterKind(const dxil::OverloadKind OLKind) {
+  switch (OLKind) {
+  case OverloadKind::VOID:
+    return ParameterKind::VOID;
+  case OverloadKind::HALF:
+    return ParameterKind::HALF;
+  case OverloadKind::FLOAT:
+    return ParameterKind::FLOAT;
+  case OverloadKind::DOUBLE:
+    return ParameterKind::DOUBLE;
+  case OverloadKind::I1:
+    return ParameterKind::I1;
+  case OverloadKind::I8:
+    return ParameterKind::I8;
+  case OverloadKind::I16:
+    return ParameterKind::I16;
+  case OverloadKind::I32:
+    return ParameterKind::I32;
+  case OverloadKind::I64:
+    return ParameterKind::I64;
+  default:
+    if ((OLKind ==
+         (OverloadKind::HALF | OverloadKind::FLOAT | OverloadKind::DOUBLE)) ||
+        (OLKind == (OverloadKind::HALF | OverloadKind::FLOAT)) ||
+        (OLKind == (OverloadKind::I1 | OverloadKind::I8 | OverloadKind::I16 |
+                    OverloadKind::I32 | OverloadKind::I64)) ||
+        (OLKind == (OverloadKind::I16 | OverloadKind::I32))) {
+      return ParameterKind::OVERLOAD;
+    } else {
+      report_fatal_error("Unsupported Overload Type encountered",
+                         /* gen_crash_diag=*/false);
+    }
+  }
+}
+
+/// Return a string representation of ParameterKind enum
+/// \param Kind Parameter Kind enum value
+/// \return std::string string representation of input Kind
+static std::string getParameterKindStr(ParameterKind Kind) {
+  switch (Kind) {
+  case ParameterKind::INVALID:
+    return "INVALID";
+  case ParameterKind::VOID:
+    return "VOID";
+  case ParameterKind::HALF:
+    return "HALF";
+  case ParameterKind::FLOAT:
+    return "FLOAT";
+  case ParameterKind::DOUBLE:
+    return "DOUBLE";
+  case ParameterKind::I1:
+    return "I1";
+  case ParameterKind::I8:
+    return "I8";
+  case ParameterKind::I16:
+    return "I16";
+  case ParameterKind::I32:
+    return "I32";
+  case ParameterKind::I64:
+    return "I64";
+  case ParameterKind::OVERLOAD:
+    return "OVERLOAD";
+  case ParameterKind::CBUFFER_RET:
+    return "CBUFFER_RET";
+  case ParameterKind::RESOURCE_RET:
+    return "RESOURCE_RET";
+  case ParameterKind::DXIL_HANDLE:
+    return "DXIL_HANDLE";
+  }
+  llvm_unreachable("Unknown llvm::dxil::ParameterKind enum");
+}
+
+static dxil::OverloadKind getOverloadKind(const Record *R) {
   auto VTRec = R->getValueAsDef("VT");
   switch (getValueType(VTRec)) {
   case MVT::isVoid:
-    return ParameterKind::VOID;
+    return OverloadKind::VOID;
   case MVT::f16:
-    return ParameterKind::HALF;
+    return OverloadKind::HALF;
   case MVT::f32:
-    return ParameterKind::FLOAT;
+    return OverloadKind::FLOAT;
   case MVT::f64:
-    return ParameterKind::DOUBLE;
+    return OverloadKind::DOUBLE;
   case MVT::i1:
-    return ParameterKind::I1;
+    return OverloadKind::I1;
   case MVT::i8:
-    return ParameterKind::I8;
+    return OverloadKind::I8;
   case MVT::i16:
-    return ParameterKind::I16;
+    return OverloadKind::I16;
   case MVT::i32:
-    return ParameterKind::I32;
-  case MVT::fAny:
+    return OverloadKind::I32;
+  case MVT::i64:
+    return OverloadKind::I64;
   case MVT::iAny:
-    return ParameterKind::OVERLOAD;
+    return static_cast<dxil::OverloadKind>(
+        OverloadKind::I16 | OverloadKind::I32 | OverloadKind::I64);
+  case MVT::fAny:
+    return static_cast<dxil::OverloadKind>(
+        OverloadKind::HALF | OverloadKind::FLOAT | OverloadKind::DOUBLE);
   case MVT::Other:
     // Handle DXIL-specific overload types
-    if (R->getValueAsInt("isHalfOrFloat") || R->getValueAsInt("isI16OrI32")) {
-      return ParameterKind::OVERLOAD;
+    {
+      if (R->getValueAsInt("isHalfOrFloat")) {
+        return static_cast<dxil::OverloadKind>(OverloadKind::HALF |
+                                               OverloadKind::FLOAT);
+      } else if (R->getValueAsInt("isI16OrI32")) {
+        return static_cast<dxil::OverloadKind>(OverloadKind::I16 |
+                                               OverloadKind::I32);
+      }
     }
     LLVM_FALLTHROUGH;
   default:
-    llvm_unreachable("Support for specified DXIL Type not yet implemented");
+    report_fatal_error(
+        "Support for specified parameter OverloadKind not yet implemented",
+        /* gen_crash_diag=*/false);
+  }
+}
+
+/// Return a string representation of OverloadKind enum
+/// \param OLKind Overload Kind
+/// \return std::string string representation of OverloadKind
+
+static std::string getOverloadKindStr(const dxil::OverloadKind OLKind) {
+  switch (OLKind) {
+  case OverloadKind::VOID:
+    return "OverloadKind::VOID";
+  case OverloadKind::HALF:
+    return "OverloadKind::HALF";
+  case OverloadKind::FLOAT:
+    return "OverloadKind::FLOAT";
+  case OverloadKind::DOUBLE:
+    return "OverloadKind::DOUBLE";
+  case OverloadKind::I1:
+    return "OverloadKind::I1";
+  case OverloadKind::I8:
+    return "OverloadKind::I8";
+  case OverloadKind::I16:
+    return "OverloadKind::I16";
+  case OverloadKind::I32:
+    return "OverloadKind::I32";
+  case OverloadKind::I64:
+    return "OverloadKind::I64";
+  default:
+    if (OLKind == (OverloadKind::I16 | OverloadKind::I32 | OverloadKind::I64)) {
+      return "OverloadKind::I16 | OverloadKind::I32 | OverloadKind::I64";
+    } else if (OLKind == (OverloadKind::HALF | OverloadKind::FLOAT |
+                          OverloadKind::DOUBLE)) {
+      return "OverloadKind::HALF | OverloadKind::FLOAT | OverloadKind::DOUBLE";
+    } else if (OLKind == (OverloadKind::HALF | OverloadKind::FLOAT)) {
+      return "OverloadKind::HALF | OverloadKind::FLOAT";
+    } else if (OLKind == (OverloadKind::I16 | OverloadKind::I32)) {
+      return "OverloadKind::I16 | OverloadKind::I32";
+    } else {
+      report_fatal_error("Unsupported OverloadKind specified",
+                         /* gen_crash_diag=*/false);
+    }
   }
 }
 
@@ -114,9 +240,25 @@ DXILOperationDesc::DXILOperationDesc(const Record *R) {
 
   Doc = R->getValueAsString("Doc");
 
-  auto TypeRecs = R->getValueAsListOfDefs("OpTypes");
+  // Populate OpOverloadTys with return type and parameter types
+  auto RetTypeRecs = R->getValueAsListOfDefs("OpRetTypes");
+  auto ParamTypeRecs = R->getValueAsListOfDefs("OpParamTypes");
+  unsigned RetTypeRecSize = RetTypeRecs.size();
+  unsigned ParamTypeRecSize = ParamTypeRecs.size();
+  // A vector with return type and parameter type records
+  std::vector<Record *> TypeRecs;
+  TypeRecs.reserve(RetTypeRecSize + ParamTypeRecSize);
+  // If return type lust is empty, the return type is void
+  if (RetTypeRecSize == 0) {
+    OpOverloadTys.emplace_back(OverloadKind::VOID);
+  } else {
+    // Append RetTypeRecs to TypeRecs
+    TypeRecs.insert(TypeRecs.end(), RetTypeRecs.begin(), RetTypeRecs.end());
+  }
+  // Append RetTypeRecs to TypeRecs
+  TypeRecs.insert(TypeRecs.end(), ParamTypeRecs.begin(), ParamTypeRecs.end());
+
   unsigned TypeRecsSize = TypeRecs.size();
-  // Populate OpTypes with return type and parameter types
 
   // Parameter indices of overloaded parameters.
   // This vector contains overload parameters in the order used to
@@ -146,13 +288,13 @@ DXILOperationDesc::DXILOperationDesc(const Record *R) {
         if (!knownType) {
           report_fatal_error("Specification of multiple differing overload "
                              "parameter types not yet supported",
-                             false);
+                             /* gen_crash_diag=*/false);
         }
       } else {
         OverloadParamIndices.push_back(i);
       }
     }
-    // Populate OpTypes array according to the type specification
+    // Populate OpOverloadTys array according to the type specification
     if (TR->isAnonymous()) {
       // Check prior overload types exist
       assert(!OverloadParamIndices.empty() &&
@@ -160,10 +302,10 @@ DXILOperationDesc::DXILOperationDesc(const Record *R) {
       // Get the parameter index of anonymous type, TR, references
       auto OLParamIndex = TR->getValueAsInt("Number");
       // Resolve and insert the type to that at OLParamIndex
-      OpTypes.emplace_back(TypeRecs[OLParamIndex]);
+      OpOverloadTys.emplace_back(getOverloadKind(TypeRecs[OLParamIndex]));
     } else {
-      // A non-anonymous type. Just record it in OpTypes
-      OpTypes.emplace_back(TR);
+      // A non-anonymous type. Just record it in OpOverloadTys
+      OpOverloadTys.emplace_back(getOverloadKind(TR));
     }
   }
 
@@ -172,7 +314,7 @@ DXILOperationDesc::DXILOperationDesc(const Record *R) {
   if (!OverloadParamIndices.empty()) {
     if (OverloadParamIndices.size() > 1)
       report_fatal_error("Multiple overload type specification not supported",
-                         false);
+                         /* gen_crash_diag=*/false);
     OverloadParamIndex = OverloadParamIndices[0];
   }
   // Get the operation class
@@ -196,89 +338,6 @@ DXILOperationDesc::DXILOperationDesc(const Record *R) {
   }
 }
 
-/// Return a string representation of ParameterKind enum
-/// \param Kind Parameter Kind enum value
-/// \return std::string string representation of input Kind
-static std::string getParameterKindStr(ParameterKind Kind) {
-  switch (Kind) {
-  case ParameterKind::INVALID:
-    return "INVALID";
-  case ParameterKind::VOID:
-    return "VOID";
-  case ParameterKind::HALF:
-    return "HALF";
-  case ParameterKind::FLOAT:
-    return "FLOAT";
-  case ParameterKind::DOUBLE:
-    return "DOUBLE";
-  case ParameterKind::I1:
-    return "I1";
-  case ParameterKind::I8:
-    return "I8";
-  case ParameterKind::I16:
-    return "I16";
-  case ParameterKind::I32:
-    return "I32";
-  case ParameterKind::I64:
-    return "I64";
-  case ParameterKind::OVERLOAD:
-    return "OVERLOAD";
-  case ParameterKind::CBUFFER_RET:
-    return "CBUFFER_RET";
-  case ParameterKind::RESOURCE_RET:
-    return "RESOURCE_RET";
-  case ParameterKind::DXIL_HANDLE:
-    return "DXIL_HANDLE";
-  }
-  llvm_unreachable("Unknown llvm::dxil::ParameterKind enum");
-}
-
-/// Return a string representation of OverloadKind enum that maps to
-/// input LLVMType record
-/// \param R TableGen def record of class LLVMType
-/// \return std::string string representation of OverloadKind
-
-static std::string getOverloadKindStr(const Record *R) {
-  auto VTRec = R->getValueAsDef("VT");
-  switch (getValueType(VTRec)) {
-  case MVT::isVoid:
-    return "OverloadKind::VOID";
-  case MVT::f16:
-    return "OverloadKind::HALF";
-  case MVT::f32:
-    return "OverloadKind::FLOAT";
-  case MVT::f64:
-    return "OverloadKin...
[truncated]

@bharadwajy bharadwajy changed the title [DirectX][DXIL] Align type spec of TableGen DXIL Op mapping with that of LLVM Intrinsic [DirectX][DXIL] Align type spec of TableGen DXIL Op mapping and LLVM Intrinsic Mar 22, 2024
@bharadwajy bharadwajy changed the title [DirectX][DXIL] Align type spec of TableGen DXIL Op mapping and LLVM Intrinsic [DirectX][DXIL] Align type spec of TableGen DXIL Op and LLVM Intrinsic Mar 22, 2024
@bharadwajy bharadwajy force-pushed the dxil_td/improve-overload-type-handling branch from f52ed26 to 79e4778 Compare March 22, 2024 17:15
Copy link
Contributor

@coopp coopp left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah.. This looks like it is now matching Justin's PR.

llvm/include/llvm/Support/DXILABI.h Show resolved Hide resolved
@bharadwajy bharadwajy force-pushed the dxil_td/improve-overload-type-handling branch from b3f04ed to 7e271f5 Compare March 25, 2024 20:26
Copy link

✅ With the latest revision this PR passed the Python code formatter.

Copy link

✅ With the latest revision this PR passed the C/C++ code formatter.

Copy link
Collaborator

@llvm-beanz llvm-beanz left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I want to put the breaks on this PR.

We have a document that is documenting this design which is inconsistent with this PR. Before we merge this we should make the documentation consistent. The doc PR is here: #85170.

DXIL Op mapping with those of TableGan class Intrinsic.

A void return type of LLVM Intrinsic is represented as [] in its
TableGen description record. Currently, a void return type of
DXIL Operation is represented as [llvm_void_ty]. In addition,
return and parameter types are recorded as a single list with
an understanding that element at index `0` is the return type.

These changes leverage and align DXIL Op type specification with
the type specification of the LLVM Intrinsic. As a result, return
and parameter types are now specified as two separate lists no
longer requiring a different representation for void return type.
Additionally, type specification would be more succinct yet
equally informative for DXIL Op records for which the same LLVM
Intrinsics types are also valid.

Added a test to verify lowering of LLVM intrinsic with void return.

Barrier intrinsic has a void return type. Specification of its
DXIL Op can inherit the types of this intrinsic. The test verifies
the changes.

Move OverloadKind to DXILABI.h.

Update definition names of enum Overload to follow naming conventions.
@bharadwajy bharadwajy force-pushed the dxil_td/improve-overload-type-handling branch from 7e271f5 to 9724728 Compare March 27, 2024 21:24
@bharadwajy
Copy link
Contributor Author

I want to put the breaks on this PR.

We have a document that is documenting this design which is inconsistent with this PR. Before we merge this we should make the documentation consistent. The doc PR is here: #85170.

Updated the document

  1. to reflect the changes proposed in this PR.
  2. with details of another option - along with a link to corresponding changes.

Thanks!

@llvm-beanz llvm-beanz self-requested a review March 29, 2024 22:48
Copy link
Collaborator

@llvm-beanz llvm-beanz left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm not going to block this, but I'm still unsure that the documentation is clear and covers all the cases we need. I'll defer to @bogner.

@bharadwajy
Copy link
Contributor Author

I'm not going to block this, but I'm still unsure that the documentation is clear and covers all the cases we need. I'll defer to @bogner.

I think I prefer the Option 3 (Leverage the existing classification of DXIL Operations) in the updated design document (#85170) compared to the one proposed in this PR for the reasons specified there. I'll explore its current implementation (changeset link in the design doc) a bit more and submit a PR that potentially overrides this one.

Thanks!

@bharadwajy
Copy link
Contributor Author

Taking a different design direction - see PR #87803. Closing.

@bharadwajy bharadwajy closed this Apr 9, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
Status: Done
Development

Successfully merging this pull request may close these issues.

[DirectX][DXIL] Align TableGen DXIL Op type spec with that of LLVM Intrinsic
6 participants