From adb9b5c516d2fd7dba02d71c3e19284069f9ca07 Mon Sep 17 00:00:00 2001 From: Razvan Lupusoru Date: Thu, 20 Nov 2025 17:38:42 -0800 Subject: [PATCH 1/3] [acc][flang] Implement acc interface for tracking type descriptors FIR operations that use derived types need to have type descriptor globals available on device when offloading. Examples of this can be seen in `CUFDeviceGlobal` which ensures that such type descriptor uses work on device for CUF. Similarly, this is needed for OpenACC. This change introduces a new interface to the OpenACC dialect named `IndirectGlobalAccessOpInterface` which can be attached to operations that may result in generation of accesses that use type descriptor globals. This functionality is needed for the `ACCImplicitDeclare` pass that is coming in a follow-up change which implicitly ensures that all referenced globals are available in OpenACC compute contexts. The interface provides a `getReferencedSymbols` method that collects all global symbols referenced by an operation. When a symbol table is provided, the implementation for FIR recursively walks type descriptor globals to find all transitively referenced symbols. Note that alternately this could have been implemented in different ways: - Codegen could implicitly generate such type globals as needed by changing the technique that relies on populating them during lowering (eg generate them directly in gpu.module during codegen). - This interface could attach to types instead of operations for a potentially more conservative implementation which maps all type descriptors even if the underlying implementation using it won't necessarily need such mapping. The technique chosen here is consistent with `CUFDeviceGlobal` (which walks operations inside `prepareImplicitDeviceGlobals`) and avoids conservative mapping of all type descriptors. --- .../OpenACC/Support/FIROpenACCOpsInterfaces.h | 10 ++ .../Support/FIROpenACCOpsInterfaces.cpp | 100 ++++++++++++++++++ .../Support/RegisterOpenACCExtensions.cpp | 9 ++ .../Dialect/OpenACC/OpenACCOpsInterfaces.td | 23 ++++ 4 files changed, 142 insertions(+) diff --git a/flang/include/flang/Optimizer/OpenACC/Support/FIROpenACCOpsInterfaces.h b/flang/include/flang/Optimizer/OpenACC/Support/FIROpenACCOpsInterfaces.h index bf87654979cc9..87d60d489ba13 100644 --- a/flang/include/flang/Optimizer/OpenACC/Support/FIROpenACCOpsInterfaces.h +++ b/flang/include/flang/Optimizer/OpenACC/Support/FIROpenACCOpsInterfaces.h @@ -67,6 +67,16 @@ struct GlobalVariableModel bool isConstant(mlir::Operation *op) const; }; +template +struct IndirectGlobalAccessModel + : public mlir::acc::IndirectGlobalAccessOpInterface::ExternalModel< + IndirectGlobalAccessModel, Op> { + void getReferencedSymbols( + mlir::Operation *op, + llvm::SmallVectorImpl &symbols, + mlir::SymbolTable *symbolTable) const; +}; + } // namespace fir::acc #endif // FLANG_OPTIMIZER_OPENACC_FIROPENACC_OPS_INTERFACES_H_ diff --git a/flang/lib/Optimizer/OpenACC/Support/FIROpenACCOpsInterfaces.cpp b/flang/lib/Optimizer/OpenACC/Support/FIROpenACCOpsInterfaces.cpp index 11fbaf2dc2bb8..aa62b5a9820ee 100644 --- a/flang/lib/Optimizer/OpenACC/Support/FIROpenACCOpsInterfaces.cpp +++ b/flang/lib/Optimizer/OpenACC/Support/FIROpenACCOpsInterfaces.cpp @@ -14,6 +14,9 @@ #include "flang/Optimizer/Dialect/FIROps.h" #include "flang/Optimizer/HLFIR/HLFIROps.h" +#include "flang/Optimizer/Support/InternalNames.h" +#include "mlir/IR/SymbolTable.h" +#include "llvm/ADT/SmallSet.h" namespace fir::acc { @@ -68,4 +71,101 @@ bool GlobalVariableModel::isConstant(mlir::Operation *op) const { return globalOp.getConstant().has_value(); } +// Helper to recursively process address-of operations in derived type +// descriptors and collect all needed fir.globals. +static void processAddrOfOpInDerivedTypeDescriptor(fir::AddrOfOp addrOfOp, + mlir::SymbolTable &symTab, + llvm::SmallSet &globalsSet, + llvm::SmallVectorImpl &symbols) { + if (auto globalOp = symTab.lookup( + addrOfOp.getSymbol().getLeafReference().getValue())) { + if (globalsSet.contains(globalOp)) { + return; + } + globalsSet.insert(globalOp); + symbols.push_back(addrOfOp.getSymbolAttr()); + globalOp.walk([&](fir::AddrOfOp op) { + processAddrOfOpInDerivedTypeDescriptor( + op, symTab, globalsSet, symbols); + }); + } +} + +// Utility to collect referenced symbols for type descriptors of derived types. +// This is the common logic for operations that may require type descriptor +// globals. +static void collectReferencedSymbolsForType(mlir::Type ty, mlir::Operation *op, + llvm::SmallVectorImpl &symbols, + mlir::SymbolTable *symbolTable) { + ty = fir::getDerivedType(fir::unwrapRefType(ty)); + + // Look for type descriptor globals only if it's a derived (record) type + if (auto recTy = mlir::dyn_cast_if_present(ty)) { + // If no symbol table provided, simply add the type descriptor name + if (!symbolTable) { + symbols.push_back(mlir::SymbolRefAttr::get(op->getContext(), + fir::NameUniquer::getTypeDescriptorName(recTy.getName()))); + return; + } + + // Otherwise, do full lookup and recursive processing + llvm::SmallSet globalsSet; + + fir::GlobalOp globalOp = symbolTable->lookup( + fir::NameUniquer::getTypeDescriptorName(recTy.getName())); + if (!globalOp) { + globalOp = symbolTable->lookup( + fir::NameUniquer::getTypeDescriptorAssemblyName(recTy.getName())); + } + if (globalOp) { + globalsSet.insert(globalOp); + symbols.push_back(mlir::SymbolRefAttr::get( + op->getContext(), globalOp.getSymName())); + globalOp.walk([&](fir::AddrOfOp addrOp) { + processAddrOfOpInDerivedTypeDescriptor( + addrOp, *symbolTable, globalsSet, symbols); + }); + } + } +} + +template <> +void IndirectGlobalAccessModel::getReferencedSymbols( + mlir::Operation *op, + llvm::SmallVectorImpl &symbols, + mlir::SymbolTable *symbolTable) const { + auto allocaOp = mlir::cast(op); + collectReferencedSymbolsForType(allocaOp.getType(), op, symbols, symbolTable); +} + +template <> +void IndirectGlobalAccessModel::getReferencedSymbols( + mlir::Operation *op, + llvm::SmallVectorImpl &symbols, + mlir::SymbolTable *symbolTable) const { + auto emboxOp = mlir::cast(op); + collectReferencedSymbolsForType(emboxOp.getMemref().getType(), op, symbols, + symbolTable); +} + +template <> +void IndirectGlobalAccessModel::getReferencedSymbols( + mlir::Operation *op, + llvm::SmallVectorImpl &symbols, + mlir::SymbolTable *symbolTable) const { + auto reboxOp = mlir::cast(op); + collectReferencedSymbolsForType(reboxOp.getBox().getType(), op, symbols, + symbolTable); +} + +template <> +void IndirectGlobalAccessModel::getReferencedSymbols( + mlir::Operation *op, + llvm::SmallVectorImpl &symbols, + mlir::SymbolTable *symbolTable) const { + auto typeDescOp = mlir::cast(op); + collectReferencedSymbolsForType(typeDescOp.getInType(), op, symbols, + symbolTable); +} + } // namespace fir::acc diff --git a/flang/lib/Optimizer/OpenACC/Support/RegisterOpenACCExtensions.cpp b/flang/lib/Optimizer/OpenACC/Support/RegisterOpenACCExtensions.cpp index 5c7f9985d41ca..915518c8de6c7 100644 --- a/flang/lib/Optimizer/OpenACC/Support/RegisterOpenACCExtensions.cpp +++ b/flang/lib/Optimizer/OpenACC/Support/RegisterOpenACCExtensions.cpp @@ -52,6 +52,15 @@ void registerOpenACCExtensions(mlir::DialectRegistry ®istry) { fir::AddrOfOp::attachInterface(*ctx); fir::GlobalOp::attachInterface(*ctx); + + fir::AllocaOp::attachInterface>( + *ctx); + fir::EmboxOp::attachInterface>( + *ctx); + fir::ReboxOp::attachInterface>( + *ctx); + fir::TypeDescOp::attachInterface>( + *ctx); }); // Register HLFIR operation interfaces diff --git a/mlir/include/mlir/Dialect/OpenACC/OpenACCOpsInterfaces.td b/mlir/include/mlir/Dialect/OpenACC/OpenACCOpsInterfaces.td index 6b0c84d31d1ba..ec41826b2bbc8 100644 --- a/mlir/include/mlir/Dialect/OpenACC/OpenACCOpsInterfaces.td +++ b/mlir/include/mlir/Dialect/OpenACC/OpenACCOpsInterfaces.td @@ -75,4 +75,27 @@ def GlobalVariableOpInterface : OpInterface<"GlobalVariableOpInterface"> { ]; } +def IndirectGlobalAccessOpInterface : OpInterface<"IndirectGlobalAccessOpInterface"> { + let cppNamespace = "::mlir::acc"; + + let description = [{ + An interface for operations that indirectly access global symbols. + This interface provides a way to query which global symbols are referenced + by an operation, which is useful for tracking dependencies and performing + analysis on global variable usage. + + The symbolTable parameter is optional. If null, implementations will look up + their own symbol table. This allows callers to pass a pre-existing symbol + table for efficiency when querying multiple operations. + }]; + + let methods = [ + InterfaceMethod<"Get the symbols referenced by this operation", + "void", + "getReferencedSymbols", + (ins "::llvm::SmallVectorImpl<::mlir::SymbolRefAttr>&":$symbols, + "::mlir::SymbolTable *":$symbolTable)>, + ]; +} + #endif // OPENACC_OPS_INTERFACES From e2ca6975219ae8d3940cb33f583a43160e3a46f1 Mon Sep 17 00:00:00 2001 From: Razvan Lupusoru Date: Thu, 20 Nov 2025 17:44:02 -0800 Subject: [PATCH 2/3] Fix formatting --- .../OpenACC/Support/FIROpenACCOpsInterfaces.h | 7 ++-- .../Support/FIROpenACCOpsInterfaces.cpp | 39 +++++++++---------- .../Support/RegisterOpenACCExtensions.cpp | 4 +- 3 files changed, 23 insertions(+), 27 deletions(-) diff --git a/flang/include/flang/Optimizer/OpenACC/Support/FIROpenACCOpsInterfaces.h b/flang/include/flang/Optimizer/OpenACC/Support/FIROpenACCOpsInterfaces.h index 87d60d489ba13..0020e1ab21a56 100644 --- a/flang/include/flang/Optimizer/OpenACC/Support/FIROpenACCOpsInterfaces.h +++ b/flang/include/flang/Optimizer/OpenACC/Support/FIROpenACCOpsInterfaces.h @@ -71,10 +71,9 @@ template struct IndirectGlobalAccessModel : public mlir::acc::IndirectGlobalAccessOpInterface::ExternalModel< IndirectGlobalAccessModel, Op> { - void getReferencedSymbols( - mlir::Operation *op, - llvm::SmallVectorImpl &symbols, - mlir::SymbolTable *symbolTable) const; + void getReferencedSymbols(mlir::Operation *op, + llvm::SmallVectorImpl &symbols, + mlir::SymbolTable *symbolTable) const; }; } // namespace fir::acc diff --git a/flang/lib/Optimizer/OpenACC/Support/FIROpenACCOpsInterfaces.cpp b/flang/lib/Optimizer/OpenACC/Support/FIROpenACCOpsInterfaces.cpp index aa62b5a9820ee..2e5d8a61b5b32 100644 --- a/flang/lib/Optimizer/OpenACC/Support/FIROpenACCOpsInterfaces.cpp +++ b/flang/lib/Optimizer/OpenACC/Support/FIROpenACCOpsInterfaces.cpp @@ -73,8 +73,8 @@ bool GlobalVariableModel::isConstant(mlir::Operation *op) const { // Helper to recursively process address-of operations in derived type // descriptors and collect all needed fir.globals. -static void processAddrOfOpInDerivedTypeDescriptor(fir::AddrOfOp addrOfOp, - mlir::SymbolTable &symTab, +static void processAddrOfOpInDerivedTypeDescriptor( + fir::AddrOfOp addrOfOp, mlir::SymbolTable &symTab, llvm::SmallSet &globalsSet, llvm::SmallVectorImpl &symbols) { if (auto globalOp = symTab.lookup( @@ -85,8 +85,7 @@ static void processAddrOfOpInDerivedTypeDescriptor(fir::AddrOfOp addrOfOp, globalsSet.insert(globalOp); symbols.push_back(addrOfOp.getSymbolAttr()); globalOp.walk([&](fir::AddrOfOp op) { - processAddrOfOpInDerivedTypeDescriptor( - op, symTab, globalsSet, symbols); + processAddrOfOpInDerivedTypeDescriptor(op, symTab, globalsSet, symbols); }); } } @@ -94,7 +93,8 @@ static void processAddrOfOpInDerivedTypeDescriptor(fir::AddrOfOp addrOfOp, // Utility to collect referenced symbols for type descriptors of derived types. // This is the common logic for operations that may require type descriptor // globals. -static void collectReferencedSymbolsForType(mlir::Type ty, mlir::Operation *op, +static void collectReferencedSymbolsForType( + mlir::Type ty, mlir::Operation *op, llvm::SmallVectorImpl &symbols, mlir::SymbolTable *symbolTable) { ty = fir::getDerivedType(fir::unwrapRefType(ty)); @@ -103,7 +103,8 @@ static void collectReferencedSymbolsForType(mlir::Type ty, mlir::Operation *op, if (auto recTy = mlir::dyn_cast_if_present(ty)) { // If no symbol table provided, simply add the type descriptor name if (!symbolTable) { - symbols.push_back(mlir::SymbolRefAttr::get(op->getContext(), + symbols.push_back(mlir::SymbolRefAttr::get( + op->getContext(), fir::NameUniquer::getTypeDescriptorName(recTy.getName()))); return; } @@ -119,11 +120,11 @@ static void collectReferencedSymbolsForType(mlir::Type ty, mlir::Operation *op, } if (globalOp) { globalsSet.insert(globalOp); - symbols.push_back(mlir::SymbolRefAttr::get( - op->getContext(), globalOp.getSymName())); + symbols.push_back( + mlir::SymbolRefAttr::get(op->getContext(), globalOp.getSymName())); globalOp.walk([&](fir::AddrOfOp addrOp) { - processAddrOfOpInDerivedTypeDescriptor( - addrOp, *symbolTable, globalsSet, symbols); + processAddrOfOpInDerivedTypeDescriptor(addrOp, *symbolTable, globalsSet, + symbols); }); } } @@ -131,8 +132,7 @@ static void collectReferencedSymbolsForType(mlir::Type ty, mlir::Operation *op, template <> void IndirectGlobalAccessModel::getReferencedSymbols( - mlir::Operation *op, - llvm::SmallVectorImpl &symbols, + mlir::Operation *op, llvm::SmallVectorImpl &symbols, mlir::SymbolTable *symbolTable) const { auto allocaOp = mlir::cast(op); collectReferencedSymbolsForType(allocaOp.getType(), op, symbols, symbolTable); @@ -140,32 +140,29 @@ void IndirectGlobalAccessModel::getReferencedSymbols( template <> void IndirectGlobalAccessModel::getReferencedSymbols( - mlir::Operation *op, - llvm::SmallVectorImpl &symbols, + mlir::Operation *op, llvm::SmallVectorImpl &symbols, mlir::SymbolTable *symbolTable) const { auto emboxOp = mlir::cast(op); collectReferencedSymbolsForType(emboxOp.getMemref().getType(), op, symbols, - symbolTable); + symbolTable); } template <> void IndirectGlobalAccessModel::getReferencedSymbols( - mlir::Operation *op, - llvm::SmallVectorImpl &symbols, + mlir::Operation *op, llvm::SmallVectorImpl &symbols, mlir::SymbolTable *symbolTable) const { auto reboxOp = mlir::cast(op); collectReferencedSymbolsForType(reboxOp.getBox().getType(), op, symbols, - symbolTable); + symbolTable); } template <> void IndirectGlobalAccessModel::getReferencedSymbols( - mlir::Operation *op, - llvm::SmallVectorImpl &symbols, + mlir::Operation *op, llvm::SmallVectorImpl &symbols, mlir::SymbolTable *symbolTable) const { auto typeDescOp = mlir::cast(op); collectReferencedSymbolsForType(typeDescOp.getInType(), op, symbols, - symbolTable); + symbolTable); } } // namespace fir::acc diff --git a/flang/lib/Optimizer/OpenACC/Support/RegisterOpenACCExtensions.cpp b/flang/lib/Optimizer/OpenACC/Support/RegisterOpenACCExtensions.cpp index 915518c8de6c7..acd1d01ef1e87 100644 --- a/flang/lib/Optimizer/OpenACC/Support/RegisterOpenACCExtensions.cpp +++ b/flang/lib/Optimizer/OpenACC/Support/RegisterOpenACCExtensions.cpp @@ -59,8 +59,8 @@ void registerOpenACCExtensions(mlir::DialectRegistry ®istry) { *ctx); fir::ReboxOp::attachInterface>( *ctx); - fir::TypeDescOp::attachInterface>( - *ctx); + fir::TypeDescOp::attachInterface< + IndirectGlobalAccessModel>(*ctx); }); // Register HLFIR operation interfaces From 234a531e863ddbf7303979e73301afa2d70c2d94 Mon Sep 17 00:00:00 2001 From: Razvan Lupusoru Date: Thu, 20 Nov 2025 20:03:18 -0800 Subject: [PATCH 3/3] Fix braces --- .../Optimizer/OpenACC/Support/FIROpenACCOpsInterfaces.cpp | 7 +++---- 1 file changed, 3 insertions(+), 4 deletions(-) diff --git a/flang/lib/Optimizer/OpenACC/Support/FIROpenACCOpsInterfaces.cpp b/flang/lib/Optimizer/OpenACC/Support/FIROpenACCOpsInterfaces.cpp index 2e5d8a61b5b32..902a2ecdec35f 100644 --- a/flang/lib/Optimizer/OpenACC/Support/FIROpenACCOpsInterfaces.cpp +++ b/flang/lib/Optimizer/OpenACC/Support/FIROpenACCOpsInterfaces.cpp @@ -79,9 +79,8 @@ static void processAddrOfOpInDerivedTypeDescriptor( llvm::SmallVectorImpl &symbols) { if (auto globalOp = symTab.lookup( addrOfOp.getSymbol().getLeafReference().getValue())) { - if (globalsSet.contains(globalOp)) { + if (globalsSet.contains(globalOp)) return; - } globalsSet.insert(globalOp); symbols.push_back(addrOfOp.getSymbolAttr()); globalOp.walk([&](fir::AddrOfOp op) { @@ -114,10 +113,10 @@ static void collectReferencedSymbolsForType( fir::GlobalOp globalOp = symbolTable->lookup( fir::NameUniquer::getTypeDescriptorName(recTy.getName())); - if (!globalOp) { + if (!globalOp) globalOp = symbolTable->lookup( fir::NameUniquer::getTypeDescriptorAssemblyName(recTy.getName())); - } + if (globalOp) { globalsSet.insert(globalOp); symbols.push_back(