Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[mlir][ods] Allow sharding of op definitions #89423

Merged
merged 1 commit into from
Apr 24, 2024
Merged

Conversation

Mogball
Copy link
Contributor

@Mogball Mogball commented Apr 19, 2024

Stacked PRs:


[mlir][ods] Allow sharding of op definitions

Adds an option to mlir-tblgen -gen-op-defs op-shard-count=N that divides the
op class definitions and op list into N segments, e.g.

// mlir-tblgen -gen-op-defs -op-shard-count=2

void FooDialect::initialize() {
  addOperations<
  >();
  addOperations<
  >();
}

When split across multiple source files, this can help significantly improve
dialect compile time for dialects with a large opset.

@Mogball Mogball requested a review from rupprecht as a code owner April 19, 2024 17:51
@llvmbot llvmbot added mlir:core MLIR Core Infrastructure mlir bazel "Peripheral" support tier build system: utils/bazel labels Apr 19, 2024
@llvmbot
Copy link
Collaborator

llvmbot commented Apr 19, 2024

@llvm/pr-subscribers-mlir

@llvm/pr-subscribers-mlir-core

Author: Jeff Niu (Mogball)

Changes

Adds an option to mlir-tblgen -gen-op-defs op-shard-count=N that divides the op class definitions and op list into N segments, e.g.

// mlir-tblgen -gen-op-defs -op-shard-count=2

void FooDialect::initialize() {
  addOperations&lt;
  &gt;();
  addOperations&lt;
  &gt;();
}

When split across multiple source files, this can help significantly improve dialect compile time for dialects with a large opset.


Patch is 29.98 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/89423.diff

14 Files Affected:

  • (modified) mlir/CMakeLists.txt (+3)
  • (modified) mlir/cmake/modules/AddMLIR.cmake (+38)
  • (modified) mlir/cmake/modules/CMakeLists.txt (+2)
  • (modified) mlir/cmake/modules/MLIRConfig.cmake.in (+1)
  • (modified) mlir/include/mlir/TableGen/CodeGenHelpers.h (+8-4)
  • (modified) mlir/lib/TableGen/CodeGenHelpers.cpp (+6-9)
  • (added) mlir/test/mlir-tblgen/shard-op-defs.td (+33)
  • (added) mlir/tools/mlir-src-sharder/CMakeLists.txt (+14)
  • (added) mlir/tools/mlir-src-sharder/mlir-src-sharder.cpp (+114)
  • (modified) mlir/tools/mlir-tblgen/OpDefinitionsGen.cpp (+129-35)
  • (modified) mlir/tools/mlir-tblgen/OpGenHelpers.cpp (+24-1)
  • (modified) mlir/tools/mlir-tblgen/OpGenHelpers.h (+5)
  • (modified) utils/bazel/llvm-project-overlay/mlir/BUILD.bazel (+9)
  • (modified) utils/bazel/llvm-project-overlay/mlir/tblgen.bzl (+133)
diff --git a/mlir/CMakeLists.txt b/mlir/CMakeLists.txt
index 5c4301af040b47..4c0ef8387b8dff 100644
--- a/mlir/CMakeLists.txt
+++ b/mlir/CMakeLists.txt
@@ -185,10 +185,13 @@ include_directories( ${MLIR_INCLUDE_DIR})
 add_subdirectory(tools/mlir-linalg-ods-gen)
 add_subdirectory(tools/mlir-pdll)
 add_subdirectory(tools/mlir-tblgen)
+add_subdirectory(tools/mlir-src-sharder)
 set(MLIR_TABLEGEN_EXE "${MLIR_TABLEGEN_EXE}" CACHE INTERNAL "")
 set(MLIR_TABLEGEN_TARGET "${MLIR_TABLEGEN_TARGET}" CACHE INTERNAL "")
 set(MLIR_PDLL_TABLEGEN_EXE "${MLIR_PDLL_TABLEGEN_EXE}" CACHE INTERNAL "")
 set(MLIR_PDLL_TABLEGEN_TARGET "${MLIR_PDLL_TABLEGEN_TARGET}" CACHE INTERNAL "")
+set(MLIR_SRC_SHARDER_TABLEGEN_EXE "${MLIR_SRC_SHARDER_TABLEGEN_EXE}" CACHE INTERNAL "")
+set(MLIR_SRC_SHARDER_TABLEGEN_TARGET "${MLIR_SRC_SHARDER_TABLEGEN_TARGET}" CACHE INTERNAL "")
 
 add_subdirectory(include/mlir)
 add_subdirectory(lib)
diff --git a/mlir/cmake/modules/AddMLIR.cmake b/mlir/cmake/modules/AddMLIR.cmake
index 1d2ed748bc2f13..afb74fb2d00025 100644
--- a/mlir/cmake/modules/AddMLIR.cmake
+++ b/mlir/cmake/modules/AddMLIR.cmake
@@ -5,6 +5,28 @@ function(mlir_tablegen ofn)
   tablegen(MLIR ${ARGV})
   set(TABLEGEN_OUTPUT ${TABLEGEN_OUTPUT} ${CMAKE_CURRENT_BINARY_DIR}/${ofn}
       PARENT_SCOPE)
+
+  # Get the current set of include paths for this td file.
+  cmake_parse_arguments(ARG "" "" "DEPENDS;EXTRA_INCLUDES" ${ARGN})
+  get_directory_property(tblgen_includes INCLUDE_DIRECTORIES)
+  list(APPEND tblgen_includes ${ARG_EXTRA_INCLUDES})
+  # Filter out any empty include items.
+  list(REMOVE_ITEM tblgen_includes "")
+
+  # Build the absolute path for the current input file.
+  if (IS_ABSOLUTE ${LLVM_TARGET_DEFINITIONS})
+    set(LLVM_TARGET_DEFINITIONS_ABSOLUTE ${LLVM_TARGET_DEFINITIONS})
+  else()
+    set(LLVM_TARGET_DEFINITIONS_ABSOLUTE ${CMAKE_CURRENT_SOURCE_DIR}/${LLVM_TARGET_DEFINITIONS})
+  endif()
+
+  # Append the includes used for this file to the tablegen_compile_commands
+  # file.
+  file(APPEND ${CMAKE_BINARY_DIR}/tablegen_compile_commands.yml
+      "--- !FileInfo:\n"
+      "  filepath: \"${LLVM_TARGET_DEFINITIONS_ABSOLUTE}\"\n"
+      "  includes: \"${CMAKE_CURRENT_SOURCE_DIR};${tblgen_includes}\"\n"
+  )
 endfunction()
 
 # Clear out any pre-existing compile_commands file before processing. This
@@ -149,6 +171,22 @@ function(add_mlir_dialect dialect dialect_namespace)
   add_dependencies(mlir-headers MLIR${dialect}IncGen)
 endfunction()
 
+# Declare sharded dialect operation declarations and definitions
+function(add_sharded_ops ops_target shard_count)
+  set(LLVM_TARGET_DEFINITIONS ${ops_target}.td)
+  mlir_tablegen(${ops_target}.h.inc -gen-op-decls -op-shard-count=${shard_count})
+  mlir_tablegen(${ops_target}.cpp.inc -gen-op-defs -op-shard-count=${shard_count})
+  set(LLVM_TARGET_DEFINITIONS ${ops_target}.cpp)
+  foreach(index RANGE ${shard_count})
+    set(SHARDED_SRC ${ops_target}.${index}.cpp)
+    list(APPEND SHARDED_SRCS ${SHARDED_SRC})
+    tablegen(MLIR_SRC_SHARDER ${SHARDED_SRC} -op-shard-index=${index})
+    set(TABLEGEN_OUTPUT ${TABLEGEN_OUTPUT} ${CMAKE_CURRENT_BINARY_DIR}/${SHARDED_SRC})
+  endforeach()
+  add_public_tablegen_target(MLIR${ops_target}ShardGen)
+  set(SHARDED_SRCS ${SHARDED_SRCS} PARENT_SCOPE)
+endfunction()
+
 # Declare a dialect in the include directory
 function(add_mlir_interface interface)
   set(LLVM_TARGET_DEFINITIONS ${interface}.td)
diff --git a/mlir/cmake/modules/CMakeLists.txt b/mlir/cmake/modules/CMakeLists.txt
index 8d2904ef46dfe8..3ac1c79b090ed6 100644
--- a/mlir/cmake/modules/CMakeLists.txt
+++ b/mlir/cmake/modules/CMakeLists.txt
@@ -39,6 +39,7 @@ set(MLIR_CONFIG_INCLUDE_DIRS
 # Refer to the best host mlir-tbgen, which might be a host-optimized version
 set(MLIR_CONFIG_TABLEGEN_EXE "${MLIR_TABLEGEN_EXE}")
 set(MLIR_CONFIG_PDLL_TABLEGEN_EXE "${MLIR_PDLL_TABLEGEN_EXE}")
+set(MLIR_CONFIG_SRC_SHARDER_TABLEGEN_EXE "${MLIR_SRC_SHARDER_TABLEGEN_EXE}")
 
 configure_file(
   ${CMAKE_CURRENT_SOURCE_DIR}/MLIRConfig.cmake.in
@@ -77,6 +78,7 @@ set(MLIR_CONFIG_INCLUDE_DIRS
 # if we're building with a host-optimized mlir-tblgen (with LLVM_OPTIMIZED_TABLEGEN).
 set(MLIR_CONFIG_TABLEGEN_EXE mlir-tblgen)
 set(MLIR_CONFIG_PDLL_TABLEGEN_EXE mlir-pdll)
+set(MLIR_CONFIG_SRC_SHARDER_TABLEGEN_EXE mlir-src-sharder)
 
 configure_file(
   ${CMAKE_CURRENT_SOURCE_DIR}/MLIRConfig.cmake.in
diff --git a/mlir/cmake/modules/MLIRConfig.cmake.in b/mlir/cmake/modules/MLIRConfig.cmake.in
index d4da3cd98cce98..7076d94a32f2bc 100644
--- a/mlir/cmake/modules/MLIRConfig.cmake.in
+++ b/mlir/cmake/modules/MLIRConfig.cmake.in
@@ -11,6 +11,7 @@ set(MLIR_CMAKE_DIR "@MLIR_CONFIG_CMAKE_DIR@")
 set(MLIR_INCLUDE_DIRS "@MLIR_CONFIG_INCLUDE_DIRS@")
 set(MLIR_TABLEGEN_EXE "@MLIR_CONFIG_TABLEGEN_EXE@")
 set(MLIR_PDLL_TABLEGEN_EXE "@MLIR_CONFIG_PDLL_TABLEGEN_EXE@")
+set(MLIR_SRC_SHARDER_TABLEGEN_EXE "@MLIR_CONFIG_SRC_SHARDER_TABLEGEN_EXE@")
 set(MLIR_INSTALL_AGGREGATE_OBJECTS "@MLIR_INSTALL_AGGREGATE_OBJECTS@")
 set(MLIR_ENABLE_BINDINGS_PYTHON "@MLIR_ENABLE_BINDINGS_PYTHON@")
 set(MLIR_ENABLE_EXECUTION_ENGINE "@MLIR_ENABLE_EXECUTION_ENGINE@")
diff --git a/mlir/include/mlir/TableGen/CodeGenHelpers.h b/mlir/include/mlir/TableGen/CodeGenHelpers.h
index dd17a44c889bbe..c263c69c53d1e3 100644
--- a/mlir/include/mlir/TableGen/CodeGenHelpers.h
+++ b/mlir/include/mlir/TableGen/CodeGenHelpers.h
@@ -99,8 +99,14 @@ class NamespaceEmitter {
 ///
 class StaticVerifierFunctionEmitter {
 public:
+  /// Create a constraint uniquer with a unique prefix derived from the record
+  /// keeper with an optional tag.
   StaticVerifierFunctionEmitter(raw_ostream &os,
-                                const llvm::RecordKeeper &records);
+                                const llvm::RecordKeeper &records,
+                                StringRef tag = "");
+
+  /// Collect and unique all the constraints used by operations.
+  void collectOpConstraints(ArrayRef<llvm::Record *> opDefs);
 
   /// Collect and unique all compatible type, attribute, successor, and region
   /// constraints from the operations in the file and emit them at the top of
@@ -108,7 +114,7 @@ class StaticVerifierFunctionEmitter {
   ///
   /// Constraints that do not meet the restriction that they can only reference
   /// `$_self` and `$_op` are not uniqued.
-  void emitOpConstraints(ArrayRef<llvm::Record *> opDefs, bool emitDecl);
+  void emitOpConstraints(ArrayRef<llvm::Record *> opDefs);
 
   /// Unique all compatible type and attribute constraints from a pattern file
   /// and emit them at the top of the generated file.
@@ -177,8 +183,6 @@ class StaticVerifierFunctionEmitter {
   /// Emit pattern constraints.
   void emitPatternConstraints();
 
-  /// Collect and unique all the constraints used by operations.
-  void collectOpConstraints(ArrayRef<llvm::Record *> opDefs);
   /// Collect and unique all pattern constraints.
   void collectPatternConstraints(ArrayRef<DagLeaf> constraints);
 
diff --git a/mlir/lib/TableGen/CodeGenHelpers.cpp b/mlir/lib/TableGen/CodeGenHelpers.cpp
index d906de6b56afc0..59865146e20bc4 100644
--- a/mlir/lib/TableGen/CodeGenHelpers.cpp
+++ b/mlir/lib/TableGen/CodeGenHelpers.cpp
@@ -24,7 +24,8 @@ using namespace mlir::tblgen;
 
 /// Generate a unique label based on the current file name to prevent name
 /// collisions if multiple generated files are included at once.
-static std::string getUniqueOutputLabel(const llvm::RecordKeeper &records) {
+static std::string getUniqueOutputLabel(const llvm::RecordKeeper &records,
+                                        StringRef tag) {
   // Use the input file name when generating a unique name.
   std::string inputFilename = records.getInputFilename();
 
@@ -33,7 +34,7 @@ static std::string getUniqueOutputLabel(const llvm::RecordKeeper &records) {
   nameRef.consume_back(".td");
 
   // Sanitize any invalid characters.
-  std::string uniqueName;
+  std::string uniqueName(tag);
   for (char c : nameRef) {
     if (llvm::isAlnum(c) || c == '_')
       uniqueName.push_back(c);
@@ -44,15 +45,11 @@ static std::string getUniqueOutputLabel(const llvm::RecordKeeper &records) {
 }
 
 StaticVerifierFunctionEmitter::StaticVerifierFunctionEmitter(
-    raw_ostream &os, const llvm::RecordKeeper &records)
-    : os(os), uniqueOutputLabel(getUniqueOutputLabel(records)) {}
+    raw_ostream &os, const llvm::RecordKeeper &records, StringRef tag)
+    : os(os), uniqueOutputLabel(getUniqueOutputLabel(records, tag)) {}
 
 void StaticVerifierFunctionEmitter::emitOpConstraints(
-    ArrayRef<llvm::Record *> opDefs, bool emitDecl) {
-  collectOpConstraints(opDefs);
-  if (emitDecl)
-    return;
-
+    ArrayRef<llvm::Record *> opDefs) {
   NamespaceEmitter namespaceEmitter(os, Operator(*opDefs[0]).getCppNamespace());
   emitTypeConstraints();
   emitAttrConstraints();
diff --git a/mlir/test/mlir-tblgen/shard-op-defs.td b/mlir/test/mlir-tblgen/shard-op-defs.td
new file mode 100644
index 00000000000000..84ac6b0fbe9ebe
--- /dev/null
+++ b/mlir/test/mlir-tblgen/shard-op-defs.td
@@ -0,0 +1,33 @@
+// RUN: mlir-tblgen -gen-op-defs -op-shard-count=2 -I %S/../../include %s | FileCheck %s --check-prefix=DEFS
+// RUN: mlir-tblgen -gen-op-decls -op-shard-count=2 -I %S/../../include %s | FileCheck %s --check-prefix=DECLS
+
+include "mlir/IR/OpBase.td"
+
+def Test_Dialect : Dialect {
+  let name = "test";
+  let cppNamespace = "test";
+}
+
+class Test_Op<string mnemonic, list<Trait> traits = []> 
+    : Op<Test_Dialect, mnemonic, traits>;
+
+def OpA : Test_Op<"a">;
+def OpB : Test_Op<"b">;
+def OpC : Test_Op<"c">;
+
+// DECLS: OpA
+// DECLS: OpB
+// DECLS: OpC
+// DECLS: registerTestDialectOperations(
+// DECLS: registerTestDialectOperations0(
+// DECLS: registerTestDialectOperations1(
+
+// DEFS-LABEL: GET_OP_DEFS_0
+// DEFS: void test::registerTestDialectOperations(
+// DEFS: void test::registerTestDialectOperations0(
+// DEFS: OpAAdaptor
+// DEFS: OpBAdaptor
+
+// DEFS-LABEL: GET_OP_DEFS_1
+// DEFS: void test::registerTestDialectOperations1(
+// DEFS: OpCAdaptor
diff --git a/mlir/tools/mlir-src-sharder/CMakeLists.txt b/mlir/tools/mlir-src-sharder/CMakeLists.txt
new file mode 100644
index 00000000000000..4ef870b61124ad
--- /dev/null
+++ b/mlir/tools/mlir-src-sharder/CMakeLists.txt
@@ -0,0 +1,14 @@
+set(LLVM_LINK_COMPONENTS Support)
+set(LIBS MLIRSupport)
+
+add_tablegen(mlir-src-sharder MLIR_SRC_SHARDER
+  mlir-src-sharder.cpp
+
+  DEPENDS
+  ${LIBS}
+  )
+
+set_target_properties(mlir-src-sharder PROPERTIES FOLDER "Tablegenning")
+target_link_libraries(mlir-src-sharder PRIVATE ${LIBS})
+
+mlir_check_all_link_libraries(mlir-src-sharder)
diff --git a/mlir/tools/mlir-src-sharder/mlir-src-sharder.cpp b/mlir/tools/mlir-src-sharder/mlir-src-sharder.cpp
new file mode 100644
index 00000000000000..dc1e2939c7d25b
--- /dev/null
+++ b/mlir/tools/mlir-src-sharder/mlir-src-sharder.cpp
@@ -0,0 +1,114 @@
+//===- mlir-src-sharder.cpp - A tool for sharder generated source files ---===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Support/FileUtilities.h"
+#include "mlir/Support/LogicalResult.h"
+#include "llvm/Support/CommandLine.h"
+#include "llvm/Support/InitLLVM.h"
+#include "llvm/Support/MemoryBuffer.h"
+#include "llvm/Support/ToolOutputFile.h"
+
+using namespace mlir;
+
+/// Create a dependency file for `-d` option.
+///
+/// This functionality is generally only for the benefit of the build system,
+/// and is modeled after the same option in TableGen.
+static LogicalResult createDependencyFile(StringRef outputFilename,
+                                          StringRef dependencyFile) {
+  if (outputFilename == "-") {
+    llvm::errs() << "error: the option -d must be used together with -o\n";
+    return failure();
+  }
+
+  std::string errorMessage;
+  std::unique_ptr<llvm::ToolOutputFile> outputFile =
+      openOutputFile(dependencyFile, &errorMessage);
+  if (!outputFile) {
+    llvm::errs() << errorMessage << "\n";
+    return failure();
+  }
+
+  outputFile->os() << outputFilename << ":\n";
+  outputFile->keep();
+  return success();
+}
+
+int main(int argc, char **argv) {
+  // FIXME: This is necessary because we link in TableGen, which defines its
+  // options as static variables.. some of which overlap with our options.
+  llvm::cl::ResetCommandLineParser();
+
+  llvm::cl::opt<unsigned> opShardIndex(
+      "op-shard-index", llvm::cl::desc("The current shard index"));
+  llvm::cl::opt<std::string> inputFilename(llvm::cl::Positional,
+                                           llvm::cl::desc("<input file>"),
+                                           llvm::cl::init("-"));
+  llvm::cl::opt<std::string> outputFilename(
+      "o", llvm::cl::desc("Output filename"), llvm::cl::value_desc("filename"),
+      llvm::cl::init("-"));
+  llvm::cl::list<std::string> includeDirs(
+      "I", llvm::cl::desc("Directory of include files"),
+      llvm::cl::value_desc("directory"), llvm::cl::Prefix);
+  llvm::cl::opt<std::string> dependencyFilename(
+      "d", llvm::cl::desc("Dependency filename"),
+      llvm::cl::value_desc("filename"), llvm::cl::init(""));
+  llvm::cl::opt<bool> writeIfChanged(
+      "write-if-changed",
+      llvm::cl::desc("Only write to the output file if it changed"));
+
+  llvm::InitLLVM y(argc, argv);
+  llvm::cl::ParseCommandLineOptions(argc, argv);
+
+  // Open the input file.
+  std::string errorMessage;
+  std::unique_ptr<llvm::MemoryBuffer> inputFile =
+      openInputFile(inputFilename, &errorMessage);
+  if (!inputFile) {
+    llvm::errs() << errorMessage << "\n";
+    return 1;
+  }
+
+  // Write the output to a buffer.
+  std::string outputStr;
+  llvm::raw_string_ostream os(outputStr);
+  os << "#define GET_OP_DEFS_" << opShardIndex << "\n"
+     << inputFile->getBuffer();
+
+  // Determine whether we need to write the output file.
+  bool shouldWriteOutput = true;
+  if (writeIfChanged) {
+    // Only update the real output file if there are any differences. This
+    // prevents recompilation of all the files depending on it if there aren't
+    // any.
+    if (auto existingOrErr =
+            llvm::MemoryBuffer::getFile(outputFilename, /*IsText=*/true))
+      if (std::move(existingOrErr.get())->getBuffer() == os.str())
+        shouldWriteOutput = false;
+  }
+
+  // Populate the output file if necessary.
+  if (shouldWriteOutput) {
+    std::unique_ptr<llvm::ToolOutputFile> outputFile =
+        openOutputFile(outputFilename, &errorMessage);
+    if (!outputFile) {
+      llvm::errs() << errorMessage << "\n";
+      return 1;
+    }
+    outputFile->os() << os.str();
+    outputFile->keep();
+  }
+
+  // Always write the depfile, even if the main output hasn't changed. If it's
+  // missing, Ninja considers the output dirty.
+  if (!dependencyFilename.empty())
+    if (failed(createDependencyFile(outputFilename, dependencyFilename)))
+      return 1;
+
+  return 0;
+}
diff --git a/mlir/tools/mlir-tblgen/OpDefinitionsGen.cpp b/mlir/tools/mlir-tblgen/OpDefinitionsGen.cpp
index 53ed5cb7c043ec..63fe5a80990746 100644
--- a/mlir/tools/mlir-tblgen/OpDefinitionsGen.cpp
+++ b/mlir/tools/mlir-tblgen/OpDefinitionsGen.cpp
@@ -4303,32 +4303,15 @@ void OpOperandAdaptorEmitter::emitDef(
   emitter.adaptor.writeDefTo(os);
 }
 
-// Emits the opcode enum and op classes.
-static void emitOpClasses(const RecordKeeper &recordKeeper,
-                          const std::vector<Record *> &defs, raw_ostream &os,
-                          bool emitDecl) {
-  // First emit forward declaration for each class, this allows them to refer
-  // to each others in traits for example.
-  if (emitDecl) {
-    os << "#if defined(GET_OP_CLASSES) || defined(GET_OP_FWD_DEFINES)\n";
-    os << "#undef GET_OP_FWD_DEFINES\n";
-    for (auto *def : defs) {
-      Operator op(*def);
-      NamespaceEmitter emitter(os, op.getCppNamespace());
-      os << "class " << op.getCppClassName() << ";\n";
-    }
-    os << "#endif\n\n";
-  }
-
-  IfDefScope scope("GET_OP_CLASSES", os);
+/// Emit the class declarations or definitions for the given op defs.
+static void
+emitOpClasses(const RecordKeeper &recordKeeper,
+              const std::vector<Record *> &defs, raw_ostream &os,
+              const StaticVerifierFunctionEmitter &staticVerifierEmitter,
+              bool emitDecl) {
   if (defs.empty())
     return;
 
-  // Generate all of the locally instantiated methods first.
-  StaticVerifierFunctionEmitter staticVerifierEmitter(os, recordKeeper);
-  os << formatv(opCommentHeader, "Local Utility Method", "Definitions");
-  staticVerifierEmitter.emitOpConstraints(defs, emitDecl);
-
   for (auto *def : defs) {
     Operator op(*def);
     if (emitDecl) {
@@ -4358,34 +4341,145 @@ static void emitOpClasses(const RecordKeeper &recordKeeper,
   }
 }
 
-// Emits a comma-separated list of the ops.
-static void emitOpList(const std::vector<Record *> &defs, raw_ostream &os) {
-  IfDefScope scope("GET_OP_LIST", os);
+/// Emit the declarations for the provided op classes.
+static void emitOpClassDecls(const RecordKeeper &recordKeeper,
+                             const std::vector<Record *> &defs,
+                             raw_ostream &os) {
+  // First emit forward declaration for each class, this allows them to refer
+  // to each others in traits for example.
+  for (auto *def : defs) {
+    Operator op(*def);
+    NamespaceEmitter emitter(os, op.getCppNamespace());
+    os << "class " << op.getCppClassName() << ";\n";
+  }
+
+  // Emit the op class declarations.
+  IfDefScope scope("GET_OP_CLASSES", os);
+  if (defs.empty())
+    return;
+  StaticVerifierFunctionEmitter staticVerifierEmitter(os, recordKeeper);
+  staticVerifierEmitter.collectOpConstraints(defs);
+  emitOpClasses(recordKeeper, defs, os, staticVerifierEmitter,
+                /*emitDecl=*/true);
+}
+
+/// Emit the definitions for the provided op classes.
+static void emitOpClassDefs(const RecordKeeper &recordKeeper,
+                            ArrayRef<Record *> defs, raw_ostream &os,
+                            StringRef constraintPrefix = "") {
+  if (defs.empty())
+    return;
+
+  // Generate all of the locally instantiated methods first.
+  StaticVerifierFunctionEmitter staticVerifierEmitter(os, recordKeeper,
+                                                      constraintPrefix);
+  os << formatv(opCommentHeader, "Local Utility Method", "Definitions");
+  staticVerifierEmitter.collectOpConstraints(defs);
+  staticVerifierEmitter.emitOpConstraints(defs);
 
-  interleave(
-      // TODO: We are constructing the Operator wrapper instance just for
-      // getting it's qualified class name here. Reduce the overhead by having a
-      // lightweight version of Operator class just for that purpose.
-      defs, [&os](Record *def) { os << Operator(def).getQualCppClassName(); },
-      [&os]() { os << ",\n"; });
+  // Emit the classes.
+  emitOpClasses(recordKeeper, defs, os, staticVerifierEmitter,
+                /*emitDecl=*/false);
 }
 
+/// Emit op declarations for all op records.
 static bool emitOpDecls(const RecordKeeper &recordKeeper, raw_ostream &os) {
   emitSourceFileHeader("Op Declarations", os, recordKeeper);
 
   std::vector<Record *> defs = getRequestedOpDefinitions(recordKeeper);
-  emitOpClasses(recordKeeper, defs, os, /*emitDecl=*/true);
+  emitOpClassDecls(recordKeeper, defs, os);
+
+  // If we are generating sharded op definitions, emit the sharded op
+  // registration hooks.
+  SmallVector<ArrayRef<Record *>, 4> shardedDefs;
+  shardOpDefinitions(defs, shardedDefs);
+  if (defs.empty() || shardedDefs.size() <= 1)
+    return false;
+
+  Dialect dialect = Operator(defs.front()).getDialect();
+  NamespaceEmitter ns(os, dialect);
+
+  const char *const opRegistrationHook =
+      "void register{0}Operations{1}({2}::{0} *dialect);\n";
+  os << formatv(opRegistrationHook, dialect.getCppClassName(), "",
+                dialect.getCppNamespace());
+  for (unsigned i = 0; i < shardedDefs.size(); ++i) {
+    os << formatv(opRegistrationHook, dialect.getCppClassName(), i,
+                  dialect.getCppNamespace());
+  }
 
   return false;
 }
 
+/// Generate the dialect op registration hook and the op class definitions for a
+/// shard of ops.
+static void emitOpDefShard(const RecordKeeper &recordKeeper,
+ ...
[truncated]

@joker-eph
Copy link
Collaborator

Can we apply this to the TestDialect as a proof of concept?

@joker-eph
Copy link
Collaborator

Can we apply this to the TestDialect as a proof of concept?

Actually I just see that this is what you do in the follow-up PR, so LGTM here

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

MLIR src sharded makes me think it's related to . mlir files rather than ODS ones. Did you consider making this a mlir-tblgen "function" (such as attribute gen or doc gen) and then call it 2x?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

mlir-tblgen only ingests .td files as records. Do you want me to inject a hook into its main function to sniff the command and change its operating mode?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't follow: can you expand on the command line issue with mlir-tblgen?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

mlir-tblgen calls into the TableGen parser and then calls into a function based on the command line with the parsed records. In order to make it ingest a C++ file (or another kind of file), I have to intercept it in the main function:

Turn this

// Generator that prints records.
GenRegistration printRecords("print-records", "Print all records to stdout",
                             [](const RecordKeeper &records, raw_ostream &os) {
                               os << records;
                               return false;
                             });

int main(int argc, char **argv) { return MlirTblgenMain(argc, argv); }

Into this:

int main(int argc, char **argv) { 
  if (argv[1] == "shard-src-files") return shardSourceFiles(argc, argv)
  return MlirTblgenMain(argc, argv); 
}

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@joker-eph ping!

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ping should be for @jpienaar who started in this direction first :)

I agree though that using mlir-tblgen for non-tablegen file does not seem right.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This wasn't resolved?

jollaitbot pushed a commit to sailfishos-mirror/llvm-project that referenced this pull request Apr 22, 2024
Adds an option to `mlir-tblgen -gen-op-defs` `op-shard-count=N` that divides the
op class definitions and op list into N segments, e.g.

```
// mlir-tblgen -gen-op-defs -op-shard-count=2

void FooDialect::initialize() {
  addOperations<
  >();
  addOperations<
  >();
}

```

When split across multiple source files, this can help significantly improve
dialect compile time for dialects with a large opset.

stack-info: PR: llvm/llvm-project#89423, branch: users/mogball/pr_1
@joker-eph
Copy link
Collaborator

Before I forget: we should add documentation for this, including the how to structure the dialect to support it.

@Mogball
Copy link
Contributor Author

Mogball commented Apr 22, 2024

Before I forget: we should add documentation for this, including the how to structure the dialect to support it.

Added in #89664

Base automatically changed from users/mogball/pr_2 to main April 22, 2024 20:42
Adds an option to `mlir-tblgen -gen-op-defs` `op-shard-count=N` that divides the
op class definitions and op list into N segments, e.g.

```
// mlir-tblgen -gen-op-defs -op-shard-count=2

void FooDialect::initialize() {
  addOperations<
  >();
  addOperations<
  >();
}

```

When split across multiple source files, this can help significantly improve
dialect compile time for dialects with a large opset.

stack-info: PR: #89423, branch: users/mogball/pr_1
@Mogball Mogball merged commit 1b232fa into main Apr 24, 2024
3 of 4 checks passed
@Mogball Mogball deleted the users/mogball/pr_1 branch April 24, 2024 21:58
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bazel "Peripheral" support tier build system: utils/bazel mlir:core MLIR Core Infrastructure mlir
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

4 participants