Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 8 additions & 1 deletion flang/include/flang/Optimizer/HLFIR/HLFIROps.td
Original file line number Diff line number Diff line change
Expand Up @@ -269,6 +269,9 @@ def hlfir_DesignateOp : hlfir_Op<"designate", [AttrSizedOperandSegments,
using Triplet = std::tuple<mlir::Value, mlir::Value, mlir::Value>;
using Subscript = std::variant<mlir::Value, Triplet>;
using Subscripts = llvm::SmallVector<Subscript, 8>;
void setFortranAttrs(fir::FortranVariableFlagsEnum flags) {
this->setFortranAttrs(std::optional<fir::FortranVariableFlagsEnum>(flags));
}
}];

let builders = [
Expand Down Expand Up @@ -319,7 +322,7 @@ def hlfir_ParentComponentOp : hlfir_Op<"parent_comp", [AttrSizedOperandSegments,
// Implement FortranVariableInterface interface. Parent components have
// no attributes (pointer, allocatable or contiguous can only be added
// to regular components).
std::optional<fir::FortranVariableFlagsEnum> getFortranAttrs() const {
std::optional<fir::FortranVariableFlagsEnum> getFortranAttrs() {
return std::nullopt;
}
void setFortranAttrs(fir::FortranVariableFlagsEnum flags) {}
Expand Down Expand Up @@ -882,6 +885,10 @@ def hlfir_AssociateOp : hlfir_Op<"associate", [AttrSizedOperandSegments,
CArg<"llvm::ArrayRef<mlir::NamedAttribute>", "{}">:$attributes)>];

let extraClassDeclaration = [{
void setFortranAttrs(fir::FortranVariableFlagsEnum flags) {
this->setFortranAttrs(std::optional<fir::FortranVariableFlagsEnum>(flags));
}

/// Override FortranVariableInterface default implementation
mlir::Value getBase() {
return getResult(0);
Expand Down
9 changes: 8 additions & 1 deletion mlir/include/mlir/TableGen/Interfaces.h
Original file line number Diff line number Diff line change
Expand Up @@ -32,14 +32,17 @@ class InterfaceMethod {
StringRef name;
};

explicit InterfaceMethod(const llvm::Record *def);
explicit InterfaceMethod(const llvm::Record *def, std::string uniqueName);

// Return the return type of this method.
StringRef getReturnType() const;

// Return the name of this method.
StringRef getName() const;

// Return the dedup name of this method.
StringRef getUniqueName() const;

// Return if this method is static.
bool isStatic() const;

Expand All @@ -62,6 +65,10 @@ class InterfaceMethod {

// The arguments of this method.
SmallVector<Argument, 2> arguments;

// The unique name of this method, to distinguish it from other methods with
// the same name (overloaded methods)
std::string uniqueName;
};

//===----------------------------------------------------------------------===//
Expand Down
21 changes: 18 additions & 3 deletions mlir/lib/TableGen/Interfaces.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,8 @@ using llvm::StringInit;
// InterfaceMethod
//===----------------------------------------------------------------------===//

InterfaceMethod::InterfaceMethod(const Record *def) : def(def) {
InterfaceMethod::InterfaceMethod(const Record *def, std::string uniqueName)
: def(def), uniqueName(uniqueName) {
const DagInit *args = def->getValueAsDag("arguments");
for (unsigned i = 0, e = args->getNumArgs(); i != e; ++i) {
arguments.push_back({cast<StringInit>(args->getArg(i))->getValue(),
Expand All @@ -42,6 +43,9 @@ StringRef InterfaceMethod::getName() const {
return def->getValueAsString("name");
}

// Return the name of this method.
StringRef InterfaceMethod::getUniqueName() const { return uniqueName; }

// Return if this method is static.
bool InterfaceMethod::isStatic() const {
return def->isSubClassOf("StaticInterfaceMethod");
Expand Down Expand Up @@ -83,8 +87,19 @@ Interface::Interface(const Record *def) : def(def) {

// Initialize the interface methods.
auto *listInit = dyn_cast<ListInit>(def->getValueInit("methods"));
for (const Init *init : listInit->getElements())
methods.emplace_back(cast<DefInit>(init)->getDef());
// In case of overloaded methods, we need to find a unique name for each for
// the internal function pointer in the "vtable" we generate. This is an
// internal name, we could use a randomly generated name as long as there are
// no collisions.
StringSet<> uniqueNames;
for (const Init *init : listInit->getElements()) {
std::string name =
cast<DefInit>(init)->getDef()->getValueAsString("name").str();
while (!uniqueNames.insert(name).second) {
name = name + "_" + std::to_string(uniqueNames.size());
}
methods.emplace_back(cast<DefInit>(init)->getDef(), name);
}

// Initialize the interface base classes.
auto *basesInit = dyn_cast<ListInit>(def->getValueInit("baseInterfaces"));
Expand Down
10 changes: 10 additions & 0 deletions mlir/test/lib/Dialect/Test/TestInterfaces.td
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,16 @@ def TestTypeInterface
InterfaceMethod<"Prints the type name.",
"void", "printTypeC", (ins "::mlir::Location":$loc)
>,
// Check that we can have multiple method with the same name.
InterfaceMethod<"Prints the type name, with a value prefixed.",
"void", "printTypeC", (ins "::mlir::Location":$loc, "int":$value)
>,
InterfaceMethod<"Prints the type name, with a value prefixed.",
"void", "printTypeC", (ins "::mlir::Location":$loc, "float":$value),
[{}], /*defaultImplementation=*/[{
emitRemark(loc) << $_type << " - " << value << " - Float TestC";
}]
>,
// It should be possible to use the interface type name as result type
// as well as in the implementation.
InterfaceMethod<"Prints the type name and returns the type as interface.",
Expand Down
4 changes: 4 additions & 0 deletions mlir/test/lib/Dialect/Test/TestTypes.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -245,6 +245,10 @@ void TestType::printTypeC(Location loc) const {
emitRemark(loc) << *this << " - TestC";
}

void TestType::printTypeC(Location loc, int value) const {
emitRemark(loc) << *this << " - " << value << " - Int TestC";
}

//===----------------------------------------------------------------------===//
// TestTypeWithLayout
//===----------------------------------------------------------------------===//
Expand Down
2 changes: 2 additions & 0 deletions mlir/test/lib/IR/TestInterfaces.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,8 @@ struct TestTypeInterfaces
testInterface.printTypeA(op->getLoc());
testInterface.printTypeB(op->getLoc());
testInterface.printTypeC(op->getLoc());
testInterface.printTypeC(op->getLoc(), 42);
testInterface.printTypeC(op->getLoc(), 3.14f);
testInterface.printTypeD(op->getLoc());
// Just check that we can assign the result to a variable of interface
// type.
Expand Down
2 changes: 2 additions & 0 deletions mlir/test/mlir-tblgen/interfaces.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,8 @@
// expected-remark@below {{'!test.test_type' - TestA}}
// expected-remark@below {{'!test.test_type' - TestB}}
// expected-remark@below {{'!test.test_type' - TestC}}
// expected-remark@below {{'!test.test_type' - 42 - Int TestC}}
// expected-remark@below {{'!test.test_type' - 3.140000e+00 - Float TestC}}
// expected-remark@below {{'!test.test_type' - TestD}}
// expected-remark@below {{'!test.test_type' - TestRet}}
// expected-remark@below {{'!test.test_type' - TestE}}
Expand Down
19 changes: 18 additions & 1 deletion mlir/tools/mlir-tblgen/AttrOrTypeDefGen.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -130,6 +130,9 @@ class DefGen {
void emitTraitMethods(const InterfaceTrait &trait);
/// Emit a trait method.
void emitTraitMethod(const InterfaceMethod &method);
/// Generate a using declaration for a trait method.
void genTraitMethodUsingDecl(const InterfaceTrait &trait,
const InterfaceMethod &method);

//===--------------------------------------------------------------------===//
// OpAsm{Type,Attr}Interface Default Method Emission
Expand Down Expand Up @@ -176,6 +179,9 @@ class DefGen {
StringRef valueType;
/// The prefix/suffix of the TableGen def name, either "Attr" or "Type".
StringRef defType;

/// The set of using declarations for trait methods.
llvm::StringSet<> interfaceUsingNames;
};
} // namespace

Expand Down Expand Up @@ -632,8 +638,10 @@ void DefGen::emitTraitMethods(const InterfaceTrait &trait) {
// Don't declare if the method has a body. Or if the method has a default
// implementation and the def didn't request that it always be declared.
if (method.getBody() || (method.getDefaultImplementation() &&
!alwaysDeclared.count(method.getName())))
!alwaysDeclared.count(method.getName()))) {
genTraitMethodUsingDecl(trait, method);
continue;
}
emitTraitMethod(method);
}
}
Expand All @@ -649,6 +657,15 @@ void DefGen::emitTraitMethod(const InterfaceMethod &method) {
std::move(params));
}

void DefGen::genTraitMethodUsingDecl(const InterfaceTrait &trait,
const InterfaceMethod &method) {
std::string name = (llvm::Twine(trait.getFullyQualifiedTraitName()) + "<" +
def.getCppClassName() + ">::" + method.getName())
.str();
if (interfaceUsingNames.insert(name).second)
defCls.declare<UsingDeclaration>(std::move(name));
}

//===----------------------------------------------------------------------===//
// OpAsm{Type,Attr}Interface Default Method Emission

Expand Down
27 changes: 26 additions & 1 deletion mlir/tools/mlir-tblgen/OpDefinitionsGen.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -789,6 +789,14 @@ class OpEmitter {
Method *genOpInterfaceMethod(const tblgen::InterfaceMethod &method,
bool declaration = true);

// Generate a `using` declaration for the op interface method to include
// the default implementation from the interface trait.
// This is needed when the interface defines multiple methods with the same
// name, but some have a default implementation and some don't.
UsingDeclaration *
genOpInterfaceMethodUsingDecl(const tblgen::InterfaceTrait *opTrait,
const tblgen::InterfaceMethod &method);

// Generate the side effect interface methods.
void genSideEffectInterfaceMethods();

Expand All @@ -815,6 +823,10 @@ class OpEmitter {

// Helper for emitting op code.
OpOrAdaptorHelper emitHelper;

// Keep track of the interface using declarations that have been generated to
// avoid duplicates.
llvm::StringSet<> interfaceUsingNames;
};

} // namespace
Expand Down Expand Up @@ -3672,8 +3684,10 @@ void OpEmitter::genOpInterfaceMethods(const tblgen::InterfaceTrait *opTrait) {
// Don't declare if the method has a default implementation and the op
// didn't request that it always be declared.
if (method.getDefaultImplementation() &&
!alwaysDeclaredMethods.count(method.getName()))
!alwaysDeclaredMethods.count(method.getName())) {
genOpInterfaceMethodUsingDecl(opTrait, method);
continue;
}
// Interface methods are allowed to overlap with existing methods, so don't
// check if pruned.
(void)genOpInterfaceMethod(method);
Expand All @@ -3692,6 +3706,17 @@ Method *OpEmitter::genOpInterfaceMethod(const InterfaceMethod &method,
std::move(paramList));
}

UsingDeclaration *
OpEmitter::genOpInterfaceMethodUsingDecl(const tblgen::InterfaceTrait *opTrait,
const InterfaceMethod &method) {
std::string name = (llvm::Twine(opTrait->getFullyQualifiedTraitName()) + "<" +
op.getCppClassName() + ">::" + method.getName())
.str();
if (interfaceUsingNames.insert(name).second)
return opClass.declare<UsingDeclaration>(std::move(name));
return nullptr;
}

void OpEmitter::genOpInterfaceMethods() {
for (const auto &trait : op.getTraits()) {
if (const auto *opTrait = dyn_cast<tblgen::InterfaceTrait>(&trait))
Expand Down
32 changes: 18 additions & 14 deletions mlir/tools/mlir-tblgen/OpInterfacesGen.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -42,10 +42,10 @@ static raw_ostream &emitCPPType(StringRef type, raw_ostream &os) {
/// Emit the method name and argument list for the given method. If 'addThisArg'
/// is true, then an argument is added to the beginning of the argument list for
/// the concrete value.
static void emitMethodNameAndArgs(const InterfaceMethod &method,
static void emitMethodNameAndArgs(const InterfaceMethod &method, StringRef name,
raw_ostream &os, StringRef valueType,
bool addThisArg, bool addConst) {
os << method.getName() << '(';
os << name << '(';
if (addThisArg) {
if (addConst)
os << "const ";
Expand Down Expand Up @@ -183,11 +183,13 @@ static void emitInterfaceDefMethods(StringRef interfaceQualName,
emitInterfaceMethodDoc(method, os);
emitCPPType(method.getReturnType(), os);
os << interfaceQualName << "::";
emitMethodNameAndArgs(method, os, valueType, /*addThisArg=*/false,
emitMethodNameAndArgs(method, method.getName(), os, valueType,
/*addThisArg=*/false,
/*addConst=*/!isOpInterface);

// Forward to the method on the concrete operation type.
os << " {\n return " << implValue << "->" << method.getName() << '(';
os << " {\n return " << implValue << "->" << method.getUniqueName()
<< '(';
if (!method.isStatic()) {
os << implValue << ", ";
os << (isOpInterface ? "getOperation()" : "*this");
Expand Down Expand Up @@ -239,7 +241,7 @@ void InterfaceGenerator::emitConceptDecl(const Interface &interface) {
for (auto &method : interface.getMethods()) {
os << " ";
emitCPPType(method.getReturnType(), os);
os << "(*" << method.getName() << ")(";
os << "(*" << method.getUniqueName() << ")(";
if (!method.isStatic()) {
os << "const Concept *impl, ";
emitCPPType(valueType, os) << (method.arg_empty() ? "" : ", ");
Expand Down Expand Up @@ -289,13 +291,13 @@ void InterfaceGenerator::emitModelDecl(const Interface &interface) {
os << " " << modelClass << "() : Concept{";
llvm::interleaveComma(
interface.getMethods(), os,
[&](const InterfaceMethod &method) { os << method.getName(); });
[&](const InterfaceMethod &method) { os << method.getUniqueName(); });
os << "} {}\n\n";

// Insert each of the virtual method overrides.
for (auto &method : interface.getMethods()) {
emitCPPType(method.getReturnType(), os << " static inline ");
emitMethodNameAndArgs(method, os, valueType,
emitMethodNameAndArgs(method, method.getUniqueName(), os, valueType,
/*addThisArg=*/!method.isStatic(),
/*addConst=*/false);
os << ";\n";
Expand All @@ -319,7 +321,7 @@ void InterfaceGenerator::emitModelDecl(const Interface &interface) {
if (method.isStatic())
os << "static ";
emitCPPType(method.getReturnType(), os);
os << method.getName() << "(";
os << method.getUniqueName() << "(";
if (!method.isStatic()) {
emitCPPType(valueType, os);
os << "tablegen_opaque_val";
Expand Down Expand Up @@ -350,7 +352,7 @@ void InterfaceGenerator::emitModelMethodsDef(const Interface &interface) {
emitCPPType(method.getReturnType(), os);
os << "detail::" << interface.getName() << "InterfaceTraits::Model<"
<< valueTemplate << ">::";
emitMethodNameAndArgs(method, os, valueType,
emitMethodNameAndArgs(method, method.getUniqueName(), os, valueType,
/*addThisArg=*/!method.isStatic(),
/*addConst=*/false);
os << " {\n ";
Expand Down Expand Up @@ -384,7 +386,7 @@ void InterfaceGenerator::emitModelMethodsDef(const Interface &interface) {
emitCPPType(method.getReturnType(), os);
os << "detail::" << interface.getName() << "InterfaceTraits::FallbackModel<"
<< valueTemplate << ">::";
emitMethodNameAndArgs(method, os, valueType,
emitMethodNameAndArgs(method, method.getUniqueName(), os, valueType,
/*addThisArg=*/!method.isStatic(),
/*addConst=*/false);
os << " {\n ";
Expand All @@ -396,7 +398,7 @@ void InterfaceGenerator::emitModelMethodsDef(const Interface &interface) {
os << "return static_cast<const " << valueTemplate << " *>(impl)->";

// Add the arguments to the call.
os << method.getName() << '(';
os << method.getUniqueName() << '(';
if (!method.isStatic())
os << "tablegen_opaque_val" << (method.arg_empty() ? "" : ", ");
llvm::interleaveComma(
Expand All @@ -416,7 +418,7 @@ void InterfaceGenerator::emitModelMethodsDef(const Interface &interface) {
<< "InterfaceTraits::ExternalModel<ConcreteModel, " << valueTemplate
<< ">::";

os << method.getName() << "(";
os << method.getUniqueName() << "(";
if (!method.isStatic()) {
emitCPPType(valueType, os);
os << "tablegen_opaque_val";
Expand Down Expand Up @@ -477,7 +479,8 @@ void InterfaceGenerator::emitInterfaceTraitDecl(const Interface &interface) {
emitInterfaceMethodDoc(method, os, " ");
os << " " << (method.isStatic() ? "static " : "");
emitCPPType(method.getReturnType(), os);
emitMethodNameAndArgs(method, os, valueType, /*addThisArg=*/false,
emitMethodNameAndArgs(method, method.getName(), os, valueType,
/*addThisArg=*/false,
/*addConst=*/!isOpInterface && !method.isStatic());
os << " {\n " << tblgen::tgfmt(defaultImpl->trim(), &traitMethodFmt)
<< "\n }\n";
Expand Down Expand Up @@ -514,7 +517,8 @@ static void emitInterfaceDeclMethods(const Interface &interface,
for (auto &method : interface.getMethods()) {
emitInterfaceMethodDoc(method, os, " ");
emitCPPType(method.getReturnType(), os << " ");
emitMethodNameAndArgs(method, os, valueType, /*addThisArg=*/false,
emitMethodNameAndArgs(method, method.getName(), os, valueType,
/*addThisArg=*/false,
/*addConst=*/!isOpInterface);
os << ";\n";
}
Expand Down