Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -67,6 +67,15 @@ struct GlobalVariableModel
bool isConstant(mlir::Operation *op) const;
};

template <typename Op>
struct IndirectGlobalAccessModel
: public mlir::acc::IndirectGlobalAccessOpInterface::ExternalModel<
IndirectGlobalAccessModel<Op>, Op> {
void getReferencedSymbols(mlir::Operation *op,
llvm::SmallVectorImpl<mlir::SymbolRefAttr> &symbols,
mlir::SymbolTable *symbolTable) const;
};

} // namespace fir::acc

#endif // FLANG_OPTIMIZER_OPENACC_FIROPENACC_OPS_INTERFACES_H_
96 changes: 96 additions & 0 deletions flang/lib/Optimizer/OpenACC/Support/FIROpenACCOpsInterfaces.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,9 @@

#include "flang/Optimizer/Dialect/FIROps.h"
#include "flang/Optimizer/HLFIR/HLFIROps.h"
#include "flang/Optimizer/Support/InternalNames.h"
#include "mlir/IR/SymbolTable.h"
#include "llvm/ADT/SmallSet.h"

namespace fir::acc {

Expand Down Expand Up @@ -68,4 +71,97 @@ bool GlobalVariableModel::isConstant(mlir::Operation *op) const {
return globalOp.getConstant().has_value();
}

// Helper to recursively process address-of operations in derived type
// descriptors and collect all needed fir.globals.
static void processAddrOfOpInDerivedTypeDescriptor(
fir::AddrOfOp addrOfOp, mlir::SymbolTable &symTab,
llvm::SmallSet<mlir::Operation *, 16> &globalsSet,
llvm::SmallVectorImpl<mlir::SymbolRefAttr> &symbols) {
if (auto globalOp = symTab.lookup<fir::GlobalOp>(
addrOfOp.getSymbol().getLeafReference().getValue())) {
if (globalsSet.contains(globalOp))
return;
globalsSet.insert(globalOp);
symbols.push_back(addrOfOp.getSymbolAttr());
globalOp.walk([&](fir::AddrOfOp op) {
processAddrOfOpInDerivedTypeDescriptor(op, symTab, globalsSet, symbols);
});
}
}

// Utility to collect referenced symbols for type descriptors of derived types.
// This is the common logic for operations that may require type descriptor
// globals.
static void collectReferencedSymbolsForType(
mlir::Type ty, mlir::Operation *op,
llvm::SmallVectorImpl<mlir::SymbolRefAttr> &symbols,
mlir::SymbolTable *symbolTable) {
ty = fir::getDerivedType(fir::unwrapRefType(ty));

// Look for type descriptor globals only if it's a derived (record) type
if (auto recTy = mlir::dyn_cast_if_present<fir::RecordType>(ty)) {
// If no symbol table provided, simply add the type descriptor name
if (!symbolTable) {
symbols.push_back(mlir::SymbolRefAttr::get(
op->getContext(),
fir::NameUniquer::getTypeDescriptorName(recTy.getName())));
return;
}

// Otherwise, do full lookup and recursive processing
llvm::SmallSet<mlir::Operation *, 16> globalsSet;

fir::GlobalOp globalOp = symbolTable->lookup<fir::GlobalOp>(
fir::NameUniquer::getTypeDescriptorName(recTy.getName()));
if (!globalOp)
globalOp = symbolTable->lookup<fir::GlobalOp>(
fir::NameUniquer::getTypeDescriptorAssemblyName(recTy.getName()));

if (globalOp) {
globalsSet.insert(globalOp);
symbols.push_back(
mlir::SymbolRefAttr::get(op->getContext(), globalOp.getSymName()));
globalOp.walk([&](fir::AddrOfOp addrOp) {
processAddrOfOpInDerivedTypeDescriptor(addrOp, *symbolTable, globalsSet,
symbols);
});
}
}
}

template <>
void IndirectGlobalAccessModel<fir::AllocaOp>::getReferencedSymbols(
mlir::Operation *op, llvm::SmallVectorImpl<mlir::SymbolRefAttr> &symbols,
mlir::SymbolTable *symbolTable) const {
auto allocaOp = mlir::cast<fir::AllocaOp>(op);
collectReferencedSymbolsForType(allocaOp.getType(), op, symbols, symbolTable);
}

template <>
void IndirectGlobalAccessModel<fir::EmboxOp>::getReferencedSymbols(
mlir::Operation *op, llvm::SmallVectorImpl<mlir::SymbolRefAttr> &symbols,
mlir::SymbolTable *symbolTable) const {
auto emboxOp = mlir::cast<fir::EmboxOp>(op);
collectReferencedSymbolsForType(emboxOp.getMemref().getType(), op, symbols,
symbolTable);
}

template <>
void IndirectGlobalAccessModel<fir::ReboxOp>::getReferencedSymbols(
mlir::Operation *op, llvm::SmallVectorImpl<mlir::SymbolRefAttr> &symbols,
mlir::SymbolTable *symbolTable) const {
auto reboxOp = mlir::cast<fir::ReboxOp>(op);
collectReferencedSymbolsForType(reboxOp.getBox().getType(), op, symbols,
symbolTable);
}

template <>
void IndirectGlobalAccessModel<fir::TypeDescOp>::getReferencedSymbols(
mlir::Operation *op, llvm::SmallVectorImpl<mlir::SymbolRefAttr> &symbols,
mlir::SymbolTable *symbolTable) const {
auto typeDescOp = mlir::cast<fir::TypeDescOp>(op);
collectReferencedSymbolsForType(typeDescOp.getInType(), op, symbols,
symbolTable);
}

} // namespace fir::acc
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,15 @@ void registerOpenACCExtensions(mlir::DialectRegistry &registry) {

fir::AddrOfOp::attachInterface<AddressOfGlobalModel>(*ctx);
fir::GlobalOp::attachInterface<GlobalVariableModel>(*ctx);

fir::AllocaOp::attachInterface<IndirectGlobalAccessModel<fir::AllocaOp>>(
*ctx);
fir::EmboxOp::attachInterface<IndirectGlobalAccessModel<fir::EmboxOp>>(
*ctx);
fir::ReboxOp::attachInterface<IndirectGlobalAccessModel<fir::ReboxOp>>(
*ctx);
fir::TypeDescOp::attachInterface<
IndirectGlobalAccessModel<fir::TypeDescOp>>(*ctx);
});

// Register HLFIR operation interfaces
Expand Down
23 changes: 23 additions & 0 deletions mlir/include/mlir/Dialect/OpenACC/OpenACCOpsInterfaces.td
Original file line number Diff line number Diff line change
Expand Up @@ -75,4 +75,27 @@ def GlobalVariableOpInterface : OpInterface<"GlobalVariableOpInterface"> {
];
}

def IndirectGlobalAccessOpInterface : OpInterface<"IndirectGlobalAccessOpInterface"> {
let cppNamespace = "::mlir::acc";

let description = [{
An interface for operations that indirectly access global symbols.
This interface provides a way to query which global symbols are referenced
by an operation, which is useful for tracking dependencies and performing
analysis on global variable usage.

The symbolTable parameter is optional. If null, implementations will look up
their own symbol table. This allows callers to pass a pre-existing symbol
table for efficiency when querying multiple operations.
}];

let methods = [
InterfaceMethod<"Get the symbols referenced by this operation",
"void",
"getReferencedSymbols",
(ins "::llvm::SmallVectorImpl<::mlir::SymbolRefAttr>&":$symbols,
"::mlir::SymbolTable *":$symbolTable)>,
];
}

#endif // OPENACC_OPS_INTERFACES