-
Notifications
You must be signed in to change notification settings - Fork 11k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[mlir][ods] Add documentation on how to use sharded op definitions (NFC) #89664
Conversation
8991d4f
to
f8c449c
Compare
@llvm/pr-subscribers-mlir Author: Jeff Niu (Mogball) ChangesStacked PRs:
[mlir][ods] Add documentation on how to use sharded op definitions (NFC)This adds explanations and instructions on how to set up a dialect for Full diff: https://github.com/llvm/llvm-project/pull/89664.diff 1 Files Affected:
diff --git a/mlir/docs/DefiningDialects/Operations.md b/mlir/docs/DefiningDialects/Operations.md
index 729393d5362673..79a0cc55f13840 100644
--- a/mlir/docs/DefiningDialects/Operations.md
+++ b/mlir/docs/DefiningDialects/Operations.md
@@ -1114,6 +1114,100 @@ void process(AddOp op, ArrayRef<Value> newOperands) {
}
```
+#### Sharded Operation Definitions
+
+Large dialects with many operations may struggle with C++ compile time of
+generated op definitions, due to large compilation units. `mlir-tblgen`
+provides the ability to shard op definitions by splitting them up evenly
+by passing `-op-shard-count` to `-gen-op-defs` and `-gen-op-decls`. The tool
+will generate a single include file for the definitions broken up by
+`GET_OP_DEFS_${N}` where `${N}` is the shard number. A shard can be compiled in
+a single compilation unit by adding a file like this to your dialect library:
+
+```c++
+#include "mlir/IR/Operation.h"
+// Add any other required includes.
+
+// Utilities shared by generated op definitions: custom directive parsers,
+// printers, etc.
+#include "OpUtils.h"
+
+#define GET_OP_DEFS_0
+#include "MyDialectOps.cpp.inc"
+```
+
+Note: this requires restructing shared utility functions within the dialect
+library so they can be shared by multiple compilation units. I.e. instead of
+defining `static` methods in the same source file, you should declare them in a
+shared header and define them in their own source file.
+
+The op registration hooks are also sharded, because the template instantiation
+can take a very long time to compile. Operations should be registered in your
+dialect like:
+
+```c++
+void MyDialect::initialize() {
+ registerMyDialectOperations(this);
+}
+```
+
+CMake and Bazel functions are included to make sharding dialects easier.
+Assuming you have organized your operation utility functions into their own
+header, define a file that looks like the one above, but without the `#define`:
+
+```c++
+// MyDialectOps.cpp
+#include "mlir/IR/Operation.h"
+
+#include "OpUtils.h"
+
+#include "MyDialectOps.cpp.inc"
+```
+
+In CMake, remove the manual `mlir_tablegen` invocations and replace them with:
+
+```cmake
+set(LLVM_TARGET_DEFINITIONS MyDialectOps.td)
+add_sharded_ops(MyDialectOps 8) # shard the op definitions by 8
+
+add_mlir_library(MyDialect
+ MyDialect.cpp
+ MyDialectOpDefs.cpp
+ ${SHARDED_SRCS}
+
+ DEPENDS
+ MLIRTestOpsShardGen
+)
+```
+
+This will automatically duplicate the `MyDialectOps.cpp` source file and add the
+`#define` up the number of shards indicated.
+
+It is recommended that any out-of-line op member functions (like verifiers) be
+defined in a separate source file. In this example, it is called
+`MyDialectOpDefs.cpp`.
+
+In Bazel, remove the `-gen-op-defs` and `-gen-op-decls` invocations, and add
+
+```bazel
+gentbl_sharded_ops(
+ name = "MyDialectOpSrcs",
+ hdr_out = "MyDialectOps.h.inc",
+ shard_count = 8,
+ sharder = "//mlir:mlir-src-sharder",
+ src_file = "MyDialectOps.cpp",
+ src_out = "MyDialectOps.cpp.inc",
+ tblgen = "//mlir:mlir-tblgen",
+ td_file = "MyDialectOps.td",
+ deps = [":MyDialectOpsTdFiles"],
+)
+
+cc_library(
+ name = "MyDialect",
+ srcs = glob(["MyDialect/*.cpp"]) + [":MyDialectOpSrcs"]
+)
+```
+
## Constraints
Constraint is a core concept in table-driven operation definition: operation
|
Adds an option to `mlir-tblgen -gen-op-defs` `op-shard-count=N` that divides the op class definitions and op list into N segments, e.g. ``` // mlir-tblgen -gen-op-defs -op-shard-count=2 void FooDialect::initialize() { addOperations< >(); addOperations< >(); } ``` When split across multiple source files, this can help significantly improve dialect compile time for dialects with a large opset. stack-info: PR: #89423, branch: users/mogball/pr_1
This PR uses the new op sharding mechanism in tablegen to shard the test dialect's op definitions. This breaks the definition of ops into multiple source files, speeding up compile time of the test dialect dramatically. This improves developer cycle times when iterating on the test dialect. stack-info: PR: #89628, branch: users/Mogball/stack/1
This adds explanations and instructions on how to set up a dialect for sharded op definitions to the MLIR documentation. stack-info: PR: #89664, branch: users/Mogball/stack/3
f8c449c
to
f782f1a
Compare
@Mogball , This PR broke our local build. We append
And in
Now we got error message from
|
The tool likely should accept the same sets of flags as TableGen, considering it is using the tablegen() macro. Feel free to send a patch. |
Thanks for responding. I am afraid I am not familiar enough to the code in this PR to create a patch. Would be either @Mogball or you able to please provide a fix? |
I put up a PR #91329 to fix this regression. I am not sure it is the best fix though as if new options are added for |
…tered. (#91329) PR #89664 introduced a regression that it unregistered llvm-tblgen option `-D` for macros. The test `TestOps.cpp` failed due to passing a macros to llvm-tblgen. It caused our internal build to fail because we append `-DLOCAL_NAME` into `LLVM_TABLEGEN_FLANGS` in `llvm/lib/cmake/llvm/TableGen.cmake` as ``` list(APPEND LLVM_TABLEGEN_FLAGS "-DLOCAL_NAME") ``` And in `./llvm/lib/Target/PowerPC/PPC.td`, we check it for some downstream code as: ``` ... #ifdef LOCAL_NAME ... #endif ``` Now we got error message from mlir-src-sharder as ``` mlir-src-sharder -op-shard-index=1 -DLOCAL_NAME llvm-project/mlir/test/lib/Dialect/Test/TestOps.cpp --write-if-changed -o tools/mlir/test/lib/Dialect/Test/TestOps.1.cpp -d tools/mlir/test/lib/Dialect/Test/TestOps.1.cpp.d mlir-src-sharder: Unknown command line argument '-DLOCAL_NAME'. Try: 'llvm-project/build/bin/mlir-src-sharder --help' mlir-src-sharder: Did you mean '-I'? ``` This PR is to fix the regression.
Stacked PRs:
[mlir][ods] Add documentation on how to use sharded op definitions (NFC)
This adds explanations and instructions on how to set up a dialect for
sharded op definitions to the MLIR documentation.