Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[DirectX][NFC] Model precise overload type specification of DXIL Ops #83917

Merged
merged 1 commit into from
Mar 12, 2024

Conversation

bharadwajy
Copy link
Contributor

@bharadwajy bharadwajy commented Mar 4, 2024

Implement an abstraction to specify precise overload types supported by DXIL ops. These overload types are typically a subset of LLVM intrinsics.

Implement the corresponding changes in DXILEmitter backend.

Add tests to verify expected errors for unsupported overload types at code generation time.

Add tests to check for correct overload error output.

@llvmbot
Copy link
Collaborator

llvmbot commented Mar 4, 2024

@llvm/pr-subscribers-backend-directx

Author: S. Bharadwaj Yadavalli (bharadwajy)

Changes

Implements an abstraction to specify precise overload types supported by DXIL ops. These overload types are typically a subset of LLVM intrinsics.

Implements the corresponding changes in DXILEmitter backend.

Adds tests to verify expected errors for unsupported overload types at code generation time.

Add tests to check for correct overrload error output.


Full diff: https://github.com/llvm/llvm-project/pull/83917.diff

8 Files Affected:

  • (modified) llvm/lib/Target/DirectX/DXIL.td (+43-5)
  • (modified) llvm/lib/Target/DirectX/DXILOpBuilder.cpp (+1-1)
  • (modified) llvm/test/CodeGen/DirectX/frac.ll (-3)
  • (added) llvm/test/CodeGen/DirectX/frac_error.ll (+14)
  • (added) llvm/test/CodeGen/DirectX/round_error.ll (+13)
  • (modified) llvm/test/CodeGen/DirectX/sin.ll (+2-17)
  • (added) llvm/test/CodeGen/DirectX/sin_error.ll (+14)
  • (modified) llvm/utils/TableGen/DXILEmitter.cpp (+107-43)
diff --git a/llvm/lib/Target/DirectX/DXIL.td b/llvm/lib/Target/DirectX/DXIL.td
index 33b08ed93e3d0a..762664176048e4 100644
--- a/llvm/lib/Target/DirectX/DXIL.td
+++ b/llvm/lib/Target/DirectX/DXIL.td
@@ -205,25 +205,63 @@ defset list<DXILOpClass> OpClasses = {
   def writeSamplerFeedbackBias : DXILOpClass;
   def writeSamplerFeedbackGrad : DXILOpClass;
   def writeSamplerFeedbackLevel: DXILOpClass;
+
+  // This is a sentinel definition. Hence placed at the end of the list
+  // and not as part of the above alphabetically sorted valid definitions.
+  // Additionally it is capitalized unlike all the others.
+  def UnknownOpClass: DXILOpClass;
+}
+
+// Several of the overloaded DXIL Operations support for data types
+// that are a subset of the overloaded LLVM intrinsics that they map to.
+// For e.g., llvm.sin.* intrinsic operates on any floating-point type and
+// maps for lowering to DXIL Op Sin. However, valid overloads of DXIL Sin
+// operation overloads are half (f16) and float (f32) only.
+//
+// The following abstracts overload types specific to DXIL operations.
+
+class DXILType : LLVMType<OtherVT> {
+  let isAny = 1;
 }
 
+// Concrete records for various overload types supported specifically by
+// DXIL Operations.
+
+def llvm_i16ori32_ty : DXILType;
+def llvm_halforfloat_ty : DXILType;
+
 // Abstraction DXIL Operation to LLVM intrinsic
-class DXILOpMapping<int opCode, DXILOpClass opClass, Intrinsic intrinsic, string doc> {
+class DXILOpMappingBase {
+  int OpCode = 0;                      // Opcode of DXIL Operation
+  DXILOpClass OpClass = UnknownOpClass;// Class of DXIL Operation.
+  Intrinsic LLVMIntrinsic = ?;         // LLVM Intrinsic DXIL Operation maps to
+  string Doc = "";                     // A short description of the operation
+  list<LLVMType> OpTypes = ?;          // Valid types of DXIL Operation in the
+                                       // format [returnTy, param1ty, ...]
+}
+
+class DXILOpMapping<int opCode, DXILOpClass opClass,
+                    Intrinsic intrinsic, string doc,
+                    list<LLVMType> opTys = []> : DXILOpMappingBase {
   int OpCode = opCode;                 // Opcode corresponding to DXIL Operation
-  DXILOpClass OpClass = opClass;             // Class of DXIL Operation.
+  DXILOpClass OpClass = opClass;       // Class of DXIL Operation.
   Intrinsic LLVMIntrinsic = intrinsic; // LLVM Intrinsic the DXIL Operation maps
   string Doc = doc;                    // to a short description of the operation
+  list<LLVMType> OpTypes = !if(!eq(!size(opTys), 0), LLVMIntrinsic.Types, opTys);
 }
 
 // Concrete definition of DXIL Operation mapping to corresponding LLVM intrinsic
 def Sin  : DXILOpMapping<13, unary, int_sin,
-                         "Returns sine(theta) for theta in radians.">;
+                         "Returns sine(theta) for theta in radians.",
+                         [llvm_halforfloat_ty, LLVMMatchType<0>]>;
 def Frac : DXILOpMapping<22, unary, int_dx_frac,
                          "Returns a fraction from 0 to 1 that represents the "
-                         "decimal part of the input.">;
+                         "decimal part of the input.",
+                         [llvm_halforfloat_ty, LLVMMatchType<0>]>;
 def Round : DXILOpMapping<26, unary, int_round,
                          "Returns the input rounded to the nearest integer"
-                         "within a floating-point type.">;
+                         "within a floating-point type.",
+                         [llvm_halforfloat_ty, LLVMMatchType<0>]>;
 def UMax : DXILOpMapping<39, binary, int_umax,
                          "Unsigned integer maximum. UMax(a,b) = a > b ? a : b">;
 def ThreadId : DXILOpMapping<93, threadId, int_dx_thread_id,
diff --git a/llvm/lib/Target/DirectX/DXILOpBuilder.cpp b/llvm/lib/Target/DirectX/DXILOpBuilder.cpp
index 21a20d45b922d9..99bc5c214db298 100644
--- a/llvm/lib/Target/DirectX/DXILOpBuilder.cpp
+++ b/llvm/lib/Target/DirectX/DXILOpBuilder.cpp
@@ -257,7 +257,7 @@ static FunctionCallee getOrCreateDXILOpFunction(dxil::OpCode DXILOp,
   // FIXME: find the issue and report error in clang instead of check it in
   // backend.
   if ((Prop->OverloadTys & (uint16_t)Kind) == 0) {
-    llvm_unreachable("invalid overload");
+    report_fatal_error("Invalid Overload Type", false);
   }
 
   std::string FnName = constructOverloadName(Kind, OverloadTy, *Prop);
diff --git a/llvm/test/CodeGen/DirectX/frac.ll b/llvm/test/CodeGen/DirectX/frac.ll
index ab605ed6084aa4..ae86fe06654da1 100644
--- a/llvm/test/CodeGen/DirectX/frac.ll
+++ b/llvm/test/CodeGen/DirectX/frac.ll
@@ -29,6 +29,3 @@ entry:
   %dx.frac = call half @llvm.dx.frac.f16(half %0)
   ret half %dx.frac
 }
-
-; Function Attrs: nocallback nofree nosync nounwind readnone speculatable willreturn
-declare half @llvm.dx.frac.f16(half) #1
diff --git a/llvm/test/CodeGen/DirectX/frac_error.ll b/llvm/test/CodeGen/DirectX/frac_error.ll
new file mode 100644
index 00000000000000..ad51bb0b6dee71
--- /dev/null
+++ b/llvm/test/CodeGen/DirectX/frac_error.ll
@@ -0,0 +1,14 @@
+; RUN: not opt -S -dxil-op-lower %s 2>&1 | FileCheck %s
+
+; This test is expected to fail with the following error
+; CHECK: LLVM ERROR: Invalid Overload Type
+
+; Function Attrs: noinline nounwind optnone
+define noundef double @frac_double(double noundef %a) #0 {
+entry:
+  %a.addr = alloca double, align 8
+  store double %a, ptr %a.addr, align 8
+  %0 = load double, ptr %a.addr, align 8
+  %dx.frac = call double @llvm.dx.frac.f64(double %0)
+  ret double %dx.frac
+}
diff --git a/llvm/test/CodeGen/DirectX/round_error.ll b/llvm/test/CodeGen/DirectX/round_error.ll
new file mode 100644
index 00000000000000..3bd87b2bbf0200
--- /dev/null
+++ b/llvm/test/CodeGen/DirectX/round_error.ll
@@ -0,0 +1,13 @@
+; RUN: not opt -S -dxil-op-lower %s 2>&1 | FileCheck %s
+
+; This test is expected to fail with the following error
+; CHECK: LLVM ERROR: Invalid Overload Type
+
+define noundef double @round_double(double noundef %a) #0 {
+entry:
+  %a.addr = alloca double, align 8
+  store double %a, ptr %a.addr, align 8
+  %0 = load double, ptr %a.addr, align 8
+  %elt.round = call double @llvm.round.f64(double %0)
+  ret double %elt.round
+}
diff --git a/llvm/test/CodeGen/DirectX/sin.ll b/llvm/test/CodeGen/DirectX/sin.ll
index bb31d28bfcfee6..ea6f4d6200ee49 100644
--- a/llvm/test/CodeGen/DirectX/sin.ll
+++ b/llvm/test/CodeGen/DirectX/sin.ll
@@ -4,11 +4,8 @@
 ; CHECK:call float @dx.op.unary.f32(i32 13, float %{{.*}})
 ; CHECK:call half @dx.op.unary.f16(i32 13, half %{{.*}})
 
-target datalayout = "e-m:e-p:32:32-i1:32-i8:8-i16:16-i32:32-i64:64-f16:16-f32:32-f64:64-n8:16:32:64"
-target triple = "dxil-pc-shadermodel6.7-library"
-
 ; Function Attrs: noinline nounwind optnone
-define noundef float @_Z3foof(float noundef %a) #0 {
+define noundef float @sin_float(float noundef %a) #0 {
 entry:
   %a.addr = alloca float, align 4
   store float %a, ptr %a.addr, align 4
@@ -21,7 +18,7 @@ entry:
 declare float @llvm.sin.f32(float) #1
 
 ; Function Attrs: noinline nounwind optnone
-define noundef half @_Z3barDh(half noundef %a) #0 {
+define noundef half @sin_half(half noundef %a) #0 {
 entry:
   %a.addr = alloca half, align 2
   store half %a, ptr %a.addr, align 2
@@ -29,15 +26,3 @@ entry:
   %1 = call half @llvm.sin.f16(half %0)
   ret half %1
 }
-
-; Function Attrs: nocallback nofree nosync nounwind readnone speculatable willreturn
-declare half @llvm.sin.f16(half) #1
-
-attributes #0 = { noinline nounwind optnone "frame-pointer"="none" "min-legal-vector-width"="0" "no-trapping-math"="true" "stack-protector-buffer-size"="8" }
-attributes #1 = { nocallback nofree nosync nounwind readnone speculatable willreturn }
-
-!llvm.module.flags = !{!0}
-!llvm.ident = !{!1}
-
-!0 = !{i32 1, !"wchar_size", i32 4}
-!1 = !{!"clang version 15.0.0 (https://github.com/llvm/llvm-project.git 73417c517644db5c419c85c0b3cb6750172fcab5)"}
diff --git a/llvm/test/CodeGen/DirectX/sin_error.ll b/llvm/test/CodeGen/DirectX/sin_error.ll
new file mode 100644
index 00000000000000..2a839faff00fe7
--- /dev/null
+++ b/llvm/test/CodeGen/DirectX/sin_error.ll
@@ -0,0 +1,14 @@
+; RUN: not opt -S -dxil-op-lower %s 2>&1 | FileCheck %s
+
+; This test is expected to fail with the following error
+; CHECK: LLVM ERROR: Invalid Overload
+
+define noundef double @sin_double(double noundef %a) #0 {
+entry:
+  %a.addr = alloca double, align 8
+  store double %a, ptr %a.addr, align 8
+  %0 = load double, ptr %a.addr, align 8
+  %1 = call double @llvm.sin.f64(double %0)
+  ret double %1
+}
+
diff --git a/llvm/utils/TableGen/DXILEmitter.cpp b/llvm/utils/TableGen/DXILEmitter.cpp
index fc958f5328736c..ac355464fe6ca5 100644
--- a/llvm/utils/TableGen/DXILEmitter.cpp
+++ b/llvm/utils/TableGen/DXILEmitter.cpp
@@ -22,6 +22,7 @@
 #include "llvm/Support/DXILABI.h"
 #include "llvm/TableGen/Record.h"
 #include "llvm/TableGen/TableGenBackend.h"
+#include <string>
 
 using namespace llvm;
 using namespace llvm::dxil;
@@ -38,8 +39,8 @@ struct DXILOperationDesc {
   int OpCode;         // ID of DXIL operation
   StringRef OpClass;  // name of the opcode class
   StringRef Doc;      // the documentation description of this instruction
-  SmallVector<MVT::SimpleValueType> OpTypes; // Vector of operand types -
-                                             // return type is at index 0
+  SmallVector<Record *> OpTypes; // Vector of operand type records -
+                                 // return type is at index 0
   SmallVector<std::string>
       OpAttributes;     // operation attribute represented as strings
   StringRef Intrinsic;  // The llvm intrinsic map to OpName. Default is "" which
@@ -57,20 +58,21 @@ struct DXILOperationDesc {
   DXILShaderModel ShaderModel;           // minimum shader model required
   DXILShaderModel ShaderModelTranslated; // minimum shader model required with
                                          // translation by linker
-  int OverloadParamIndex; // parameter index which control the overload.
-                          // When < 0, should be only 1 overload type.
+  int OverloadParamIndex;             // Index of parameter with overload type.
+                                      //   -1 : no overload types
   SmallVector<StringRef, 4> counters; // counters for this inst.
   DXILOperationDesc(const Record *);
 };
 } // end anonymous namespace
 
-/// Convert DXIL type name string to dxil::ParameterKind
+/// Return dxil::ParameterKind corresponding to input LLVMType record
 ///
-/// \param VT Simple Value Type
+/// \param R TableGen def record of class LLVMType
 /// \return ParameterKind As defined in llvm/Support/DXILABI.h
 
-static ParameterKind getParameterKind(MVT::SimpleValueType VT) {
-  switch (VT) {
+static ParameterKind getParameterKind(const Record *R) {
+  auto VTRec = R->getValueAsDef("VT");
+  switch (getValueType(VTRec)) {
   case MVT::isVoid:
     return ParameterKind::VOID;
   case MVT::f16:
@@ -90,6 +92,18 @@ static ParameterKind getParameterKind(MVT::SimpleValueType VT) {
   case MVT::fAny:
   case MVT::iAny:
     return ParameterKind::OVERLOAD;
+  case MVT::Other:
+    // Handle DXIL-specific overload types
+    {
+      auto RetKind = StringSwitch<ParameterKind>(R->getNameInitAsString())
+                         .Cases("llvm_i16ori32_ty", "llvm_halforfloat_ty",
+                                ParameterKind::OVERLOAD)
+                         .Default(ParameterKind::INVALID);
+      if (RetKind != ParameterKind::INVALID) {
+        return RetKind;
+      }
+    }
+    LLVM_FALLTHROUGH;
   default:
     llvm_unreachable("Support for specified DXIL Type not yet implemented");
   }
@@ -106,45 +120,80 @@ DXILOperationDesc::DXILOperationDesc(const Record *R) {
 
   Doc = R->getValueAsString("Doc");
 
+  auto TypeRecs = R->getValueAsListOfDefs("OpTypes");
+  unsigned TypeRecsSize = TypeRecs.size();
+  // Populate OpTypes with return type and parameter types
+
+  // Parameter indices of overloaded parameters.
+  // This vector contains overload parameters in the order order used to
+  // resolve an LLVMMatchType in accordance with  convention outlined in
+  // the comment before the definition of class LLVMMatchType in
+  // llvm/IR/Intrinsics.td
+  SmallVector<int> OverloadParamIndices;
+  for (unsigned i = 0; i < TypeRecsSize; i++) {
+    auto TR = TypeRecs[i];
+    // Track operation parameter indices of any overload types
+    auto isAny = TR->getValueAsInt("isAny");
+    if (isAny == 1) {
+      // TODO: At present it is expected that all overload types in a DXIL Op
+      // are of the same type. Hence, OverloadParamIndices will have only one
+      // element. This implies we do not need a vector. However, until more
+      // (all?) DXIL Ops are added in DXIL.td, a vector is being used to flag
+      // cases this assumption would not hold.
+      if (!OverloadParamIndices.empty()) {
+        bool knownType = true;
+        // Ensure that the same overload type registered earlier is being used
+        for (auto Idx : OverloadParamIndices) {
+          if (TR != TypeRecs[Idx]) {
+            knownType = false;
+            break;
+          }
+        }
+        if (!knownType) {
+          report_fatal_error("Specification of multiple differing overload "
+                             "parameter types not yet supported",
+                             false);
+        }
+      } else {
+        OverloadParamIndices.push_back(i);
+      }
+    }
+    // Populate OpTypes array according to the type specification
+    if (TR->isAnonymous()) {
+      // Check prior overload types exist
+      assert(!OverloadParamIndices.empty() &&
+             "No prior overloaded parameter found to match.");
+      // Get the parameter index of anonymous type, TR, references
+      auto OLParamIndex = TR->getValueAsInt("Number");
+      // Resolve and insert the type to that at OLParamIndex
+      OpTypes.emplace_back(TypeRecs[OLParamIndex]);
+    } else {
+      // A non-anonymous type. Just record it in OpTypes
+      OpTypes.emplace_back(TR);
+    }
+  }
+
+  // Set the index of the overload parameter, if any.
+  OverloadParamIndex = -1; // default; indicating none
+  if (!OverloadParamIndices.empty()) {
+    if (OverloadParamIndices.size() > 1)
+      report_fatal_error("Multiple overload type specification not supported",
+                         false);
+    OverloadParamIndex = OverloadParamIndices[0];
+  }
+  // Get the operation class
+  OpClass = R->getValueAsDef("OpClass")->getName();
+
   if (R->getValue("LLVMIntrinsic")) {
     auto *IntrinsicDef = R->getValueAsDef("LLVMIntrinsic");
     auto DefName = IntrinsicDef->getName();
     assert(DefName.starts_with("int_") && "invalid intrinsic name");
     // Remove the int_ from intrinsic name.
     Intrinsic = DefName.substr(4);
-    // TODO: It is expected that return type and parameter types of
-    // DXIL Operation are the same as that of the intrinsic. Deviations
-    // are expected to be encoded in TableGen record specification and
-    // handled accordingly here. Support to be added later, as needed.
-    // Get parameter type list of the intrinsic. Types attribute contains
-    // the list of as [returnType, param1Type,, param2Type, ...]
-
-    OverloadParamIndex = -1;
-    auto TypeRecs = IntrinsicDef->getValueAsListOfDefs("Types");
-    unsigned TypeRecsSize = TypeRecs.size();
-    // Populate return type and parameter type names
-    for (unsigned i = 0; i < TypeRecsSize; i++) {
-      auto TR = TypeRecs[i];
-      OpTypes.emplace_back(getValueType(TR->getValueAsDef("VT")));
-      // Get the overload parameter index.
-      // TODO : Seems hacky. Is it possible that more than one parameter can
-      // be of overload kind??
-      // TODO: Check for any additional constraints specified for DXIL operation
-      // restricting return type.
-      if (i > 0) {
-        auto &CurParam = OpTypes.back();
-        if (getParameterKind(CurParam) >= ParameterKind::OVERLOAD) {
-          OverloadParamIndex = i;
-        }
-      }
-    }
-    // Get the operation class
-    OpClass = R->getValueAsDef("OpClass")->getName();
-
-    // NOTE: For now, assume that attributes of DXIL Operation are the same as
+    // TODO: For now, assume that attributes of DXIL Operation are the same as
     // that of the intrinsic. Deviations are expected to be encoded in TableGen
     // record specification and handled accordingly here. Support to be added
-    // later.
+    // as needed.
     auto IntrPropList = IntrinsicDef->getValueAsListInit("IntrProperties");
     auto IntrPropListSize = IntrPropList->size();
     for (unsigned i = 0; i < IntrPropListSize; i++) {
@@ -191,12 +240,13 @@ static std::string getParameterKindStr(ParameterKind Kind) {
 }
 
 /// Return a string representation of OverloadKind enum that maps to
-/// input Simple Value Type enum
-/// \param VT Simple Value Type enum
+/// input LLVMType record
+/// \param R TableGen def record of class LLVMType
 /// \return std::string string representation of OverloadKind
 
-static std::string getOverloadKindStr(MVT::SimpleValueType VT) {
-  switch (VT) {
+static std::string getOverloadKindStr(const Record *R) {
+  auto VTRec = R->getValueAsDef("VT");
+  switch (getValueType(VTRec)) {
   case MVT::isVoid:
     return "OverloadKind::VOID";
   case MVT::f16:
@@ -219,6 +269,20 @@ static std::string getOverloadKindStr(MVT::SimpleValueType VT) {
     return "OverloadKind::I16 | OverloadKind::I32 | OverloadKind::I64";
   case MVT::fAny:
     return "OverloadKind::HALF | OverloadKind::FLOAT | OverloadKind::DOUBLE";
+  case MVT::Other:
+    // Handle DXIL-specific overload types
+    {
+      auto RetStr =
+          StringSwitch<std::string>(R->getNameInitAsString())
+              .Case("llvm_i16ori32_ty", "OverloadKind::I16 | OverloadKind::I32")
+              .Case("llvm_halforfloat_ty",
+                    "OverloadKind::HALF | OverloadKind::FLOAT")
+              .Default("");
+      if (RetStr != "") {
+        return RetStr;
+      }
+    }
+    LLVM_FALLTHROUGH;
   default:
     llvm_unreachable(
         "Support for specified parameter OverloadKind not yet implemented");

def Frac : DXILOpMapping<22, unary, int_dx_frac,
"Returns a fraction from 0 to 1 that represents the "
"decimal part of the input.">;
"decimal part of the input.",
[llvm_halforfloat_ty, LLVMMatchType<0>]>;
Copy link
Member

@farzonl farzonl Mar 5, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For frac we created our own intrinsic. Do we have to duplicate the type definition or could we just update in intrinsicsDirectx.td and then inherit type? Since in this case its our own intrinsic could we just prevent generation of double instead of erroring?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For frac we created our own intrinsic. Do we have to duplicate the type definition or could we just update in intrinsicsDirectx.td and then inherit type? Since in this case its our own intrinsic could we just prevent generation of double instead of erroring?

The type used is llvm_halforfloat_ty - which is defined in DXIL.td. Using it to specify int_dx_frac will require the type definition to be moved to IntrinsicsDirectX.td. I also believe, as a result, the TableGen backends need to be taught to handle that type appropriately - all of which can be done but not (yet) sure of the impact (and need) in the standard Tablegen backends since definitions in IntrisicsDirectX.td appear to be globally available.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

More relevantly, the valid overload type of DirectX intrinsic int_dx_frac, viz., llvm_anyfloat_ty, which is float of any format and hence includes any scalar and vector floating-point types. OTOH, valid overload type for DXIL operation Frac that maps to int_dx_frac is one of the two scalar types half and double- denoted by llvm_halforfloat_ty defined in DXIL.td, to represent the more restrictive overload type. So, inheriting the type specification of int_dx_frac for Frac would not accurately model valid overload types of Frac.

@@ -257,7 +257,7 @@ static FunctionCallee getOrCreateDXILOpFunction(dxil::OpCode DXILOp,
// FIXME: find the issue and report error in clang instead of check it in
// backend.
if ((Prop->OverloadTys & (uint16_t)Kind) == 0) {
llvm_unreachable("invalid overload");
report_fatal_error("Invalid Overload Type", false);
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Generally best to include a comment for boolean flag args:

    report_fatal_error("Invalid Overload Type", /*gen_crash_diag=*/false);

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done.

@@ -0,0 +1,14 @@
; RUN: not opt -S -dxil-op-lower %s 2>&1 | FileCheck %s

; This test is expected to fail with the following error
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This comment doesn't really add any information to the check for "LLVM ERROR". Maybe a comment like "DXIL does not support overloads of sin for doubles" would be more helpful.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Changed.

StringSwitch<std::string>(R->getNameInitAsString())
.Case("llvm_i16ori32_ty", "OverloadKind::I16 | OverloadKind::I32")
.Case("llvm_halforfloat_ty",
"OverloadKind::HALF | OverloadKind::FLOAT")
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

String matching on these two types doesn't feel very robust or scalable - I think we want to do this slightly differently.

One option that keeps this simple but encodes the information a bit more clearly would be to add bits to LLVMType, like so:

let isHalfOrFloat = 0

This still isn't very extensible but it accomplishes the approach you have here a little bit more cleanly.

The other option is of course to support lists of types and embed them in some kind of LLVMAlternativeType or something, but that's quite a bit more complex and might be overkill.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

String matching on these two types doesn't feel very robust or scalable - I think we want to do this slightly differently.

One option that keeps this simple but encodes the information a bit more clearly would be to add bits to LLVMType, like so:

let isHalfOrFloat = 0

This still isn't very extensible but it accomplishes the approach you have here a little bit more cleanly.

OK. Changed.

The other option is of course to support lists of types and embed them in some kind of LLVMAlternativeType or something, but that's quite a bit more complex and might be overkill.

I had considered something like this but refrained for the reasons you cited.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hm, I'm not sure that really worked out that well. We're still switching over the names in "getParameterKind" so we haven't really improved the abstraction, and the if chain over the known set of bools is probably just as error prone.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Are these the only two instances of these sets of types that are needed for all of the DXIL ops? If there are others, are they always two types or could it be a larger set?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hm, I'm not sure that really worked out that well. We're still switching over the names in "getParameterKind" so we haven't really improved the abstraction, and the if chain over the known set of bools is probably just as error prone.

I missed changing this usage. I pushed a change. Thanks!

Are these the only two instances of these sets of types that are needed for all of the DXIL ops?

No. It appears that there are other instances of valid overload types of some DXIL Ops that would be over-specified using LLVM's *Any value type. For example, valid overload types of WaveMatrix_Fill are half, float and int. However, I have not yet formulated how (if?) such DXIL ops should be specified in DXIL.td.

If there are others, are they always two types or could it be a larger set?

Such overload types can include more than two fixed types.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

WaveMatrix is a preview feature, and the existing WaveMatrix* ops are almost certainly going to change. Feel free to ignore those for now.

@bharadwajy bharadwajy requested a review from bogner March 7, 2024 19:15
@bharadwajy bharadwajy force-pushed the dxil_td/precise_dxil_type_spec branch from 6477eb1 to 27563e6 Compare March 11, 2024 19:21
@@ -257,7 +257,7 @@ static FunctionCallee getOrCreateDXILOpFunction(dxil::OpCode DXILOp,
// FIXME: find the issue and report error in clang instead of check it in
// backend.
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This comment isn't accurate and we should remove it. Your change below does surface the error as we need it to.

StringSwitch<std::string>(R->getNameInitAsString())
.Case("llvm_i16ori32_ty", "OverloadKind::I16 | OverloadKind::I32")
.Case("llvm_halforfloat_ty",
"OverloadKind::HALF | OverloadKind::FLOAT")
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

WaveMatrix is a preview feature, and the existing WaveMatrix* ops are almost certainly going to change. Feel free to ignore those for now.

Copy link
Contributor

@bogner bogner left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM with the caveat that I think we need to iterate on the design for how we handle things like isHalfOrFloat to be a bit more general. I'm okay with that happening in tree as you expand this to handle more operations and get a better idea of the total space.

(R->getValueAsInt("isI16OrI32"))) {
return ParameterKind::OVERLOAD;
}
}
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The extra scope here isn't necessary, and we don't need parens around the getValueAsInt calls. Scope comment applies to the other switch as well.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The extra scope here isn't necessary, and we don't need parens around the getValueAsInt calls.

Fixed.

Scope comment applies to the other switch as well.

The comment is specifically to highlight the case that refers to the DXIL-specific types isHalforFloat and isI16OrI32 only, which are of MVT::Other type. The other cases handle normal LLVM MVTs.

Implement an abstraction to specify precise overload types supported by
DXIL ops. These overload types are typically a subset of LLVM
intrinsic types.

Implement the corresponding changes in DXILEmitter backend.

Add tests to verify expected errors for unsupported overload types
at code generation time.

Add tests to check for correct overload error output.
@bharadwajy bharadwajy force-pushed the dxil_td/precise_dxil_type_spec branch from 7485a4c to 558c504 Compare March 12, 2024 20:31
@bharadwajy bharadwajy merged commit 54f631d into llvm:main Mar 12, 2024
3 of 4 checks passed
@bharadwajy bharadwajy deleted the dxil_td/precise_dxil_type_spec branch March 12, 2024 20:51
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

6 participants