Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[HLSL] add loop unroll #93879

Merged
merged 5 commits into from
Jul 11, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 12 additions & 0 deletions clang/include/clang/Basic/Attr.td
Original file line number Diff line number Diff line change
Expand Up @@ -4115,6 +4115,18 @@ def LoopHint : Attr {
let HasCustomParsing = 1;
}

/// The HLSL loop attributes
def HLSLLoopHint: StmtAttr {
/// [unroll(directive)]
/// [loop]
let Spellings = [Microsoft<"unroll">, Microsoft<"loop">];
let Args = [UnsignedArgument<"directive", /*opt*/1>];
let Subjects = SubjectList<[ForStmt, WhileStmt, DoStmt],
ErrorDiag, "'for', 'while', and 'do' statements">;
let LangOpts = [HLSL];
let Documentation = [HLSLLoopHintDocs, HLSLUnrollHintDocs];
}

def CapturedRecord : InheritableAttr {
// This attribute has no spellings as it is only ever created implicitly.
let Spellings = [];
Expand Down
99 changes: 97 additions & 2 deletions clang/include/clang/Basic/AttrDocs.td
Original file line number Diff line number Diff line change
Expand Up @@ -7345,6 +7345,100 @@ where shaders must be compiled into a library and linked at runtime.
}];
}

def HLSLLoopHintDocs : Documentation {
let Category = DocCatStmt;
let Heading = "[loop]";
let Content = [{
The ``[loop]`` directive allows loop optimization hints to be
specified for the subsequent loop. The directive allows unrolling to
be disabled and is not compatible with [unroll(x)].

Specifying the parameter, ``[loop]``, directs the
unroller to not unroll the loop.

.. code-block:: hlsl

[loop]
for (...) {
...
}

.. code-block:: hlsl

[loop]
while (...) {
...
}

.. code-block:: hlsl

[loop]
do {
...
} while (...)

See `hlsl loop extensions <https://learn.microsoft.com/en-us/windows/win32/direct3dhlsl/dx-graphics-hlsl-for>`_
for details.
}];
}

def HLSLUnrollHintDocs : Documentation {
let Category = DocCatStmt;
let Heading = "[unroll(x)], [unroll]";
let Content = [{
Loop unrolling optimization hints can be specified with ``[unroll(x)]``
. The attribute is placed immediately before a for, while,
or do-while.
Specifying the parameter, ``[unroll(_value_)]``, directs the
unroller to unroll the loop ``_value_`` times. Note: [unroll(x)] is not compatible with [loop].

.. code-block:: hlsl

[unroll(4)]
for (...) {
...
}

.. code-block:: hlsl

[unroll]
for (...) {
...
}

.. code-block:: hlsl

[unroll(4)]
while (...) {
...
}

.. code-block:: hlsl

[unroll]
while (...) {
...
}

.. code-block:: hlsl

[unroll(4)]
do {
...
} while (...)

.. code-block:: hlsl

[unroll]
do {
...
} while (...)

See `hlsl loop extensions <https://learn.microsoft.com/en-us/windows/win32/direct3dhlsl/dx-graphics-hlsl-for>`_
for details.
}];
}

def ClangRandomizeLayoutDocs : Documentation {
let Category = DocCatDecl;
let Heading = "randomize_layout, no_randomize_layout";
Expand Down Expand Up @@ -7404,7 +7498,8 @@ b for constant buffer views (CBV).

Register space is specified in the format ``space[number]`` and defaults to ``space0`` if omitted.
Here're resource binding examples with and without space:
.. code-block:: c++

.. code-block:: hlsl
llvm-beanz marked this conversation as resolved.
Show resolved Hide resolved

RWBuffer<float> Uav : register(u3, space1);
Buffer<float> Buf : register(t1);
Expand All @@ -7422,7 +7517,7 @@ A subcomponent is a register number, which is an integer. A component is in the

Examples:

.. code-block:: c++
.. code-block:: hlsl

cbuffer A {
float3 a : packoffset(c0.y);
Expand Down
15 changes: 13 additions & 2 deletions clang/lib/CodeGen/CGLoopInfo.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -612,9 +612,9 @@ void LoopInfoStack::push(BasicBlock *Header, clang::ASTContext &Ctx,
const LoopHintAttr *LH = dyn_cast<LoopHintAttr>(Attr);
const OpenCLUnrollHintAttr *OpenCLHint =
dyn_cast<OpenCLUnrollHintAttr>(Attr);

const HLSLLoopHintAttr *HLSLLoopHint = dyn_cast<HLSLLoopHintAttr>(Attr);
// Skip non loop hint attributes
if (!LH && !OpenCLHint) {
if (!LH && !OpenCLHint && !HLSLLoopHint) {
continue;
}

Expand All @@ -635,6 +635,17 @@ void LoopInfoStack::push(BasicBlock *Header, clang::ASTContext &Ctx,
Option = LoopHintAttr::UnrollCount;
State = LoopHintAttr::Numeric;
}
} else if (HLSLLoopHint) {
ValueInt = HLSLLoopHint->getDirective();
if (HLSLLoopHint->getSemanticSpelling() ==
HLSLLoopHintAttr::Spelling::Microsoft_unroll) {
if (ValueInt == 0)
State = LoopHintAttr::Enable;
if (ValueInt > 0) {
Option = LoopHintAttr::UnrollCount;
farzonl marked this conversation as resolved.
Show resolved Hide resolved
State = LoopHintAttr::Numeric;
}
}
} else if (LH) {
auto *ValueExpr = LH->getValue();
if (ValueExpr) {
Expand Down
11 changes: 7 additions & 4 deletions clang/lib/Parse/ParseStmt.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -114,18 +114,21 @@ Parser::ParseStatementOrDeclaration(StmtVector &Stmts,
// here because we don't want to allow arbitrary orderings.
ParsedAttributes CXX11Attrs(AttrFactory);
MaybeParseCXX11Attributes(CXX11Attrs, /*MightBeObjCMessageSend*/ true);
ParsedAttributes GNUAttrs(AttrFactory);
ParsedAttributes GNUOrMSAttrs(AttrFactory);
if (getLangOpts().OpenCL)
MaybeParseGNUAttributes(GNUAttrs);
MaybeParseGNUAttributes(GNUOrMSAttrs);

if (getLangOpts().HLSL)
MaybeParseMicrosoftAttributes(GNUOrMSAttrs);

StmtResult Res = ParseStatementOrDeclarationAfterAttributes(
Stmts, StmtCtx, TrailingElseLoc, CXX11Attrs, GNUAttrs);
Stmts, StmtCtx, TrailingElseLoc, CXX11Attrs, GNUOrMSAttrs);
MaybeDestroyTemplateIds();

// Attributes that are left should all go on the statement, so concatenate the
// two lists.
ParsedAttributes Attrs(AttrFactory);
takeAndConcatenateAttrs(CXX11Attrs, GNUAttrs, Attrs);
takeAndConcatenateAttrs(CXX11Attrs, GNUOrMSAttrs, Attrs);

assert((Attrs.empty() || Res.isInvalid() || Res.isUsable()) &&
"attributes on empty statement");
Expand Down
36 changes: 36 additions & 0 deletions clang/lib/Sema/SemaStmtAttr.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
#include "clang/Basic/TargetInfo.h"
#include "clang/Sema/DelayedDiagnostic.h"
#include "clang/Sema/Lookup.h"
#include "clang/Sema/ParsedAttr.h"
#include "clang/Sema/ScopeInfo.h"
#include "clang/Sema/SemaInternal.h"
#include "llvm/ADT/StringExtras.h"
Expand Down Expand Up @@ -584,6 +585,39 @@ static Attr *handleOpenCLUnrollHint(Sema &S, Stmt *St, const ParsedAttr &A,
return ::new (S.Context) OpenCLUnrollHintAttr(S.Context, A, UnrollFactor);
}

static Attr *handleHLSLLoopHintAttr(Sema &S, Stmt *St, const ParsedAttr &A,
SourceRange Range) {

if (A.getSemanticSpelling() == HLSLLoopHintAttr::Spelling::Microsoft_loop &&
!A.checkAtMostNumArgs(S, 0))
return nullptr;

unsigned UnrollFactor = 0;
if (A.getNumArgs() == 1) {

if (A.isArgIdent(0)) {
S.Diag(A.getLoc(), diag::err_attribute_argument_type)
farzonl marked this conversation as resolved.
Show resolved Hide resolved
<< A << AANT_ArgumentIntegerConstant << A.getRange();
return nullptr;
}

Expr *E = A.getArgAsExpr(0);

if (S.CheckLoopHintExpr(E, St->getBeginLoc(),
/*AllowZero=*/false))
return nullptr;

std::optional<llvm::APSInt> ArgVal = E->getIntegerConstantExpr(S.Context);
// CheckLoopHintExpr handles non int const cases
assert(ArgVal != std::nullopt && "ArgVal should be an integer constant.");
int Val = ArgVal->getSExtValue();
// CheckLoopHintExpr handles negative and zero cases
assert(Val > 0 && "Val should be a positive integer greater than zero.");
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

As stated in your comment, CheckLoopHintExpr will check each of these asserts, and return true in each case, so nullptr will be returned by this function on line 608. That said, is there really a purpose for these asserts? I don't see how they can trigger when CheckLoopHintExpr confirms these conditions.

Copy link
Member Author

@farzonl farzonl Jun 25, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good question. You are correct it is not needed, I added it as a form of future proofing. Let me explain my thinking.

CheckLoopHintExpr is the sema checks initially created for clang loop unrolling. We are just leveraging it. There are typically two ways of enforcing a contract on expected behavior, unit tests and asserts. I typically like to include both asserts and unit tests. of the two asserts are my preference because they encode your expectations in the code base and are not tangential to it like unit tests.

As a thought experiment lets imagine a change to CheckLoopHintExpr. This is unlikely to happen, but for the sake of argument lets say the behavior of CheckLoopHintExpr changes to one day support non constant ints. Then the nullptr would not return and we would get a fall through. We want alerts to this regression in behavior on the HLSL portion and asserts are my prefered way of knowing what happend for the reasons mentioned above but also because you get a callstack. I use to write a bunch of c# code, we had code contracts to define expectations. Thats kind of what i'm trying to do here.

UnrollFactor = static_cast<unsigned>(Val);
}
return ::new (S.Context) HLSLLoopHintAttr(S.Context, A, UnrollFactor);
}

static Attr *ProcessStmtAttribute(Sema &S, Stmt *St, const ParsedAttr &A,
SourceRange Range) {
if (A.isInvalid() || A.getKind() == ParsedAttr::IgnoredAttribute)
Expand Down Expand Up @@ -618,6 +652,8 @@ static Attr *ProcessStmtAttribute(Sema &S, Stmt *St, const ParsedAttr &A,
return handleFallThroughAttr(S, St, A, Range);
case ParsedAttr::AT_LoopHint:
return handleLoopHintAttr(S, St, A, Range);
case ParsedAttr::AT_HLSLLoopHint:
return handleHLSLLoopHintAttr(S, St, A, Range);
case ParsedAttr::AT_OpenCLUnrollHint:
return handleOpenCLUnrollHint(S, St, A, Range);
case ParsedAttr::AT_Suppress:
Expand Down
130 changes: 130 additions & 0 deletions clang/test/CodeGenHLSL/loops/unroll.hlsl
farzonl marked this conversation as resolved.
Show resolved Hide resolved
Original file line number Diff line number Diff line change
@@ -0,0 +1,130 @@
// RUN: %clang_cc1 -std=hlsl2021 -finclude-default-header -x hlsl -triple \
// RUN: dxil-pc-shadermodel6.3-library -disable-llvm-passes %s -emit-llvm -o - | FileCheck %s

/*** for ***/
void for_count()
{
// CHECK-LABEL: for_count
farzonl marked this conversation as resolved.
Show resolved Hide resolved
[unroll(8)]
for( int i = 0; i < 1000; ++i);
// CHECK: br label %{{.*}}, !llvm.loop ![[FOR_DISTINCT:.*]]
}

void for_disable()
{
// CHECK-LABEL: for_disable
[loop]
for( int i = 0; i < 1000; ++i);
// CHECK: br label %{{.*}}, !llvm.loop ![[FOR_DISABLE:.*]]
}

void for_enable()
{
// CHECK-LABEL: for_enable
[unroll]
for( int i = 0; i < 1000; ++i);
// CHECK: br label %{{.*}}, !llvm.loop ![[FOR_ENABLE:.*]]
}

void for_nested_one_unroll_enable()
{
// CHECK-LABEL: for_nested_one_unroll_enable
int s = 0;
[unroll]
for( int i = 0; i < 1000; ++i) {
for( int j = 0; j < 10; ++j)
s += i + j;
}
// CHECK: br label %{{.*}}, !llvm.loop ![[FOR_NESTED_ENABLE:.*]]
// CHECK-NOT: br label %{{.*}}, !llvm.loop ![[FOR_NESTED_1_ENABLE:.*]]
}

void for_nested_two_unroll_enable()
{
// CHECK-LABEL: for_nested_two_unroll_enable
int s = 0;
[unroll]
for( int i = 0; i < 1000; ++i) {
[unroll]
for( int j = 0; j < 10; ++j)
s += i + j;
}
// CHECK: br label %{{.*}}, !llvm.loop ![[FOR_NESTED2_ENABLE:.*]]
// CHECK: br label %{{.*}}, !llvm.loop ![[FOR_NESTED2_1_ENABLE:.*]]
}


/*** while ***/
void while_count()
{
// CHECK-LABEL: while_count
int i = 1000;
[unroll(4)]
while(i-->0);
// CHECK: br label %{{.*}}, !llvm.loop ![[WHILE_DISTINCT:.*]]
}

void while_disable()
{
// CHECK-LABEL: while_disable
int i = 1000;
[loop]
while(i-->0);
// CHECK: br label %{{.*}}, !llvm.loop ![[WHILE_DISABLE:.*]]
}

void while_enable()
{
// CHECK-LABEL: while_enable
int i = 1000;
[unroll]
while(i-->0);
// CHECK: br label %{{.*}}, !llvm.loop ![[WHILE_ENABLE:.*]]
}

/*** do ***/
void do_count()
{
// CHECK-LABEL: do_count
int i = 1000;
[unroll(16)]
do {} while(i--> 0);
// CHECK: br i1 %{{.*}}, label %{{.*}}, label %{{.*}}, !llvm.loop ![[DO_DISTINCT:.*]]
}

void do_disable()
{
// CHECK-LABEL: do_disable
int i = 1000;
[loop]
do {} while(i--> 0);
// CHECK: br i1 %{{.*}}, label %{{.*}}, label %{{.*}}, !llvm.loop ![[DO_DISABLE:.*]]
}

void do_enable()
{
// CHECK-LABEL: do_enable
int i = 1000;
[unroll]
do {} while(i--> 0);
// CHECK: br i1 %{{.*}}, label %{{.*}}, label %{{.*}}, !llvm.loop ![[DO_ENABLE:.*]]
}


// CHECK: ![[FOR_DISTINCT]] = distinct !{![[FOR_DISTINCT]], ![[FOR_COUNT:.*]]}
// CHECK: ![[FOR_COUNT]] = !{!"llvm.loop.unroll.count", i32 8}
// CHECK: ![[FOR_DISABLE]] = distinct !{![[FOR_DISABLE]], ![[DISABLE:.*]]}
// CHECK: ![[DISABLE]] = !{!"llvm.loop.unroll.disable"}
// CHECK: ![[FOR_ENABLE]] = distinct !{![[FOR_ENABLE]], ![[ENABLE:.*]]}
// CHECK: ![[ENABLE]] = !{!"llvm.loop.unroll.enable"}
// CHECK: ![[FOR_NESTED_ENABLE]] = distinct !{![[FOR_NESTED_ENABLE]], ![[ENABLE]]}
// CHECK: ![[FOR_NESTED2_ENABLE]] = distinct !{![[FOR_NESTED2_ENABLE]], ![[ENABLE]]}
// CHECK: ![[FOR_NESTED2_1_ENABLE]] = distinct !{![[FOR_NESTED2_1_ENABLE]], ![[ENABLE]]}
// CHECK: ![[WHILE_DISTINCT]] = distinct !{![[WHILE_DISTINCT]], ![[WHILE_COUNT:.*]]}
// CHECK: ![[WHILE_COUNT]] = !{!"llvm.loop.unroll.count", i32 4}
// CHECK: ![[WHILE_DISABLE]] = distinct !{![[WHILE_DISABLE]], ![[DISABLE]]}
// CHECK: ![[WHILE_ENABLE]] = distinct !{![[WHILE_ENABLE]], ![[ENABLE]]}
// CHECK: ![[DO_DISTINCT]] = distinct !{![[DO_DISTINCT]], ![[DO_COUNT:.*]]}
// CHECK: ![[DO_COUNT]] = !{!"llvm.loop.unroll.count", i32 16}
// CHECK: ![[DO_DISABLE]] = distinct !{![[DO_DISABLE]], ![[DISABLE]]}
// CHECK: ![[DO_ENABLE]] = distinct !{![[DO_ENABLE]], ![[ENABLE]]}
Loading
Loading