Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
37 changes: 36 additions & 1 deletion mlir/include/mlir/Dialect/OpenACC/Analysis/OpenACCSupport.h
Original file line number Diff line number Diff line change
Expand Up @@ -58,8 +58,10 @@
namespace mlir {
namespace acc {

// Forward declaration for RecipeKind enum
// Forward declarations
enum class RecipeKind : uint32_t;
bool isValidSymbolUse(Operation *user, SymbolRefAttr symbol,
Operation **definingOpPtr);

namespace detail {
/// This class contains internal trait classes used by OpenACCSupport.
Expand All @@ -79,11 +81,27 @@ struct OpenACCSupportTraits {

// Used to report a case that is not supported by the implementation.
virtual InFlightDiagnostic emitNYI(Location loc, const Twine &message) = 0;

/// Check if a symbol use is valid for use in an OpenACC region.
virtual bool isValidSymbolUse(Operation *user, SymbolRefAttr symbol,
Operation **definingOpPtr) = 0;
};

/// SFINAE helpers to detect if implementation has optional methods
template <typename ImplT, typename... Args>
using isValidSymbolUse_t =
decltype(std::declval<ImplT>().isValidSymbolUse(std::declval<Args>()...));

template <typename ImplT>
using has_isValidSymbolUse =
llvm::is_detected<isValidSymbolUse_t, ImplT, Operation *, SymbolRefAttr,
Operation **>;

/// This class wraps a concrete OpenACCSupport implementation and forwards
/// interface calls to it. This provides type erasure, allowing different
/// implementation types to be used interchangeably without inheritance.
/// Methods can be optionally implemented; if not present, default behavior
/// is used.
template <typename ImplT>
class Model final : public Concept {
public:
Expand All @@ -102,6 +120,14 @@ struct OpenACCSupportTraits {
return impl.emitNYI(loc, message);
}

bool isValidSymbolUse(Operation *user, SymbolRefAttr symbol,
Operation **definingOpPtr) final {
if constexpr (has_isValidSymbolUse<ImplT>::value)
return impl.isValidSymbolUse(user, symbol, definingOpPtr);
else
return acc::isValidSymbolUse(user, symbol, definingOpPtr);
}

private:
ImplT impl;
};
Expand Down Expand Up @@ -154,6 +180,15 @@ class OpenACCSupport {
/// unsupported case.
InFlightDiagnostic emitNYI(Location loc, const Twine &message);

/// Check if a symbol use is valid for use in an OpenACC region.
///
/// \param user The operation using the symbol.
/// \param symbol The symbol reference being used.
/// \param definingOpPtr Optional output parameter to receive the defining op.
/// \return true if the symbol use is valid, false otherwise.
bool isValidSymbolUse(Operation *user, SymbolRefAttr symbol,
Operation **definingOpPtr = nullptr);

/// Signal that this analysis should always be preserved so that
/// underlying implementation registration is not lost.
bool isInvalidated(const AnalysisManager::PreservedAnalyses &pa) {
Expand Down
10 changes: 10 additions & 0 deletions mlir/include/mlir/Dialect/OpenACC/OpenACCUtils.h
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,16 @@ std::string getRecipeName(mlir::acc::RecipeKind kind, mlir::Type type);
// base `array` from an operation that only accesses a subarray.
mlir::Value getBaseEntity(mlir::Value val);

/// Check if a symbol use is valid for use in an OpenACC region.
/// This includes looking for various attributes such as `acc.routine_info`
/// and `acc.declare` attributes.
/// \param user The operation using the symbol
/// \param symbol The symbol reference being used
/// \param definingOpPtr Optional output parameter to receive the defining op
/// \return true if the symbol use is valid, false otherwise
bool isValidSymbolUse(mlir::Operation *user, mlir::SymbolRefAttr symbol,
mlir::Operation **definingOpPtr = nullptr);

} // namespace acc
} // namespace mlir

Expand Down
7 changes: 7 additions & 0 deletions mlir/lib/Dialect/OpenACC/Analysis/OpenACCSupport.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -41,5 +41,12 @@ InFlightDiagnostic OpenACCSupport::emitNYI(Location loc, const Twine &message) {
return mlir::emitError(loc, "not yet implemented: " + message);
}

bool OpenACCSupport::isValidSymbolUse(Operation *user, SymbolRefAttr symbol,
Operation **definingOpPtr) {
if (impl)
return impl->isValidSymbolUse(user, symbol, definingOpPtr);
return acc::isValidSymbolUse(user, symbol, definingOpPtr);
}

} // namespace acc
} // namespace mlir
50 changes: 50 additions & 0 deletions mlir/lib/Dialect/OpenACC/Utils/OpenACCUtils.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -9,8 +9,11 @@
#include "mlir/Dialect/OpenACC/OpenACCUtils.h"

#include "mlir/Dialect/OpenACC/OpenACC.h"
#include "mlir/IR/SymbolTable.h"
#include "mlir/Interfaces/FunctionInterfaces.h"
#include "mlir/Interfaces/ViewLikeInterface.h"
#include "llvm/ADT/TypeSwitch.h"
#include "llvm/IR/Intrinsics.h"
#include "llvm/Support/Casting.h"

mlir::Operation *mlir::acc::getEnclosingComputeOp(mlir::Region &region) {
Expand Down Expand Up @@ -155,3 +158,50 @@ mlir::Value mlir::acc::getBaseEntity(mlir::Value val) {

return val;
}

bool mlir::acc::isValidSymbolUse(mlir::Operation *user,
mlir::SymbolRefAttr symbol,
mlir::Operation **definingOpPtr) {
mlir::Operation *definingOp =
mlir::SymbolTable::lookupNearestSymbolFrom(user, symbol);

// If there are no defining ops, we have no way to ensure validity because
// we cannot check for any attributes.
if (!definingOp)
return false;

if (definingOpPtr)
*definingOpPtr = definingOp;

// Check if the defining op is a recipe (private, reduction, firstprivate).
// Recipes are valid as they get materialized before being offloaded to
// device. They are only instructions for how to materialize.
if (mlir::isa<mlir::acc::PrivateRecipeOp, mlir::acc::ReductionRecipeOp,
mlir::acc::FirstprivateRecipeOp>(definingOp))
return true;

// Check if the defining op is a function
if (auto func =
mlir::dyn_cast_if_present<mlir::FunctionOpInterface>(definingOp)) {
// If this symbol is actually an acc routine - then it is expected for it
// to be offloaded - therefore it is valid.
if (func->hasAttr(mlir::acc::getRoutineInfoAttrName()))
return true;

// If this symbol is a call to an LLVM intrinsic, then it is likely valid.
// Check the following:
// 1. The function is private
// 2. The function has no body
// 3. Name starts with "llvm."
// 4. The function's name is a valid LLVM intrinsic name
if (func.getVisibility() == mlir::SymbolTable::Visibility::Private &&
func.getFunctionBody().empty() && func.getName().starts_with("llvm.") &&
llvm::Intrinsic::lookupIntrinsicID(func.getName()) !=
llvm::Intrinsic::not_intrinsic)
return true;
}

// A declare attribute is needed for symbol references.
bool hasDeclare = definingOp->hasAttr(mlir::acc::getDeclareAttrName());
return hasDeclare;
}
Loading
Loading