Skip to content

Commit 9736c86

Browse files
committed
[HLSL] [SPIR-V] Add initial support for typed buffer counters
This is part 1 of implementing the typed buffer counters proposal: https://github.com/llvm/wg-hlsl/blob/main/proposals/0023-typed-buffer-counters.md This patch adds the initial plumbing for supporting counter variables associated with structured buffers for the SPIR-V backend. It introduces an `IsCounter` attribute to `HLSLAttributedResourceType` and threads it through the AST, type printing, and mangling. It also adds a `__counter_handle` member to the relevant buffer types in `HLSLBuiltinTypeDeclBuilder`.
1 parent dcd0a2e commit 9736c86

17 files changed

+247
-59
lines changed

clang/include/clang/AST/TypeBase.h

Lines changed: 12 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -6700,15 +6700,21 @@ class HLSLAttributedResourceType : public Type, public llvm::FoldingSetNode {
67006700
LLVM_PREFERRED_TYPE(bool)
67016701
uint8_t RawBuffer : 1;
67026702

6703+
LLVM_PREFERRED_TYPE(bool)
6704+
uint8_t IsCounter : 1;
6705+
67036706
Attributes(llvm::dxil::ResourceClass ResourceClass, bool IsROV = false,
6704-
bool RawBuffer = false)
6705-
: ResourceClass(ResourceClass), IsROV(IsROV), RawBuffer(RawBuffer) {}
6707+
bool RawBuffer = false, bool IsCounter = false)
6708+
: ResourceClass(ResourceClass), IsROV(IsROV), RawBuffer(RawBuffer),
6709+
IsCounter(IsCounter) {}
67066710

6707-
Attributes() : Attributes(llvm::dxil::ResourceClass::UAV, false, false) {}
6711+
Attributes()
6712+
: Attributes(llvm::dxil::ResourceClass::UAV, false, false, false) {}
67086713

67096714
friend bool operator==(const Attributes &LHS, const Attributes &RHS) {
6710-
return std::tie(LHS.ResourceClass, LHS.IsROV, LHS.RawBuffer) ==
6711-
std::tie(RHS.ResourceClass, RHS.IsROV, RHS.RawBuffer);
6715+
return std::tie(LHS.ResourceClass, LHS.IsROV, LHS.RawBuffer,
6716+
LHS.IsCounter) == std::tie(RHS.ResourceClass, RHS.IsROV,
6717+
RHS.RawBuffer, RHS.IsCounter);
67126718
}
67136719
friend bool operator!=(const Attributes &LHS, const Attributes &RHS) {
67146720
return !(LHS == RHS);
@@ -6749,6 +6755,7 @@ class HLSLAttributedResourceType : public Type, public llvm::FoldingSetNode {
67496755
ID.AddInteger(static_cast<uint32_t>(Attrs.ResourceClass));
67506756
ID.AddBoolean(Attrs.IsROV);
67516757
ID.AddBoolean(Attrs.RawBuffer);
6758+
ID.AddBoolean(Attrs.IsCounter);
67526759
}
67536760

67546761
static bool classof(const Type *T) {

clang/include/clang/AST/TypeProperties.td

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -662,14 +662,17 @@ let Class = HLSLAttributedResourceType in {
662662
def : Property<"rawBuffer", Bool> {
663663
let Read = [{ node->getAttrs().RawBuffer }];
664664
}
665+
def : Property<"isCounter", Bool> {
666+
let Read = [{ node->getAttrs().IsCounter }];
667+
}
665668
def : Property<"wrappedTy", QualType> {
666669
let Read = [{ node->getWrappedType() }];
667670
}
668671
def : Property<"containedTy", QualType> {
669672
let Read = [{ node->getContainedType() }];
670673
}
671674
def : Creator<[{
672-
HLSLAttributedResourceType::Attributes attrs(static_cast<llvm::dxil::ResourceClass>(resClass), isROV, rawBuffer);
675+
HLSLAttributedResourceType::Attributes attrs(static_cast<llvm::dxil::ResourceClass>(resClass), isROV, rawBuffer, isCounter);
673676
return ctx.getHLSLAttributedResourceType(wrappedTy, containedTy, attrs);
674677
}]>;
675678
}

clang/include/clang/Basic/Attr.td

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5059,6 +5059,12 @@ def HLSLRawBuffer : TypeAttr {
50595059
let Documentation = [InternalOnly];
50605060
}
50615061

5062+
def HLSLIsCounter : TypeAttr {
5063+
let Spellings = [CXX11<"hlsl", "is_counter">];
5064+
let LangOpts = [HLSL];
5065+
let Documentation = [InternalOnly];
5066+
}
5067+
50625068
def HLSLGroupSharedAddressSpace : TypeAttr {
50635069
let Spellings = [CustomKeyword<"groupshared">];
50645070
let Subjects = SubjectList<[Var]>;

clang/lib/AST/ItaniumMangle.cpp

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4624,6 +4624,8 @@ void CXXNameMangler::mangleType(const HLSLAttributedResourceType *T) {
46244624
Str += "_ROV";
46254625
if (Attrs.RawBuffer)
46264626
Str += "_Raw";
4627+
if (Attrs.IsCounter)
4628+
Str += "_Counter";
46274629
if (T->hasContainedType())
46284630
Str += "_CT";
46294631
mangleVendorQualifier(Str);

clang/lib/AST/TypePrinter.cpp

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2033,6 +2033,7 @@ void TypePrinter::printAttributedAfter(const AttributedType *T,
20332033
case attr::HLSLROV:
20342034
case attr::HLSLRawBuffer:
20352035
case attr::HLSLContainedType:
2036+
case attr::HLSLIsCounter:
20362037
llvm_unreachable("HLSL resource type attributes handled separately");
20372038

20382039
case attr::OpenCLPrivateAddressSpace:
@@ -2181,6 +2182,8 @@ void TypePrinter::printHLSLAttributedResourceAfter(
21812182
OS << " [[hlsl::is_rov]]";
21822183
if (Attrs.RawBuffer)
21832184
OS << " [[hlsl::raw_buffer]]";
2185+
if (Attrs.IsCounter)
2186+
OS << " [[hlsl::is_counter]]";
21842187

21852188
QualType ContainedTy = T->getContainedType();
21862189
if (!ContainedTy.isNull()) {

clang/lib/CodeGen/Targets/SPIR.cpp

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -486,6 +486,12 @@ llvm::Type *CommonSPIRTargetCodeGenInfo::getHLSLType(
486486
return getSPIRVImageTypeFromHLSLResource(ResAttrs, ContainedTy, CGM);
487487
}
488488

489+
if (ResAttrs.IsCounter) {
490+
llvm::Type *ElemType = llvm::Type::getInt32Ty(Ctx);
491+
uint32_t StorageClass = /* StorageBuffer storage class */ 12;
492+
return llvm::TargetExtType::get(Ctx, "spirv.VulkanBuffer", {ElemType},
493+
{StorageClass, true});
494+
}
489495
llvm::Type *ElemType = CGM.getTypes().ConvertTypeForMem(ContainedTy);
490496
llvm::ArrayType *RuntimeArrayType = llvm::ArrayType::get(ElemType, 0);
491497
uint32_t StorageClass = /* StorageBuffer storage class */ 12;

clang/lib/Sema/HLSLBuiltinTypeDeclBuilder.cpp

Lines changed: 139 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -57,6 +57,15 @@ CXXConstructorDecl *lookupCopyConstructor(QualType ResTy) {
5757
return CD;
5858
return nullptr;
5959
}
60+
61+
/// Set up common members and attributes for buffer types
62+
static bool resourceHasCounter(const CXXRecordDecl *Decl) {
63+
StringRef Name = Decl->getName();
64+
return Name == "RWStructuredBuffer" || Name == "AppendStructuredBuffer" ||
65+
Name == "ConsumeStructuredBuffer" ||
66+
Name == "RasterizerOrderedStructuredBuffer";
67+
}
68+
6069
} // namespace
6170

6271
// Builder for template arguments of builtin types. Used internally
@@ -138,7 +147,16 @@ struct BuiltinTypeMethodBuilder {
138147
// LastStmt - refers to the last statement in the method body; referencing
139148
// LastStmt will remove the statement from the method body since
140149
// it will be linked from the new expression being constructed.
141-
enum class PlaceHolder { _0, _1, _2, _3, _4, Handle = 128, LastStmt };
150+
enum class PlaceHolder {
151+
_0,
152+
_1,
153+
_2,
154+
_3,
155+
_4,
156+
Handle = 128,
157+
CounterHandle,
158+
LastStmt
159+
};
142160

143161
Expr *convertPlaceholder(PlaceHolder PH);
144162
Expr *convertPlaceholder(LocalVar &Var);
@@ -178,10 +196,17 @@ struct BuiltinTypeMethodBuilder {
178196
template <typename ResourceT, typename ValueT>
179197
BuiltinTypeMethodBuilder &setHandleFieldOnResource(ResourceT ResourceRecord,
180198
ValueT HandleValue);
199+
template <typename T>
200+
BuiltinTypeMethodBuilder &
201+
accessCounterHandleFieldOnResource(T ResourceRecord);
202+
template <typename ResourceT, typename ValueT>
203+
BuiltinTypeMethodBuilder &
204+
setCounterHandleFieldOnResource(ResourceT ResourceRecord, ValueT HandleValue);
181205
template <typename T> BuiltinTypeMethodBuilder &returnValue(T ReturnValue);
182206
BuiltinTypeMethodBuilder &returnThis();
183207
BuiltinTypeDeclBuilder &finalize();
184208
Expr *getResourceHandleExpr();
209+
Expr *getResourceCounterHandleExpr();
185210

186211
private:
187212
void createDecl();
@@ -346,6 +371,8 @@ TemplateParameterListBuilder::finalizeTemplateArgs(ConceptDecl *CD) {
346371
Expr *BuiltinTypeMethodBuilder::convertPlaceholder(PlaceHolder PH) {
347372
if (PH == PlaceHolder::Handle)
348373
return getResourceHandleExpr();
374+
if (PH == PlaceHolder::CounterHandle)
375+
return getResourceCounterHandleExpr();
349376

350377
if (PH == PlaceHolder::LastStmt) {
351378
assert(!StmtsList.empty() && "no statements in the list");
@@ -467,6 +494,18 @@ Expr *BuiltinTypeMethodBuilder::getResourceHandleExpr() {
467494
OK_Ordinary);
468495
}
469496

497+
Expr *BuiltinTypeMethodBuilder::getResourceCounterHandleExpr() {
498+
ensureCompleteDecl();
499+
500+
ASTContext &AST = DeclBuilder.SemaRef.getASTContext();
501+
CXXThisExpr *This = CXXThisExpr::Create(
502+
AST, SourceLocation(), Method->getFunctionObjectParameterType(), true);
503+
FieldDecl *HandleField = DeclBuilder.getResourceCounterHandleField();
504+
return MemberExpr::CreateImplicit(AST, This, false, HandleField,
505+
HandleField->getType(), VK_LValue,
506+
OK_Ordinary);
507+
}
508+
470509
BuiltinTypeMethodBuilder &
471510
BuiltinTypeMethodBuilder::declareLocalVar(LocalVar &Var) {
472511
ensureCompleteDecl();
@@ -583,6 +622,44 @@ BuiltinTypeMethodBuilder::setHandleFieldOnResource(ResourceT ResourceRecord,
583622
return *this;
584623
}
585624

625+
template <typename T>
626+
BuiltinTypeMethodBuilder &
627+
BuiltinTypeMethodBuilder::accessCounterHandleFieldOnResource(T ResourceRecord) {
628+
ensureCompleteDecl();
629+
630+
Expr *ResourceExpr = convertPlaceholder(ResourceRecord);
631+
632+
ASTContext &AST = DeclBuilder.SemaRef.getASTContext();
633+
FieldDecl *HandleField = DeclBuilder.getResourceCounterHandleField();
634+
MemberExpr *HandleExpr = MemberExpr::CreateImplicit(
635+
AST, ResourceExpr, false, HandleField, HandleField->getType(), VK_LValue,
636+
OK_Ordinary);
637+
StmtsList.push_back(HandleExpr);
638+
return *this;
639+
}
640+
641+
template <typename ResourceT, typename ValueT>
642+
BuiltinTypeMethodBuilder &
643+
BuiltinTypeMethodBuilder::setCounterHandleFieldOnResource(
644+
ResourceT ResourceRecord, ValueT HandleValue) {
645+
ensureCompleteDecl();
646+
647+
Expr *ResourceExpr = convertPlaceholder(ResourceRecord);
648+
Expr *HandleValueExpr = convertPlaceholder(HandleValue);
649+
650+
ASTContext &AST = DeclBuilder.SemaRef.getASTContext();
651+
FieldDecl *HandleField = DeclBuilder.getResourceCounterHandleField();
652+
MemberExpr *HandleMemberExpr = MemberExpr::CreateImplicit(
653+
AST, ResourceExpr, false, HandleField, HandleField->getType(), VK_LValue,
654+
OK_Ordinary);
655+
Stmt *AssignStmt = BinaryOperator::Create(
656+
DeclBuilder.SemaRef.getASTContext(), HandleMemberExpr, HandleValueExpr,
657+
BO_Assign, HandleMemberExpr->getType(), ExprValueKind::VK_PRValue,
658+
ExprObjectKind::OK_Ordinary, SourceLocation(), FPOptionsOverride());
659+
StmtsList.push_back(AssignStmt);
660+
return *this;
661+
}
662+
586663
template <typename T>
587664
BuiltinTypeMethodBuilder &BuiltinTypeMethodBuilder::returnValue(T ReturnValue) {
588665
ensureCompleteDecl();
@@ -722,6 +799,15 @@ BuiltinTypeDeclBuilder::addMemberVariable(StringRef Name, QualType Type,
722799
return *this;
723800
}
724801

802+
BuiltinTypeDeclBuilder &BuiltinTypeDeclBuilder::addHandleMembers(
803+
ResourceClass RC, bool IsROV, bool RawBuffer, AccessSpecifier Access) {
804+
addHandleMember(RC, IsROV, RawBuffer, Access);
805+
if (resourceHasCounter(Record)) {
806+
addCounterHandleMember(RC, IsROV, RawBuffer, Access);
807+
}
808+
return *this;
809+
}
810+
725811
BuiltinTypeDeclBuilder &BuiltinTypeDeclBuilder::addHandleMember(
726812
ResourceClass RC, bool IsROV, bool RawBuffer, AccessSpecifier Access) {
727813
assert(!Record->isCompleteDefinition() && "record is already complete");
@@ -745,6 +831,30 @@ BuiltinTypeDeclBuilder &BuiltinTypeDeclBuilder::addHandleMember(
745831
return *this;
746832
}
747833

834+
BuiltinTypeDeclBuilder &BuiltinTypeDeclBuilder::addCounterHandleMember(
835+
ResourceClass RC, bool IsROV, bool RawBuffer, AccessSpecifier Access) {
836+
assert(!Record->isCompleteDefinition() && "record is already complete");
837+
838+
ASTContext &Ctx = SemaRef.getASTContext();
839+
TypeSourceInfo *ElementTypeInfo =
840+
Ctx.getTrivialTypeSourceInfo(getHandleElementType(), SourceLocation());
841+
842+
// add handle member with resource type attributes
843+
QualType AttributedResTy = QualType();
844+
SmallVector<const Attr *> Attrs = {
845+
HLSLResourceClassAttr::CreateImplicit(Ctx, RC),
846+
IsROV ? HLSLROVAttr::CreateImplicit(Ctx) : nullptr,
847+
RawBuffer ? HLSLRawBufferAttr::CreateImplicit(Ctx) : nullptr,
848+
ElementTypeInfo
849+
? HLSLContainedTypeAttr::CreateImplicit(Ctx, ElementTypeInfo)
850+
: nullptr,
851+
HLSLIsCounterAttr::CreateImplicit(Ctx)};
852+
if (CreateHLSLAttributedResourceType(SemaRef, Ctx.HLSLResourceTy, Attrs,
853+
AttributedResTy))
854+
addMemberVariable("__counter_handle", AttributedResTy, {}, Access);
855+
return *this;
856+
}
857+
748858
// Adds default constructor to the resource class:
749859
// Resource::Resource()
750860
BuiltinTypeDeclBuilder &BuiltinTypeDeclBuilder::addDefaultHandleConstructor() {
@@ -848,12 +958,18 @@ BuiltinTypeDeclBuilder &BuiltinTypeDeclBuilder::addCopyConstructor() {
848958

849959
using PH = BuiltinTypeMethodBuilder::PlaceHolder;
850960

851-
return BuiltinTypeMethodBuilder(*this, /*Name=*/"", AST.VoidTy,
852-
/*IsConst=*/false, /*IsCtor=*/true)
853-
.addParam("other", ConstRecordRefType)
961+
BuiltinTypeMethodBuilder MMB(*this, /*Name=*/"", AST.VoidTy,
962+
/*IsConst=*/false, /*IsCtor=*/true);
963+
MMB.addParam("other", ConstRecordRefType)
854964
.accessHandleFieldOnResource(PH::_0)
855-
.assign(PH::Handle, PH::LastStmt)
856-
.finalize();
965+
.assign(PH::Handle, PH::LastStmt);
966+
967+
if (getResourceCounterHandleField()) {
968+
MMB.accessCounterHandleFieldOnResource(PH::_0).assign(PH::CounterHandle,
969+
PH::LastStmt);
970+
}
971+
972+
return MMB.finalize();
857973
}
858974

859975
BuiltinTypeDeclBuilder &BuiltinTypeDeclBuilder::addCopyAssignmentOperator() {
@@ -868,12 +984,17 @@ BuiltinTypeDeclBuilder &BuiltinTypeDeclBuilder::addCopyAssignmentOperator() {
868984

869985
using PH = BuiltinTypeMethodBuilder::PlaceHolder;
870986
DeclarationName Name = AST.DeclarationNames.getCXXOperatorName(OO_Equal);
871-
return BuiltinTypeMethodBuilder(*this, Name, RecordRefType)
872-
.addParam("other", ConstRecordRefType)
987+
BuiltinTypeMethodBuilder MMB(*this, Name, RecordRefType);
988+
MMB.addParam("other", ConstRecordRefType)
873989
.accessHandleFieldOnResource(PH::_0)
874-
.assign(PH::Handle, PH::LastStmt)
875-
.returnThis()
876-
.finalize();
990+
.assign(PH::Handle, PH::LastStmt);
991+
992+
if (getResourceCounterHandleField()) {
993+
MMB.accessCounterHandleFieldOnResource(PH::_0).assign(PH::CounterHandle,
994+
PH::LastStmt);
995+
}
996+
997+
return MMB.returnThis().finalize();
877998
}
878999

8791000
BuiltinTypeDeclBuilder &BuiltinTypeDeclBuilder::addArraySubscriptOperators() {
@@ -909,6 +1030,13 @@ FieldDecl *BuiltinTypeDeclBuilder::getResourceHandleField() const {
9091030
return I->second;
9101031
}
9111032

1033+
FieldDecl *BuiltinTypeDeclBuilder::getResourceCounterHandleField() const {
1034+
auto I = Fields.find("__counter_handle");
1035+
if (I == Fields.end())
1036+
return nullptr;
1037+
return I->second;
1038+
}
1039+
9121040
QualType BuiltinTypeDeclBuilder::getFirstTemplateTypeParam() {
9131041
assert(Template && "record it not a template");
9141042
if (const auto *TTD = dyn_cast<TemplateTypeParmDecl>(

clang/lib/Sema/HLSLBuiltinTypeDeclBuilder.h

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -72,8 +72,14 @@ class BuiltinTypeDeclBuilder {
7272
AccessSpecifier Access = AccessSpecifier::AS_private);
7373

7474
BuiltinTypeDeclBuilder &
75+
addHandleMembers(ResourceClass RC, bool IsROV, bool RawBuffer,
76+
AccessSpecifier Access = AccessSpecifier::AS_private);
77+
BuiltinTypeDeclBuilder &
7578
addHandleMember(ResourceClass RC, bool IsROV, bool RawBuffer,
7679
AccessSpecifier Access = AccessSpecifier::AS_private);
80+
BuiltinTypeDeclBuilder &
81+
addCounterHandleMember(ResourceClass RC, bool IsROV, bool RawBuffer,
82+
AccessSpecifier Access = AccessSpecifier::AS_private);
7783
BuiltinTypeDeclBuilder &addArraySubscriptOperators();
7884

7985
// Builtin types constructors
@@ -96,6 +102,7 @@ class BuiltinTypeDeclBuilder {
96102

97103
private:
98104
FieldDecl *getResourceHandleField() const;
105+
FieldDecl *getResourceCounterHandleField() const;
99106
QualType getFirstTemplateTypeParam();
100107
QualType getHandleElementType();
101108
Expr *getConstantIntExpr(int value);

clang/lib/Sema/HLSLExternalSemaSource.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -130,7 +130,7 @@ static BuiltinTypeDeclBuilder setupBufferType(CXXRecordDecl *Decl, Sema &S,
130130
ResourceClass RC, bool IsROV,
131131
bool RawBuffer) {
132132
return BuiltinTypeDeclBuilder(S, Decl)
133-
.addHandleMember(RC, IsROV, RawBuffer)
133+
.addHandleMembers(RC, IsROV, RawBuffer)
134134
.addDefaultHandleConstructor()
135135
.addCopyConstructor()
136136
.addCopyAssignmentOperator()

clang/lib/Sema/SemaHLSL.cpp

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1811,6 +1811,13 @@ bool clang::CreateHLSLAttributedResourceType(
18111811
}
18121812
ResAttrs.RawBuffer = true;
18131813
break;
1814+
case attr::HLSLIsCounter:
1815+
if (ResAttrs.IsCounter) {
1816+
S.Diag(A->getLocation(), diag::warn_duplicate_attribute_exact) << A;
1817+
return false;
1818+
}
1819+
ResAttrs.IsCounter = true;
1820+
break;
18141821
case attr::HLSLContainedType: {
18151822
const HLSLContainedTypeAttr *CTAttr = cast<HLSLContainedTypeAttr>(A);
18161823
QualType Ty = CTAttr->getType();
@@ -1903,6 +1910,10 @@ bool SemaHLSL::handleResourceTypeAttr(QualType T, const ParsedAttr &AL) {
19031910
A = HLSLRawBufferAttr::Create(getASTContext(), ACI);
19041911
break;
19051912

1913+
case ParsedAttr::AT_HLSLIsCounter:
1914+
A = HLSLIsCounterAttr::Create(getASTContext(), ACI);
1915+
break;
1916+
19061917
case ParsedAttr::AT_HLSLContainedType: {
19071918
if (AL.getNumArgs() != 1 && !AL.hasParsedType()) {
19081919
Diag(AL.getLoc(), diag::err_attribute_wrong_number_arguments) << AL << 1;

0 commit comments

Comments
 (0)