Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[clang][NFC] Move more functions to SemaHLSL #88354

Merged
merged 1 commit into from
Apr 12, 2024

Conversation

Endilll
Copy link
Contributor

@Endilll Endilll commented Apr 11, 2024

A follow-up to #87912. I'm moving more HLSL-related functions from Sema to SemaHLSL. I'm also dropping HLSL from their names in the process.

@Endilll Endilll added the clang:frontend Language frontend issues, e.g. anything involving "Sema" label Apr 11, 2024
@llvmbot llvmbot added clang Clang issues not falling into any other category HLSL HLSL Language Support labels Apr 11, 2024
@llvmbot
Copy link
Collaborator

llvmbot commented Apr 11, 2024

@llvm/pr-subscribers-hlsl

Author: Vlad Serebrennikov (Endilll)

Changes

A follow-up to #87912. I'm moving more HLSL-related functions from Sema to SemaHLSL. I'm also dropping HLSL from their names in the process.


Patch is 24.36 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/88354.diff

6 Files Affected:

  • (modified) clang/include/clang/Sema/Sema.h (-15)
  • (modified) clang/include/clang/Sema/SemaHLSL.h (+23-4)
  • (modified) clang/lib/Parse/ParseHLSL.cpp (+5-5)
  • (modified) clang/lib/Sema/SemaDecl.cpp (+6-124)
  • (modified) clang/lib/Sema/SemaDeclAttr.cpp (+4-50)
  • (modified) clang/lib/Sema/SemaHLSL.cpp (+180-6)
diff --git a/clang/include/clang/Sema/Sema.h b/clang/include/clang/Sema/Sema.h
index e3e255a0dd76f8..e904cd3ad13fb7 100644
--- a/clang/include/clang/Sema/Sema.h
+++ b/clang/include/clang/Sema/Sema.h
@@ -2940,13 +2940,6 @@ class Sema final : public SemaBase {
                                       QualType NewT, QualType OldT);
   void CheckMain(FunctionDecl *FD, const DeclSpec &D);
   void CheckMSVCRTEntryPoint(FunctionDecl *FD);
-  void ActOnHLSLTopLevelFunction(FunctionDecl *FD);
-  void CheckHLSLEntryPoint(FunctionDecl *FD);
-  void CheckHLSLSemanticAnnotation(FunctionDecl *EntryPoint, const Decl *Param,
-                                   const HLSLAnnotationAttr *AnnotationAttr);
-  void DiagnoseHLSLAttrStageMismatch(
-      const Attr *A, HLSLShaderAttr::ShaderType Stage,
-      std::initializer_list<HLSLShaderAttr::ShaderType> AllowedStages);
   Attr *getImplicitCodeSegOrSectionAttrForFunction(const FunctionDecl *FD,
                                                    bool IsDefinition);
   void CheckFunctionOrTemplateParamDeclarator(Scope *S, Declarator &D);
@@ -3707,14 +3700,6 @@ class Sema final : public SemaBase {
                           StringRef UuidAsWritten, MSGuidDecl *GuidDecl);
 
   BTFDeclTagAttr *mergeBTFDeclTagAttr(Decl *D, const BTFDeclTagAttr &AL);
-  HLSLNumThreadsAttr *mergeHLSLNumThreadsAttr(Decl *D,
-                                              const AttributeCommonInfo &AL,
-                                              int X, int Y, int Z);
-  HLSLShaderAttr *mergeHLSLShaderAttr(Decl *D, const AttributeCommonInfo &AL,
-                                      HLSLShaderAttr::ShaderType ShaderType);
-  HLSLParamModifierAttr *
-  mergeHLSLParamModifierAttr(Decl *D, const AttributeCommonInfo &AL,
-                             HLSLParamModifierAttr::Spelling Spelling);
 
   WebAssemblyImportNameAttr *
   mergeImportNameAttr(Decl *D, const WebAssemblyImportNameAttr &AL);
diff --git a/clang/include/clang/Sema/SemaHLSL.h b/clang/include/clang/Sema/SemaHLSL.h
index acc675963c23a5..34acaf19517f2a 100644
--- a/clang/include/clang/Sema/SemaHLSL.h
+++ b/clang/include/clang/Sema/SemaHLSL.h
@@ -13,12 +13,16 @@
 #ifndef LLVM_CLANG_SEMA_SEMAHLSL_H
 #define LLVM_CLANG_SEMA_SEMAHLSL_H
 
+#include "clang/AST/Attr.h"
+#include "clang/AST/Decl.h"
 #include "clang/AST/DeclBase.h"
 #include "clang/AST/Expr.h"
+#include "clang/Basic/AttributeCommonInfo.h"
 #include "clang/Basic/IdentifierTable.h"
 #include "clang/Basic/SourceLocation.h"
 #include "clang/Sema/Scope.h"
 #include "clang/Sema/SemaBase.h"
+#include <initializer_list>
 
 namespace clang {
 
@@ -26,10 +30,25 @@ class SemaHLSL : public SemaBase {
 public:
   SemaHLSL(Sema &S);
 
-  Decl *ActOnStartHLSLBuffer(Scope *BufferScope, bool CBuffer,
-                             SourceLocation KwLoc, IdentifierInfo *Ident,
-                             SourceLocation IdentLoc, SourceLocation LBrace);
-  void ActOnFinishHLSLBuffer(Decl *Dcl, SourceLocation RBrace);
+  Decl *ActOnStartBuffer(Scope *BufferScope, bool CBuffer, SourceLocation KwLoc,
+                         IdentifierInfo *Ident, SourceLocation IdentLoc,
+                         SourceLocation LBrace);
+  void ActOnFinishBuffer(Decl *Dcl, SourceLocation RBrace);
+  HLSLNumThreadsAttr *mergeNumThreadsAttr(Decl *D,
+                                          const AttributeCommonInfo &AL, int X,
+                                          int Y, int Z);
+  HLSLShaderAttr *mergeShaderAttr(Decl *D, const AttributeCommonInfo &AL,
+                                  HLSLShaderAttr::ShaderType ShaderType);
+  HLSLParamModifierAttr *
+  mergeParamModifierAttr(Decl *D, const AttributeCommonInfo &AL,
+                         HLSLParamModifierAttr::Spelling Spelling);
+  void ActOnTopLevelFunction(FunctionDecl *FD);
+  void CheckEntryPoint(FunctionDecl *FD);
+  void CheckSemanticAnnotation(FunctionDecl *EntryPoint, const Decl *Param,
+                               const HLSLAnnotationAttr *AnnotationAttr);
+  void DiagnoseAttrStageMismatch(
+      const Attr *A, HLSLShaderAttr::ShaderType Stage,
+      std::initializer_list<HLSLShaderAttr::ShaderType> AllowedStages);
 };
 
 } // namespace clang
diff --git a/clang/lib/Parse/ParseHLSL.cpp b/clang/lib/Parse/ParseHLSL.cpp
index 5afc958600fa55..d97985d42369ad 100644
--- a/clang/lib/Parse/ParseHLSL.cpp
+++ b/clang/lib/Parse/ParseHLSL.cpp
@@ -72,9 +72,9 @@ Decl *Parser::ParseHLSLBuffer(SourceLocation &DeclEnd) {
     return nullptr;
   }
 
-  Decl *D = Actions.HLSL().ActOnStartHLSLBuffer(
-      getCurScope(), IsCBuffer, BufferLoc, Identifier, IdentifierLoc,
-      T.getOpenLocation());
+  Decl *D = Actions.HLSL().ActOnStartBuffer(getCurScope(), IsCBuffer, BufferLoc,
+                                            Identifier, IdentifierLoc,
+                                            T.getOpenLocation());
 
   while (Tok.isNot(tok::r_brace) && Tok.isNot(tok::eof)) {
     // FIXME: support attribute on constants inside cbuffer/tbuffer.
@@ -88,7 +88,7 @@ Decl *Parser::ParseHLSLBuffer(SourceLocation &DeclEnd) {
       T.skipToEnd();
       DeclEnd = T.getCloseLocation();
       BufferScope.Exit();
-      Actions.HLSL().ActOnFinishHLSLBuffer(D, DeclEnd);
+      Actions.HLSL().ActOnFinishBuffer(D, DeclEnd);
       return nullptr;
     }
   }
@@ -96,7 +96,7 @@ Decl *Parser::ParseHLSLBuffer(SourceLocation &DeclEnd) {
   T.consumeClose();
   DeclEnd = T.getCloseLocation();
   BufferScope.Exit();
-  Actions.HLSL().ActOnFinishHLSLBuffer(D, DeclEnd);
+  Actions.HLSL().ActOnFinishBuffer(D, DeclEnd);
 
   Actions.ProcessDeclAttributeList(Actions.CurScope, D, Attrs);
   return D;
diff --git a/clang/lib/Sema/SemaDecl.cpp b/clang/lib/Sema/SemaDecl.cpp
index 8472aaeb6bad97..3beb4fb1f8c733 100644
--- a/clang/lib/Sema/SemaDecl.cpp
+++ b/clang/lib/Sema/SemaDecl.cpp
@@ -45,6 +45,7 @@
 #include "clang/Sema/ParsedTemplate.h"
 #include "clang/Sema/Scope.h"
 #include "clang/Sema/ScopeInfo.h"
+#include "clang/Sema/SemaHLSL.h"
 #include "clang/Sema/SemaInternal.h"
 #include "clang/Sema/Template.h"
 #include "llvm/ADT/SmallString.h"
@@ -2972,10 +2973,10 @@ static bool mergeDeclAttribute(Sema &S, NamedDecl *D,
   else if (const auto *BTFA = dyn_cast<BTFDeclTagAttr>(Attr))
     NewAttr = S.mergeBTFDeclTagAttr(D, *BTFA);
   else if (const auto *NT = dyn_cast<HLSLNumThreadsAttr>(Attr))
-    NewAttr =
-        S.mergeHLSLNumThreadsAttr(D, *NT, NT->getX(), NT->getY(), NT->getZ());
+    NewAttr = S.HLSL().mergeNumThreadsAttr(D, *NT, NT->getX(), NT->getY(),
+                                           NT->getZ());
   else if (const auto *SA = dyn_cast<HLSLShaderAttr>(Attr))
-    NewAttr = S.mergeHLSLShaderAttr(D, *SA, SA->getType());
+    NewAttr = S.HLSL().mergeShaderAttr(D, *SA, SA->getType());
   else if (isa<SuppressAttr>(Attr))
     // Do nothing. Each redeclaration should be suppressed separately.
     NewAttr = nullptr;
@@ -10809,10 +10810,10 @@ Sema::ActOnFunctionDeclarator(Scope *S, Declarator &D, DeclContext *DC,
   if (getLangOpts().HLSL && D.isFunctionDefinition()) {
     // Any top level function could potentially be specified as an entry.
     if (!NewFD->isInvalidDecl() && S->getDepth() == 0 && Name.isIdentifier())
-      ActOnHLSLTopLevelFunction(NewFD);
+      HLSL().ActOnTopLevelFunction(NewFD);
 
     if (NewFD->hasAttr<HLSLShaderAttr>())
-      CheckHLSLEntryPoint(NewFD);
+      HLSL().CheckEntryPoint(NewFD);
   }
 
   // If this is the first declaration of a library builtin function, add
@@ -12660,125 +12661,6 @@ void Sema::CheckMSVCRTEntryPoint(FunctionDecl *FD) {
   }
 }
 
-void Sema::ActOnHLSLTopLevelFunction(FunctionDecl *FD) {
-  auto &TargetInfo = getASTContext().getTargetInfo();
-
-  if (FD->getName() != TargetInfo.getTargetOpts().HLSLEntry)
-    return;
-
-  StringRef Env = TargetInfo.getTriple().getEnvironmentName();
-  HLSLShaderAttr::ShaderType ShaderType;
-  if (HLSLShaderAttr::ConvertStrToShaderType(Env, ShaderType)) {
-    if (const auto *Shader = FD->getAttr<HLSLShaderAttr>()) {
-      // The entry point is already annotated - check that it matches the
-      // triple.
-      if (Shader->getType() != ShaderType) {
-        Diag(Shader->getLocation(), diag::err_hlsl_entry_shader_attr_mismatch)
-            << Shader;
-        FD->setInvalidDecl();
-      }
-    } else {
-      // Implicitly add the shader attribute if the entry function isn't
-      // explicitly annotated.
-      FD->addAttr(HLSLShaderAttr::CreateImplicit(Context, ShaderType,
-                                                 FD->getBeginLoc()));
-    }
-  } else {
-    switch (TargetInfo.getTriple().getEnvironment()) {
-    case llvm::Triple::UnknownEnvironment:
-    case llvm::Triple::Library:
-      break;
-    default:
-      llvm_unreachable("Unhandled environment in triple");
-    }
-  }
-}
-
-void Sema::CheckHLSLEntryPoint(FunctionDecl *FD) {
-  const auto *ShaderAttr = FD->getAttr<HLSLShaderAttr>();
-  assert(ShaderAttr && "Entry point has no shader attribute");
-  HLSLShaderAttr::ShaderType ST = ShaderAttr->getType();
-
-  switch (ST) {
-  case HLSLShaderAttr::Pixel:
-  case HLSLShaderAttr::Vertex:
-  case HLSLShaderAttr::Geometry:
-  case HLSLShaderAttr::Hull:
-  case HLSLShaderAttr::Domain:
-  case HLSLShaderAttr::RayGeneration:
-  case HLSLShaderAttr::Intersection:
-  case HLSLShaderAttr::AnyHit:
-  case HLSLShaderAttr::ClosestHit:
-  case HLSLShaderAttr::Miss:
-  case HLSLShaderAttr::Callable:
-    if (const auto *NT = FD->getAttr<HLSLNumThreadsAttr>()) {
-      DiagnoseHLSLAttrStageMismatch(NT, ST,
-                                    {HLSLShaderAttr::Compute,
-                                     HLSLShaderAttr::Amplification,
-                                     HLSLShaderAttr::Mesh});
-      FD->setInvalidDecl();
-    }
-    break;
-
-  case HLSLShaderAttr::Compute:
-  case HLSLShaderAttr::Amplification:
-  case HLSLShaderAttr::Mesh:
-    if (!FD->hasAttr<HLSLNumThreadsAttr>()) {
-      Diag(FD->getLocation(), diag::err_hlsl_missing_numthreads)
-          << HLSLShaderAttr::ConvertShaderTypeToStr(ST);
-      FD->setInvalidDecl();
-    }
-    break;
-  }
-
-  for (ParmVarDecl *Param : FD->parameters()) {
-    if (const auto *AnnotationAttr = Param->getAttr<HLSLAnnotationAttr>()) {
-      CheckHLSLSemanticAnnotation(FD, Param, AnnotationAttr);
-    } else {
-      // FIXME: Handle struct parameters where annotations are on struct fields.
-      // See: https://github.com/llvm/llvm-project/issues/57875
-      Diag(FD->getLocation(), diag::err_hlsl_missing_semantic_annotation);
-      Diag(Param->getLocation(), diag::note_previous_decl) << Param;
-      FD->setInvalidDecl();
-    }
-  }
-  // FIXME: Verify return type semantic annotation.
-}
-
-void Sema::CheckHLSLSemanticAnnotation(
-    FunctionDecl *EntryPoint, const Decl *Param,
-    const HLSLAnnotationAttr *AnnotationAttr) {
-  auto *ShaderAttr = EntryPoint->getAttr<HLSLShaderAttr>();
-  assert(ShaderAttr && "Entry point has no shader attribute");
-  HLSLShaderAttr::ShaderType ST = ShaderAttr->getType();
-
-  switch (AnnotationAttr->getKind()) {
-  case attr::HLSLSV_DispatchThreadID:
-  case attr::HLSLSV_GroupIndex:
-    if (ST == HLSLShaderAttr::Compute)
-      return;
-    DiagnoseHLSLAttrStageMismatch(AnnotationAttr, ST,
-                                  {HLSLShaderAttr::Compute});
-    break;
-  default:
-    llvm_unreachable("Unknown HLSLAnnotationAttr");
-  }
-}
-
-void Sema::DiagnoseHLSLAttrStageMismatch(
-    const Attr *A, HLSLShaderAttr::ShaderType Stage,
-    std::initializer_list<HLSLShaderAttr::ShaderType> AllowedStages) {
-  SmallVector<StringRef, 8> StageStrings;
-  llvm::transform(AllowedStages, std::back_inserter(StageStrings),
-                  [](HLSLShaderAttr::ShaderType ST) {
-                    return StringRef(
-                        HLSLShaderAttr::ConvertShaderTypeToStr(ST));
-                  });
-  Diag(A->getLoc(), diag::err_hlsl_attr_unsupported_in_stage)
-      << A << HLSLShaderAttr::ConvertShaderTypeToStr(Stage)
-      << (AllowedStages.size() != 1) << join(StageStrings, ", ");
-}
-
 bool Sema::CheckForConstantInitializer(Expr *Init, QualType DclT) {
   // FIXME: Need strict checking.  In C89, we need to check for
   // any assignment, increment, decrement, function-calls, or
diff --git a/clang/lib/Sema/SemaDeclAttr.cpp b/clang/lib/Sema/SemaDeclAttr.cpp
index 8bce04640e748e..b91064e28e4153 100644
--- a/clang/lib/Sema/SemaDeclAttr.cpp
+++ b/clang/lib/Sema/SemaDeclAttr.cpp
@@ -39,6 +39,7 @@
 #include "clang/Sema/ParsedAttr.h"
 #include "clang/Sema/Scope.h"
 #include "clang/Sema/ScopeInfo.h"
+#include "clang/Sema/SemaHLSL.h"
 #include "clang/Sema/SemaInternal.h"
 #include "llvm/ADT/STLExtras.h"
 #include "llvm/ADT/StringExtras.h"
@@ -7238,24 +7239,11 @@ static void handleHLSLNumThreadsAttr(Sema &S, Decl *D, const ParsedAttr &AL) {
     return;
   }
 
-  HLSLNumThreadsAttr *NewAttr = S.mergeHLSLNumThreadsAttr(D, AL, X, Y, Z);
+  HLSLNumThreadsAttr *NewAttr = S.HLSL().mergeNumThreadsAttr(D, AL, X, Y, Z);
   if (NewAttr)
     D->addAttr(NewAttr);
 }
 
-HLSLNumThreadsAttr *Sema::mergeHLSLNumThreadsAttr(Decl *D,
-                                                  const AttributeCommonInfo &AL,
-                                                  int X, int Y, int Z) {
-  if (HLSLNumThreadsAttr *NT = D->getAttr<HLSLNumThreadsAttr>()) {
-    if (NT->getX() != X || NT->getY() != Y || NT->getZ() != Z) {
-      Diag(NT->getLocation(), diag::err_hlsl_attribute_param_mismatch) << AL;
-      Diag(AL.getLoc(), diag::note_conflicting_attribute);
-    }
-    return nullptr;
-  }
-  return ::new (Context) HLSLNumThreadsAttr(Context, AL, X, Y, Z);
-}
-
 static bool isLegalTypeForHLSLSV_DispatchThreadID(QualType T) {
   if (!T->hasUnsignedIntegerRepresentation())
     return false;
@@ -7299,24 +7287,11 @@ static void handleHLSLShaderAttr(Sema &S, Decl *D, const ParsedAttr &AL) {
 
   // FIXME: check function match the shader stage.
 
-  HLSLShaderAttr *NewAttr = S.mergeHLSLShaderAttr(D, AL, ShaderType);
+  HLSLShaderAttr *NewAttr = S.HLSL().mergeShaderAttr(D, AL, ShaderType);
   if (NewAttr)
     D->addAttr(NewAttr);
 }
 
-HLSLShaderAttr *
-Sema::mergeHLSLShaderAttr(Decl *D, const AttributeCommonInfo &AL,
-                          HLSLShaderAttr::ShaderType ShaderType) {
-  if (HLSLShaderAttr *NT = D->getAttr<HLSLShaderAttr>()) {
-    if (NT->getType() != ShaderType) {
-      Diag(NT->getLocation(), diag::err_hlsl_attribute_param_mismatch) << AL;
-      Diag(AL.getLoc(), diag::note_conflicting_attribute);
-    }
-    return nullptr;
-  }
-  return HLSLShaderAttr::Create(Context, ShaderType, AL);
-}
-
 static void handleHLSLResourceBindingAttr(Sema &S, Decl *D,
                                           const ParsedAttr &AL) {
   StringRef Space = "space0";
@@ -7391,34 +7366,13 @@ static void handleHLSLResourceBindingAttr(Sema &S, Decl *D,
 
 static void handleHLSLParamModifierAttr(Sema &S, Decl *D,
                                         const ParsedAttr &AL) {
-  HLSLParamModifierAttr *NewAttr = S.mergeHLSLParamModifierAttr(
+  HLSLParamModifierAttr *NewAttr = S.HLSL().mergeParamModifierAttr(
       D, AL,
       static_cast<HLSLParamModifierAttr::Spelling>(AL.getSemanticSpelling()));
   if (NewAttr)
     D->addAttr(NewAttr);
 }
 
-HLSLParamModifierAttr *
-Sema::mergeHLSLParamModifierAttr(Decl *D, const AttributeCommonInfo &AL,
-                                 HLSLParamModifierAttr::Spelling Spelling) {
-  // We can only merge an `in` attribute with an `out` attribute. All other
-  // combinations of duplicated attributes are ill-formed.
-  if (HLSLParamModifierAttr *PA = D->getAttr<HLSLParamModifierAttr>()) {
-    if ((PA->isIn() && Spelling == HLSLParamModifierAttr::Keyword_out) ||
-        (PA->isOut() && Spelling == HLSLParamModifierAttr::Keyword_in)) {
-      D->dropAttr<HLSLParamModifierAttr>();
-      SourceRange AdjustedRange = {PA->getLocation(), AL.getRange().getEnd()};
-      return HLSLParamModifierAttr::Create(
-          Context, /*MergedSpelling=*/true, AdjustedRange,
-          HLSLParamModifierAttr::Keyword_inout);
-    }
-    Diag(AL.getLoc(), diag::err_hlsl_duplicate_parameter_modifier) << AL;
-    Diag(PA->getLocation(), diag::note_conflicting_attribute);
-    return nullptr;
-  }
-  return HLSLParamModifierAttr::Create(Context, AL);
-}
-
 static void handleMSInheritanceAttr(Sema &S, Decl *D, const ParsedAttr &AL) {
   if (!S.LangOpts.CPlusPlus) {
     S.Diag(AL.getLoc(), diag::err_attribute_not_supported_in_lang)
diff --git a/clang/lib/Sema/SemaHLSL.cpp b/clang/lib/Sema/SemaHLSL.cpp
index 681849d6e6c8a2..bb9e37f18d370c 100644
--- a/clang/lib/Sema/SemaHLSL.cpp
+++ b/clang/lib/Sema/SemaHLSL.cpp
@@ -9,17 +9,25 @@
 //===----------------------------------------------------------------------===//
 
 #include "clang/Sema/SemaHLSL.h"
+#include "clang/Basic/DiagnosticSema.h"
+#include "clang/Basic/LLVM.h"
+#include "clang/Basic/TargetInfo.h"
 #include "clang/Sema/Sema.h"
+#include "llvm/ADT/STLExtras.h"
+#include "llvm/ADT/StringExtras.h"
+#include "llvm/ADT/StringRef.h"
+#include "llvm/Support/ErrorHandling.h"
+#include "llvm/TargetParser/Triple.h"
+#include <iterator>
 
 using namespace clang;
 
 SemaHLSL::SemaHLSL(Sema &S) : SemaBase(S) {}
 
-Decl *SemaHLSL::ActOnStartHLSLBuffer(Scope *BufferScope, bool CBuffer,
-                                     SourceLocation KwLoc,
-                                     IdentifierInfo *Ident,
-                                     SourceLocation IdentLoc,
-                                     SourceLocation LBrace) {
+Decl *SemaHLSL::ActOnStartBuffer(Scope *BufferScope, bool CBuffer,
+                                 SourceLocation KwLoc, IdentifierInfo *Ident,
+                                 SourceLocation IdentLoc,
+                                 SourceLocation LBrace) {
   // For anonymous namespace, take the location of the left brace.
   DeclContext *LexicalParent = SemaRef.getCurLexicalContext();
   HLSLBufferDecl *Result = HLSLBufferDecl::Create(
@@ -31,8 +39,174 @@ Decl *SemaHLSL::ActOnStartHLSLBuffer(Scope *BufferScope, bool CBuffer,
   return Result;
 }
 
-void SemaHLSL::ActOnFinishHLSLBuffer(Decl *Dcl, SourceLocation RBrace) {
+void SemaHLSL::ActOnFinishBuffer(Decl *Dcl, SourceLocation RBrace) {
   auto *BufDecl = cast<HLSLBufferDecl>(Dcl);
   BufDecl->setRBraceLoc(RBrace);
   SemaRef.PopDeclContext();
 }
+
+HLSLNumThreadsAttr *SemaHLSL::mergeNumThreadsAttr(Decl *D,
+                                                  const AttributeCommonInfo &AL,
+                                                  int X, int Y, int Z) {
+  if (HLSLNumThreadsAttr *NT = D->getAttr<HLSLNumThreadsAttr>()) {
+    if (NT->getX() != X || NT->getY() != Y || NT->getZ() != Z) {
+      Diag(NT->getLocation(), diag::err_hlsl_attribute_param_mismatch) << AL;
+      Diag(AL.getLoc(), diag::note_conflicting_attribute);
+    }
+    return nullptr;
+  }
+  return ::new (getASTContext())
+      HLSLNumThreadsAttr(getASTContext(), AL, X, Y, Z);
+}
+
+HLSLShaderAttr *
+SemaHLSL::mergeShaderAttr(Decl *D, const AttributeCommonInfo &AL,
+                          HLSLShaderAttr::ShaderType ShaderType) {
+  if (HLSLShaderAttr *NT = D->getAttr<HLSLShaderAttr>()) {
+    if (NT->getType() != ShaderType) {
+      Diag(NT->getLocation(), diag::err_hlsl_attribute_param_mismatch) << AL;
+      Diag(AL.getLoc(), diag::note_conflicting_attribute);
+    }
+    return nullptr;
+  }
+  return HLSLShaderAttr::Create(getASTContext(), ShaderType, AL);
+}
+
+HLSLParamModifierAttr *
+SemaHLSL::mergeParamModifierAttr(Decl *D, const AttributeCommonInfo &AL,
+                                 HLSLParamModifierAttr::Spelling Spelling) {
+  // We can only merge an `in` attribute with an `out` attribute. All other
+  // combinations of duplicated attributes are ill-formed.
+  if (HLSLParamModifierAttr *PA = D->getAttr<HLSLParamModifierAttr>()) {
+    if ((PA->isIn() && Spelling == HLSLParamModifierAttr::Keyword_out) ||
+        (PA->isOut() && Spelling == HLSLParamModifierAttr::Keyword_in)) {
+      D->dropAttr<HLSLParamModifierAttr>();
+      SourceRange AdjustedRange = {PA->getLocation(), AL.getRange().getEnd()};
+      return HLSLParamModifierAttr::Create(
+          getASTContext(), /*MergedSpelling=*/true, AdjustedRange,
+          HLSLParamModifierAttr::Keyword_inout);
+    }
+    Diag(AL.getLoc(), diag::err_hlsl_duplicate_parameter_modifier) << AL;
+    Diag(PA->getLocation(), diag::note_conflicting_attribute);
+    return nullptr;
+  }
+  return HLSLParamModifierAttr::Cr...
[truncated]

@llvmbot
Copy link
Collaborator

llvmbot commented Apr 11, 2024

@llvm/pr-subscribers-clang

Author: Vlad Serebrennikov (Endilll)

Changes

A follow-up to #87912. I'm moving more HLSL-related functions from Sema to SemaHLSL. I'm also dropping HLSL from their names in the process.


Patch is 24.36 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/88354.diff

6 Files Affected:

  • (modified) clang/include/clang/Sema/Sema.h (-15)
  • (modified) clang/include/clang/Sema/SemaHLSL.h (+23-4)
  • (modified) clang/lib/Parse/ParseHLSL.cpp (+5-5)
  • (modified) clang/lib/Sema/SemaDecl.cpp (+6-124)
  • (modified) clang/lib/Sema/SemaDeclAttr.cpp (+4-50)
  • (modified) clang/lib/Sema/SemaHLSL.cpp (+180-6)
diff --git a/clang/include/clang/Sema/Sema.h b/clang/include/clang/Sema/Sema.h
index e3e255a0dd76f8..e904cd3ad13fb7 100644
--- a/clang/include/clang/Sema/Sema.h
+++ b/clang/include/clang/Sema/Sema.h
@@ -2940,13 +2940,6 @@ class Sema final : public SemaBase {
                                       QualType NewT, QualType OldT);
   void CheckMain(FunctionDecl *FD, const DeclSpec &D);
   void CheckMSVCRTEntryPoint(FunctionDecl *FD);
-  void ActOnHLSLTopLevelFunction(FunctionDecl *FD);
-  void CheckHLSLEntryPoint(FunctionDecl *FD);
-  void CheckHLSLSemanticAnnotation(FunctionDecl *EntryPoint, const Decl *Param,
-                                   const HLSLAnnotationAttr *AnnotationAttr);
-  void DiagnoseHLSLAttrStageMismatch(
-      const Attr *A, HLSLShaderAttr::ShaderType Stage,
-      std::initializer_list<HLSLShaderAttr::ShaderType> AllowedStages);
   Attr *getImplicitCodeSegOrSectionAttrForFunction(const FunctionDecl *FD,
                                                    bool IsDefinition);
   void CheckFunctionOrTemplateParamDeclarator(Scope *S, Declarator &D);
@@ -3707,14 +3700,6 @@ class Sema final : public SemaBase {
                           StringRef UuidAsWritten, MSGuidDecl *GuidDecl);
 
   BTFDeclTagAttr *mergeBTFDeclTagAttr(Decl *D, const BTFDeclTagAttr &AL);
-  HLSLNumThreadsAttr *mergeHLSLNumThreadsAttr(Decl *D,
-                                              const AttributeCommonInfo &AL,
-                                              int X, int Y, int Z);
-  HLSLShaderAttr *mergeHLSLShaderAttr(Decl *D, const AttributeCommonInfo &AL,
-                                      HLSLShaderAttr::ShaderType ShaderType);
-  HLSLParamModifierAttr *
-  mergeHLSLParamModifierAttr(Decl *D, const AttributeCommonInfo &AL,
-                             HLSLParamModifierAttr::Spelling Spelling);
 
   WebAssemblyImportNameAttr *
   mergeImportNameAttr(Decl *D, const WebAssemblyImportNameAttr &AL);
diff --git a/clang/include/clang/Sema/SemaHLSL.h b/clang/include/clang/Sema/SemaHLSL.h
index acc675963c23a5..34acaf19517f2a 100644
--- a/clang/include/clang/Sema/SemaHLSL.h
+++ b/clang/include/clang/Sema/SemaHLSL.h
@@ -13,12 +13,16 @@
 #ifndef LLVM_CLANG_SEMA_SEMAHLSL_H
 #define LLVM_CLANG_SEMA_SEMAHLSL_H
 
+#include "clang/AST/Attr.h"
+#include "clang/AST/Decl.h"
 #include "clang/AST/DeclBase.h"
 #include "clang/AST/Expr.h"
+#include "clang/Basic/AttributeCommonInfo.h"
 #include "clang/Basic/IdentifierTable.h"
 #include "clang/Basic/SourceLocation.h"
 #include "clang/Sema/Scope.h"
 #include "clang/Sema/SemaBase.h"
+#include <initializer_list>
 
 namespace clang {
 
@@ -26,10 +30,25 @@ class SemaHLSL : public SemaBase {
 public:
   SemaHLSL(Sema &S);
 
-  Decl *ActOnStartHLSLBuffer(Scope *BufferScope, bool CBuffer,
-                             SourceLocation KwLoc, IdentifierInfo *Ident,
-                             SourceLocation IdentLoc, SourceLocation LBrace);
-  void ActOnFinishHLSLBuffer(Decl *Dcl, SourceLocation RBrace);
+  Decl *ActOnStartBuffer(Scope *BufferScope, bool CBuffer, SourceLocation KwLoc,
+                         IdentifierInfo *Ident, SourceLocation IdentLoc,
+                         SourceLocation LBrace);
+  void ActOnFinishBuffer(Decl *Dcl, SourceLocation RBrace);
+  HLSLNumThreadsAttr *mergeNumThreadsAttr(Decl *D,
+                                          const AttributeCommonInfo &AL, int X,
+                                          int Y, int Z);
+  HLSLShaderAttr *mergeShaderAttr(Decl *D, const AttributeCommonInfo &AL,
+                                  HLSLShaderAttr::ShaderType ShaderType);
+  HLSLParamModifierAttr *
+  mergeParamModifierAttr(Decl *D, const AttributeCommonInfo &AL,
+                         HLSLParamModifierAttr::Spelling Spelling);
+  void ActOnTopLevelFunction(FunctionDecl *FD);
+  void CheckEntryPoint(FunctionDecl *FD);
+  void CheckSemanticAnnotation(FunctionDecl *EntryPoint, const Decl *Param,
+                               const HLSLAnnotationAttr *AnnotationAttr);
+  void DiagnoseAttrStageMismatch(
+      const Attr *A, HLSLShaderAttr::ShaderType Stage,
+      std::initializer_list<HLSLShaderAttr::ShaderType> AllowedStages);
 };
 
 } // namespace clang
diff --git a/clang/lib/Parse/ParseHLSL.cpp b/clang/lib/Parse/ParseHLSL.cpp
index 5afc958600fa55..d97985d42369ad 100644
--- a/clang/lib/Parse/ParseHLSL.cpp
+++ b/clang/lib/Parse/ParseHLSL.cpp
@@ -72,9 +72,9 @@ Decl *Parser::ParseHLSLBuffer(SourceLocation &DeclEnd) {
     return nullptr;
   }
 
-  Decl *D = Actions.HLSL().ActOnStartHLSLBuffer(
-      getCurScope(), IsCBuffer, BufferLoc, Identifier, IdentifierLoc,
-      T.getOpenLocation());
+  Decl *D = Actions.HLSL().ActOnStartBuffer(getCurScope(), IsCBuffer, BufferLoc,
+                                            Identifier, IdentifierLoc,
+                                            T.getOpenLocation());
 
   while (Tok.isNot(tok::r_brace) && Tok.isNot(tok::eof)) {
     // FIXME: support attribute on constants inside cbuffer/tbuffer.
@@ -88,7 +88,7 @@ Decl *Parser::ParseHLSLBuffer(SourceLocation &DeclEnd) {
       T.skipToEnd();
       DeclEnd = T.getCloseLocation();
       BufferScope.Exit();
-      Actions.HLSL().ActOnFinishHLSLBuffer(D, DeclEnd);
+      Actions.HLSL().ActOnFinishBuffer(D, DeclEnd);
       return nullptr;
     }
   }
@@ -96,7 +96,7 @@ Decl *Parser::ParseHLSLBuffer(SourceLocation &DeclEnd) {
   T.consumeClose();
   DeclEnd = T.getCloseLocation();
   BufferScope.Exit();
-  Actions.HLSL().ActOnFinishHLSLBuffer(D, DeclEnd);
+  Actions.HLSL().ActOnFinishBuffer(D, DeclEnd);
 
   Actions.ProcessDeclAttributeList(Actions.CurScope, D, Attrs);
   return D;
diff --git a/clang/lib/Sema/SemaDecl.cpp b/clang/lib/Sema/SemaDecl.cpp
index 8472aaeb6bad97..3beb4fb1f8c733 100644
--- a/clang/lib/Sema/SemaDecl.cpp
+++ b/clang/lib/Sema/SemaDecl.cpp
@@ -45,6 +45,7 @@
 #include "clang/Sema/ParsedTemplate.h"
 #include "clang/Sema/Scope.h"
 #include "clang/Sema/ScopeInfo.h"
+#include "clang/Sema/SemaHLSL.h"
 #include "clang/Sema/SemaInternal.h"
 #include "clang/Sema/Template.h"
 #include "llvm/ADT/SmallString.h"
@@ -2972,10 +2973,10 @@ static bool mergeDeclAttribute(Sema &S, NamedDecl *D,
   else if (const auto *BTFA = dyn_cast<BTFDeclTagAttr>(Attr))
     NewAttr = S.mergeBTFDeclTagAttr(D, *BTFA);
   else if (const auto *NT = dyn_cast<HLSLNumThreadsAttr>(Attr))
-    NewAttr =
-        S.mergeHLSLNumThreadsAttr(D, *NT, NT->getX(), NT->getY(), NT->getZ());
+    NewAttr = S.HLSL().mergeNumThreadsAttr(D, *NT, NT->getX(), NT->getY(),
+                                           NT->getZ());
   else if (const auto *SA = dyn_cast<HLSLShaderAttr>(Attr))
-    NewAttr = S.mergeHLSLShaderAttr(D, *SA, SA->getType());
+    NewAttr = S.HLSL().mergeShaderAttr(D, *SA, SA->getType());
   else if (isa<SuppressAttr>(Attr))
     // Do nothing. Each redeclaration should be suppressed separately.
     NewAttr = nullptr;
@@ -10809,10 +10810,10 @@ Sema::ActOnFunctionDeclarator(Scope *S, Declarator &D, DeclContext *DC,
   if (getLangOpts().HLSL && D.isFunctionDefinition()) {
     // Any top level function could potentially be specified as an entry.
     if (!NewFD->isInvalidDecl() && S->getDepth() == 0 && Name.isIdentifier())
-      ActOnHLSLTopLevelFunction(NewFD);
+      HLSL().ActOnTopLevelFunction(NewFD);
 
     if (NewFD->hasAttr<HLSLShaderAttr>())
-      CheckHLSLEntryPoint(NewFD);
+      HLSL().CheckEntryPoint(NewFD);
   }
 
   // If this is the first declaration of a library builtin function, add
@@ -12660,125 +12661,6 @@ void Sema::CheckMSVCRTEntryPoint(FunctionDecl *FD) {
   }
 }
 
-void Sema::ActOnHLSLTopLevelFunction(FunctionDecl *FD) {
-  auto &TargetInfo = getASTContext().getTargetInfo();
-
-  if (FD->getName() != TargetInfo.getTargetOpts().HLSLEntry)
-    return;
-
-  StringRef Env = TargetInfo.getTriple().getEnvironmentName();
-  HLSLShaderAttr::ShaderType ShaderType;
-  if (HLSLShaderAttr::ConvertStrToShaderType(Env, ShaderType)) {
-    if (const auto *Shader = FD->getAttr<HLSLShaderAttr>()) {
-      // The entry point is already annotated - check that it matches the
-      // triple.
-      if (Shader->getType() != ShaderType) {
-        Diag(Shader->getLocation(), diag::err_hlsl_entry_shader_attr_mismatch)
-            << Shader;
-        FD->setInvalidDecl();
-      }
-    } else {
-      // Implicitly add the shader attribute if the entry function isn't
-      // explicitly annotated.
-      FD->addAttr(HLSLShaderAttr::CreateImplicit(Context, ShaderType,
-                                                 FD->getBeginLoc()));
-    }
-  } else {
-    switch (TargetInfo.getTriple().getEnvironment()) {
-    case llvm::Triple::UnknownEnvironment:
-    case llvm::Triple::Library:
-      break;
-    default:
-      llvm_unreachable("Unhandled environment in triple");
-    }
-  }
-}
-
-void Sema::CheckHLSLEntryPoint(FunctionDecl *FD) {
-  const auto *ShaderAttr = FD->getAttr<HLSLShaderAttr>();
-  assert(ShaderAttr && "Entry point has no shader attribute");
-  HLSLShaderAttr::ShaderType ST = ShaderAttr->getType();
-
-  switch (ST) {
-  case HLSLShaderAttr::Pixel:
-  case HLSLShaderAttr::Vertex:
-  case HLSLShaderAttr::Geometry:
-  case HLSLShaderAttr::Hull:
-  case HLSLShaderAttr::Domain:
-  case HLSLShaderAttr::RayGeneration:
-  case HLSLShaderAttr::Intersection:
-  case HLSLShaderAttr::AnyHit:
-  case HLSLShaderAttr::ClosestHit:
-  case HLSLShaderAttr::Miss:
-  case HLSLShaderAttr::Callable:
-    if (const auto *NT = FD->getAttr<HLSLNumThreadsAttr>()) {
-      DiagnoseHLSLAttrStageMismatch(NT, ST,
-                                    {HLSLShaderAttr::Compute,
-                                     HLSLShaderAttr::Amplification,
-                                     HLSLShaderAttr::Mesh});
-      FD->setInvalidDecl();
-    }
-    break;
-
-  case HLSLShaderAttr::Compute:
-  case HLSLShaderAttr::Amplification:
-  case HLSLShaderAttr::Mesh:
-    if (!FD->hasAttr<HLSLNumThreadsAttr>()) {
-      Diag(FD->getLocation(), diag::err_hlsl_missing_numthreads)
-          << HLSLShaderAttr::ConvertShaderTypeToStr(ST);
-      FD->setInvalidDecl();
-    }
-    break;
-  }
-
-  for (ParmVarDecl *Param : FD->parameters()) {
-    if (const auto *AnnotationAttr = Param->getAttr<HLSLAnnotationAttr>()) {
-      CheckHLSLSemanticAnnotation(FD, Param, AnnotationAttr);
-    } else {
-      // FIXME: Handle struct parameters where annotations are on struct fields.
-      // See: https://github.com/llvm/llvm-project/issues/57875
-      Diag(FD->getLocation(), diag::err_hlsl_missing_semantic_annotation);
-      Diag(Param->getLocation(), diag::note_previous_decl) << Param;
-      FD->setInvalidDecl();
-    }
-  }
-  // FIXME: Verify return type semantic annotation.
-}
-
-void Sema::CheckHLSLSemanticAnnotation(
-    FunctionDecl *EntryPoint, const Decl *Param,
-    const HLSLAnnotationAttr *AnnotationAttr) {
-  auto *ShaderAttr = EntryPoint->getAttr<HLSLShaderAttr>();
-  assert(ShaderAttr && "Entry point has no shader attribute");
-  HLSLShaderAttr::ShaderType ST = ShaderAttr->getType();
-
-  switch (AnnotationAttr->getKind()) {
-  case attr::HLSLSV_DispatchThreadID:
-  case attr::HLSLSV_GroupIndex:
-    if (ST == HLSLShaderAttr::Compute)
-      return;
-    DiagnoseHLSLAttrStageMismatch(AnnotationAttr, ST,
-                                  {HLSLShaderAttr::Compute});
-    break;
-  default:
-    llvm_unreachable("Unknown HLSLAnnotationAttr");
-  }
-}
-
-void Sema::DiagnoseHLSLAttrStageMismatch(
-    const Attr *A, HLSLShaderAttr::ShaderType Stage,
-    std::initializer_list<HLSLShaderAttr::ShaderType> AllowedStages) {
-  SmallVector<StringRef, 8> StageStrings;
-  llvm::transform(AllowedStages, std::back_inserter(StageStrings),
-                  [](HLSLShaderAttr::ShaderType ST) {
-                    return StringRef(
-                        HLSLShaderAttr::ConvertShaderTypeToStr(ST));
-                  });
-  Diag(A->getLoc(), diag::err_hlsl_attr_unsupported_in_stage)
-      << A << HLSLShaderAttr::ConvertShaderTypeToStr(Stage)
-      << (AllowedStages.size() != 1) << join(StageStrings, ", ");
-}
-
 bool Sema::CheckForConstantInitializer(Expr *Init, QualType DclT) {
   // FIXME: Need strict checking.  In C89, we need to check for
   // any assignment, increment, decrement, function-calls, or
diff --git a/clang/lib/Sema/SemaDeclAttr.cpp b/clang/lib/Sema/SemaDeclAttr.cpp
index 8bce04640e748e..b91064e28e4153 100644
--- a/clang/lib/Sema/SemaDeclAttr.cpp
+++ b/clang/lib/Sema/SemaDeclAttr.cpp
@@ -39,6 +39,7 @@
 #include "clang/Sema/ParsedAttr.h"
 #include "clang/Sema/Scope.h"
 #include "clang/Sema/ScopeInfo.h"
+#include "clang/Sema/SemaHLSL.h"
 #include "clang/Sema/SemaInternal.h"
 #include "llvm/ADT/STLExtras.h"
 #include "llvm/ADT/StringExtras.h"
@@ -7238,24 +7239,11 @@ static void handleHLSLNumThreadsAttr(Sema &S, Decl *D, const ParsedAttr &AL) {
     return;
   }
 
-  HLSLNumThreadsAttr *NewAttr = S.mergeHLSLNumThreadsAttr(D, AL, X, Y, Z);
+  HLSLNumThreadsAttr *NewAttr = S.HLSL().mergeNumThreadsAttr(D, AL, X, Y, Z);
   if (NewAttr)
     D->addAttr(NewAttr);
 }
 
-HLSLNumThreadsAttr *Sema::mergeHLSLNumThreadsAttr(Decl *D,
-                                                  const AttributeCommonInfo &AL,
-                                                  int X, int Y, int Z) {
-  if (HLSLNumThreadsAttr *NT = D->getAttr<HLSLNumThreadsAttr>()) {
-    if (NT->getX() != X || NT->getY() != Y || NT->getZ() != Z) {
-      Diag(NT->getLocation(), diag::err_hlsl_attribute_param_mismatch) << AL;
-      Diag(AL.getLoc(), diag::note_conflicting_attribute);
-    }
-    return nullptr;
-  }
-  return ::new (Context) HLSLNumThreadsAttr(Context, AL, X, Y, Z);
-}
-
 static bool isLegalTypeForHLSLSV_DispatchThreadID(QualType T) {
   if (!T->hasUnsignedIntegerRepresentation())
     return false;
@@ -7299,24 +7287,11 @@ static void handleHLSLShaderAttr(Sema &S, Decl *D, const ParsedAttr &AL) {
 
   // FIXME: check function match the shader stage.
 
-  HLSLShaderAttr *NewAttr = S.mergeHLSLShaderAttr(D, AL, ShaderType);
+  HLSLShaderAttr *NewAttr = S.HLSL().mergeShaderAttr(D, AL, ShaderType);
   if (NewAttr)
     D->addAttr(NewAttr);
 }
 
-HLSLShaderAttr *
-Sema::mergeHLSLShaderAttr(Decl *D, const AttributeCommonInfo &AL,
-                          HLSLShaderAttr::ShaderType ShaderType) {
-  if (HLSLShaderAttr *NT = D->getAttr<HLSLShaderAttr>()) {
-    if (NT->getType() != ShaderType) {
-      Diag(NT->getLocation(), diag::err_hlsl_attribute_param_mismatch) << AL;
-      Diag(AL.getLoc(), diag::note_conflicting_attribute);
-    }
-    return nullptr;
-  }
-  return HLSLShaderAttr::Create(Context, ShaderType, AL);
-}
-
 static void handleHLSLResourceBindingAttr(Sema &S, Decl *D,
                                           const ParsedAttr &AL) {
   StringRef Space = "space0";
@@ -7391,34 +7366,13 @@ static void handleHLSLResourceBindingAttr(Sema &S, Decl *D,
 
 static void handleHLSLParamModifierAttr(Sema &S, Decl *D,
                                         const ParsedAttr &AL) {
-  HLSLParamModifierAttr *NewAttr = S.mergeHLSLParamModifierAttr(
+  HLSLParamModifierAttr *NewAttr = S.HLSL().mergeParamModifierAttr(
       D, AL,
       static_cast<HLSLParamModifierAttr::Spelling>(AL.getSemanticSpelling()));
   if (NewAttr)
     D->addAttr(NewAttr);
 }
 
-HLSLParamModifierAttr *
-Sema::mergeHLSLParamModifierAttr(Decl *D, const AttributeCommonInfo &AL,
-                                 HLSLParamModifierAttr::Spelling Spelling) {
-  // We can only merge an `in` attribute with an `out` attribute. All other
-  // combinations of duplicated attributes are ill-formed.
-  if (HLSLParamModifierAttr *PA = D->getAttr<HLSLParamModifierAttr>()) {
-    if ((PA->isIn() && Spelling == HLSLParamModifierAttr::Keyword_out) ||
-        (PA->isOut() && Spelling == HLSLParamModifierAttr::Keyword_in)) {
-      D->dropAttr<HLSLParamModifierAttr>();
-      SourceRange AdjustedRange = {PA->getLocation(), AL.getRange().getEnd()};
-      return HLSLParamModifierAttr::Create(
-          Context, /*MergedSpelling=*/true, AdjustedRange,
-          HLSLParamModifierAttr::Keyword_inout);
-    }
-    Diag(AL.getLoc(), diag::err_hlsl_duplicate_parameter_modifier) << AL;
-    Diag(PA->getLocation(), diag::note_conflicting_attribute);
-    return nullptr;
-  }
-  return HLSLParamModifierAttr::Create(Context, AL);
-}
-
 static void handleMSInheritanceAttr(Sema &S, Decl *D, const ParsedAttr &AL) {
   if (!S.LangOpts.CPlusPlus) {
     S.Diag(AL.getLoc(), diag::err_attribute_not_supported_in_lang)
diff --git a/clang/lib/Sema/SemaHLSL.cpp b/clang/lib/Sema/SemaHLSL.cpp
index 681849d6e6c8a2..bb9e37f18d370c 100644
--- a/clang/lib/Sema/SemaHLSL.cpp
+++ b/clang/lib/Sema/SemaHLSL.cpp
@@ -9,17 +9,25 @@
 //===----------------------------------------------------------------------===//
 
 #include "clang/Sema/SemaHLSL.h"
+#include "clang/Basic/DiagnosticSema.h"
+#include "clang/Basic/LLVM.h"
+#include "clang/Basic/TargetInfo.h"
 #include "clang/Sema/Sema.h"
+#include "llvm/ADT/STLExtras.h"
+#include "llvm/ADT/StringExtras.h"
+#include "llvm/ADT/StringRef.h"
+#include "llvm/Support/ErrorHandling.h"
+#include "llvm/TargetParser/Triple.h"
+#include <iterator>
 
 using namespace clang;
 
 SemaHLSL::SemaHLSL(Sema &S) : SemaBase(S) {}
 
-Decl *SemaHLSL::ActOnStartHLSLBuffer(Scope *BufferScope, bool CBuffer,
-                                     SourceLocation KwLoc,
-                                     IdentifierInfo *Ident,
-                                     SourceLocation IdentLoc,
-                                     SourceLocation LBrace) {
+Decl *SemaHLSL::ActOnStartBuffer(Scope *BufferScope, bool CBuffer,
+                                 SourceLocation KwLoc, IdentifierInfo *Ident,
+                                 SourceLocation IdentLoc,
+                                 SourceLocation LBrace) {
   // For anonymous namespace, take the location of the left brace.
   DeclContext *LexicalParent = SemaRef.getCurLexicalContext();
   HLSLBufferDecl *Result = HLSLBufferDecl::Create(
@@ -31,8 +39,174 @@ Decl *SemaHLSL::ActOnStartHLSLBuffer(Scope *BufferScope, bool CBuffer,
   return Result;
 }
 
-void SemaHLSL::ActOnFinishHLSLBuffer(Decl *Dcl, SourceLocation RBrace) {
+void SemaHLSL::ActOnFinishBuffer(Decl *Dcl, SourceLocation RBrace) {
   auto *BufDecl = cast<HLSLBufferDecl>(Dcl);
   BufDecl->setRBraceLoc(RBrace);
   SemaRef.PopDeclContext();
 }
+
+HLSLNumThreadsAttr *SemaHLSL::mergeNumThreadsAttr(Decl *D,
+                                                  const AttributeCommonInfo &AL,
+                                                  int X, int Y, int Z) {
+  if (HLSLNumThreadsAttr *NT = D->getAttr<HLSLNumThreadsAttr>()) {
+    if (NT->getX() != X || NT->getY() != Y || NT->getZ() != Z) {
+      Diag(NT->getLocation(), diag::err_hlsl_attribute_param_mismatch) << AL;
+      Diag(AL.getLoc(), diag::note_conflicting_attribute);
+    }
+    return nullptr;
+  }
+  return ::new (getASTContext())
+      HLSLNumThreadsAttr(getASTContext(), AL, X, Y, Z);
+}
+
+HLSLShaderAttr *
+SemaHLSL::mergeShaderAttr(Decl *D, const AttributeCommonInfo &AL,
+                          HLSLShaderAttr::ShaderType ShaderType) {
+  if (HLSLShaderAttr *NT = D->getAttr<HLSLShaderAttr>()) {
+    if (NT->getType() != ShaderType) {
+      Diag(NT->getLocation(), diag::err_hlsl_attribute_param_mismatch) << AL;
+      Diag(AL.getLoc(), diag::note_conflicting_attribute);
+    }
+    return nullptr;
+  }
+  return HLSLShaderAttr::Create(getASTContext(), ShaderType, AL);
+}
+
+HLSLParamModifierAttr *
+SemaHLSL::mergeParamModifierAttr(Decl *D, const AttributeCommonInfo &AL,
+                                 HLSLParamModifierAttr::Spelling Spelling) {
+  // We can only merge an `in` attribute with an `out` attribute. All other
+  // combinations of duplicated attributes are ill-formed.
+  if (HLSLParamModifierAttr *PA = D->getAttr<HLSLParamModifierAttr>()) {
+    if ((PA->isIn() && Spelling == HLSLParamModifierAttr::Keyword_out) ||
+        (PA->isOut() && Spelling == HLSLParamModifierAttr::Keyword_in)) {
+      D->dropAttr<HLSLParamModifierAttr>();
+      SourceRange AdjustedRange = {PA->getLocation(), AL.getRange().getEnd()};
+      return HLSLParamModifierAttr::Create(
+          getASTContext(), /*MergedSpelling=*/true, AdjustedRange,
+          HLSLParamModifierAttr::Keyword_inout);
+    }
+    Diag(AL.getLoc(), diag::err_hlsl_duplicate_parameter_modifier) << AL;
+    Diag(PA->getLocation(), diag::note_conflicting_attribute);
+    return nullptr;
+  }
+  return HLSLParamModifierAttr::Cr...
[truncated]

@Endilll Endilll merged commit b45c9c3 into llvm:main Apr 12, 2024
8 checks passed
@Endilll Endilll deleted the semahlsl-functions branch April 12, 2024 03:45
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
clang:frontend Language frontend issues, e.g. anything involving "Sema" clang Clang issues not falling into any other category HLSL HLSL Language Support
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

3 participants