Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
162 changes: 160 additions & 2 deletions mlir/lib/Target/IRDLToCpp/IRDLToCpp.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,7 @@ struct OpStrings {
std::string opCppName;
SmallVector<std::string> opResultNames;
SmallVector<std::string> opOperandNames;
SmallVector<std::string> opRegionNames;
};

static std::string joinNameList(llvm::ArrayRef<std::string> names) {
Expand Down Expand Up @@ -87,8 +88,8 @@ static TypeStrings getStrings(irdl::TypeOp type) {
/// Generates OpStrings from an OperatioOp
static OpStrings getStrings(irdl::OperationOp op) {
auto operandOp = op.getOp<irdl::OperandsOp>();

auto resultOp = op.getOp<irdl::ResultsOp>();
auto regionsOp = op.getOp<irdl::RegionsOp>();

OpStrings strings;
strings.opName = op.getSymName();
Expand All @@ -108,6 +109,13 @@ static OpStrings getStrings(irdl::OperationOp op) {
}));
}

if (regionsOp) {
strings.opRegionNames = SmallVector<std::string>(
llvm::map_range(regionsOp->getNames(), [](Attribute attr) {
return llvm::formatv("{0}", cast<StringAttr>(attr));
}));
}

return strings;
}

Expand All @@ -122,6 +130,7 @@ static void fillDict(irdl::detail::dictionary &dict,
static void fillDict(irdl::detail::dictionary &dict, const OpStrings &strings) {
const auto operandCount = strings.opOperandNames.size();
const auto resultCount = strings.opResultNames.size();
const auto regionCount = strings.opRegionNames.size();

dict["OP_NAME"] = strings.opName;
dict["OP_CPP_NAME"] = strings.opCppName;
Expand All @@ -131,6 +140,7 @@ static void fillDict(irdl::detail::dictionary &dict, const OpStrings &strings) {
operandCount ? joinNameList(strings.opOperandNames) : "{\"\"}";
dict["OP_RESULT_INITIALIZER_LIST"] =
resultCount ? joinNameList(strings.opResultNames) : "{\"\"}";
dict["OP_REGION_COUNT"] = std::to_string(regionCount);
}

/// Fills a dictionary with values from DialectStrings
Expand Down Expand Up @@ -179,6 +189,8 @@ static void generateOpGetterDeclarations(irdl::detail::dictionary &dict,
const OpStrings &opStrings) {
auto opGetters = std::string{};
auto resGetters = std::string{};
auto regionGetters = std::string{};
auto regionAdaptorGetters = std::string{};

for (size_t i = 0, end = opStrings.opOperandNames.size(); i < end; ++i) {
const auto op =
Expand All @@ -196,8 +208,23 @@ static void generateOpGetterDeclarations(irdl::detail::dictionary &dict,
op, i);
}

for (size_t i = 0, end = opStrings.opRegionNames.size(); i < end; ++i) {
const auto op =
llvm::convertToCamelFromSnakeCase(opStrings.opRegionNames[i], true);
regionAdaptorGetters += llvm::formatv(
R"(::mlir::Region &get{0}() { return *getRegions()[{1}]; }
)",
op, i);
regionGetters += llvm::formatv(
R"(::mlir::Region &get{0}() { return (*this)->getRegion({1}); }
)",
op, i);
}

dict["OP_OPERAND_GETTER_DECLS"] = opGetters;
dict["OP_RESULT_GETTER_DECLS"] = resGetters;
dict["OP_REGION_ADAPTER_GETTER_DECLS"] = regionAdaptorGetters;
dict["OP_REGION_GETTER_DECLS"] = regionGetters;
}

static void generateOpBuilderDeclarations(irdl::detail::dictionary &dict,
Expand Down Expand Up @@ -238,6 +265,22 @@ static void generateOpBuilderDeclarations(irdl::detail::dictionary &dict,
dict["OP_BUILD_DECLS"] = buildDecls;
}

// add traits to the dictionary, return true if any were added
static SmallVector<std::string> generateTraits(irdl::OperationOp op,
const OpStrings &strings) {
SmallVector<std::string> cppTraitNames;
if (!strings.opRegionNames.empty()) {
cppTraitNames.push_back(
llvm::formatv("::mlir::OpTrait::NRegions<{0}>::Impl",
strings.opRegionNames.size())
.str());

// Requires verifyInvariantsImpl is implemented on the op
cppTraitNames.emplace_back("::mlir::OpTrait::OpInvariants");
}
return cppTraitNames;
}

static LogicalResult generateOperationInclude(irdl::OperationOp op,
raw_ostream &output,
irdl::detail::dictionary &dict) {
Expand All @@ -247,6 +290,13 @@ static LogicalResult generateOperationInclude(irdl::OperationOp op,
const auto opStrings = getStrings(op);
fillDict(dict, opStrings);

SmallVector<std::string> traitNames = generateTraits(op, opStrings);
if (traitNames.empty())
dict["OP_TEMPLATE_ARGS"] = opStrings.opCppName;
else
dict["OP_TEMPLATE_ARGS"] = llvm::formatv("{0}, {1}", opStrings.opCppName,
llvm::join(traitNames, ", "));

generateOpGetterDeclarations(dict, opStrings);
generateOpBuilderDeclarations(dict, opStrings);

Expand Down Expand Up @@ -301,6 +351,110 @@ static LogicalResult generateInclude(irdl::DialectOp dialect,
return success();
}

static void generateRegionConstraintVerifiers(
irdl::detail::dictionary &dict, irdl::OperationOp op,
const OpStrings &strings, SmallVectorImpl<std::string> &verifierHelpers,
SmallVectorImpl<std::string> &verifierCalls) {
auto regionsOp = op.getOp<irdl::RegionsOp>();
if (strings.opRegionNames.empty() || !regionsOp)
return;

for (size_t i = 0; i < strings.opRegionNames.size(); ++i) {
std::string regionName = strings.opRegionNames[i];
std::string helperFnName =
llvm::formatv("__mlir_irdl_local_region_constraint_{0}_{1}",
strings.opCppName, regionName)
.str();

// Extract the actual region constraint from the IRDL RegionOp
std::string condition = "true";
std::string textualConditionName = "any region";

if (auto regionDefOp =
dyn_cast<irdl::RegionOp>(regionsOp->getArgs()[i].getDefiningOp())) {
// Generate constraint condition based on RegionOp attributes
SmallVector<std::string> conditionParts;
SmallVector<std::string> descriptionParts;

// Check number of blocks constraint
if (auto blockCount = regionDefOp.getNumberOfBlocks()) {
conditionParts.push_back(
llvm::formatv("region.getBlocks().size() == {0}",
blockCount.value())
.str());
descriptionParts.push_back(
llvm::formatv("exactly {0} block(s)", blockCount.value()).str());
}

// Check entry block arguments constraint
if (regionDefOp.getConstrainedArguments()) {
size_t expectedArgCount = regionDefOp.getEntryBlockArgs().size();
conditionParts.push_back(
llvm::formatv("region.getNumArguments() == {0}", expectedArgCount)
.str());
descriptionParts.push_back(
llvm::formatv("{0} entry block argument(s)", expectedArgCount)
.str());
}

// Combine conditions
if (!conditionParts.empty()) {
condition = llvm::join(conditionParts, " && ");
}

// Generate descriptive error message
if (!descriptionParts.empty()) {
textualConditionName =
llvm::formatv("region with {0}",
llvm::join(descriptionParts, " and "))
.str();
}
}

verifierHelpers.push_back(llvm::formatv(
R"(static ::llvm::LogicalResult {0}(::mlir::Operation *op, ::mlir::Region &region, ::llvm::StringRef regionName, unsigned regionIndex) {{
if (!({1})) {{
return op->emitOpError("region #") << regionIndex
<< (regionName.empty() ? " " : " ('" + regionName + "') ")
<< "failed to verify constraint: {2}";
}
return ::mlir::success();
})",
helperFnName, condition, textualConditionName));

verifierCalls.push_back(llvm::formatv(R"(
if (::mlir::failed({0}(*this, (*this)->getRegion({1}), "{2}", {1})))
return ::mlir::failure();)",
helperFnName, i, regionName)
.str());
}
}

static void generateVerifiers(irdl::detail::dictionary &dict,
irdl::OperationOp op, const OpStrings &strings) {
SmallVector<std::string> verifierHelpers;
SmallVector<std::string> verifierCalls;

generateRegionConstraintVerifiers(dict, op, strings, verifierHelpers,
verifierCalls);

// Add an overall verifier that sequences the helper calls
std::string verifierDef =
llvm::formatv(R"(
::llvm::LogicalResult {0}::verifyInvariantsImpl() {{
if(::mlir::failed(verify()))
return ::mlir::failure();

{1}

return ::mlir::success();
})",
strings.opCppName, llvm::join(verifierCalls, "\n"));

dict["OP_VERIFIER_HELPERS"] = llvm::join(verifierHelpers, "\n");
dict["OP_VERIFIER"] = verifierDef;
}

static std::string generateOpDefinition(irdl::detail::dictionary &dict,
irdl::OperationOp op) {
static const auto perOpDefTemplate = mlir::irdl::detail::Template{
Expand Down Expand Up @@ -370,6 +524,8 @@ void {0}::build(::mlir::OpBuilder &opBuilder, ::mlir::OperationState &opState, {

dict["OP_BUILD_DEFS"] = buildDefinition;

generateVerifiers(dict, op, opStrings);

std::string str;
llvm::raw_string_ostream stream{str};
perOpDefTemplate.render(stream, dict);
Expand Down Expand Up @@ -427,7 +583,7 @@ static LogicalResult generateLib(irdl::DialectOp dialect, raw_ostream &output,
dict["TYPE_PARSER"] = llvm::formatv(
R"(static ::mlir::OptionalParseResult generatedTypeParser(::mlir::AsmParser &parser, ::llvm::StringRef *mnemonic, ::mlir::Type &value) {
return ::mlir::AsmParser::KeywordSwitch<::mlir::OptionalParseResult>(parser)
{0}
{0}
.Default([&](llvm::StringRef keyword, llvm::SMLoc) {{
*mnemonic = keyword;
return std::nullopt;
Expand Down Expand Up @@ -520,6 +676,8 @@ static LogicalResult verifySupported(irdl::DialectOp dialect) {
"IRDL C++ translation does not yet support variadic results");
}))
.Case<irdl::AnyOp>(([](irdl::AnyOp) { return success(); }))
.Case<irdl::RegionOp>(([](irdl::RegionOp) { return success(); }))
.Case<irdl::RegionsOp>(([](irdl::RegionsOp) { return success(); }))
.Default([](mlir::Operation *op) -> LogicalResult {
return op->emitError("IRDL C++ translation does not yet support "
"translation of ")
Expand Down
53 changes: 31 additions & 22 deletions mlir/lib/Target/IRDLToCpp/Templates/PerOperationDecl.txt
Original file line number Diff line number Diff line change
Expand Up @@ -12,15 +12,15 @@ public:
struct Properties {
};
public:
__OP_CPP_NAME__GenericAdaptorBase(::mlir::Operation *op)
: odsAttrs(op->getRawDictionaryAttrs()), odsOpName(op->getName()),
odsRegions(op->getRegions())
__OP_CPP_NAME__GenericAdaptorBase(::mlir::Operation *op)
: odsAttrs(op->getRawDictionaryAttrs()), odsOpName(op->getName()),
odsRegions(op->getRegions())
{}

/// Return the unstructured operand index of a structured operand along with
// the amount of unstructured operands it contains.
std::pair<unsigned, unsigned>
getStructuredOperandIndexAndLength (unsigned index,
getStructuredOperandIndexAndLength (unsigned index,
unsigned odsOperandsSize) {
return {index, 1};
}
Expand All @@ -32,6 +32,12 @@ public:
::mlir::DictionaryAttr getAttributes() {
return odsAttrs;
}

__OP_REGION_ADAPTER_GETTER_DECLS__

::mlir::RegionRange getRegions() {
return odsRegions;
}
protected:
::mlir::DictionaryAttr odsAttrs;
::std::optional<::mlir::OperationName> odsOpName;
Expand All @@ -42,28 +48,28 @@ protected:
} // namespace detail

template <typename RangeT>
class __OP_CPP_NAME__GenericAdaptor
class __OP_CPP_NAME__GenericAdaptor
: public detail::__OP_CPP_NAME__GenericAdaptorBase {
using ValueT = ::llvm::detail::ValueOfRange<RangeT>;
using Base = detail::__OP_CPP_NAME__GenericAdaptorBase;
public:
__OP_CPP_NAME__GenericAdaptor(RangeT values, ::mlir::DictionaryAttr attrs,
::mlir::OpaqueProperties properties,
::mlir::RegionRange regions = {})
: __OP_CPP_NAME__GenericAdaptor(values, attrs,
(properties ? *properties.as<::mlir::EmptyProperties *>()
::mlir::OpaqueProperties properties,
::mlir::RegionRange regions = {})
: __OP_CPP_NAME__GenericAdaptor(values, attrs,
(properties ? *properties.as<::mlir::EmptyProperties *>()
: ::mlir::EmptyProperties{}), regions) {}

__OP_CPP_NAME__GenericAdaptor(RangeT values,
__OP_CPP_NAME__GenericAdaptor(RangeT values,
const __OP_CPP_NAME__GenericAdaptorBase &base)
: Base(base), odsOperands(values) {}

// This template parameter allows using __OP_CPP_NAME__ which is declared
// This template parameter allows using __OP_CPP_NAME__ which is declared
// later.
template <typename LateInst = __OP_CPP_NAME__,
typename = std::enable_if_t<
std::is_same_v<LateInst, __OP_CPP_NAME__>>>
__OP_CPP_NAME__GenericAdaptor(RangeT values, LateInst op)
__OP_CPP_NAME__GenericAdaptor(RangeT values, LateInst op)
: Base(op), odsOperands(values) {}

/// Return the unstructured operand index of a structured operand along with
Expand All @@ -77,7 +83,7 @@ public:
RangeT getStructuredOperands(unsigned index) {
auto valueRange = getStructuredOperandIndexAndLength(index);
return {std::next(odsOperands.begin(), valueRange.first),
std::next(odsOperands.begin(),
std::next(odsOperands.begin(),
valueRange.first + valueRange.second)};
}

Expand All @@ -91,7 +97,7 @@ private:
RangeT odsOperands;
};

class __OP_CPP_NAME__Adaptor
class __OP_CPP_NAME__Adaptor
: public __OP_CPP_NAME__GenericAdaptor<::mlir::ValueRange> {
public:
using __OP_CPP_NAME__GenericAdaptor::__OP_CPP_NAME__GenericAdaptor;
Expand All @@ -100,7 +106,7 @@ public:
::llvm::LogicalResult verify(::mlir::Location loc);
};

class __OP_CPP_NAME__ : public ::mlir::Op<__OP_CPP_NAME__> {
class __OP_CPP_NAME__ : public ::mlir::Op<__OP_TEMPLATE_ARGS__> {
public:
using Op::Op;
using Op::print;
Expand All @@ -112,6 +118,8 @@ public:
return {};
}

::llvm::LogicalResult verifyInvariantsImpl();

static constexpr ::llvm::StringLiteral getOperationName() {
return ::llvm::StringLiteral("__DIALECT_NAME__.__OP_NAME__");
}
Expand Down Expand Up @@ -147,7 +155,7 @@ public:
::mlir::Operation::operand_range getStructuredOperands(unsigned index) {
auto valueRange = getStructuredOperandIndexAndLength(index);
return {std::next(getOperation()->operand_begin(), valueRange.first),
std::next(getOperation()->operand_begin(),
std::next(getOperation()->operand_begin(),
valueRange.first + valueRange.second)};
}

Expand All @@ -162,18 +170,19 @@ public:
::mlir::Operation::result_range getStructuredResults(unsigned index) {
auto valueRange = getStructuredResultIndexAndLength(index);
return {std::next(getOperation()->result_begin(), valueRange.first),
std::next(getOperation()->result_begin(),
std::next(getOperation()->result_begin(),
valueRange.first + valueRange.second)};
}

__OP_OPERAND_GETTER_DECLS__
__OP_RESULT_GETTER_DECLS__

__OP_REGION_GETTER_DECLS__

__OP_BUILD_DECLS__
static void build(::mlir::OpBuilder &odsBuilder,
::mlir::OperationState &odsState,
::mlir::TypeRange resultTypes,
::mlir::ValueRange operands,
static void build(::mlir::OpBuilder &odsBuilder,
::mlir::OperationState &odsState,
::mlir::TypeRange resultTypes,
::mlir::ValueRange operands,
::llvm::ArrayRef<::mlir::NamedAttribute> attributes = {});

static __OP_CPP_NAME__ create(::mlir::OpBuilder &odsBuilder,
Expand Down
Loading