Skip to content

Commit

Permalink
Reapply "[OpenMP] Add the ompx_attribute clause for target directives"
Browse files Browse the repository at this point in the history
This reverts commit 0d12683 and
reapplies ef9ec4b with an extension to
fix the Flang build.

Differential Revision: https://reviews.llvm.org/D156184
  • Loading branch information
jdoerfert committed Jul 25, 2023
1 parent 76c22b1 commit 08a2207
Show file tree
Hide file tree
Showing 23 changed files with 490 additions and 105 deletions.
48 changes: 48 additions & 0 deletions clang/include/clang/AST/OpenMPClause.h
Original file line number Diff line number Diff line change
Expand Up @@ -9172,6 +9172,54 @@ class OMPDoacrossClause final
}
};

/// This represents 'ompx_attribute' clause in a directive that might generate
/// an outlined function. An example is given below.
///
/// \code
/// #pragma omp target [...] ompx_attribute(flatten)
/// \endcode
class OMPXAttributeClause
: public OMPNoChildClause<llvm::omp::OMPC_ompx_attribute> {
friend class OMPClauseReader;

/// Location of '('.
SourceLocation LParenLoc;

/// The parsed attributes (clause arguments)
SmallVector<const Attr *> Attrs;

public:
/// Build 'ompx_attribute' clause.
///
/// \param Attrs The parsed attributes (clause arguments)
/// \param StartLoc Starting location of the clause.
/// \param LParenLoc Location of '('.
/// \param EndLoc Ending location of the clause.
OMPXAttributeClause(ArrayRef<const Attr *> Attrs, SourceLocation StartLoc,
SourceLocation LParenLoc, SourceLocation EndLoc)
: OMPNoChildClause(StartLoc, EndLoc), LParenLoc(LParenLoc), Attrs(Attrs) {
}

/// Build an empty clause.
OMPXAttributeClause() : OMPNoChildClause() {}

/// Sets the location of '('.
void setLParenLoc(SourceLocation Loc) { LParenLoc = Loc; }

/// Returns the location of '('.
SourceLocation getLParenLoc() const { return LParenLoc; }

/// Returned the attributes parsed from this clause.
ArrayRef<const Attr *> getAttrs() const { return Attrs; }

private:
/// Replace the attributes with \p NewAttrs.
void setAttrs(ArrayRef<Attr *> NewAttrs) {
Attrs.clear();
Attrs.append(NewAttrs.begin(), NewAttrs.end());
}
};

} // namespace clang

#endif // LLVM_CLANG_AST_OPENMPCLAUSE_H
6 changes: 6 additions & 0 deletions clang/include/clang/AST/RecursiveASTVisitor.h
Original file line number Diff line number Diff line change
Expand Up @@ -3875,6 +3875,12 @@ bool RecursiveASTVisitor<Derived>::VisitOMPDoacrossClause(
return true;
}

template <typename Derived>
bool RecursiveASTVisitor<Derived>::VisitOMPXAttributeClause(
OMPXAttributeClause *C) {
return true;
}

// FIXME: look at the following tricky-seeming exprs to see if we
// need to recurse on anything. These are ones that have methods
// returning decls or qualtypes or nestednamespecifier -- though I'm
Expand Down
3 changes: 2 additions & 1 deletion clang/include/clang/Basic/DiagnosticGroups.td
Original file line number Diff line number Diff line change
Expand Up @@ -1278,9 +1278,10 @@ def OpenMPMapping : DiagGroup<"openmp-mapping">;
def OpenMPTarget : DiagGroup<"openmp-target", [OpenMPMapping]>;
def OpenMPPre51Compat : DiagGroup<"pre-openmp-51-compat">;
def OpenMP51Ext : DiagGroup<"openmp-51-extensions">;
def OpenMPExtensions : DiagGroup<"openmp-extensions">;
def OpenMP : DiagGroup<"openmp", [
SourceUsesOpenMP, OpenMPClauses, OpenMPLoopForm, OpenMPTarget,
OpenMPMapping, OpenMP51Ext
OpenMPMapping, OpenMP51Ext, OpenMPExtensions
]>;

// Backend warnings.
Expand Down
3 changes: 3 additions & 0 deletions clang/include/clang/Basic/DiagnosticParseKinds.td
Original file line number Diff line number Diff line change
Expand Up @@ -1540,6 +1540,9 @@ def warn_omp_more_one_omp_all_memory : Warning<
InGroup<OpenMPClauses>;
def warn_omp_depend_in_ordered_deprecated : Warning<"'depend' clause for"
" 'ordered' is deprecated; use 'doacross' instead">, InGroup<Deprecated>;
def warn_omp_invalid_attribute_for_ompx_attributes : Warning<"'ompx_attribute' clause only allows "
"'amdgpu_flat_work_group_size', 'amdgpu_waves_per_eu', and 'launch_bounds'; "
"%0 is ignored">, InGroup<OpenMPExtensions>;

// Pragma loop support.
def err_pragma_loop_missing_argument : Error<
Expand Down
7 changes: 7 additions & 0 deletions clang/include/clang/Parse/Parser.h
Original file line number Diff line number Diff line change
Expand Up @@ -3490,6 +3490,13 @@ class Parser : public CodeCompletionHandler {
//
OMPClause *ParseOpenMPInteropClause(OpenMPClauseKind Kind, bool ParseOnly);

/// Parses a ompx_attribute clause
///
/// \param ParseOnly true to skip the clause's semantic actions and return
/// nullptr.
//
OMPClause *ParseOpenMPOMPXAttributesClause(bool ParseOnly);

public:
/// Parses simple expression in parens for single-expression clauses of OpenMP
/// constructs.
Expand Down
21 changes: 21 additions & 0 deletions clang/include/clang/Sema/Sema.h
Original file line number Diff line number Diff line change
Expand Up @@ -10988,6 +10988,11 @@ class Sema final {
bool ConstantFoldAttrArgs(const AttributeCommonInfo &CI,
MutableArrayRef<Expr *> Args);

/// Create an CUDALaunchBoundsAttr attribute.
CUDALaunchBoundsAttr *CreateLaunchBoundsAttr(const AttributeCommonInfo &CI,
Expr *MaxThreads,
Expr *MinBlocks);

/// AddLaunchBoundsAttr - Adds a launch_bounds attribute to a particular
/// declaration.
void AddLaunchBoundsAttr(Decl *D, const AttributeCommonInfo &CI,
Expand All @@ -11004,11 +11009,21 @@ class Sema final {
void AddXConsumedAttr(Decl *D, const AttributeCommonInfo &CI,
RetainOwnershipKind K, bool IsTemplateInstantiation);

/// Create an AMDGPUWavesPerEUAttr attribute.
AMDGPUFlatWorkGroupSizeAttr *
CreateAMDGPUFlatWorkGroupSizeAttr(const AttributeCommonInfo &CI, Expr *Min,
Expr *Max);

/// addAMDGPUFlatWorkGroupSizeAttr - Adds an amdgpu_flat_work_group_size
/// attribute to a particular declaration.
void addAMDGPUFlatWorkGroupSizeAttr(Decl *D, const AttributeCommonInfo &CI,
Expr *Min, Expr *Max);

/// Create an AMDGPUWavesPerEUAttr attribute.
AMDGPUWavesPerEUAttr *
CreateAMDGPUWavesPerEUAttr(const AttributeCommonInfo &CI, Expr *Min,
Expr *Max);

/// addAMDGPUWavePersEUAttr - Adds an amdgpu_waves_per_eu attribute to a
/// particular declaration.
void addAMDGPUWavesPerEUAttr(Decl *D, const AttributeCommonInfo &CI,
Expand Down Expand Up @@ -12341,6 +12356,12 @@ class Sema final {
ArrayRef<Expr *> VarList, SourceLocation StartLoc,
SourceLocation LParenLoc, SourceLocation EndLoc);

/// Called on a well-formed 'ompx_attribute' clause.
OMPClause *ActOnOpenMPXAttributeClause(ArrayRef<const Attr *> Attrs,
SourceLocation StartLoc,
SourceLocation LParenLoc,
SourceLocation EndLoc);

/// The kind of conversion being performed.
enum CheckedConversionKind {
/// An implicit conversion.
Expand Down
12 changes: 12 additions & 0 deletions clang/lib/AST/OpenMPClause.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2534,6 +2534,18 @@ void OMPClausePrinter::VisitOMPDoacrossClause(OMPDoacrossClause *Node) {
OS << ")";
}

void OMPClausePrinter::VisitOMPXAttributeClause(OMPXAttributeClause *Node) {
OS << "ompx_attribute(";
bool IsFirst = true;
for (auto &Attr : Node->getAttrs()) {
if (!IsFirst)
OS << ", ";
Attr->printPretty(OS, Policy);
IsFirst = false;
}
OS << ")";
}

void OMPTraitInfo::getAsVariantMatchInfo(ASTContext &ASTCtx,
VariantMatchInfo &VMI) const {
for (const OMPTraitSet &Set : Sets) {
Expand Down
2 changes: 2 additions & 0 deletions clang/lib/AST/StmtProfile.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -928,6 +928,8 @@ void OMPClauseProfiler::VisitOMPXDynCGroupMemClause(
void OMPClauseProfiler::VisitOMPDoacrossClause(const OMPDoacrossClause *C) {
VisitOMPClauseList(C);
}
void OMPClauseProfiler::VisitOMPXAttributeClause(const OMPXAttributeClause *C) {
}
} // namespace

void
Expand Down
19 changes: 17 additions & 2 deletions clang/lib/CodeGen/CGOpenMPRuntime.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -6110,8 +6110,23 @@ void CGOpenMPRuntime::emitTargetOutlinedFunctionHelper(
DefaultValTeams, DefaultValThreads,
IsOffloadEntry, OutlinedFn, OutlinedFnID);

if (OutlinedFn != nullptr)
CGM.getTargetCodeGenInfo().setTargetAttributes(nullptr, OutlinedFn, CGM);
if (!OutlinedFn)
return;

CGM.getTargetCodeGenInfo().setTargetAttributes(nullptr, OutlinedFn, CGM);

for (auto *C : D.getClausesOfKind<OMPXAttributeClause>()) {
for (auto *A : C->getAttrs()) {
if (auto *Attr = dyn_cast<CUDALaunchBoundsAttr>(A))
CGM.handleCUDALaunchBoundsAttr(OutlinedFn, Attr);
else if (auto *Attr = dyn_cast<AMDGPUFlatWorkGroupSizeAttr>(A))
CGM.handleAMDGPUFlatWorkGroupSizeAttr(OutlinedFn, Attr);
else if (auto *Attr = dyn_cast<AMDGPUWavesPerEUAttr>(A))
CGM.handleAMDGPUWavesPerEUAttr(OutlinedFn, Attr);
else
llvm_unreachable("Unexpected attribute kind");
}
}
}

/// Checks if the expression is constant or does not have non-trivial function
Expand Down
15 changes: 15 additions & 0 deletions clang/lib/CodeGen/CodeGenModule.h
Original file line number Diff line number Diff line change
Expand Up @@ -1557,6 +1557,21 @@ class CodeGenModule : public CodeGenTypeCache {
/// because we'll lose all important information after each repl.
void moveLazyEmissionStates(CodeGenModule *NewBuilder);

/// Emit the IR encoding to attach the CUDA launch bounds attribute to \p F.
void handleCUDALaunchBoundsAttr(llvm::Function *F,
const CUDALaunchBoundsAttr *A);

/// Emit the IR encoding to attach the AMD GPU flat-work-group-size attribute
/// to \p F. Alternatively, the work group size can be taken from a \p
/// ReqdWGS.
void handleAMDGPUFlatWorkGroupSizeAttr(
llvm::Function *F, const AMDGPUFlatWorkGroupSizeAttr *A,
const ReqdWorkGroupSizeAttr *ReqdWGS = nullptr);

/// Emit the IR encoding to attach the AMD GPU waves-per-eu attribute to \p F.
void handleAMDGPUWavesPerEUAttr(llvm::Function *F,
const AMDGPUWavesPerEUAttr *A);

private:
llvm::Constant *GetOrCreateLLVMFunction(
StringRef MangledName, llvm::Type *Ty, GlobalDecl D, bool ForVTable,
Expand Down
82 changes: 44 additions & 38 deletions clang/lib/CodeGen/Targets/AMDGPU.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -317,26 +317,7 @@ void AMDGPUTargetCodeGenInfo::setFunctionDeclAttributes(

const auto *FlatWGS = FD->getAttr<AMDGPUFlatWorkGroupSizeAttr>();
if (ReqdWGS || FlatWGS) {
unsigned Min = 0;
unsigned Max = 0;
if (FlatWGS) {
Min = FlatWGS->getMin()
->EvaluateKnownConstInt(M.getContext())
.getExtValue();
Max = FlatWGS->getMax()
->EvaluateKnownConstInt(M.getContext())
.getExtValue();
}
if (ReqdWGS && Min == 0 && Max == 0)
Min = Max = ReqdWGS->getXDim() * ReqdWGS->getYDim() * ReqdWGS->getZDim();

if (Min != 0) {
assert(Min <= Max && "Min must be less than or equal Max");

std::string AttrVal = llvm::utostr(Min) + "," + llvm::utostr(Max);
F->addFnAttr("amdgpu-flat-work-group-size", AttrVal);
} else
assert(Max == 0 && "Max must be zero");
M.handleAMDGPUFlatWorkGroupSizeAttr(F, FlatWGS, ReqdWGS);
} else if (IsOpenCLKernel || IsHIPKernel) {
// By default, restrict the maximum size to a value specified by
// --gpu-max-threads-per-block=n or its default value for HIP.
Expand All @@ -349,24 +330,8 @@ void AMDGPUTargetCodeGenInfo::setFunctionDeclAttributes(
F->addFnAttr("amdgpu-flat-work-group-size", AttrVal);
}

if (const auto *Attr = FD->getAttr<AMDGPUWavesPerEUAttr>()) {
unsigned Min =
Attr->getMin()->EvaluateKnownConstInt(M.getContext()).getExtValue();
unsigned Max = Attr->getMax() ? Attr->getMax()
->EvaluateKnownConstInt(M.getContext())
.getExtValue()
: 0;

if (Min != 0) {
assert((Max == 0 || Min <= Max) && "Min must be less than or equal Max");

std::string AttrVal = llvm::utostr(Min);
if (Max != 0)
AttrVal = AttrVal + "," + llvm::utostr(Max);
F->addFnAttr("amdgpu-waves-per-eu", AttrVal);
} else
assert(Max == 0 && "Max must be zero");
}
if (const auto *Attr = FD->getAttr<AMDGPUWavesPerEUAttr>())
M.handleAMDGPUWavesPerEUAttr(F, Attr);

if (const auto *Attr = FD->getAttr<AMDGPUNumSGPRAttr>()) {
unsigned NumSGPR = Attr->getNumSGPR();
Expand Down Expand Up @@ -595,6 +560,47 @@ llvm::Value *AMDGPUTargetCodeGenInfo::createEnqueuedBlockKernel(
return F;
}

void CodeGenModule::handleAMDGPUFlatWorkGroupSizeAttr(
llvm::Function *F, const AMDGPUFlatWorkGroupSizeAttr *FlatWGS,
const ReqdWorkGroupSizeAttr *ReqdWGS) {
unsigned Min = 0;
unsigned Max = 0;
if (FlatWGS) {
Min = FlatWGS->getMin()->EvaluateKnownConstInt(getContext()).getExtValue();
Max = FlatWGS->getMax()->EvaluateKnownConstInt(getContext()).getExtValue();
}
if (ReqdWGS && Min == 0 && Max == 0)
Min = Max = ReqdWGS->getXDim() * ReqdWGS->getYDim() * ReqdWGS->getZDim();

if (Min != 0) {
assert(Min <= Max && "Min must be less than or equal Max");

std::string AttrVal = llvm::utostr(Min) + "," + llvm::utostr(Max);
F->addFnAttr("amdgpu-flat-work-group-size", AttrVal);
} else
assert(Max == 0 && "Max must be zero");
}

void CodeGenModule::handleAMDGPUWavesPerEUAttr(
llvm::Function *F, const AMDGPUWavesPerEUAttr *Attr) {
unsigned Min =
Attr->getMin()->EvaluateKnownConstInt(getContext()).getExtValue();
unsigned Max =
Attr->getMax()
? Attr->getMax()->EvaluateKnownConstInt(getContext()).getExtValue()
: 0;

if (Min != 0) {
assert((Max == 0 || Min <= Max) && "Min must be less than or equal Max");

std::string AttrVal = llvm::utostr(Min);
if (Max != 0)
AttrVal = AttrVal + "," + llvm::utostr(Max);
F->addFnAttr("amdgpu-waves-per-eu", AttrVal);
} else
assert(Max == 0 && "Max must be zero");
}

std::unique_ptr<TargetCodeGenInfo>
CodeGen::createAMDGPUTargetCodeGenInfo(CodeGenModule &CGM) {
return std::make_unique<AMDGPUTargetCodeGenInfo>(CGM.getTypes());
Expand Down
44 changes: 25 additions & 19 deletions clang/lib/CodeGen/Targets/NVPTX.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -71,12 +71,12 @@ class NVPTXTargetCodeGenInfo : public TargetCodeGenInfo {
return true;
}

private:
// Adds a NamedMDNode with GV, Name, and Operand as operands, and adds the
// resulting MDNode to the nvvm.annotations MDNode.
static void addNVVMMetadata(llvm::GlobalValue *GV, StringRef Name,
int Operand);

private:
static void emitBuiltinSurfTexDeviceCopy(CodeGenFunction &CGF, LValue Dst,
LValue Src) {
llvm::Value *Handle = nullptr;
Expand Down Expand Up @@ -256,24 +256,8 @@ void NVPTXTargetCodeGenInfo::setTargetAttributes(
// Create !{<func-ref>, metadata !"kernel", i32 1} node
addNVVMMetadata(F, "kernel", 1);
}
if (CUDALaunchBoundsAttr *Attr = FD->getAttr<CUDALaunchBoundsAttr>()) {
// Create !{<func-ref>, metadata !"maxntidx", i32 <val>} node
llvm::APSInt MaxThreads(32);
MaxThreads = Attr->getMaxThreads()->EvaluateKnownConstInt(M.getContext());
if (MaxThreads > 0)
addNVVMMetadata(F, "maxntidx", MaxThreads.getExtValue());

// min blocks is an optional argument for CUDALaunchBoundsAttr. If it was
// not specified in __launch_bounds__ or if the user specified a 0 value,
// we don't have to add a PTX directive.
if (Attr->getMinBlocks()) {
llvm::APSInt MinBlocks(32);
MinBlocks = Attr->getMinBlocks()->EvaluateKnownConstInt(M.getContext());
if (MinBlocks > 0)
// Create !{<func-ref>, metadata !"minctasm", i32 <val>} node
addNVVMMetadata(F, "minctasm", MinBlocks.getExtValue());
}
}
if (CUDALaunchBoundsAttr *Attr = FD->getAttr<CUDALaunchBoundsAttr>())
M.handleCUDALaunchBoundsAttr(F, Attr);
}

// Attach kernel metadata directly if compiling for NVPTX.
Expand Down Expand Up @@ -303,6 +287,28 @@ bool NVPTXTargetCodeGenInfo::shouldEmitStaticExternCAliases() const {
}
}

void CodeGenModule::handleCUDALaunchBoundsAttr(
llvm::Function *F, const CUDALaunchBoundsAttr *Attr) {
// Create !{<func-ref>, metadata !"maxntidx", i32 <val>} node
llvm::APSInt MaxThreads(32);
MaxThreads = Attr->getMaxThreads()->EvaluateKnownConstInt(getContext());
if (MaxThreads > 0)
NVPTXTargetCodeGenInfo::addNVVMMetadata(F, "maxntidx",
MaxThreads.getExtValue());

// min blocks is an optional argument for CUDALaunchBoundsAttr. If it was
// not specified in __launch_bounds__ or if the user specified a 0 value,
// we don't have to add a PTX directive.
if (Attr->getMinBlocks()) {
llvm::APSInt MinBlocks(32);
MinBlocks = Attr->getMinBlocks()->EvaluateKnownConstInt(getContext());
if (MinBlocks > 0)
// Create !{<func-ref>, metadata !"minctasm", i32 <val>} node
NVPTXTargetCodeGenInfo::addNVVMMetadata(F, "minctasm",
MinBlocks.getExtValue());
}
}

std::unique_ptr<TargetCodeGenInfo>
CodeGen::createNVPTXTargetCodeGenInfo(CodeGenModule &CGM) {
return std::make_unique<NVPTXTargetCodeGenInfo>(CGM.getTypes());
Expand Down

0 comments on commit 08a2207

Please sign in to comment.