Skip to content

Commit 7fdc2ed

Browse files
committed
[mlir] reallow null results in TransformEachOpTrait
Previous changes in 98acd74 were overly eager to disallow null payload everywhere. The semantics of TransformEachOpTrait allows individual applications to return null payloads as means of filtering out the operations to which they are not applicable without emitting even a silenceable failure. This is a questionable choice, but one apparently relied upon. Null payloads are not supposed to leak outside of the trait. Reviewed By: qcolombet Differential Revision: https://reviews.llvm.org/D143904
1 parent 64dad4b commit 7fdc2ed

File tree

4 files changed

+55
-28
lines changed

4 files changed

+55
-28
lines changed

mlir/include/mlir/Dialect/Transform/IR/TransformInterfaces.h

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1015,6 +1015,8 @@ mlir::transform::TransformEachOpTrait<OpTy>::apply(
10151015
for (OpResult r : this->getOperation()->getResults()) {
10161016
if (r.getType().isa<TransformParamTypeInterface>())
10171017
transformResults.setParams(r, emptyParams);
1018+
else if (r.getType().isa<TransformValueHandleTypeInterface>())
1019+
transformResults.setValues(r, ValueRange());
10181020
else
10191021
transformResults.set(r, emptyPayload);
10201022
}

mlir/lib/Dialect/Transform/IR/TransformInterfaces.cpp

Lines changed: 35 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -921,48 +921,60 @@ transform::detail::checkApplyToOne(Operation *transformOp,
921921
// Check that the right kind of value was produced.
922922
for (const auto &[ptr, res] :
923923
llvm::zip(partialResult, transformOp->getResults())) {
924-
if (ptr.isNull()) {
925-
return emitDiag() << "null result #" << res.getResultNumber()
926-
<< " produced";
924+
if (ptr.isNull())
925+
continue;
926+
if (res.getType().template isa<TransformHandleTypeInterface>() &&
927+
!ptr.is<Operation *>()) {
928+
return emitDiag() << "application of " << transformOpName
929+
<< " expected to produce an Operation * for result #"
930+
<< res.getResultNumber();
927931
}
928-
if (ptr.is<Operation *>() &&
929-
!res.getType().template isa<TransformHandleTypeInterface>()) {
932+
if (res.getType().template isa<TransformParamTypeInterface>() &&
933+
!ptr.is<Attribute>()) {
930934
return emitDiag() << "application of " << transformOpName
931935
<< " expected to produce an Attribute for result #"
932936
<< res.getResultNumber();
933937
}
934-
if (ptr.is<Attribute>() &&
935-
!res.getType().template isa<TransformParamTypeInterface>()) {
938+
if (res.getType().template isa<TransformValueHandleTypeInterface>() &&
939+
!ptr.is<Value>()) {
936940
return emitDiag() << "application of " << transformOpName
937-
<< " expected to produce an Operation * for result #"
941+
<< " expected to produce a Value for result #"
938942
<< res.getResultNumber();
939943
}
940944
}
941945
return success();
942946
}
943947

948+
template <typename T>
949+
static SmallVector<T> castVector(ArrayRef<transform::MappedValue> range) {
950+
return llvm::to_vector(llvm::map_range(
951+
range, [](transform::MappedValue value) { return value.get<T>(); }));
952+
}
953+
944954
void transform::detail::setApplyToOneResults(
945955
Operation *transformOp, TransformResults &transformResults,
946956
ArrayRef<ApplyToEachResultList> results) {
957+
SmallVector<SmallVector<MappedValue>> transposed;
958+
transposed.resize(transformOp->getNumResults());
959+
for (const ApplyToEachResultList &partialResults : results) {
960+
if (llvm::any_of(partialResults,
961+
[](MappedValue value) { return value.isNull(); }))
962+
continue;
963+
assert(transformOp->getNumResults() == partialResults.size() &&
964+
"expected as many partial results as op as results");
965+
for (auto &[i, value] : llvm::enumerate(partialResults))
966+
transposed[i].push_back(value);
967+
}
968+
947969
for (OpResult r : transformOp->getResults()) {
970+
unsigned position = r.getResultNumber();
948971
if (r.getType().isa<TransformParamTypeInterface>()) {
949-
auto params = llvm::to_vector(
950-
llvm::map_range(results, [r](const ApplyToEachResultList &oneResult) {
951-
return oneResult[r.getResultNumber()].get<Attribute>();
952-
}));
953-
transformResults.setParams(r, params);
972+
transformResults.setParams(r,
973+
castVector<Attribute>(transposed[position]));
954974
} else if (r.getType().isa<TransformValueHandleTypeInterface>()) {
955-
auto values = llvm::to_vector(
956-
llvm::map_range(results, [r](const ApplyToEachResultList &oneResult) {
957-
return oneResult[r.getResultNumber()].get<Value>();
958-
}));
959-
transformResults.setValues(r, values);
975+
transformResults.setValues(r, castVector<Value>(transposed[position]));
960976
} else {
961-
auto payloads = llvm::to_vector(
962-
llvm::map_range(results, [r](const ApplyToEachResultList &oneResult) {
963-
return oneResult[r.getResultNumber()].get<Operation *>();
964-
}));
965-
transformResults.set(r, payloads);
977+
transformResults.set(r, castVector<Operation *>(transposed[position]));
966978
}
967979
}
968980
}

mlir/test/Dialect/Transform/test-interpreter.mlir

Lines changed: 17 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -495,8 +495,9 @@ transform.with_pdl_patterns {
495495

496496
// -----
497497

498+
// This should not fail.
499+
498500
func.func @foo() {
499-
// expected-note @below {{when applied to this op}}
500501
"op" () : () -> ()
501502
return
502503
}
@@ -513,7 +514,6 @@ transform.with_pdl_patterns {
513514
transform.sequence %arg0 : !pdl.operation failures(propagate) {
514515
^bb0(%arg1: !pdl.operation):
515516
%0 = pdl_match @some in %arg1 : (!pdl.operation) -> !pdl.operation
516-
// expected-error @below {{null result #0 produced}}
517517
transform.test_mixed_null_and_non_null_results %0
518518
}
519519
}
@@ -1053,11 +1053,11 @@ module {
10531053

10541054
// -----
10551055

1056-
// expected-note @below {{when applied to this op}}
1056+
// Should not fail.
1057+
10571058
module {
10581059
transform.sequence failures(propagate) {
10591060
^bb0(%arg0: !transform.any_op):
1060-
// expected-error @below {{null result #0 produced}}
10611061
transform.test_produce_transform_param_or_forward_operand %arg0
10621062
{ first_result_is_null }
10631063
: (!transform.any_op) -> (!transform.any_op, !transform.param<i64>)
@@ -1079,6 +1079,19 @@ module {
10791079

10801080
// -----
10811081

1082+
// expected-note @below {{when applied to this op}}
1083+
module {
1084+
transform.sequence failures(propagate) {
1085+
^bb0(%arg0: !transform.any_op):
1086+
// expected-error @below {{expected to produce a Value for result #0}}
1087+
transform.test_produce_transform_param_or_forward_operand %arg0
1088+
{ second_result_is_handle }
1089+
: (!transform.any_op) -> (!transform.any_value, !transform.param<i64>)
1090+
}
1091+
}
1092+
1093+
// -----
1094+
10821095
transform.sequence failures(propagate) {
10831096
^bb0(%arg0: !transform.any_op):
10841097
// expected-error @below {{attempting to assign a null payload op to this transform value}}

mlir/test/lib/Dialect/Transform/TestTransformDialectExtension.td

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -371,7 +371,7 @@ def TestProduceTransformParamOrForwardOperandOp
371371
UnitAttr:$first_result_is_param,
372372
UnitAttr:$first_result_is_null,
373373
UnitAttr:$second_result_is_handle);
374-
let results = (outs TransformHandleTypeInterface:$out,
374+
let results = (outs AnyType:$out,
375375
TransformParamTypeInterface:$param);
376376
let assemblyFormat = "$in attr-dict `:` functional-type(operands, results)";
377377
let cppNamespace = "::mlir::test";

0 commit comments

Comments
 (0)