Skip to content

[HLSL][NFC] Remove RegisterBindingFlags struct #108924

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 7 commits into from
Sep 18, 2024

Conversation

hekota
Copy link
Member

@hekota hekota commented Sep 17, 2024

When diagnosing register bindings we just need to make sure there is a resource that matches the provided register type. We can emit the diagnostics right away instead of collecting flags in the RegisterBindingFlags struct. That also enables early exit when scanning user defined types because we can return as soon as we find a matching resource for the given register type.

When diagnosing register bindings we just need to make sure there is
a resource that matches the provided register type. We can emit the
diagnostics right away instead of collecting flags in the
RegisterBindingFlags struct. That also enables early exit when scanning
user defined types because we can return as soon as we find a matching
resource for the given register type.
@llvmbot llvmbot added clang Clang issues not falling into any other category clang:frontend Language frontend issues, e.g. anything involving "Sema" HLSL HLSL Language Support labels Sep 17, 2024
@llvmbot
Copy link
Member

llvmbot commented Sep 17, 2024

@llvm/pr-subscribers-clang

Author: Helena Kotas (hekota)

Changes

When diagnosing register bindings we just need to make sure there is a resource that matches the provided register type. We can emit the diagnostics right away instead of collecting flags in the RegisterBindingFlags struct. That also enables early exit when scanning user defined types because we can return as soon as we find a matching resource for the given register type.


Full diff: https://github.com/llvm/llvm-project/pull/108924.diff

1 Files Affected:

  • (modified) clang/lib/Sema/SemaHLSL.cpp (+113-181)
diff --git a/clang/lib/Sema/SemaHLSL.cpp b/clang/lib/Sema/SemaHLSL.cpp
index 26de9a986257c5..45a8d75ce099cc 100644
--- a/clang/lib/Sema/SemaHLSL.cpp
+++ b/clang/lib/Sema/SemaHLSL.cpp
@@ -41,6 +41,47 @@
 
 using namespace clang;
 
+enum class RegisterType { SRV, UAV, CBuffer, Sampler, C, I, Invalid };
+
+static RegisterType getRegisterType(llvm::dxil::ResourceClass RC) {
+  switch (RC) {
+  case llvm::dxil::ResourceClass::SRV:
+    return RegisterType::SRV;
+  case llvm::dxil::ResourceClass::UAV:
+    return RegisterType::UAV;
+  case llvm::dxil::ResourceClass::CBuffer:
+    return RegisterType::CBuffer;
+  case llvm::dxil::ResourceClass::Sampler:
+    return RegisterType::Sampler;
+  }
+  llvm_unreachable("unexpected ResourceClass value");
+}
+
+static RegisterType getRegisterType(StringRef Slot) {
+  switch (Slot[0]) {
+  case 't':
+  case 'T':
+    return RegisterType::SRV;
+  case 'u':
+  case 'U':
+    return RegisterType::UAV;
+  case 'b':
+  case 'B':
+    return RegisterType::CBuffer;
+  case 's':
+  case 'S':
+    return RegisterType::Sampler;
+  case 'c':
+  case 'C':
+    return RegisterType::C;
+  case 'i':
+  case 'I':
+    return RegisterType::I;
+  default:
+    return RegisterType::Invalid;
+  }
+}
+
 SemaHLSL::SemaHLSL(Sema &S) : SemaBase(S) {}
 
 Decl *SemaHLSL::ActOnStartBuffer(Scope *BufferScope, bool CBuffer,
@@ -739,28 +780,6 @@ SemaHLSL::TakeLocForHLSLAttribute(const HLSLAttributedResourceType *RT) {
   return LocInfo;
 }
 
-struct RegisterBindingFlags {
-  bool Resource = false;
-  bool UDT = false;
-  bool Other = false;
-  bool Basic = false;
-
-  bool SRV = false;
-  bool UAV = false;
-  bool CBV = false;
-  bool Sampler = false;
-
-  bool ContainsNumeric = false;
-  bool DefaultGlobals = false;
-
-  // used only when Resource == true
-  std::optional<llvm::dxil::ResourceClass> ResourceClass;
-};
-
-static bool isDeclaredWithinCOrTBuffer(const Decl *TheDecl) {
-  return TheDecl && isa<HLSLBufferDecl>(TheDecl->getDeclContext());
-}
-
 // get the record decl from a var decl that we expect
 // represents a resource
 static CXXRecordDecl *getRecordDeclFromVarDecl(VarDecl *VD) {
@@ -775,24 +794,6 @@ static CXXRecordDecl *getRecordDeclFromVarDecl(VarDecl *VD) {
   return TheRecordDecl;
 }
 
-static void updateResourceClassFlagsFromDeclResourceClass(
-    RegisterBindingFlags &Flags, llvm::hlsl::ResourceClass DeclResourceClass) {
-  switch (DeclResourceClass) {
-  case llvm::hlsl::ResourceClass::SRV:
-    Flags.SRV = true;
-    break;
-  case llvm::hlsl::ResourceClass::UAV:
-    Flags.UAV = true;
-    break;
-  case llvm::hlsl::ResourceClass::CBuffer:
-    Flags.CBV = true;
-    break;
-  case llvm::hlsl::ResourceClass::Sampler:
-    Flags.Sampler = true;
-    break;
-  }
-}
-
 const HLSLAttributedResourceType *
 findAttributedResourceTypeOnField(VarDecl *VD) {
   assert(VD != nullptr && "expected VarDecl");
@@ -806,8 +807,10 @@ findAttributedResourceTypeOnField(VarDecl *VD) {
   return nullptr;
 }
 
-static void updateResourceClassFlagsFromRecordType(RegisterBindingFlags &Flags,
-                                                   const RecordType *RT) {
+// Iterate over RecordType fields and return true if any of them matched the
+// register type
+static bool ContainsResourceForRegisterType(Sema &S, const RecordType *RT,
+                                            RegisterType RegType) {
   llvm::SmallVector<const Type *> TypesToScan;
   TypesToScan.emplace_back(RT);
 
@@ -816,8 +819,8 @@ static void updateResourceClassFlagsFromRecordType(RegisterBindingFlags &Flags,
     while (T->isArrayType())
       T = T->getArrayElementTypeNoTypeQual();
     if (T->isIntegralOrEnumerationType() || T->isFloatingType()) {
-      Flags.ContainsNumeric = true;
-      continue;
+      if (RegType == RegisterType::C)
+        return true;
     }
     const RecordType *RT = T->getAs<RecordType>();
     if (!RT)
@@ -828,101 +831,85 @@ static void updateResourceClassFlagsFromRecordType(RegisterBindingFlags &Flags,
       const Type *FieldTy = FD->getType().getTypePtr();
       if (const HLSLAttributedResourceType *AttrResType =
               dyn_cast<HLSLAttributedResourceType>(FieldTy)) {
-        updateResourceClassFlagsFromDeclResourceClass(
-            Flags, AttrResType->getAttrs().ResourceClass);
+        llvm::dxil::ResourceClass RC = AttrResType->getAttrs().ResourceClass;
+        if (getRegisterType(RC) == RegType)
+          return true;
         continue;
       }
       TypesToScan.emplace_back(FD->getType().getTypePtr());
     }
   }
+  return false;
 }
 
-static RegisterBindingFlags HLSLFillRegisterBindingFlags(Sema &S,
-                                                         Decl *TheDecl) {
-  RegisterBindingFlags Flags;
+static void CheckContainsResourceForRegisterType(Sema &S,
+                                                 SourceLocation &ArgLoc,
+                                                 Decl *D,
+                                                 RegisterType RegType) {
+  int RegTypeNum = static_cast<int>(RegType);
 
   // check if the decl type is groupshared
-  if (TheDecl->hasAttr<HLSLGroupSharedAddressSpaceAttr>()) {
-    Flags.Other = true;
-    return Flags;
+  if (D->hasAttr<HLSLGroupSharedAddressSpaceAttr>()) {
+    S.Diag(ArgLoc, diag::err_hlsl_binding_type_mismatch) << RegTypeNum;
+    return;
   }
 
   // Cbuffers and Tbuffers are HLSLBufferDecl types
-  if (HLSLBufferDecl *CBufferOrTBuffer = dyn_cast<HLSLBufferDecl>(TheDecl)) {
-    Flags.Resource = true;
-    Flags.ResourceClass = CBufferOrTBuffer->isCBuffer()
-                              ? llvm::dxil::ResourceClass::CBuffer
-                              : llvm::dxil::ResourceClass::SRV;
+  if (HLSLBufferDecl *CBufferOrTBuffer = dyn_cast<HLSLBufferDecl>(D)) {
+    llvm::dxil::ResourceClass RC = CBufferOrTBuffer->isCBuffer()
+                                       ? llvm::dxil::ResourceClass::CBuffer
+                                       : llvm::dxil::ResourceClass::SRV;
+    if (RegType != getRegisterType(RC))
+      S.Diag(D->getLocation(), diag::err_hlsl_binding_type_mismatch)
+          << RegTypeNum;
+    return;
   }
   // Samplers, UAVs, and SRVs are VarDecl types
-  else if (VarDecl *TheVarDecl = dyn_cast<VarDecl>(TheDecl)) {
+  if (VarDecl *TheVarDecl = dyn_cast<VarDecl>(D)) {
+    // Resource
     if (const HLSLAttributedResourceType *AttrResType =
             findAttributedResourceTypeOnField(TheVarDecl)) {
-      Flags.Resource = true;
-      Flags.ResourceClass = AttrResType->getAttrs().ResourceClass;
-    } else {
-      const clang::Type *TheBaseType = TheVarDecl->getType().getTypePtr();
-      while (TheBaseType->isArrayType())
-        TheBaseType = TheBaseType->getArrayElementTypeNoTypeQual();
-
-      if (TheBaseType->isArithmeticType()) {
-        Flags.Basic = true;
-        if (!isDeclaredWithinCOrTBuffer(TheDecl) &&
-            (TheBaseType->isIntegralType(S.getASTContext()) ||
-             TheBaseType->isFloatingType()))
-          Flags.DefaultGlobals = true;
-      } else if (TheBaseType->isRecordType()) {
-        Flags.UDT = true;
-        const RecordType *TheRecordTy = TheBaseType->getAs<RecordType>();
-        updateResourceClassFlagsFromRecordType(Flags, TheRecordTy);
-      } else
-        Flags.Other = true;
+      if (RegType != getRegisterType(AttrResType->getAttrs().ResourceClass))
+        S.Diag(D->getLocation(), diag::err_hlsl_binding_type_mismatch)
+            << RegTypeNum;
+      return;
     }
-  } else {
-    llvm_unreachable("expected be VarDecl or HLSLBufferDecl");
-  }
-  return Flags;
-}
 
-enum class RegisterType { SRV, UAV, CBuffer, Sampler, C, I, Invalid };
-
-static RegisterType getRegisterType(llvm::dxil::ResourceClass RC) {
-  switch (RC) {
-  case llvm::dxil::ResourceClass::SRV:
-    return RegisterType::SRV;
-  case llvm::dxil::ResourceClass::UAV:
-    return RegisterType::UAV;
-  case llvm::dxil::ResourceClass::CBuffer:
-    return RegisterType::CBuffer;
-  case llvm::dxil::ResourceClass::Sampler:
-    return RegisterType::Sampler;
-  }
-  llvm_unreachable("unexpected ResourceClass value");
-}
-
-static RegisterType getRegisterType(StringRef Slot) {
-  switch (Slot[0]) {
-  case 't':
-  case 'T':
-    return RegisterType::SRV;
-  case 'u':
-  case 'U':
-    return RegisterType::UAV;
-  case 'b':
-  case 'B':
-    return RegisterType::CBuffer;
-  case 's':
-  case 'S':
-    return RegisterType::Sampler;
-  case 'c':
-  case 'C':
-    return RegisterType::C;
-  case 'i':
-  case 'I':
-    return RegisterType::I;
-  default:
-    return RegisterType::Invalid;
+    const clang::Type *TheBaseType = TheVarDecl->getType().getTypePtr();
+    while (TheBaseType->isArrayType())
+      TheBaseType = TheBaseType->getArrayElementTypeNoTypeQual();
+
+    // Basic types
+    if (TheBaseType->isArithmeticType()) {
+      if (!isa<HLSLBufferDecl>(D->getDeclContext()) &&
+          (TheBaseType->isIntegralType(S.getASTContext()) ||
+           TheBaseType->isFloatingType())) {
+        // Default Globals
+        if (RegType == RegisterType::CBuffer)
+          S.Diag(ArgLoc, diag::warn_hlsl_deprecated_register_type_b);
+        else if (RegType != RegisterType::C)
+          S.Diag(ArgLoc, diag::err_hlsl_binding_type_mismatch) << RegTypeNum;
+      } else {
+        if (RegType == RegisterType::C)
+          S.Diag(ArgLoc, diag::warn_hlsl_register_type_c_packoffset);
+        else
+          S.Diag(ArgLoc, diag::err_hlsl_binding_type_mismatch) << RegTypeNum;
+      }
+    } else if (TheBaseType->isRecordType()) {
+      // Class/struct types - walk the declaration and check each field and
+      // subclass
+      if (!ContainsResourceForRegisterType(S, TheBaseType->getAs<RecordType>(),
+                                           RegType))
+        S.Diag(D->getLocation(),
+               diag::warn_hlsl_user_defined_type_missing_member)
+            << RegTypeNum;
+    } else {
+      // Anything else is an error
+      S.Diag(ArgLoc, diag::err_hlsl_binding_type_mismatch) << RegTypeNum;
+    }
+    return;
   }
+  llvm_unreachable("expected be VarDecl or HLSLBufferDecl");
 }
 
 static void ValidateMultipleRegisterAnnotations(Sema &S, Decl *TheDecl,
@@ -958,73 +945,18 @@ static void ValidateMultipleRegisterAnnotations(Sema &S, Decl *TheDecl,
 }
 
 static void DiagnoseHLSLRegisterAttribute(Sema &S, SourceLocation &ArgLoc,
-                                          Decl *TheDecl, RegisterType regType) {
+                                          Decl *D, RegisterType RegType) {
 
   // exactly one of these two types should be set
-  assert(((isa<VarDecl>(TheDecl) && !isa<HLSLBufferDecl>(TheDecl)) ||
-          (!isa<VarDecl>(TheDecl) && isa<HLSLBufferDecl>(TheDecl))) &&
+  assert(((isa<VarDecl>(D) && !isa<HLSLBufferDecl>(D)) ||
+          (!isa<VarDecl>(D) && isa<HLSLBufferDecl>(D))) &&
          "expecting VarDecl or HLSLBufferDecl");
 
-  RegisterBindingFlags Flags = HLSLFillRegisterBindingFlags(S, TheDecl);
-  assert((int)Flags.Other + (int)Flags.Resource + (int)Flags.Basic +
-                 (int)Flags.UDT ==
-             1 &&
-         "only one resource analysis result should be expected");
-
-  int regTypeNum = static_cast<int>(regType);
-
-  // first, if "other" is set, emit an error
-  if (Flags.Other) {
-    S.Diag(ArgLoc, diag::err_hlsl_binding_type_mismatch) << regTypeNum;
-    return;
-  }
+  // check if the declaration contains resource matching the register type
+  CheckContainsResourceForRegisterType(S, ArgLoc, D, RegType);
 
   // next, if multiple register annotations exist, check that none conflict.
-  ValidateMultipleRegisterAnnotations(S, TheDecl, regType);
-
-  // next, if resource is set, make sure the register type in the register
-  // annotation is compatible with the variable's resource type.
-  if (Flags.Resource) {
-    RegisterType expRegType = getRegisterType(Flags.ResourceClass.value());
-    if (regType != expRegType) {
-      S.Diag(TheDecl->getLocation(), diag::err_hlsl_binding_type_mismatch)
-          << regTypeNum;
-    }
-    return;
-  }
-
-  // next, handle diagnostics for when the "basic" flag is set
-  if (Flags.Basic) {
-    if (Flags.DefaultGlobals) {
-      if (regType == RegisterType::CBuffer)
-        S.Diag(ArgLoc, diag::warn_hlsl_deprecated_register_type_b);
-      else if (regType != RegisterType::C)
-        S.Diag(ArgLoc, diag::err_hlsl_binding_type_mismatch) << regTypeNum;
-      return;
-    }
-
-    if (regType == RegisterType::C)
-      S.Diag(ArgLoc, diag::warn_hlsl_register_type_c_packoffset);
-    else
-      S.Diag(ArgLoc, diag::err_hlsl_binding_type_mismatch) << regTypeNum;
-
-    return;
-  }
-
-  // finally, we handle the udt case
-  if (Flags.UDT) {
-    const bool ExpectedRegisterTypesForUDT[] = {
-        Flags.SRV, Flags.UAV, Flags.CBV, Flags.Sampler, Flags.ContainsNumeric};
-    assert((size_t)regTypeNum < std::size(ExpectedRegisterTypesForUDT) &&
-           "regType has unexpected value");
-
-    if (!ExpectedRegisterTypesForUDT[regTypeNum])
-      S.Diag(TheDecl->getLocation(),
-             diag::warn_hlsl_user_defined_type_missing_member)
-          << regTypeNum;
-
-    return;
-  }
+  ValidateMultipleRegisterAnnotations(S, D, RegType);
 }
 
 void SemaHLSL::handleResourceBindingAttr(Decl *TheDecl, const ParsedAttr &AL) {

@llvmbot
Copy link
Member

llvmbot commented Sep 17, 2024

@llvm/pr-subscribers-hlsl

Author: Helena Kotas (hekota)

Changes

When diagnosing register bindings we just need to make sure there is a resource that matches the provided register type. We can emit the diagnostics right away instead of collecting flags in the RegisterBindingFlags struct. That also enables early exit when scanning user defined types because we can return as soon as we find a matching resource for the given register type.


Full diff: https://github.com/llvm/llvm-project/pull/108924.diff

1 Files Affected:

  • (modified) clang/lib/Sema/SemaHLSL.cpp (+113-181)
diff --git a/clang/lib/Sema/SemaHLSL.cpp b/clang/lib/Sema/SemaHLSL.cpp
index 26de9a986257c5..45a8d75ce099cc 100644
--- a/clang/lib/Sema/SemaHLSL.cpp
+++ b/clang/lib/Sema/SemaHLSL.cpp
@@ -41,6 +41,47 @@
 
 using namespace clang;
 
+enum class RegisterType { SRV, UAV, CBuffer, Sampler, C, I, Invalid };
+
+static RegisterType getRegisterType(llvm::dxil::ResourceClass RC) {
+  switch (RC) {
+  case llvm::dxil::ResourceClass::SRV:
+    return RegisterType::SRV;
+  case llvm::dxil::ResourceClass::UAV:
+    return RegisterType::UAV;
+  case llvm::dxil::ResourceClass::CBuffer:
+    return RegisterType::CBuffer;
+  case llvm::dxil::ResourceClass::Sampler:
+    return RegisterType::Sampler;
+  }
+  llvm_unreachable("unexpected ResourceClass value");
+}
+
+static RegisterType getRegisterType(StringRef Slot) {
+  switch (Slot[0]) {
+  case 't':
+  case 'T':
+    return RegisterType::SRV;
+  case 'u':
+  case 'U':
+    return RegisterType::UAV;
+  case 'b':
+  case 'B':
+    return RegisterType::CBuffer;
+  case 's':
+  case 'S':
+    return RegisterType::Sampler;
+  case 'c':
+  case 'C':
+    return RegisterType::C;
+  case 'i':
+  case 'I':
+    return RegisterType::I;
+  default:
+    return RegisterType::Invalid;
+  }
+}
+
 SemaHLSL::SemaHLSL(Sema &S) : SemaBase(S) {}
 
 Decl *SemaHLSL::ActOnStartBuffer(Scope *BufferScope, bool CBuffer,
@@ -739,28 +780,6 @@ SemaHLSL::TakeLocForHLSLAttribute(const HLSLAttributedResourceType *RT) {
   return LocInfo;
 }
 
-struct RegisterBindingFlags {
-  bool Resource = false;
-  bool UDT = false;
-  bool Other = false;
-  bool Basic = false;
-
-  bool SRV = false;
-  bool UAV = false;
-  bool CBV = false;
-  bool Sampler = false;
-
-  bool ContainsNumeric = false;
-  bool DefaultGlobals = false;
-
-  // used only when Resource == true
-  std::optional<llvm::dxil::ResourceClass> ResourceClass;
-};
-
-static bool isDeclaredWithinCOrTBuffer(const Decl *TheDecl) {
-  return TheDecl && isa<HLSLBufferDecl>(TheDecl->getDeclContext());
-}
-
 // get the record decl from a var decl that we expect
 // represents a resource
 static CXXRecordDecl *getRecordDeclFromVarDecl(VarDecl *VD) {
@@ -775,24 +794,6 @@ static CXXRecordDecl *getRecordDeclFromVarDecl(VarDecl *VD) {
   return TheRecordDecl;
 }
 
-static void updateResourceClassFlagsFromDeclResourceClass(
-    RegisterBindingFlags &Flags, llvm::hlsl::ResourceClass DeclResourceClass) {
-  switch (DeclResourceClass) {
-  case llvm::hlsl::ResourceClass::SRV:
-    Flags.SRV = true;
-    break;
-  case llvm::hlsl::ResourceClass::UAV:
-    Flags.UAV = true;
-    break;
-  case llvm::hlsl::ResourceClass::CBuffer:
-    Flags.CBV = true;
-    break;
-  case llvm::hlsl::ResourceClass::Sampler:
-    Flags.Sampler = true;
-    break;
-  }
-}
-
 const HLSLAttributedResourceType *
 findAttributedResourceTypeOnField(VarDecl *VD) {
   assert(VD != nullptr && "expected VarDecl");
@@ -806,8 +807,10 @@ findAttributedResourceTypeOnField(VarDecl *VD) {
   return nullptr;
 }
 
-static void updateResourceClassFlagsFromRecordType(RegisterBindingFlags &Flags,
-                                                   const RecordType *RT) {
+// Iterate over RecordType fields and return true if any of them matched the
+// register type
+static bool ContainsResourceForRegisterType(Sema &S, const RecordType *RT,
+                                            RegisterType RegType) {
   llvm::SmallVector<const Type *> TypesToScan;
   TypesToScan.emplace_back(RT);
 
@@ -816,8 +819,8 @@ static void updateResourceClassFlagsFromRecordType(RegisterBindingFlags &Flags,
     while (T->isArrayType())
       T = T->getArrayElementTypeNoTypeQual();
     if (T->isIntegralOrEnumerationType() || T->isFloatingType()) {
-      Flags.ContainsNumeric = true;
-      continue;
+      if (RegType == RegisterType::C)
+        return true;
     }
     const RecordType *RT = T->getAs<RecordType>();
     if (!RT)
@@ -828,101 +831,85 @@ static void updateResourceClassFlagsFromRecordType(RegisterBindingFlags &Flags,
       const Type *FieldTy = FD->getType().getTypePtr();
       if (const HLSLAttributedResourceType *AttrResType =
               dyn_cast<HLSLAttributedResourceType>(FieldTy)) {
-        updateResourceClassFlagsFromDeclResourceClass(
-            Flags, AttrResType->getAttrs().ResourceClass);
+        llvm::dxil::ResourceClass RC = AttrResType->getAttrs().ResourceClass;
+        if (getRegisterType(RC) == RegType)
+          return true;
         continue;
       }
       TypesToScan.emplace_back(FD->getType().getTypePtr());
     }
   }
+  return false;
 }
 
-static RegisterBindingFlags HLSLFillRegisterBindingFlags(Sema &S,
-                                                         Decl *TheDecl) {
-  RegisterBindingFlags Flags;
+static void CheckContainsResourceForRegisterType(Sema &S,
+                                                 SourceLocation &ArgLoc,
+                                                 Decl *D,
+                                                 RegisterType RegType) {
+  int RegTypeNum = static_cast<int>(RegType);
 
   // check if the decl type is groupshared
-  if (TheDecl->hasAttr<HLSLGroupSharedAddressSpaceAttr>()) {
-    Flags.Other = true;
-    return Flags;
+  if (D->hasAttr<HLSLGroupSharedAddressSpaceAttr>()) {
+    S.Diag(ArgLoc, diag::err_hlsl_binding_type_mismatch) << RegTypeNum;
+    return;
   }
 
   // Cbuffers and Tbuffers are HLSLBufferDecl types
-  if (HLSLBufferDecl *CBufferOrTBuffer = dyn_cast<HLSLBufferDecl>(TheDecl)) {
-    Flags.Resource = true;
-    Flags.ResourceClass = CBufferOrTBuffer->isCBuffer()
-                              ? llvm::dxil::ResourceClass::CBuffer
-                              : llvm::dxil::ResourceClass::SRV;
+  if (HLSLBufferDecl *CBufferOrTBuffer = dyn_cast<HLSLBufferDecl>(D)) {
+    llvm::dxil::ResourceClass RC = CBufferOrTBuffer->isCBuffer()
+                                       ? llvm::dxil::ResourceClass::CBuffer
+                                       : llvm::dxil::ResourceClass::SRV;
+    if (RegType != getRegisterType(RC))
+      S.Diag(D->getLocation(), diag::err_hlsl_binding_type_mismatch)
+          << RegTypeNum;
+    return;
   }
   // Samplers, UAVs, and SRVs are VarDecl types
-  else if (VarDecl *TheVarDecl = dyn_cast<VarDecl>(TheDecl)) {
+  if (VarDecl *TheVarDecl = dyn_cast<VarDecl>(D)) {
+    // Resource
     if (const HLSLAttributedResourceType *AttrResType =
             findAttributedResourceTypeOnField(TheVarDecl)) {
-      Flags.Resource = true;
-      Flags.ResourceClass = AttrResType->getAttrs().ResourceClass;
-    } else {
-      const clang::Type *TheBaseType = TheVarDecl->getType().getTypePtr();
-      while (TheBaseType->isArrayType())
-        TheBaseType = TheBaseType->getArrayElementTypeNoTypeQual();
-
-      if (TheBaseType->isArithmeticType()) {
-        Flags.Basic = true;
-        if (!isDeclaredWithinCOrTBuffer(TheDecl) &&
-            (TheBaseType->isIntegralType(S.getASTContext()) ||
-             TheBaseType->isFloatingType()))
-          Flags.DefaultGlobals = true;
-      } else if (TheBaseType->isRecordType()) {
-        Flags.UDT = true;
-        const RecordType *TheRecordTy = TheBaseType->getAs<RecordType>();
-        updateResourceClassFlagsFromRecordType(Flags, TheRecordTy);
-      } else
-        Flags.Other = true;
+      if (RegType != getRegisterType(AttrResType->getAttrs().ResourceClass))
+        S.Diag(D->getLocation(), diag::err_hlsl_binding_type_mismatch)
+            << RegTypeNum;
+      return;
     }
-  } else {
-    llvm_unreachable("expected be VarDecl or HLSLBufferDecl");
-  }
-  return Flags;
-}
 
-enum class RegisterType { SRV, UAV, CBuffer, Sampler, C, I, Invalid };
-
-static RegisterType getRegisterType(llvm::dxil::ResourceClass RC) {
-  switch (RC) {
-  case llvm::dxil::ResourceClass::SRV:
-    return RegisterType::SRV;
-  case llvm::dxil::ResourceClass::UAV:
-    return RegisterType::UAV;
-  case llvm::dxil::ResourceClass::CBuffer:
-    return RegisterType::CBuffer;
-  case llvm::dxil::ResourceClass::Sampler:
-    return RegisterType::Sampler;
-  }
-  llvm_unreachable("unexpected ResourceClass value");
-}
-
-static RegisterType getRegisterType(StringRef Slot) {
-  switch (Slot[0]) {
-  case 't':
-  case 'T':
-    return RegisterType::SRV;
-  case 'u':
-  case 'U':
-    return RegisterType::UAV;
-  case 'b':
-  case 'B':
-    return RegisterType::CBuffer;
-  case 's':
-  case 'S':
-    return RegisterType::Sampler;
-  case 'c':
-  case 'C':
-    return RegisterType::C;
-  case 'i':
-  case 'I':
-    return RegisterType::I;
-  default:
-    return RegisterType::Invalid;
+    const clang::Type *TheBaseType = TheVarDecl->getType().getTypePtr();
+    while (TheBaseType->isArrayType())
+      TheBaseType = TheBaseType->getArrayElementTypeNoTypeQual();
+
+    // Basic types
+    if (TheBaseType->isArithmeticType()) {
+      if (!isa<HLSLBufferDecl>(D->getDeclContext()) &&
+          (TheBaseType->isIntegralType(S.getASTContext()) ||
+           TheBaseType->isFloatingType())) {
+        // Default Globals
+        if (RegType == RegisterType::CBuffer)
+          S.Diag(ArgLoc, diag::warn_hlsl_deprecated_register_type_b);
+        else if (RegType != RegisterType::C)
+          S.Diag(ArgLoc, diag::err_hlsl_binding_type_mismatch) << RegTypeNum;
+      } else {
+        if (RegType == RegisterType::C)
+          S.Diag(ArgLoc, diag::warn_hlsl_register_type_c_packoffset);
+        else
+          S.Diag(ArgLoc, diag::err_hlsl_binding_type_mismatch) << RegTypeNum;
+      }
+    } else if (TheBaseType->isRecordType()) {
+      // Class/struct types - walk the declaration and check each field and
+      // subclass
+      if (!ContainsResourceForRegisterType(S, TheBaseType->getAs<RecordType>(),
+                                           RegType))
+        S.Diag(D->getLocation(),
+               diag::warn_hlsl_user_defined_type_missing_member)
+            << RegTypeNum;
+    } else {
+      // Anything else is an error
+      S.Diag(ArgLoc, diag::err_hlsl_binding_type_mismatch) << RegTypeNum;
+    }
+    return;
   }
+  llvm_unreachable("expected be VarDecl or HLSLBufferDecl");
 }
 
 static void ValidateMultipleRegisterAnnotations(Sema &S, Decl *TheDecl,
@@ -958,73 +945,18 @@ static void ValidateMultipleRegisterAnnotations(Sema &S, Decl *TheDecl,
 }
 
 static void DiagnoseHLSLRegisterAttribute(Sema &S, SourceLocation &ArgLoc,
-                                          Decl *TheDecl, RegisterType regType) {
+                                          Decl *D, RegisterType RegType) {
 
   // exactly one of these two types should be set
-  assert(((isa<VarDecl>(TheDecl) && !isa<HLSLBufferDecl>(TheDecl)) ||
-          (!isa<VarDecl>(TheDecl) && isa<HLSLBufferDecl>(TheDecl))) &&
+  assert(((isa<VarDecl>(D) && !isa<HLSLBufferDecl>(D)) ||
+          (!isa<VarDecl>(D) && isa<HLSLBufferDecl>(D))) &&
          "expecting VarDecl or HLSLBufferDecl");
 
-  RegisterBindingFlags Flags = HLSLFillRegisterBindingFlags(S, TheDecl);
-  assert((int)Flags.Other + (int)Flags.Resource + (int)Flags.Basic +
-                 (int)Flags.UDT ==
-             1 &&
-         "only one resource analysis result should be expected");
-
-  int regTypeNum = static_cast<int>(regType);
-
-  // first, if "other" is set, emit an error
-  if (Flags.Other) {
-    S.Diag(ArgLoc, diag::err_hlsl_binding_type_mismatch) << regTypeNum;
-    return;
-  }
+  // check if the declaration contains resource matching the register type
+  CheckContainsResourceForRegisterType(S, ArgLoc, D, RegType);
 
   // next, if multiple register annotations exist, check that none conflict.
-  ValidateMultipleRegisterAnnotations(S, TheDecl, regType);
-
-  // next, if resource is set, make sure the register type in the register
-  // annotation is compatible with the variable's resource type.
-  if (Flags.Resource) {
-    RegisterType expRegType = getRegisterType(Flags.ResourceClass.value());
-    if (regType != expRegType) {
-      S.Diag(TheDecl->getLocation(), diag::err_hlsl_binding_type_mismatch)
-          << regTypeNum;
-    }
-    return;
-  }
-
-  // next, handle diagnostics for when the "basic" flag is set
-  if (Flags.Basic) {
-    if (Flags.DefaultGlobals) {
-      if (regType == RegisterType::CBuffer)
-        S.Diag(ArgLoc, diag::warn_hlsl_deprecated_register_type_b);
-      else if (regType != RegisterType::C)
-        S.Diag(ArgLoc, diag::err_hlsl_binding_type_mismatch) << regTypeNum;
-      return;
-    }
-
-    if (regType == RegisterType::C)
-      S.Diag(ArgLoc, diag::warn_hlsl_register_type_c_packoffset);
-    else
-      S.Diag(ArgLoc, diag::err_hlsl_binding_type_mismatch) << regTypeNum;
-
-    return;
-  }
-
-  // finally, we handle the udt case
-  if (Flags.UDT) {
-    const bool ExpectedRegisterTypesForUDT[] = {
-        Flags.SRV, Flags.UAV, Flags.CBV, Flags.Sampler, Flags.ContainsNumeric};
-    assert((size_t)regTypeNum < std::size(ExpectedRegisterTypesForUDT) &&
-           "regType has unexpected value");
-
-    if (!ExpectedRegisterTypesForUDT[regTypeNum])
-      S.Diag(TheDecl->getLocation(),
-             diag::warn_hlsl_user_defined_type_missing_member)
-          << regTypeNum;
-
-    return;
-  }
+  ValidateMultipleRegisterAnnotations(S, D, RegType);
 }
 
 void SemaHLSL::handleResourceBindingAttr(Decl *TheDecl, const ParsedAttr &AL) {

@hekota hekota requested a review from damyanp September 17, 2024 04:05
Copy link
Collaborator

@llvm-beanz llvm-beanz left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

One minor nit inline, but I think this is fine either way.

Copy link
Contributor

@damyanp damyanp left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM - just a couple of nits.

This refactor is leaning heavily on the tests, it might be good to get a review from @bob80905 just in case there's something that the tests didn't cover that may need to be considered.

@hekota hekota requested a review from bob80905 September 17, 2024 17:19
+ couple of local variables names shortened for easier debugging when manually invoking dump() in debugger
@hekota hekota merged commit f212826 into llvm:main Sep 18, 2024
8 checks passed
Comment on lines +879 to +880
assert(isa<VarDecl>(D) && "D is expected to be VarDecl or HLSLBufferDecl");
VarDecl *VD = cast<VarDecl>(D);
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

FWIW cast<> will assert if the cast fails. It's debatable whether or not the slight amount of extra information from the specific message is worth its own assert, so this doesn't necessarily need to be changed, but I figured I'd point that out.

tmsri pushed a commit to tmsri/llvm-project that referenced this pull request Sep 19, 2024
When diagnosing register bindings we just need to make sure there is a
resource that matches the provided register type. We can emit the
diagnostics right away instead of collecting flags in the
RegisterBindingFlags struct. That also enables early exit when scanning
user defined types because we can return as soon as we find a matching
resource for the given register type.
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
clang:frontend Language frontend issues, e.g. anything involving "Sema" clang Clang issues not falling into any other category HLSL HLSL Language Support
Projects
Archived in project
Development

Successfully merging this pull request may close these issues.

7 participants