-
Notifications
You must be signed in to change notification settings - Fork 15.3k
[acc][flang] Implement acc interface for tracking type descriptors #168982
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[acc][flang] Implement acc interface for tracking type descriptors #168982
Conversation
FIR operations that use derived types need to have type descriptor globals available on device when offloading. Examples of this can be seen in `CUFDeviceGlobal` which ensures that such type descriptor uses work on device for CUF. Similarly, this is needed for OpenACC. This change introduces a new interface to the OpenACC dialect named `IndirectGlobalAccessOpInterface` which can be attached to operations that may result in generation of accesses that use type descriptor globals. This functionality is needed for the `ACCImplicitDeclare` pass that is coming in a follow-up change which implicitly ensures that all referenced globals are available in OpenACC compute contexts. The interface provides a `getReferencedSymbols` method that collects all global symbols referenced by an operation. When a symbol table is provided, the implementation for FIR recursively walks type descriptor globals to find all transitively referenced symbols. Note that alternately this could have been implemented in different ways: - Codegen could implicitly generate such type globals as needed by changing the technique that relies on populating them during lowering (eg generate them directly in gpu.module during codegen). - This interface could attach to types instead of operations for a potentially more conservative implementation which maps all type descriptors even if the underlying implementation using it won't necessarily need such mapping. The technique chosen here is consistent with `CUFDeviceGlobal` (which walks operations inside `prepareImplicitDeviceGlobals`) and avoids conservative mapping of all type descriptors.
|
@llvm/pr-subscribers-openacc @llvm/pr-subscribers-flang-fir-hlfir Author: Razvan Lupusoru (razvanlupusoru) ChangesFIR operations that use derived types need to have type descriptor globals available on device when offloading. Examples of this can be seen in Similarly, this is needed for OpenACC. This change introduces a new interface to the OpenACC dialect named The interface provides a Note that alternately this could have been implemented in different ways:
The technique chosen here is consistent with Full diff: https://github.com/llvm/llvm-project/pull/168982.diff 4 Files Affected:
diff --git a/flang/include/flang/Optimizer/OpenACC/Support/FIROpenACCOpsInterfaces.h b/flang/include/flang/Optimizer/OpenACC/Support/FIROpenACCOpsInterfaces.h
index bf87654979cc9..87d60d489ba13 100644
--- a/flang/include/flang/Optimizer/OpenACC/Support/FIROpenACCOpsInterfaces.h
+++ b/flang/include/flang/Optimizer/OpenACC/Support/FIROpenACCOpsInterfaces.h
@@ -67,6 +67,16 @@ struct GlobalVariableModel
bool isConstant(mlir::Operation *op) const;
};
+template <typename Op>
+struct IndirectGlobalAccessModel
+ : public mlir::acc::IndirectGlobalAccessOpInterface::ExternalModel<
+ IndirectGlobalAccessModel<Op>, Op> {
+ void getReferencedSymbols(
+ mlir::Operation *op,
+ llvm::SmallVectorImpl<mlir::SymbolRefAttr> &symbols,
+ mlir::SymbolTable *symbolTable) const;
+};
+
} // namespace fir::acc
#endif // FLANG_OPTIMIZER_OPENACC_FIROPENACC_OPS_INTERFACES_H_
diff --git a/flang/lib/Optimizer/OpenACC/Support/FIROpenACCOpsInterfaces.cpp b/flang/lib/Optimizer/OpenACC/Support/FIROpenACCOpsInterfaces.cpp
index 11fbaf2dc2bb8..aa62b5a9820ee 100644
--- a/flang/lib/Optimizer/OpenACC/Support/FIROpenACCOpsInterfaces.cpp
+++ b/flang/lib/Optimizer/OpenACC/Support/FIROpenACCOpsInterfaces.cpp
@@ -14,6 +14,9 @@
#include "flang/Optimizer/Dialect/FIROps.h"
#include "flang/Optimizer/HLFIR/HLFIROps.h"
+#include "flang/Optimizer/Support/InternalNames.h"
+#include "mlir/IR/SymbolTable.h"
+#include "llvm/ADT/SmallSet.h"
namespace fir::acc {
@@ -68,4 +71,101 @@ bool GlobalVariableModel::isConstant(mlir::Operation *op) const {
return globalOp.getConstant().has_value();
}
+// Helper to recursively process address-of operations in derived type
+// descriptors and collect all needed fir.globals.
+static void processAddrOfOpInDerivedTypeDescriptor(fir::AddrOfOp addrOfOp,
+ mlir::SymbolTable &symTab,
+ llvm::SmallSet<mlir::Operation *, 16> &globalsSet,
+ llvm::SmallVectorImpl<mlir::SymbolRefAttr> &symbols) {
+ if (auto globalOp = symTab.lookup<fir::GlobalOp>(
+ addrOfOp.getSymbol().getLeafReference().getValue())) {
+ if (globalsSet.contains(globalOp)) {
+ return;
+ }
+ globalsSet.insert(globalOp);
+ symbols.push_back(addrOfOp.getSymbolAttr());
+ globalOp.walk([&](fir::AddrOfOp op) {
+ processAddrOfOpInDerivedTypeDescriptor(
+ op, symTab, globalsSet, symbols);
+ });
+ }
+}
+
+// Utility to collect referenced symbols for type descriptors of derived types.
+// This is the common logic for operations that may require type descriptor
+// globals.
+static void collectReferencedSymbolsForType(mlir::Type ty, mlir::Operation *op,
+ llvm::SmallVectorImpl<mlir::SymbolRefAttr> &symbols,
+ mlir::SymbolTable *symbolTable) {
+ ty = fir::getDerivedType(fir::unwrapRefType(ty));
+
+ // Look for type descriptor globals only if it's a derived (record) type
+ if (auto recTy = mlir::dyn_cast_if_present<fir::RecordType>(ty)) {
+ // If no symbol table provided, simply add the type descriptor name
+ if (!symbolTable) {
+ symbols.push_back(mlir::SymbolRefAttr::get(op->getContext(),
+ fir::NameUniquer::getTypeDescriptorName(recTy.getName())));
+ return;
+ }
+
+ // Otherwise, do full lookup and recursive processing
+ llvm::SmallSet<mlir::Operation *, 16> globalsSet;
+
+ fir::GlobalOp globalOp = symbolTable->lookup<fir::GlobalOp>(
+ fir::NameUniquer::getTypeDescriptorName(recTy.getName()));
+ if (!globalOp) {
+ globalOp = symbolTable->lookup<fir::GlobalOp>(
+ fir::NameUniquer::getTypeDescriptorAssemblyName(recTy.getName()));
+ }
+ if (globalOp) {
+ globalsSet.insert(globalOp);
+ symbols.push_back(mlir::SymbolRefAttr::get(
+ op->getContext(), globalOp.getSymName()));
+ globalOp.walk([&](fir::AddrOfOp addrOp) {
+ processAddrOfOpInDerivedTypeDescriptor(
+ addrOp, *symbolTable, globalsSet, symbols);
+ });
+ }
+ }
+}
+
+template <>
+void IndirectGlobalAccessModel<fir::AllocaOp>::getReferencedSymbols(
+ mlir::Operation *op,
+ llvm::SmallVectorImpl<mlir::SymbolRefAttr> &symbols,
+ mlir::SymbolTable *symbolTable) const {
+ auto allocaOp = mlir::cast<fir::AllocaOp>(op);
+ collectReferencedSymbolsForType(allocaOp.getType(), op, symbols, symbolTable);
+}
+
+template <>
+void IndirectGlobalAccessModel<fir::EmboxOp>::getReferencedSymbols(
+ mlir::Operation *op,
+ llvm::SmallVectorImpl<mlir::SymbolRefAttr> &symbols,
+ mlir::SymbolTable *symbolTable) const {
+ auto emboxOp = mlir::cast<fir::EmboxOp>(op);
+ collectReferencedSymbolsForType(emboxOp.getMemref().getType(), op, symbols,
+ symbolTable);
+}
+
+template <>
+void IndirectGlobalAccessModel<fir::ReboxOp>::getReferencedSymbols(
+ mlir::Operation *op,
+ llvm::SmallVectorImpl<mlir::SymbolRefAttr> &symbols,
+ mlir::SymbolTable *symbolTable) const {
+ auto reboxOp = mlir::cast<fir::ReboxOp>(op);
+ collectReferencedSymbolsForType(reboxOp.getBox().getType(), op, symbols,
+ symbolTable);
+}
+
+template <>
+void IndirectGlobalAccessModel<fir::TypeDescOp>::getReferencedSymbols(
+ mlir::Operation *op,
+ llvm::SmallVectorImpl<mlir::SymbolRefAttr> &symbols,
+ mlir::SymbolTable *symbolTable) const {
+ auto typeDescOp = mlir::cast<fir::TypeDescOp>(op);
+ collectReferencedSymbolsForType(typeDescOp.getInType(), op, symbols,
+ symbolTable);
+}
+
} // namespace fir::acc
diff --git a/flang/lib/Optimizer/OpenACC/Support/RegisterOpenACCExtensions.cpp b/flang/lib/Optimizer/OpenACC/Support/RegisterOpenACCExtensions.cpp
index 5c7f9985d41ca..915518c8de6c7 100644
--- a/flang/lib/Optimizer/OpenACC/Support/RegisterOpenACCExtensions.cpp
+++ b/flang/lib/Optimizer/OpenACC/Support/RegisterOpenACCExtensions.cpp
@@ -52,6 +52,15 @@ void registerOpenACCExtensions(mlir::DialectRegistry ®istry) {
fir::AddrOfOp::attachInterface<AddressOfGlobalModel>(*ctx);
fir::GlobalOp::attachInterface<GlobalVariableModel>(*ctx);
+
+ fir::AllocaOp::attachInterface<IndirectGlobalAccessModel<fir::AllocaOp>>(
+ *ctx);
+ fir::EmboxOp::attachInterface<IndirectGlobalAccessModel<fir::EmboxOp>>(
+ *ctx);
+ fir::ReboxOp::attachInterface<IndirectGlobalAccessModel<fir::ReboxOp>>(
+ *ctx);
+ fir::TypeDescOp::attachInterface<IndirectGlobalAccessModel<fir::TypeDescOp>>(
+ *ctx);
});
// Register HLFIR operation interfaces
diff --git a/mlir/include/mlir/Dialect/OpenACC/OpenACCOpsInterfaces.td b/mlir/include/mlir/Dialect/OpenACC/OpenACCOpsInterfaces.td
index 6b0c84d31d1ba..ec41826b2bbc8 100644
--- a/mlir/include/mlir/Dialect/OpenACC/OpenACCOpsInterfaces.td
+++ b/mlir/include/mlir/Dialect/OpenACC/OpenACCOpsInterfaces.td
@@ -75,4 +75,27 @@ def GlobalVariableOpInterface : OpInterface<"GlobalVariableOpInterface"> {
];
}
+def IndirectGlobalAccessOpInterface : OpInterface<"IndirectGlobalAccessOpInterface"> {
+ let cppNamespace = "::mlir::acc";
+
+ let description = [{
+ An interface for operations that indirectly access global symbols.
+ This interface provides a way to query which global symbols are referenced
+ by an operation, which is useful for tracking dependencies and performing
+ analysis on global variable usage.
+
+ The symbolTable parameter is optional. If null, implementations will look up
+ their own symbol table. This allows callers to pass a pre-existing symbol
+ table for efficiency when querying multiple operations.
+ }];
+
+ let methods = [
+ InterfaceMethod<"Get the symbols referenced by this operation",
+ "void",
+ "getReferencedSymbols",
+ (ins "::llvm::SmallVectorImpl<::mlir::SymbolRefAttr>&":$symbols,
+ "::mlir::SymbolTable *":$symbolTable)>,
+ ];
+}
+
#endif // OPENACC_OPS_INTERFACES
|
|
@llvm/pr-subscribers-mlir Author: Razvan Lupusoru (razvanlupusoru) ChangesFIR operations that use derived types need to have type descriptor globals available on device when offloading. Examples of this can be seen in Similarly, this is needed for OpenACC. This change introduces a new interface to the OpenACC dialect named The interface provides a Note that alternately this could have been implemented in different ways:
The technique chosen here is consistent with Full diff: https://github.com/llvm/llvm-project/pull/168982.diff 4 Files Affected:
diff --git a/flang/include/flang/Optimizer/OpenACC/Support/FIROpenACCOpsInterfaces.h b/flang/include/flang/Optimizer/OpenACC/Support/FIROpenACCOpsInterfaces.h
index bf87654979cc9..87d60d489ba13 100644
--- a/flang/include/flang/Optimizer/OpenACC/Support/FIROpenACCOpsInterfaces.h
+++ b/flang/include/flang/Optimizer/OpenACC/Support/FIROpenACCOpsInterfaces.h
@@ -67,6 +67,16 @@ struct GlobalVariableModel
bool isConstant(mlir::Operation *op) const;
};
+template <typename Op>
+struct IndirectGlobalAccessModel
+ : public mlir::acc::IndirectGlobalAccessOpInterface::ExternalModel<
+ IndirectGlobalAccessModel<Op>, Op> {
+ void getReferencedSymbols(
+ mlir::Operation *op,
+ llvm::SmallVectorImpl<mlir::SymbolRefAttr> &symbols,
+ mlir::SymbolTable *symbolTable) const;
+};
+
} // namespace fir::acc
#endif // FLANG_OPTIMIZER_OPENACC_FIROPENACC_OPS_INTERFACES_H_
diff --git a/flang/lib/Optimizer/OpenACC/Support/FIROpenACCOpsInterfaces.cpp b/flang/lib/Optimizer/OpenACC/Support/FIROpenACCOpsInterfaces.cpp
index 11fbaf2dc2bb8..aa62b5a9820ee 100644
--- a/flang/lib/Optimizer/OpenACC/Support/FIROpenACCOpsInterfaces.cpp
+++ b/flang/lib/Optimizer/OpenACC/Support/FIROpenACCOpsInterfaces.cpp
@@ -14,6 +14,9 @@
#include "flang/Optimizer/Dialect/FIROps.h"
#include "flang/Optimizer/HLFIR/HLFIROps.h"
+#include "flang/Optimizer/Support/InternalNames.h"
+#include "mlir/IR/SymbolTable.h"
+#include "llvm/ADT/SmallSet.h"
namespace fir::acc {
@@ -68,4 +71,101 @@ bool GlobalVariableModel::isConstant(mlir::Operation *op) const {
return globalOp.getConstant().has_value();
}
+// Helper to recursively process address-of operations in derived type
+// descriptors and collect all needed fir.globals.
+static void processAddrOfOpInDerivedTypeDescriptor(fir::AddrOfOp addrOfOp,
+ mlir::SymbolTable &symTab,
+ llvm::SmallSet<mlir::Operation *, 16> &globalsSet,
+ llvm::SmallVectorImpl<mlir::SymbolRefAttr> &symbols) {
+ if (auto globalOp = symTab.lookup<fir::GlobalOp>(
+ addrOfOp.getSymbol().getLeafReference().getValue())) {
+ if (globalsSet.contains(globalOp)) {
+ return;
+ }
+ globalsSet.insert(globalOp);
+ symbols.push_back(addrOfOp.getSymbolAttr());
+ globalOp.walk([&](fir::AddrOfOp op) {
+ processAddrOfOpInDerivedTypeDescriptor(
+ op, symTab, globalsSet, symbols);
+ });
+ }
+}
+
+// Utility to collect referenced symbols for type descriptors of derived types.
+// This is the common logic for operations that may require type descriptor
+// globals.
+static void collectReferencedSymbolsForType(mlir::Type ty, mlir::Operation *op,
+ llvm::SmallVectorImpl<mlir::SymbolRefAttr> &symbols,
+ mlir::SymbolTable *symbolTable) {
+ ty = fir::getDerivedType(fir::unwrapRefType(ty));
+
+ // Look for type descriptor globals only if it's a derived (record) type
+ if (auto recTy = mlir::dyn_cast_if_present<fir::RecordType>(ty)) {
+ // If no symbol table provided, simply add the type descriptor name
+ if (!symbolTable) {
+ symbols.push_back(mlir::SymbolRefAttr::get(op->getContext(),
+ fir::NameUniquer::getTypeDescriptorName(recTy.getName())));
+ return;
+ }
+
+ // Otherwise, do full lookup and recursive processing
+ llvm::SmallSet<mlir::Operation *, 16> globalsSet;
+
+ fir::GlobalOp globalOp = symbolTable->lookup<fir::GlobalOp>(
+ fir::NameUniquer::getTypeDescriptorName(recTy.getName()));
+ if (!globalOp) {
+ globalOp = symbolTable->lookup<fir::GlobalOp>(
+ fir::NameUniquer::getTypeDescriptorAssemblyName(recTy.getName()));
+ }
+ if (globalOp) {
+ globalsSet.insert(globalOp);
+ symbols.push_back(mlir::SymbolRefAttr::get(
+ op->getContext(), globalOp.getSymName()));
+ globalOp.walk([&](fir::AddrOfOp addrOp) {
+ processAddrOfOpInDerivedTypeDescriptor(
+ addrOp, *symbolTable, globalsSet, symbols);
+ });
+ }
+ }
+}
+
+template <>
+void IndirectGlobalAccessModel<fir::AllocaOp>::getReferencedSymbols(
+ mlir::Operation *op,
+ llvm::SmallVectorImpl<mlir::SymbolRefAttr> &symbols,
+ mlir::SymbolTable *symbolTable) const {
+ auto allocaOp = mlir::cast<fir::AllocaOp>(op);
+ collectReferencedSymbolsForType(allocaOp.getType(), op, symbols, symbolTable);
+}
+
+template <>
+void IndirectGlobalAccessModel<fir::EmboxOp>::getReferencedSymbols(
+ mlir::Operation *op,
+ llvm::SmallVectorImpl<mlir::SymbolRefAttr> &symbols,
+ mlir::SymbolTable *symbolTable) const {
+ auto emboxOp = mlir::cast<fir::EmboxOp>(op);
+ collectReferencedSymbolsForType(emboxOp.getMemref().getType(), op, symbols,
+ symbolTable);
+}
+
+template <>
+void IndirectGlobalAccessModel<fir::ReboxOp>::getReferencedSymbols(
+ mlir::Operation *op,
+ llvm::SmallVectorImpl<mlir::SymbolRefAttr> &symbols,
+ mlir::SymbolTable *symbolTable) const {
+ auto reboxOp = mlir::cast<fir::ReboxOp>(op);
+ collectReferencedSymbolsForType(reboxOp.getBox().getType(), op, symbols,
+ symbolTable);
+}
+
+template <>
+void IndirectGlobalAccessModel<fir::TypeDescOp>::getReferencedSymbols(
+ mlir::Operation *op,
+ llvm::SmallVectorImpl<mlir::SymbolRefAttr> &symbols,
+ mlir::SymbolTable *symbolTable) const {
+ auto typeDescOp = mlir::cast<fir::TypeDescOp>(op);
+ collectReferencedSymbolsForType(typeDescOp.getInType(), op, symbols,
+ symbolTable);
+}
+
} // namespace fir::acc
diff --git a/flang/lib/Optimizer/OpenACC/Support/RegisterOpenACCExtensions.cpp b/flang/lib/Optimizer/OpenACC/Support/RegisterOpenACCExtensions.cpp
index 5c7f9985d41ca..915518c8de6c7 100644
--- a/flang/lib/Optimizer/OpenACC/Support/RegisterOpenACCExtensions.cpp
+++ b/flang/lib/Optimizer/OpenACC/Support/RegisterOpenACCExtensions.cpp
@@ -52,6 +52,15 @@ void registerOpenACCExtensions(mlir::DialectRegistry ®istry) {
fir::AddrOfOp::attachInterface<AddressOfGlobalModel>(*ctx);
fir::GlobalOp::attachInterface<GlobalVariableModel>(*ctx);
+
+ fir::AllocaOp::attachInterface<IndirectGlobalAccessModel<fir::AllocaOp>>(
+ *ctx);
+ fir::EmboxOp::attachInterface<IndirectGlobalAccessModel<fir::EmboxOp>>(
+ *ctx);
+ fir::ReboxOp::attachInterface<IndirectGlobalAccessModel<fir::ReboxOp>>(
+ *ctx);
+ fir::TypeDescOp::attachInterface<IndirectGlobalAccessModel<fir::TypeDescOp>>(
+ *ctx);
});
// Register HLFIR operation interfaces
diff --git a/mlir/include/mlir/Dialect/OpenACC/OpenACCOpsInterfaces.td b/mlir/include/mlir/Dialect/OpenACC/OpenACCOpsInterfaces.td
index 6b0c84d31d1ba..ec41826b2bbc8 100644
--- a/mlir/include/mlir/Dialect/OpenACC/OpenACCOpsInterfaces.td
+++ b/mlir/include/mlir/Dialect/OpenACC/OpenACCOpsInterfaces.td
@@ -75,4 +75,27 @@ def GlobalVariableOpInterface : OpInterface<"GlobalVariableOpInterface"> {
];
}
+def IndirectGlobalAccessOpInterface : OpInterface<"IndirectGlobalAccessOpInterface"> {
+ let cppNamespace = "::mlir::acc";
+
+ let description = [{
+ An interface for operations that indirectly access global symbols.
+ This interface provides a way to query which global symbols are referenced
+ by an operation, which is useful for tracking dependencies and performing
+ analysis on global variable usage.
+
+ The symbolTable parameter is optional. If null, implementations will look up
+ their own symbol table. This allows callers to pass a pre-existing symbol
+ table for efficiency when querying multiple operations.
+ }];
+
+ let methods = [
+ InterfaceMethod<"Get the symbols referenced by this operation",
+ "void",
+ "getReferencedSymbols",
+ (ins "::llvm::SmallVectorImpl<::mlir::SymbolRefAttr>&":$symbols,
+ "::mlir::SymbolTable *":$symbolTable)>,
+ ];
+}
+
#endif // OPENACC_OPS_INTERFACES
|
|
✅ With the latest revision this PR passed the C/C++ code formatter. |
vzakhari
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Looks great! Thank you, Razvan!
flang/lib/Optimizer/OpenACC/Support/FIROpenACCOpsInterfaces.cpp
Outdated
Show resolved
Hide resolved
clementval
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Just one brace comment otherwise LGTM
🐧 Linux x64 Test Results
|
FIR operations that use derived types need to have type descriptor globals available on device when offloading. Examples of this can be seen in
CUFDeviceGlobalwhich ensures that such type descriptor uses work on device for CUF.Similarly, this is needed for OpenACC. This change introduces a new interface to the OpenACC dialect named
IndirectGlobalAccessOpInterfacewhich can be attached to operations that may result in generation of accesses that use type descriptor globals. This functionality is needed for theACCImplicitDeclarepass that is coming in a follow-up change which implicitly ensures that all referenced globals are available in OpenACC compute contexts.The interface provides a
getReferencedSymbolsmethod that collects all global symbols referenced by an operation. When a symbol table is provided, the implementation for FIR recursively walks type descriptor globals to find all transitively referenced symbols.Note that alternately this could have been implemented in different ways:
The technique chosen here is consistent with
CUFDeviceGlobal(which walks operations insideprepareImplicitDeviceGlobals) and avoids conservative mapping of all type descriptors.