diff --git a/mlir/lib/Dialect/Transform/IR/TransformDialect.cpp b/mlir/lib/Dialect/Transform/IR/TransformDialect.cpp index a500228d68c77..45cef9c162c70 100644 --- a/mlir/lib/Dialect/Transform/IR/TransformDialect.cpp +++ b/mlir/lib/Dialect/Transform/IR/TransformDialect.cpp @@ -13,6 +13,7 @@ #include "mlir/Dialect/Transform/IR/Utils.h" #include "mlir/Dialect/Transform/Interfaces/TransformInterfaces.h" #include "mlir/IR/DialectImplementation.h" +#include "mlir/IR/Verifier.h" #include "llvm/ADT/SCCIterator.h" #include "llvm/ADT/TypeSwitch.h" @@ -140,6 +141,20 @@ LogicalResult transform::TransformDialect::verifyOperationAttribute( "operations with symbol tables"; } + // Pre-verify calls and callables because call graph construction below + // assumes they are valid, but this verifier runs before verifying the + // nested operations. + WalkResult walkResult = op->walk([](Operation *nested) { + if (!isa(nested)) + return WalkResult::advance(); + + if (failed(verify(nested, /*verifyRecursively=*/false))) + return WalkResult::interrupt(); + return WalkResult::advance(); + }); + if (walkResult.wasInterrupted()) + return failure(); + const mlir::CallGraph callgraph(op); for (auto scc = llvm::scc_begin(&callgraph); !scc.isAtEnd(); ++scc) { if (!scc.hasCycle()) diff --git a/mlir/lib/Dialect/Transform/IR/TransformOps.cpp b/mlir/lib/Dialect/Transform/IR/TransformOps.cpp index 132ed815c354e..dfda42d5d18ee 100644 --- a/mlir/lib/Dialect/Transform/IR/TransformOps.cpp +++ b/mlir/lib/Dialect/Transform/IR/TransformOps.cpp @@ -2098,17 +2098,11 @@ void transform::IncludeOp::getEffects( getOperation(), getTarget()); if (!callee) return defaultEffects(); - DiagnosedSilenceableFailure earlyVerifierResult = - verifyNamedSequenceOp(callee, /*emitWarnings=*/false); - if (!earlyVerifierResult.succeeded()) { - (void)earlyVerifierResult.silence(); - return defaultEffects(); - } for (unsigned i = 0, e = getNumOperands(); i < e; ++i) { if (callee.getArgAttr(i, TransformDialect::kArgConsumedAttrName)) consumesHandle(getOperation()->getOpOperand(i), effects); - else + else if (callee.getArgAttr(i, TransformDialect::kArgReadOnlyAttrName)) onlyReadsHandle(getOperation()->getOpOperand(i), effects); } } diff --git a/mlir/test/Dialect/Transform/ops-invalid.mlir b/mlir/test/Dialect/Transform/ops-invalid.mlir index 71a260f1196e9..68305de73761a 100644 --- a/mlir/test/Dialect/Transform/ops-invalid.mlir +++ b/mlir/test/Dialect/Transform/ops-invalid.mlir @@ -369,6 +369,7 @@ module attributes { transform.with_named_sequence } { // expected-error @below {{recursion not allowed in named sequences}} transform.named_sequence @self_recursion() -> () { transform.include @self_recursion failures(suppress) () : () -> () + transform.yield } } @@ -376,13 +377,13 @@ module attributes { transform.with_named_sequence } { module @mutual_recursion attributes { transform.with_named_sequence } { // expected-note @below {{operation on recursion stack}} - transform.named_sequence @foo(%arg0: !transform.any_op) -> () { + transform.named_sequence @foo(%arg0: !transform.any_op {transform.readonly}) -> () { transform.include @bar failures(suppress) (%arg0) : (!transform.any_op) -> () transform.yield } // expected-error @below {{recursion not allowed in named sequences}} - transform.named_sequence @bar(%arg0: !transform.any_op) -> () { + transform.named_sequence @bar(%arg0: !transform.any_op {transform.readonly}) -> () { transform.include @foo failures(propagate) (%arg0) : (!transform.any_op) -> () transform.yield } @@ -430,7 +431,7 @@ module attributes { transform.with_named_sequence } { // ----- module attributes { transform.with_named_sequence } { - transform.named_sequence @foo(%arg0: !transform.any_op) -> () { + transform.named_sequence @foo(%arg0: !transform.any_op {transform.readonly}) -> () { transform.yield } @@ -444,7 +445,7 @@ module attributes { transform.with_named_sequence } { // ----- module attributes { transform.with_named_sequence } { - transform.named_sequence @foo(%arg0: !transform.any_op) -> (!transform.any_op) { + transform.named_sequence @foo(%arg0: !transform.any_op {transform.readonly}) -> (!transform.any_op) { transform.yield %arg0 : !transform.any_op } @@ -458,7 +459,7 @@ module attributes { transform.with_named_sequence } { // ----- module attributes { transform.with_named_sequence } { - transform.named_sequence @foo(%arg0: !transform.any_op) -> (!transform.any_op) { + transform.named_sequence @foo(%arg0: !transform.any_op {transform.readonly}) -> (!transform.any_op) { transform.yield %arg0 : !transform.any_op } @@ -543,7 +544,6 @@ module attributes { transform.with_named_sequence } { // ----- module attributes { transform.with_named_sequence } { - // expected-error @below {{must provide consumed/readonly status for arguments of external or called ops}} transform.named_sequence @foo(%op: !transform.any_op) { transform.debug.emit_remark_at %op, "message" : !transform.any_op transform.yield @@ -551,6 +551,8 @@ module attributes { transform.with_named_sequence } { transform.sequence failures(propagate) { ^bb0(%arg0: !transform.any_op): + // expected-error @below {{TransformOpInterface requires memory effects on operands to be specified}} + // expected-note @below {{no effects specified for operand #0}} transform.include @foo failures(propagate) (%arg0) : (!transform.any_op) -> () transform.yield } @@ -908,3 +910,54 @@ module attributes { transform.with_named_sequence } { transform.yield } } + +// ----- + +module attributes { transform.with_named_sequence } { + transform.named_sequence @__transform_main(%arg0: !transform.any_op) -> () { + // Intentionally malformed func with no region. This shouldn't crash the + // verifier of `with_named_sequence` that runs before we get to the + // function. + // expected-error @below {{requires one region}} + "func.func"() : () -> () + transform.yield + } +} + +// ----- + +module attributes { transform.with_named_sequence } { + transform.named_sequence @__transform_main(%arg0: !transform.any_op) -> () { + // Intentionally malformed call with a region. This shouldn't crash the + // verifier of `with_named_sequence` that runs before we get to the call. + // expected-error @below {{requires zero regions}} + "func.call"() <{ + function_type = () -> (), + sym_name = "lambda_function" + }> ({ + ^bb0: + "func.return"() : () -> () + }) : () -> () + transform.yield + } +} + +// ----- + +module attributes { transform.with_named_sequence } { + // Intentionally malformed sequence where the verifier should not crash. + // expected-error @below {{ op expects argument attribute array to have the same number of elements as the number of function arguments, got 1, but expected 3}} + "transform.named_sequence"() <{ + arg_attrs = [{transform.readonly}], + function_type = (i1, tensor, tensor) -> (), + sym_name = "print_message" + }> ({}) : () -> () + "transform.named_sequence"() <{ + function_type = (!transform.any_op) -> (), + sym_name = "reference_other_module" + }> ({ + ^bb0(%arg0: !transform.any_op): + "transform.include"(%arg0) <{target = @print_message}> : (!transform.any_op) -> () + "transform.yield"() : () -> () + }) : () -> () +}