diff --git a/mlir/include/mlir/Dialect/Transform/IR/TransformDialect.td b/mlir/include/mlir/Dialect/Transform/IR/TransformDialect.td index 055cd78e6130a..598a0649b524a 100644 --- a/mlir/include/mlir/Dialect/Transform/IR/TransformDialect.td +++ b/mlir/include/mlir/Dialect/Transform/IR/TransformDialect.td @@ -118,7 +118,7 @@ def Transform_Dialect : Dialect { mapping between transform IR values and payload IR operations. - An `Allocate` effect from this resource means creating a new mapping - entry, it is always accompanied by a `Write` effet. + entry, it is always accompanied by a `Write` effect. - A `Read` effect from this resource means accessing the mapping. @@ -151,6 +151,37 @@ def Transform_Dialect : Dialect { transform operations can return _new_ handles that can be read or consumed by subsequent operations. + ## Handle Invalidation + + The execution model of the transform dialect expects that a payload IR + operation is associated with _at most one_ transform IR handle. This avoids + the situation when a handle to an operation outlives the operation itself + that can be erased during a transformation triggered through another handle. + + Handles pointing to operations nested in each other are allowed to co-exist + in the transform IR. However, a transform IR operation that consumes such a + handle automatically _invalidates_ all the other handles that are associated + with operations nested in the operations associated with the consumed + handle. Any use of the invalidated handle results in undefined behavior + since the payload IR operations associated with it are likely to have been + mutated or erased. The mere fact of the handle being invalidated does _not_ + trigger undefined behavior, only its appearance as an operand does. + Invalidation applies to the entire handle, even if some of the payload IR + operations associated with it are not nested in payload IR operations + associated with another, consumed handle. + + Note: the restriction on two handles not pointing to the same operation may + be relaxed in the future to follow the invalidation model for nested + operation. + + The Transform dialect infrastructure has the capability of checking whether + the transform IR op operand is invalidated before applying the + transformation. However, such a check is computationally expensive and + must be enabled explicitly through `TransformOptions`. Additionally, the + `transform-dialect-check-uses` pass emits warnings when a handle may be used + after it has been consumed, but does so abstractly, without processing the + payload IR. + ## Intended Use and Integrations The transformation control infrastructure provided by this dialect is @@ -184,7 +215,7 @@ def Transform_Dialect : Dialect { differentiate between the parts of the loop produced by the previous pass (both are the same operation, and it is likely undesirable to pollute the operation with pass-specific information). Implementing passes that run the - combined transfomration would have run into the combinatorial explosion + combined transformation would have run into the combinatorial explosion issue due to multiple possible transform compositions or into the need for deep pass parameterization, the ultimate form of which is an ad-hoc dialect to specify which transformations the pass should run. The transform dialect @@ -200,7 +231,7 @@ def Transform_Dialect : Dialect { takes care of bookkeeping. As such, the transform dialect does not provide the interpreter pass. Instead, it provides a set of utilities that can be used by clients to define their own interpreter passes or as part of a more - complex pass. For example, the mapping between values in the tranfsorm IR + complex pass. For example, the mapping between values in the transform IR and operations in the payload IR, or the function that applies the transformations specified by ops in the given block sequentially. Note that a transform op may have regions with further transform ops in them, with diff --git a/mlir/include/mlir/Dialect/Transform/IR/TransformInterfaces.h b/mlir/include/mlir/Dialect/Transform/IR/TransformInterfaces.h index 1e7cfb53ca8e6..c42474c671e97 100644 --- a/mlir/include/mlir/Dialect/Transform/IR/TransformInterfaces.h +++ b/mlir/include/mlir/Dialect/Transform/IR/TransformInterfaces.h @@ -18,6 +18,27 @@ namespace transform { class TransformOpInterface; +/// Options controlling the application of transform operations by the +/// TransformState. +class TransformOptions { +public: + TransformOptions() {} + + /// Requests computationally expensive checks of the transform and payload IR + /// well-formedness to be performed before each transformation. In particular, + /// these ensure that the handles still point to valid operations when used. + TransformOptions &enableExpensiveChecks(bool enable = true) { + expensiveChecksEnabled = enable; + return *this; + } + + /// Returns true if the expensive checks are requested. + bool getExpensiveChecksEnabled() const { return expensiveChecksEnabled; } + +private: + bool expensiveChecksEnabled = true; +}; + /// The state maintained across applications of various ops implementing the /// TransformOpInterface. The operations implementing this interface and the /// surrounding structure are referred to as transform IR. The operations to @@ -63,8 +84,10 @@ class TransformState { /// Creates a state for transform ops living in the given region. The parent /// operation of the region. The second argument points to the root operation /// in the payload IR beind transformed, which may or may not contain the - /// region with transform ops. - TransformState(Region ®ion, Operation *root); + /// region with transform ops. Additional options can be provided through the + /// trailing configuration object. + TransformState(Region ®ion, Operation *root, + const TransformOptions &options = TransformOptions()); /// Returns the op at which the transformation state is rooted. This is /// typically helpful for transformations that apply globally. @@ -296,6 +319,21 @@ class TransformState { static LogicalResult tryEmplaceReverseMapping(Mappings &map, Operation *op, Value handle); + /// If the operand is a handle consumed by the operation, i.e. has the "free" + /// memory effect associated with it, identifies other handles that are + /// pointing to payload IR operations nested in the operations pointed to by + /// the consumed handle. Marks all such handles as invalidated so trigger + /// errors if they are used. + void recordHandleInvalidation(OpOperand &handle); + + /// Checks that the operation does not use invalidated handles as operands. + /// Reports errors and returns failure if it does. Otherwise, invalidates the + /// handles consumed by the operation as well as any handles pointing to + /// payload IR operations nested in the operations associated with the + /// consumed handles. + LogicalResult + checkAndRecordHandleInvalidation(TransformOpInterface transform); + /// The mappings between transform IR values and payload IR ops, aggregated by /// the region in which the transform IR values are defined. llvm::SmallDenseMap mappings; @@ -307,6 +345,14 @@ class TransformState { /// The top-level operation that contains all payload IR, typically a module. Operation *topLevel; + /// Additional options controlling the transformation state behavior. + TransformOptions options; + + /// The mapping from invalidated handles to the error-reporting functions that + /// describe when the handles were invalidated. Calling such a function emits + /// a user-visible diagnostic. + DenseMap> invalidatedHandles; + #if LLVM_ENABLE_ABI_BREAKING_CHECKS /// A stack of nested regions that are being processed in the transform IR. /// Each region must be an ancestor of the following regions in this list. diff --git a/mlir/lib/Dialect/Transform/IR/TransformInterfaces.cpp b/mlir/lib/Dialect/Transform/IR/TransformInterfaces.cpp index 3e11b6794bbde..ff28f447e43bd 100644 --- a/mlir/lib/Dialect/Transform/IR/TransformInterfaces.cpp +++ b/mlir/lib/Dialect/Transform/IR/TransformInterfaces.cpp @@ -21,8 +21,9 @@ using namespace mlir; constexpr const Value transform::TransformState::kTopLevelValue; -transform::TransformState::TransformState(Region ®ion, Operation *root) - : topLevel(root) { +transform::TransformState::TransformState(Region ®ion, Operation *root, + const TransformOptions &options) + : topLevel(root), options(options) { auto result = mappings.try_emplace(®ion); assert(result.second && "the region scope is already present"); (void)result; @@ -120,8 +121,78 @@ LogicalResult transform::TransformState::updatePayloadOps( return success(); } +void transform::TransformState::recordHandleInvalidation(OpOperand &handle) { + ArrayRef potentialAncestors = getPayloadOps(handle.get()); + for (const Mappings &mapping : llvm::make_second_range(mappings)) { + for (const auto &kvp : mapping.reverse) { + // If the op is associated with invalidated handle, skip the check as it + // may be reading invalid IR. + Operation *op = kvp.first; + Value otherHandle = kvp.second; + if (invalidatedHandles.count(otherHandle)) + continue; + + for (Operation *ancestor : potentialAncestors) { + if (!ancestor->isProperAncestor(op)) + continue; + + // Make sure the error-reporting lambda doesn't capture anything + // by-reference because it will go out of scope. Additionally, extract + // location from Payload IR ops because the ops themselves may be + // deleted before the lambda gets called. + Location ancestorLoc = ancestor->getLoc(); + Location opLoc = op->getLoc(); + Operation *owner = handle.getOwner(); + unsigned operandNo = handle.getOperandNumber(); + invalidatedHandles[otherHandle] = [ancestorLoc, opLoc, owner, operandNo, + otherHandle]() { + InFlightDiagnostic diag = + owner->emitOpError() + << "invalidated the handle to payload operations nested in the " + "payload operation associated with its operand #" + << operandNo; + diag.attachNote(ancestorLoc) << "ancestor op"; + diag.attachNote(opLoc) << "nested op"; + diag.attachNote(otherHandle.getLoc()) << "other handle"; + }; + } + } + } +} + +LogicalResult transform::TransformState::checkAndRecordHandleInvalidation( + TransformOpInterface transform) { + auto memoryEffectsIface = + cast(transform.getOperation()); + SmallVector effects; + memoryEffectsIface.getEffectsOnResource( + transform::TransformMappingResource::get(), effects); + + for (OpOperand &target : transform->getOpOperands()) { + // If the operand uses an invalidated handle, report it. + auto it = invalidatedHandles.find(target.get()); + if (it != invalidatedHandles.end()) + return it->getSecond()(), failure(); + + // Invalidate handles pointing to the operations nested in the operation + // associated with the handle consumed by this operation. + auto consumesTarget = [&](const MemoryEffects::EffectInstance &effect) { + return isa(effect.getEffect()) && + effect.getValue() == target.get(); + }; + if (llvm::find_if(effects, consumesTarget) != effects.end()) + recordHandleInvalidation(target); + } + return success(); +} + LogicalResult transform::TransformState::applyTransform(TransformOpInterface transform) { + if (options.getExpensiveChecksEnabled() && + failed(checkAndRecordHandleInvalidation(transform))) { + return failure(); + } + transform::TransformResults results(transform->getNumResults()); if (failed(transform.apply(results, *this))) return failure(); @@ -131,23 +202,23 @@ transform::TransformState::applyTransform(TransformOpInterface transform) { auto memEffectInterface = cast(transform.getOperation()); SmallVector effects; - for (Value target : transform->getOperands()) { + for (OpOperand &target : transform->getOpOperands()) { effects.clear(); - memEffectInterface.getEffectsOnValue(target, effects); + memEffectInterface.getEffectsOnValue(target.get(), effects); if (llvm::any_of(effects, [](const MemoryEffects::EffectInstance &effect) { return isa( effect.getResource()) && isa(effect.getEffect()); })) { - removePayloadOps(target); + removePayloadOps(target.get()); } } - for (auto &en : llvm::enumerate(transform->getResults())) { - assert(en.value().getDefiningOp() == transform.getOperation() && + for (OpResult result : transform->getResults()) { + assert(result.getDefiningOp() == transform.getOperation() && "payload IR association for a value other than the result of the " "current transform op"); - if (failed(setPayloadOps(en.value(), results.get(en.index())))) + if (failed(setPayloadOps(result, results.get(result.getResultNumber())))) return failure(); } diff --git a/mlir/test/Dialect/Transform/expensive-checks.mlir b/mlir/test/Dialect/Transform/expensive-checks.mlir new file mode 100644 index 0000000000000..c86367154f37b --- /dev/null +++ b/mlir/test/Dialect/Transform/expensive-checks.mlir @@ -0,0 +1,27 @@ +// RUN: mlir-opt --test-transform-dialect-interpreter='enable-expensive-checks=1' --split-input-file --verify-diagnostics %s + +// expected-note @below {{ancestor op}} +func.func @func() { + // expected-note @below {{nested op}} + return +} + +transform.with_pdl_patterns { +^bb0(%arg0: !pdl.operation): + pdl.pattern @return : benefit(1) { + %0 = operands + %1 = types + %2 = operation "func.return"(%0 : !pdl.range) -> (%1 : !pdl.range) + rewrite %2 with "transform.dialect" + } + + sequence %arg0 { + ^bb1(%arg1: !pdl.operation): + // expected-note @below {{other handle}} + %0 = pdl_match @return in %arg1 + %1 = get_closest_isolated_parent %0 + // expected-error @below {{invalidated the handle to payload operations nested in the payload operation associated with its operand #0}} + test_consume_operand %1 + test_print_remark_at_operand %0, "remark" + } +} diff --git a/mlir/test/lib/Dialect/Transform/TestTransformDialectExtension.cpp b/mlir/test/lib/Dialect/Transform/TestTransformDialectExtension.cpp index 39a3afce95fc7..e2d0a74b9f429 100644 --- a/mlir/test/lib/Dialect/Transform/TestTransformDialectExtension.cpp +++ b/mlir/test/lib/Dialect/Transform/TestTransformDialectExtension.cpp @@ -119,6 +119,12 @@ LogicalResult mlir::test::TestProduceParamOrForwardOperandOp::verify() { return success(); } +LogicalResult +mlir::test::TestConsumeOperand::apply(transform::TransformResults &results, + transform::TransformState &state) { + return success(); +} + LogicalResult mlir::test::TestConsumeOperandIfMatchesParamOrFail::apply( transform::TransformResults &results, transform::TransformState &state) { ArrayRef payload = state.getPayloadOps(getOperand()); diff --git a/mlir/test/lib/Dialect/Transform/TestTransformDialectExtension.td b/mlir/test/lib/Dialect/Transform/TestTransformDialectExtension.td index 8623b8a18aaa8..0cf6b8fb4612d 100644 --- a/mlir/test/lib/Dialect/Transform/TestTransformDialectExtension.td +++ b/mlir/test/lib/Dialect/Transform/TestTransformDialectExtension.td @@ -34,6 +34,15 @@ def TestProduceParamOrForwardOperandOp let hasVerifier = 1; } +def TestConsumeOperand : Op]> { + let arguments = (ins + Arg:$operand); + let assemblyFormat = "$operand attr-dict"; + let cppNamespace = "::mlir::test"; +} + def TestConsumeOperandIfMatchesParamOrFail : Op]> { diff --git a/mlir/test/lib/Dialect/Transform/TestTransformDialectInterpreter.cpp b/mlir/test/lib/Dialect/Transform/TestTransformDialectInterpreter.cpp index 8c4aa1aee5529..e54b2c96dd91a 100644 --- a/mlir/test/lib/Dialect/Transform/TestTransformDialectInterpreter.cpp +++ b/mlir/test/lib/Dialect/Transform/TestTransformDialectInterpreter.cpp @@ -27,6 +27,10 @@ class TestTransformDialectInterpreterPass MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID( TestTransformDialectInterpreterPass) + TestTransformDialectInterpreterPass() = default; + TestTransformDialectInterpreterPass( + const TestTransformDialectInterpreterPass &) {} + StringRef getArgument() const override { return "test-transform-dialect-interpreter"; } @@ -37,13 +41,21 @@ class TestTransformDialectInterpreterPass void runOnOperation() override { ModuleOp module = getOperation(); - transform::TransformState state(module.getBodyRegion(), module); + transform::TransformState state( + module.getBodyRegion(), module, + transform::TransformOptions().enableExpensiveChecks( + enableExpensiveChecks)); for (auto op : module.getBody()->getOps()) { if (failed(state.applyTransform(op))) return signalPassFailure(); } } + + Option enableExpensiveChecks{ + *this, "enable-expensive-checks", llvm::cl::init(false), + llvm::cl::desc("perform expensive checks to better report errors in the " + "transform IR")}; }; } // namespace