diff --git a/mlir/include/mlir/IR/EnumAttr.td b/mlir/include/mlir/IR/EnumAttr.td index d5bc51a3aab94..929283e4d48b6 100644 --- a/mlir/include/mlir/IR/EnumAttr.td +++ b/mlir/include/mlir/IR/EnumAttr.td @@ -243,6 +243,15 @@ class BitEnumAttrgetValueAsString("specializedAttrClassName"); } +bool EnumAttr::printBitEnumPrimaryGroups() const { + return def->getValueAsBit("printBitEnumPrimaryGroups"); +} + StructFieldAttr::StructFieldAttr(const llvm::Record *record) : def(record) { assert(def->isSubClassOf("StructFieldAttr") && "must be subclass of TableGen 'StructFieldAttr' class"); diff --git a/mlir/tools/mlir-tblgen/EnumsGen.cpp b/mlir/tools/mlir-tblgen/EnumsGen.cpp index 15b17199735ac..1e71cdb16fe03 100644 --- a/mlir/tools/mlir-tblgen/EnumsGen.cpp +++ b/mlir/tools/mlir-tblgen/EnumsGen.cpp @@ -204,12 +204,47 @@ static void emitSymToStrFnForBitEnum(const Record &enumDef, raw_ostream &os) { allBitsUnsetCase->getSymbol()); } os << " ::llvm::SmallVector<::llvm::StringRef, 2> strs;\n"; - for (const auto &enumerant : enumerants) { - // Skip the special enumerant for None. - if (int64_t val = enumerant.getValue()) - os << formatv( - " if ({0}u == ({0}u & val)) {{ strs.push_back(\"{1}\"); }\n ", val, - enumerant.getStr()); + + // Add case string if the value has all case bits, and remove them to avoid + // printing again. Used only for groups, when printBitEnumPrimaryGroups is 1. + const char *const formatCompareRemove = R"( + if ({0}u == ({0}u & val)) {{ + strs.push_back("{1}"); + val &= ~static_cast<{2}>({0}); + } +)"; + // Add case string if the value has all case bits. Used for individual bit + // cases, and for groups when printBitEnumPrimaryGroups is 0. + const char *const formatCompare = R"( + if ({0}u == ({0}u & val)) + strs.push_back("{1}"); +)"; + // Optionally elide bits that are members of groups that will also be printed + // for more concise output. + if (enumAttr.printBitEnumPrimaryGroups()) { + os << " // Print bit enum groups before individual bits\n"; + // Emit comparisons for group bit cases in reverse tablegen declaration + // order, removing bits for groups with all bits present. + for (const auto &enumerant : llvm::reverse(enumerants)) { + if ((enumerant.getValue() != 0) && + enumerant.getDef().isSubClassOf("BitEnumAttrCaseGroup")) { + os << formatv(formatCompareRemove, enumerant.getValue(), + enumerant.getStr(), enumAttr.getUnderlyingType()); + } + } + // Emit comparisons for individual bit cases in tablegen declaration order. + for (const auto &enumerant : enumerants) { + if ((enumerant.getValue() != 0) && + enumerant.getDef().isSubClassOf("BitEnumAttrCaseBit")) + os << formatv(formatCompare, enumerant.getValue(), enumerant.getStr()); + } + } else { + // Emit comparisons for ALL nonzero cases (individual bits and groups) in + // tablegen declaration order. + for (const auto &enumerant : enumerants) { + if (enumerant.getValue() != 0) + os << formatv(formatCompare, enumerant.getValue(), enumerant.getStr()); + } } os << formatv(" return ::llvm::join(strs, \"{0}\");\n", separator); diff --git a/mlir/unittests/TableGen/EnumsGenTest.cpp b/mlir/unittests/TableGen/EnumsGenTest.cpp index a5819c5a857ca..1b6f23932249f 100644 --- a/mlir/unittests/TableGen/EnumsGenTest.cpp +++ b/mlir/unittests/TableGen/EnumsGenTest.cpp @@ -70,6 +70,9 @@ TEST(EnumsGenTest, GeneratedBitEnumDefinition) { EXPECT_EQ(0u, static_cast(BitEnumWithNone::None)); EXPECT_EQ(1u, static_cast(BitEnumWithNone::Bit0)); EXPECT_EQ(8u, static_cast(BitEnumWithNone::Bit3)); + + EXPECT_EQ(2u, static_cast(BitEnum64_Test::Bit1)); + EXPECT_EQ(144115188075855872u, static_cast(BitEnum64_Test::Bit57)); } TEST(EnumsGenTest, GeneratedSymbolToStringFnForBitEnum) { @@ -79,8 +82,11 @@ TEST(EnumsGenTest, GeneratedSymbolToStringFnForBitEnum) { EXPECT_EQ( stringifyBitEnumWithNone(BitEnumWithNone::Bit0 | BitEnumWithNone::Bit3), "Bit0|Bit3"); - EXPECT_EQ(2u, static_cast(BitEnum64_Test::Bit1)); - EXPECT_EQ(144115188075855872u, static_cast(BitEnum64_Test::Bit57)); + + EXPECT_EQ(stringifyBitEnum64_Test(BitEnum64_Test::Bit1), "Bit1"); + EXPECT_EQ( + stringifyBitEnum64_Test(BitEnum64_Test::Bit1 | BitEnum64_Test::Bit57), + "Bit1|Bit57"); } TEST(EnumsGenTest, GeneratedStringToSymbolForBitEnum) { @@ -116,6 +122,26 @@ TEST(EnumsGenTest, GeneratedStringToSymbolForGroupedBitEnum) { BitEnumWithGroup::Bit3 | BitEnumWithGroup::Bit0); } +TEST(EnumsGenTest, GeneratedSymbolToStringFnForPrimaryGroupBitEnum) { + EXPECT_EQ(stringifyBitEnumPrimaryGroup( + BitEnumPrimaryGroup::Bit0 | BitEnumPrimaryGroup::Bit1 | + BitEnumPrimaryGroup::Bit2 | BitEnumPrimaryGroup::Bit3), + "Bits0To3"); + EXPECT_EQ(stringifyBitEnumPrimaryGroup(BitEnumPrimaryGroup::Bit0 | + BitEnumPrimaryGroup::Bit2 | + BitEnumPrimaryGroup::Bit3), + "Bit0,Bit2,Bit3"); + EXPECT_EQ(stringifyBitEnumPrimaryGroup(BitEnumPrimaryGroup::Bit0 | + BitEnumPrimaryGroup::Bit4 | + BitEnumPrimaryGroup::Bit5), + "Bits4And5,Bit0"); + EXPECT_EQ(stringifyBitEnumPrimaryGroup( + BitEnumPrimaryGroup::Bit0 | BitEnumPrimaryGroup::Bit1 | + BitEnumPrimaryGroup::Bit2 | BitEnumPrimaryGroup::Bit3 | + BitEnumPrimaryGroup::Bit4 | BitEnumPrimaryGroup::Bit5), + "Bits0To5"); +} + TEST(EnumsGenTest, GeneratedOperator) { EXPECT_TRUE(bitEnumContains(BitEnumWithNone::Bit0 | BitEnumWithNone::Bit3, BitEnumWithNone::Bit0)); diff --git a/mlir/unittests/TableGen/enums.td b/mlir/unittests/TableGen/enums.td index 2baaeb0a50248..5c48b2c770907 100644 --- a/mlir/unittests/TableGen/enums.td +++ b/mlir/unittests/TableGen/enums.td @@ -40,10 +40,21 @@ def BitEnumWithoutNone : I32BitEnumAttr<"BitEnumWithoutNone", "A test enum", def Bits0To3 : I32BitEnumAttrCaseGroup<"Bits0To3", [Bit0, Bit1, Bit2, Bit3]>; +def Bits4And5 : I32BitEnumAttrCaseGroup<"Bits4And5", + [Bit4, Bit5]>; +def Bits0To5 : I32BitEnumAttrCaseGroup<"Bits0To5", + [Bits0To3, Bits4And5]>; def BitEnumWithGroup : I32BitEnumAttr<"BitEnumWithGroup", "A test enum", [Bit0, Bit1, Bit2, Bit3, Bit4, Bits0To3]>; +def BitEnumPrimaryGroup : I32BitEnumAttr<"BitEnumPrimaryGroup", "test enum", + [Bit0, Bit1, Bit2, Bit3, Bit4, Bit5, + Bits0To3, Bits4And5, Bits0To5]> { + let separator = ","; + let printBitEnumPrimaryGroups = 1; +} + def BitEnum64_None : I64BitEnumAttrCaseNone<"None">; def BitEnum64_57 : I64BitEnumAttrCaseBit<"Bit57", 57>; def BitEnum64_1 : I64BitEnumAttrCaseBit<"Bit1", 1>;