diff --git a/mlir/include/mlir/Dialect/OpenACC/Analysis/OpenACCSupport.h b/mlir/include/mlir/Dialect/OpenACC/Analysis/OpenACCSupport.h index d9b2646b753f3..7be525e87a695 100644 --- a/mlir/include/mlir/Dialect/OpenACC/Analysis/OpenACCSupport.h +++ b/mlir/include/mlir/Dialect/OpenACC/Analysis/OpenACCSupport.h @@ -58,8 +58,10 @@ namespace mlir { namespace acc { -// Forward declaration for RecipeKind enum +// Forward declarations enum class RecipeKind : uint32_t; +bool isValidSymbolUse(Operation *user, SymbolRefAttr symbol, + Operation **definingOpPtr); namespace detail { /// This class contains internal trait classes used by OpenACCSupport. @@ -79,11 +81,27 @@ struct OpenACCSupportTraits { // Used to report a case that is not supported by the implementation. virtual InFlightDiagnostic emitNYI(Location loc, const Twine &message) = 0; + + /// Check if a symbol use is valid for use in an OpenACC region. + virtual bool isValidSymbolUse(Operation *user, SymbolRefAttr symbol, + Operation **definingOpPtr) = 0; }; + /// SFINAE helpers to detect if implementation has optional methods + template + using isValidSymbolUse_t = + decltype(std::declval().isValidSymbolUse(std::declval()...)); + + template + using has_isValidSymbolUse = + llvm::is_detected; + /// This class wraps a concrete OpenACCSupport implementation and forwards /// interface calls to it. This provides type erasure, allowing different /// implementation types to be used interchangeably without inheritance. + /// Methods can be optionally implemented; if not present, default behavior + /// is used. template class Model final : public Concept { public: @@ -102,6 +120,14 @@ struct OpenACCSupportTraits { return impl.emitNYI(loc, message); } + bool isValidSymbolUse(Operation *user, SymbolRefAttr symbol, + Operation **definingOpPtr) final { + if constexpr (has_isValidSymbolUse::value) + return impl.isValidSymbolUse(user, symbol, definingOpPtr); + else + return acc::isValidSymbolUse(user, symbol, definingOpPtr); + } + private: ImplT impl; }; @@ -154,6 +180,15 @@ class OpenACCSupport { /// unsupported case. InFlightDiagnostic emitNYI(Location loc, const Twine &message); + /// Check if a symbol use is valid for use in an OpenACC region. + /// + /// \param user The operation using the symbol. + /// \param symbol The symbol reference being used. + /// \param definingOpPtr Optional output parameter to receive the defining op. + /// \return true if the symbol use is valid, false otherwise. + bool isValidSymbolUse(Operation *user, SymbolRefAttr symbol, + Operation **definingOpPtr = nullptr); + /// Signal that this analysis should always be preserved so that /// underlying implementation registration is not lost. bool isInvalidated(const AnalysisManager::PreservedAnalyses &pa) { diff --git a/mlir/include/mlir/Dialect/OpenACC/OpenACCUtils.h b/mlir/include/mlir/Dialect/OpenACC/OpenACCUtils.h index 964735755c4a3..2852e0917c3fb 100644 --- a/mlir/include/mlir/Dialect/OpenACC/OpenACCUtils.h +++ b/mlir/include/mlir/Dialect/OpenACC/OpenACCUtils.h @@ -52,6 +52,16 @@ std::string getRecipeName(mlir::acc::RecipeKind kind, mlir::Type type); // base `array` from an operation that only accesses a subarray. mlir::Value getBaseEntity(mlir::Value val); +/// Check if a symbol use is valid for use in an OpenACC region. +/// This includes looking for various attributes such as `acc.routine_info` +/// and `acc.declare` attributes. +/// \param user The operation using the symbol +/// \param symbol The symbol reference being used +/// \param definingOpPtr Optional output parameter to receive the defining op +/// \return true if the symbol use is valid, false otherwise +bool isValidSymbolUse(mlir::Operation *user, mlir::SymbolRefAttr symbol, + mlir::Operation **definingOpPtr = nullptr); + } // namespace acc } // namespace mlir diff --git a/mlir/lib/Dialect/OpenACC/Analysis/OpenACCSupport.cpp b/mlir/lib/Dialect/OpenACC/Analysis/OpenACCSupport.cpp index 40e769e7068cf..1d775fb975738 100644 --- a/mlir/lib/Dialect/OpenACC/Analysis/OpenACCSupport.cpp +++ b/mlir/lib/Dialect/OpenACC/Analysis/OpenACCSupport.cpp @@ -41,5 +41,12 @@ InFlightDiagnostic OpenACCSupport::emitNYI(Location loc, const Twine &message) { return mlir::emitError(loc, "not yet implemented: " + message); } +bool OpenACCSupport::isValidSymbolUse(Operation *user, SymbolRefAttr symbol, + Operation **definingOpPtr) { + if (impl) + return impl->isValidSymbolUse(user, symbol, definingOpPtr); + return acc::isValidSymbolUse(user, symbol, definingOpPtr); +} + } // namespace acc } // namespace mlir diff --git a/mlir/lib/Dialect/OpenACC/Utils/OpenACCUtils.cpp b/mlir/lib/Dialect/OpenACC/Utils/OpenACCUtils.cpp index fbac28e740750..aebc248e02ea0 100644 --- a/mlir/lib/Dialect/OpenACC/Utils/OpenACCUtils.cpp +++ b/mlir/lib/Dialect/OpenACC/Utils/OpenACCUtils.cpp @@ -9,8 +9,11 @@ #include "mlir/Dialect/OpenACC/OpenACCUtils.h" #include "mlir/Dialect/OpenACC/OpenACC.h" +#include "mlir/IR/SymbolTable.h" +#include "mlir/Interfaces/FunctionInterfaces.h" #include "mlir/Interfaces/ViewLikeInterface.h" #include "llvm/ADT/TypeSwitch.h" +#include "llvm/IR/Intrinsics.h" #include "llvm/Support/Casting.h" mlir::Operation *mlir::acc::getEnclosingComputeOp(mlir::Region ®ion) { @@ -155,3 +158,50 @@ mlir::Value mlir::acc::getBaseEntity(mlir::Value val) { return val; } + +bool mlir::acc::isValidSymbolUse(mlir::Operation *user, + mlir::SymbolRefAttr symbol, + mlir::Operation **definingOpPtr) { + mlir::Operation *definingOp = + mlir::SymbolTable::lookupNearestSymbolFrom(user, symbol); + + // If there are no defining ops, we have no way to ensure validity because + // we cannot check for any attributes. + if (!definingOp) + return false; + + if (definingOpPtr) + *definingOpPtr = definingOp; + + // Check if the defining op is a recipe (private, reduction, firstprivate). + // Recipes are valid as they get materialized before being offloaded to + // device. They are only instructions for how to materialize. + if (mlir::isa(definingOp)) + return true; + + // Check if the defining op is a function + if (auto func = + mlir::dyn_cast_if_present(definingOp)) { + // If this symbol is actually an acc routine - then it is expected for it + // to be offloaded - therefore it is valid. + if (func->hasAttr(mlir::acc::getRoutineInfoAttrName())) + return true; + + // If this symbol is a call to an LLVM intrinsic, then it is likely valid. + // Check the following: + // 1. The function is private + // 2. The function has no body + // 3. Name starts with "llvm." + // 4. The function's name is a valid LLVM intrinsic name + if (func.getVisibility() == mlir::SymbolTable::Visibility::Private && + func.getFunctionBody().empty() && func.getName().starts_with("llvm.") && + llvm::Intrinsic::lookupIntrinsicID(func.getName()) != + llvm::Intrinsic::not_intrinsic) + return true; + } + + // A declare attribute is needed for symbol references. + bool hasDeclare = definingOp->hasAttr(mlir::acc::getDeclareAttrName()); + return hasDeclare; +} diff --git a/mlir/unittests/Dialect/OpenACC/OpenACCUtilsTest.cpp b/mlir/unittests/Dialect/OpenACC/OpenACCUtilsTest.cpp index 6f4e30585b2c9..8b1f532bbe5c0 100644 --- a/mlir/unittests/Dialect/OpenACC/OpenACCUtilsTest.cpp +++ b/mlir/unittests/Dialect/OpenACC/OpenACCUtilsTest.cpp @@ -674,3 +674,242 @@ TEST_F(OpenACCUtilsTest, getBaseEntityChainedSubviews) { Value ultimateBase = getBaseEntity(baseEntity); EXPECT_EQ(ultimateBase, baseMemref); } + +//===----------------------------------------------------------------------===// +// isValidSymbolUse Tests +//===----------------------------------------------------------------------===// + +TEST_F(OpenACCUtilsTest, isValidSymbolUseNoDefiningOp) { + // Create a memref.get_global that references a non-existent global + auto memrefType = MemRefType::get({10}, b.getI32Type()); + llvm::StringRef globalName = "nonexistent_global"; + SymbolRefAttr nonExistentSymbol = SymbolRefAttr::get(&context, globalName); + + OwningOpRef getGlobalOp = + memref::GetGlobalOp::create(b, loc, memrefType, globalName); + + Operation *definingOp = nullptr; + bool result = + isValidSymbolUse(getGlobalOp.get(), nonExistentSymbol, &definingOp); + + EXPECT_FALSE(result); + EXPECT_EQ(definingOp, nullptr); +} + +TEST_F(OpenACCUtilsTest, isValidSymbolUseRecipe) { + // Create a module to hold the recipe + OwningOpRef module = ModuleOp::create(loc); + Block *moduleBlock = module->getBody(); + + OpBuilder::InsertionGuard guard(b); + b.setInsertionPointToStart(moduleBlock); + + // Create a private recipe (any recipe type would work) + auto i32Type = b.getI32Type(); + llvm::StringRef recipeName = "test_recipe"; + OwningOpRef recipeOp = + PrivateRecipeOp::create(b, loc, recipeName, i32Type); + + // Create a value to privatize + auto memrefTy = MemRefType::get({10}, b.getI32Type()); + OwningOpRef allocOp = + memref::AllocaOp::create(b, loc, memrefTy); + TypedValue varPtr = + cast>(allocOp->getResult()); + + // Create a private op as the user operation + OwningOpRef privateOp = PrivateOp::create( + b, loc, varPtr, /*structured=*/true, /*implicit=*/false); + + // Create a symbol reference to the recipe + SymbolRefAttr recipeSymbol = SymbolRefAttr::get(&context, recipeName); + + Operation *definingOp = nullptr; + bool result = isValidSymbolUse(privateOp.get(), recipeSymbol, &definingOp); + + EXPECT_TRUE(result); + EXPECT_EQ(definingOp, recipeOp.get()); +} + +TEST_F(OpenACCUtilsTest, isValidSymbolUseFunctionWithRoutineInfo) { + // Create a module to hold the function + OwningOpRef module = ModuleOp::create(loc); + Block *moduleBlock = module->getBody(); + + OpBuilder::InsertionGuard guard(b); + b.setInsertionPointToStart(moduleBlock); + + // Create a function with routine_info attribute + auto funcType = b.getFunctionType({}, {}); + llvm::StringRef funcName = "routine_func"; + OwningOpRef funcOp = + func::FuncOp::create(b, loc, funcName, funcType); + + // Add routine_info attribute with a reference to a routine + SmallVector routineRefs = { + SymbolRefAttr::get(&context, "acc_routine")}; + funcOp.get()->setAttr(getRoutineInfoAttrName(), + RoutineInfoAttr::get(&context, routineRefs)); + + // Create a call operation that uses the function symbol + SymbolRefAttr funcSymbol = SymbolRefAttr::get(&context, funcName); + OwningOpRef callOp = func::CallOp::create( + b, loc, funcSymbol, funcType.getResults(), ValueRange{}); + + Operation *definingOp = nullptr; + bool result = isValidSymbolUse(callOp.get(), funcSymbol, &definingOp); + + EXPECT_TRUE(result); + EXPECT_NE(definingOp, nullptr); +} + +TEST_F(OpenACCUtilsTest, isValidSymbolUseLLVMIntrinsic) { + // Create a module to hold the function + OwningOpRef module = ModuleOp::create(loc); + Block *moduleBlock = module->getBody(); + + OpBuilder::InsertionGuard guard(b); + b.setInsertionPointToStart(moduleBlock); + + // Create a private function with LLVM intrinsic name + auto funcType = b.getFunctionType({b.getF32Type()}, {b.getF32Type()}); + llvm::StringRef intrinsicName = "llvm.sqrt.f32"; + OwningOpRef funcOp = + func::FuncOp::create(b, loc, intrinsicName, funcType); + + // Set visibility to private (required for intrinsics) + funcOp->setPrivate(); + + // Create a call operation that uses the intrinsic + SymbolRefAttr funcSymbol = SymbolRefAttr::get(&context, intrinsicName); + OwningOpRef callOp = func::CallOp::create( + b, loc, funcSymbol, funcType.getResults(), ValueRange{}); + + Operation *definingOp = nullptr; + bool result = isValidSymbolUse(callOp.get(), funcSymbol, &definingOp); + + EXPECT_TRUE(result); + EXPECT_NE(definingOp, nullptr); +} + +TEST_F(OpenACCUtilsTest, isValidSymbolUseFunctionNotIntrinsic) { + // Create a module to hold the function + OwningOpRef module = ModuleOp::create(loc); + Block *moduleBlock = module->getBody(); + + OpBuilder::InsertionGuard guard(b); + b.setInsertionPointToStart(moduleBlock); + + // Create a private function that looks like intrinsic but isn't + auto funcType = b.getFunctionType({}, {}); + llvm::StringRef funcName = "llvm.not_a_real_intrinsic"; + OwningOpRef funcOp = + func::FuncOp::create(b, loc, funcName, funcType); + funcOp->setPrivate(); + + // Create a call operation that uses the function + SymbolRefAttr funcSymbol = SymbolRefAttr::get(&context, funcName); + OwningOpRef callOp = func::CallOp::create( + b, loc, funcSymbol, funcType.getResults(), ValueRange{}); + + Operation *definingOp = nullptr; + bool result = isValidSymbolUse(callOp.get(), funcSymbol, &definingOp); + + // Should be false because it's not a valid intrinsic and has no + // acc.routine_info attr + EXPECT_FALSE(result); + EXPECT_NE(definingOp, nullptr); +} + +TEST_F(OpenACCUtilsTest, isValidSymbolUseWithDeclareAttr) { + // Create a module to hold a function + OwningOpRef module = ModuleOp::create(loc); + Block *moduleBlock = module->getBody(); + + OpBuilder::InsertionGuard guard(b); + b.setInsertionPointToStart(moduleBlock); + + // Create a function with declare attribute + auto funcType = b.getFunctionType({}, {}); + llvm::StringRef funcName = "declared_func"; + OwningOpRef funcOp = + func::FuncOp::create(b, loc, funcName, funcType); + + // Add declare attribute + funcOp.get()->setAttr( + getDeclareAttrName(), + DeclareAttr::get(&context, + DataClauseAttr::get(&context, DataClause::acc_copy))); + + // Create a call operation that uses the function + SymbolRefAttr funcSymbol = SymbolRefAttr::get(&context, funcName); + OwningOpRef callOp = func::CallOp::create( + b, loc, funcSymbol, funcType.getResults(), ValueRange{}); + + Operation *definingOp = nullptr; + bool result = isValidSymbolUse(callOp.get(), funcSymbol, &definingOp); + + EXPECT_TRUE(result); + EXPECT_NE(definingOp, nullptr); +} + +TEST_F(OpenACCUtilsTest, isValidSymbolUseWithoutValidAttributes) { + // Create a module to hold a function + OwningOpRef module = ModuleOp::create(loc); + Block *moduleBlock = module->getBody(); + + OpBuilder::InsertionGuard guard(b); + b.setInsertionPointToStart(moduleBlock); + + // Create a function without any special attributes + auto funcType = b.getFunctionType({}, {}); + llvm::StringRef funcName = "regular_func"; + OwningOpRef funcOp = + func::FuncOp::create(b, loc, funcName, funcType); + + // Create a call operation that uses the function + SymbolRefAttr funcSymbol = SymbolRefAttr::get(&context, funcName); + OwningOpRef callOp = func::CallOp::create( + b, loc, funcSymbol, funcType.getResults(), ValueRange{}); + + Operation *definingOp = nullptr; + bool result = isValidSymbolUse(callOp.get(), funcSymbol, &definingOp); + + // Should be false - no routine_info, not an intrinsic, no declare attribute + EXPECT_FALSE(result); + EXPECT_NE(definingOp, nullptr); +} + +TEST_F(OpenACCUtilsTest, isValidSymbolUseNullDefiningOpPtr) { + // Create a module to hold a recipe + OwningOpRef module = ModuleOp::create(loc); + Block *moduleBlock = module->getBody(); + + OpBuilder::InsertionGuard guard(b); + b.setInsertionPointToStart(moduleBlock); + + // Create a private recipe + auto i32Type = b.getI32Type(); + llvm::StringRef recipeName = "test_recipe"; + OwningOpRef recipeOp = + PrivateRecipeOp::create(b, loc, recipeName, i32Type); + + // Create a value to privatize + auto memrefTy = MemRefType::get({10}, b.getI32Type()); + OwningOpRef allocOp = + memref::AllocaOp::create(b, loc, memrefTy); + TypedValue varPtr = + cast>(allocOp->getResult()); + + // Create a private op as the user operation + OwningOpRef privateOp = PrivateOp::create( + b, loc, varPtr, /*structured=*/true, /*implicit=*/false); + + // Create a symbol reference to the recipe + SymbolRefAttr recipeSymbol = SymbolRefAttr::get(&context, recipeName); + + // Call without definingOpPtr (nullptr) + bool result = isValidSymbolUse(privateOp.get(), recipeSymbol, nullptr); + + EXPECT_TRUE(result); +}