Skip to content

Commit

Permalink
[mlir] Generate parser/printers for enums
Browse files Browse the repository at this point in the history
This greatly simplifies composing enums in attribute/type printers,
which currently reimplement these functions as needed.

Differential Revision: https://reviews.llvm.org/D136407
  • Loading branch information
River707 committed Oct 21, 2022
1 parent b525392 commit 29bb0b5
Show file tree
Hide file tree
Showing 6 changed files with 151 additions and 55 deletions.
4 changes: 2 additions & 2 deletions mlir/include/mlir/Dialect/LLVMIR/LLVMAttrDefs.td
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ def LinkageAttr : LLVM_Attr<"Linkage"> {
let parameters = (ins
"linkage::Linkage":$linkage
);
let hasCustomAssemblyFormat = 1;
let assemblyFormat = "`<` $linkage `>`";
}

// Attribute definition for the LLVM Linkage enum.
Expand All @@ -30,7 +30,7 @@ def CConvAttr : LLVM_Attr<"CConv"> {
let parameters = (ins
"CConv":$CallingConv
);
let hasCustomAssemblyFormat = 1;
let assemblyFormat = "`<` $CallingConv `>`";
}

def LoopOptionsAttr : LLVM_Attr<"LoopOptions"> {
Expand Down
48 changes: 0 additions & 48 deletions mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2797,54 +2797,6 @@ bool mlir::LLVM::satisfiesLLVMModule(Operation *op) {
op->hasTrait<OpTrait::IsIsolatedFromAbove>();
}

void LinkageAttr::print(AsmPrinter &printer) const {
printer << "<";
if (static_cast<uint64_t>(getLinkage()) <= getMaxEnumValForLinkage())
printer << stringifyEnum(getLinkage());
else
printer << static_cast<uint64_t>(getLinkage());
printer << ">";
}

Attribute LinkageAttr::parse(AsmParser &parser, Type type) {
StringRef elemName;
if (parser.parseLess() || parser.parseKeyword(&elemName) ||
parser.parseGreater())
return {};
auto elem = linkage::symbolizeLinkage(elemName);
if (!elem) {
parser.emitError(parser.getNameLoc(), "Unknown linkage: ") << elemName;
return {};
}
Linkage linkage = *elem;
return LinkageAttr::get(parser.getContext(), linkage);
}

void CConvAttr::print(AsmPrinter &printer) const {
printer << "<";
if (static_cast<uint64_t>(getCallingConv()) <= cconv::getMaxEnumValForCConv())
printer << stringifyEnum(getCallingConv());
else
printer << "INVALID_cc_" << static_cast<uint64_t>(getCallingConv());
printer << ">";
}

Attribute CConvAttr::parse(AsmParser &parser, Type type) {
StringRef convName;

if (parser.parseLess() || parser.parseKeyword(&convName) ||
parser.parseGreater())
return {};
auto cconv = cconv::symbolizeCConv(convName);
if (!cconv) {
parser.emitError(parser.getNameLoc(), "unknown calling convention: ")
<< convName;
return {};
}
CConv cconvVal = *cconv;
return CConvAttr::get(parser.getContext(), cconvVal);
}

LoopOptionsAttrBuilder::LoopOptionsAttrBuilder(LoopOptionsAttr attr)
: options(attr.getOptions().begin(), attr.getOptions().end()) {}

Expand Down
3 changes: 2 additions & 1 deletion mlir/test/Dialect/LLVMIR/func.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -273,8 +273,9 @@ module {
// -----

module {
// expected-error@+2 {{unknown calling convention: cc_12}}
"llvm.func"() ({
// expected-error @below {{invalid Calling Conventions specification: cc_12}}
// expected-error @below {{failed to parse CConvAttr parameter 'CallingConv' which is to be a `CConv`}}
}) {sym_name = "generic_unknown_calling_convention", CConv = #llvm.cconv<cc_12>, function_type = !llvm.func<i64 (i64, i64)>} : () -> ()
}

Expand Down
49 changes: 49 additions & 0 deletions mlir/test/mlir-tblgen/enums-gen.td
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,24 @@ def MyBitEnum: I32BitEnumAttr<"MyBitEnum", "An example bit enum",
// DECL: std::string stringifyMyBitEnum(MyBitEnum);
// DECL: ::llvm::Optional<MyBitEnum> symbolizeMyBitEnum(::llvm::StringRef);

// DECL: struct FieldParser<::MyBitEnum, ::MyBitEnum> {
// DECL: template <typename ParserT>
// DECL: static FailureOr<::MyBitEnum> parse(ParserT &parser) {
// DECL: // Parse the keyword/string containing the enum.
// DECL: std::string enumKeyword;
// DECL: auto loc = parser.getCurrentLocation();
// DECL: if (failed(parser.parseOptionalKeywordOrString(&enumKeyword)))
// DECL: return parser.emitError(loc, "expected keyword for An example bit enum");
// DECL: // Symbolize the keyword.
// DECL: if (::llvm::Optional<::MyBitEnum> attr = ::symbolizeEnum<::MyBitEnum>(enumKeyword))
// DECL: return *attr;
// DECL: return parser.emitError(loc, "invalid An example bit enum specification: ") << enumKeyword;
// DECL: }

// DECL: inline ::llvm::raw_ostream &operator<<(::llvm::raw_ostream &p, ::MyBitEnum value) {
// DECL: auto valueStr = stringifyEnum(value);
// DECL: return p << valueStr;

// DEF-LABEL: std::string stringifyMyBitEnum
// DEF: auto val = static_cast<uint32_t>
// DEF: if (val == 0) return "None";
Expand All @@ -40,3 +58,34 @@ def MyBitEnum: I32BitEnumAttr<"MyBitEnum", "An example bit enum",
// DEF: if (str == "None") return MyBitEnum::None;
// DEF: .Case("tagged", 1)
// DEF: .Case("Bit1", 2)

// Test enum printer generation for non non-keyword enums.

def NonKeywordBit: I32BitEnumAttrCaseBit<"Bit0", 0, "tag-ged">;
def MyMixedNonKeywordBitEnum: I32BitEnumAttr<"MyMixedNonKeywordBitEnum", "An example bit enum", [
NonKeywordBit,
Bit1
]> {
let genSpecializedAttr = 0;
}

def MyNonKeywordBitEnum: I32BitEnumAttr<"MyNonKeywordBitEnum", "An example bit enum", [
NonKeywordBit
]> {
let genSpecializedAttr = 0;
}

// DECL: inline ::llvm::raw_ostream &operator<<(::llvm::raw_ostream &p, ::MyMixedNonKeywordBitEnum value) {
// DECL: auto valueStr = stringifyEnum(value);
// DECL: switch (value) {
// DECL: case ::MyMixedNonKeywordBitEnum::Bit1:
// DECL: break;
// DECL: default:
// DECL: return p << '"' << valueStr << '"';
// DECL: }
// DECL: return p << valueStr;
// DECL: }

// DECL: inline ::llvm::raw_ostream &operator<<(::llvm::raw_ostream &p, ::MyNonKeywordBitEnum value) {
// DECL: auto valueStr = stringifyEnum(value);
// DECL: return p << '"' << valueStr << '"';
97 changes: 93 additions & 4 deletions mlir/tools/mlir-tblgen/EnumsGen.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -10,9 +10,11 @@
//
//===----------------------------------------------------------------------===//

#include "FormatGen.h"
#include "mlir/TableGen/Attribute.h"
#include "mlir/TableGen/Format.h"
#include "mlir/TableGen/GenInfo.h"
#include "llvm/ADT/BitVector.h"
#include "llvm/ADT/SmallVector.h"
#include "llvm/ADT/StringExtras.h"
#include "llvm/Support/FormatVariadic.h"
Expand Down Expand Up @@ -65,10 +67,92 @@ static void emitEnumClass(const Record &enumDef, StringRef enumName,
os << "};\n\n";
}

static void emitDenseMapInfo(StringRef enumName, std::string underlyingType,
static void emitParserPrinter(const EnumAttr &enumAttr, StringRef qualName,
StringRef cppNamespace, raw_ostream &os) {
if (enumAttr.getUnderlyingType().empty() ||
enumAttr.getConstBuilderTemplate().empty())
return;
auto cases = enumAttr.getAllCases();

// Check which cases shouldn't be printed using a keyword.
llvm::BitVector nonKeywordCases(cases.size());
for (auto [index, caseVal] : llvm::enumerate(cases))
if (!mlir::tblgen::canFormatStringAsKeyword(caseVal.getStr()))
nonKeywordCases.set(index);

// If this is a bit enum attribute, don't allow cases that may overlap with
// other cases. For simplicity sake, only allow cases with a single bit value.
if (enumAttr.isBitEnum()) {
for (auto [index, caseVal] : llvm::enumerate(cases)) {
int64_t value = caseVal.getValue();
if (value < 0 || (value != 0 && !llvm::isPowerOf2_64(value)))
nonKeywordCases.set(index);
}
}

// Generate the parser and the start of the printer for the enum.
const char *parsedAndPrinterStart = R"(
namespace mlir {
template <typename T, typename>
struct FieldParser;
template<>
struct FieldParser<{0}, {0}> {{
template <typename ParserT>
static FailureOr<{0}> parse(ParserT &parser) {{
// Parse the keyword/string containing the enum.
std::string enumKeyword;
auto loc = parser.getCurrentLocation();
if (failed(parser.parseOptionalKeywordOrString(&enumKeyword)))
return parser.emitError(loc, "expected keyword for {2}");
// Symbolize the keyword.
if (::llvm::Optional<{0}> attr = {1}::symbolizeEnum<{0}>(enumKeyword))
return *attr;
return parser.emitError(loc, "invalid {2} specification: ") << enumKeyword;
}
};
} // namespace mlir
namespace llvm {
inline ::llvm::raw_ostream &operator<<(::llvm::raw_ostream &p, {0} value) {{
auto valueStr = stringifyEnum(value);
)";
os << formatv(parsedAndPrinterStart, qualName, cppNamespace,
enumAttr.getSummary());

// If all cases require a string, always wrap.
if (nonKeywordCases.all()) {
os << " return p << '\"' << valueStr << '\"';\n"
"}\n"
"} // namespace llvm\n";
return;
}

// If there are any cases that can't be used with a keyword, switch on the
// case value to determine when to print in the string form.
if (nonKeywordCases.any()) {
os << " switch (value) {\n";
for (auto &it : llvm::enumerate(cases)) {
if (nonKeywordCases.test(it.index()))
continue;
StringRef symbol = it.value().getSymbol();
os << llvm::formatv(" case {0}::{1}:\n", qualName,
llvm::isDigit(symbol.front()) ? ("_" + symbol)
: symbol);
}
os << " break;\n"
" default:\n"
" return p << '\"' << valueStr << '\"';\n"
" }\n";
}
os << " return p << valueStr;\n"
"}\n"
"} // namespace llvm\n";
}

static void emitDenseMapInfo(StringRef qualName, std::string underlyingType,
StringRef cppNamespace, raw_ostream &os) {
std::string qualName =
std::string(formatv("{0}::{1}", cppNamespace, enumName));
if (underlyingType.empty())
underlyingType =
std::string(formatv("std::underlying_type_t<{0}>", qualName));
Expand Down Expand Up @@ -529,8 +613,13 @@ class {1} : public ::mlir::{2} {
for (auto ns : llvm::reverse(namespaces))
os << "} // namespace " << ns << "\n";

// Generate a generic parser and printer for the enum.
std::string qualName =
std::string(formatv("{0}::{1}", cppNamespace, enumName));
emitParserPrinter(enumAttr, qualName, cppNamespace, os);

// Emit DenseMapInfo for this enum class
emitDenseMapInfo(enumName, underlyingType, cppNamespace, os);
emitDenseMapInfo(qualName, underlyingType, cppNamespace, os);
}

static bool emitEnumDecls(const RecordKeeper &recordKeeper, raw_ostream &os) {
Expand Down
5 changes: 5 additions & 0 deletions mlir/tools/mlir-tblgen/FormatGen.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -444,6 +444,11 @@ bool mlir::tblgen::shouldEmitSpaceBefore(StringRef value,

bool mlir::tblgen::canFormatStringAsKeyword(
StringRef value, function_ref<void(Twine)> emitError) {
if (value.empty()) {
if (emitError)
emitError("keywords cannot be empty");
return false;
}
if (!isalpha(value.front()) && value.front() != '_') {
if (emitError)
emitError("valid keyword starts with a letter or '_'");
Expand Down

0 comments on commit 29bb0b5

Please sign in to comment.