diff --git a/mlir/CMakeLists.txt b/mlir/CMakeLists.txt index 5c4301af040b4..4c0ef8387b8df 100644 --- a/mlir/CMakeLists.txt +++ b/mlir/CMakeLists.txt @@ -185,10 +185,13 @@ include_directories( ${MLIR_INCLUDE_DIR}) add_subdirectory(tools/mlir-linalg-ods-gen) add_subdirectory(tools/mlir-pdll) add_subdirectory(tools/mlir-tblgen) +add_subdirectory(tools/mlir-src-sharder) set(MLIR_TABLEGEN_EXE "${MLIR_TABLEGEN_EXE}" CACHE INTERNAL "") set(MLIR_TABLEGEN_TARGET "${MLIR_TABLEGEN_TARGET}" CACHE INTERNAL "") set(MLIR_PDLL_TABLEGEN_EXE "${MLIR_PDLL_TABLEGEN_EXE}" CACHE INTERNAL "") set(MLIR_PDLL_TABLEGEN_TARGET "${MLIR_PDLL_TABLEGEN_TARGET}" CACHE INTERNAL "") +set(MLIR_SRC_SHARDER_TABLEGEN_EXE "${MLIR_SRC_SHARDER_TABLEGEN_EXE}" CACHE INTERNAL "") +set(MLIR_SRC_SHARDER_TABLEGEN_TARGET "${MLIR_SRC_SHARDER_TABLEGEN_TARGET}" CACHE INTERNAL "") add_subdirectory(include/mlir) add_subdirectory(lib) diff --git a/mlir/cmake/modules/AddMLIR.cmake b/mlir/cmake/modules/AddMLIR.cmake index 1d2ed748bc2f1..afb74fb2d0002 100644 --- a/mlir/cmake/modules/AddMLIR.cmake +++ b/mlir/cmake/modules/AddMLIR.cmake @@ -5,6 +5,28 @@ function(mlir_tablegen ofn) tablegen(MLIR ${ARGV}) set(TABLEGEN_OUTPUT ${TABLEGEN_OUTPUT} ${CMAKE_CURRENT_BINARY_DIR}/${ofn} PARENT_SCOPE) + + # Get the current set of include paths for this td file. + cmake_parse_arguments(ARG "" "" "DEPENDS;EXTRA_INCLUDES" ${ARGN}) + get_directory_property(tblgen_includes INCLUDE_DIRECTORIES) + list(APPEND tblgen_includes ${ARG_EXTRA_INCLUDES}) + # Filter out any empty include items. + list(REMOVE_ITEM tblgen_includes "") + + # Build the absolute path for the current input file. + if (IS_ABSOLUTE ${LLVM_TARGET_DEFINITIONS}) + set(LLVM_TARGET_DEFINITIONS_ABSOLUTE ${LLVM_TARGET_DEFINITIONS}) + else() + set(LLVM_TARGET_DEFINITIONS_ABSOLUTE ${CMAKE_CURRENT_SOURCE_DIR}/${LLVM_TARGET_DEFINITIONS}) + endif() + + # Append the includes used for this file to the tablegen_compile_commands + # file. + file(APPEND ${CMAKE_BINARY_DIR}/tablegen_compile_commands.yml + "--- !FileInfo:\n" + " filepath: \"${LLVM_TARGET_DEFINITIONS_ABSOLUTE}\"\n" + " includes: \"${CMAKE_CURRENT_SOURCE_DIR};${tblgen_includes}\"\n" + ) endfunction() # Clear out any pre-existing compile_commands file before processing. This @@ -149,6 +171,22 @@ function(add_mlir_dialect dialect dialect_namespace) add_dependencies(mlir-headers MLIR${dialect}IncGen) endfunction() +# Declare sharded dialect operation declarations and definitions +function(add_sharded_ops ops_target shard_count) + set(LLVM_TARGET_DEFINITIONS ${ops_target}.td) + mlir_tablegen(${ops_target}.h.inc -gen-op-decls -op-shard-count=${shard_count}) + mlir_tablegen(${ops_target}.cpp.inc -gen-op-defs -op-shard-count=${shard_count}) + set(LLVM_TARGET_DEFINITIONS ${ops_target}.cpp) + foreach(index RANGE ${shard_count}) + set(SHARDED_SRC ${ops_target}.${index}.cpp) + list(APPEND SHARDED_SRCS ${SHARDED_SRC}) + tablegen(MLIR_SRC_SHARDER ${SHARDED_SRC} -op-shard-index=${index}) + set(TABLEGEN_OUTPUT ${TABLEGEN_OUTPUT} ${CMAKE_CURRENT_BINARY_DIR}/${SHARDED_SRC}) + endforeach() + add_public_tablegen_target(MLIR${ops_target}ShardGen) + set(SHARDED_SRCS ${SHARDED_SRCS} PARENT_SCOPE) +endfunction() + # Declare a dialect in the include directory function(add_mlir_interface interface) set(LLVM_TARGET_DEFINITIONS ${interface}.td) diff --git a/mlir/cmake/modules/CMakeLists.txt b/mlir/cmake/modules/CMakeLists.txt index 8d2904ef46dfe..3ac1c79b090ed 100644 --- a/mlir/cmake/modules/CMakeLists.txt +++ b/mlir/cmake/modules/CMakeLists.txt @@ -39,6 +39,7 @@ set(MLIR_CONFIG_INCLUDE_DIRS # Refer to the best host mlir-tbgen, which might be a host-optimized version set(MLIR_CONFIG_TABLEGEN_EXE "${MLIR_TABLEGEN_EXE}") set(MLIR_CONFIG_PDLL_TABLEGEN_EXE "${MLIR_PDLL_TABLEGEN_EXE}") +set(MLIR_CONFIG_SRC_SHARDER_TABLEGEN_EXE "${MLIR_SRC_SHARDER_TABLEGEN_EXE}") configure_file( ${CMAKE_CURRENT_SOURCE_DIR}/MLIRConfig.cmake.in @@ -77,6 +78,7 @@ set(MLIR_CONFIG_INCLUDE_DIRS # if we're building with a host-optimized mlir-tblgen (with LLVM_OPTIMIZED_TABLEGEN). set(MLIR_CONFIG_TABLEGEN_EXE mlir-tblgen) set(MLIR_CONFIG_PDLL_TABLEGEN_EXE mlir-pdll) +set(MLIR_CONFIG_SRC_SHARDER_TABLEGEN_EXE mlir-src-sharder) configure_file( ${CMAKE_CURRENT_SOURCE_DIR}/MLIRConfig.cmake.in diff --git a/mlir/cmake/modules/MLIRConfig.cmake.in b/mlir/cmake/modules/MLIRConfig.cmake.in index d4da3cd98cce9..7076d94a32f2b 100644 --- a/mlir/cmake/modules/MLIRConfig.cmake.in +++ b/mlir/cmake/modules/MLIRConfig.cmake.in @@ -11,6 +11,7 @@ set(MLIR_CMAKE_DIR "@MLIR_CONFIG_CMAKE_DIR@") set(MLIR_INCLUDE_DIRS "@MLIR_CONFIG_INCLUDE_DIRS@") set(MLIR_TABLEGEN_EXE "@MLIR_CONFIG_TABLEGEN_EXE@") set(MLIR_PDLL_TABLEGEN_EXE "@MLIR_CONFIG_PDLL_TABLEGEN_EXE@") +set(MLIR_SRC_SHARDER_TABLEGEN_EXE "@MLIR_CONFIG_SRC_SHARDER_TABLEGEN_EXE@") set(MLIR_INSTALL_AGGREGATE_OBJECTS "@MLIR_INSTALL_AGGREGATE_OBJECTS@") set(MLIR_ENABLE_BINDINGS_PYTHON "@MLIR_ENABLE_BINDINGS_PYTHON@") set(MLIR_ENABLE_EXECUTION_ENGINE "@MLIR_ENABLE_EXECUTION_ENGINE@") diff --git a/mlir/include/mlir/TableGen/CodeGenHelpers.h b/mlir/include/mlir/TableGen/CodeGenHelpers.h index dd17a44c889bb..c263c69c53d1e 100644 --- a/mlir/include/mlir/TableGen/CodeGenHelpers.h +++ b/mlir/include/mlir/TableGen/CodeGenHelpers.h @@ -99,8 +99,14 @@ class NamespaceEmitter { /// class StaticVerifierFunctionEmitter { public: + /// Create a constraint uniquer with a unique prefix derived from the record + /// keeper with an optional tag. StaticVerifierFunctionEmitter(raw_ostream &os, - const llvm::RecordKeeper &records); + const llvm::RecordKeeper &records, + StringRef tag = ""); + + /// Collect and unique all the constraints used by operations. + void collectOpConstraints(ArrayRef opDefs); /// Collect and unique all compatible type, attribute, successor, and region /// constraints from the operations in the file and emit them at the top of @@ -108,7 +114,7 @@ class StaticVerifierFunctionEmitter { /// /// Constraints that do not meet the restriction that they can only reference /// `$_self` and `$_op` are not uniqued. - void emitOpConstraints(ArrayRef opDefs, bool emitDecl); + void emitOpConstraints(ArrayRef opDefs); /// Unique all compatible type and attribute constraints from a pattern file /// and emit them at the top of the generated file. @@ -177,8 +183,6 @@ class StaticVerifierFunctionEmitter { /// Emit pattern constraints. void emitPatternConstraints(); - /// Collect and unique all the constraints used by operations. - void collectOpConstraints(ArrayRef opDefs); /// Collect and unique all pattern constraints. void collectPatternConstraints(ArrayRef constraints); diff --git a/mlir/lib/TableGen/CodeGenHelpers.cpp b/mlir/lib/TableGen/CodeGenHelpers.cpp index d906de6b56afc..59865146e20bc 100644 --- a/mlir/lib/TableGen/CodeGenHelpers.cpp +++ b/mlir/lib/TableGen/CodeGenHelpers.cpp @@ -24,7 +24,8 @@ using namespace mlir::tblgen; /// Generate a unique label based on the current file name to prevent name /// collisions if multiple generated files are included at once. -static std::string getUniqueOutputLabel(const llvm::RecordKeeper &records) { +static std::string getUniqueOutputLabel(const llvm::RecordKeeper &records, + StringRef tag) { // Use the input file name when generating a unique name. std::string inputFilename = records.getInputFilename(); @@ -33,7 +34,7 @@ static std::string getUniqueOutputLabel(const llvm::RecordKeeper &records) { nameRef.consume_back(".td"); // Sanitize any invalid characters. - std::string uniqueName; + std::string uniqueName(tag); for (char c : nameRef) { if (llvm::isAlnum(c) || c == '_') uniqueName.push_back(c); @@ -44,15 +45,11 @@ static std::string getUniqueOutputLabel(const llvm::RecordKeeper &records) { } StaticVerifierFunctionEmitter::StaticVerifierFunctionEmitter( - raw_ostream &os, const llvm::RecordKeeper &records) - : os(os), uniqueOutputLabel(getUniqueOutputLabel(records)) {} + raw_ostream &os, const llvm::RecordKeeper &records, StringRef tag) + : os(os), uniqueOutputLabel(getUniqueOutputLabel(records, tag)) {} void StaticVerifierFunctionEmitter::emitOpConstraints( - ArrayRef opDefs, bool emitDecl) { - collectOpConstraints(opDefs); - if (emitDecl) - return; - + ArrayRef opDefs) { NamespaceEmitter namespaceEmitter(os, Operator(*opDefs[0]).getCppNamespace()); emitTypeConstraints(); emitAttrConstraints(); diff --git a/mlir/test/lib/Dialect/Test/CMakeLists.txt b/mlir/test/lib/Dialect/Test/CMakeLists.txt index f63e4d330e6ac..fab8937809332 100644 --- a/mlir/test/lib/Dialect/Test/CMakeLists.txt +++ b/mlir/test/lib/Dialect/Test/CMakeLists.txt @@ -31,8 +31,6 @@ mlir_tablegen(TestOpEnums.cpp.inc -gen-enum-defs) add_public_tablegen_target(MLIRTestEnumDefIncGen) set(LLVM_TARGET_DEFINITIONS TestOps.td) -mlir_tablegen(TestOps.h.inc -gen-op-decls) -mlir_tablegen(TestOps.cpp.inc -gen-op-defs) mlir_tablegen(TestOpsDialect.h.inc -gen-dialect-decls -dialect=test) mlir_tablegen(TestOpsDialect.cpp.inc -gen-dialect-defs -dialect=test) mlir_tablegen(TestPatterns.inc -gen-rewriters) @@ -43,6 +41,8 @@ mlir_tablegen(TestOpsSyntax.h.inc -gen-op-decls) mlir_tablegen(TestOpsSyntax.cpp.inc -gen-op-defs) add_public_tablegen_target(MLIRTestOpsSyntaxIncGen) +add_sharded_ops(TestOps 20) + # Exclude tests from libMLIR.so add_mlir_library(MLIRTestDialect TestAttributes.cpp @@ -56,6 +56,7 @@ add_mlir_library(MLIRTestDialect TestTypes.cpp TestOpsSyntax.cpp TestDialectInterfaces.cpp + ${SHARDED_SRCS} EXCLUDE_FROM_LIBMLIR @@ -66,6 +67,7 @@ add_mlir_library(MLIRTestDialect MLIRTestTypeDefIncGen MLIRTestOpsIncGen MLIRTestOpsSyntaxIncGen + MLIRTestOpsShardGen LINK_LIBS PUBLIC MLIRControlFlowInterfaces diff --git a/mlir/test/lib/Dialect/Test/TestDialect.cpp b/mlir/test/lib/Dialect/Test/TestDialect.cpp index 77fd7e61bd3a0..bfb9592e63828 100644 --- a/mlir/test/lib/Dialect/Test/TestDialect.cpp +++ b/mlir/test/lib/Dialect/Test/TestDialect.cpp @@ -326,12 +326,9 @@ struct TestOpEffectInterfaceFallback void TestDialect::initialize() { registerAttributes(); registerTypes(); - addOperations< -#define GET_OP_LIST -#include "TestOps.cpp.inc" - >(); registerOpsSyntax(); addOperations(); + registerTestDialectOperations(this); registerDynamicOp(getDynamicGenericOp(this)); registerDynamicOp(getDynamicOneOperandTwoResultsOp(this)); registerDynamicOp(getDynamicCustomParserPrinterOp(this)); diff --git a/mlir/test/lib/Dialect/Test/TestOps.cpp b/mlir/test/lib/Dialect/Test/TestOps.cpp index ce7e476be74e6..47d5b1b19121e 100644 --- a/mlir/test/lib/Dialect/Test/TestOps.cpp +++ b/mlir/test/lib/Dialect/Test/TestOps.cpp @@ -14,5 +14,4 @@ using namespace mlir; using namespace test; -#define GET_OP_CLASSES #include "TestOps.cpp.inc" diff --git a/mlir/test/mlir-tblgen/shard-op-defs.td b/mlir/test/mlir-tblgen/shard-op-defs.td new file mode 100644 index 0000000000000..84ac6b0fbe9eb --- /dev/null +++ b/mlir/test/mlir-tblgen/shard-op-defs.td @@ -0,0 +1,33 @@ +// RUN: mlir-tblgen -gen-op-defs -op-shard-count=2 -I %S/../../include %s | FileCheck %s --check-prefix=DEFS +// RUN: mlir-tblgen -gen-op-decls -op-shard-count=2 -I %S/../../include %s | FileCheck %s --check-prefix=DECLS + +include "mlir/IR/OpBase.td" + +def Test_Dialect : Dialect { + let name = "test"; + let cppNamespace = "test"; +} + +class Test_Op traits = []> + : Op; + +def OpA : Test_Op<"a">; +def OpB : Test_Op<"b">; +def OpC : Test_Op<"c">; + +// DECLS: OpA +// DECLS: OpB +// DECLS: OpC +// DECLS: registerTestDialectOperations( +// DECLS: registerTestDialectOperations0( +// DECLS: registerTestDialectOperations1( + +// DEFS-LABEL: GET_OP_DEFS_0 +// DEFS: void test::registerTestDialectOperations( +// DEFS: void test::registerTestDialectOperations0( +// DEFS: OpAAdaptor +// DEFS: OpBAdaptor + +// DEFS-LABEL: GET_OP_DEFS_1 +// DEFS: void test::registerTestDialectOperations1( +// DEFS: OpCAdaptor diff --git a/mlir/tools/mlir-src-sharder/CMakeLists.txt b/mlir/tools/mlir-src-sharder/CMakeLists.txt new file mode 100644 index 0000000000000..4ef870b61124a --- /dev/null +++ b/mlir/tools/mlir-src-sharder/CMakeLists.txt @@ -0,0 +1,14 @@ +set(LLVM_LINK_COMPONENTS Support) +set(LIBS MLIRSupport) + +add_tablegen(mlir-src-sharder MLIR_SRC_SHARDER + mlir-src-sharder.cpp + + DEPENDS + ${LIBS} + ) + +set_target_properties(mlir-src-sharder PROPERTIES FOLDER "Tablegenning") +target_link_libraries(mlir-src-sharder PRIVATE ${LIBS}) + +mlir_check_all_link_libraries(mlir-src-sharder) diff --git a/mlir/tools/mlir-src-sharder/mlir-src-sharder.cpp b/mlir/tools/mlir-src-sharder/mlir-src-sharder.cpp new file mode 100644 index 0000000000000..dc1e2939c7d25 --- /dev/null +++ b/mlir/tools/mlir-src-sharder/mlir-src-sharder.cpp @@ -0,0 +1,114 @@ +//===- mlir-src-sharder.cpp - A tool for sharder generated source files ---===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#include "mlir/Support/FileUtilities.h" +#include "mlir/Support/LogicalResult.h" +#include "llvm/Support/CommandLine.h" +#include "llvm/Support/InitLLVM.h" +#include "llvm/Support/MemoryBuffer.h" +#include "llvm/Support/ToolOutputFile.h" + +using namespace mlir; + +/// Create a dependency file for `-d` option. +/// +/// This functionality is generally only for the benefit of the build system, +/// and is modeled after the same option in TableGen. +static LogicalResult createDependencyFile(StringRef outputFilename, + StringRef dependencyFile) { + if (outputFilename == "-") { + llvm::errs() << "error: the option -d must be used together with -o\n"; + return failure(); + } + + std::string errorMessage; + std::unique_ptr outputFile = + openOutputFile(dependencyFile, &errorMessage); + if (!outputFile) { + llvm::errs() << errorMessage << "\n"; + return failure(); + } + + outputFile->os() << outputFilename << ":\n"; + outputFile->keep(); + return success(); +} + +int main(int argc, char **argv) { + // FIXME: This is necessary because we link in TableGen, which defines its + // options as static variables.. some of which overlap with our options. + llvm::cl::ResetCommandLineParser(); + + llvm::cl::opt opShardIndex( + "op-shard-index", llvm::cl::desc("The current shard index")); + llvm::cl::opt inputFilename(llvm::cl::Positional, + llvm::cl::desc(""), + llvm::cl::init("-")); + llvm::cl::opt outputFilename( + "o", llvm::cl::desc("Output filename"), llvm::cl::value_desc("filename"), + llvm::cl::init("-")); + llvm::cl::list includeDirs( + "I", llvm::cl::desc("Directory of include files"), + llvm::cl::value_desc("directory"), llvm::cl::Prefix); + llvm::cl::opt dependencyFilename( + "d", llvm::cl::desc("Dependency filename"), + llvm::cl::value_desc("filename"), llvm::cl::init("")); + llvm::cl::opt writeIfChanged( + "write-if-changed", + llvm::cl::desc("Only write to the output file if it changed")); + + llvm::InitLLVM y(argc, argv); + llvm::cl::ParseCommandLineOptions(argc, argv); + + // Open the input file. + std::string errorMessage; + std::unique_ptr inputFile = + openInputFile(inputFilename, &errorMessage); + if (!inputFile) { + llvm::errs() << errorMessage << "\n"; + return 1; + } + + // Write the output to a buffer. + std::string outputStr; + llvm::raw_string_ostream os(outputStr); + os << "#define GET_OP_DEFS_" << opShardIndex << "\n" + << inputFile->getBuffer(); + + // Determine whether we need to write the output file. + bool shouldWriteOutput = true; + if (writeIfChanged) { + // Only update the real output file if there are any differences. This + // prevents recompilation of all the files depending on it if there aren't + // any. + if (auto existingOrErr = + llvm::MemoryBuffer::getFile(outputFilename, /*IsText=*/true)) + if (std::move(existingOrErr.get())->getBuffer() == os.str()) + shouldWriteOutput = false; + } + + // Populate the output file if necessary. + if (shouldWriteOutput) { + std::unique_ptr outputFile = + openOutputFile(outputFilename, &errorMessage); + if (!outputFile) { + llvm::errs() << errorMessage << "\n"; + return 1; + } + outputFile->os() << os.str(); + outputFile->keep(); + } + + // Always write the depfile, even if the main output hasn't changed. If it's + // missing, Ninja considers the output dirty. + if (!dependencyFilename.empty()) + if (failed(createDependencyFile(outputFilename, dependencyFilename))) + return 1; + + return 0; +} diff --git a/mlir/tools/mlir-tblgen/OpDefinitionsGen.cpp b/mlir/tools/mlir-tblgen/OpDefinitionsGen.cpp index 53ed5cb7c043e..63fe5a8099074 100644 --- a/mlir/tools/mlir-tblgen/OpDefinitionsGen.cpp +++ b/mlir/tools/mlir-tblgen/OpDefinitionsGen.cpp @@ -4303,32 +4303,15 @@ void OpOperandAdaptorEmitter::emitDef( emitter.adaptor.writeDefTo(os); } -// Emits the opcode enum and op classes. -static void emitOpClasses(const RecordKeeper &recordKeeper, - const std::vector &defs, raw_ostream &os, - bool emitDecl) { - // First emit forward declaration for each class, this allows them to refer - // to each others in traits for example. - if (emitDecl) { - os << "#if defined(GET_OP_CLASSES) || defined(GET_OP_FWD_DEFINES)\n"; - os << "#undef GET_OP_FWD_DEFINES\n"; - for (auto *def : defs) { - Operator op(*def); - NamespaceEmitter emitter(os, op.getCppNamespace()); - os << "class " << op.getCppClassName() << ";\n"; - } - os << "#endif\n\n"; - } - - IfDefScope scope("GET_OP_CLASSES", os); +/// Emit the class declarations or definitions for the given op defs. +static void +emitOpClasses(const RecordKeeper &recordKeeper, + const std::vector &defs, raw_ostream &os, + const StaticVerifierFunctionEmitter &staticVerifierEmitter, + bool emitDecl) { if (defs.empty()) return; - // Generate all of the locally instantiated methods first. - StaticVerifierFunctionEmitter staticVerifierEmitter(os, recordKeeper); - os << formatv(opCommentHeader, "Local Utility Method", "Definitions"); - staticVerifierEmitter.emitOpConstraints(defs, emitDecl); - for (auto *def : defs) { Operator op(*def); if (emitDecl) { @@ -4358,34 +4341,145 @@ static void emitOpClasses(const RecordKeeper &recordKeeper, } } -// Emits a comma-separated list of the ops. -static void emitOpList(const std::vector &defs, raw_ostream &os) { - IfDefScope scope("GET_OP_LIST", os); +/// Emit the declarations for the provided op classes. +static void emitOpClassDecls(const RecordKeeper &recordKeeper, + const std::vector &defs, + raw_ostream &os) { + // First emit forward declaration for each class, this allows them to refer + // to each others in traits for example. + for (auto *def : defs) { + Operator op(*def); + NamespaceEmitter emitter(os, op.getCppNamespace()); + os << "class " << op.getCppClassName() << ";\n"; + } + + // Emit the op class declarations. + IfDefScope scope("GET_OP_CLASSES", os); + if (defs.empty()) + return; + StaticVerifierFunctionEmitter staticVerifierEmitter(os, recordKeeper); + staticVerifierEmitter.collectOpConstraints(defs); + emitOpClasses(recordKeeper, defs, os, staticVerifierEmitter, + /*emitDecl=*/true); +} + +/// Emit the definitions for the provided op classes. +static void emitOpClassDefs(const RecordKeeper &recordKeeper, + ArrayRef defs, raw_ostream &os, + StringRef constraintPrefix = "") { + if (defs.empty()) + return; + + // Generate all of the locally instantiated methods first. + StaticVerifierFunctionEmitter staticVerifierEmitter(os, recordKeeper, + constraintPrefix); + os << formatv(opCommentHeader, "Local Utility Method", "Definitions"); + staticVerifierEmitter.collectOpConstraints(defs); + staticVerifierEmitter.emitOpConstraints(defs); - interleave( - // TODO: We are constructing the Operator wrapper instance just for - // getting it's qualified class name here. Reduce the overhead by having a - // lightweight version of Operator class just for that purpose. - defs, [&os](Record *def) { os << Operator(def).getQualCppClassName(); }, - [&os]() { os << ",\n"; }); + // Emit the classes. + emitOpClasses(recordKeeper, defs, os, staticVerifierEmitter, + /*emitDecl=*/false); } +/// Emit op declarations for all op records. static bool emitOpDecls(const RecordKeeper &recordKeeper, raw_ostream &os) { emitSourceFileHeader("Op Declarations", os, recordKeeper); std::vector defs = getRequestedOpDefinitions(recordKeeper); - emitOpClasses(recordKeeper, defs, os, /*emitDecl=*/true); + emitOpClassDecls(recordKeeper, defs, os); + + // If we are generating sharded op definitions, emit the sharded op + // registration hooks. + SmallVector, 4> shardedDefs; + shardOpDefinitions(defs, shardedDefs); + if (defs.empty() || shardedDefs.size() <= 1) + return false; + + Dialect dialect = Operator(defs.front()).getDialect(); + NamespaceEmitter ns(os, dialect); + + const char *const opRegistrationHook = + "void register{0}Operations{1}({2}::{0} *dialect);\n"; + os << formatv(opRegistrationHook, dialect.getCppClassName(), "", + dialect.getCppNamespace()); + for (unsigned i = 0; i < shardedDefs.size(); ++i) { + os << formatv(opRegistrationHook, dialect.getCppClassName(), i, + dialect.getCppNamespace()); + } return false; } +/// Generate the dialect op registration hook and the op class definitions for a +/// shard of ops. +static void emitOpDefShard(const RecordKeeper &recordKeeper, + ArrayRef defs, const Dialect &dialect, + unsigned shardIndex, unsigned shardCount, + raw_ostream &os) { + std::string shardGuard = "GET_OP_DEFS_"; + std::string indexStr = std::to_string(shardIndex); + shardGuard += indexStr; + IfDefScope scope(shardGuard, os); + + // Emit the op registration hook in the first shard. + const char *const opRegistrationHook = + "void {0}::register{1}Operations{2}({0}::{1} *dialect) {{\n"; + if (shardIndex == 0) { + os << formatv(opRegistrationHook, dialect.getCppNamespace(), + dialect.getCppClassName(), ""); + for (unsigned i = 0; i < shardCount; ++i) { + os << formatv(" {0}::register{1}Operations{2}(dialect);\n", + dialect.getCppNamespace(), dialect.getCppClassName(), i); + } + os << "}\n"; + } + + // Generate the per-shard op registration hook. + os << formatv(opCommentHeader, dialect.getCppClassName(), + "Op Registration Hook") + << formatv(opRegistrationHook, dialect.getCppNamespace(), + dialect.getCppClassName(), shardIndex); + for (Record *def : defs) { + os << formatv(" ::mlir::RegisteredOperationName::insert<{0}>(*dialect);\n", + Operator(def).getQualCppClassName()); + } + os << "}\n"; + + // Generate the per-shard op definitions. + emitOpClassDefs(recordKeeper, defs, os, indexStr); +} + +/// Emit op definitions for all op records. static bool emitOpDefs(const RecordKeeper &recordKeeper, raw_ostream &os) { emitSourceFileHeader("Op Definitions", os, recordKeeper); std::vector defs = getRequestedOpDefinitions(recordKeeper); - emitOpList(defs, os); - emitOpClasses(recordKeeper, defs, os, /*emitDecl=*/false); + SmallVector, 4> shardedDefs; + shardOpDefinitions(defs, shardedDefs); + + // If no shard was requested, emit the regular op list and class definitions. + if (shardedDefs.size() == 1) { + { + IfDefScope scope("GET_OP_LIST", os); + interleave( + defs, os, + [&](Record *def) { os << Operator(def).getQualCppClassName(); }, + ",\n"); + } + { + IfDefScope scope("GET_OP_CLASSES", os); + emitOpClassDefs(recordKeeper, defs, os); + } + return false; + } + if (defs.empty()) + return false; + Dialect dialect = Operator(defs.front()).getDialect(); + for (auto [idx, value] : llvm::enumerate(shardedDefs)) { + emitOpDefShard(recordKeeper, value, dialect, idx, shardedDefs.size(), os); + } return false; } diff --git a/mlir/tools/mlir-tblgen/OpGenHelpers.cpp b/mlir/tools/mlir-tblgen/OpGenHelpers.cpp index 7fd34df8460d3..c2a2423a24026 100644 --- a/mlir/tools/mlir-tblgen/OpGenHelpers.cpp +++ b/mlir/tools/mlir-tblgen/OpGenHelpers.cpp @@ -31,6 +31,10 @@ static cl::opt opExcFilter( "op-exclude-regex", cl::desc("Regex of name of op's to exclude (no filter if empty)"), cl::cat(opDefGenCat)); +static cl::opt opShardCount( + "op-shard-count", + cl::desc("The number of shards into which the op classes will be divided"), + cl::cat(opDefGenCat), cl::init(1)); static std::string getOperationName(const Record &def) { auto prefix = def.getValueAsDef("opDialect")->getValueAsString("name"); @@ -79,4 +83,23 @@ bool mlir::tblgen::isPythonReserved(StringRef str) { reserved.insert("issubclass"); reserved.insert("type"); return reserved.contains(str); -} \ No newline at end of file +} + +void mlir::tblgen::shardOpDefinitions( + ArrayRef defs, + SmallVectorImpl> &shardedDefs) { + assert(opShardCount > 0 && "expected a positive shard count"); + if (opShardCount == 1) { + shardedDefs.push_back(defs); + return; + } + + unsigned minShardSize = defs.size() / opShardCount; + unsigned numMissing = defs.size() - minShardSize * opShardCount; + shardedDefs.reserve(opShardCount); + for (unsigned i = 0, start = 0; i < opShardCount; ++i) { + unsigned size = minShardSize + (i < numMissing); + shardedDefs.push_back(defs.slice(start, size)); + start += size; + } +} diff --git a/mlir/tools/mlir-tblgen/OpGenHelpers.h b/mlir/tools/mlir-tblgen/OpGenHelpers.h index 3dcff14d1221e..1b43d5d3ce3a7 100644 --- a/mlir/tools/mlir-tblgen/OpGenHelpers.h +++ b/mlir/tools/mlir-tblgen/OpGenHelpers.h @@ -13,6 +13,7 @@ #ifndef MLIR_TOOLS_MLIRTBLGEN_OPGENHELPERS_H_ #define MLIR_TOOLS_MLIRTBLGEN_OPGENHELPERS_H_ +#include "mlir/Support/LLVM.h" #include "llvm/TableGen/Record.h" #include @@ -28,6 +29,10 @@ getRequestedOpDefinitions(const llvm::RecordKeeper &recordKeeper); /// Regenerate using python -c"print(set(sorted(__import__('keyword').kwlist)))" bool isPythonReserved(llvm::StringRef str); +/// Shard the op defintions into the number of shards set by "op-shard-count". +void shardOpDefinitions(ArrayRef defs, + SmallVectorImpl> &shardedDefs); + } // namespace tblgen } // namespace mlir diff --git a/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel b/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel index 6c732b8f13490..00a40019ac2e7 100644 --- a/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel +++ b/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel @@ -9727,6 +9727,15 @@ cc_binary( ], ) +cc_binary( + name = "mlir-src-sharder", + srcs = ["tools/mlir-src-sharder/mlir-src-sharder.cpp"], + deps = [ + ":Support", + "//llvm:Support", + ], +) + cc_binary( name = "mlir-linalg-ods-yaml-gen", srcs = [ diff --git a/utils/bazel/llvm-project-overlay/mlir/tblgen.bzl b/utils/bazel/llvm-project-overlay/mlir/tblgen.bzl index fdf6a57107ac3..e45ba1fe0ef72 100644 --- a/utils/bazel/llvm-project-overlay/mlir/tblgen.bzl +++ b/utils/bazel/llvm-project-overlay/mlir/tblgen.bzl @@ -432,3 +432,136 @@ def gentbl_cc_library( copts = copts, **kwargs ) + +def _gentbl_shard_impl(ctx): + args = ctx.actions.args() + args.add(ctx.file.src_file) + args.add("-op-shard-index", ctx.attr.index) + args.add("-o", ctx.outputs.out.path) + ctx.actions.run( + outputs = [ctx.outputs.out], + inputs = [ctx.file.src_file], + executable = ctx.executable.sharder, + arguments = [args], + use_default_shell_env = True, + mnemonic = "ShardGenerate", + ) + +gentbl_shard_rule = rule( + _gentbl_shard_impl, + doc = "", + output_to_genfiles = True, + attrs = { + "index": attr.int(mandatory = True, doc = ""), + "sharder": attr.label( + doc = "", + executable = True, + cfg = "exec", + ), + "src_file": attr.label( + doc = "", + allow_single_file = True, + mandatory = True, + ), + "out": attr.output( + doc = "", + mandatory = True, + ), + }, +) + +def gentbl_sharded_ops( + name, + tblgen, + sharder, + td_file, + shard_count, + src_file, + src_out, + hdr_out, + test = False, + includes = [], + strip_include_prefix = None, + deps = []): + """Generate sharded op declarations and definitions. + + This special build rule shards op definitions in a TableGen file and generates multiple copies + of a template source file for including and compiling each shard. The rule defines a filegroup + consisting of the source shards, the generated source file, and the generated header file. + + Args: + name: The name of the filegroup. + tblgen: The binary used to produce the output. + sharder: The source file sharder to use. + td_file: The primary table definitions file. + shard_count: The number of op definition shards to produce. + src_file: The source file template. + src_out: The generated source file. + hdr_out: The generated header file. + test: Whether this is a test target. + includes: See gentbl_rule.includes + deps: See gentbl_rule.deps + strip_include_prefix: Attribute to pass through to cc_library. + """ + cc_lib_name = name + "__gentbl_cc_lib" + gentbl_cc_library( + name = cc_lib_name, + strip_include_prefix = strip_include_prefix, + includes = includes, + tbl_outs = [ + ( + [ + "-gen-op-defs", + "-op-shard-count=" + str(shard_count), + ], + src_out, + ), + ( + [ + "-gen-op-decls", + "-op-shard-count=" + str(shard_count), + ], + hdr_out, + ), + ], + tblgen = tblgen, + td_file = td_file, + test = test, + deps = deps, + ) + all_files = [hdr_out, src_out] + for i in range(0, shard_count): + out_file = "shard_copy_" + str(i) + "_" + src_file + gentbl_shard_rule( + index = i, + name = name + "__src_shard" + str(i), + testonly = test, + out = out_file, + sharder = sharder, + src_file = src_file, + ) + all_files.append(out_file) + native.filegroup(name = name, srcs = all_files) + +def gentbl_sharded_op_defs(name, source_file, shard_count): + """Generates multiple copies of a source file that includes sharded op definitions. + + Args: + name: The name of the rule. + source_file: The source to copy. + shard_count: The number of shards. + + Returns: + A list of the copied filenames to be included in the dialect library. + """ + copies = [] + for i in range(0, shard_count): + out_file = "shard_copy_" + str(i) + "_" + source_file + copies.append(out_file) + native.genrule( + name = name + "_shard_" + str(i), + srcs = [source_file], + outs = [out_file], + cmd = "echo -e \"#define GET_OP_DEFS_" + str(i) + "\n$$(cat $(SRCS))\" > $(OUTS)", + ) + return copies diff --git a/utils/bazel/llvm-project-overlay/mlir/test/BUILD.bazel b/utils/bazel/llvm-project-overlay/mlir/test/BUILD.bazel index dc5f4047c286d..b98f7eb5613af 100644 --- a/utils/bazel/llvm-project-overlay/mlir/test/BUILD.bazel +++ b/utils/bazel/llvm-project-overlay/mlir/test/BUILD.bazel @@ -4,7 +4,7 @@ load("@bazel_skylib//rules:expand_template.bzl", "expand_template") load("//llvm:lit_test.bzl", "package_path") -load("//mlir:tblgen.bzl", "gentbl_cc_library", "td_library") +load("//mlir:tblgen.bzl", "gentbl_cc_library", "td_library", "gentbl_sharded_ops", "td_library") package( default_visibility = ["//visibility:public"], @@ -151,14 +151,6 @@ gentbl_cc_library( name = "TestOpsIncGen", strip_include_prefix = "lib/Dialect/Test", tbl_outs = [ - ( - ["-gen-op-decls"], - "lib/Dialect/Test/TestOps.h.inc", - ), - ( - ["-gen-op-defs"], - "lib/Dialect/Test/TestOps.cpp.inc", - ), ( [ "-gen-dialect-decls", @@ -370,12 +362,25 @@ cc_library( ], ) +gentbl_sharded_ops( + name = "TestDialectOpSrcs", + hdr_out = "lib/Dialect/Test/TestOps.h.inc", + shard_count = 20, + sharder = "//mlir:mlir-src-sharder", + src_file = "lib/Dialect/Test/TestOps.cpp", + src_out = "lib/Dialect/Test/TestOps.cpp.inc", + tblgen = "//mlir:mlir-tblgen", + td_file = "lib/Dialect/Test/TestOps.td", + test = True, + deps = [":TestOpTdFiles"], +) + cc_library( name = "TestDialect", srcs = glob( ["lib/Dialect/Test/*.cpp"], exclude = ["lib/Dialect/Test/TestToLLVMIRTranslation.cpp"], - ), + ) + [":TestDialectOpSrcs"], hdrs = glob(["lib/Dialect/Test/*.h"]), includes = [ "lib/Dialect/Test",