Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,9 @@
#include "mlir/Dialect/OpenACC/OpenACC.h"

namespace fir {
class AddrOfOp;
class DeclareOp;
class GlobalOp;
} // namespace fir

namespace hlfir {
Expand Down Expand Up @@ -53,6 +55,18 @@ struct PartialEntityAccessModel<hlfir::DeclareOp>
bool isCompleteView(mlir::Operation *op) const;
};

struct AddressOfGlobalModel
: public mlir::acc::AddressOfGlobalOpInterface::ExternalModel<
AddressOfGlobalModel, fir::AddrOfOp> {
mlir::SymbolRefAttr getSymbol(mlir::Operation *op) const;
};

struct GlobalVariableModel
: public mlir::acc::GlobalVariableOpInterface::ExternalModel<
GlobalVariableModel, fir::GlobalOp> {
bool isConstant(mlir::Operation *op) const;
};

} // namespace fir::acc

#endif // FLANG_OPTIMIZER_OPENACC_FIROPENACC_OPS_INTERFACES_H_
Original file line number Diff line number Diff line change
Expand Up @@ -59,4 +59,13 @@ bool PartialEntityAccessModel<hlfir::DeclareOp>::isCompleteView(
return !getBaseEntity(op);
}

mlir::SymbolRefAttr AddressOfGlobalModel::getSymbol(mlir::Operation *op) const {
return mlir::cast<fir::AddrOfOp>(op).getSymbolAttr();
}

bool GlobalVariableModel::isConstant(mlir::Operation *op) const {
auto globalOp = mlir::cast<fir::GlobalOp>(op);
return globalOp.getConstant().has_value();
}

} // namespace fir::acc
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,9 @@ void registerOpenACCExtensions(mlir::DialectRegistry &registry) {
PartialEntityAccessModel<fir::CoordinateOp>>(*ctx);
fir::DeclareOp::attachInterface<PartialEntityAccessModel<fir::DeclareOp>>(
*ctx);

fir::AddrOfOp::attachInterface<AddressOfGlobalModel>(*ctx);
fir::GlobalOp::attachInterface<GlobalVariableModel>(*ctx);
});

// Register HLFIR operation interfaces
Expand Down
31 changes: 31 additions & 0 deletions mlir/include/mlir/Dialect/OpenACC/OpenACCOpsInterfaces.td
Original file line number Diff line number Diff line change
Expand Up @@ -44,4 +44,35 @@ def PartialEntityAccessOpInterface : OpInterface<"PartialEntityAccessOpInterface
];
}

def AddressOfGlobalOpInterface : OpInterface<"AddressOfGlobalOpInterface"> {
let cppNamespace = "::mlir::acc";

let description = [{
An interface for operations that compute the address of a global variable
or symbol.
}];

let methods = [
InterfaceMethod<"Get the symbol reference to the global", "::mlir::SymbolRefAttr",
"getSymbol", (ins)>,
];
}

def GlobalVariableOpInterface : OpInterface<"GlobalVariableOpInterface"> {
let cppNamespace = "::mlir::acc";

let description = [{
An interface for operations that define global variables. This interface
provides a uniform way to query properties of global variables across
different dialects.
}];

let methods = [
InterfaceMethod<"Check if the global variable is constant", "bool",
"isConstant", (ins), [{
return false;
}]>,
];
}

#endif // OPENACC_OPS_INTERFACES
23 changes: 23 additions & 0 deletions mlir/lib/Dialect/OpenACC/IR/OpenACC.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -211,6 +211,24 @@ struct LLVMPointerPointerLikeModel
Type getElementType(Type pointer) const { return Type(); }
};

struct MemrefAddressOfGlobalModel
: public AddressOfGlobalOpInterface::ExternalModel<
MemrefAddressOfGlobalModel, memref::GetGlobalOp> {
SymbolRefAttr getSymbol(Operation *op) const {
auto getGlobalOp = cast<memref::GetGlobalOp>(op);
return getGlobalOp.getNameAttr();
}
};

struct MemrefGlobalVariableModel
: public GlobalVariableOpInterface::ExternalModel<MemrefGlobalVariableModel,
memref::GlobalOp> {
bool isConstant(Operation *op) const {
auto globalOp = cast<memref::GlobalOp>(op);
return globalOp.getConstant();
}
};

/// Helper function for any of the times we need to modify an ArrayAttr based on
/// a device type list. Returns a new ArrayAttr with all of the
/// existingDeviceTypes, plus the effective new ones(or an added none if hte new
Expand Down Expand Up @@ -302,6 +320,11 @@ void OpenACCDialect::initialize() {
MemRefPointerLikeModel<UnrankedMemRefType>>(*getContext());
LLVM::LLVMPointerType::attachInterface<LLVMPointerPointerLikeModel>(
*getContext());

// Attach operation interfaces
memref::GetGlobalOp::attachInterface<MemrefAddressOfGlobalModel>(
*getContext());
memref::GlobalOp::attachInterface<MemrefGlobalVariableModel>(*getContext());
}

//===----------------------------------------------------------------------===//
Expand Down
1 change: 1 addition & 0 deletions mlir/unittests/Dialect/OpenACC/CMakeLists.txt
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
add_mlir_unittest(MLIROpenACCTests
OpenACCOpsTest.cpp
OpenACCOpsInterfacesTest.cpp
OpenACCUtilsTest.cpp
)
mlir_target_link_libraries(MLIROpenACCTests
Expand Down
95 changes: 95 additions & 0 deletions mlir/unittests/Dialect/OpenACC/OpenACCOpsInterfacesTest.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,95 @@
//===- OpenACCOpsInterfacesTest.cpp - Unit tests for OpenACC interfaces --===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//

#include "mlir/Dialect/MemRef/IR/MemRef.h"
#include "mlir/Dialect/OpenACC/OpenACC.h"
#include "mlir/IR/Builders.h"
#include "mlir/IR/BuiltinTypes.h"
#include "mlir/IR/MLIRContext.h"
#include "mlir/IR/OwningOpRef.h"
#include "gtest/gtest.h"

using namespace mlir;
using namespace mlir::acc;

//===----------------------------------------------------------------------===//
// Test Fixture
//===----------------------------------------------------------------------===//

class OpenACCOpsInterfacesTest : public ::testing::Test {
protected:
OpenACCOpsInterfacesTest()
: context(), builder(&context), loc(UnknownLoc::get(&context)) {
context.loadDialect<acc::OpenACCDialect, memref::MemRefDialect>();
}

MLIRContext context;
OpBuilder builder;
Location loc;
};

//===----------------------------------------------------------------------===//
// GlobalVariableOpInterface Tests
//===----------------------------------------------------------------------===//

TEST_F(OpenACCOpsInterfacesTest, GlobalVariableOpInterfaceNonConstant) {
// Test that a non-constant global returns false for isConstant()

auto memrefType = MemRefType::get({10}, builder.getF32Type());
OwningOpRef<memref::GlobalOp> globalOp = memref::GlobalOp::create(
builder, loc,
/*sym_name=*/builder.getStringAttr("mutable_global"),
/*sym_visibility=*/builder.getStringAttr("private"),
/*type=*/TypeAttr::get(memrefType),
/*initial_value=*/Attribute(),
/*constant=*/UnitAttr(),
/*alignment=*/IntegerAttr());

auto globalVarIface =
dyn_cast<GlobalVariableOpInterface>(globalOp->getOperation());
ASSERT_TRUE(globalVarIface != nullptr);
EXPECT_FALSE(globalVarIface.isConstant());
}

TEST_F(OpenACCOpsInterfacesTest, GlobalVariableOpInterfaceConstant) {
// Test that a constant global returns true for isConstant()

auto memrefType = MemRefType::get({5}, builder.getI32Type());
OwningOpRef<memref::GlobalOp> constantGlobalOp = memref::GlobalOp::create(
builder, loc,
/*sym_name=*/builder.getStringAttr("constant_global"),
/*sym_visibility=*/builder.getStringAttr("public"),
/*type=*/TypeAttr::get(memrefType),
/*initial_value=*/Attribute(),
/*constant=*/builder.getUnitAttr(),
/*alignment=*/IntegerAttr());

auto globalVarIface =
dyn_cast<GlobalVariableOpInterface>(constantGlobalOp->getOperation());
ASSERT_TRUE(globalVarIface != nullptr);
EXPECT_TRUE(globalVarIface.isConstant());
}

//===----------------------------------------------------------------------===//
// AddressOfGlobalOpInterface Tests
//===----------------------------------------------------------------------===//

TEST_F(OpenACCOpsInterfacesTest, AddressOfGlobalOpInterfaceGetSymbol) {
// Test that getSymbol() returns the correct symbol reference

auto memrefType = MemRefType::get({5}, builder.getI32Type());
const auto *symbolName = "test_global_symbol";

OwningOpRef<memref::GetGlobalOp> getGlobalOp = memref::GetGlobalOp::create(
builder, loc, memrefType, FlatSymbolRefAttr::get(&context, symbolName));

auto addrOfGlobalIface =
dyn_cast<AddressOfGlobalOpInterface>(getGlobalOp->getOperation());
ASSERT_TRUE(addrOfGlobalIface != nullptr);
EXPECT_EQ(addrOfGlobalIface.getSymbol().getLeafReference(), symbolName);
}