Skip to content

Commit

Permalink
[mlir] add transform dialect entry point
Browse files Browse the repository at this point in the history
Introduce `transform::applyTransforms` as a top-level entry point to the
Transform dialect-driven transformation infrastructure, by analogy with
`applyFull/PartialConversion`. Clients are expected to use this function
and no longer need to maintain the transformation state. Make the
constructor of the TransformState private for that purpose.

Reviewed By: springerm

Differential Revision: https://reviews.llvm.org/D135681
  • Loading branch information
ftynse committed Oct 12, 2022
1 parent 812ad21 commit 32f0bde
Show file tree
Hide file tree
Showing 5 changed files with 95 additions and 37 deletions.
27 changes: 19 additions & 8 deletions mlir/include/mlir/Dialect/Transform/IR/TransformDialect.td
Original file line number Diff line number Diff line change
Expand Up @@ -16,19 +16,18 @@ def Transform_Dialect : Dialect {
let description = [{
## Disclaimer

** Proceed with care: not ready for general use. **
**This dialect is actively developed and may change frequently.**

This dialect is evolving rapidly and may change on a very short notice. To
decrease the maintenance burden and churn, only a few in-tree use cases are
currently supported in the main tree:
To decrease the maintenance burden and churn, please post a description of
the intended use case on the MLIR forum. A few in-tree use cases are
currently supported:

- high-level transformations on "structured ops" (i.e. ops that operate on
chunks of data in a way that can be decomposed into operations on
smaller chunks of data and control flow) in Linalg, Tensor and Vector
dialects.

*Please post a description of the intended use case on the MLIR forum and
wait for confirmation.*
dialects;
- loop transformations in the SCF dialect.


## Overview

Expand Down Expand Up @@ -79,6 +78,18 @@ def Transform_Dialect : Dialect {
expected to have the `PossibleTopLevelTransformOpTrait` and may be used
without arguments.

A program transformation expressed using the Transform dialect can be
programmatically triggered by calling:

```c++
LogicalResult transform::applyTransforms(Operation *payloadRoot,
TransformOpInterface transform,
const TransformOptions &options);
```

that applies the transformations specified by the top-level `transform` to
payload IR contained in `payloadRoot`.

## Dialect Extension Mechanism

This dialect is designed to be extensible, that is, clients of this dialect
Expand Down
29 changes: 21 additions & 8 deletions mlir/include/mlir/Dialect/Transform/IR/TransformInterfaces.h
Original file line number Diff line number Diff line change
Expand Up @@ -206,6 +206,16 @@ class TransformOptions {
bool expensiveChecksEnabled = true;
};

/// Entry point to the Transform dialect infrastructure. Applies the
/// transformation specified by `transform` to payload IR contained in
/// `payloadRoot`. The `transform` operation may contain other operations that
/// will be executed following the internal logic of the operation. It must
/// have the `PossibleTopLevelTransformOp` trait and not have any operands.
/// This function internally keeps track of the transformation state.
LogicalResult
applyTransforms(Operation *payloadRoot, TransformOpInterface transform,
const TransformOptions &options = TransformOptions());

/// The state maintained across applications of various ops implementing the
/// TransformOpInterface. The operations implementing this interface and the
/// surrounding structure are referred to as transform IR. The operations to
Expand Down Expand Up @@ -250,15 +260,11 @@ class TransformState {
TransformOpReverseMapping reverse;
};

public:
/// Creates a state for transform ops living in the given region. The parent
/// operation of the region. The second argument points to the root operation
/// in the payload IR being transformed, which may or may not contain the
/// region with transform ops. Additional options can be provided through the
/// trailing configuration object.
TransformState(Region &region, Operation *root,
const TransformOptions &options = TransformOptions());
friend LogicalResult applyTransforms(Operation *payloadRoot,
TransformOpInterface transform,
const TransformOptions &options);

public:
/// Returns the op at which the transformation state is rooted. This is
/// typically helpful for transformations that apply globally.
Operation *getTopLevel() const;
Expand Down Expand Up @@ -438,6 +444,13 @@ class TransformState {
/// Identifier for storing top-level value in the `operations` mapping.
static constexpr Value kTopLevelValue = Value();

/// Creates a state for transform ops living in the given region. The second
/// argument points to the root operation in the payload IR being transformed,
/// which may or may not contain the region with transform ops. Additional
/// options can be provided through the trailing configuration object.
TransformState(Region *region, Operation *payloadRoot,
const TransformOptions &options = TransformOptions());

/// Returns the mappings frame for the reigon in which the value is defined.
const Mappings &getMapping(Value value) const {
return const_cast<TransformState *>(this)->getMapping(value);
Expand Down
31 changes: 27 additions & 4 deletions mlir/lib/Dialect/Transform/IR/TransformInterfaces.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
#include "mlir/IR/Operation.h"
#include "llvm/ADT/STLExtras.h"
#include "llvm/Support/Debug.h"
#include "llvm/Support/ErrorHandling.h"

#define DEBUG_TYPE "transform-dialect"
#define DEBUG_PRINT_AFTER_ALL "transform-dialect-print-top-level-after-all"
Expand All @@ -25,14 +26,15 @@ using namespace mlir;

constexpr const Value transform::TransformState::kTopLevelValue;

transform::TransformState::TransformState(Region &region, Operation *root,
transform::TransformState::TransformState(Region *region,
Operation *payloadRoot,
const TransformOptions &options)
: topLevel(root), options(options) {
auto result = mappings.try_emplace(&region);
: topLevel(payloadRoot), options(options) {
auto result = mappings.try_emplace(region);
assert(result.second && "the region scope is already present");
(void)result;
#if LLVM_ENABLE_ABI_BREAKING_CHECKS
regionStack.push_back(&region);
regionStack.push_back(region);
#endif // LLVM_ENABLE_ABI_BREAKING_CHECKS
}

Expand Down Expand Up @@ -447,6 +449,27 @@ void transform::onlyReadsPayload(
effects.emplace_back(MemoryEffects::Read::get(), PayloadIRResource::get());
}

//===----------------------------------------------------------------------===//
// Entry point.
//===----------------------------------------------------------------------===//

LogicalResult transform::applyTransforms(Operation *payloadRoot,
TransformOpInterface transform,
const TransformOptions &options) {
#ifndef NDEBUG
if (!transform->hasTrait<PossibleTopLevelTransformOpTrait>() ||
transform->getNumOperands() != 0) {
transform->emitError()
<< "expected transform to start at the top-level transform op";
llvm::report_fatal_error("could not run transforms",
/*gen_crash_diag=*/false);
}
#endif // NDEBUG

TransformState state(transform->getParentRegion(), payloadRoot, options);
return state.applyTransform(transform).checkAndReport();
}

//===----------------------------------------------------------------------===//
// Generated interface implementation.
//===----------------------------------------------------------------------===//
Expand Down
36 changes: 24 additions & 12 deletions mlir/test/Dialect/Transform/test-interpreter.mlir
Original file line number Diff line number Diff line change
@@ -1,29 +1,41 @@
// RUN: mlir-opt %s --test-transform-dialect-interpreter -allow-unregistered-dialect --split-input-file --verify-diagnostics

// expected-remark @below {{applying transformation}}
transform.test_transform_op
transform.sequence failures(propagate) {
^bb0(%arg0: !transform.any_op):
// expected-remark @below {{applying transformation}}
transform.test_transform_op
}

// -----

%0 = transform.test_produce_param_or_forward_operand 42 { foo = "bar" }
// expected-remark @below {{succeeded}}
transform.test_consume_operand_if_matches_param_or_fail %0[42]
transform.sequence failures(propagate) {
^bb0(%arg0: !transform.any_op):
%0 = transform.test_produce_param_or_forward_operand 42 { foo = "bar" }
// expected-remark @below {{succeeded}}
transform.test_consume_operand_if_matches_param_or_fail %0[42]
}

// -----

%0 = transform.test_produce_param_or_forward_operand 42 { foo = "bar" }
// expected-error @below {{expected the operand to be associated with 21 got 42}}
transform.test_consume_operand_if_matches_param_or_fail %0[21]
transform.sequence failures(propagate) {
^bb0(%arg0: !transform.any_op):
%0 = transform.test_produce_param_or_forward_operand 42 { foo = "bar" }
// expected-error @below {{expected the operand to be associated with 21 got 42}}
transform.test_consume_operand_if_matches_param_or_fail %0[21]
}

// -----

// It is okay to have multiple handles to the same payload op as long
// as only one of them is consumed. The expensive checks mode is necessary
// to detect double-consumption.
%0 = transform.test_produce_param_or_forward_operand 42 { foo = "bar" }
%1 = transform.test_copy_payload %0
// expected-remark @below {{succeeded}}
transform.test_consume_operand_if_matches_param_or_fail %0[42]
transform.sequence failures(propagate) {
^bb0(%arg0: !transform.any_op):
%0 = transform.test_produce_param_or_forward_operand 42 { foo = "bar" }
%1 = transform.test_copy_payload %0
// expected-remark @below {{succeeded}}
transform.test_consume_operand_if_matches_param_or_fail %0[42]
}

// -----

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -41,13 +41,12 @@ class TestTransformDialectInterpreterPass

void runOnOperation() override {
ModuleOp module = getOperation();
transform::TransformState state(
module.getBodyRegion(), module,
transform::TransformOptions().enableExpensiveChecks(
enableExpensiveChecks));
for (auto op :
module.getBody()->getOps<transform::TransformOpInterface>()) {
if (failed(state.applyTransform(op).checkAndReport()))
if (failed(transform::applyTransforms(
module, op,
transform::TransformOptions().enableExpensiveChecks(
enableExpensiveChecks))))
return signalPassFailure();
}
}
Expand Down

0 comments on commit 32f0bde

Please sign in to comment.