From 6c92c23bd62c9ff262e035a135bb4bfac3406f3a Mon Sep 17 00:00:00 2001 From: Farzon Lotfi Date: Thu, 30 May 2024 13:51:35 -0400 Subject: [PATCH] [HLSL] add loop unroll - `Attr.td` - Define the HLSL loop attribute hints (unroll and loop) - `AttrDocs.td` - Add documentation for unroll and loop - `CGLoopInfo.cpp` - Add codegen for HLSL unroll that maps to clang unroll expectations - `ParseStmt.cpp` - For statements if HLSL define DeclSpecAttrs via MaybeParseMicrosoftAttributes - `SemaStmtAttr.cpp` - Add the HLSL loop unroll handeling --- clang/include/clang/Basic/Attr.td | 13 ++++ clang/include/clang/Basic/AttrDocs.td | 47 +++++++++++ clang/lib/CodeGen/CGLoopInfo.cpp | 15 +++- clang/lib/Parse/ParseStmt.cpp | 11 ++- clang/lib/Sema/SemaStmtAttr.cpp | 29 +++++++ clang/test/CodeGenHLSL/loops/unroll.hlsl | 99 ++++++++++++++++++++++++ 6 files changed, 208 insertions(+), 6 deletions(-) create mode 100644 clang/test/CodeGenHLSL/loops/unroll.hlsl diff --git a/clang/include/clang/Basic/Attr.td b/clang/include/clang/Basic/Attr.td index 2665b7353ca4a..e319db08ba168 100644 --- a/clang/include/clang/Basic/Attr.td +++ b/clang/include/clang/Basic/Attr.td @@ -4114,6 +4114,19 @@ def LoopHint : Attr { let HasCustomParsing = 1; } +/// The HLSL loop attributes +def HLSLLoopHint: StmtAttr { + /// [unroll(directive)] + /// [loop] + let Spellings = [Microsoft<"unroll">, Microsoft<"loop">]; + let Args = [UnsignedArgument<"directive">]; + let Subjects = SubjectList<[ForStmt, WhileStmt, DoStmt], + ErrorDiag, "'for', 'while', and 'do' statements">; + let LangOpts = [HLSL]; + let Documentation = [HLSLLoopHintDocs, HLSLUnrollHintDocs]; + let HasCustomParsing = 1; +} + def CapturedRecord : InheritableAttr { // This attribute has no spellings as it is only ever created implicitly. let Spellings = []; diff --git a/clang/include/clang/Basic/AttrDocs.td b/clang/include/clang/Basic/AttrDocs.td index a313e811c9d21..d84af9402d6db 100644 --- a/clang/include/clang/Basic/AttrDocs.td +++ b/clang/include/clang/Basic/AttrDocs.td @@ -7342,6 +7342,53 @@ where shaders must be compiled into a library and linked at runtime. }]; } +def HLSLLoopHintDocs : Documentation { + let Category = DocCatStmt; + let Heading = "#[loop]"; + let Content = [{ +The ``[loop]`` directive allows loop optimization hints to be +specified for the subsequent loop. The directive allows unrolling to +be disabled and is not compatible with [unroll(x)]. + +Specifying the parameter, ``[loop]``, directs the +unroller to not unroll the loop. + +.. code-block:: hlsl + + [loop] + for (...) { + ... + } + +See `hlsl loop extensions +`_ +for details. + }]; +} + +def HLSLUnrollHintDocs : Documentation { + let Category = DocCatStmt; + let Heading = "[unroll(x)]"; + let Content = [{ +Loop unrolling optimization hints can be specified with ``[unroll(x)]`` +. The attribute is placed immediately before a for, while, +or do-while. +Specifying the parameter, ``[unroll(_value_)]``, directs the +unroller to unroll the loop ``_value_`` times. Note: [unroll(x)] is not compatible with [loop]. + +.. code-block:: hlsl + + [unroll(4)] + for (...) { + ... + } +See +`hlsl loop extensions +`_ +for details. + }]; +} + def ClangRandomizeLayoutDocs : Documentation { let Category = DocCatDecl; let Heading = "randomize_layout, no_randomize_layout"; diff --git a/clang/lib/CodeGen/CGLoopInfo.cpp b/clang/lib/CodeGen/CGLoopInfo.cpp index 0d4800b90a2f2..6b886bd6b6d2c 100644 --- a/clang/lib/CodeGen/CGLoopInfo.cpp +++ b/clang/lib/CodeGen/CGLoopInfo.cpp @@ -612,9 +612,9 @@ void LoopInfoStack::push(BasicBlock *Header, clang::ASTContext &Ctx, const LoopHintAttr *LH = dyn_cast(Attr); const OpenCLUnrollHintAttr *OpenCLHint = dyn_cast(Attr); - + const HLSLLoopHintAttr *HLSLLoopHint = dyn_cast(Attr); // Skip non loop hint attributes - if (!LH && !OpenCLHint) { + if (!LH && !OpenCLHint && !HLSLLoopHint) { continue; } @@ -635,6 +635,17 @@ void LoopInfoStack::push(BasicBlock *Header, clang::ASTContext &Ctx, Option = LoopHintAttr::UnrollCount; State = LoopHintAttr::Numeric; } + } else if (HLSLLoopHint) { + ValueInt = HLSLLoopHint->getDirective(); + if (HLSLLoopHint->getSemanticSpelling() == + HLSLLoopHintAttr::Spelling::Microsoft_unroll) { + if (ValueInt == 0) + State = LoopHintAttr::Enable; + if (ValueInt > 0) { + Option = LoopHintAttr::UnrollCount; + State = LoopHintAttr::Numeric; + } + } } else if (LH) { auto *ValueExpr = LH->getValue(); if (ValueExpr) { diff --git a/clang/lib/Parse/ParseStmt.cpp b/clang/lib/Parse/ParseStmt.cpp index c25203243ee49..b8a717820c418 100644 --- a/clang/lib/Parse/ParseStmt.cpp +++ b/clang/lib/Parse/ParseStmt.cpp @@ -114,18 +114,21 @@ Parser::ParseStatementOrDeclaration(StmtVector &Stmts, // here because we don't want to allow arbitrary orderings. ParsedAttributes CXX11Attrs(AttrFactory); MaybeParseCXX11Attributes(CXX11Attrs, /*MightBeObjCMessageSend*/ true); - ParsedAttributes GNUAttrs(AttrFactory); + ParsedAttributes DeclSpecAttrs(AttrFactory); if (getLangOpts().OpenCL) - MaybeParseGNUAttributes(GNUAttrs); + MaybeParseGNUAttributes(DeclSpecAttrs); + + if (getLangOpts().HLSL) + MaybeParseMicrosoftAttributes(DeclSpecAttrs); StmtResult Res = ParseStatementOrDeclarationAfterAttributes( - Stmts, StmtCtx, TrailingElseLoc, CXX11Attrs, GNUAttrs); + Stmts, StmtCtx, TrailingElseLoc, CXX11Attrs, DeclSpecAttrs); MaybeDestroyTemplateIds(); // Attributes that are left should all go on the statement, so concatenate the // two lists. ParsedAttributes Attrs(AttrFactory); - takeAndConcatenateAttrs(CXX11Attrs, GNUAttrs, Attrs); + takeAndConcatenateAttrs(CXX11Attrs, DeclSpecAttrs, Attrs); assert((Attrs.empty() || Res.isInvalid() || Res.isUsable()) && "attributes on empty statement"); diff --git a/clang/lib/Sema/SemaStmtAttr.cpp b/clang/lib/Sema/SemaStmtAttr.cpp index 6f538ed55cb72..35a20c113f628 100644 --- a/clang/lib/Sema/SemaStmtAttr.cpp +++ b/clang/lib/Sema/SemaStmtAttr.cpp @@ -16,6 +16,7 @@ #include "clang/Basic/TargetInfo.h" #include "clang/Sema/DelayedDiagnostic.h" #include "clang/Sema/Lookup.h" +#include "clang/Sema/ParsedAttr.h" #include "clang/Sema/ScopeInfo.h" #include "clang/Sema/SemaInternal.h" #include "llvm/ADT/StringExtras.h" @@ -584,6 +585,32 @@ static Attr *handleOpenCLUnrollHint(Sema &S, Stmt *St, const ParsedAttr &A, return ::new (S.Context) OpenCLUnrollHintAttr(S.Context, A, UnrollFactor); } +static Attr *handleHLSLLoopHintAttr(Sema &S, Stmt *St, const ParsedAttr &A, + SourceRange Range) { + unsigned UnrollFactor = 0; + if (A.getNumArgs() == 1) { + Expr *E = A.getArgAsExpr(0); + std::optional ArgVal; + + if (!(ArgVal = E->getIntegerConstantExpr(S.Context))) { + S.Diag(A.getLoc(), diag::err_attribute_argument_type) + << A << AANT_ArgumentIntegerConstant << E->getSourceRange(); + return nullptr; + } + + int Val = ArgVal->getSExtValue(); + if (Val <= 0) { + S.Diag(A.getRange().getBegin(), + diag::err_attribute_requires_positive_integer) + << A << /* positive */ 0; + return nullptr; + } + UnrollFactor = static_cast(Val); + } + + return ::new (S.Context) HLSLLoopHintAttr(S.Context, A, UnrollFactor); +} + static Attr *ProcessStmtAttribute(Sema &S, Stmt *St, const ParsedAttr &A, SourceRange Range) { if (A.isInvalid() || A.getKind() == ParsedAttr::IgnoredAttribute) @@ -618,6 +645,8 @@ static Attr *ProcessStmtAttribute(Sema &S, Stmt *St, const ParsedAttr &A, return handleFallThroughAttr(S, St, A, Range); case ParsedAttr::AT_LoopHint: return handleLoopHintAttr(S, St, A, Range); + case ParsedAttr::AT_HLSLLoopHint: + return handleHLSLLoopHintAttr(S, St, A, Range); case ParsedAttr::AT_OpenCLUnrollHint: return handleOpenCLUnrollHint(S, St, A, Range); case ParsedAttr::AT_Suppress: diff --git a/clang/test/CodeGenHLSL/loops/unroll.hlsl b/clang/test/CodeGenHLSL/loops/unroll.hlsl new file mode 100644 index 0000000000000..0ebe5b0e847e4 --- /dev/null +++ b/clang/test/CodeGenHLSL/loops/unroll.hlsl @@ -0,0 +1,99 @@ +// RUN: %clang_cc1 -std=hlsl2021 -finclude-default-header -x hlsl -triple \ +// RUN: dxil-pc-shadermodel6.3-library %s -emit-llvm -o - | FileCheck %s + +/*** for ***/ +void for_count() +{ +// CHECK-LABEL: for_count + [unroll(8)] + for( int i = 0; i < 1000; ++i); +// CHECK: br label %{{.*}}, !llvm.loop ![[FOR_DISTINCT:.*]] +} + +void for_disable() +{ +// CHECK-LABEL: for_disable + [loop] + for( int i = 0; i < 1000; ++i); +// CHECK: br label %{{.*}}, !llvm.loop ![[FOR_DISABLE:.*]] +} + +void for_enable() +{ +// CHECK-LABEL: for_enable + [unroll] + for( int i = 0; i < 1000; ++i); +// CHECK: br label %{{.*}}, !llvm.loop ![[FOR_ENABLE:.*]] +} + +/*** while ***/ +void while_count() +{ +// CHECK-LABEL: while_count + int i = 1000; + [unroll(4)] + while(i-->0); +// CHECK: br label %{{.*}}, !llvm.loop ![[WHILE_DISTINCT:.*]] +} + +void while_disable() +{ +// CHECK-LABEL: while_disable + int i = 1000; + [loop] + while(i-->0); +// CHECK: br label %{{.*}}, !llvm.loop ![[WHILE_DISABLE:.*]] +} + +void while_enable() +{ +// CHECK-LABEL: while_enable + int i = 1000; + [unroll] + while(i-->0); +// CHECK: br label %{{.*}}, !llvm.loop ![[WHILE_ENABLE:.*]] +} + +/*** do ***/ +void do_count() +{ +// CHECK-LABEL: do_count + int i = 1000; + [unroll(16)] + do {} while(i--> 0); +// CHECK: br i1 %{{.*}}, label %{{.*}}, label %{{.*}}, !llvm.loop ![[DO_DISTINCT:.*]] +} + +void do_disable() +{ +// CHECK-LABEL: do_disable + int i = 1000; + [loop] + do {} while(i--> 0); +// CHECK: br i1 %{{.*}}, label %{{.*}}, label %{{.*}}, !llvm.loop ![[DO_DISABLE:.*]] +} + +void do_enable() +{ +// CHECK-LABEL: do_enable + int i = 1000; + [unroll] + do {} while(i--> 0); +// CHECK: br i1 %{{.*}}, label %{{.*}}, label %{{.*}}, !llvm.loop ![[DO_ENABLE:.*]] +} + + +// CHECK: ![[FOR_DISTINCT]] = distinct !{![[FOR_DISTINCT]], ![[FOR_COUNT:.*]]} +// CHECK: ![[FOR_COUNT]] = !{!"llvm.loop.unroll.count", i32 8} +// CHECK: ![[FOR_DISABLE]] = distinct !{![[FOR_DISABLE]], ![[DISABLE:.*]]} +// CHECK: ![[DISABLE]] = !{!"llvm.loop.unroll.disable"} +// CHECK: ![[FOR_ENABLE]] = distinct !{![[FOR_ENABLE]], ![[ENABLE:.*]]} +// CHECK: ![[ENABLE]] = !{!"llvm.loop.unroll.enable"} +// CHECK: ![[WHILE_DISTINCT]] = distinct !{![[WHILE_DISTINCT]], ![[WHILE_COUNT:.*]]} +// CHECK: ![[WHILE_COUNT]] = !{!"llvm.loop.unroll.count", i32 4} +// CHECK: ![[WHILE_DISABLE]] = distinct !{![[WHILE_DISABLE]], ![[DISABLE]]} +// CHECK: ![[WHILE_ENABLE]] = distinct !{![[WHILE_ENABLE]], ![[ENABLE]]} +// CHECK: ![[DO_DISTINCT]] = distinct !{![[DO_DISTINCT]], ![[DO_COUNT:.*]]} +// CHECK: ![[DO_COUNT]] = !{!"llvm.loop.unroll.count", i32 16} +// CHECK: ![[DO_DISABLE]] = distinct !{![[DO_DISABLE]], ![[DISABLE]]} +// CHECK: ![[DO_ENABLE]] = distinct !{![[DO_ENABLE]], ![[ENABLE]]}