Skip to content

Commit

Permalink
[mlir] add readonly/consume annotations to transform named sequences
Browse files Browse the repository at this point in the history
Use the argument attribute mechanism for function-like operations to
annotate the arguments of named transform sequences as consuming or only
reading the handles passed as arguments. This makes it possible to
correctly specify handle invalidation for external named sequences by
requiring their declarations to always provide such annotations.
Additionally, these annotations remove the need to analyze the body of
a named sequence to understand its effects on the arguments. Make them
required for named sequences that are called from the same file, in
addition to external sequences.

Provide a convenience pass that infers annotations by analyzing bodies
of named sequences provided they are not called from the same file.

Reviewed By: springerm

Differential Revision: https://reviews.llvm.org/D147223
  • Loading branch information
ftynse committed Apr 4, 2023
1 parent ebd579c commit 4110934
Show file tree
Hide file tree
Showing 17 changed files with 380 additions and 28 deletions.
9 changes: 8 additions & 1 deletion mlir/include/mlir/Dialect/Transform/IR/TransformDialect.td
Expand Up @@ -36,6 +36,13 @@ def Transform_Dialect : Dialect {
constexpr const static llvm::StringLiteral
kTargetTagAttrName = "transform.target_tag";

/// Names of the attributes indicating whether an argument of an external
/// transform dialect symbol is consumed or only read.
constexpr const static llvm::StringLiteral
kArgConsumedAttrName = "transform.consumed";
constexpr const static llvm::StringLiteral
kArgReadOnlyAttrName = "transform.readonly";

/// Returns the named PDL constraint functions available in the dialect
/// as a map from their name to the function.
const ::llvm::StringMap<::mlir::PDLConstraintFunction> &
Expand Down Expand Up @@ -114,7 +121,7 @@ def Transform_Dialect : Dialect {
}];
}

// Base class for ops that belong to the tranfsorm dialect. Ops defined in
// Base class for ops that belong to the transform dialect. Ops defined in
// extensions of this dialect may also use this.
class TransformDialectOp<string mnemonic, list<Trait> traits = []>
: Op<Transform_Dialect, mnemonic, traits>;
Expand Down
5 changes: 5 additions & 0 deletions mlir/include/mlir/Dialect/Transform/IR/TransformInterfaces.h
Expand Up @@ -847,6 +847,11 @@ bool isHandleConsumed(Value handle, transform::TransformOpInterface transform);
void modifiesPayload(SmallVectorImpl<MemoryEffects::EffectInstance> &effects);
void onlyReadsPayload(SmallVectorImpl<MemoryEffects::EffectInstance> &effects);

/// Populates `consumedArguments` with positions of `block` arguments that are
/// consumed by the operations in the `block`.
void getConsumedBlockArguments(
Block &block, llvm::SmallDenseSet<unsigned> &consumedArguments);

/// Trait implementing the MemoryEffectOpInterface for operations that "consume"
/// their operands and produce new results.
template <typename OpTy>
Expand Down
2 changes: 1 addition & 1 deletion mlir/include/mlir/Dialect/Transform/Transforms/Passes.h
@@ -1,4 +1,4 @@
//===- CheckUses.h - Expensive transform value validity checks --*- C++ -*-===//
//===- Passes.h - Transform dialect pass entry points -----------*- C++ -*-===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
Expand Down
10 changes: 10 additions & 0 deletions mlir/include/mlir/Dialect/Transform/Transforms/Passes.td
Expand Up @@ -32,4 +32,14 @@ def CheckUsesPass : Pass<"transform-dialect-check-uses"> {
}];
}

def InferEffectsPass : Pass<"transform-infer-effects"> {
let summary = "infer transform side effects for symbols";
let description = [{
This pass analyzes the definitions of transform dialect callable symbol
operations, such as `transform.named_sequence`, and annotates the symbol
arguments with attributes indicating the side effects that the nested
operations have on them.
}];
}

#endif // MLIR_DIALECT_TRANSFORM_TRANSFORMS_PASSES
8 changes: 8 additions & 0 deletions mlir/lib/Dialect/Transform/IR/TransformDialect.cpp
Expand Up @@ -175,6 +175,14 @@ LogicalResult transform::TransformDialect::verifyOperationAttribute(
}
return success();
}
if (attribute.getName().getValue() == kArgConsumedAttrName ||
attribute.getName().getValue() == kArgReadOnlyAttrName) {
if (!attribute.getValue().isa<UnitAttr>()) {
return op->emitError()
<< attribute.getName() << " must be a unit attribute";
}
return success();
}
return emitError(op->getLoc())
<< "unknown attribute: " << attribute.getName();
}
Expand Down
23 changes: 23 additions & 0 deletions mlir/lib/Dialect/Transform/IR/TransformInterfaces.cpp
Expand Up @@ -1318,6 +1318,29 @@ void transform::onlyReadsPayload(
effects.emplace_back(MemoryEffects::Read::get(), PayloadIRResource::get());
}

void transform::getConsumedBlockArguments(
Block &block, llvm::SmallDenseSet<unsigned int> &consumedArguments) {
SmallVector<MemoryEffects::EffectInstance> effects;
for (Operation &nested : block) {
auto iface = dyn_cast<MemoryEffectOpInterface>(nested);
if (!iface)
continue;

effects.clear();
iface.getEffects(effects);
for (const MemoryEffects::EffectInstance &effect : effects) {
BlockArgument argument =
dyn_cast_or_null<BlockArgument>(effect.getValue());
if (!argument || argument.getOwner() != &block ||
!isa<MemoryEffects::Free>(effect.getEffect()) ||
effect.getResource() != transform::TransformMappingResource::get()) {
continue;
}
consumedArguments.insert(argument.getArgNumber());
}
}
}

//===----------------------------------------------------------------------===//
// Utilities for TransformOpInterface.
//===----------------------------------------------------------------------===//
Expand Down
90 changes: 75 additions & 15 deletions mlir/lib/Dialect/Transform/IR/TransformOps.cpp
Expand Up @@ -720,39 +720,97 @@ verifyNamedSequenceOp(transform::NamedSequenceOp op);

void transform::IncludeOp::getEffects(
SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {
// Always mark as modifying the payload.
// TODO: a mechanism to annotate effects on payload. Even when all handles are
// only read, the payload may still be modified, so we currently stay on the
// conservative side and always indicate modification. This may prevent some
// code reordering.
modifiesPayload(effects);

// Results are always produced.
producesHandle(getResults(), effects);

// Adds default effects to operands and results. This will be added if
// preconditions fail so the trait verifier doesn't complain about missing
// effects and the real precondition failure is reported later on.
auto defaultEffects = [&] { onlyReadsHandle(getOperands(), effects); };

// Bail if the callee is unknown. This may run as part of the verification
// process before we verified the validity of the callee or of this op.
auto target =
getOperation()->getAttrOfType<SymbolRefAttr>(getTargetAttrName());
if (!target)
return;
return defaultEffects();
auto callee = SymbolTable::lookupNearestSymbolFrom<NamedSequenceOp>(
getOperation(), getTarget());
if (!callee)
return;
return defaultEffects();
DiagnosedSilenceableFailure earlyVerifierResult =
verifyNamedSequenceOp(callee);
if (!earlyVerifierResult.succeeded()) {
(void)earlyVerifierResult.silence();
return;
return defaultEffects();
}

// Carry over effects from the callee.
// TODO: external callees must provides attributes annotating the
// readonly/consume effects on operands.
if (!callee.isExternal())
remapArgumentEffects(callee.getBody().front(), getOperands(), effects);

// Proper effects.
onlyReadsHandle(getOperands(), effects);
producesHandle(getResults(), effects);
for (unsigned i = 0, e = getNumOperands(); i < e; ++i) {
if (callee.getArgAttr(i, TransformDialect::kArgConsumedAttrName))
consumesHandle(getOperand(i), effects);
else
onlyReadsHandle(getOperand(i), effects);
}
}

template <typename... Tys>
static bool implementSameInterface(Type t1, Type t2) {
return ((isa<Tys>(t1) && isa<Tys>(t2)) || ... || false);
}

/// Checks that the attributes of the named sequence operation have correct
/// consumption effect annotations. If `alsoVerifyInternal`, checks for
/// annotations being present even if they can be inferred from the body.
static DiagnosedSilenceableFailure
verifyNamedSequenceConsumeAnnotations(transform::NamedSequenceOp op,
bool alsoVerifyInternal = false) {
llvm::SmallDenseSet<unsigned> consumedArguments;
if (!op.isExternal()) {
transform::getConsumedBlockArguments(op.getBody().front(),
consumedArguments);
}
for (unsigned i = 0, e = op.getFunctionType().getNumInputs(); i < e; ++i) {
bool isConsumed =
op.getArgAttr(i, transform::TransformDialect::kArgConsumedAttrName) !=
nullptr;
bool isReadOnly =
op.getArgAttr(i, transform::TransformDialect::kArgReadOnlyAttrName) !=
nullptr;
if (isConsumed && isReadOnly) {
return op.emitSilenceableError()
<< "argument #" << i << " cannot be both readonly and consumed";
}
if ((op.isExternal() || alsoVerifyInternal) && !isConsumed && !isReadOnly) {
return op.emitSilenceableError()
<< "must provide consumed/readonly status for arguments of "
"external or called ops";
}
if (op.isExternal())
continue;

if (consumedArguments.contains(i) && !isConsumed && isReadOnly) {
return op.emitSilenceableError()
<< "argument #" << i
<< " is consumed in the body but is not marked as such";
}
if (!consumedArguments.contains(i) && isConsumed) {
Diagnostic warning(op->getLoc(), DiagnosticSeverity::Warning);
warning << "argument #" << i
<< " is not consumed in the body but is marked as consumed";
return DiagnosedSilenceableFailure::silenceableFailure(
std::move(warning));
}
}
return DiagnosedSilenceableFailure::success();
}

LogicalResult
transform::IncludeOp::verifySymbolUses(SymbolTableCollection &symbolTable) {
// Access through indirection and do additional checking because this may be
Expand Down Expand Up @@ -794,7 +852,9 @@ transform::IncludeOp::verifySymbolUses(SymbolTableCollection &symbolTable) {
}
}

return success();
return verifyNamedSequenceConsumeAnnotations(target,
/*alsoVerifyInternal=*/true)
.checkAndReport();
}

//===----------------------------------------------------------------------===//
Expand Down Expand Up @@ -899,7 +959,7 @@ verifyNamedSequenceOp(transform::NamedSequenceOp op) {
}

if (op.isExternal() || op.getBody().empty())
return DiagnosedSilenceableFailure::success();
return verifyNamedSequenceConsumeAnnotations(op);

if (op.getBody().front().empty())
return emitSilenceableFailure(op) << "expected a non-empty body block";
Expand Down Expand Up @@ -931,7 +991,7 @@ verifyNamedSequenceOp(transform::NamedSequenceOp op) {
<< operandType << " vs " << resultType << ")";
}

return DiagnosedSilenceableFailure::success();
return verifyNamedSequenceConsumeAnnotations(op);
}

LogicalResult transform::NamedSequenceOp::verify() {
Expand Down
1 change: 1 addition & 0 deletions mlir/lib/Dialect/Transform/Transforms/CMakeLists.txt
@@ -1,5 +1,6 @@
add_mlir_dialect_library(MLIRTransformDialectTransforms
CheckUses.cpp
InferEffects.cpp
TransformInterpreterPassBase.cpp

DEPENDS
Expand Down
69 changes: 69 additions & 0 deletions mlir/lib/Dialect/Transform/Transforms/InferEffects.cpp
@@ -0,0 +1,69 @@
//===- InferEffects.cpp - Infer memory effects for named symbols ----------===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//

#include "mlir/Dialect/Transform/IR/TransformDialect.h"
#include "mlir/Dialect/Transform/Transforms/Passes.h"

#include "mlir/Dialect/Transform/IR/TransformInterfaces.h"
#include "mlir/IR/FunctionInterfaces.h"
#include "mlir/IR/Visitors.h"
#include "mlir/Interfaces/SideEffectInterfaces.h"
#include "llvm/ADT/DenseSet.h"

using namespace mlir;

namespace mlir {
namespace transform {
#define GEN_PASS_DEF_INFEREFFECTSPASS
#include "mlir/Dialect/Transform/Transforms/Passes.h.inc"
} // namespace transform
} // namespace mlir

static LogicalResult inferSideEffectAnnotations(Operation *op) {
if (!isa<transform::TransformOpInterface>(op))
return success();

auto func = dyn_cast<FunctionOpInterface>(op);
if (!func || func.isExternal())
return success();

if (!func.getFunctionBody().hasOneBlock()) {
return op->emitError()
<< "only single-block operations are currently supported";
}

// Note that there can't be an inclusion of an unannotated symbol because it
// wouldn't have passed the verifier, so recursion isn't necessary here.
llvm::SmallDenseSet<unsigned> consumedArguments;
transform::getConsumedBlockArguments(func.getFunctionBody().front(),
consumedArguments);

for (unsigned i = 0, e = func.getNumArguments(); i < e; ++i) {
func.setArgAttr(i,
consumedArguments.contains(i)
? transform::TransformDialect::kArgConsumedAttrName
: transform::TransformDialect::kArgReadOnlyAttrName,
UnitAttr::get(op->getContext()));
}
return success();
}

namespace {
class InferEffectsPass
: public transform::impl::InferEffectsPassBase<InferEffectsPass> {
public:
void runOnOperation() override {
WalkResult result = getOperation()->walk([](Operation *op) {
return failed(inferSideEffectAnnotations(op)) ? WalkResult::interrupt()
: WalkResult::advance();
});
if (result.wasInterrupted())
return signalPassFailure();
}
};
} // namespace
Expand Up @@ -12,6 +12,7 @@
//===----------------------------------------------------------------------===//

#include "mlir/Dialect/Transform/Transforms/TransformInterpreterPassBase.h"
#include "mlir/Dialect/Transform/IR/TransformDialect.h"
#include "mlir/Dialect/Transform/IR/TransformInterfaces.h"
#include "mlir/IR/BuiltinOps.h"
#include "mlir/IR/FunctionInterfaces.h"
Expand Down Expand Up @@ -298,6 +299,12 @@ static void performOptionalDebugActions(
/// Replaces external symbols in `block` with their (non-external) definitions
/// from the given module.
static LogicalResult defineDeclaredSymbols(Block &block, ModuleOp definitions) {
MLIRContext &ctx = *definitions->getContext();
auto consumedName =
StringAttr::get(&ctx, transform::TransformDialect::kArgConsumedAttrName);
auto readOnlyName =
StringAttr::get(&ctx, transform::TransformDialect::kArgReadOnlyAttrName);

for (Operation &op : llvm::make_early_inc_range(block)) {
LLVM_DEBUG(DBGS() << op << "\n");
auto symbol = dyn_cast<SymbolOpInterface>(op);
Expand Down Expand Up @@ -330,6 +337,30 @@ static LogicalResult defineDeclaredSymbols(Block &block, ModuleOp definitions) {
<< externalSymbolFunc.getFunctionType() << ")";
}

for (unsigned i = 0, e = symbolFunc.getNumArguments(); i < e; ++i) {
bool isExternalConsumed =
externalSymbolFunc.getArgAttr(i, consumedName) != nullptr;
bool isExternalReadonly =
externalSymbolFunc.getArgAttr(i, readOnlyName) != nullptr;
bool isConsumed = symbolFunc.getArgAttr(i, consumedName) != nullptr;
bool isReadonly = symbolFunc.getArgAttr(i, readOnlyName) != nullptr;
if (!isExternalConsumed && !isExternalReadonly) {
if (isConsumed)
externalSymbolFunc.setArgAttr(i, consumedName, UnitAttr::get(&ctx));
else if (isReadonly)
externalSymbolFunc.setArgAttr(i, readOnlyName, UnitAttr::get(&ctx));
continue;
}

if ((isExternalConsumed && !isConsumed) ||
(isExternalReadonly && !isReadonly)) {
return symbolFunc.emitError()
<< "external definition has mismatching consumption annotations "
"for argument #"
<< i;
}
}

OpBuilder builder(&op);
builder.setInsertionPoint(&op);
builder.clone(*externalSymbol);
Expand Down
13 changes: 13 additions & 0 deletions mlir/test/Dialect/Transform/infer-effects.mlir
@@ -0,0 +1,13 @@
// RUN: mlir-opt %s --transform-infer-effects | FileCheck %s

module attributes { transform.with_named_sequence } {
// CHECK-LABEL: @infer
// CHECK-SAME: %{{.*}}: !transform.any_op {transform.consumed}
// CHECK-SAME: %{{.*}}: !transform.any_op {transform.readonly}
// CHECK-SAME: %{{.*}}: !transform.param<i32> {transform.readonly}
transform.named_sequence @infer(%op: !transform.any_op, %other: !transform.any_op, %param: !transform.param<i32>) {
transform.test_consume_operand %op : !transform.any_op
transform.test_print_remark_at_operand %other, "" : !transform.any_op
transform.yield
}
}

0 comments on commit 4110934

Please sign in to comment.