Skip to content

Commit b0b6c16

Browse files
authored
[Clang][OpenMP][Tile] Allow non-constant tile sizes. (#91345)
Allow non-constants in the `sizes` clause such as ``` #pragma omp tile sizes(a) for (int i = 0; i < n; ++i) ``` This is permitted since tile was introduced in [OpenMP 5.1](https://www.openmp.org/spec-html/5.1/openmpsu53.html#x78-860002.11.9). It is possible to sneak-in negative numbers at runtime as in ``` int a = -1; #pragma omp tile sizes(a) ``` Even though it is not well-formed, it should still result in every loop iteration to be executed exactly once, an invariant of the tile construct that we should ensure. `ParseOpenMPExprListClause` is extracted-out to be reused by the `permutation` clause of the `interchange` construct. Some care was put into ensuring correct behavior in template contexts.
1 parent c4e9e41 commit b0b6c16

File tree

9 files changed

+748
-70
lines changed

9 files changed

+748
-70
lines changed

clang/include/clang/Parse/Parser.h

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3553,6 +3553,23 @@ class Parser : public CodeCompletionHandler {
35533553
OMPClause *ParseOpenMPVarListClause(OpenMPDirectiveKind DKind,
35543554
OpenMPClauseKind Kind, bool ParseOnly);
35553555

3556+
/// Parses a clause consisting of a list of expressions.
3557+
///
3558+
/// \param Kind The clause to parse.
3559+
/// \param ClauseNameLoc [out] The location of the clause name.
3560+
/// \param OpenLoc [out] The location of '('.
3561+
/// \param CloseLoc [out] The location of ')'.
3562+
/// \param Exprs [out] The parsed expressions.
3563+
/// \param ReqIntConst If true, each expression must be an integer constant.
3564+
///
3565+
/// \return Whether the clause was parsed successfully.
3566+
bool ParseOpenMPExprListClause(OpenMPClauseKind Kind,
3567+
SourceLocation &ClauseNameLoc,
3568+
SourceLocation &OpenLoc,
3569+
SourceLocation &CloseLoc,
3570+
SmallVectorImpl<Expr *> &Exprs,
3571+
bool ReqIntConst = false);
3572+
35563573
/// Parses and creates OpenMP 5.0 iterators expression:
35573574
/// <iterators> = 'iterator' '(' { [ <iterator-type> ] identifier =
35583575
/// <range-specification> }+ ')'

clang/lib/Parse/ParseOpenMP.cpp

Lines changed: 40 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -3107,34 +3107,14 @@ bool Parser::ParseOpenMPSimpleVarList(
31073107
}
31083108

31093109
OMPClause *Parser::ParseOpenMPSizesClause() {
3110-
SourceLocation ClauseNameLoc = ConsumeToken();
3110+
SourceLocation ClauseNameLoc, OpenLoc, CloseLoc;
31113111
SmallVector<Expr *, 4> ValExprs;
3112-
3113-
BalancedDelimiterTracker T(*this, tok::l_paren, tok::annot_pragma_openmp_end);
3114-
if (T.consumeOpen()) {
3115-
Diag(Tok, diag::err_expected) << tok::l_paren;
3112+
if (ParseOpenMPExprListClause(OMPC_sizes, ClauseNameLoc, OpenLoc, CloseLoc,
3113+
ValExprs))
31163114
return nullptr;
3117-
}
3118-
3119-
while (true) {
3120-
ExprResult Val = ParseConstantExpression();
3121-
if (!Val.isUsable()) {
3122-
T.skipToEnd();
3123-
return nullptr;
3124-
}
3125-
3126-
ValExprs.push_back(Val.get());
3127-
3128-
if (Tok.is(tok::r_paren) || Tok.is(tok::annot_pragma_openmp_end))
3129-
break;
3130-
3131-
ExpectAndConsume(tok::comma);
3132-
}
3133-
3134-
T.consumeClose();
31353115

3136-
return Actions.OpenMP().ActOnOpenMPSizesClause(
3137-
ValExprs, ClauseNameLoc, T.getOpenLocation(), T.getCloseLocation());
3116+
return Actions.OpenMP().ActOnOpenMPSizesClause(ValExprs, ClauseNameLoc,
3117+
OpenLoc, CloseLoc);
31383118
}
31393119

31403120
OMPClause *Parser::ParseOpenMPUsesAllocatorClause(OpenMPDirectiveKind DKind) {
@@ -4991,3 +4971,38 @@ OMPClause *Parser::ParseOpenMPVarListClause(OpenMPDirectiveKind DKind,
49914971
OMPVarListLocTy Locs(Loc, LOpen, Data.RLoc);
49924972
return Actions.OpenMP().ActOnOpenMPVarListClause(Kind, Vars, Locs, Data);
49934973
}
4974+
4975+
bool Parser::ParseOpenMPExprListClause(OpenMPClauseKind Kind,
4976+
SourceLocation &ClauseNameLoc,
4977+
SourceLocation &OpenLoc,
4978+
SourceLocation &CloseLoc,
4979+
SmallVectorImpl<Expr *> &Exprs,
4980+
bool ReqIntConst) {
4981+
assert(getOpenMPClauseName(Kind) == PP.getSpelling(Tok) &&
4982+
"Expected parsing to start at clause name");
4983+
ClauseNameLoc = ConsumeToken();
4984+
4985+
// Parse inside of '(' and ')'.
4986+
BalancedDelimiterTracker T(*this, tok::l_paren, tok::annot_pragma_openmp_end);
4987+
if (T.consumeOpen()) {
4988+
Diag(Tok, diag::err_expected) << tok::l_paren;
4989+
return true;
4990+
}
4991+
4992+
// Parse the list with interleaved commas.
4993+
do {
4994+
ExprResult Val =
4995+
ReqIntConst ? ParseConstantExpression() : ParseAssignmentExpression();
4996+
if (!Val.isUsable()) {
4997+
// Encountered something other than an expression; abort to ')'.
4998+
T.skipToEnd();
4999+
return true;
5000+
}
5001+
Exprs.push_back(Val.get());
5002+
} while (TryConsumeToken(tok::comma));
5003+
5004+
bool Result = T.consumeClose();
5005+
OpenLoc = T.getOpenLocation();
5006+
CloseLoc = T.getCloseLocation();
5007+
return Result;
5008+
}

clang/lib/Sema/SemaOpenMP.cpp

Lines changed: 95 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -15111,13 +15111,11 @@ StmtResult SemaOpenMP::ActOnOpenMPTileDirective(ArrayRef<OMPClause *> Clauses,
1511115111
ASTContext &Context = getASTContext();
1511215112
Scope *CurScope = SemaRef.getCurScope();
1511315113

15114-
auto SizesClauses =
15115-
OMPExecutableDirective::getClausesOfKind<OMPSizesClause>(Clauses);
15116-
if (SizesClauses.empty()) {
15117-
// A missing 'sizes' clause is already reported by the parser.
15114+
const auto *SizesClause =
15115+
OMPExecutableDirective::getSingleClause<OMPSizesClause>(Clauses);
15116+
if (!SizesClause ||
15117+
llvm::any_of(SizesClause->getSizesRefs(), [](Expr *E) { return !E; }))
1511815118
return StmtError();
15119-
}
15120-
const OMPSizesClause *SizesClause = *SizesClauses.begin();
1512115119
unsigned NumLoops = SizesClause->getNumSizes();
1512215120

1512315121
// Empty statement should only be possible if there already was an error.
@@ -15138,6 +15136,13 @@ StmtResult SemaOpenMP::ActOnOpenMPTileDirective(ArrayRef<OMPClause *> Clauses,
1513815136
return OMPTileDirective::Create(Context, StartLoc, EndLoc, Clauses,
1513915137
NumLoops, AStmt, nullptr, nullptr);
1514015138

15139+
assert(LoopHelpers.size() == NumLoops &&
15140+
"Expecting loop iteration space dimensionality to match number of "
15141+
"affected loops");
15142+
assert(OriginalInits.size() == NumLoops &&
15143+
"Expecting loop iteration space dimensionality to match number of "
15144+
"affected loops");
15145+
1514115146
SmallVector<Decl *, 4> PreInits;
1514215147
CaptureVars CopyTransformer(SemaRef);
1514315148

@@ -15197,6 +15202,44 @@ StmtResult SemaOpenMP::ActOnOpenMPTileDirective(ArrayRef<OMPClause *> Clauses,
1519715202
// Once the original iteration values are set, append the innermost body.
1519815203
Stmt *Inner = Body;
1519915204

15205+
auto MakeDimTileSize = [&SemaRef = this->SemaRef, &CopyTransformer, &Context,
15206+
SizesClause, CurScope](int I) -> Expr * {
15207+
Expr *DimTileSizeExpr = SizesClause->getSizesRefs()[I];
15208+
if (isa<ConstantExpr>(DimTileSizeExpr))
15209+
return AssertSuccess(CopyTransformer.TransformExpr(DimTileSizeExpr));
15210+
15211+
// When the tile size is not a constant but a variable, it is possible to
15212+
// pass non-positive numbers. For instance:
15213+
// \code{c}
15214+
// int a = 0;
15215+
// #pragma omp tile sizes(a)
15216+
// for (int i = 0; i < 42; ++i)
15217+
// body(i);
15218+
// \endcode
15219+
// Although there is no meaningful interpretation of the tile size, the body
15220+
// should still be executed 42 times to avoid surprises. To preserve the
15221+
// invariant that every loop iteration is executed exactly once and not
15222+
// cause an infinite loop, apply a minimum tile size of one.
15223+
// Build expr:
15224+
// \code{c}
15225+
// (TS <= 0) ? 1 : TS
15226+
// \endcode
15227+
QualType DimTy = DimTileSizeExpr->getType();
15228+
uint64_t DimWidth = Context.getTypeSize(DimTy);
15229+
IntegerLiteral *Zero = IntegerLiteral::Create(
15230+
Context, llvm::APInt::getZero(DimWidth), DimTy, {});
15231+
IntegerLiteral *One =
15232+
IntegerLiteral::Create(Context, llvm::APInt(DimWidth, 1), DimTy, {});
15233+
Expr *Cond = AssertSuccess(SemaRef.BuildBinOp(
15234+
CurScope, {}, BO_LE,
15235+
AssertSuccess(CopyTransformer.TransformExpr(DimTileSizeExpr)), Zero));
15236+
Expr *MinOne = new (Context) ConditionalOperator(
15237+
Cond, {}, One, {},
15238+
AssertSuccess(CopyTransformer.TransformExpr(DimTileSizeExpr)), DimTy,
15239+
VK_PRValue, OK_Ordinary);
15240+
return MinOne;
15241+
};
15242+
1520015243
// Create tile loops from the inside to the outside.
1520115244
for (int I = NumLoops - 1; I >= 0; --I) {
1520215245
OMPLoopBasedDirective::HelperExprs &LoopHelper = LoopHelpers[I];
@@ -15207,10 +15250,6 @@ StmtResult SemaOpenMP::ActOnOpenMPTileDirective(ArrayRef<OMPClause *> Clauses,
1520715250
// Commonly used variables. One of the constraints of an AST is that every
1520815251
// node object must appear at most once, hence we define lamdas that create
1520915252
// a new AST node at every use.
15210-
auto MakeDimTileSize = [&CopyTransformer, I, SizesClause]() -> Expr * {
15211-
Expr *DimTileSize = SizesClause->getSizesRefs()[I];
15212-
return AssertSuccess(CopyTransformer.TransformExpr(DimTileSize));
15213-
};
1521415253
auto MakeTileIVRef = [&SemaRef = this->SemaRef, &TileIndVars, I, CntTy,
1521515254
OrigCntVar]() {
1521615255
return buildDeclRefExpr(SemaRef, TileIndVars[I], CntTy,
@@ -15237,7 +15276,7 @@ StmtResult SemaOpenMP::ActOnOpenMPTileDirective(ArrayRef<OMPClause *> Clauses,
1523715276
// .tile.iv < min(.floor.iv + DimTileSize, NumIterations)
1523815277
ExprResult EndOfTile =
1523915278
SemaRef.BuildBinOp(CurScope, LoopHelper.Cond->getExprLoc(), BO_Add,
15240-
MakeFloorIVRef(), MakeDimTileSize());
15279+
MakeFloorIVRef(), MakeDimTileSize(I));
1524115280
if (!EndOfTile.isUsable())
1524215281
return StmtError();
1524315282
ExprResult IsPartialTile =
@@ -15297,10 +15336,6 @@ StmtResult SemaOpenMP::ActOnOpenMPTileDirective(ArrayRef<OMPClause *> Clauses,
1529715336
QualType CntTy = OrigCntVar->getType();
1529815337

1529915338
// Commonly used variables.
15300-
auto MakeDimTileSize = [&CopyTransformer, I, SizesClause]() -> Expr * {
15301-
Expr *DimTileSize = SizesClause->getSizesRefs()[I];
15302-
return AssertSuccess(CopyTransformer.TransformExpr(DimTileSize));
15303-
};
1530415339
auto MakeFloorIVRef = [&SemaRef = this->SemaRef, &FloorIndVars, I, CntTy,
1530515340
OrigCntVar]() {
1530615341
return buildDeclRefExpr(SemaRef, FloorIndVars[I], CntTy,
@@ -15329,7 +15364,7 @@ StmtResult SemaOpenMP::ActOnOpenMPTileDirective(ArrayRef<OMPClause *> Clauses,
1532915364
// For incr-statement: .floor.iv += DimTileSize
1533015365
ExprResult IncrStmt =
1533115366
SemaRef.BuildBinOp(CurScope, LoopHelper.Inc->getExprLoc(), BO_AddAssign,
15332-
MakeFloorIVRef(), MakeDimTileSize());
15367+
MakeFloorIVRef(), MakeDimTileSize(I));
1533315368
if (!IncrStmt.isUsable())
1533415369
return StmtError();
1533515370

@@ -17430,16 +17465,53 @@ OMPClause *SemaOpenMP::ActOnOpenMPSizesClause(ArrayRef<Expr *> SizeExprs,
1743017465
SourceLocation StartLoc,
1743117466
SourceLocation LParenLoc,
1743217467
SourceLocation EndLoc) {
17433-
for (Expr *SizeExpr : SizeExprs) {
17434-
ExprResult NumForLoopsResult = VerifyPositiveIntegerConstantInClause(
17435-
SizeExpr, OMPC_sizes, /*StrictlyPositive=*/true);
17436-
if (!NumForLoopsResult.isUsable())
17437-
return nullptr;
17468+
SmallVector<Expr *> SanitizedSizeExprs(SizeExprs);
17469+
17470+
for (Expr *&SizeExpr : SanitizedSizeExprs) {
17471+
// Skip if already sanitized, e.g. during a partial template instantiation.
17472+
if (!SizeExpr)
17473+
continue;
17474+
17475+
bool IsValid = isNonNegativeIntegerValue(SizeExpr, SemaRef, OMPC_sizes,
17476+
/*StrictlyPositive=*/true);
17477+
17478+
// isNonNegativeIntegerValue returns true for non-integral types (but still
17479+
// emits error diagnostic), so check for the expected type explicitly.
17480+
QualType SizeTy = SizeExpr->getType();
17481+
if (!SizeTy->isIntegerType())
17482+
IsValid = false;
17483+
17484+
// Handling in templates is tricky. There are four possibilities to
17485+
// consider:
17486+
//
17487+
// 1a. The expression is valid and we are in a instantiated template or not
17488+
// in a template:
17489+
// Pass valid expression to be further analysed later in Sema.
17490+
// 1b. The expression is valid and we are in a template (including partial
17491+
// instantiation):
17492+
// isNonNegativeIntegerValue skipped any checks so there is no
17493+
// guarantee it will be correct after instantiation.
17494+
// ActOnOpenMPSizesClause will be called again at instantiation when
17495+
// it is not in a dependent context anymore. This may cause warnings
17496+
// to be emitted multiple times.
17497+
// 2a. The expression is invalid and we are in an instantiated template or
17498+
// not in a template:
17499+
// Invalidate the expression with a clearly wrong value (nullptr) so
17500+
// later in Sema we do not have to do the same validity analysis again
17501+
// or crash from unexpected data. Error diagnostics have already been
17502+
// emitted.
17503+
// 2b. The expression is invalid and we are in a template (including partial
17504+
// instantiation):
17505+
// Pass the invalid expression as-is, template instantiation may
17506+
// replace unexpected types/values with valid ones. The directives
17507+
// with this clause must not try to use these expressions in dependent
17508+
// contexts, but delay analysis until full instantiation.
17509+
if (!SizeExpr->isInstantiationDependent() && !IsValid)
17510+
SizeExpr = nullptr;
1743817511
}
1743917512

17440-
DSAStack->setAssociatedLoops(SizeExprs.size());
1744117513
return OMPSizesClause::Create(getASTContext(), StartLoc, LParenLoc, EndLoc,
17442-
SizeExprs);
17514+
SanitizedSizeExprs);
1744317515
}
1744417516

1744517517
OMPClause *SemaOpenMP::ActOnOpenMPFullClause(SourceLocation StartLoc,

clang/test/OpenMP/tile_ast_print.cpp

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -183,4 +183,21 @@ void tfoo7() {
183183
}
184184

185185

186+
// PRINT-LABEL: void foo8(
187+
// DUMP-LABEL: FunctionDecl {{.*}} foo8
188+
void foo8(int a) {
189+
// PRINT: #pragma omp tile sizes(a)
190+
// DUMP: OMPTileDirective
191+
// DUMP-NEXT: OMPSizesClause
192+
// DUMP-NEXT: ImplicitCastExpr
193+
// DUMP-NEXT: DeclRefExpr {{.*}} 'a'
194+
#pragma omp tile sizes(a)
195+
// PRINT-NEXT: for (int i = 7; i < 19; i += 3)
196+
// DUMP-NEXT: ForStmt
197+
for (int i = 7; i < 19; i += 3)
198+
// PRINT: body(i);
199+
// DUMP: CallExpr
200+
body(i);
201+
}
202+
186203
#endif

0 commit comments

Comments
 (0)