Skip to content

Commit

Permalink
[mlir] update transform dialect tutorials (#81199)
Browse files Browse the repository at this point in the history
Use the "main" transform-interpreter pass instead of the test pass.
This, along with the previously introduced debug extension, now allow
tutorials to no longer depend on test passes and extensions.
  • Loading branch information
ftynse committed Feb 9, 2024
1 parent 7291761 commit b33b91a
Show file tree
Hide file tree
Showing 22 changed files with 615 additions and 559 deletions.
347 changes: 182 additions & 165 deletions mlir/docs/Tutorials/transform/Ch1.md

Large diffs are not rendered by default.

202 changes: 109 additions & 93 deletions mlir/docs/Tutorials/transform/Ch2.md
Expand Up @@ -10,37 +10,40 @@ The Transform dialect uses the dialect extension mechanism to allow additional o
// In MyExtension.cpp.
#include "mlir/Dialect/Transform/IR/TransformDialect.h"

// Define a new Transform dialect extension. This uses the CRTP idiom to identify
// extensions.
// Define a new Transform dialect extension. This uses the CRTP idiom to
// identify extensions.
class MyExtension : public ::mlir::transform::TransformDialectExtension<MyExtension> {
public:
// The extension must derive the base constructor.
using Base::Base;

// This function initializes the extension, similarly to `initialize` in dialect
// definitions. List individual operations and dependent dialects here.
// This function initializes the extension, similarly to `initialize` in
// dialect definitions. List individual operations and dependent dialects
// here.
void init();
};

void MyExtension::init() {
// Similarly to dialects, an extension can declare a dependent dialect. This dialect
// will be loaded along with the extension and, therefore, along with the Transform
// dialect. Only declare as dependent the dialects that contain the attributes or
// types used by transform operations. Do NOT declare as dependent the dialects
// produced during the transformation.
// Similarly to dialects, an extension can declare a dependent dialect. This
// dialect will be loaded along with the extension and, therefore, along with
// the Transform dialect. Only declare as dependent the dialects that contain
// the attributes or types used by transform operations. Do NOT declare as
// dependent the dialects produced during the transformation.
//
// declareDependentDialect<MyDialect>();

// When transformations are applied, they may produce new operations from previously
// unloaded dialects. Typically, a pass would need to declare itself dependent on
// the dialects containing such new operations. To avoid confusion with the dialects
// the extension itself depends on, the Transform dialects differentiates between:
// When transformations are applied, they may produce new operations from
// previously unloaded dialects. Typically, a pass would need to declare
// itself dependent on the dialects containing such new operations. To avoid
// confusion with the dialects the extension itself depends on, the Transform
// dialects differentiates between:
// - dependent dialects, which are used by the transform operations, and
// - generated dialects, which contain the entities (attributes, operations,
// types) that may be produced by applying the transformation even when not
// present in the original payload IR.
// In the following chapter, we will be add operations that generate function calls
// and structured control flow operations, so let's declare the corresponding
// dialects as generated.
// - generated dialects, which contain the entities (attributes, operations,
// types) that may be produced by applying the transformation even when
// not present in the original payload IR.
// In the following chapter, we will be add operations that generate function
// calls and structured control flow operations, so let's declare the
// corresponding dialects as generated.
declareGeneratedDialect<::mlir::scf::SCFDialect>();
declareGeneratedDialect<::mlir::func::FuncDialect>();

Expand Down Expand Up @@ -89,7 +92,7 @@ mlir_tablegen(MyExtension.cpp.inc -gen-op-defs)
# Add a CMakeTarget we can depend on to ensure the generation happens before the compilation.
add_public_tablegen_target(MyExtensionIncGen)

# Don't forget to generate the documentation, this will produce a MyExtension.md under
# Don't forget to generate the documentation, this will produce a MyExtension.md under
# Dialects.
add_mlir_doc(MyExtension MyExtension Dialects/ -gen-op-doc)
```
Expand All @@ -103,7 +106,8 @@ add_mlir_library(
# Built from the following source files.
MyExtension.cpp

# Make sure ODS declaration and definitions are generated before compiling this.
# Make sure ODS declaration and definitions are generated before compiling
# this.
DEPENDS
MyExtensionIncGen

Expand Down Expand Up @@ -136,10 +140,10 @@ This will generate two files, `MyExtension.h.inc` and `MyExtension.cpp.inc`, tha
void MyExtension::init() {
// …

// Finally, we register the additional transform operations with the dialect. List all
// operations generated from ODS. This call will perform additional checks that the
// operations implement the transform and memory effect interfaces required by the
// dialect interpreter and assert if they do not.
// Finally, we register the additional transform operations with the dialect.
// List all operations generated from ODS. This call will perform additional
// checks that the operations implement the transform and memory effect
// interfaces required by the dialect interpreter and assert if they do not.
registerTransformOps<
#define GET_OP_LIST
#include "MyExtension.cpp.inc"
Expand All @@ -154,34 +158,36 @@ With this setup, we are now ready to define the new transform operation to rewri
```tablegen
// In MyExtension.td.
// Define the new operation. By convention, prefix its name with the name of the dialect
// extension, "my.". The full operation name will be further prefixed with "transform.".
// Define the new operation. By convention, prefix its name with the name of the
// dialect extension, "my.". The full operation name will be further prefixed
// with "transform.".
def ChangeCallTargetOp : Op<Transform_Dialect, "my.change_call_target",
// Indicate that the operation implements the required TransformOpInterface and
// MemoryEffectsOpInterface.
// Indicate that the operation implements the required TransformOpInterface
// and MemoryEffectsOpInterface.
[DeclareOpInterfaceMethods<TransformOpInterface>,
DeclareOpInterfaceMethods<MemoryEffectsOpInterface>]> {
// Provide a brief and a full description. It is recommended that the latter describes
// the effects on the operands and how the operation processes various failure modes.
// Provide a brief and a full description. It is recommended that the latter
// describes the effects on the operands and how the operation processes
// various failure modes.
let summary = "Changes the callee of a call operation to the specified one";
let description = [{
For each `func.call` payload operation associated with the handle, changes its
callee to be the symbol whose name is provided as an attribute to this operation.
For each `func.call` payload operation associated with the handle, changes
its callee to be the symbol whose name is provided as an attribute to this operation.
Generates a silenceable failure if the operand is associated with payload operations
that are not `func.call`.
Only reads the operand.
Generates a silenceable failure if the operand is associated with payload operations that are not `func.call`. Only reads the operand.
}];
// The arguments include the handle to the payload operations and the attribute that
// specifies the new callee. The handle must implement TransformHandleTypeInterface.
// We use a string attribute as the symbol may not exist in the transform IR so the
// verification may fail.
// The arguments include the handle to the payload operations and the
// attribute that specifies the new callee. The handle must implement
// TransformHandleTypeInterface.
// We use a string attribute as the symbol may not exist in the transform IR
// so the verification may fail.
let arguments = (ins
TransformHandleTypeInterface:$call,
StrAttr:$new_target);
// The results are empty as the transformation does not produce any new payload.
// The results are empty as the transformation does not produce any new
// payload.
let results = (outs);
// Provide nice syntax.
Expand Down Expand Up @@ -224,8 +230,8 @@ must be modified with the provided rewriter.
// It can also carry additional user-defined state.
::mlir::transform::TransformState &state) {

// First, we need to obtain the list of payload operations that are associated with
// the operand handle.
// First, we need to obtain the list of payload operations that are associated
// with the operand handle.
auto payload = state.getPayloadOps(getCall());

// Then, we iterate over the list of operands and call the actual IR-mutating
Expand Down Expand Up @@ -280,56 +286,66 @@ void registerMyExtension(::mlir::DialectRegistry &registry) {
After registering the extension, it becomes possible to use our new operation in the Transform dialect interpreter. The upstream testing pass can be used as is.
```mlir
transform.sequence failures(propagate) {
^bb0(%arg0: !transform.any_op,
%arg1: !transform.op<"linalg.matmul">,
%arg2: !transform.op<"linalg.elemwise_binary">):
// Since the %arg2 handle is associated with both elementwise operations,
// we need to split it into two handles so we can target only the second
// elementwise operation.
%add, %max = transform.split_handle %arg2 : (!transform.op<"linalg.elemwise_binary">)
-> (!transform.any_op, !transform.any_op)
// The actual tiling transformation takes tile sizes as attributes. It produces a
// handle to the loop generated during tiling.
%loop, %tiled = transform.structured.tile_using_forall %max tile_sizes [8, 32]
: (!transform.any_op) -> (!transform.any_op, !transform.any_op)
// We can now fuse the other operations into the loop. Here, we fuse
// operations one-by-one. This requires the operation that is being fused
// to define the value used within the loop, so the order of such fusions
// is important. We could also use "transform.merge_handles" to obtain
// a single handle to all operations and give it to `fuse_into_containing_op`
// that would take care of the ordering in this case.
%add_fused = transform.structured.fuse_into_containing_op %add into %loop
: (!transform.any_op, !transform.any_op) -> !transform.any_op
%matmul_fused = transform.structured.fuse_into_containing_op %arg1 into %loop
: (!transform.op<"linalg.matmul">, !transform.any_op) -> !transform.any_op
// Tile again to get the desired size. Note that this time this tiles the
// "add" operation and fuses matmul into the loop, but doesn't affect the
// "max" operation. This illustrates the precise targeting with the transform
// dialect. Otherwise, it is difficult to differentiate "add" and "max", both
// of which having the same kind.
%loop_2, %tiled_2 = transform.structured.tile_using_forall %add_fused tile_sizes [4, 4]
: (!transform.any_op) -> (!transform.any_op, !transform.any_op)
%matmul_fused_2 = transform.structured.fuse_into_containing_op %matmul_fused into %loop_2
: (!transform.any_op, !transform.any_op) -> !transform.any_op
// Since outlining is currently only implemented for region-holding operations
// such as loops, use tiling to size 1 to materialize the outer loop that is
// going to be outlined.
%outline_target, %_ = transform.structured.tile_using_forall %tiled_2 tile_sizes [1]
: (!transform.any_op) -> (!transform.any_op, !transform.any_op)
transform.structured.fuse_into_containing_op %matmul_fused_2 into %outline_target
: (!transform.any_op, !transform.any_op) -> !transform.any_op
%func, %call = transform.loop.outline %outline_target {func_name = "outlined"}
: (!transform.any_op) -> (!transform.any_op, !transform.any_op)
// Rewrite the call target.
transform.my.change_call_target %call, "microkernel" : !transform.any_op
transform.yield
module attributes {transform.with_named_sequence} {
transform.named_sequence @__transform_main(
%arg0: !transform.any_op,
%arg1: !transform.op<"linalg.matmul">,
%arg2: !transform.op<"linalg.elemwise_binary">) {
// Since the %arg2 handle is associated with both elementwise operations,
// we need to split it into two handles so we can target only the second
// elementwise operation.
%add, %max = transform.split_handle %arg2
: (!transform.op<"linalg.elemwise_binary">)
-> (!transform.any_op, !transform.any_op)
// The actual tiling transformation takes tile sizes as attributes. It
// produces a handle to the loop generated during tiling.
%loop, %tiled = transform.structured.tile_using_forall %max
tile_sizes [8, 32]
: (!transform.any_op) -> (!transform.any_op, !transform.any_op)
// We can now fuse the other operations into the loop. Here, we fuse
// operations one-by-one. This requires the operation that is being fused
// to define the value used within the loop, so the order of such fusions
// is important. We could also use "transform.merge_handles" to obtain
// a single handle to all operations and give it to
// `fuse_into_containing_op` that would take care of the ordering in this
// case.
%add_fused = transform.structured.fuse_into_containing_op %add into %loop
: (!transform.any_op, !transform.any_op) -> !transform.any_op
%matmul_fused = transform.structured.fuse_into_containing_op %arg1
into %loop
: (!transform.op<"linalg.matmul">, !transform.any_op)
-> !transform.any_op
// Tile again to get the desired size. Note that this time this tiles the
// "add" operation and fuses matmul into the loop, but doesn't affect the
// "max" operation. This illustrates the precise targeting with the
// transform dialect. Otherwise, it is difficult to differentiate "add" and
// "max", both of which having the same kind.
%loop_2, %tiled_2 = transform.structured.tile_using_forall %add_fused
tile_sizes [4, 4]
: (!transform.any_op) -> (!transform.any_op, !transform.any_op)
%matmul_fused_2 = transform.structured.fuse_into_containing_op %matmul_fused
into %loop_2
: (!transform.any_op, !transform.any_op) -> !transform.any_op
// Since outlining is currently only implemented for region-holding
// operations such as loops, use tiling to size 1 to materialize the outer
// loop that is going to be outlined.
%outline_target, %_ = transform.structured.tile_using_forall %tiled_2 tile_sizes [1]
: (!transform.any_op) -> (!transform.any_op, !transform.any_op)
transform.structured.fuse_into_containing_op %matmul_fused_2 into %outline_target
: (!transform.any_op, !transform.any_op) -> !transform.any_op
%func, %call = transform.loop.outline %outline_target
{func_name = "outlined"}
: (!transform.any_op) -> (!transform.any_op, !transform.any_op)
// Rewrite the call target.
transform.my.change_call_target %call, "microkernel" : !transform.any_op
transform.yield
}
}
```

Expand Down
12 changes: 6 additions & 6 deletions mlir/docs/Tutorials/transform/Ch3.md
Expand Up @@ -79,15 +79,15 @@ def CallOpInterfaceHandle
// The type must implement `TransformHandleTypeInterface`.
[DeclareTypeInterfaceMethods<TransformHandleTypeInterface>]> {
// The usual components of a type such as description, mnemonic and assembly format
// The usual components of a type such as description, mnemonic and assembly format
// should be provided.
let summary = "handle to payload operations implementing CallOpInterface";
let mnemonic = "my.call_op_interface";
let assemblyFormat = "";
}
```

We will omit the generation of declaration and definitions using Tablegen for brevity as it is identical to the regular case.
We will omit the generation of declaration and definitions using Tablegen for brevity as it is identical to the regular case.

To finalize the definition of a transform type, one must implement the interface methods.

Expand All @@ -109,9 +109,9 @@ mlir::transform::CallOpInterfaceHandleType::checkPayload(
if (llvm::isa<mlir::CallOpInterface>(op))
continue;

// By convention, these verifiers always emit a silenceable failure since they are
// By convention, these verifiers always emit a silenceable failure since they are
// checking a precondition.
DiagnosedSilenceableFailure diag = emitSilenceableError(loc)
DiagnosedSilenceableFailure diag = emitSilenceableError(loc)
<< "expected the payload operation to implement CallOpInterface";
diag.attachNote(op->getLoc()) << "offending operation";
return diag;
Expand All @@ -129,8 +129,8 @@ Additional attributes and types need to be registered in the extension, next to
// In MyExtension.cpp.
void MyExtension::init() {
//
// ...
registerTypes<
#define GET_TYPEDEF_LIST
#include "MyExtensionTypes.cpp.inc"
Expand Down
2 changes: 1 addition & 1 deletion mlir/docs/Tutorials/transform/Ch4.md
Expand Up @@ -205,7 +205,7 @@ transform.named_sequence @__transform_main(
%root: !transform.any_op {transform.readonly}) {
// Collect groups of operations that match the criteria specified in the
// named sequence.
%matmul, %el1, %el2 = transform.collect_matching @match_matmul_elemwise in %root
%matmul, %el1, %el2 = transform.collect_matching @match_matmul_elemwise in %root
: (!transform.any_op) -> (!transform.any_op, !transform.any_op, !transform.any_op)
%elemwise = transform.merge_handles %el1, %el2 : !transform.any_op
Expand Down
22 changes: 4 additions & 18 deletions mlir/examples/transform/Ch2/transform-opt/transform-opt.cpp
Expand Up @@ -12,6 +12,7 @@

#include "MyExtension.h"

#include "mlir/Dialect/Transform/Transforms/Passes.h"
#include "mlir/IR/DialectRegistry.h"
#include "mlir/IR/MLIRContext.h"
#include "mlir/InitAllDialects.h"
Expand All @@ -20,14 +21,6 @@
#include "mlir/Transforms/Passes.h"
#include <cstdlib>

// Forward declarations of test passes that used in this chapter for
// illustrative purposes. Test passes are not directly exposed for use in
// binaries other than mlir-opt, which is too big to serve as an example.
namespace mlir::test {
void registerTestTransformDialectEraseSchedulePass();
void registerTestTransformDialectInterpreterPass();
} // namespace mlir::test

namespace test {
void registerTestTransformDialectExtension(mlir::DialectRegistry &);
} // namespace test
Expand All @@ -39,22 +32,15 @@ int main(int argc, char **argv) {
mlir::registerAllExtensions(registry);
registerMyExtension(registry);

// Register transform interpreter pass.
mlir::transform::registerInterpreterPass();

// Register a handful of cleanup passes that we can run to make the output IR
// look nicer.
mlir::registerCanonicalizerPass();
mlir::registerCSEPass();
mlir::registerSymbolDCEPass();

// Register the test passes.
#ifdef MLIR_INCLUDE_TESTS
mlir::test::registerTestTransformDialectEraseSchedulePass();
mlir::test::registerTestTransformDialectInterpreterPass();
test::registerTestTransformDialectExtension(registry);
#else
llvm::errs() << "warning: MLIR built without test passes, interpreter "
"testing will not be available\n";
#endif // MLIR_INCLUDE_TESTS

// Delegate to the MLIR utility for parsing and pass management.
return mlir::MlirOptMain(argc, argv, "transform-opt-ch2", registry)
.succeeded()
Expand Down

0 comments on commit b33b91a

Please sign in to comment.