626 changes: 235 additions & 391 deletions mlir/include/mlir/Dialect/Transform/IR/TransformInterfaces.h

Large diffs are not rendered by default.

43 changes: 32 additions & 11 deletions mlir/include/mlir/Dialect/Transform/IR/TransformInterfaces.td
Original file line number Diff line number Diff line change
Expand Up @@ -103,27 +103,21 @@ def TransformOpInterface : OpInterface<"TransformOpInterface"> {
}];
}

def TransformTypeInterface : TypeInterface<"TransformTypeInterface"> {
let description = [{
Types that can be used for Transform dialect handle values. Such types
define the properties of Payload IR operations associated with the handle.
A user of such a handle can assume that these properties have been verified
for any Payload IR operation associated with it.
}];

class TransformTypeInterfaceBase<string cppClass, string cppObjectType>
: TypeInterface<cppClass> {
let cppNamespace = "::mlir::transform";

let methods = [
InterfaceMethod<
/*desc=*/[{
Checks if the given list of associated Payload IR operations satisfy
the conditions defined by this type. If not, produces a silenceable
Checks if the given associated objects (Payload IR operations or attributes)
satisfy the conditions defined by this type. If not, produces a silenceable
error at the specified location.
}],
/*returnType=*/"::mlir::DiagnosedSilenceableFailure",
/*name=*/"checkPayload",
/*arguments=*/(ins "::mlir::Location":$loc,
"::mlir::ArrayRef<::mlir::Operation *>":$payload)
"::mlir::ArrayRef<" # cppObjectType # ">":$payload)
>
];

Expand All @@ -135,6 +129,29 @@ def TransformTypeInterface : TypeInterface<"TransformTypeInterface"> {
}];
}

def TransformHandleTypeInterface
: TransformTypeInterfaceBase<"TransformHandleTypeInterface",
"::mlir::Operation *"> {
let description = [{
Types that can be used for the Transform dialect handle values. Such types
define the properties of Payload IR operations associated with the handle.
A user of such a handle can assume that these properties have been verified
for any Payload IR operation associated with it.
}];
}

def TransformParamTypeInterface
: TransformTypeInterfaceBase<"TransformParamTypeInterface",
"::mlir::Attribute"> {
let description = [{
Types that can be used for the Transform dialect parameter values. Such types
define the structure of the parameters associated with the value, e.g., their
underlying type. A user of the value can assume that the parameter has been
verified.
}];

}

def FunctionalStyleTransformOpTrait
: NativeOpTrait<"FunctionalStyleTransformOpTrait"> {
let cppNamespace = "::mlir::transform";
Expand All @@ -148,4 +165,8 @@ def NavigationTransformOpTrait : NativeOpTrait<"NavigationTransformOpTrait"> {
let cppNamespace = "::mlir::transform";
}

def ParamProducerTransformOpTrait : NativeOpTrait<"ParamProducerTransformOpTrait"> {
let cppNamespace = "::mlir::transform";
}

#endif // MLIR_DIALECT_TRANSFORM_IR_TRANSFORM_INTERFACES_TD
1 change: 1 addition & 0 deletions mlir/include/mlir/Dialect/Transform/IR/TransformOps.h
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
#define MLIR_DIALECT_TRANSFORM_IR_TRANSFORMOPS_H

#include "mlir/Dialect/PDL/IR/PDLTypes.h"
#include "mlir/Dialect/Transform/IR/TransformInterfaces.h"
#include "mlir/Dialect/Transform/IR/TransformTypes.h"
#include "mlir/IR/OpDefinition.h"
#include "mlir/IR/OpImplementation.h"
Expand Down
50 changes: 25 additions & 25 deletions mlir/include/mlir/Dialect/Transform/IR/TransformOps.td
Original file line number Diff line number Diff line change
Expand Up @@ -88,8 +88,8 @@ def AlternativesOp : TransformDialectOp<"alternatives",
```
}];

let arguments = (ins Optional<TransformTypeInterface>:$scope);
let results = (outs Variadic<TransformTypeInterface>:$results);
let arguments = (ins Optional<TransformHandleTypeInterface>:$scope);
let results = (outs Variadic<TransformHandleTypeInterface>:$results);
let regions = (region VariadicRegion<SizedRegion<1>>:$alternatives);

let assemblyFormat =
Expand All @@ -102,14 +102,14 @@ def CastOp : TransformDialectOp<"cast",
[TransformOpInterface, TransformEachOpTrait,
DeclareOpInterfaceMethods<CastOpInterface>,
DeclareOpInterfaceMethods<MemoryEffectsOpInterface>]> {
let arguments = (ins TransformTypeInterface:$input);
let results = (outs TransformTypeInterface:$output);
let arguments = (ins TransformHandleTypeInterface:$input);
let results = (outs TransformHandleTypeInterface:$output);
let assemblyFormat = "$input attr-dict `:` type($input) `to` type($output)";

let extraClassDeclaration = [{
::mlir::DiagnosedSilenceableFailure applyToOne(
::mlir::Operation *target,
::llvm::SmallVectorImpl<::mlir::Operation *> &results,
::mlir::transform::ApplyToEachResultList &results,
::mlir::transform::TransformState &state);
}];
}
Expand Down Expand Up @@ -143,8 +143,8 @@ def ForeachOp : TransformDialectOp<"foreach",
merged and mapped to the same resulting handle.
}];

let arguments = (ins TransformTypeInterface:$target);
let results = (outs Variadic<TransformTypeInterface>:$results);
let arguments = (ins TransformHandleTypeInterface:$target);
let results = (outs Variadic<TransformHandleTypeInterface>:$results);
let regions = (region SizedRegion<1>:$body);
let assemblyFormat =
"$target `:` type($target) (`->` type($results)^)? $body attr-dict";
Expand Down Expand Up @@ -183,8 +183,8 @@ def GetClosestIsolatedParentOp : TransformDialectOp<"get_closest_isolated_parent
on the further transformation applied to the handle produced here.
}];

let arguments = (ins TransformTypeInterface:$target);
let results = (outs TransformTypeInterface:$parent);
let arguments = (ins TransformHandleTypeInterface:$target);
let results = (outs TransformHandleTypeInterface:$parent);
let assemblyFormat =
"$target attr-dict `:` functional-type(operands, results)";
}
Expand All @@ -202,9 +202,9 @@ def GetProducerOfOperand : TransformDialectOp<"get_producer_of_operand",
computational operations, which can be empty.
}];

let arguments = (ins TransformTypeInterface:$target,
let arguments = (ins TransformHandleTypeInterface:$target,
I64Attr:$operand_number);
let results = (outs TransformTypeInterface:$parent);
let results = (outs TransformHandleTypeInterface:$parent);
let assemblyFormat = "$target `[` $operand_number `]` attr-dict `:` "
"functional-type(operands, results)";
}
Expand All @@ -225,9 +225,9 @@ def MergeHandlesOp : TransformDialectOp<"merge_handles",
same or different handles. Consumes the operands and produces a new handle.
}];

let arguments = (ins Variadic<TransformTypeInterface>:$handles,
let arguments = (ins Variadic<TransformHandleTypeInterface>:$handles,
UnitAttr:$deduplicate);
let results = (outs TransformTypeInterface:$result);
let results = (outs TransformHandleTypeInterface:$result);
let assemblyFormat = "($deduplicate^)? $handles attr-dict `:` type($result)";
let hasFolder = 1;
}
Expand All @@ -250,9 +250,9 @@ def SplitHandlesOp : TransformDialectOp<"split_handles",
operations contained in the source `handle`. Otherwise it silently fails.
}];

let arguments = (ins TransformTypeInterface:$handle,
let arguments = (ins TransformHandleTypeInterface:$handle,
I64Attr:$num_result_handles);
let results = (outs Variadic<TransformTypeInterface>:$results);
let results = (outs Variadic<TransformHandleTypeInterface>:$results);

let builders = [
OpBuilder<(ins "Value":$handle, "int64_t":$numResultHandles)>
Expand Down Expand Up @@ -286,10 +286,10 @@ def PDLMatchOp : TransformDialectOp<"pdl_match",
}];

let arguments = (ins
Arg<TransformTypeInterface, "Payload IR scope to match within">:$root,
Arg<TransformHandleTypeInterface, "Payload IR scope to match within">:$root,
SymbolRefAttr:$pattern_name);
let results = (outs
Res<TransformTypeInterface, "Handle to the matched Payload IR ops">:$matched);
Res<TransformHandleTypeInterface, "Handle to the matched Payload IR ops">:$matched);

let assemblyFormat = "$pattern_name `in` $root attr-dict `:` "
"functional-type(operands, results)";
Expand All @@ -307,7 +307,7 @@ def PrintOp : TransformDialectOp<"print",
This op is useful for printf-style debugging.
}];

let arguments = (ins Optional<TransformTypeInterface>:$target,
let arguments = (ins Optional<TransformHandleTypeInterface>:$target,
OptionalAttr<StrAttr>:$name);
let results = (outs);

Expand Down Expand Up @@ -349,9 +349,9 @@ def ReplicateOp : TransformDialectOp<"replicate",
MergeHandlesOp can be used to construct arbitrary lists with repetitions.
}];

let arguments = (ins TransformTypeInterface:$pattern,
Variadic<TransformTypeInterface>:$handles);
let results = (outs Variadic<TransformTypeInterface>:$replicated);
let arguments = (ins TransformHandleTypeInterface:$pattern,
Variadic<TransformHandleTypeInterface>:$handles);
let results = (outs Variadic<TransformHandleTypeInterface>:$replicated);
let assemblyFormat = "`num` `(` $pattern `)` $handles attr-dict `:` "
"type($pattern) `,` type($handles)";
}
Expand Down Expand Up @@ -396,8 +396,8 @@ def SequenceOp : TransformDialectOp<"sequence",
}];

let arguments = (ins FailurePropagationMode:$failure_propagation_mode,
Optional<TransformTypeInterface>:$root);
let results = (outs Variadic<TransformTypeInterface>:$results);
Optional<TransformHandleTypeInterface>:$root);
let results = (outs Variadic<TransformHandleTypeInterface>:$results);
let regions = (region SizedRegion<1>:$body);

let assemblyFormat =
Expand Down Expand Up @@ -467,7 +467,7 @@ def WithPDLPatternsOp : TransformDialectOp<"with_pdl_patterns",
}];

let arguments = (ins
Arg<Optional<TransformTypeInterface>, "Root operation of the Payload IR",
Arg<Optional<TransformHandleTypeInterface>, "Root operation of the Payload IR",
[TransformMappingRead]>:$root);
let regions = (region SizedRegion<1>:$body);
let assemblyFormat = "($root^ `:` type($root))? attr-dict-with-keyword regions";
Expand All @@ -489,7 +489,7 @@ def YieldOp : TransformDialectOp<"yield", [Terminator]> {
}];

let arguments = (ins
Arg<Variadic<TransformTypeInterface>, "Operation handles yielded back to the parent",
Arg<Variadic<TransformHandleTypeInterface>, "Operation handles yielded back to the parent",
[TransformMappingRead]>:$operands);
let assemblyFormat = "operands attr-dict (`:` type($operands)^)?";

Expand Down
2 changes: 1 addition & 1 deletion mlir/include/mlir/Dialect/Transform/IR/TransformTypes.h
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
#ifndef MLIR_DIALECT_TRANSFORM_IR_TRANSFORMTYPES_H
#define MLIR_DIALECT_TRANSFORM_IR_TRANSFORMTYPES_H

#include "mlir/Dialect/Transform/IR/TransformInterfaces.h"
#include "mlir/Dialect/Transform/Utils/DiagnosedSilenceableFailure.h"
#include "mlir/IR/Types.h"
#include "mlir/Support/LLVM.h"

Expand Down
20 changes: 18 additions & 2 deletions mlir/include/mlir/Dialect/Transform/IR/TransformTypes.td
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ include "mlir/Dialect/Transform/IR/TransformInterfaces.td"
include "mlir/Dialect/Transform/IR/TransformDialect.td"

def Transform_AnyOpType : TypeDef<Transform_Dialect, "AnyOp",
[DeclareTypeInterfaceMethods<TransformTypeInterface>]> {
[DeclareTypeInterfaceMethods<TransformHandleTypeInterface>]> {
let description = [{
Transform IR handle that can be associated with a list of arbitrary
Payload IR operations.
Expand All @@ -24,7 +24,7 @@ def Transform_AnyOpType : TypeDef<Transform_Dialect, "AnyOp",
}

def Transform_OperationType : TypeDef<Transform_Dialect, "Operation",
[DeclareTypeInterfaceMethods<TransformTypeInterface>]> {
[DeclareTypeInterfaceMethods<TransformHandleTypeInterface>]> {
let description = [{
Transform IR handle that can be associated with a list of Payload IR
operations with the specified operation name.
Expand All @@ -36,6 +36,22 @@ def Transform_OperationType : TypeDef<Transform_Dialect, "Operation",
let assemblyFormat = "`<` $operation_name `>`";
}

def Transform_ParamType : TypeDef<Transform_Dialect, "Param",
[DeclareTypeInterfaceMethods<TransformParamTypeInterface>]> {
let description = [{
Transform IR value that can be associated with the list of parameters
of the given type. Types are currently limited to integers, but may be
extended in the future to other types values of which can be contained
in attributes.
}];
let mnemonic = "param";
let parameters = (ins
TypeParameter<"::mlir::Type", "Underlying type of the parameter">:$type
);
let assemblyFormat = "`<` $type `>`";
let genVerifyDecl = 1;
}

class Transform_ConcreteOpType<string opname>
: Type<And<[Transform_OperationType.predicate,
CPred<"$_self.cast<::mlir::transform::OperationType>()"
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,266 @@
//===- DiagnosedSilenceableFailure.h - Tri-state result ----------- C++ -*-===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
//
// This file declares the DiagnosedSilenceableFailure class allowing to store
// a tri-state result (definite failure, recoverable failure, success) with an
// optional associated list of diagnostics.
//
//===----------------------------------------------------------------------===//

#include "mlir/IR/Diagnostics.h"
#include "mlir/IR/Operation.h"

#ifndef MLIR_DIALECT_TRANSFORM_UTILS_DIAGNOSEDSILENCEABLEFAILURE_H
#define MLIR_DIALECT_TRANSFORM_UTILS_DIAGNOSEDSILENCEABLEFAILURE_H

namespace mlir {
/// The result of a transform IR operation application. This can have one of the
/// three states:
/// - success;
/// - silenceable (recoverable) failure with yet-unreported diagnostic;
/// - definite failure.
/// Silenceable failure is intended to communicate information about
/// transformations that did not apply but in a way that supports recovery,
/// for example, they did not modify the payload IR or modified it in some
/// predictable way. They are associated with a Diagnostic that provides more
/// details on the failure. Silenceable failure can be discarded, turning the
/// result into success, or "reported", emitting the diagnostic and turning the
/// result into definite failure.
/// Transform IR operations containing other operations are allowed to do either
/// with the results of the nested transformations, but must propagate definite
/// failures as their diagnostics have been already reported to the user.
class [[nodiscard]] DiagnosedSilenceableFailure {
public:
DiagnosedSilenceableFailure(const DiagnosedSilenceableFailure &) = delete;
DiagnosedSilenceableFailure &
operator=(const DiagnosedSilenceableFailure &) = delete;
DiagnosedSilenceableFailure(DiagnosedSilenceableFailure &&) = default;
DiagnosedSilenceableFailure &
operator=(DiagnosedSilenceableFailure &&) = default;

/// Constructs a DiagnosedSilenceableFailure in the success state.
static DiagnosedSilenceableFailure success() {
return DiagnosedSilenceableFailure(::mlir::success());
}

/// Constructs a DiagnosedSilenceableFailure in the failure state. Typically,
/// a diagnostic has been emitted before this.
static DiagnosedSilenceableFailure definiteFailure() {
return DiagnosedSilenceableFailure(::mlir::failure());
}

/// Constructs a DiagnosedSilenceableFailure in the silenceable failure state,
/// ready to emit the given diagnostic. This is considered a failure
/// regardless of the diagnostic severity.
static DiagnosedSilenceableFailure silenceableFailure(Diagnostic &&diag) {
return DiagnosedSilenceableFailure(std::forward<Diagnostic>(diag));
}
static DiagnosedSilenceableFailure
silenceableFailure(SmallVector<Diagnostic> &&diag) {
return DiagnosedSilenceableFailure(
std::forward<SmallVector<Diagnostic>>(diag));
}

/// Converts all kinds of failure into a LogicalResult failure, emitting the
/// diagnostic if necessary. Must not be called more than once.
LogicalResult checkAndReport();

/// Returns `true` if this is a success.
bool succeeded() const {
return ::mlir::succeeded(result) && diagnostics.empty();
}

/// Returns `true` if this is a definite failure.
bool isDefiniteFailure() const {
return ::mlir::failed(result) && diagnostics.empty();
}

/// Returns `true` if this is a silenceable failure.
bool isSilenceableFailure() const { return !diagnostics.empty(); }

/// Returns the diagnostic message without emitting it. Expects this object
/// to be a silenceable failure.
std::string getMessage() const {
std::string res;
for (auto &diagnostic : diagnostics) {
res.append(diagnostic.str());
res.append("\n");
}
return res;
}

/// Returns a string representation of the failure mode (for error reporting).
std::string getStatusString() const {
if (succeeded())
return "success";
if (isSilenceableFailure())
return "silenceable failure";
return "definite failure";
}

/// Converts silenceable failure into LogicalResult success without reporting
/// the diagnostic, preserves the other states.
LogicalResult silence() {
if (!diagnostics.empty()) {
diagnostics.clear();
result = ::mlir::success();
}
return result;
}

/// Take the diagnostics and silence.
void takeDiagnostics(SmallVectorImpl<Diagnostic> &diags) {
assert(!diagnostics.empty() && "expected a diagnostic to be present");
diags.append(std::make_move_iterator(diagnostics.begin()),
std::make_move_iterator(diagnostics.end()));
}

/// Streams the given values into the last diagnostic.
/// Expects this object to be a silenceable failure.
template <typename T>
DiagnosedSilenceableFailure &operator<<(T &&value) & {
assert(isSilenceableFailure() &&
"can only append output in silenceable failure state");
diagnostics.back() << std::forward<T>(value);
return *this;
}
template <typename T>
DiagnosedSilenceableFailure &&operator<<(T &&value) && {
return std::move(this->operator<<(std::forward<T>(value)));
}

/// Attaches a note to the last diagnostic.
/// Expects this object to be a silenceable failure.
Diagnostic &attachNote(Optional<Location> loc = std::nullopt) {
assert(isSilenceableFailure() &&
"can only attach notes to silenceable failures");
return diagnostics.back().attachNote(loc);
}

private:
explicit DiagnosedSilenceableFailure(LogicalResult result) : result(result) {}
explicit DiagnosedSilenceableFailure(Diagnostic &&diagnostic)
: result(failure()) {
diagnostics.emplace_back(std::move(diagnostic));
}
explicit DiagnosedSilenceableFailure(SmallVector<Diagnostic> &&diagnostics)
: diagnostics(std::move(diagnostics)), result(failure()) {}

/// The diagnostics associated with this object. If non-empty, the object is
/// considered to be in the silenceable failure state regardless of the
/// `result` field.
SmallVector<Diagnostic, 1> diagnostics;

/// The "definite" logical state, either success or failure.
/// Ignored if the diagnostics message is present.
LogicalResult result;

#if LLVM_ENABLE_ABI_BREAKING_CHECKS
/// Whether the associated diagnostics have been reported.
/// Diagnostics reporting consumes the diagnostics, so we need a mechanism to
/// differentiate reported diagnostics from a state where it was never
/// created.
bool reported = false;
#endif // LLVM_ENABLE_ABI_BREAKING_CHECKS
};

class DiagnosedDefiniteFailure;

DiagnosedDefiniteFailure emitDefiniteFailure(Location loc,
const Twine &message = {});

/// A compatibility class connecting `InFlightDiagnostic` to
/// `DiagnosedSilenceableFailure` while providing an interface similar to the
/// former. Implicitly convertible to `DiagnosticSilenceableFailure` in definite
/// failure state and to `LogicalResult` failure. Reports the error on
/// conversion or on destruction. Instances of this class can be created by
/// `emitDefiniteFailure()`.
class DiagnosedDefiniteFailure {
friend DiagnosedDefiniteFailure emitDefiniteFailure(Location loc,
const Twine &message);

public:
/// Only move-constructible because it carries an in-flight diagnostic.
DiagnosedDefiniteFailure(DiagnosedDefiniteFailure &&) = default;

/// Forward the message to the diagnostic.
template <typename T>
DiagnosedDefiniteFailure &operator<<(T &&value) & {
diag << std::forward<T>(value);
return *this;
}
template <typename T>
DiagnosedDefiniteFailure &&operator<<(T &&value) && {
return std::move(this->operator<<(std::forward<T>(value)));
}

/// Attaches a note to the error.
Diagnostic &attachNote(Optional<Location> loc = std::nullopt) {
return diag.attachNote(loc);
}

/// Implicit conversion to DiagnosedSilenceableFailure in the definite failure
/// state. Reports the error.
operator DiagnosedSilenceableFailure() {
diag.report();
return DiagnosedSilenceableFailure::definiteFailure();
}

/// Implicit conversion to LogicalResult in the failure state. Reports the
/// error.
operator LogicalResult() {
diag.report();
return failure();
}

private:
/// Constructs a definite failure at the given location with the given
/// message.
explicit DiagnosedDefiniteFailure(Location loc, const Twine &message)
: diag(emitError(loc, message)) {}

/// Copy-construction and any assignment is disallowed to prevent repeated
/// error reporting.
DiagnosedDefiniteFailure(const DiagnosedDefiniteFailure &) = delete;
DiagnosedDefiniteFailure &
operator=(const DiagnosedDefiniteFailure &) = delete;
DiagnosedDefiniteFailure &operator=(DiagnosedDefiniteFailure &&) = delete;

/// The error message.
InFlightDiagnostic diag;
};

/// Emits a definite failure with the given message. The returned object allows
/// for last-minute modification to the error message, such as attaching notes
/// and completing the message. It will be reported when the object is
/// destructed or converted.
inline DiagnosedDefiniteFailure emitDefiniteFailure(Location loc,
const Twine &message) {
return DiagnosedDefiniteFailure(loc, message);
}
inline DiagnosedDefiniteFailure emitDefiniteFailure(Operation *op,
const Twine &message = {}) {
return emitDefiniteFailure(op->getLoc(), message);
}

/// Emits a silenceable failure with the given message. A silenceable failure
/// must be either suppressed or converted into a definite failure and reported
/// to the user.
inline DiagnosedSilenceableFailure
emitSilenceableFailure(Location loc, const Twine &message = {}) {
Diagnostic diag(loc, DiagnosticSeverity::Error);
diag << message;
return DiagnosedSilenceableFailure::silenceableFailure(std::move(diag));
}
inline DiagnosedSilenceableFailure
emitSilenceableFailure(Operation *op, const Twine &message = {}) {
return emitSilenceableFailure(op->getLoc(), message);
}
} // namespace mlir

#endif // MLIR_DIALECT_TRANSFORM_UTILS_DIAGNOSEDSILENCEABLEFAILURE_H
2 changes: 0 additions & 2 deletions mlir/include/mlir/Dialect/Transform/Utils/Utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -14,8 +14,6 @@
#include "mlir/IR/ValueRange.h"
#include "mlir/Support/LLVM.h"

#include "llvm/ADT/SmallVector.h"

namespace mlir {
class OpAsmPrinter;

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -75,7 +75,7 @@ void transform::OneShotBufferizeOp::getEffects(

DiagnosedSilenceableFailure
EmptyTensorToAllocTensorOp::applyToOne(tensor::EmptyOp target,
SmallVector<Operation *> &results,
ApplyToEachResultList &results,
transform::TransformState &state) {
IRRewriter rewriter(target->getContext());
rewriter.setInsertionPoint(target);
Expand Down
20 changes: 9 additions & 11 deletions mlir/lib/Dialect/GPU/TransformOps/GPUTransformOps.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -291,14 +291,14 @@ static void generateGpuBlockIds(RewriterBase &rewriter,

DiagnosedSilenceableFailure
transform::MapForeachToBlocks::applyToOne(Operation *target,
SmallVectorImpl<Operation *> &results,
ApplyToEachResultList &results,
transform::TransformState &state) {
LaunchOp gpuLaunch = dyn_cast<LaunchOp>(target);
TrivialPatternRewriter rewriter(getContext());
auto transformOp = cast<TransformOpInterface>(getOperation());

if (!getGenerateGpuLaunch() && !gpuLaunch) {
results.assign({target});
results.push_back(target);
DiagnosedSilenceableFailure diag =
emitSilenceableError()
<< "Given target is not gpu.launch, set `generate_gpu_launch` "
Expand All @@ -312,7 +312,7 @@ transform::MapForeachToBlocks::applyToOne(Operation *target,
mlir::transform::gpu::findTopLevelForeachThreadOp(
target, topLevelForeachThreadOp, transformOp);
if (!diag.succeeded()) {
results.assign({target});
results.push_back(target);
diag.attachNote(target->getLoc()) << "when applied to this payload op";
return diag;
}
Expand All @@ -325,7 +325,7 @@ transform::MapForeachToBlocks::applyToOne(Operation *target,
DiagnosedSilenceableFailure diag =
createGpuLaunch(rewriter, target->getLoc(), transformOp, gpuLaunch);
if (!diag.succeeded()) {
results.assign({target});
results.push_back(target);
return diag;
}
rewriter.setInsertionPointToStart(&gpuLaunch.getBody().front());
Expand All @@ -352,7 +352,7 @@ transform::MapForeachToBlocks::applyToOne(Operation *target,
gridDim[0], gridDim[1], gridDim[2]);
}

results.assign({gpuLaunch});
results.push_back(gpuLaunch);
return diag;
}

Expand Down Expand Up @@ -520,14 +520,12 @@ DiagnosedSilenceableFailure mlir::transform::gpu::mapNestedForeachToThreadsImpl(
}

DiagnosedSilenceableFailure transform::MapNestedForeachToThreads::applyToOne(
::mlir::Operation *target,
::llvm::SmallVectorImpl<::mlir::Operation *> &results,
::mlir::transform::TransformState &state) {
Operation *target, ApplyToEachResultList &results, TransformState &state) {
LaunchOp gpuLaunch = dyn_cast<LaunchOp>(target);
auto transformOp = cast<TransformOpInterface>(getOperation());

if (!gpuLaunch) {
results.assign({target});
results.push_back(target);
return emitSilenceableError() << "Given target is not gpu.launch";
}

Expand All @@ -538,7 +536,7 @@ DiagnosedSilenceableFailure transform::MapNestedForeachToThreads::applyToOne(
checkGpuLimits(transformOp, std::nullopt, std::nullopt, std::nullopt,
blockDim[0], blockDim[1], blockDim[2]);
if (diag.isSilenceableFailure()) {
results.assign({target});
results.push_back(target);
diag.attachNote(getLoc()) << getBlockDimAttrName() << " is very large";
return diag;
}
Expand All @@ -562,7 +560,7 @@ DiagnosedSilenceableFailure transform::MapNestedForeachToThreads::applyToOne(
blockDim[2]);
}

results.assign({gpuLaunch});
results.push_back(gpuLaunch.getOperation());
return diag;
}

Expand Down
34 changes: 19 additions & 15 deletions mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@
#include "mlir/IR/OpDefinition.h"
#include "mlir/Interfaces/TilingInterface.h"
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
#include "llvm/ADT/ScopeExit.h"
#include "llvm/ADT/StringSet.h"
#include "llvm/Support/Debug.h"

Expand Down Expand Up @@ -65,7 +66,7 @@ static FailureOr<LinalgOp> tryApply(Operation *operation, Args &&...args) {

DiagnosedSilenceableFailure
transform::DecomposeOp::applyToOne(linalg::LinalgOp target,
SmallVectorImpl<Operation *> &results,
transform::ApplyToEachResultList &results,
transform::TransformState &state) {
#define DOWNSCALE(trans) \
{ \
Expand Down Expand Up @@ -576,7 +577,7 @@ transform::FuseIntoContainingOp::apply(transform::TransformResults &results,

DiagnosedSilenceableFailure
transform::GeneralizeOp::applyToOne(linalg::LinalgOp target,
SmallVectorImpl<Operation *> &results,
transform::ApplyToEachResultList &results,
transform::TransformState &state) {
// Exit early if no transformation is needed.
if (isa<GenericOp>(target)) {
Expand All @@ -598,7 +599,7 @@ transform::GeneralizeOp::applyToOne(linalg::LinalgOp target,

DiagnosedSilenceableFailure
transform::InterchangeOp::applyToOne(linalg::GenericOp target,
SmallVectorImpl<Operation *> &results,
transform::ApplyToEachResultList &results,
transform::TransformState &state) {
ArrayRef<int64_t> interchangeVector = getIteratorInterchange();
// Exit early if no transformation is needed.
Expand Down Expand Up @@ -707,7 +708,8 @@ transform::MatchOp::apply(transform::TransformResults &results,
//===---------------------------------------------------------------------===//

DiagnosedSilenceableFailure transform::MultiTileSizesOp::applyToOne(
LinalgOp target, SmallVector<Operation *> &results, TransformState &state) {
LinalgOp target, transform::ApplyToEachResultList &results,
TransformState &state) {
OpBuilder builder(target.getContext());
builder.setInsertionPoint(target);
OpFoldResult targetSize = builder.getIndexAttr(getTargetSize());
Expand Down Expand Up @@ -747,7 +749,7 @@ void transform::MultiTileSizesOp::getEffects(

DiagnosedSilenceableFailure
transform::PadOp::applyToOne(linalg::LinalgOp target,
SmallVectorImpl<Operation *> &results,
transform::ApplyToEachResultList &results,
transform::TransformState &state) {
// Convert the integer packing flags to booleans.
SmallVector<bool> packPaddings;
Expand Down Expand Up @@ -860,7 +862,7 @@ LogicalResult transform::PadOp::verify() {

DiagnosedSilenceableFailure
transform::PromoteOp::applyToOne(linalg::LinalgOp target,
SmallVectorImpl<Operation *> &results,
transform::ApplyToEachResultList &results,
transform::TransformState &state) {
LinalgPromotionOptions promotionOptions;
if (!getOperandsToPromote().empty())
Expand Down Expand Up @@ -954,7 +956,7 @@ LogicalResult transform::ReplaceOp::verify() {

DiagnosedSilenceableFailure
transform::ScalarizeOp::applyToOne(linalg::LinalgOp target,
SmallVectorImpl<Operation *> &results,
transform::ApplyToEachResultList &results,
transform::TransformState &state) {
scf::SCFTilingOptions tilingOptions;
tilingOptions.setTileSizeComputationFunction([&](OpBuilder &b, Operation *) {
Expand Down Expand Up @@ -990,7 +992,10 @@ transform::ScalarizeOp::applyToOne(linalg::LinalgOp target,
rewriter.replaceOp(target, maybeTilingResult->replacements);
else
rewriter.eraseOp(target);
results.append(maybeTilingResult->tiledOps);

results.reserve(maybeTilingResult->tiledOps.size());
for (Operation *tiled : maybeTilingResult->tiledOps)
results.push_back(tiled);
return DiagnosedSilenceableFailure::success();
}

Expand Down Expand Up @@ -1171,10 +1176,9 @@ void transform::SplitReductionOp::build(
result.addTypes({resultType, resultType, resultType, resultType});
}

DiagnosedSilenceableFailure
transform::SplitReductionOp::applyToOne(linalg::LinalgOp target,
SmallVectorImpl<Operation *> &results,
transform::TransformState &state) {
DiagnosedSilenceableFailure transform::SplitReductionOp::applyToOne(
linalg::LinalgOp target, transform::ApplyToEachResultList &results,
transform::TransformState &state) {
ControlSplitReductionFn splitFn = [&](LinalgOp) {
return linalg::SplitReductionOptions{int64_t(getSplitFactor()),
unsigned(getInsertSplitDimension()),
Expand Down Expand Up @@ -1218,7 +1222,7 @@ void transform::TileReductionUsingScfOp::build(
}

DiagnosedSilenceableFailure transform::TileReductionUsingScfOp::applyToOne(
linalg::LinalgOp target, SmallVectorImpl<Operation *> &results,
linalg::LinalgOp target, transform::ApplyToEachResultList &results,
transform::TransformState &state) {
TrivialPatternRewriter rewriter(getContext());
rewriter.setInsertionPoint(target);
Expand Down Expand Up @@ -1262,7 +1266,7 @@ void transform::TileReductionUsingForeachThreadOp::build(

DiagnosedSilenceableFailure
transform::TileReductionUsingForeachThreadOp::applyToOne(
linalg::LinalgOp target, SmallVectorImpl<Operation *> &results,
linalg::LinalgOp target, transform::ApplyToEachResultList &results,
transform::TransformState &state) {
TrivialPatternRewriter rewriter(getContext());
rewriter.setInsertionPoint(target);
Expand Down Expand Up @@ -1951,7 +1955,7 @@ struct VectorizationPattern : public RewritePattern {

DiagnosedSilenceableFailure
transform::VectorizeOp::applyToOne(Operation *target,
SmallVectorImpl<Operation *> &results,
transform::ApplyToEachResultList &results,
transform::TransformState &state) {
if (!target->hasTrait<OpTrait::IsIsolatedFromAbove>()) {
auto diag = this->emitOpError("requires isolated-from-above targets");
Expand Down
7 changes: 3 additions & 4 deletions mlir/lib/Dialect/MemRef/TransformOps/MemRefTransformOps.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -21,10 +21,9 @@ using namespace mlir;
// MemRefMultiBufferOp
//===----------------------------------------------------------------------===//

DiagnosedSilenceableFailure
transform::MemRefMultiBufferOp::applyToOne(memref::AllocOp target,
SmallVector<Operation *> &results,
transform::TransformState &state) {
DiagnosedSilenceableFailure transform::MemRefMultiBufferOp::applyToOne(
memref::AllocOp target, transform::ApplyToEachResultList &results,
transform::TransformState &state) {
auto newBuffer = memref::multiBuffer(target, getFactor());
if (failed(newBuffer)) {
Diagnostic diag(target->getLoc(), DiagnosticSeverity::Note);
Expand Down
6 changes: 3 additions & 3 deletions mlir/lib/Dialect/SCF/TransformOps/SCFTransformOps.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -125,7 +125,7 @@ transform::LoopOutlineOp::apply(transform::TransformResults &results,

DiagnosedSilenceableFailure
transform::LoopPeelOp::applyToOne(scf::ForOp target,
SmallVector<Operation *> &results,
transform::ApplyToEachResultList &results,
transform::TransformState &state) {
scf::ForOp result;
IRRewriter rewriter(target->getContext());
Expand Down Expand Up @@ -182,7 +182,7 @@ loopScheduling(scf::ForOp forOp,

DiagnosedSilenceableFailure
transform::LoopPipelineOp::applyToOne(scf::ForOp target,
SmallVector<Operation *> &results,
transform::ApplyToEachResultList &results,
transform::TransformState &state) {
scf::PipeliningOption options;
options.getScheduleFn =
Expand Down Expand Up @@ -210,7 +210,7 @@ transform::LoopPipelineOp::applyToOne(scf::ForOp target,

DiagnosedSilenceableFailure
transform::LoopUnrollOp::applyToOne(Operation *op,
SmallVector<Operation *> &results,
transform::ApplyToEachResultList &results,
transform::TransformState &state) {
LogicalResult result(failure());
if (scf::ForOp scfFor = dyn_cast<scf::ForOp>(op))
Expand Down
1 change: 1 addition & 0 deletions mlir/lib/Dialect/Transform/IR/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -15,4 +15,5 @@ add_mlir_dialect_library(MLIRTransformDialect
MLIRPDLInterpDialect
MLIRRewrite
MLIRSideEffectInterfaces
MLIRTransformDialectUtils
)
18 changes: 12 additions & 6 deletions mlir/lib/Dialect/Transform/IR/TransformDialect.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -35,17 +35,23 @@ void transform::detail::checkImplementsTransformOpInterface(
"MemoryEffectsOpInterface");
}

void transform::detail::checkImplementsTransformTypeInterface(
void transform::detail::checkImplementsTransformHandleTypeInterface(
TypeID typeID, MLIRContext *context) {
const auto &abstractType = AbstractType::lookup(typeID, context);
assert(abstractType.hasInterface(TransformTypeInterface::getInterfaceID()));
assert(
(abstractType.hasInterface(
TransformHandleTypeInterface::getInterfaceID()) ||
abstractType.hasInterface(
TransformParamTypeInterface::getInterfaceID())) &&
"expected Transform dialect type to implement one of the two interfaces");
}
#endif // NDEBUG

namespace {
struct PDLOperationTypeTransformTypeInterfaceImpl
: public transform::TransformTypeInterface::ExternalModel<
PDLOperationTypeTransformTypeInterfaceImpl, pdl::OperationType> {
struct PDLOperationTypeTransformHandleTypeInterfaceImpl
: public transform::TransformHandleTypeInterface::ExternalModel<
PDLOperationTypeTransformHandleTypeInterfaceImpl,
pdl::OperationType> {
DiagnosedSilenceableFailure
checkPayload(Type type, Location loc, ArrayRef<Operation *> payload) const {
return DiagnosedSilenceableFailure::success();
Expand All @@ -63,7 +69,7 @@ void transform::TransformDialect::initialize() {
initializeTypes();

pdl::OperationType::attachInterface<
PDLOperationTypeTransformTypeInterfaceImpl>(*getContext());
PDLOperationTypeTransformHandleTypeInterfaceImpl>(*getContext());
}

void transform::TransformDialect::mergeInPDLMatchHooks(
Expand Down
221 changes: 210 additions & 11 deletions mlir/lib/Dialect/Transform/IR/TransformInterfaces.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
#include "mlir/IR/Diagnostics.h"
#include "mlir/IR/Operation.h"
#include "llvm/ADT/STLExtras.h"
#include "llvm/ADT/ScopeExit.h"
#include "llvm/Support/Debug.h"
#include "llvm/Support/ErrorHandling.h"

Expand Down Expand Up @@ -44,7 +45,16 @@ ArrayRef<Operation *>
transform::TransformState::getPayloadOps(Value value) const {
const TransformOpMapping &operationMapping = getMapping(value).direct;
auto iter = operationMapping.find(value);
assert(iter != operationMapping.end() && "unknown handle");
assert(iter != operationMapping.end() &&
"cannot find mapping for payload handle (param handle provided?)");
return iter->getSecond();
}

ArrayRef<Attribute> transform::TransformState::getParams(Value value) const {
const ParamMapping &mapping = getMapping(value).params;
auto iter = mapping.find(value);
assert(iter != mapping.end() &&
"cannot find mapping for param handle (payload handle provided?)");
return iter->getSecond();
}

Expand All @@ -67,8 +77,10 @@ transform::TransformState::setPayloadOps(Value value,
ArrayRef<Operation *> targets) {
assert(value != kTopLevelValue &&
"attempting to reset the transformation root");
assert(!value.getType().isa<TransformParamTypeInterface>() &&
"cannot associate payload ops with a value of parameter type");

auto iface = value.getType().cast<TransformTypeInterface>();
auto iface = value.getType().cast<TransformHandleTypeInterface>();
DiagnosedSilenceableFailure result =
iface.checkPayload(value.getLoc(), targets);
if (failed(result.checkAndReport()))
Expand All @@ -89,6 +101,26 @@ transform::TransformState::setPayloadOps(Value value,
return success();
}

LogicalResult transform::TransformState::setParams(Value value,
ArrayRef<Param> params) {
assert(value != nullptr && "attempting to set params for a null value");

auto valueType = value.getType().dyn_cast<TransformParamTypeInterface>();
assert(value &&
"cannot associate parameter with a value of non-parameter type");
DiagnosedSilenceableFailure result =
valueType.checkPayload(value.getLoc(), params);
if (failed(result.checkAndReport()))
return failure();

Mappings &mappings = getMapping(value);
bool inserted =
mappings.params.insert({value, llvm::to_vector(params)}).second;
assert(inserted && "value is already associated with another list of params");
(void)inserted;
return success();
}

void transform::TransformState::dropReverseMapping(Mappings &mappings,
Operation *op, Value value) {
auto it = mappings.reverse.find(op);
Expand All @@ -112,8 +144,8 @@ LogicalResult transform::TransformState::updatePayloadOps(
Mappings &mappings = getMapping(value);
auto it = mappings.direct.find(value);
assert(it != mappings.direct.end() && "unknown handle");
SmallVector<Operation *> &association = it->getSecond();
SmallVector<Operation *> updated;
SmallVector<Operation *, 2> &association = it->getSecond();
SmallVector<Operation *, 2> updated;
updated.reserve(association.size());

for (Operation *op : association) {
Expand All @@ -124,7 +156,7 @@ LogicalResult transform::TransformState::updatePayloadOps(
}
}

auto iface = value.getType().cast<TransformTypeInterface>();
auto iface = value.getType().cast<TransformHandleTypeInterface>();
DiagnosedSilenceableFailure result =
iface.checkPayload(value.getLoc(), updated);
if (failed(result.checkAndReport()))
Expand Down Expand Up @@ -269,8 +301,21 @@ transform::TransformState::applyTransform(TransformOpInterface transform) {
assert(result.getDefiningOp() == transform.getOperation() &&
"payload IR association for a value other than the result of the "
"current transform op");
if (failed(setPayloadOps(result, results.get(result.getResultNumber()))))
return DiagnosedSilenceableFailure::definiteFailure();
if (result.getType().isa<TransformParamTypeInterface>()) {
assert(results.isParam(result.getResultNumber()) &&
"expected parameters for the parameter-typed result");
if (failed(
setParams(result, results.getParams(result.getResultNumber())))) {
return DiagnosedSilenceableFailure::definiteFailure();
}
} else {
assert(!results.isParam(result.getResultNumber()) &&
"expected payload ops for the non-parameter typed result");
if (failed(
setPayloadOps(result, results.get(result.getResultNumber())))) {
return DiagnosedSilenceableFailure::definiteFailure();
}
}
}

printOnFailureRAII.release();
Expand Down Expand Up @@ -312,6 +357,8 @@ transform::TransformState::Extension::replacePayloadOp(Operation *op,
transform::TransformResults::TransformResults(unsigned numSegments) {
segments.resize(numSegments,
ArrayRef<Operation *>(nullptr, static_cast<size_t>(0)));
paramSegments.resize(numSegments, ArrayRef<TransformState::Param>(
nullptr, static_cast<size_t>(0)));
}

void transform::TransformResults::set(OpResult value,
Expand All @@ -325,14 +372,128 @@ void transform::TransformResults::set(OpResult value,
segments[position] = makeArrayRef(operations).drop_front(start);
}

void transform::TransformResults::setParams(
OpResult value, ArrayRef<transform::TransformState::Param> params) {
int64_t position = value.getResultNumber();
assert(position < static_cast<int64_t>(paramSegments.size()) &&
"setting params for a non-existent handle");
assert(paramSegments[position].data() == nullptr && "params already set");
size_t start = this->params.size();
llvm::append_range(this->params, params);
paramSegments[position] = makeArrayRef(this->params).drop_front(start);
}

ArrayRef<Operation *>
transform::TransformResults::get(unsigned resultNumber) const {
assert(resultNumber < segments.size() &&
"querying results for a non-existent handle");
assert(segments[resultNumber].data() != nullptr && "querying unset results");
assert(segments[resultNumber].data() != nullptr &&
"querying unset results (param expected?)");
return segments[resultNumber];
}

ArrayRef<transform::TransformState::Param>
transform::TransformResults::getParams(unsigned resultNumber) const {
assert(resultNumber < paramSegments.size() &&
"querying params for a non-existent handle");
assert(paramSegments[resultNumber].data() != nullptr &&
"querying unset params (payload ops expected?)");
return paramSegments[resultNumber];
}

bool transform::TransformResults::isParam(unsigned resultNumber) const {
assert(resultNumber < paramSegments.size() &&
"querying association for a non-existent handle");
return paramSegments[resultNumber].data() != nullptr;
}

//===----------------------------------------------------------------------===//
// Utilities for TransformEachOpTrait.
//===----------------------------------------------------------------------===//

LogicalResult
transform::detail::checkApplyToOne(Operation *transformOp,
Location payloadOpLoc,
const ApplyToEachResultList &partialResult) {
Location transformOpLoc = transformOp->getLoc();
StringRef transformOpName = transformOp->getName().getStringRef();
unsigned expectedNumResults = transformOp->getNumResults();
// TODO: encode this implicit must always produce `expectedNumResults`
// and nullptr is fine with a proper trait.
if (partialResult.size() != expectedNumResults) {
auto diag = mlir::emitError(transformOpLoc, "applications of ")
<< transformOpName << " expected to produce "
<< expectedNumResults << " results (actually produced "
<< partialResult.size() << ").";
diag.attachNote(transformOpLoc)
<< "If you need variadic results, consider a generic `apply` "
<< "instead of the specialized `applyToOne`.";
diag.attachNote(transformOpLoc)
<< "Producing " << expectedNumResults << " null results is "
<< "allowed if the use case warrants it.";
diag.attachNote(payloadOpLoc) << "when applied to this op";
return failure();
}

// Check that all is null or none is null
// TODO: relax this behavior and encode with a proper trait.
if (llvm::any_of(
partialResult,
[](llvm::PointerUnion<Operation *, Attribute> ptr) { return ptr; }) &&
llvm::any_of(partialResult,
[](llvm::PointerUnion<Operation *, Attribute> ptr) {
return !ptr;
})) {
auto diag = mlir::emitError(transformOpLoc, "unexpected application of ")
<< transformOpName
<< " produces both null and non null results.";
diag.attachNote(payloadOpLoc) << "when applied to this op";
return failure();
}

// Check that the right kind of value was produced.
for (const auto &[ptr, res] :
llvm::zip(partialResult, transformOp->getResults())) {
if (ptr.is<Operation *>() &&
!res.getType().template isa<TransformHandleTypeInterface>()) {
mlir::emitError(transformOpLoc)
<< "applications of " << transformOpName
<< " expected to produce an Attribute for result #"
<< res.getResultNumber();
return failure();
}
if (ptr.is<Attribute>() &&
!res.getType().template isa<TransformParamTypeInterface>()) {
mlir::emitError(transformOpLoc)
<< "applications of " << transformOpName
<< " expected to produce an Operation * for result #"
<< res.getResultNumber();
return failure();
}
}
return success();
}

void transform::detail::setApplyToOneResults(
Operation *transformOp, TransformResults &transformResults,
ArrayRef<ApplyToEachResultList> results) {
for (OpResult r : transformOp->getResults()) {
if (r.getType().isa<TransformParamTypeInterface>()) {
auto params = llvm::to_vector(
llvm::map_range(results, [r](const ApplyToEachResultList &oneResult) {
return oneResult[r.getResultNumber()].get<Attribute>();
}));
transformResults.setParams(r, params);
} else {
auto payloads = llvm::to_vector(
llvm::map_range(results, [r](const ApplyToEachResultList &oneResult) {
return oneResult[r.getResultNumber()].get<Operation *>();
}));
transformResults.set(r, payloads);
}
}
}

//===----------------------------------------------------------------------===//
// Utilities for PossibleTopLevelTransformOpTrait.
//===----------------------------------------------------------------------===//
Expand Down Expand Up @@ -366,9 +527,10 @@ transform::detail::verifyPossibleTopLevelTransformOpTrait(Operation *op) {

Block *body = &bodyRegion->front();
if (body->getNumArguments() != 1 ||
!body->getArgumentTypes()[0].isa<TransformTypeInterface>()) {
return op->emitOpError() << "expects the entry block to have one argument "
"of type implementing TransformTypeInterface";
!body->getArgumentTypes()[0].isa<TransformHandleTypeInterface>()) {
return op->emitOpError()
<< "expects the entry block to have one argument "
"of type implementing TransformHandleTypeInterface";
}

if (auto *parent =
Expand All @@ -386,6 +548,43 @@ transform::detail::verifyPossibleTopLevelTransformOpTrait(Operation *op) {
return success();
}

//===----------------------------------------------------------------------===//
// Utilities for ParamProducedTransformOpTrait.
//===----------------------------------------------------------------------===//

void transform::detail::getParamProducerTransformOpTraitEffects(
Operation *op, SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {
producesHandle(op->getResults(), effects);
bool hasPayloadOperands = false;
for (Value operand : op->getOperands()) {
onlyReadsHandle(operand, effects);
if (operand.getType().isa<TransformHandleTypeInterface>())
hasPayloadOperands = true;
}
if (hasPayloadOperands)
onlyReadsPayload(effects);
}

LogicalResult
transform::detail::verifyParamProducerTransformOpTrait(Operation *op) {
// Interfaces can be attached dynamically, so this cannot be a static
// assert.
if (!op->getName().getInterface<MemoryEffectOpInterface>()) {
llvm::report_fatal_error(
Twine("ParamProducerTransformOpTrait must be attached to an op that "
"implements MemoryEffectsOpInterface, found on ") +
op->getName().getStringRef());
}
for (Value result : op->getResults()) {
if (result.getType().isa<TransformParamTypeInterface>())
continue;
return op->emitOpError()
<< "ParamProducerTransformOpTrait attached to this op expects "
"result types to implement TransformParamTypeInterface";
}
return success();
}

//===----------------------------------------------------------------------===//
// Memory effects.
//===----------------------------------------------------------------------===//
Expand Down
12 changes: 6 additions & 6 deletions mlir/lib/Dialect/Transform/IR/TransformOps.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -261,8 +261,7 @@ LogicalResult transform::AlternativesOp::verify() {
//===----------------------------------------------------------------------===//

DiagnosedSilenceableFailure
transform::CastOp::applyToOne(Operation *target,
SmallVectorImpl<Operation *> &results,
transform::CastOp::applyToOne(Operation *target, ApplyToEachResultList &results,
transform::TransformState &state) {
results.push_back(target);
return DiagnosedSilenceableFailure::success();
Expand All @@ -281,7 +280,8 @@ bool transform::CastOp::areCastCompatible(TypeRange inputs, TypeRange outputs) {
return llvm::all_of(
std::initializer_list<Type>{inputs.front(), outputs.front()},
[](Type ty) {
return ty.isa<pdl::OperationType, transform::TransformTypeInterface>();
return ty
.isa<pdl::OperationType, transform::TransformHandleTypeInterface>();
});
}

Expand Down Expand Up @@ -370,9 +370,9 @@ LogicalResult transform::ForeachOp::verify() {
return emitOpError() << "expects the same number of results as the "
"terminator has operands";
for (Value v : yieldOp.getOperands())
if (!v.getType().isa<TransformTypeInterface>())
return yieldOp->emitOpError(
"expects operands to have types implementing TransformTypeInterface");
if (!v.getType().isa<TransformHandleTypeInterface>())
return yieldOp->emitOpError("expects operands to have types implementing "
"TransformHandleTypeInterface");
return success();
}

Expand Down
41 changes: 41 additions & 0 deletions mlir/lib/Dialect/Transform/IR/TransformTypes.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
#include "mlir/Dialect/Transform/IR/TransformDialect.h"
#include "mlir/Dialect/Transform/IR/TransformInterfaces.h"
#include "mlir/IR/Builders.h"
#include "mlir/IR/BuiltinTypes.h"
#include "mlir/IR/DialectImplementation.h"
#include "mlir/IR/OpImplementation.h"
#include "mlir/IR/Types.h"
Expand Down Expand Up @@ -37,12 +38,20 @@ void transform::TransformDialect::initializeTypes() {
>();
}

//===----------------------------------------------------------------------===//
// transform::AnyOpType
//===----------------------------------------------------------------------===//

DiagnosedSilenceableFailure
transform::AnyOpType::checkPayload(Location loc,
ArrayRef<Operation *> payload) const {
return DiagnosedSilenceableFailure::success();
}

//===----------------------------------------------------------------------===//
// transform::OperationType
//===----------------------------------------------------------------------===//

DiagnosedSilenceableFailure
transform::OperationType::checkPayload(Location loc,
ArrayRef<Operation *> payload) const {
Expand All @@ -58,3 +67,35 @@ transform::OperationType::checkPayload(Location loc,

return DiagnosedSilenceableFailure::success();
}

//===----------------------------------------------------------------------===//
// transform::ParamType
//===----------------------------------------------------------------------===//

LogicalResult
transform::ParamType::verify(function_ref<InFlightDiagnostic()> emitError,
Type type) {
IntegerType intType = type.dyn_cast<IntegerType>();
if (!intType || intType.getWidth() > 64)
return emitError() << "only supports integer types with width <=64";
return success();
}

DiagnosedSilenceableFailure
transform::ParamType::checkPayload(Location loc,
ArrayRef<Attribute> payload) const {
for (Attribute attr : payload) {
auto integerAttr = attr.dyn_cast<IntegerAttr>();
if (!integerAttr) {
return emitSilenceableError(loc)
<< "expected parameter to be an integer attribute, got " << attr;
}
if (integerAttr.getType() != getType()) {
return emitSilenceableError(loc)
<< "expected the type of the parameter attribute ("
<< integerAttr.getType() << ") to match the parameter type ("
<< getType() << ")";
}
}
return DiagnosedSilenceableFailure::success();
}
5 changes: 4 additions & 1 deletion mlir/lib/Dialect/Transform/Utils/CMakeLists.txt
Original file line number Diff line number Diff line change
@@ -1,10 +1,13 @@
add_mlir_dialect_library(MLIRTransformDialectUtils
DiagnosedSilenceableFailure.cpp
Utils.cpp

ADDITIONAL_HEADER_DIRS
${MLIR_MAIN_INCLUDE_DIR}/mlir/Dialect/Transform

LINK_LIBS PUBLIC
MLIRDialectUtils
MLIRIR
MLIRSupport
MLIRTransformDialect
MLIRViewLikeInterface
)
33 changes: 33 additions & 0 deletions mlir/lib/Dialect/Transform/Utils/DiagnosedSilenceableFailure.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,33 @@
//===- DiagnosedSilenceableFailure.cpp - Tri-state result -----------------===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
//
// This file defines the DiagnosedSilenceableFailure class allowing to store
// a tri-state result (definite failure, recoverable failure, success) with an
// optional associated list of diagnostics.
//
//===----------------------------------------------------------------------===//

#include "mlir/Dialect/Transform/Utils/DiagnosedSilenceableFailure.h"

using namespace mlir;

LogicalResult mlir::DiagnosedSilenceableFailure::checkAndReport() {
#if LLVM_ENABLE_ABI_BREAKING_CHECKS
assert(!reported && "attempting to report a diagnostic more than once");
reported = true;
#endif // LLVM_ENABLE_ABI_BREAKING_CHECKS
if (!diagnostics.empty()) {
for (auto &&diagnostic : diagnostics) {
diagnostic.getLocation().getContext()->getDiagEngine().emit(
std::move(diagnostic));
}
diagnostics.clear();
result = ::mlir::failure();
}
return result;
}
13 changes: 4 additions & 9 deletions mlir/lib/Dialect/Transform/Utils/Utils.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -8,20 +8,15 @@

#include "mlir/Dialect/Transform/Utils/Utils.h"

#include "mlir/Dialect/Transform/IR/TransformDialect.h"
#include "mlir/Dialect/Transform/IR/TransformTypes.h"
#include "mlir/Dialect/Utils/StaticValueUtils.h"
#include "mlir/IR/Matchers.h"
#include "mlir/IR/OpDefinition.h"
#include "mlir/Interfaces/ViewLikeInterface.h"

using namespace mlir;
using namespace mlir::transform;

void transform::printPackedOrDynamicIndexList(OpAsmPrinter &printer,
Operation *op, Value packed,
OperandRange values,
ArrayRef<int64_t> integers) {
void mlir::transform::printPackedOrDynamicIndexList(
OpAsmPrinter &printer, Operation *op, Value packed, OperandRange values,
ArrayRef<int64_t> integers) {
if (packed) {
assert(values.empty() && integers.empty() && "expected no values/integers");
printer << packed;
Expand All @@ -30,7 +25,7 @@ void transform::printPackedOrDynamicIndexList(OpAsmPrinter &printer,
printDynamicIndexList(printer, op, values, integers);
}

ParseResult transform::parsePackedOrDynamicIndexList(
ParseResult mlir::transform::parsePackedOrDynamicIndexList(
OpAsmParser &parser, std::optional<OpAsmParser::UnresolvedOperand> &packed,
SmallVectorImpl<OpAsmParser::UnresolvedOperand> &values,
DenseI64ArrayAttr &integers) {
Expand Down
4 changes: 2 additions & 2 deletions mlir/test/Dialect/Transform/ops-invalid.mlir
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
// RUN: mlir-opt %s -split-input-file -verify-diagnostics

// expected-error @below {{expects the entry block to have one argument of type implementing TransformTypeInterface}}
// expected-error @below {{expects the entry block to have one argument of type implementing TransformHandleTypeInterface}}
transform.sequence failures(propagate) {
}

Expand Down Expand Up @@ -190,7 +190,7 @@ transform.sequence failures(propagate) {

// -----

// expected-error @below {{expects the entry block to have one argument of type implementing TransformTypeInterface}}
// expected-error @below {{expects the entry block to have one argument of type implementing TransformHandleTypeInterface}}
transform.alternatives {
^bb0:
transform.yield
Expand Down
87 changes: 87 additions & 0 deletions mlir/test/Dialect/Transform/test-interpreter.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -920,6 +920,7 @@ transform.with_pdl_patterns {
}

"test.some_op"() : () -> ()

// -----

func.func @split_handles(%a: index, %b: index, %c: index) {
Expand All @@ -937,3 +938,89 @@ transform.sequence -> !pdl.operation failures(propagate) {
/// propagate mode.
yield %fun : !pdl.operation
}

// -----

transform.sequence failures(propagate) {
^bb0(%arg0: !transform.any_op):
%0 = transform.test_produce_integer_param_with_type i32 : !transform.test_dialect_param
// expected-remark @below {{0 : i32}}
transform.test_print_param %0 : !transform.test_dialect_param
}

// -----

transform.sequence failures(propagate) {
^bb0(%arg0: !transform.any_op):
// expected-error @below {{expected the type of the parameter attribute ('i32') to match the parameter type ('i64')}}
transform.test_produce_integer_param_with_type i32 : !transform.param<i64>
}

// -----


transform.sequence failures(propagate) {
^bb0(%arg0: !transform.any_op):
%0 = transform.test_add_to_param 40
%1 = transform.test_add_to_param %0, 2
// expected-remark @below {{42 : i32}}
transform.test_print_param %1 : !transform.test_dialect_param
}

// -----

transform.sequence failures(propagate) {
^bb0(%arg0: !pdl.operation):
%0 = transform.structured.match ops{["func.func"]} in %arg0
%1 = transform.test_produce_param_with_number_of_test_ops %0 : !pdl.operation
// expected-remark @below {{1 : i32, 3 : i32}}
transform.test_print_param %1 : !transform.test_dialect_param
%2 = transform.test_add_to_param %1, 100
// expected-remark @below {{101 : i32, 103 : i32}}
transform.test_print_param %2 : !transform.test_dialect_param
}

func.func private @one_test_op(%arg0: i32) {
"test.op_a"(%arg0) { attr = 0 : i32} : (i32) -> i32
return
}

func.func private @three_test_ops(%arg0: i32) {
"test.op_a"(%arg0) { attr = 0 : i32} : (i32) -> i32
"test.op_a"(%arg0) { attr = 0 : i32} : (i32) -> i32
"test.op_a"(%arg0) { attr = 0 : i32} : (i32) -> i32
return
}

// -----

transform.sequence failures(propagate) {
^bb0(%arg0: !transform.any_op):
// expected-error @below {{expected to produce an Operation * for result #0}}
transform.test_produce_transform_param_or_forward_operand %arg0
{ first_result_is_param }
: (!transform.any_op) -> (!transform.any_op, !transform.param<i64>)
}

// -----

// expected-note @below {{when applied to this op}}
module {
transform.sequence failures(propagate) {
^bb0(%arg0: !transform.any_op):
// expected-error @below {{produces both null and non null results}}
transform.test_produce_transform_param_or_forward_operand %arg0
{ first_result_is_null }
: (!transform.any_op) -> (!transform.any_op, !transform.param<i64>)
}
}

// -----

transform.sequence failures(propagate) {
^bb0(%arg0: !transform.any_op):
// expected-error @below {{expected to produce an Attribute for result #1}}
transform.test_produce_transform_param_or_forward_operand %arg0
{ second_result_is_handle }
: (!transform.any_op) -> (!transform.any_op, !transform.param<i64>)
}
134 changes: 123 additions & 11 deletions mlir/test/lib/Dialect/Transform/TestTransformDialectExtension.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -17,8 +17,10 @@
#include "mlir/Dialect/Transform/IR/TransformDialect.h"
#include "mlir/Dialect/Transform/IR/TransformInterfaces.h"
#include "mlir/IR/OpImplementation.h"
#include "llvm/ADT/STLExtras.h"
#include "llvm/ADT/TypeSwitch.h"
#include "llvm/Support/Compiler.h"
#include "llvm/Support/raw_ostream.h"

using namespace mlir;

Expand Down Expand Up @@ -241,7 +243,7 @@ DiagnosedSilenceableFailure mlir::test::TestEmitRemarkAndEraseOperandOp::apply(
}

DiagnosedSilenceableFailure mlir::test::TestWrongNumberOfResultsOp::applyToOne(
Operation *target, SmallVectorImpl<Operation *> &results,
Operation *target, transform::ApplyToEachResultList &results,
transform::TransformState &state) {
OperationState opState(target->getLoc(), "foo");
results.push_back(OpBuilder(target).create(opState));
Expand All @@ -250,7 +252,7 @@ DiagnosedSilenceableFailure mlir::test::TestWrongNumberOfResultsOp::applyToOne(

DiagnosedSilenceableFailure
mlir::test::TestWrongNumberOfMultiResultsOp::applyToOne(
Operation *target, SmallVectorImpl<Operation *> &results,
Operation *target, transform::ApplyToEachResultList &results,
transform::TransformState &state) {
static int count = 0;
if (count++ == 0) {
Expand All @@ -262,7 +264,7 @@ mlir::test::TestWrongNumberOfMultiResultsOp::applyToOne(

DiagnosedSilenceableFailure
mlir::test::TestCorrectNumberOfMultiResultsOp::applyToOne(
Operation *target, SmallVectorImpl<Operation *> &results,
Operation *target, transform::ApplyToEachResultList &results,
transform::TransformState &state) {
OperationState opState(target->getLoc(), "foo");
results.push_back(OpBuilder(target).create(opState));
Expand All @@ -272,7 +274,7 @@ mlir::test::TestCorrectNumberOfMultiResultsOp::applyToOne(

DiagnosedSilenceableFailure
mlir::test::TestMixedNullAndNonNullResultsOp::applyToOne(
Operation *target, SmallVectorImpl<Operation *> &results,
Operation *target, transform::ApplyToEachResultList &results,
transform::TransformState &state) {
OperationState opState(target->getLoc(), "foo");
results.push_back(nullptr);
Expand All @@ -282,7 +284,7 @@ mlir::test::TestMixedNullAndNonNullResultsOp::applyToOne(

DiagnosedSilenceableFailure
mlir::test::TestMixedSuccessAndSilenceableOp::applyToOne(
Operation *target, SmallVectorImpl<Operation *> &results,
Operation *target, transform::ApplyToEachResultList &results,
transform::TransformState &state) {
if (target->hasAttr("target_me"))
return DiagnosedSilenceableFailure::success();
Expand Down Expand Up @@ -317,15 +319,27 @@ DiagnosedSilenceableFailure mlir::transform::TestDialectOpType::checkPayload(

for (Operation *op : payload) {
if (op->getName().getDialectNamespace() != "test") {
Diagnostic diag(loc, DiagnosticSeverity::Error);
diag << "expected the payload operation to belong to the 'test' dialect";
return DiagnosedSilenceableFailure::silenceableFailure(std::move(diag));
return emitSilenceableError(loc) << "expected the payload operation to "
"belong to the 'test' dialect";
}
}

return DiagnosedSilenceableFailure::success();
}

DiagnosedSilenceableFailure mlir::transform::TestDialectParamType::checkPayload(
Location loc, ArrayRef<Attribute> payload) const {
for (Attribute attr : payload) {
auto integerAttr = attr.dyn_cast<IntegerAttr>();
if (integerAttr && integerAttr.getType().isSignlessInteger(32))
continue;
return emitSilenceableError(loc)
<< "expected the parameter to be a i32 integer attribute";
}

return DiagnosedSilenceableFailure::success();
}

void mlir::test::TestReportNumberOfTrackedHandlesNestedUnder::getEffects(
SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {
transform::onlyReadsHandle(getTarget(), effects);
Expand All @@ -346,6 +360,104 @@ mlir::test::TestReportNumberOfTrackedHandlesNestedUnder::apply(
return DiagnosedSilenceableFailure::success();
}

void mlir::test::TestPrintParamOp::getEffects(
SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {
transform::onlyReadsHandle(getParam(), effects);
}

DiagnosedSilenceableFailure
mlir::test::TestPrintParamOp::apply(transform::TransformResults &results,
transform::TransformState &state) {
std::string str;
llvm::raw_string_ostream os(str);
llvm::interleaveComma(state.getParams(getParam()), os);
auto diag = emitRemark() << os.str();
return DiagnosedSilenceableFailure::success();
}

DiagnosedSilenceableFailure
mlir::test::TestAddToParamOp::apply(transform::TransformResults &results,
transform::TransformState &state) {
SmallVector<uint32_t> values(/*Size=*/1, /*Value=*/0);
if (Value param = getParam()) {
values = llvm::to_vector(
llvm::map_range(state.getParams(param), [](Attribute attr) -> uint32_t {
return attr.cast<IntegerAttr>().getValue().getLimitedValue(
UINT32_MAX);
}));
}

Builder builder(getContext());
SmallVector<Attribute> result = llvm::to_vector(
llvm::map_range(values, [this, &builder](uint32_t value) -> Attribute {
return builder.getI32IntegerAttr(value + getAddendum());
}));
results.setParams(getResult().cast<OpResult>(), result);
return DiagnosedSilenceableFailure::success();
}

DiagnosedSilenceableFailure
mlir::test::TestProduceParamWithNumberOfTestOps::apply(
transform::TransformResults &results, transform::TransformState &state) {
Builder builder(getContext());
SmallVector<Attribute> result = llvm::to_vector(
llvm::map_range(state.getPayloadOps(getHandle()),
[&builder](Operation *payload) -> Attribute {
int32_t count = 0;
payload->walk([&count](Operation *op) {
if (op->getName().getDialectNamespace() == "test")
++count;
});
return builder.getI32IntegerAttr(count);
}));
results.setParams(getResult().cast<OpResult>(), result);
return DiagnosedSilenceableFailure::success();
}

DiagnosedSilenceableFailure
mlir::test::TestProduceIntegerParamWithTypeOp::apply(
transform::TransformResults &results, transform::TransformState &state) {
Attribute zero = IntegerAttr::get(getType(), 0);
results.setParams(getResult().cast<OpResult>(), zero);
return DiagnosedSilenceableFailure::success();
}

LogicalResult mlir::test::TestProduceIntegerParamWithTypeOp::verify() {
if (!getType().isa<IntegerType>()) {
return emitOpError() << "expects an integer type";
}
return success();
}

void mlir::test::TestProduceTransformParamOrForwardOperandOp::getEffects(
SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {
transform::onlyReadsHandle(getIn(), effects);
transform::producesHandle(getOut(), effects);
transform::producesHandle(getParam(), effects);
}

DiagnosedSilenceableFailure
mlir::test::TestProduceTransformParamOrForwardOperandOp::applyToOne(
Operation *target, ::transform::ApplyToEachResultList &results,
::transform::TransformState &state) {
Builder builder(getContext());
if (getFirstResultIsParam()) {
results.push_back(builder.getI64IntegerAttr(0));
} else if (getFirstResultIsNull()) {
results.push_back(nullptr);
} else {
results.push_back(state.getPayloadOps(getIn()).front());
}

if (getSecondResultIsHandle()) {
results.push_back(state.getPayloadOps(getIn()).front());
} else {
results.push_back(builder.getI64IntegerAttr(42));
}

return DiagnosedSilenceableFailure::success();
}

namespace {
/// Test extension of the Transform dialect. Registers additional ops and
/// declares PDL as dependent dialect since the additional ops are using PDL
Expand All @@ -371,9 +483,6 @@ class TestTransformDialectExtension
};
} // namespace

#define GET_OP_CLASSES
#include "TestTransformDialectExtension.cpp.inc"

// These are automatically generated by ODS but are not used as the Transform
// dialect uses a different dispatch mechanism to support dialect extensions.
LLVM_ATTRIBUTE_UNUSED static OptionalParseResult
Expand All @@ -384,6 +493,9 @@ generatedTypePrinter(Type def, AsmPrinter &printer);
#define GET_TYPEDEF_CLASSES
#include "TestTransformDialectExtensionTypes.cpp.inc"

#define GET_OP_CLASSES
#include "TestTransformDialectExtension.cpp.inc"

void ::test::registerTestTransformDialectExtension(DialectRegistry &registry) {
registry.addExtensions<TestTransformDialectExtension>();
}
Original file line number Diff line number Diff line change
Expand Up @@ -23,12 +23,12 @@ namespace mlir {
class DialectRegistry;
} // namespace mlir

#define GET_OP_CLASSES
#include "TestTransformDialectExtension.h.inc"

#define GET_TYPEDEF_CLASSES
#include "TestTransformDialectExtensionTypes.h.inc"

#define GET_OP_CLASSES
#include "TestTransformDialectExtension.h.inc"

namespace test {
/// Registers the test extension to the Transform dialect.
void registerTestTransformDialectExtension(::mlir::DialectRegistry &registry);
Expand Down
86 changes: 79 additions & 7 deletions mlir/test/lib/Dialect/Transform/TestTransformDialectExtension.td
Original file line number Diff line number Diff line change
Expand Up @@ -23,12 +23,22 @@ include "mlir/Dialect/PDL/IR/PDLTypes.td"

def TestTransformTestDialectHandleType
: TypeDef<Transform_Dialect, "TestDialectOp",
[DeclareTypeInterfaceMethods<TransformTypeInterface>]> {
[DeclareTypeInterfaceMethods<TransformHandleTypeInterface>]> {
let description = [{Handle pointing to an op from the Test dialect.}];
let mnemonic = "test_dialect_op";
let assemblyFormat = "";
}

def TestTransformTestDialectParamType
: TypeDef<Transform_Dialect, "TestDialectParam",
[DeclareTypeInterfaceMethods<TransformParamTypeInterface>]> {
let description = [{
Parameter associated with an i32 attribute for testing purposes.
}];
let mnemonic = "test_dialect_param";
let assemblyFormat = "";
}

def TestProduceParamOrForwardOperandOp
: Op<Transform_Dialect, "test_produce_param_or_forward_operand",
[DeclareOpInterfaceMethods<TransformOpInterface>]> {
Expand Down Expand Up @@ -69,7 +79,7 @@ def TestPrintRemarkAtOperandOp
: Op<Transform_Dialect, "test_print_remark_at_operand",
[DeclareOpInterfaceMethods<TransformOpInterface>]> {
let arguments = (ins
Arg<TransformTypeInterface, "",
Arg<TransformHandleTypeInterface, "",
[TransformMappingRead, PayloadIRRead]>:$operand,
StrAttr:$message);
let assemblyFormat =
Expand Down Expand Up @@ -163,7 +173,7 @@ def TestWrongNumberOfResultsOp
let extraClassDeclaration = [{
::mlir::DiagnosedSilenceableFailure applyToOne(
::mlir::Operation * target,
::llvm::SmallVectorImpl<::mlir::Operation *> &results,
::mlir::transform::ApplyToEachResultList &results,
::mlir::transform::TransformState &state);
}];
}
Expand All @@ -179,7 +189,7 @@ def TestWrongNumberOfMultiResultsOp
let extraClassDeclaration = [{
::mlir::DiagnosedSilenceableFailure applyToOne(
::mlir::Operation * target,
::llvm::SmallVectorImpl<::mlir::Operation *> &results,
::mlir::transform::ApplyToEachResultList &results,
::mlir::transform::TransformState &state);
}];
}
Expand All @@ -196,7 +206,7 @@ def TestCorrectNumberOfMultiResultsOp
let extraClassDeclaration = [{
::mlir::DiagnosedSilenceableFailure applyToOne(
::mlir::Operation * target,
::llvm::SmallVectorImpl<::mlir::Operation *> &results,
::mlir::transform::ApplyToEachResultList &results,
::mlir::transform::TransformState &state);
}];
}
Expand All @@ -213,7 +223,7 @@ def TestMixedNullAndNonNullResultsOp
let extraClassDeclaration = [{
::mlir::DiagnosedSilenceableFailure applyToOne(
::mlir::Operation * target,
::llvm::SmallVectorImpl<::mlir::Operation *> &results,
::mlir::transform::ApplyToEachResultList &results,
::mlir::transform::TransformState &state);
}];
}
Expand All @@ -229,7 +239,7 @@ def TestMixedSuccessAndSilenceableOp
let extraClassDeclaration = [{
::mlir::DiagnosedSilenceableFailure applyToOne(
::mlir::Operation * target,
::llvm::SmallVectorImpl<::mlir::Operation *> &results,
::mlir::transform::ApplyToEachResultList &results,
::mlir::transform::TransformState &state);
}];
}
Expand Down Expand Up @@ -262,4 +272,66 @@ def TestReportNumberOfTrackedHandlesNestedUnder
let cppNamespace = "::mlir::test";
}

def TestPrintParamOp
: Op<Transform_Dialect, "test_print_param",
[DeclareOpInterfaceMethods<MemoryEffectsOpInterface>,
DeclareOpInterfaceMethods<TransformOpInterface>]> {
let arguments = (ins TransformParamTypeInterface:$param);
let assemblyFormat = "$param attr-dict `:` type($param)";
let cppNamespace = "::mlir::test";
}

def TestAddToParamOp
: Op<Transform_Dialect, "test_add_to_param",
[MemoryEffectsOpInterface, ParamProducerTransformOpTrait,
DeclareOpInterfaceMethods<TransformOpInterface>]> {
let arguments = (ins Optional<TestTransformTestDialectParamType>:$param,
I32Attr:$addendum);
let results = (outs TestTransformTestDialectParamType:$result);
let assemblyFormat = "($param^ `,`)? $addendum attr-dict";
let cppNamespace = "::mlir::test";
}

def TestProduceParamWithNumberOfTestOps
: Op<Transform_Dialect, "test_produce_param_with_number_of_test_ops",
[MemoryEffectsOpInterface, ParamProducerTransformOpTrait,
DeclareOpInterfaceMethods<TransformOpInterface>]> {
let arguments = (ins TransformHandleTypeInterface:$handle);
let results = (outs TestTransformTestDialectParamType:$result);
let assemblyFormat = "$handle attr-dict `:` type($handle)";
let cppNamespace = "::mlir::test";
}

def TestProduceIntegerParamWithTypeOp
: Op<Transform_Dialect, "test_produce_integer_param_with_type",
[MemoryEffectsOpInterface, ParamProducerTransformOpTrait,
DeclareOpInterfaceMethods<TransformOpInterface>]> {
let arguments = (ins TypeAttr:$type);
let results = (outs TransformParamTypeInterface:$result);
let assemblyFormat = "$type attr-dict `:` type($result)";
let cppNamespace = "::mlir::test";
let hasVerifier = 1;
}

def TestProduceTransformParamOrForwardOperandOp
: Op<Transform_Dialect, "test_produce_transform_param_or_forward_operand",
[DeclareOpInterfaceMethods<MemoryEffectsOpInterface>,
TransformEachOpTrait, TransformOpInterface]> {
let arguments = (ins TransformHandleTypeInterface:$in,
UnitAttr:$first_result_is_param,
UnitAttr:$first_result_is_null,
UnitAttr:$second_result_is_handle);
let results = (outs TransformHandleTypeInterface:$out,
TransformParamTypeInterface:$param);
let assemblyFormat = "$in attr-dict `:` functional-type(operands, results)";
let cppNamespace = "::mlir::test";

let extraClassDeclaration = [{
::mlir::DiagnosedSilenceableFailure applyToOne(
::mlir::Operation *target,
::mlir::transform::ApplyToEachResultList &results,
::mlir::transform::TransformState &state);
}];
}

#endif // MLIR_TESTTRANSFORMDIALECTEXTENSION_TD
6 changes: 3 additions & 3 deletions utils/bazel/llvm-project-overlay/mlir/BUILD.bazel
Original file line number Diff line number Diff line change
Expand Up @@ -9138,6 +9138,7 @@ cc_library(
":TransformDialectEnumsIncGen",
":TransformDialectIncGen",
":TransformDialectInterfacesIncGen",
":TransformDialectUtils",
":TransformOpsIncGen",
":TransformTypesIncGen",
"//llvm:Support",
Expand Down Expand Up @@ -9187,14 +9188,13 @@ cc_library(

cc_library(
name = "TransformDialectUtils",
srcs = ["lib/Dialect/Transform/Utils/Utils.cpp"],
hdrs = ["include/mlir/Dialect/Transform/Utils/Utils.h"],
srcs = glob(["lib/Dialect/Transform/Utils/*cpp"]),
hdrs = glob(["include/mlir/Dialect/Transform/Utils/*.h"]),
includes = ["include"],
deps = [
":DialectUtils",
":IR",
":Support",
":TransformDialect",
":ViewLikeInterface",
"//llvm:Support",
],
Expand Down