Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[DX] Support pipeline state masks #66425

Merged
merged 2 commits into from
Sep 15, 2023

Conversation

llvm-beanz
Copy link
Collaborator

The DXContainer pipeline state information encodes a bunch of mask vectors that are used to track things about the inputs and outputs from each shader.

This adds support for reading and writing them throught he YAML test interfaces. The writing logic in MC is extremely primitive and we'll want to revisit the API for that, but since I'm not sure how we'll want to generate the mask bits from DXIL during code generation I didn't want to spend too much time on the API.

Fixes #59479

The DXContainer pipeline state information encodes a bunch of mask
vectors that are used to track things about the inputs and outputs from
each shader.

This adds support for reading and writing them throught he YAML test
interfaces. The writing logic in MC is extremely primitive and we'll
want to revisit the API for that, but since I'm not sure how we'll want
to generate the mask bits from DXIL during code generation I didn't
want to spend too much time on the API.
@llvmbot
Copy link
Collaborator

llvmbot commented Sep 14, 2023

@llvm/pr-subscribers-mc
@llvm/pr-subscribers-objectyaml
@llvm/pr-subscribers-llvm-binary-utilities

@llvm/pr-subscribers-backend-directx

Changes The DXContainer pipeline state information encodes a bunch of mask vectors that are used to track things about the inputs and outputs from each shader.

This adds support for reading and writing them throught he YAML test interfaces. The writing logic in MC is extremely primitive and we'll want to revisit the API for that, but since I'm not sure how we'll want to generate the mask bits from DXIL during code generation I didn't want to spend too much time on the API.

Fixes #59479

Patch is 61.27 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/66425.diff

29 Files Affected:

  • (modified) llvm/include/llvm/MC/DXContainerPSVInfo.h (+12)
  • (modified) llvm/include/llvm/Object/DXContainer.h (+75-2)
  • (modified) llvm/include/llvm/ObjectYAML/DXContainerYAML.h (+8)
  • (modified) llvm/include/llvm/Support/EndianStream.h (+9)
  • (modified) llvm/lib/MC/DXContainerPSVInfo.cpp (+13)
  • (modified) llvm/lib/Object/DXContainer.cpp (+62)
  • (modified) llvm/lib/ObjectYAML/DXContainerEmitter.cpp (+19)
  • (modified) llvm/lib/ObjectYAML/DXContainerYAML.cpp (+18)
  • (added) llvm/test/ObjectYAML/DXContainer/DomainMaskVectors.yaml (+200)
  • (added) llvm/test/ObjectYAML/DXContainer/GeometryMaskVectors.yaml (+174)
  • (added) llvm/test/ObjectYAML/DXContainer/HullMaskVectors.yaml (+181)
  • (modified) llvm/test/ObjectYAML/DXContainer/PSVv1-amplification.yaml (+14-4)
  • (modified) llvm/test/ObjectYAML/DXContainer/PSVv1-compute.yaml (+14-4)
  • (modified) llvm/test/ObjectYAML/DXContainer/PSVv1-domain.yaml (+20-8)
  • (modified) llvm/test/ObjectYAML/DXContainer/PSVv1-geometry.yaml (+14-4)
  • (modified) llvm/test/ObjectYAML/DXContainer/PSVv1-hull.yaml (+20-8)
  • (modified) llvm/test/ObjectYAML/DXContainer/PSVv1-mesh.yaml (+14-4)
  • (modified) llvm/test/ObjectYAML/DXContainer/PSVv1-pixel.yaml (+14-4)
  • (modified) llvm/test/ObjectYAML/DXContainer/PSVv1-vertex.yaml (+14-4)
  • (modified) llvm/test/ObjectYAML/DXContainer/PSVv2-amplification.yaml (+14-4)
  • (modified) llvm/test/ObjectYAML/DXContainer/PSVv2-compute.yaml (+14-4)
  • (modified) llvm/test/ObjectYAML/DXContainer/PSVv2-domain.yaml (+20-8)
  • (modified) llvm/test/ObjectYAML/DXContainer/PSVv2-geometry.yaml (+14-4)
  • (modified) llvm/test/ObjectYAML/DXContainer/PSVv2-hull.yaml (+20-8)
  • (modified) llvm/test/ObjectYAML/DXContainer/PSVv2-mesh.yaml (+14-4)
  • (modified) llvm/test/ObjectYAML/DXContainer/PSVv2-pixel.yaml (+14-4)
  • (modified) llvm/test/ObjectYAML/DXContainer/PSVv2-vertex.yaml (+14-4)
  • (modified) llvm/test/ObjectYAML/DXContainer/SigElements.yaml (+7-2)
  • (modified) llvm/tools/obj2yaml/dxcontainer2yaml.cpp (+19)
diff --git a/llvm/include/llvm/MC/DXContainerPSVInfo.h b/llvm/include/llvm/MC/DXContainerPSVInfo.h
index dc6bf2fa40d065a..76e3e498029c4a1 100644
--- a/llvm/include/llvm/MC/DXContainerPSVInfo.h
+++ b/llvm/include/llvm/MC/DXContainerPSVInfo.h
@@ -51,6 +51,18 @@ struct PSVRuntimeInfo {
   SmallVector<PSVSignatureElement> OutputElements;
   SmallVector<PSVSignatureElement> PatchOrPrimElements;
 
+  // The interface here is bad, and we'll want to change this in the future. We
+  // probably will want to build out these mask vectors as vectors of bools and
+  // have this utility object convert them to the bit masks. I don't want to
+  // over-engineer this API now since we don't know what the data coming in to
+  // feed it will look like, so I kept it extremely simple for the immediate use
+  // case.
+  SmallVector<uint32_t> OutputVectorMasks[4];
+  SmallVector<uint32_t> PatchOrPrimMasks;
+  SmallVector<uint32_t> InputOutputMap[4];
+  SmallVector<uint32_t> InputPatchMap;
+  SmallVector<uint32_t> PatchOutputMap;
+
   // Serialize PSVInfo into the provided raw_ostream. The version field
   // specifies the data version to encode, the default value specifies encoding
   // the highest supported version.
diff --git a/llvm/include/llvm/Object/DXContainer.h b/llvm/include/llvm/Object/DXContainer.h
index 2aae0a199f8c1c0..9ea8a48f9430d74 100644
--- a/llvm/include/llvm/Object/DXContainer.h
+++ b/llvm/include/llvm/Object/DXContainer.h
@@ -27,6 +27,18 @@ namespace llvm {
 namespace object {
 
 namespace DirectX {
+
+namespace detail {
+template <typename T>
+std::enable_if_t<std::is_arithmetic<T>::value, void> swapBytes(T &value) {
+  sys::swapByteOrder(value);
+}
+
+template <typename T>
+std::enable_if_t<std::is_class<T>::value, void> swapBytes(T &value) {
+  value.swapBytes();
+}
+} // namespace detail
 class PSVRuntimeInfo {
 
   // This class provides a view into the underlying resource array. The Resource
@@ -35,7 +47,7 @@ class PSVRuntimeInfo {
   // swaps it as appropriate.
   template <typename T> struct ViewArray {
     StringRef Data;
-    uint32_t Stride; // size of each element in the list.
+    uint32_t Stride = sizeof(T); // size of each element in the list.
 
     ViewArray() = default;
     ViewArray(StringRef D, size_t S) : Data(D), Stride(S) {}
@@ -65,7 +77,7 @@ class PSVRuntimeInfo {
         memcpy(static_cast<void *>(&Val), Current,
                std::min(Stride, MaxStride()));
         if (sys::IsBigEndianHost)
-          Val.swapBytes();
+          detail::swapBytes(Val);
         return Val;
       }
 
@@ -120,6 +132,12 @@ class PSVRuntimeInfo {
   SigElementArray SigOutputElements;
   SigElementArray SigPatchOrPrimElements;
 
+  ViewArray<uint32_t> OutputVectorMasks[4];
+  ViewArray<uint32_t> PatchOrPrimMasks;
+  ViewArray<uint32_t> InputOutputMap[4];
+  ViewArray<uint32_t> InputPatchMap;
+  ViewArray<uint32_t> PatchOutputMap;
+
 public:
   PSVRuntimeInfo(StringRef D) : Data(D), Size(0) {}
 
@@ -140,6 +158,22 @@ class PSVRuntimeInfo {
 
   const InfoStruct &getInfo() const { return BasicInfo; }
 
+  template <typename T> const T *getInfoAs() const {
+    if (const auto *P = std::get_if<dxbc::PSV::v2::RuntimeInfo>(&BasicInfo))
+      return static_cast<const T *>(P);
+    if (std::is_same<T, dxbc::PSV::v2::RuntimeInfo>::value)
+      return nullptr;
+
+    if (const auto *P = std::get_if<dxbc::PSV::v1::RuntimeInfo>(&BasicInfo))
+      return static_cast<const T *>(P);
+    if (std::is_same<T, dxbc::PSV::v1::RuntimeInfo>::value)
+      return nullptr;
+
+    if (const auto *P = std::get_if<dxbc::PSV::v0::RuntimeInfo>(&BasicInfo))
+      return static_cast<const T *>(P);
+    return nullptr;
+  }
+
   StringRef getStringTable() const { return StringTable; }
   ArrayRef<uint32_t> getSemanticIndexTable() const {
     return SemanticIndexTable;
@@ -155,7 +189,46 @@ class PSVRuntimeInfo {
     return SigPatchOrPrimElements;
   }
 
+  ViewArray<uint32_t> getOutputVectorMasks(size_t Idx) const {
+    assert(Idx < 4);
+    return OutputVectorMasks[Idx];
+  }
+
+  ViewArray<uint32_t> getPatchOrPrimMasks() const { return PatchOrPrimMasks; }
+
+  ViewArray<uint32_t> getInputOutputMap(size_t Idx) const {
+    assert(Idx < 4);
+    return InputOutputMap[Idx];
+  }
+
+  ViewArray<uint32_t> getInputPatchMap() const { return InputPatchMap; }
+  ViewArray<uint32_t> getPatchOutputMap() const { return PatchOutputMap; }
+
   uint32_t getSigElementStride() const { return SigInputElements.Stride; }
+
+  bool usesViewID() const {
+    if (const auto *P = getInfoAs<dxbc::PSV::v1::RuntimeInfo>())
+      return P->UsesViewID != 0;
+    return false;
+  }
+
+  uint8_t getInputVectorCount() const {
+    if (const auto *P = getInfoAs<dxbc::PSV::v1::RuntimeInfo>())
+      return P->SigInputVectors;
+    return 0;
+  }
+
+  ArrayRef<uint8_t> getOutputVectorCounts() const {
+    if (const auto *P = getInfoAs<dxbc::PSV::v1::RuntimeInfo>())
+      return ArrayRef<uint8_t>(P->SigOutputVectors);
+    return ArrayRef<uint8_t>();
+  }
+
+  uint8_t getPatchConstOrPrimVectorCount() const {
+    if (const auto *P = getInfoAs<dxbc::PSV::v1::RuntimeInfo>())
+      return P->GeomData.SigPatchConstOrPrimVectors;
+    return 0;
+  }
 };
 
 } // namespace DirectX
diff --git a/llvm/include/llvm/ObjectYAML/DXContainerYAML.h b/llvm/include/llvm/ObjectYAML/DXContainerYAML.h
index 1ab979fd0dfeaca..bce6fa8475bdd48 100644
--- a/llvm/include/llvm/ObjectYAML/DXContainerYAML.h
+++ b/llvm/include/llvm/ObjectYAML/DXContainerYAML.h
@@ -113,6 +113,13 @@ struct PSVInfo {
   SmallVector<SignatureElement> SigOutputElements;
   SmallVector<SignatureElement> SigPatchOrPrimElements;
 
+  using MaskVector = SmallVector<llvm::yaml::Hex32>;
+  MaskVector OutputVectorMasks[4];
+  MaskVector PatchOrPrimMasks;
+  MaskVector InputOutputMap[4];
+  MaskVector InputPatchMap;
+  MaskVector PatchOutputMap;
+
   void mapInfoForVersion(yaml::IO &IO);
 
   PSVInfo();
@@ -143,6 +150,7 @@ struct Object {
 LLVM_YAML_IS_SEQUENCE_VECTOR(llvm::DXContainerYAML::Part)
 LLVM_YAML_IS_SEQUENCE_VECTOR(llvm::DXContainerYAML::ResourceBindInfo)
 LLVM_YAML_IS_SEQUENCE_VECTOR(llvm::DXContainerYAML::SignatureElement)
+LLVM_YAML_IS_SEQUENCE_VECTOR(llvm::DXContainerYAML::PSVInfo::MaskVector)
 LLVM_YAML_DECLARE_ENUM_TRAITS(llvm::dxbc::PSV::SemanticKind)
 LLVM_YAML_DECLARE_ENUM_TRAITS(llvm::dxbc::PSV::ComponentType)
 LLVM_YAML_DECLARE_ENUM_TRAITS(llvm::dxbc::PSV::InterpolationMode)
diff --git a/llvm/include/llvm/Support/EndianStream.h b/llvm/include/llvm/Support/EndianStream.h
index 8ff87d23e83b145..a1dac5ad9f42a74 100644
--- a/llvm/include/llvm/Support/EndianStream.h
+++ b/llvm/include/llvm/Support/EndianStream.h
@@ -25,6 +25,15 @@ namespace support {
 
 namespace endian {
 
+template <typename value_type>
+inline void write_array(raw_ostream &os, ArrayRef<value_type> values,
+                        endianness endian) {
+  for (const auto orig : values) {
+    value_type value = byte_swap<value_type>(orig, endian);
+    os.write((const char *)&value, sizeof(value_type));
+  }
+}
+
 template <typename value_type>
 inline void write(raw_ostream &os, value_type value, endianness endian) {
   value = byte_swap<value_type>(value, endian);
diff --git a/llvm/lib/MC/DXContainerPSVInfo.cpp b/llvm/lib/MC/DXContainerPSVInfo.cpp
index 03df6be41a21c21..533659053c36f3d 100644
--- a/llvm/lib/MC/DXContainerPSVInfo.cpp
+++ b/llvm/lib/MC/DXContainerPSVInfo.cpp
@@ -147,4 +147,17 @@ void PSVRuntimeInfo::write(raw_ostream &OS, uint32_t Version) const {
     OS.write(reinterpret_cast<const char *>(&SignatureElements[0]),
              SignatureElements.size() * sizeof(v0::SignatureElement));
   }
+
+  for (const auto &MaskVector : OutputVectorMasks)
+    support::endian::write_array(OS, ArrayRef<uint32_t>(MaskVector),
+                                 support::little);
+  support::endian::write_array(OS, ArrayRef<uint32_t>(PatchOrPrimMasks),
+                               support::little);
+  for (const auto &MaskVector : InputOutputMap)
+    support::endian::write_array(OS, ArrayRef<uint32_t>(MaskVector),
+                                 support::little);
+  support::endian::write_array(OS, ArrayRef<uint32_t>(InputPatchMap),
+                               support::little);
+  support::endian::write_array(OS, ArrayRef<uint32_t>(PatchOutputMap),
+                               support::little);
 }
diff --git a/llvm/lib/Object/DXContainer.cpp b/llvm/lib/Object/DXContainer.cpp
index df1f98213a9951f..16433f11eef625f 100644
--- a/llvm/lib/Object/DXContainer.cpp
+++ b/llvm/lib/Object/DXContainer.cpp
@@ -321,6 +321,68 @@ Error DirectX::PSVRuntimeInfo::parse(uint16_t ShaderKind) {
     Current += PSize;
   }
 
+  ArrayRef<uint8_t> OutputVectorCounts = getOutputVectorCounts();
+  uint8_t PatchConstOrPrimVectorCount = getPatchConstOrPrimVectorCount();
+  uint8_t InputVectorCount = getInputVectorCount();
+
+  auto maskDwordSize = [](uint8_t Vector) {
+    return (static_cast<uint32_t>(Vector) + 7) >> 3;
+  };
+
+  auto mapTableSize = [maskDwordSize](uint8_t X, uint8_t Y) {
+    return maskDwordSize(Y) * X * 4;
+  };
+
+  if (usesViewID()) {
+    for (uint32_t I = 0; I < 4; ++I) {
+      // The vector mask is one bit per component and 4 components per vector.
+      // We can compute the number of dwords required by rounding up to the next
+      // multiple of 8.
+      uint32_t NumDwords =
+          maskDwordSize(static_cast<uint32_t>(OutputVectorCounts[I]));
+      size_t NumBytes = NumDwords * sizeof(uint32_t);
+      OutputVectorMasks[I].Data = Data.substr(Current - Data.begin(), NumBytes);
+      Current += NumBytes;
+    }
+
+    if (ShaderStage == Triple::Hull && PatchConstOrPrimVectorCount > 0) {
+      uint32_t NumDwords = maskDwordSize(PatchConstOrPrimVectorCount);
+      size_t NumBytes = NumDwords * sizeof(uint32_t);
+      PatchOrPrimMasks.Data = Data.substr(Current - Data.begin(), NumBytes);
+      Current += NumBytes;
+    }
+  }
+
+  // Input/Output mapping table
+  for (uint32_t I = 0; I < 4; ++I) {
+    if (InputVectorCount == 0 || OutputVectorCounts[I] == 0)
+      continue;
+    uint32_t NumDwords = mapTableSize(InputVectorCount, OutputVectorCounts[I]);
+    size_t NumBytes = NumDwords * sizeof(uint32_t);
+    InputOutputMap[I].Data = Data.substr(Current - Data.begin(), NumBytes);
+    Current += NumBytes;
+  }
+
+  // Hull shader: Input/Patch mapping table
+  if (ShaderStage == Triple::Hull && PatchConstOrPrimVectorCount > 0 &&
+      InputVectorCount > 0) {
+    uint32_t NumDwords =
+        mapTableSize(InputVectorCount, PatchConstOrPrimVectorCount);
+    size_t NumBytes = NumDwords * sizeof(uint32_t);
+    InputPatchMap.Data = Data.substr(Current - Data.begin(), NumBytes);
+    Current += NumBytes;
+  }
+
+  // Domain Shader: Patch/Output mapping table
+  if (ShaderStage == Triple::Domain && PatchConstOrPrimVectorCount > 0 &&
+      OutputVectorCounts[0] > 0) {
+    uint32_t NumDwords =
+        mapTableSize(PatchConstOrPrimVectorCount, OutputVectorCounts[0]);
+    size_t NumBytes = NumDwords * sizeof(uint32_t);
+    PatchOutputMap.Data = Data.substr(Current - Data.begin(), NumBytes);
+    Current += NumBytes;
+  }
+
   return Error::success();
 }
 
diff --git a/llvm/lib/ObjectYAML/DXContainerEmitter.cpp b/llvm/lib/ObjectYAML/DXContainerEmitter.cpp
index c5d5d6551d401c5..de8b0f59844bfff 100644
--- a/llvm/lib/ObjectYAML/DXContainerEmitter.cpp
+++ b/llvm/lib/ObjectYAML/DXContainerEmitter.cpp
@@ -219,6 +219,25 @@ void DXContainerWriter::writeParts(raw_ostream &OS) {
             El.Allocated, El.Kind, El.Type, El.Mode, El.DynamicMask,
             El.Stream});
 
+      for (int I = 0; I < 4; ++I) {
+        PSV.OutputVectorMasks[I].insert(PSV.OutputVectorMasks[I].begin(),
+                                        P.Info->OutputVectorMasks[I].begin(),
+                                        P.Info->OutputVectorMasks[I].end());
+        PSV.InputOutputMap[I].insert(PSV.InputOutputMap[I].begin(),
+                                     P.Info->InputOutputMap[I].begin(),
+                                     P.Info->InputOutputMap[I].end());
+      }
+
+      PSV.PatchOrPrimMasks.insert(PSV.PatchOrPrimMasks.begin(),
+                                  P.Info->PatchOrPrimMasks.begin(),
+                                  P.Info->PatchOrPrimMasks.end());
+      PSV.InputPatchMap.insert(PSV.InputPatchMap.begin(),
+                               P.Info->InputPatchMap.begin(),
+                               P.Info->InputPatchMap.end());
+      PSV.PatchOutputMap.insert(PSV.PatchOutputMap.begin(),
+                                P.Info->PatchOutputMap.begin(),
+                                P.Info->PatchOutputMap.end());
+
       PSV.finalize(static_cast<Triple::EnvironmentType>(
           Triple::Pixel + P.Info->Info.ShaderStage));
       PSV.write(OS, P.Info->Version);
diff --git a/llvm/lib/ObjectYAML/DXContainerYAML.cpp b/llvm/lib/ObjectYAML/DXContainerYAML.cpp
index 2b03098d7a5d08b..c7cf1ec9afc1f66 100644
--- a/llvm/lib/ObjectYAML/DXContainerYAML.cpp
+++ b/llvm/lib/ObjectYAML/DXContainerYAML.cpp
@@ -139,6 +139,24 @@ void MappingTraits<DXContainerYAML::PSVInfo>::mapping(
   IO.mapRequired("SigInputElements", PSV.SigInputElements);
   IO.mapRequired("SigOutputElements", PSV.SigOutputElements);
   IO.mapRequired("SigPatchOrPrimElements", PSV.SigPatchOrPrimElements);
+
+  Triple::EnvironmentType Stage = dxbc::getShaderStage(PSV.Info.ShaderStage);
+  if (PSV.Info.UsesViewID) {
+    MutableArrayRef<SmallVector<llvm::yaml::Hex32>> MutableOutMasks(
+        PSV.OutputVectorMasks);
+    IO.mapRequired("OutputVectorMasks", MutableOutMasks);
+    if (Stage == Triple::EnvironmentType::Hull)
+      IO.mapRequired("PatchOrPrimMasks", PSV.PatchOrPrimMasks);
+  }
+  MutableArrayRef<SmallVector<llvm::yaml::Hex32>> MutableIOMap(
+      PSV.InputOutputMap);
+  IO.mapRequired("InputOutputMap", MutableIOMap);
+
+  if (Stage == Triple::EnvironmentType::Hull)
+    IO.mapRequired("InputPatchMap", PSV.InputPatchMap);
+
+  if (Stage == Triple::EnvironmentType::Domain)
+    IO.mapRequired("PatchOutputMap", PSV.PatchOutputMap);
 }
 
 void MappingTraits<DXContainerYAML::Part>::mapping(IO &IO,
diff --git a/llvm/test/ObjectYAML/DXContainer/DomainMaskVectors.yaml b/llvm/test/ObjectYAML/DXContainer/DomainMaskVectors.yaml
new file mode 100644
index 000000000000000..713fbc61e094b5a
--- /dev/null
+++ b/llvm/test/ObjectYAML/DXContainer/DomainMaskVectors.yaml
@@ -0,0 +1,200 @@
+# RUN: yaml2obj %s | obj2yaml | FileCheck %s
+--- !dxcontainer
+Header:
+  Hash:            [ 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 
+                     0x0, 0x0, 0x0, 0x0, 0x0, 0x0 ]
+  Version:
+    Major:           1
+    Minor:           0
+  FileSize:        4616
+  PartCount:       8
+  PartOffsets:     [ 64, 80, 140, 200, 580, 952, 2756, 2784 ]
+Parts:
+  - Name:            SFI0
+    Size:            8
+    Flags:
+      Doubles:         false
+      ComputeShadersPlusRawAndStructuredBuffers: false
+      UAVsAtEveryStage: false
+      Max64UAVs:       false
+      MinimumPrecision: false
+      DX11_1_DoubleExtensions: false
+      DX11_1_ShaderExtensions: false
+      LEVEL9ComparisonFiltering: false
+      TiledResources:  false
+      StencilRef:      false
+      InnerCoverage:   false
+      TypedUAVLoadAdditionalFormats: false
+      ROVs:            false
+      ViewportAndRTArrayIndexFromAnyShaderFeedingRasterizer: false
+      WaveOps:         false
+      Int64Ops:        false
+      ViewID:          true
+      Barycentrics:    false
+      NativeLowPrecision: false
+      ShadingRate:     false
+      Raytracing_Tier_1_1: false
+      SamplerFeedback: false
+      AtomicInt64OnTypedResource: false
+      AtomicInt64OnGroupShared: false
+      DerivativesInMeshAndAmpShaders: false
+      ResourceDescriptorHeapIndexing: false
+      SamplerDescriptorHeapIndexing: false
+      RESERVED:        false
+      AtomicInt64OnHeapResource: false
+      AdvancedTextureOps: false
+      WriteableMSAATextures: false
+      NextUnusedBit:   false
+  - Name:            ISG1
+    Size:            52
+  - Name:            OSG1
+    Size:            52
+  - Name:            PSG1
+    Size:            372
+  - Name:            PSV0
+    Size:            364
+    PSVInfo:
+      Version:         2
+      ShaderStage:     4
+      InputControlPointCount: 16
+      OutputPositionPresent: 1
+      TessellatorDomain: 3
+      MinimumWaveLaneCount: 0
+      MaximumWaveLaneCount: 4294967295
+      UsesViewID:      1
+      SigPatchConstOrPrimVectors: 7
+      SigInputVectors: 1
+      SigOutputVectors: [ 1, 0, 0, 0 ]
+      NumThreadsX:     0
+      NumThreadsY:     0
+      NumThreadsZ:     0
+      ResourceStride:  24
+      Resources:
+        - Type:            2
+          Space:           0
+          LowerBound:      0
+          UpperBound:      0
+          Kind:            13
+          Flags:           0
+      SigInputElements:
+        - Name:       ...

Copy link
Contributor

@bogner bogner left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks basically good though I have a few stylistic/abstraction type comments

@@ -51,6 +51,18 @@ struct PSVRuntimeInfo {
SmallVector<PSVSignatureElement> OutputElements;
SmallVector<PSVSignatureElement> PatchOrPrimElements;

// The interface here is bad, and we'll want to change this in the future. We
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Probably worth putting TODO: in this comment

// over-engineer this API now since we don't know what the data coming in to
// feed it will look like, so I kept it extremely simple for the immediate use
// case.
SmallVector<uint32_t> OutputVectorMasks[4];
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This leaves a lot of "4"s laying around the patch. Might be cleaner to use std::array<SmallVector<uint32_t>, 4> for OutputVectorMasks and InputOutputMap, and then just use .size() in the various places we iterate through these and their derivative types?

@@ -219,6 +219,25 @@ void DXContainerWriter::writeParts(raw_ostream &OS) {
El.Allocated, El.Kind, El.Type, El.Mode, El.DynamicMask,
El.Stream});

for (int I = 0; I < 4; ++I) {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If you do change to using std::array as my other comment suggests, it'd be good to keep a static assert here:

      static_assert(PSV.OutputVectorMasks.size() == PSV.InputOutputMap.size());
      for (size_t I = 0, E = PSV.OutputVectorMasks.size(); I != E; ++I) {
        ...


// Domain Shader: Patch/Output mapping table
if (ShaderStage == Triple::Domain && PatchConstOrPrimVectorCount > 0 &&
OutputVectorCounts[0] > 0) {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It's a little awkward to be checking the shader stages directly here. Do you think it'd be better to introduce helpers like PSVRuntimeInfo::hasInputPatchMapping(Triple::EnvironmentType) and hasPatchOutputMapping?

Relatedly, is it / should it be an error to have PatchConstOrPrimVectorCount > 0 if this is some other shader stage?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I like the refactoring idea here. I could also make an accessor that does the check and returns 0 for the incorrect stages. It can't be an error if the value is non-zero because the value is in a union that stores other data in other stages.

I would very much like to reconsider how all of this is encoded in the future because it goes to great lengths to save a few bits here and there, but is really complex as a result.

@@ -140,6 +159,22 @@ class PSVRuntimeInfo {

const InfoStruct &getInfo() const { return BasicInfo; }

template <typename T> const T *getInfoAs() const {
if (const auto *P = std::get_if<dxbc::PSV::v2::RuntimeInfo>(&BasicInfo))
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could we just return std::get_if <T >(&BasicInfo) for getInfoAs()?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That wouldn’t do the same thing. This allows us to do getInfoAs<dxbc::PSV::v1::RuntimeInfo>() and get back a v1 pointer even when the underlying object is a v2 object, or get a v0 object regardless of the underlying type.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why do we want to get a v0::RuntimeInfo ptr from a v2 object?
Because v0::RuntimeInfo::swapBytes() is not included in v2::RuntimeInfo::swapBytes()?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is useful for cases where you want to read a field from the v0 or v1 data, but you don't actually care if the underlying data is a different version. You can see it in use in the usesViewID() method. That bit is part of the v1 structure, so I want to return its value for v1 or v2 and return false for v0.

This helper allows me to not need to replicate the std::get_if blocks throughout the class, and since std::is_same will be compile-time resolved, the instantiations should be stripped down based on the statically resolvable structure versions.

@llvm-beanz llvm-beanz merged commit b799e9d into llvm:main Sep 15, 2023
1 of 2 checks passed
ZijunZhaoCCK pushed a commit to ZijunZhaoCCK/llvm-project that referenced this pull request Sep 19, 2023
The DXContainer pipeline state information encodes a bunch of mask
vectors that are used to track things about the inputs and outputs from
each shader.

This adds support for reading and writing them throught he YAML test
interfaces. The writing logic in MC is extremely primitive and we'll
want to revisit the API for that, but since I'm not sure how we'll want
to generate the mask bits from DXIL during code generation I didn't want
to spend too much time on the API.

Fixes llvm#59479
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
Status: Done
Development

Successfully merging this pull request may close these issues.

[DirectX] DXContainer PSV0 part Object support
4 participants