Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[DirectX][NFC] Change specification of overload types and attribute in DXIL.td #81184

Merged
merged 3 commits into from
Feb 13, 2024

Conversation

bharadwajy
Copy link
Contributor

  • Specify overload types of DXIL Operation as list of types instead of a string.
  • Add supported DXIL type record definitions to DXIL.td leveraging LLVMType to avoid duplicate definitions.
  • Spell out DXIL Operation Attribute specification string.
  • Make corresponding changes to process the records in DXILEmitter.cpp

@llvmbot
Copy link
Collaborator

llvmbot commented Feb 8, 2024

@llvm/pr-subscribers-backend-directx

Author: S. Bharadwaj Yadavalli (bharadwajy)

Changes
  • Specify overload types of DXIL Operation as list of types instead of a string.
  • Add supported DXIL type record definitions to DXIL.td leveraging LLVMType to avoid duplicate definitions.
  • Spell out DXIL Operation Attribute specification string.
  • Make corresponding changes to process the records in DXILEmitter.cpp

Full diff: https://github.com/llvm/llvm-project/pull/81184.diff

2 Files Affected:

  • (modified) llvm/lib/Target/DirectX/DXIL.td (+42-11)
  • (modified) llvm/utils/TableGen/DXILEmitter.cpp (+90-53)
diff --git a/llvm/lib/Target/DirectX/DXIL.td b/llvm/lib/Target/DirectX/DXIL.td
index 3f3ace5a1a3a36..4b09c9597e2228 100644
--- a/llvm/lib/Target/DirectX/DXIL.td
+++ b/llvm/lib/Target/DirectX/DXIL.td
@@ -34,12 +34,42 @@ def BinaryUintCategory : DXILOpCategory<"Binary uint">;
 def UnaryFloatCategory : DXILOpCategory<"Unary float">;
 def ComputeIDCategory : DXILOpCategory<"Compute/Mesh/Amplification shader">;
 
+// ValueTypes specific to DXIL
+// Define Overload value type as an entity with no size and an arbitrary value
+// of 1024 - assuming that all currently defined values are less than 1024
+def overloadVal       : ValueType<0, 1024>;
+def resourceRetVal    : ValueType<0, 1025>;
+def cbufferRetVal     : ValueType<0, 1026>;
+def handleVal         : ValueType<0, 1027>;
+
+// Following are the scalar types supported by DXIL operations and are synonymous
+// to llvm_*_ty defined for readability and ease of use in the context of this file.
+
+def voidTy  : LLVMType<isVoid>;
+
+// Floating point types
+def f16Ty   : LLVMType<f16>;
+def f32Ty   : LLVMType<f32>;
+def f64Ty   : LLVMType<f64>;
+
+// Integer types
+def i1Ty   : LLVMType<i1>;
+def i8Ty   : LLVMType<i8>;
+def i16Ty  : LLVMType<i16>;
+def i32Ty  : LLVMType<i32>;
+def i64Ty  : LLVMType<i64>;
+
+def overloadTy        : LLVMType<overloadVal>;
+def resourceRetTy     : LLVMType<resourceRetVal>;
+def cbufferRetTy      : LLVMType<cbufferRetVal>;
+def handleTy          : LLVMType<handleVal>;
+
 // The parameter description for a DXIL operation
 class DXILOpParameter<int pos, string type, string name, string doc,
                  bit isConstant = 0, string enumName = "",
                  int maxValue = 0> {
   int Pos = pos;               // Position in parameter list
-  string LLVMType = type;      // LLVM type name, $o for overload, $r for resource
+  string Type = type;          // LLVM type name, $o for overload, $r for resource
                                // type, $cb for legacy cbuffer, $u4 for u4 struct
   string Name = name;          // Short, unique parameter name
   string Doc = doc;            // Description of this parameter
@@ -56,9 +86,10 @@ class DXILOperationDesc {
   DXILOpCategory OpCategory;  // Category of the operation
   string Doc = "";            // Description of the operation
   list<DXILOpParameter> Params = []; // Parameter list of the operation
-  string OverloadTypes = "";  // Overload types, if applicable
-  string Attributes = "";     // Attribute shorthands: rn=does not access
-                              // memory,ro=only reads from memory,
+  list<LLVMType> OverloadTypes = [];  // Overload types, if applicable
+  string Attributes = "";     // Operation Attribute
+                              // "NoReadMemory" - does not read memory
+                              // "ReadMemory"   - reads memory
   bit IsDerivative = 0;       // Whether this is some kind of derivative
   bit IsGradient = 0;         // Whether this requires a gradient calculation
   bit IsFeedback = 0;         // Whether this is a sampler feedback operation
@@ -71,7 +102,7 @@ class DXILOperationDesc {
 }
 
 class DXILOperation<string name, int opCode, DXILOpClass opClass, DXILOpCategory opCategory, string doc,
-              string oloadTypes, string attrs, list<DXILOpParameter> params,
+              list<LLVMType> oloadTypes, string attrs, list<DXILOpParameter> params,
               list<string> statsGroup = []> : DXILOperationDesc {
   let OpName = name;
   let OpCode = opCode;
@@ -88,7 +119,7 @@ class DXILOperation<string name, int opCode, DXILOpClass opClass, DXILOpCategory
 class LLVMIntrinsic<Intrinsic llvm_intrinsic_> { Intrinsic llvm_intrinsic = llvm_intrinsic_; }
 
 def Sin : DXILOperation<"Sin", 13, UnaryClass, UnaryFloatCategory, "returns sine(theta) for theta in radians.",
-  "half;float;", "rn",
+  [f16Ty,f32Ty], "NoReadMemory",
   [
     DXILOpParameter<0, "$o", "", "operation result">,
     DXILOpParameter<1, "i32", "opcode", "DXIL opcode">,
@@ -98,7 +129,7 @@ def Sin : DXILOperation<"Sin", 13, UnaryClass, UnaryFloatCategory, "returns sine
   LLVMIntrinsic<int_sin>;
 
 def UMax : DXILOperation< "UMax", 39,  BinaryClass,  BinaryUintCategory, "unsigned integer maximum. UMax(a,b) = a > b ? a : b",
-    "i16;i32;i64;",  "rn",
+    [i16Ty,i32Ty,i64Ty],  "NoReadMemory",
   [
     DXILOpParameter<0,  "$o",  "",  "operation result">,
     DXILOpParameter<1,  "i32",  "opcode",  "DXIL opcode">,
@@ -108,7 +139,7 @@ def UMax : DXILOperation< "UMax", 39,  BinaryClass,  BinaryUintCategory, "unsign
   ["uints"]>,
   LLVMIntrinsic<int_umax>;
 
-def ThreadId : DXILOperation< "ThreadId", 93,  ThreadIdClass, ComputeIDCategory, "reads the thread ID", "i32;",  "rn",
+def ThreadId : DXILOperation< "ThreadId", 93,  ThreadIdClass, ComputeIDCategory, "reads the thread ID", [i32Ty],  "NoReadMemory",
   [
     DXILOpParameter<0,  "i32",  "",  "thread ID component">,
     DXILOpParameter<1,  "i32",  "opcode",  "DXIL opcode">,
@@ -116,7 +147,7 @@ def ThreadId : DXILOperation< "ThreadId", 93,  ThreadIdClass, ComputeIDCategory,
   ]>,
   LLVMIntrinsic<int_dx_thread_id>;
 
-def GroupId : DXILOperation< "GroupId", 94,  GroupIdClass, ComputeIDCategory, "reads the group ID (SV_GroupID)", "i32;",  "rn",
+def GroupId : DXILOperation< "GroupId", 94,  GroupIdClass, ComputeIDCategory, "reads the group ID (SV_GroupID)", [i32Ty],  "NoReadMemory",
   [
     DXILOpParameter<0,  "i32",  "",  "group ID component">,
     DXILOpParameter<1,  "i32",  "opcode",  "DXIL opcode">,
@@ -125,7 +156,7 @@ def GroupId : DXILOperation< "GroupId", 94,  GroupIdClass, ComputeIDCategory, "r
   LLVMIntrinsic<int_dx_group_id>;
 
 def ThreadIdInGroup : DXILOperation< "ThreadIdInGroup", 95,  ThreadIdInGroupClass, ComputeIDCategory,
-  "reads the thread ID within the group (SV_GroupThreadID)", "i32;",  "rn",
+  "reads the thread ID within the group (SV_GroupThreadID)", [i32Ty],  "NoReadMemory",
   [
     DXILOpParameter<0,  "i32",  "",  "thread ID in group component">,
     DXILOpParameter<1,  "i32",  "opcode",  "DXIL opcode">,
@@ -134,7 +165,7 @@ def ThreadIdInGroup : DXILOperation< "ThreadIdInGroup", 95,  ThreadIdInGroupClas
   LLVMIntrinsic<int_dx_thread_id_in_group>;
 
 def FlattenedThreadIdInGroup : DXILOperation< "FlattenedThreadIdInGroup", 96,  FlattenedThreadIdInGroupClass, ComputeIDCategory,
-   "provides a flattened index for a given thread within a given group (SV_GroupIndex)", "i32;",  "rn",
+   "provides a flattened index for a given thread within a given group (SV_GroupIndex)", [i32Ty],  "NoReadMemory",
   [
     DXILOpParameter<0,  "i32",  "",  "result">,
     DXILOpParameter<1,  "i32",  "opcode",  "DXIL opcode">
diff --git a/llvm/utils/TableGen/DXILEmitter.cpp b/llvm/utils/TableGen/DXILEmitter.cpp
index cb9f9c6b03c636..32c5f4ff16f004 100644
--- a/llvm/utils/TableGen/DXILEmitter.cpp
+++ b/llvm/utils/TableGen/DXILEmitter.cpp
@@ -49,7 +49,7 @@ struct DXILOperationDesc {
   StringRef Doc;       // the documentation description of this instruction
 
   SmallVector<DXILParameter> Params; // the operands that this instruction takes
-  StringRef OverloadTypes;       // overload types if applicable
+  SmallVector<ParameterKind> OverloadTypes; // overload types if applicable
   StringRef FnAttr;              // attribute shorthands: rn=does not access
                                  // memory,ro=only reads from memory
   StringRef Intrinsic; // The llvm intrinsic map to OpName. Default is "" which
@@ -69,37 +69,31 @@ struct DXILOperationDesc {
   int OverloadParamIndex; // parameter index which control the overload.
                           // When < 0, should be only 1 overload type.
   SmallVector<StringRef, 4> counters; // counters for this inst.
-  DXILOperationDesc(const Record *R) {
-    OpName = R->getValueAsString("OpName");
-    OpCode = R->getValueAsInt("OpCode");
-    OpClass = R->getValueAsDef("OpClass")->getValueAsString("Name");
-    Category = R->getValueAsDef("OpCategory")->getValueAsString("Name");
-
-    if (R->getValue("llvm_intrinsic")) {
-      auto *IntrinsicDef = R->getValueAsDef("llvm_intrinsic");
-      auto DefName = IntrinsicDef->getName();
-      assert(DefName.starts_with("int_") && "invalid intrinsic name");
-      // Remove the int_ from intrinsic name.
-      Intrinsic = DefName.substr(4);
-    }
-
-    Doc = R->getValueAsString("Doc");
-
-    ListInit *ParamList = R->getValueAsListInit("Params");
-    OverloadParamIndex = -1;
-    for (unsigned I = 0; I < ParamList->size(); ++I) {
-      Record *Param = ParamList->getElementAsRecord(I);
-      Params.emplace_back(DXILParameter(Param));
-      auto &CurParam = Params.back();
-      if (CurParam.Kind >= ParameterKind::OVERLOAD)
-        OverloadParamIndex = I;
-    }
-    OverloadTypes = R->getValueAsString("OverloadTypes");
-    FnAttr = R->getValueAsString("Attributes");
-  }
+  DXILOperationDesc(const Record *);
 };
 } // end anonymous namespace
 
+// Convert DXIL type name string to dxil::ParameterKind
+// @param typeNameStr Type name string
+// @return ParameterKind as defined in llvm/Support/DXILABI.h
+static ParameterKind getDXILTypeNameToKind(StringRef typeNameStr) {
+  return StringSwitch<ParameterKind>(typeNameStr)
+      .Case("voidTy", ParameterKind::VOID)
+      .Case("f16Ty", ParameterKind::HALF)
+      .Case("f32Ty", ParameterKind::FLOAT)
+      .Case("f64Ty", ParameterKind::DOUBLE)
+      .Case("i1Ty", ParameterKind::I1)
+      .Case("i8Ty", ParameterKind::I8)
+      .Case("i16Ty", ParameterKind::I16)
+      .Case("i32Ty", ParameterKind::I32)
+      .Case("i64Ty", ParameterKind::I64)
+      .Case("overloadTy", ParameterKind::OVERLOAD)
+      .Case("handleTy", ParameterKind::DXIL_HANDLE)
+      .Case("cbufferRetTy", ParameterKind::CBUFFER_RET)
+      .Case("resourceRetTy", ParameterKind::RESOURCE_RET)
+      .Default(ParameterKind::INVALID);
+}
+
 static ParameterKind parameterTypeNameToKind(StringRef Name) {
   return StringSwitch<ParameterKind>(Name)
       .Case("void", ParameterKind::VOID)
@@ -118,10 +112,44 @@ static ParameterKind parameterTypeNameToKind(StringRef Name) {
       .Default(ParameterKind::INVALID);
 }
 
+DXILOperationDesc::DXILOperationDesc(const Record *R) {
+  OpName = R->getValueAsString("OpName");
+  OpCode = R->getValueAsInt("OpCode");
+  OpClass = R->getValueAsDef("OpClass")->getValueAsString("Name");
+  Category = R->getValueAsDef("OpCategory")->getValueAsString("Name");
+
+  if (R->getValue("llvm_intrinsic")) {
+    auto *IntrinsicDef = R->getValueAsDef("llvm_intrinsic");
+    auto DefName = IntrinsicDef->getName();
+    assert(DefName.starts_with("int_") && "invalid intrinsic name");
+    // Remove the int_ from intrinsic name.
+    Intrinsic = DefName.substr(4);
+  }
+
+  Doc = R->getValueAsString("Doc");
+
+  ListInit *ParamList = R->getValueAsListInit("Params");
+  OverloadParamIndex = -1;
+  for (unsigned I = 0; I < ParamList->size(); ++I) {
+    Record *Param = ParamList->getElementAsRecord(I);
+    Params.emplace_back(DXILParameter(Param));
+    auto &CurParam = Params.back();
+    if (CurParam.Kind >= ParameterKind::OVERLOAD)
+      OverloadParamIndex = I;
+  }
+  ListInit *OverloadTypeList = R->getValueAsListInit("OverloadTypes");
+
+  for (unsigned I = 0; I < OverloadTypeList->size(); ++I) {
+    Record *R = OverloadTypeList->getElementAsRecord(I);
+    OverloadTypes.emplace_back(getDXILTypeNameToKind(R->getNameInitAsString()));
+  }
+  FnAttr = R->getValueAsString("Attributes");
+}
+
 DXILParameter::DXILParameter(const Record *R) {
   Name = R->getValueAsString("Name");
   Pos = R->getValueAsInt("Pos");
-  Kind = parameterTypeNameToKind(R->getValueAsString("LLVMType"));
+  Kind = parameterTypeNameToKind(R->getValueAsString("Type"));
   if (R->getValue("Doc"))
     Doc = R->getValueAsString("Doc");
   IsConst = R->getValueAsBit("IsConstant");
@@ -268,36 +296,45 @@ static void emitDXILIntrinsicMap(std::vector<DXILOperationDesc> &Ops,
 
 static std::string emitDXILOperationFnAttr(StringRef FnAttr) {
   return StringSwitch<std::string>(FnAttr)
-      .Case("rn", "Attribute::ReadNone")
-      .Case("ro", "Attribute::ReadOnly")
+      .Case("NoReadMemory", "Attribute::ReadNone")
+      .Case("ReadMemory", "Attribute::ReadOnly")
       .Default("Attribute::None");
 }
 
-static std::string getOverloadKind(StringRef Overload) {
-  return StringSwitch<std::string>(Overload)
-      .Case("half", "OverloadKind::HALF")
-      .Case("float", "OverloadKind::FLOAT")
-      .Case("double", "OverloadKind::DOUBLE")
-      .Case("i1", "OverloadKind::I1")
-      .Case("i16", "OverloadKind::I16")
-      .Case("i32", "OverloadKind::I32")
-      .Case("i64", "OverloadKind::I64")
-      .Case("udt", "OverloadKind::UserDefineType")
-      .Case("obj", "OverloadKind::ObjectType")
-      .Default("OverloadKind::VOID");
+static std::string overloadKindStr(ParameterKind Overload) {
+  switch (Overload) {
+  case ParameterKind::HALF:
+    return "OverloadKind::HALF";
+  case ParameterKind::FLOAT:
+    return "OverloadKind::FLOAT";
+  case ParameterKind::DOUBLE:
+    return "OverloadKind::DOUBLE";
+  case ParameterKind::I1:
+    return "OverloadKind::I1";
+  case ParameterKind::I8:
+    return "OverloadKind::I8";
+  case ParameterKind::I16:
+    return "OverloadKind::I16";
+  case ParameterKind::I32:
+    return "OverloadKind::I32";
+  case ParameterKind::I64:
+    return "OverloadKind::I64";
+  case ParameterKind::VOID:
+    return "OverloadKind::VOID";
+  default:
+    return "OverloadKind::UNKNOWN";
+  }
 }
 
-static std::string getDXILOperationOverload(StringRef Overloads) {
-  SmallVector<StringRef> OverloadStrs;
-  Overloads.split(OverloadStrs, ';', /*MaxSplit*/ -1, /*KeepEmpty*/ false);
+static std::string
+getDXILOperationOverloads(SmallVector<ParameterKind> Overloads) {
   // Format is: OverloadKind::FLOAT | OverloadKind::HALF
-  assert(!OverloadStrs.empty() && "Invalid overloads");
-  auto It = OverloadStrs.begin();
+  auto It = Overloads.begin();
   std::string Result;
   raw_string_ostream OS(Result);
-  OS << getOverloadKind(*It);
-  for (++It; It != OverloadStrs.end(); ++It) {
-    OS << " | " << getOverloadKind(*It);
+  OS << overloadKindStr(*It);
+  for (++It; It != Overloads.end(); ++It) {
+    OS << " | " << overloadKindStr(*It);
   }
   return OS.str();
 }
@@ -367,7 +404,7 @@ static void emitDXILOperationTable(std::vector<DXILOperationDesc> &Ops,
     OS << "  { dxil::OpCode::" << Op.OpName << ", "
        << OpStrings.get(Op.OpName.str()) << ", OpCodeClass::" << Op.OpClass
        << ", " << OpClassStrings.get(getDXILOpClassName(Op.OpClass)) << ", "
-       << getDXILOperationOverload(Op.OverloadTypes) << ", "
+       << getDXILOperationOverloads(Op.OverloadTypes) << ", "
        << emitDXILOperationFnAttr(Op.FnAttr) << ", " << Op.OverloadParamIndex
        << ", " << Op.Params.size() << ", "
        << Parameters.get(ParameterMap[Op.OpClass]) << " },\n";

@bharadwajy
Copy link
Contributor Author

Copy link

github-actions bot commented Feb 9, 2024

✅ With the latest revision this PR passed the C/C++ code formatter.

…in DXIL.td

 - Specify overload types of DXIL Operation as list of types instead of a string.
 - Add supported DXIL type record definitions to `DXIL.td` leveraging `LLVMType` to
   avoid duplicate definitions.
 - Spell out DXIL Operation Attribute specification string.
 - Make corresponding changes to process the records in DXILEmitter.cpp
Leverage EnumAttr class defined in llvm/IR/Attributes.td instead of
using an arbitrary string or defining a new one. This leverages
validation of acceptable attribute specification of DXIL operation
records added at compile-time.
llvm/lib/Target/DirectX/DXIL.td Outdated Show resolved Hide resolved
llvm/utils/TableGen/DXILEmitter.cpp Show resolved Hide resolved
@python3kgae python3kgae merged commit 8ba4ff3 into llvm:main Feb 13, 2024
3 of 4 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

4 participants