From 43d069f024922cad0602239e8e1bdf7eebce0783 Mon Sep 17 00:00:00 2001 From: Razvan Lupusoru Date: Thu, 13 Nov 2025 13:29:55 -0800 Subject: [PATCH] [mlir][acc] Check legality of symbols in acc regions This PR adds a new utility function to check whether symbols used in OpenACC regions are legal for offloading. Functions must be marked with `acc routine` or be built-in intrinsics. Global symbols must be marked with `acc declare`. The utility is designed to be extensible, and the OpenACCSupport analysis has been updated to allow handling of additional symbols that do not necessarily use OpenACC attributes but are marked in a way that still guarantees the symbol will be available when offloading. For example, in the Flang implementation, CUF attributes can be validated as legal symbols. --- .../Dialect/OpenACC/Analysis/OpenACCSupport.h | 37 ++- .../mlir/Dialect/OpenACC/OpenACCUtils.h | 10 + .../OpenACC/Analysis/OpenACCSupport.cpp | 7 + .../Dialect/OpenACC/Utils/OpenACCUtils.cpp | 50 ++++ .../Dialect/OpenACC/OpenACCUtilsTest.cpp | 239 ++++++++++++++++++ 5 files changed, 342 insertions(+), 1 deletion(-) diff --git a/mlir/include/mlir/Dialect/OpenACC/Analysis/OpenACCSupport.h b/mlir/include/mlir/Dialect/OpenACC/Analysis/OpenACCSupport.h index d9b2646b753f3..7be525e87a695 100644 --- a/mlir/include/mlir/Dialect/OpenACC/Analysis/OpenACCSupport.h +++ b/mlir/include/mlir/Dialect/OpenACC/Analysis/OpenACCSupport.h @@ -58,8 +58,10 @@ namespace mlir { namespace acc { -// Forward declaration for RecipeKind enum +// Forward declarations enum class RecipeKind : uint32_t; +bool isValidSymbolUse(Operation *user, SymbolRefAttr symbol, + Operation **definingOpPtr); namespace detail { /// This class contains internal trait classes used by OpenACCSupport. @@ -79,11 +81,27 @@ struct OpenACCSupportTraits { // Used to report a case that is not supported by the implementation. virtual InFlightDiagnostic emitNYI(Location loc, const Twine &message) = 0; + + /// Check if a symbol use is valid for use in an OpenACC region. + virtual bool isValidSymbolUse(Operation *user, SymbolRefAttr symbol, + Operation **definingOpPtr) = 0; }; + /// SFINAE helpers to detect if implementation has optional methods + template + using isValidSymbolUse_t = + decltype(std::declval().isValidSymbolUse(std::declval()...)); + + template + using has_isValidSymbolUse = + llvm::is_detected; + /// This class wraps a concrete OpenACCSupport implementation and forwards /// interface calls to it. This provides type erasure, allowing different /// implementation types to be used interchangeably without inheritance. + /// Methods can be optionally implemented; if not present, default behavior + /// is used. template class Model final : public Concept { public: @@ -102,6 +120,14 @@ struct OpenACCSupportTraits { return impl.emitNYI(loc, message); } + bool isValidSymbolUse(Operation *user, SymbolRefAttr symbol, + Operation **definingOpPtr) final { + if constexpr (has_isValidSymbolUse::value) + return impl.isValidSymbolUse(user, symbol, definingOpPtr); + else + return acc::isValidSymbolUse(user, symbol, definingOpPtr); + } + private: ImplT impl; }; @@ -154,6 +180,15 @@ class OpenACCSupport { /// unsupported case. InFlightDiagnostic emitNYI(Location loc, const Twine &message); + /// Check if a symbol use is valid for use in an OpenACC region. + /// + /// \param user The operation using the symbol. + /// \param symbol The symbol reference being used. + /// \param definingOpPtr Optional output parameter to receive the defining op. + /// \return true if the symbol use is valid, false otherwise. + bool isValidSymbolUse(Operation *user, SymbolRefAttr symbol, + Operation **definingOpPtr = nullptr); + /// Signal that this analysis should always be preserved so that /// underlying implementation registration is not lost. bool isInvalidated(const AnalysisManager::PreservedAnalyses &pa) { diff --git a/mlir/include/mlir/Dialect/OpenACC/OpenACCUtils.h b/mlir/include/mlir/Dialect/OpenACC/OpenACCUtils.h index 964735755c4a3..2852e0917c3fb 100644 --- a/mlir/include/mlir/Dialect/OpenACC/OpenACCUtils.h +++ b/mlir/include/mlir/Dialect/OpenACC/OpenACCUtils.h @@ -52,6 +52,16 @@ std::string getRecipeName(mlir::acc::RecipeKind kind, mlir::Type type); // base `array` from an operation that only accesses a subarray. mlir::Value getBaseEntity(mlir::Value val); +/// Check if a symbol use is valid for use in an OpenACC region. +/// This includes looking for various attributes such as `acc.routine_info` +/// and `acc.declare` attributes. +/// \param user The operation using the symbol +/// \param symbol The symbol reference being used +/// \param definingOpPtr Optional output parameter to receive the defining op +/// \return true if the symbol use is valid, false otherwise +bool isValidSymbolUse(mlir::Operation *user, mlir::SymbolRefAttr symbol, + mlir::Operation **definingOpPtr = nullptr); + } // namespace acc } // namespace mlir diff --git a/mlir/lib/Dialect/OpenACC/Analysis/OpenACCSupport.cpp b/mlir/lib/Dialect/OpenACC/Analysis/OpenACCSupport.cpp index 40e769e7068cf..1d775fb975738 100644 --- a/mlir/lib/Dialect/OpenACC/Analysis/OpenACCSupport.cpp +++ b/mlir/lib/Dialect/OpenACC/Analysis/OpenACCSupport.cpp @@ -41,5 +41,12 @@ InFlightDiagnostic OpenACCSupport::emitNYI(Location loc, const Twine &message) { return mlir::emitError(loc, "not yet implemented: " + message); } +bool OpenACCSupport::isValidSymbolUse(Operation *user, SymbolRefAttr symbol, + Operation **definingOpPtr) { + if (impl) + return impl->isValidSymbolUse(user, symbol, definingOpPtr); + return acc::isValidSymbolUse(user, symbol, definingOpPtr); +} + } // namespace acc } // namespace mlir diff --git a/mlir/lib/Dialect/OpenACC/Utils/OpenACCUtils.cpp b/mlir/lib/Dialect/OpenACC/Utils/OpenACCUtils.cpp index fbac28e740750..aebc248e02ea0 100644 --- a/mlir/lib/Dialect/OpenACC/Utils/OpenACCUtils.cpp +++ b/mlir/lib/Dialect/OpenACC/Utils/OpenACCUtils.cpp @@ -9,8 +9,11 @@ #include "mlir/Dialect/OpenACC/OpenACCUtils.h" #include "mlir/Dialect/OpenACC/OpenACC.h" +#include "mlir/IR/SymbolTable.h" +#include "mlir/Interfaces/FunctionInterfaces.h" #include "mlir/Interfaces/ViewLikeInterface.h" #include "llvm/ADT/TypeSwitch.h" +#include "llvm/IR/Intrinsics.h" #include "llvm/Support/Casting.h" mlir::Operation *mlir::acc::getEnclosingComputeOp(mlir::Region ®ion) { @@ -155,3 +158,50 @@ mlir::Value mlir::acc::getBaseEntity(mlir::Value val) { return val; } + +bool mlir::acc::isValidSymbolUse(mlir::Operation *user, + mlir::SymbolRefAttr symbol, + mlir::Operation **definingOpPtr) { + mlir::Operation *definingOp = + mlir::SymbolTable::lookupNearestSymbolFrom(user, symbol); + + // If there are no defining ops, we have no way to ensure validity because + // we cannot check for any attributes. + if (!definingOp) + return false; + + if (definingOpPtr) + *definingOpPtr = definingOp; + + // Check if the defining op is a recipe (private, reduction, firstprivate). + // Recipes are valid as they get materialized before being offloaded to + // device. They are only instructions for how to materialize. + if (mlir::isa(definingOp)) + return true; + + // Check if the defining op is a function + if (auto func = + mlir::dyn_cast_if_present(definingOp)) { + // If this symbol is actually an acc routine - then it is expected for it + // to be offloaded - therefore it is valid. + if (func->hasAttr(mlir::acc::getRoutineInfoAttrName())) + return true; + + // If this symbol is a call to an LLVM intrinsic, then it is likely valid. + // Check the following: + // 1. The function is private + // 2. The function has no body + // 3. Name starts with "llvm." + // 4. The function's name is a valid LLVM intrinsic name + if (func.getVisibility() == mlir::SymbolTable::Visibility::Private && + func.getFunctionBody().empty() && func.getName().starts_with("llvm.") && + llvm::Intrinsic::lookupIntrinsicID(func.getName()) != + llvm::Intrinsic::not_intrinsic) + return true; + } + + // A declare attribute is needed for symbol references. + bool hasDeclare = definingOp->hasAttr(mlir::acc::getDeclareAttrName()); + return hasDeclare; +} diff --git a/mlir/unittests/Dialect/OpenACC/OpenACCUtilsTest.cpp b/mlir/unittests/Dialect/OpenACC/OpenACCUtilsTest.cpp index 6f4e30585b2c9..8b1f532bbe5c0 100644 --- a/mlir/unittests/Dialect/OpenACC/OpenACCUtilsTest.cpp +++ b/mlir/unittests/Dialect/OpenACC/OpenACCUtilsTest.cpp @@ -674,3 +674,242 @@ TEST_F(OpenACCUtilsTest, getBaseEntityChainedSubviews) { Value ultimateBase = getBaseEntity(baseEntity); EXPECT_EQ(ultimateBase, baseMemref); } + +//===----------------------------------------------------------------------===// +// isValidSymbolUse Tests +//===----------------------------------------------------------------------===// + +TEST_F(OpenACCUtilsTest, isValidSymbolUseNoDefiningOp) { + // Create a memref.get_global that references a non-existent global + auto memrefType = MemRefType::get({10}, b.getI32Type()); + llvm::StringRef globalName = "nonexistent_global"; + SymbolRefAttr nonExistentSymbol = SymbolRefAttr::get(&context, globalName); + + OwningOpRef getGlobalOp = + memref::GetGlobalOp::create(b, loc, memrefType, globalName); + + Operation *definingOp = nullptr; + bool result = + isValidSymbolUse(getGlobalOp.get(), nonExistentSymbol, &definingOp); + + EXPECT_FALSE(result); + EXPECT_EQ(definingOp, nullptr); +} + +TEST_F(OpenACCUtilsTest, isValidSymbolUseRecipe) { + // Create a module to hold the recipe + OwningOpRef module = ModuleOp::create(loc); + Block *moduleBlock = module->getBody(); + + OpBuilder::InsertionGuard guard(b); + b.setInsertionPointToStart(moduleBlock); + + // Create a private recipe (any recipe type would work) + auto i32Type = b.getI32Type(); + llvm::StringRef recipeName = "test_recipe"; + OwningOpRef recipeOp = + PrivateRecipeOp::create(b, loc, recipeName, i32Type); + + // Create a value to privatize + auto memrefTy = MemRefType::get({10}, b.getI32Type()); + OwningOpRef allocOp = + memref::AllocaOp::create(b, loc, memrefTy); + TypedValue varPtr = + cast>(allocOp->getResult()); + + // Create a private op as the user operation + OwningOpRef privateOp = PrivateOp::create( + b, loc, varPtr, /*structured=*/true, /*implicit=*/false); + + // Create a symbol reference to the recipe + SymbolRefAttr recipeSymbol = SymbolRefAttr::get(&context, recipeName); + + Operation *definingOp = nullptr; + bool result = isValidSymbolUse(privateOp.get(), recipeSymbol, &definingOp); + + EXPECT_TRUE(result); + EXPECT_EQ(definingOp, recipeOp.get()); +} + +TEST_F(OpenACCUtilsTest, isValidSymbolUseFunctionWithRoutineInfo) { + // Create a module to hold the function + OwningOpRef module = ModuleOp::create(loc); + Block *moduleBlock = module->getBody(); + + OpBuilder::InsertionGuard guard(b); + b.setInsertionPointToStart(moduleBlock); + + // Create a function with routine_info attribute + auto funcType = b.getFunctionType({}, {}); + llvm::StringRef funcName = "routine_func"; + OwningOpRef funcOp = + func::FuncOp::create(b, loc, funcName, funcType); + + // Add routine_info attribute with a reference to a routine + SmallVector routineRefs = { + SymbolRefAttr::get(&context, "acc_routine")}; + funcOp.get()->setAttr(getRoutineInfoAttrName(), + RoutineInfoAttr::get(&context, routineRefs)); + + // Create a call operation that uses the function symbol + SymbolRefAttr funcSymbol = SymbolRefAttr::get(&context, funcName); + OwningOpRef callOp = func::CallOp::create( + b, loc, funcSymbol, funcType.getResults(), ValueRange{}); + + Operation *definingOp = nullptr; + bool result = isValidSymbolUse(callOp.get(), funcSymbol, &definingOp); + + EXPECT_TRUE(result); + EXPECT_NE(definingOp, nullptr); +} + +TEST_F(OpenACCUtilsTest, isValidSymbolUseLLVMIntrinsic) { + // Create a module to hold the function + OwningOpRef module = ModuleOp::create(loc); + Block *moduleBlock = module->getBody(); + + OpBuilder::InsertionGuard guard(b); + b.setInsertionPointToStart(moduleBlock); + + // Create a private function with LLVM intrinsic name + auto funcType = b.getFunctionType({b.getF32Type()}, {b.getF32Type()}); + llvm::StringRef intrinsicName = "llvm.sqrt.f32"; + OwningOpRef funcOp = + func::FuncOp::create(b, loc, intrinsicName, funcType); + + // Set visibility to private (required for intrinsics) + funcOp->setPrivate(); + + // Create a call operation that uses the intrinsic + SymbolRefAttr funcSymbol = SymbolRefAttr::get(&context, intrinsicName); + OwningOpRef callOp = func::CallOp::create( + b, loc, funcSymbol, funcType.getResults(), ValueRange{}); + + Operation *definingOp = nullptr; + bool result = isValidSymbolUse(callOp.get(), funcSymbol, &definingOp); + + EXPECT_TRUE(result); + EXPECT_NE(definingOp, nullptr); +} + +TEST_F(OpenACCUtilsTest, isValidSymbolUseFunctionNotIntrinsic) { + // Create a module to hold the function + OwningOpRef module = ModuleOp::create(loc); + Block *moduleBlock = module->getBody(); + + OpBuilder::InsertionGuard guard(b); + b.setInsertionPointToStart(moduleBlock); + + // Create a private function that looks like intrinsic but isn't + auto funcType = b.getFunctionType({}, {}); + llvm::StringRef funcName = "llvm.not_a_real_intrinsic"; + OwningOpRef funcOp = + func::FuncOp::create(b, loc, funcName, funcType); + funcOp->setPrivate(); + + // Create a call operation that uses the function + SymbolRefAttr funcSymbol = SymbolRefAttr::get(&context, funcName); + OwningOpRef callOp = func::CallOp::create( + b, loc, funcSymbol, funcType.getResults(), ValueRange{}); + + Operation *definingOp = nullptr; + bool result = isValidSymbolUse(callOp.get(), funcSymbol, &definingOp); + + // Should be false because it's not a valid intrinsic and has no + // acc.routine_info attr + EXPECT_FALSE(result); + EXPECT_NE(definingOp, nullptr); +} + +TEST_F(OpenACCUtilsTest, isValidSymbolUseWithDeclareAttr) { + // Create a module to hold a function + OwningOpRef module = ModuleOp::create(loc); + Block *moduleBlock = module->getBody(); + + OpBuilder::InsertionGuard guard(b); + b.setInsertionPointToStart(moduleBlock); + + // Create a function with declare attribute + auto funcType = b.getFunctionType({}, {}); + llvm::StringRef funcName = "declared_func"; + OwningOpRef funcOp = + func::FuncOp::create(b, loc, funcName, funcType); + + // Add declare attribute + funcOp.get()->setAttr( + getDeclareAttrName(), + DeclareAttr::get(&context, + DataClauseAttr::get(&context, DataClause::acc_copy))); + + // Create a call operation that uses the function + SymbolRefAttr funcSymbol = SymbolRefAttr::get(&context, funcName); + OwningOpRef callOp = func::CallOp::create( + b, loc, funcSymbol, funcType.getResults(), ValueRange{}); + + Operation *definingOp = nullptr; + bool result = isValidSymbolUse(callOp.get(), funcSymbol, &definingOp); + + EXPECT_TRUE(result); + EXPECT_NE(definingOp, nullptr); +} + +TEST_F(OpenACCUtilsTest, isValidSymbolUseWithoutValidAttributes) { + // Create a module to hold a function + OwningOpRef module = ModuleOp::create(loc); + Block *moduleBlock = module->getBody(); + + OpBuilder::InsertionGuard guard(b); + b.setInsertionPointToStart(moduleBlock); + + // Create a function without any special attributes + auto funcType = b.getFunctionType({}, {}); + llvm::StringRef funcName = "regular_func"; + OwningOpRef funcOp = + func::FuncOp::create(b, loc, funcName, funcType); + + // Create a call operation that uses the function + SymbolRefAttr funcSymbol = SymbolRefAttr::get(&context, funcName); + OwningOpRef callOp = func::CallOp::create( + b, loc, funcSymbol, funcType.getResults(), ValueRange{}); + + Operation *definingOp = nullptr; + bool result = isValidSymbolUse(callOp.get(), funcSymbol, &definingOp); + + // Should be false - no routine_info, not an intrinsic, no declare attribute + EXPECT_FALSE(result); + EXPECT_NE(definingOp, nullptr); +} + +TEST_F(OpenACCUtilsTest, isValidSymbolUseNullDefiningOpPtr) { + // Create a module to hold a recipe + OwningOpRef module = ModuleOp::create(loc); + Block *moduleBlock = module->getBody(); + + OpBuilder::InsertionGuard guard(b); + b.setInsertionPointToStart(moduleBlock); + + // Create a private recipe + auto i32Type = b.getI32Type(); + llvm::StringRef recipeName = "test_recipe"; + OwningOpRef recipeOp = + PrivateRecipeOp::create(b, loc, recipeName, i32Type); + + // Create a value to privatize + auto memrefTy = MemRefType::get({10}, b.getI32Type()); + OwningOpRef allocOp = + memref::AllocaOp::create(b, loc, memrefTy); + TypedValue varPtr = + cast>(allocOp->getResult()); + + // Create a private op as the user operation + OwningOpRef privateOp = PrivateOp::create( + b, loc, varPtr, /*structured=*/true, /*implicit=*/false); + + // Create a symbol reference to the recipe + SymbolRefAttr recipeSymbol = SymbolRefAttr::get(&context, recipeName); + + // Call without definingOpPtr (nullptr) + bool result = isValidSymbolUse(privateOp.get(), recipeSymbol, nullptr); + + EXPECT_TRUE(result); +}