diff --git a/mlir/tools/mlir-tblgen/EnumsGen.cpp b/mlir/tools/mlir-tblgen/EnumsGen.cpp index d4d32f5885971..92890a4a89dd2 100644 --- a/mlir/tools/mlir-tblgen/EnumsGen.cpp +++ b/mlir/tools/mlir-tblgen/EnumsGen.cpp @@ -20,6 +20,7 @@ #include "llvm/ADT/StringExtras.h" #include "llvm/Support/FormatVariadic.h" #include "llvm/Support/raw_ostream.h" +#include "llvm/TableGen/CodeGenHelpers.h" #include "llvm/TableGen/Error.h" #include "llvm/TableGen/Record.h" #include "llvm/TableGen/TableGenBackend.h" @@ -703,11 +704,7 @@ static void emitEnumDecl(const Record &enumDef, raw_ostream &os) { StringRef underlyingToSymFnName = enumInfo.getUnderlyingToSymbolFnName(); auto enumerants = enumInfo.getAllCases(); - SmallVector namespaces; - llvm::SplitString(cppNamespace, namespaces, "::"); - - for (auto ns : namespaces) - os << "namespace " << ns << " {\n"; + llvm::NamespaceEmitter ns(os, cppNamespace); // Emit the enum class definition emitEnumClass(enumDef, enumName, underlyingType, description, enumerants, os); @@ -768,8 +765,7 @@ class {1} : public ::mlir::{2} { os << formatv(attrClassDecl, enumName, attrClassName, baseAttrClassName); } - for (auto ns : llvm::reverse(namespaces)) - os << "} // namespace " << ns << "\n"; + ns.close(); // Generate a generic parser and printer for the enum. std::string qualName = @@ -792,13 +788,8 @@ static bool emitEnumDecls(const RecordKeeper &records, raw_ostream &os) { static void emitEnumDef(const Record &enumDef, raw_ostream &os) { EnumInfo enumInfo(enumDef); - StringRef cppNamespace = enumInfo.getCppNamespace(); - SmallVector namespaces; - llvm::SplitString(cppNamespace, namespaces, "::"); - - for (auto ns : namespaces) - os << "namespace " << ns << " {\n"; + llvm::NamespaceEmitter ns(os, enumInfo.getCppNamespace()); if (enumInfo.isBitEnum()) { emitSymToStrFnForBitEnum(enumDef, os); @@ -812,10 +803,6 @@ static void emitEnumDef(const Record &enumDef, raw_ostream &os) { if (enumInfo.genSpecializedAttr()) emitSpecializedAttrDef(enumDef, os); - - for (auto ns : llvm::reverse(namespaces)) - os << "} // namespace " << ns << "\n"; - os << "\n"; } static bool emitEnumDefs(const RecordKeeper &records, raw_ostream &os) { diff --git a/mlir/tools/mlir-tblgen/OpInterfacesGen.cpp b/mlir/tools/mlir-tblgen/OpInterfacesGen.cpp index 3cc1636ac3317..ae80806cbd912 100644 --- a/mlir/tools/mlir-tblgen/OpInterfacesGen.cpp +++ b/mlir/tools/mlir-tblgen/OpInterfacesGen.cpp @@ -19,6 +19,7 @@ #include "llvm/ADT/StringExtras.h" #include "llvm/Support/FormatVariadic.h" #include "llvm/Support/raw_ostream.h" +#include "llvm/TableGen/CodeGenHelpers.h" #include "llvm/TableGen/Error.h" #include "llvm/TableGen/Record.h" #include "llvm/TableGen/TableGenBackend.h" @@ -340,11 +341,7 @@ void InterfaceGenerator::emitModelDecl(const Interface &interface) { } void InterfaceGenerator::emitModelMethodsDef(const Interface &interface) { - llvm::SmallVector namespaces; - llvm::SplitString(interface.getCppNamespace(), namespaces, "::"); - for (StringRef ns : namespaces) - os << "namespace " << ns << " {\n"; - + llvm::NamespaceEmitter ns(os, interface.getCppNamespace()); for (auto &method : interface.getMethods()) { os << "template\n"; emitCPPType(method.getReturnType(), os); @@ -440,18 +437,11 @@ void InterfaceGenerator::emitModelMethodsDef(const Interface &interface) { method.isStatic() ? &ctx : &nonStaticMethodFmt); os << "\n}\n"; } - - for (StringRef ns : llvm::reverse(namespaces)) - os << "} // namespace " << ns << "\n"; } void InterfaceGenerator::emitInterfaceTraitDecl(const Interface &interface) { - llvm::SmallVector namespaces; - llvm::SplitString(interface.getCppNamespace(), namespaces, "::"); - for (StringRef ns : namespaces) - os << "namespace " << ns << " {\n"; - - os << "namespace detail {\n"; + auto cppNamespace = (interface.getCppNamespace() + "::detail").str(); + llvm::NamespaceEmitter ns(os, cppNamespace); StringRef interfaceName = interface.getName(); auto interfaceTraitsName = (interfaceName + "InterfaceTraits").str(); @@ -501,10 +491,6 @@ void InterfaceGenerator::emitInterfaceTraitDecl(const Interface &interface) { os << tblgen::tgfmt(*extraTraitDecls, &traitMethodFmt) << "\n"; os << " };\n"; - os << "}// namespace detail\n"; - - for (StringRef ns : llvm::reverse(namespaces)) - os << "} // namespace " << ns << "\n"; } static void emitInterfaceDeclMethods(const Interface &interface, @@ -529,10 +515,7 @@ static void emitInterfaceDeclMethods(const Interface &interface, } void InterfaceGenerator::forwardDeclareInterface(const Interface &interface) { - llvm::SmallVector namespaces; - llvm::SplitString(interface.getCppNamespace(), namespaces, "::"); - for (StringRef ns : namespaces) - os << "namespace " << ns << " {\n"; + llvm::NamespaceEmitter ns(os, interface.getCppNamespace()); // Emit a forward declaration of the interface class so that it becomes usable // in the signature of its methods. @@ -544,16 +527,10 @@ void InterfaceGenerator::forwardDeclareInterface(const Interface &interface) { StringRef interfaceName = interface.getName(); os << "class " << interfaceName << ";\n"; - - for (StringRef ns : llvm::reverse(namespaces)) - os << "} // namespace " << ns << "\n"; } void InterfaceGenerator::emitInterfaceDecl(const Interface &interface) { - llvm::SmallVector namespaces; - llvm::SplitString(interface.getCppNamespace(), namespaces, "::"); - for (StringRef ns : namespaces) - os << "namespace " << ns << " {\n"; + llvm::NamespaceEmitter ns(os, interface.getCppNamespace()); StringRef interfaceName = interface.getName(); auto interfaceTraitsName = (interfaceName + "InterfaceTraits").str(); @@ -633,9 +610,6 @@ void InterfaceGenerator::emitInterfaceDecl(const Interface &interface) { } os << "};\n"; - - for (StringRef ns : llvm::reverse(namespaces)) - os << "} // namespace " << ns << "\n"; } bool InterfaceGenerator::emitInterfaceDecls() { diff --git a/mlir/tools/mlir-tblgen/SPIRVUtilsGen.cpp b/mlir/tools/mlir-tblgen/SPIRVUtilsGen.cpp index 3ead2f0e37214..ca291b57f4344 100644 --- a/mlir/tools/mlir-tblgen/SPIRVUtilsGen.cpp +++ b/mlir/tools/mlir-tblgen/SPIRVUtilsGen.cpp @@ -259,8 +259,8 @@ static void emitInterfaceDecl(const Availability &availability, std::string interfaceTraitsName = std::string(formatv("{0}Traits", interfaceName)); - StringRef cppNamespace = availability.getInterfaceClassNamespace(); - llvm::NamespaceEmitter nsEmitter(os, cppNamespace); + llvm::NamespaceEmitter nsEmitter(os, + availability.getInterfaceClassNamespace()); os << "class " << interfaceName << ";\n\n"; // Emit the traits struct containing the concept and model declarations. @@ -418,15 +418,9 @@ static void emitAvailabilityQueryForBitEnum(const Record &enumDef, static void emitEnumDecl(const Record &enumDef, raw_ostream &os) { EnumInfo enumInfo(enumDef); StringRef enumName = enumInfo.getEnumClassName(); - StringRef cppNamespace = enumInfo.getCppNamespace(); auto enumerants = enumInfo.getAllCases(); - llvm::SmallVector namespaces; - llvm::SplitString(cppNamespace, namespaces, "::"); - - for (auto ns : namespaces) - os << "namespace " << ns << " {\n"; - + llvm::NamespaceEmitter ns(os, enumInfo.getCppNamespace()); llvm::StringSet<> handledClasses; // Place all availability specifications to their corresponding @@ -441,9 +435,6 @@ static void emitEnumDecl(const Record &enumDef, raw_ostream &os) { enumName); handledClasses.insert(className); } - - for (auto ns : llvm::reverse(namespaces)) - os << "} // namespace " << ns << "\n"; } static bool emitEnumDecls(const RecordKeeper &records, raw_ostream &os) { @@ -459,31 +450,19 @@ static bool emitEnumDecls(const RecordKeeper &records, raw_ostream &os) { static void emitEnumDef(const Record &enumDef, raw_ostream &os) { EnumInfo enumInfo(enumDef); - StringRef cppNamespace = enumInfo.getCppNamespace(); - - llvm::SmallVector namespaces; - llvm::SplitString(cppNamespace, namespaces, "::"); - - for (auto ns : namespaces) - os << "namespace " << ns << " {\n"; + llvm::NamespaceEmitter ns(os, enumInfo.getCppNamespace()); - if (enumInfo.isBitEnum()) { + if (enumInfo.isBitEnum()) emitAvailabilityQueryForBitEnum(enumDef, os); - } else { + else emitAvailabilityQueryForIntEnum(enumDef, os); - } - - for (auto ns : llvm::reverse(namespaces)) - os << "} // namespace " << ns << "\n"; - os << "\n"; } static bool emitEnumDefs(const RecordKeeper &records, raw_ostream &os) { llvm::emitSourceFileHeader("SPIR-V Enum Availability Definitions", os, records); - auto defs = records.getAllDerivedDefinitions("EnumInfo"); - for (const auto *def : defs) + for (const Record *def : records.getAllDerivedDefinitions("EnumInfo")) emitEnumDef(*def, os); return false;