Skip to content

Commit

Permalink
[mlir][ods] Add tablegen field for concise printing of BitEnum attrib…
Browse files Browse the repository at this point in the history
…utes

This diff introduces a tablegen field for bit enum attributes
(`printBitEnumPrimaryGroups`) to control printing when the enum uses "group"
cases. An example would be an implementation that uses a `fastmath` enum value
as an alias for individual fastmath flags. The proposed field would allow
printing of simply `fast` for the enum value, instead of the more verbose list
that would include `fast` as well as the individual flags (e.g. `reassoc,nnan,
ninf,nsz,arcp,contract,afn,fast`).

Reviewed By: rriddle

Differential Revision: https://reviews.llvm.org/D123871
  • Loading branch information
jfurtek authored and Mogball committed Apr 25, 2022
1 parent 87468e8 commit 4e5dee2
Show file tree
Hide file tree
Showing 6 changed files with 94 additions and 8 deletions.
9 changes: 9 additions & 0 deletions mlir/include/mlir/IR/EnumAttr.td
Expand Up @@ -243,6 +243,15 @@ class BitEnumAttr<I intType, string name, string summary,

// The delimiter used to separate bit enum cases in strings.
string separator = "|";

// Print the "primary group" only for bits that are members of case groups
// that have all bits present. When the value is 0, printing will display both
// both individual bit case names AND the names for all groups that the bit is
// contained in. When the value is 1, for each bit that is set AND is a member
// of a group with all bits set, only the "primary group" (i.e. the first
// group with all bits set in reverse declaration order) will be printed (for
// conciseness).
bit printBitEnumPrimaryGroups = 0;
}

class I32BitEnumAttr<string name, string summary,
Expand Down
1 change: 1 addition & 0 deletions mlir/include/mlir/TableGen/Attribute.h
Expand Up @@ -206,6 +206,7 @@ class EnumAttr : public Attribute {
bool genSpecializedAttr() const;
llvm::Record *getBaseAttrClass() const;
StringRef getSpecializedAttrClassName() const;
bool printBitEnumPrimaryGroups() const;
};

class StructFieldAttr {
Expand Down
4 changes: 4 additions & 0 deletions mlir/lib/TableGen/Attribute.cpp
Expand Up @@ -239,6 +239,10 @@ StringRef EnumAttr::getSpecializedAttrClassName() const {
return def->getValueAsString("specializedAttrClassName");
}

bool EnumAttr::printBitEnumPrimaryGroups() const {
return def->getValueAsBit("printBitEnumPrimaryGroups");
}

StructFieldAttr::StructFieldAttr(const llvm::Record *record) : def(record) {
assert(def->isSubClassOf("StructFieldAttr") &&
"must be subclass of TableGen 'StructFieldAttr' class");
Expand Down
47 changes: 41 additions & 6 deletions mlir/tools/mlir-tblgen/EnumsGen.cpp
Expand Up @@ -204,12 +204,47 @@ static void emitSymToStrFnForBitEnum(const Record &enumDef, raw_ostream &os) {
allBitsUnsetCase->getSymbol());
}
os << " ::llvm::SmallVector<::llvm::StringRef, 2> strs;\n";
for (const auto &enumerant : enumerants) {
// Skip the special enumerant for None.
if (int64_t val = enumerant.getValue())
os << formatv(
" if ({0}u == ({0}u & val)) {{ strs.push_back(\"{1}\"); }\n ", val,
enumerant.getStr());

// Add case string if the value has all case bits, and remove them to avoid
// printing again. Used only for groups, when printBitEnumPrimaryGroups is 1.
const char *const formatCompareRemove = R"(
if ({0}u == ({0}u & val)) {{
strs.push_back("{1}");
val &= ~static_cast<{2}>({0});
}
)";
// Add case string if the value has all case bits. Used for individual bit
// cases, and for groups when printBitEnumPrimaryGroups is 0.
const char *const formatCompare = R"(
if ({0}u == ({0}u & val))
strs.push_back("{1}");
)";
// Optionally elide bits that are members of groups that will also be printed
// for more concise output.
if (enumAttr.printBitEnumPrimaryGroups()) {
os << " // Print bit enum groups before individual bits\n";
// Emit comparisons for group bit cases in reverse tablegen declaration
// order, removing bits for groups with all bits present.
for (const auto &enumerant : llvm::reverse(enumerants)) {
if ((enumerant.getValue() != 0) &&
enumerant.getDef().isSubClassOf("BitEnumAttrCaseGroup")) {
os << formatv(formatCompareRemove, enumerant.getValue(),
enumerant.getStr(), enumAttr.getUnderlyingType());
}
}
// Emit comparisons for individual bit cases in tablegen declaration order.
for (const auto &enumerant : enumerants) {
if ((enumerant.getValue() != 0) &&
enumerant.getDef().isSubClassOf("BitEnumAttrCaseBit"))
os << formatv(formatCompare, enumerant.getValue(), enumerant.getStr());
}
} else {
// Emit comparisons for ALL nonzero cases (individual bits and groups) in
// tablegen declaration order.
for (const auto &enumerant : enumerants) {
if (enumerant.getValue() != 0)
os << formatv(formatCompare, enumerant.getValue(), enumerant.getStr());
}
}
os << formatv(" return ::llvm::join(strs, \"{0}\");\n", separator);

Expand Down
30 changes: 28 additions & 2 deletions mlir/unittests/TableGen/EnumsGenTest.cpp
Expand Up @@ -70,6 +70,9 @@ TEST(EnumsGenTest, GeneratedBitEnumDefinition) {
EXPECT_EQ(0u, static_cast<uint32_t>(BitEnumWithNone::None));
EXPECT_EQ(1u, static_cast<uint32_t>(BitEnumWithNone::Bit0));
EXPECT_EQ(8u, static_cast<uint32_t>(BitEnumWithNone::Bit3));

EXPECT_EQ(2u, static_cast<uint64_t>(BitEnum64_Test::Bit1));
EXPECT_EQ(144115188075855872u, static_cast<uint64_t>(BitEnum64_Test::Bit57));
}

TEST(EnumsGenTest, GeneratedSymbolToStringFnForBitEnum) {
Expand All @@ -79,8 +82,11 @@ TEST(EnumsGenTest, GeneratedSymbolToStringFnForBitEnum) {
EXPECT_EQ(
stringifyBitEnumWithNone(BitEnumWithNone::Bit0 | BitEnumWithNone::Bit3),
"Bit0|Bit3");
EXPECT_EQ(2u, static_cast<uint64_t>(BitEnum64_Test::Bit1));
EXPECT_EQ(144115188075855872u, static_cast<uint64_t>(BitEnum64_Test::Bit57));

EXPECT_EQ(stringifyBitEnum64_Test(BitEnum64_Test::Bit1), "Bit1");
EXPECT_EQ(
stringifyBitEnum64_Test(BitEnum64_Test::Bit1 | BitEnum64_Test::Bit57),
"Bit1|Bit57");
}

TEST(EnumsGenTest, GeneratedStringToSymbolForBitEnum) {
Expand Down Expand Up @@ -116,6 +122,26 @@ TEST(EnumsGenTest, GeneratedStringToSymbolForGroupedBitEnum) {
BitEnumWithGroup::Bit3 | BitEnumWithGroup::Bit0);
}

TEST(EnumsGenTest, GeneratedSymbolToStringFnForPrimaryGroupBitEnum) {
EXPECT_EQ(stringifyBitEnumPrimaryGroup(
BitEnumPrimaryGroup::Bit0 | BitEnumPrimaryGroup::Bit1 |
BitEnumPrimaryGroup::Bit2 | BitEnumPrimaryGroup::Bit3),
"Bits0To3");
EXPECT_EQ(stringifyBitEnumPrimaryGroup(BitEnumPrimaryGroup::Bit0 |
BitEnumPrimaryGroup::Bit2 |
BitEnumPrimaryGroup::Bit3),
"Bit0,Bit2,Bit3");
EXPECT_EQ(stringifyBitEnumPrimaryGroup(BitEnumPrimaryGroup::Bit0 |
BitEnumPrimaryGroup::Bit4 |
BitEnumPrimaryGroup::Bit5),
"Bits4And5,Bit0");
EXPECT_EQ(stringifyBitEnumPrimaryGroup(
BitEnumPrimaryGroup::Bit0 | BitEnumPrimaryGroup::Bit1 |
BitEnumPrimaryGroup::Bit2 | BitEnumPrimaryGroup::Bit3 |
BitEnumPrimaryGroup::Bit4 | BitEnumPrimaryGroup::Bit5),
"Bits0To5");
}

TEST(EnumsGenTest, GeneratedOperator) {
EXPECT_TRUE(bitEnumContains(BitEnumWithNone::Bit0 | BitEnumWithNone::Bit3,
BitEnumWithNone::Bit0));
Expand Down
11 changes: 11 additions & 0 deletions mlir/unittests/TableGen/enums.td
Expand Up @@ -40,10 +40,21 @@ def BitEnumWithoutNone : I32BitEnumAttr<"BitEnumWithoutNone", "A test enum",

def Bits0To3 : I32BitEnumAttrCaseGroup<"Bits0To3",
[Bit0, Bit1, Bit2, Bit3]>;
def Bits4And5 : I32BitEnumAttrCaseGroup<"Bits4And5",
[Bit4, Bit5]>;
def Bits0To5 : I32BitEnumAttrCaseGroup<"Bits0To5",
[Bits0To3, Bits4And5]>;

def BitEnumWithGroup : I32BitEnumAttr<"BitEnumWithGroup", "A test enum",
[Bit0, Bit1, Bit2, Bit3, Bit4, Bits0To3]>;

def BitEnumPrimaryGroup : I32BitEnumAttr<"BitEnumPrimaryGroup", "test enum",
[Bit0, Bit1, Bit2, Bit3, Bit4, Bit5,
Bits0To3, Bits4And5, Bits0To5]> {
let separator = ",";
let printBitEnumPrimaryGroups = 1;
}

def BitEnum64_None : I64BitEnumAttrCaseNone<"None">;
def BitEnum64_57 : I64BitEnumAttrCaseBit<"Bit57", 57>;
def BitEnum64_1 : I64BitEnumAttrCaseBit<"Bit1", 1>;
Expand Down

0 comments on commit 4e5dee2

Please sign in to comment.