Skip to content

[HLSL][NFC] Remove RegisterBindingFlags struct #108924

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 7 commits into from
Sep 18, 2024
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
308 changes: 119 additions & 189 deletions clang/lib/Sema/SemaHLSL.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,48 @@
#include <utility>

using namespace clang;
using llvm::dxil::ResourceClass;

enum class RegisterType { SRV, UAV, CBuffer, Sampler, C, I, Invalid };

static RegisterType getRegisterType(ResourceClass RC) {
switch (RC) {
case ResourceClass::SRV:
return RegisterType::SRV;
case ResourceClass::UAV:
return RegisterType::UAV;
case ResourceClass::CBuffer:
return RegisterType::CBuffer;
case ResourceClass::Sampler:
return RegisterType::Sampler;
}
llvm_unreachable("unexpected ResourceClass value");
}

static RegisterType getRegisterType(StringRef Slot) {
switch (Slot[0]) {
case 't':
case 'T':
return RegisterType::SRV;
case 'u':
case 'U':
return RegisterType::UAV;
case 'b':
case 'B':
return RegisterType::CBuffer;
case 's':
case 'S':
return RegisterType::Sampler;
case 'c':
case 'C':
return RegisterType::C;
case 'i':
case 'I':
return RegisterType::I;
default:
return RegisterType::Invalid;
}
}

SemaHLSL::SemaHLSL(Sema &S) : SemaBase(S) {}

Expand Down Expand Up @@ -586,8 +628,7 @@ bool clang::CreateHLSLAttributedResourceType(
LocEnd = A->getRange().getEnd();
switch (A->getKind()) {
case attr::HLSLResourceClass: {
llvm::dxil::ResourceClass RC =
cast<HLSLResourceClassAttr>(A)->getResourceClass();
ResourceClass RC = cast<HLSLResourceClassAttr>(A)->getResourceClass();
if (HasResourceClass) {
S.Diag(A->getLocation(), ResAttrs.ResourceClass == RC
? diag::warn_duplicate_attribute_exact
Expand Down Expand Up @@ -672,7 +713,7 @@ bool SemaHLSL::handleResourceTypeAttr(const ParsedAttr &AL) {
SourceLocation ArgLoc = Loc->Loc;

// Validate resource class value
llvm::dxil::ResourceClass RC;
ResourceClass RC;
if (!HLSLResourceClassAttr::ConvertStrToResourceClass(Identifier, RC)) {
Diag(ArgLoc, diag::warn_attribute_type_not_supported)
<< "ResourceClass" << Identifier;
Expand Down Expand Up @@ -750,28 +791,6 @@ SemaHLSL::TakeLocForHLSLAttribute(const HLSLAttributedResourceType *RT) {
return LocInfo;
}

struct RegisterBindingFlags {
bool Resource = false;
bool UDT = false;
bool Other = false;
bool Basic = false;

bool SRV = false;
bool UAV = false;
bool CBV = false;
bool Sampler = false;

bool ContainsNumeric = false;
bool DefaultGlobals = false;

// used only when Resource == true
std::optional<llvm::dxil::ResourceClass> ResourceClass;
};

static bool isDeclaredWithinCOrTBuffer(const Decl *TheDecl) {
return TheDecl && isa<HLSLBufferDecl>(TheDecl->getDeclContext());
}

// get the record decl from a var decl that we expect
// represents a resource
static CXXRecordDecl *getRecordDeclFromVarDecl(VarDecl *VD) {
Expand All @@ -786,24 +805,6 @@ static CXXRecordDecl *getRecordDeclFromVarDecl(VarDecl *VD) {
return TheRecordDecl;
}

static void updateResourceClassFlagsFromDeclResourceClass(
RegisterBindingFlags &Flags, llvm::hlsl::ResourceClass DeclResourceClass) {
switch (DeclResourceClass) {
case llvm::hlsl::ResourceClass::SRV:
Flags.SRV = true;
break;
case llvm::hlsl::ResourceClass::UAV:
Flags.UAV = true;
break;
case llvm::hlsl::ResourceClass::CBuffer:
Flags.CBV = true;
break;
case llvm::hlsl::ResourceClass::Sampler:
Flags.Sampler = true;
break;
}
}

const HLSLAttributedResourceType *
findAttributedResourceTypeOnField(VarDecl *VD) {
assert(VD != nullptr && "expected VarDecl");
Expand All @@ -817,8 +818,10 @@ findAttributedResourceTypeOnField(VarDecl *VD) {
return nullptr;
}

static void updateResourceClassFlagsFromRecordType(RegisterBindingFlags &Flags,
const RecordType *RT) {
// Iterate over RecordType fields and return true if any of them matched the
// register type
static bool ContainsResourceForRegisterType(Sema &S, const RecordType *RT,
RegisterType RegType) {
llvm::SmallVector<const Type *> TypesToScan;
TypesToScan.emplace_back(RT);

Expand All @@ -827,8 +830,8 @@ static void updateResourceClassFlagsFromRecordType(RegisterBindingFlags &Flags,
while (T->isArrayType())
T = T->getArrayElementTypeNoTypeQual();
if (T->isIntegralOrEnumerationType() || T->isFloatingType()) {
Flags.ContainsNumeric = true;
continue;
if (RegType == RegisterType::C)
return true;
}
const RecordType *RT = T->getAs<RecordType>();
if (!RT)
Expand All @@ -839,100 +842,84 @@ static void updateResourceClassFlagsFromRecordType(RegisterBindingFlags &Flags,
const Type *FieldTy = FD->getType().getTypePtr();
if (const HLSLAttributedResourceType *AttrResType =
dyn_cast<HLSLAttributedResourceType>(FieldTy)) {
updateResourceClassFlagsFromDeclResourceClass(
Flags, AttrResType->getAttrs().ResourceClass);
continue;
ResourceClass RC = AttrResType->getAttrs().ResourceClass;
if (getRegisterType(RC) == RegType)
return true;
} else {
TypesToScan.emplace_back(FD->getType().getTypePtr());
}
TypesToScan.emplace_back(FD->getType().getTypePtr());
}
}
return false;
}

static RegisterBindingFlags HLSLFillRegisterBindingFlags(Sema &S,
Decl *TheDecl) {
RegisterBindingFlags Flags;
static void CheckContainsResourceForRegisterType(Sema &S,
SourceLocation &ArgLoc,
Decl *D, RegisterType RegType,
bool SpecifiedSpace) {
int RegTypeNum = static_cast<int>(RegType);

// check if the decl type is groupshared
if (TheDecl->hasAttr<HLSLGroupSharedAddressSpaceAttr>()) {
Flags.Other = true;
return Flags;
if (D->hasAttr<HLSLGroupSharedAddressSpaceAttr>()) {
S.Diag(ArgLoc, diag::err_hlsl_binding_type_mismatch) << RegTypeNum;
return;
}

// Cbuffers and Tbuffers are HLSLBufferDecl types
if (HLSLBufferDecl *CBufferOrTBuffer = dyn_cast<HLSLBufferDecl>(TheDecl)) {
Flags.Resource = true;
Flags.ResourceClass = CBufferOrTBuffer->isCBuffer()
? llvm::dxil::ResourceClass::CBuffer
: llvm::dxil::ResourceClass::SRV;
if (HLSLBufferDecl *CBufferOrTBuffer = dyn_cast<HLSLBufferDecl>(D)) {
ResourceClass RC = CBufferOrTBuffer->isCBuffer() ? ResourceClass::CBuffer
: ResourceClass::SRV;
if (RegType != getRegisterType(RC))
S.Diag(D->getLocation(), diag::err_hlsl_binding_type_mismatch)
<< RegTypeNum;
return;
}

// Samplers, UAVs, and SRVs are VarDecl types
else if (VarDecl *TheVarDecl = dyn_cast<VarDecl>(TheDecl)) {
if (const HLSLAttributedResourceType *AttrResType =
findAttributedResourceTypeOnField(TheVarDecl)) {
Flags.Resource = true;
Flags.ResourceClass = AttrResType->getAttrs().ResourceClass;
} else {
const clang::Type *TheBaseType = TheVarDecl->getType().getTypePtr();
while (TheBaseType->isArrayType())
TheBaseType = TheBaseType->getArrayElementTypeNoTypeQual();

if (TheBaseType->isArithmeticType()) {
Flags.Basic = true;
if (!isDeclaredWithinCOrTBuffer(TheDecl) &&
(TheBaseType->isIntegralType(S.getASTContext()) ||
TheBaseType->isFloatingType()))
Flags.DefaultGlobals = true;
} else if (TheBaseType->isRecordType()) {
Flags.UDT = true;
const RecordType *TheRecordTy = TheBaseType->getAs<RecordType>();
updateResourceClassFlagsFromRecordType(Flags, TheRecordTy);
} else
Flags.Other = true;
}
} else {
llvm_unreachable("expected be VarDecl or HLSLBufferDecl");
assert(isa<VarDecl>(D) && "D is expected to be VarDecl or HLSLBufferDecl");
VarDecl *VD = cast<VarDecl>(D);
Comment on lines +879 to +880
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

FWIW cast<> will assert if the cast fails. It's debatable whether or not the slight amount of extra information from the specific message is worth its own assert, so this doesn't necessarily need to be changed, but I figured I'd point that out.


// Resource
if (const HLSLAttributedResourceType *AttrResType =
findAttributedResourceTypeOnField(VD)) {
if (RegType != getRegisterType(AttrResType->getAttrs().ResourceClass))
S.Diag(D->getLocation(), diag::err_hlsl_binding_type_mismatch)
<< RegTypeNum;
return;
}
return Flags;
}

enum class RegisterType { SRV, UAV, CBuffer, Sampler, C, I, Invalid };
const clang::Type *Ty = VD->getType().getTypePtr();
while (Ty->isArrayType())
Ty = Ty->getArrayElementTypeNoTypeQual();

static RegisterType getRegisterType(llvm::dxil::ResourceClass RC) {
switch (RC) {
case llvm::dxil::ResourceClass::SRV:
return RegisterType::SRV;
case llvm::dxil::ResourceClass::UAV:
return RegisterType::UAV;
case llvm::dxil::ResourceClass::CBuffer:
return RegisterType::CBuffer;
case llvm::dxil::ResourceClass::Sampler:
return RegisterType::Sampler;
}
llvm_unreachable("unexpected ResourceClass value");
}
// Basic types
if (Ty->isArithmeticType()) {
bool DeclaredInCOrTBuffer = isa<HLSLBufferDecl>(D->getDeclContext());
if (SpecifiedSpace && !DeclaredInCOrTBuffer)
S.Diag(ArgLoc, diag::err_hlsl_space_on_global_constant);

static RegisterType getRegisterType(StringRef Slot) {
switch (Slot[0]) {
case 't':
case 'T':
return RegisterType::SRV;
case 'u':
case 'U':
return RegisterType::UAV;
case 'b':
case 'B':
return RegisterType::CBuffer;
case 's':
case 'S':
return RegisterType::Sampler;
case 'c':
case 'C':
return RegisterType::C;
case 'i':
case 'I':
return RegisterType::I;
default:
return RegisterType::Invalid;
if (!DeclaredInCOrTBuffer &&
(Ty->isIntegralType(S.getASTContext()) || Ty->isFloatingType())) {
// Default Globals
if (RegType == RegisterType::CBuffer)
S.Diag(ArgLoc, diag::warn_hlsl_deprecated_register_type_b);
else if (RegType != RegisterType::C)
S.Diag(ArgLoc, diag::err_hlsl_binding_type_mismatch) << RegTypeNum;
} else {
if (RegType == RegisterType::C)
S.Diag(ArgLoc, diag::warn_hlsl_register_type_c_packoffset);
else
S.Diag(ArgLoc, diag::err_hlsl_binding_type_mismatch) << RegTypeNum;
}
} else if (Ty->isRecordType()) {
// Class/struct types - walk the declaration and check each field and
// subclass
if (!ContainsResourceForRegisterType(S, Ty->getAs<RecordType>(), RegType))
S.Diag(D->getLocation(), diag::warn_hlsl_user_defined_type_missing_member)
<< RegTypeNum;
} else {
// Anything else is an error
S.Diag(ArgLoc, diag::err_hlsl_binding_type_mismatch) << RegTypeNum;
}
}

Expand Down Expand Up @@ -969,76 +956,19 @@ static void ValidateMultipleRegisterAnnotations(Sema &S, Decl *TheDecl,
}

static void DiagnoseHLSLRegisterAttribute(Sema &S, SourceLocation &ArgLoc,
Decl *TheDecl, RegisterType RegType,
const bool SpecifiedSpace) {
Decl *D, RegisterType RegType,
bool SpecifiedSpace) {

// exactly one of these two types should be set
assert(((isa<VarDecl>(TheDecl) && !isa<HLSLBufferDecl>(TheDecl)) ||
(!isa<VarDecl>(TheDecl) && isa<HLSLBufferDecl>(TheDecl))) &&
assert(((isa<VarDecl>(D) && !isa<HLSLBufferDecl>(D)) ||
(!isa<VarDecl>(D) && isa<HLSLBufferDecl>(D))) &&
"expecting VarDecl or HLSLBufferDecl");

RegisterBindingFlags Flags = HLSLFillRegisterBindingFlags(S, TheDecl);
assert((int)Flags.Other + (int)Flags.Resource + (int)Flags.Basic +
(int)Flags.UDT ==
1 &&
"only one resource analysis result should be expected");

int RegTypeNum = static_cast<int>(RegType);

// first, if "other" is set, emit an error
if (Flags.Other) {
S.Diag(ArgLoc, diag::err_hlsl_binding_type_mismatch) << RegTypeNum;
return;
}
// check if the declaration contains resource matching the register type
CheckContainsResourceForRegisterType(S, ArgLoc, D, RegType, SpecifiedSpace);

// next, if multiple register annotations exist, check that none conflict.
ValidateMultipleRegisterAnnotations(S, TheDecl, RegType);

// next, if resource is set, make sure the register type in the register
// annotation is compatible with the variable's resource type.
if (Flags.Resource) {
RegisterType ExpRegType = getRegisterType(Flags.ResourceClass.value());
if (RegType != ExpRegType) {
S.Diag(TheDecl->getLocation(), diag::err_hlsl_binding_type_mismatch)
<< RegTypeNum;
}

return;
}

// next, handle diagnostics for when the "basic" flag is set
if (Flags.Basic) {
if (SpecifiedSpace && !isDeclaredWithinCOrTBuffer(TheDecl))
S.Diag(ArgLoc, diag::err_hlsl_space_on_global_constant);

if (Flags.DefaultGlobals) {
if (RegType == RegisterType::CBuffer)
S.Diag(ArgLoc, diag::warn_hlsl_deprecated_register_type_b);
else if (RegType != RegisterType::C)
S.Diag(ArgLoc, diag::err_hlsl_binding_type_mismatch) << RegTypeNum;
return;
}

if (RegType == RegisterType::C)
S.Diag(ArgLoc, diag::warn_hlsl_register_type_c_packoffset);
else
S.Diag(ArgLoc, diag::err_hlsl_binding_type_mismatch) << RegTypeNum;

return;
}

// finally, we handle the udt case
if (Flags.UDT) {
const bool ExpectedRegisterTypesForUDT[] = {
Flags.SRV, Flags.UAV, Flags.CBV, Flags.Sampler, Flags.ContainsNumeric};
assert((size_t)RegTypeNum < std::size(ExpectedRegisterTypesForUDT) &&
"regType has unexpected value");

if (!ExpectedRegisterTypesForUDT[RegTypeNum])
S.Diag(TheDecl->getLocation(),
diag::warn_hlsl_user_defined_type_missing_member)
<< RegTypeNum;
}
ValidateMultipleRegisterAnnotations(S, D, RegType);
}

void SemaHLSL::handleResourceBindingAttr(Decl *TheDecl, const ParsedAttr &AL) {
Expand Down
Loading