diff --git a/flang/include/flang/Optimizer/OpenACC/Support/FIROpenACCOpsInterfaces.h b/flang/include/flang/Optimizer/OpenACC/Support/FIROpenACCOpsInterfaces.h index 7afe97aac57e8..bf87654979cc9 100644 --- a/flang/include/flang/Optimizer/OpenACC/Support/FIROpenACCOpsInterfaces.h +++ b/flang/include/flang/Optimizer/OpenACC/Support/FIROpenACCOpsInterfaces.h @@ -16,7 +16,9 @@ #include "mlir/Dialect/OpenACC/OpenACC.h" namespace fir { +class AddrOfOp; class DeclareOp; +class GlobalOp; } // namespace fir namespace hlfir { @@ -53,6 +55,18 @@ struct PartialEntityAccessModel bool isCompleteView(mlir::Operation *op) const; }; +struct AddressOfGlobalModel + : public mlir::acc::AddressOfGlobalOpInterface::ExternalModel< + AddressOfGlobalModel, fir::AddrOfOp> { + mlir::SymbolRefAttr getSymbol(mlir::Operation *op) const; +}; + +struct GlobalVariableModel + : public mlir::acc::GlobalVariableOpInterface::ExternalModel< + GlobalVariableModel, fir::GlobalOp> { + bool isConstant(mlir::Operation *op) const; +}; + } // namespace fir::acc #endif // FLANG_OPTIMIZER_OPENACC_FIROPENACC_OPS_INTERFACES_H_ diff --git a/flang/lib/Optimizer/OpenACC/Support/FIROpenACCOpsInterfaces.cpp b/flang/lib/Optimizer/OpenACC/Support/FIROpenACCOpsInterfaces.cpp index c1734be5185f4..11fbaf2dc2bb8 100644 --- a/flang/lib/Optimizer/OpenACC/Support/FIROpenACCOpsInterfaces.cpp +++ b/flang/lib/Optimizer/OpenACC/Support/FIROpenACCOpsInterfaces.cpp @@ -59,4 +59,13 @@ bool PartialEntityAccessModel::isCompleteView( return !getBaseEntity(op); } +mlir::SymbolRefAttr AddressOfGlobalModel::getSymbol(mlir::Operation *op) const { + return mlir::cast(op).getSymbolAttr(); +} + +bool GlobalVariableModel::isConstant(mlir::Operation *op) const { + auto globalOp = mlir::cast(op); + return globalOp.getConstant().has_value(); +} + } // namespace fir::acc diff --git a/flang/lib/Optimizer/OpenACC/Support/RegisterOpenACCExtensions.cpp b/flang/lib/Optimizer/OpenACC/Support/RegisterOpenACCExtensions.cpp index d71c40dfac03c..5c7f9985d41ca 100644 --- a/flang/lib/Optimizer/OpenACC/Support/RegisterOpenACCExtensions.cpp +++ b/flang/lib/Optimizer/OpenACC/Support/RegisterOpenACCExtensions.cpp @@ -49,6 +49,9 @@ void registerOpenACCExtensions(mlir::DialectRegistry ®istry) { PartialEntityAccessModel>(*ctx); fir::DeclareOp::attachInterface>( *ctx); + + fir::AddrOfOp::attachInterface(*ctx); + fir::GlobalOp::attachInterface(*ctx); }); // Register HLFIR operation interfaces diff --git a/mlir/include/mlir/Dialect/OpenACC/OpenACCOpsInterfaces.td b/mlir/include/mlir/Dialect/OpenACC/OpenACCOpsInterfaces.td index 054c13a88a552..6b0c84d31d1ba 100644 --- a/mlir/include/mlir/Dialect/OpenACC/OpenACCOpsInterfaces.td +++ b/mlir/include/mlir/Dialect/OpenACC/OpenACCOpsInterfaces.td @@ -44,4 +44,35 @@ def PartialEntityAccessOpInterface : OpInterface<"PartialEntityAccessOpInterface ]; } +def AddressOfGlobalOpInterface : OpInterface<"AddressOfGlobalOpInterface"> { + let cppNamespace = "::mlir::acc"; + + let description = [{ + An interface for operations that compute the address of a global variable + or symbol. + }]; + + let methods = [ + InterfaceMethod<"Get the symbol reference to the global", "::mlir::SymbolRefAttr", + "getSymbol", (ins)>, + ]; +} + +def GlobalVariableOpInterface : OpInterface<"GlobalVariableOpInterface"> { + let cppNamespace = "::mlir::acc"; + + let description = [{ + An interface for operations that define global variables. This interface + provides a uniform way to query properties of global variables across + different dialects. + }]; + + let methods = [ + InterfaceMethod<"Check if the global variable is constant", "bool", + "isConstant", (ins), [{ + return false; + }]>, + ]; +} + #endif // OPENACC_OPS_INTERFACES diff --git a/mlir/lib/Dialect/OpenACC/IR/OpenACC.cpp b/mlir/lib/Dialect/OpenACC/IR/OpenACC.cpp index 8c9c137b8aebb..5749e6ded73ba 100644 --- a/mlir/lib/Dialect/OpenACC/IR/OpenACC.cpp +++ b/mlir/lib/Dialect/OpenACC/IR/OpenACC.cpp @@ -211,6 +211,24 @@ struct LLVMPointerPointerLikeModel Type getElementType(Type pointer) const { return Type(); } }; +struct MemrefAddressOfGlobalModel + : public AddressOfGlobalOpInterface::ExternalModel< + MemrefAddressOfGlobalModel, memref::GetGlobalOp> { + SymbolRefAttr getSymbol(Operation *op) const { + auto getGlobalOp = cast(op); + return getGlobalOp.getNameAttr(); + } +}; + +struct MemrefGlobalVariableModel + : public GlobalVariableOpInterface::ExternalModel { + bool isConstant(Operation *op) const { + auto globalOp = cast(op); + return globalOp.getConstant(); + } +}; + /// Helper function for any of the times we need to modify an ArrayAttr based on /// a device type list. Returns a new ArrayAttr with all of the /// existingDeviceTypes, plus the effective new ones(or an added none if hte new @@ -302,6 +320,11 @@ void OpenACCDialect::initialize() { MemRefPointerLikeModel>(*getContext()); LLVM::LLVMPointerType::attachInterface( *getContext()); + + // Attach operation interfaces + memref::GetGlobalOp::attachInterface( + *getContext()); + memref::GlobalOp::attachInterface(*getContext()); } //===----------------------------------------------------------------------===// diff --git a/mlir/unittests/Dialect/OpenACC/CMakeLists.txt b/mlir/unittests/Dialect/OpenACC/CMakeLists.txt index 177c8680b0040..c8c2bb96b0539 100644 --- a/mlir/unittests/Dialect/OpenACC/CMakeLists.txt +++ b/mlir/unittests/Dialect/OpenACC/CMakeLists.txt @@ -1,5 +1,6 @@ add_mlir_unittest(MLIROpenACCTests OpenACCOpsTest.cpp + OpenACCOpsInterfacesTest.cpp OpenACCUtilsTest.cpp ) mlir_target_link_libraries(MLIROpenACCTests diff --git a/mlir/unittests/Dialect/OpenACC/OpenACCOpsInterfacesTest.cpp b/mlir/unittests/Dialect/OpenACC/OpenACCOpsInterfacesTest.cpp new file mode 100644 index 0000000000000..261f5c513ea24 --- /dev/null +++ b/mlir/unittests/Dialect/OpenACC/OpenACCOpsInterfacesTest.cpp @@ -0,0 +1,95 @@ +//===- OpenACCOpsInterfacesTest.cpp - Unit tests for OpenACC interfaces --===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#include "mlir/Dialect/MemRef/IR/MemRef.h" +#include "mlir/Dialect/OpenACC/OpenACC.h" +#include "mlir/IR/Builders.h" +#include "mlir/IR/BuiltinTypes.h" +#include "mlir/IR/MLIRContext.h" +#include "mlir/IR/OwningOpRef.h" +#include "gtest/gtest.h" + +using namespace mlir; +using namespace mlir::acc; + +//===----------------------------------------------------------------------===// +// Test Fixture +//===----------------------------------------------------------------------===// + +class OpenACCOpsInterfacesTest : public ::testing::Test { +protected: + OpenACCOpsInterfacesTest() + : context(), builder(&context), loc(UnknownLoc::get(&context)) { + context.loadDialect(); + } + + MLIRContext context; + OpBuilder builder; + Location loc; +}; + +//===----------------------------------------------------------------------===// +// GlobalVariableOpInterface Tests +//===----------------------------------------------------------------------===// + +TEST_F(OpenACCOpsInterfacesTest, GlobalVariableOpInterfaceNonConstant) { + // Test that a non-constant global returns false for isConstant() + + auto memrefType = MemRefType::get({10}, builder.getF32Type()); + OwningOpRef globalOp = memref::GlobalOp::create( + builder, loc, + /*sym_name=*/builder.getStringAttr("mutable_global"), + /*sym_visibility=*/builder.getStringAttr("private"), + /*type=*/TypeAttr::get(memrefType), + /*initial_value=*/Attribute(), + /*constant=*/UnitAttr(), + /*alignment=*/IntegerAttr()); + + auto globalVarIface = + dyn_cast(globalOp->getOperation()); + ASSERT_TRUE(globalVarIface != nullptr); + EXPECT_FALSE(globalVarIface.isConstant()); +} + +TEST_F(OpenACCOpsInterfacesTest, GlobalVariableOpInterfaceConstant) { + // Test that a constant global returns true for isConstant() + + auto memrefType = MemRefType::get({5}, builder.getI32Type()); + OwningOpRef constantGlobalOp = memref::GlobalOp::create( + builder, loc, + /*sym_name=*/builder.getStringAttr("constant_global"), + /*sym_visibility=*/builder.getStringAttr("public"), + /*type=*/TypeAttr::get(memrefType), + /*initial_value=*/Attribute(), + /*constant=*/builder.getUnitAttr(), + /*alignment=*/IntegerAttr()); + + auto globalVarIface = + dyn_cast(constantGlobalOp->getOperation()); + ASSERT_TRUE(globalVarIface != nullptr); + EXPECT_TRUE(globalVarIface.isConstant()); +} + +//===----------------------------------------------------------------------===// +// AddressOfGlobalOpInterface Tests +//===----------------------------------------------------------------------===// + +TEST_F(OpenACCOpsInterfacesTest, AddressOfGlobalOpInterfaceGetSymbol) { + // Test that getSymbol() returns the correct symbol reference + + auto memrefType = MemRefType::get({5}, builder.getI32Type()); + const auto *symbolName = "test_global_symbol"; + + OwningOpRef getGlobalOp = memref::GetGlobalOp::create( + builder, loc, memrefType, FlatSymbolRefAttr::get(&context, symbolName)); + + auto addrOfGlobalIface = + dyn_cast(getGlobalOp->getOperation()); + ASSERT_TRUE(addrOfGlobalIface != nullptr); + EXPECT_EQ(addrOfGlobalIface.getSymbol().getLeafReference(), symbolName); +}