diff --git a/flang/include/flang/Optimizer/OpenACC/Support/FIROpenACCOpsInterfaces.h b/flang/include/flang/Optimizer/OpenACC/Support/FIROpenACCOpsInterfaces.h index bf87654979cc9..0020e1ab21a56 100644 --- a/flang/include/flang/Optimizer/OpenACC/Support/FIROpenACCOpsInterfaces.h +++ b/flang/include/flang/Optimizer/OpenACC/Support/FIROpenACCOpsInterfaces.h @@ -67,6 +67,15 @@ struct GlobalVariableModel bool isConstant(mlir::Operation *op) const; }; +template +struct IndirectGlobalAccessModel + : public mlir::acc::IndirectGlobalAccessOpInterface::ExternalModel< + IndirectGlobalAccessModel, Op> { + void getReferencedSymbols(mlir::Operation *op, + llvm::SmallVectorImpl &symbols, + mlir::SymbolTable *symbolTable) const; +}; + } // namespace fir::acc #endif // FLANG_OPTIMIZER_OPENACC_FIROPENACC_OPS_INTERFACES_H_ diff --git a/flang/lib/Optimizer/OpenACC/Support/FIROpenACCOpsInterfaces.cpp b/flang/lib/Optimizer/OpenACC/Support/FIROpenACCOpsInterfaces.cpp index 11fbaf2dc2bb8..902a2ecdec35f 100644 --- a/flang/lib/Optimizer/OpenACC/Support/FIROpenACCOpsInterfaces.cpp +++ b/flang/lib/Optimizer/OpenACC/Support/FIROpenACCOpsInterfaces.cpp @@ -14,6 +14,9 @@ #include "flang/Optimizer/Dialect/FIROps.h" #include "flang/Optimizer/HLFIR/HLFIROps.h" +#include "flang/Optimizer/Support/InternalNames.h" +#include "mlir/IR/SymbolTable.h" +#include "llvm/ADT/SmallSet.h" namespace fir::acc { @@ -68,4 +71,97 @@ bool GlobalVariableModel::isConstant(mlir::Operation *op) const { return globalOp.getConstant().has_value(); } +// Helper to recursively process address-of operations in derived type +// descriptors and collect all needed fir.globals. +static void processAddrOfOpInDerivedTypeDescriptor( + fir::AddrOfOp addrOfOp, mlir::SymbolTable &symTab, + llvm::SmallSet &globalsSet, + llvm::SmallVectorImpl &symbols) { + if (auto globalOp = symTab.lookup( + addrOfOp.getSymbol().getLeafReference().getValue())) { + if (globalsSet.contains(globalOp)) + return; + globalsSet.insert(globalOp); + symbols.push_back(addrOfOp.getSymbolAttr()); + globalOp.walk([&](fir::AddrOfOp op) { + processAddrOfOpInDerivedTypeDescriptor(op, symTab, globalsSet, symbols); + }); + } +} + +// Utility to collect referenced symbols for type descriptors of derived types. +// This is the common logic for operations that may require type descriptor +// globals. +static void collectReferencedSymbolsForType( + mlir::Type ty, mlir::Operation *op, + llvm::SmallVectorImpl &symbols, + mlir::SymbolTable *symbolTable) { + ty = fir::getDerivedType(fir::unwrapRefType(ty)); + + // Look for type descriptor globals only if it's a derived (record) type + if (auto recTy = mlir::dyn_cast_if_present(ty)) { + // If no symbol table provided, simply add the type descriptor name + if (!symbolTable) { + symbols.push_back(mlir::SymbolRefAttr::get( + op->getContext(), + fir::NameUniquer::getTypeDescriptorName(recTy.getName()))); + return; + } + + // Otherwise, do full lookup and recursive processing + llvm::SmallSet globalsSet; + + fir::GlobalOp globalOp = symbolTable->lookup( + fir::NameUniquer::getTypeDescriptorName(recTy.getName())); + if (!globalOp) + globalOp = symbolTable->lookup( + fir::NameUniquer::getTypeDescriptorAssemblyName(recTy.getName())); + + if (globalOp) { + globalsSet.insert(globalOp); + symbols.push_back( + mlir::SymbolRefAttr::get(op->getContext(), globalOp.getSymName())); + globalOp.walk([&](fir::AddrOfOp addrOp) { + processAddrOfOpInDerivedTypeDescriptor(addrOp, *symbolTable, globalsSet, + symbols); + }); + } + } +} + +template <> +void IndirectGlobalAccessModel::getReferencedSymbols( + mlir::Operation *op, llvm::SmallVectorImpl &symbols, + mlir::SymbolTable *symbolTable) const { + auto allocaOp = mlir::cast(op); + collectReferencedSymbolsForType(allocaOp.getType(), op, symbols, symbolTable); +} + +template <> +void IndirectGlobalAccessModel::getReferencedSymbols( + mlir::Operation *op, llvm::SmallVectorImpl &symbols, + mlir::SymbolTable *symbolTable) const { + auto emboxOp = mlir::cast(op); + collectReferencedSymbolsForType(emboxOp.getMemref().getType(), op, symbols, + symbolTable); +} + +template <> +void IndirectGlobalAccessModel::getReferencedSymbols( + mlir::Operation *op, llvm::SmallVectorImpl &symbols, + mlir::SymbolTable *symbolTable) const { + auto reboxOp = mlir::cast(op); + collectReferencedSymbolsForType(reboxOp.getBox().getType(), op, symbols, + symbolTable); +} + +template <> +void IndirectGlobalAccessModel::getReferencedSymbols( + mlir::Operation *op, llvm::SmallVectorImpl &symbols, + mlir::SymbolTable *symbolTable) const { + auto typeDescOp = mlir::cast(op); + collectReferencedSymbolsForType(typeDescOp.getInType(), op, symbols, + symbolTable); +} + } // namespace fir::acc diff --git a/flang/lib/Optimizer/OpenACC/Support/RegisterOpenACCExtensions.cpp b/flang/lib/Optimizer/OpenACC/Support/RegisterOpenACCExtensions.cpp index 5c7f9985d41ca..acd1d01ef1e87 100644 --- a/flang/lib/Optimizer/OpenACC/Support/RegisterOpenACCExtensions.cpp +++ b/flang/lib/Optimizer/OpenACC/Support/RegisterOpenACCExtensions.cpp @@ -52,6 +52,15 @@ void registerOpenACCExtensions(mlir::DialectRegistry ®istry) { fir::AddrOfOp::attachInterface(*ctx); fir::GlobalOp::attachInterface(*ctx); + + fir::AllocaOp::attachInterface>( + *ctx); + fir::EmboxOp::attachInterface>( + *ctx); + fir::ReboxOp::attachInterface>( + *ctx); + fir::TypeDescOp::attachInterface< + IndirectGlobalAccessModel>(*ctx); }); // Register HLFIR operation interfaces diff --git a/mlir/include/mlir/Dialect/OpenACC/OpenACCOpsInterfaces.td b/mlir/include/mlir/Dialect/OpenACC/OpenACCOpsInterfaces.td index 6b0c84d31d1ba..ec41826b2bbc8 100644 --- a/mlir/include/mlir/Dialect/OpenACC/OpenACCOpsInterfaces.td +++ b/mlir/include/mlir/Dialect/OpenACC/OpenACCOpsInterfaces.td @@ -75,4 +75,27 @@ def GlobalVariableOpInterface : OpInterface<"GlobalVariableOpInterface"> { ]; } +def IndirectGlobalAccessOpInterface : OpInterface<"IndirectGlobalAccessOpInterface"> { + let cppNamespace = "::mlir::acc"; + + let description = [{ + An interface for operations that indirectly access global symbols. + This interface provides a way to query which global symbols are referenced + by an operation, which is useful for tracking dependencies and performing + analysis on global variable usage. + + The symbolTable parameter is optional. If null, implementations will look up + their own symbol table. This allows callers to pass a pre-existing symbol + table for efficiency when querying multiple operations. + }]; + + let methods = [ + InterfaceMethod<"Get the symbols referenced by this operation", + "void", + "getReferencedSymbols", + (ins "::llvm::SmallVectorImpl<::mlir::SymbolRefAttr>&":$symbols, + "::mlir::SymbolTable *":$symbolTable)>, + ]; +} + #endif // OPENACC_OPS_INTERFACES