267 changes: 200 additions & 67 deletions mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp

Large diffs are not rendered by default.

26 changes: 26 additions & 0 deletions mlir/lib/Dialect/Linalg/Transforms/Tiling.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -117,6 +117,32 @@ static void emitIsPositiveIndexAssertion(ImplicitLocOpBuilder &b,
b.getStringAttr("expected strictly positive tile size and divisor"));
}

FailureOr<StaticMultiSizeSpecification>
mlir::linalg::computeStaticMultiTileSizes(LinalgOp op, unsigned dimension,
int64_t targetSize, int64_t divisor) {
assert(!op.hasDynamicShape() &&
"cannot compute static multi-tile sizes for an op with dynamic shape");
assert(targetSize > 0 && "target size must be non-negative");
assert(divisor > 0 && "divisor must be non-negative");
assert(dimension < op.getNumLoops() && "dimension overflow");

StaticMultiSizeSpecification spec;
int64_t tripCount = op.getStaticLoopRanges()[dimension];
int64_t a = tripCount / divisor;
int64_t t = (targetSize + divisor - 1) / divisor;
int64_t totalTripCount = (a + t - 1) / t;
spec.lowTileSize = (a / totalTripCount) * divisor;
spec.highTileSize = spec.lowTileSize + divisor;
spec.highTripCount = a % totalTripCount;
spec.lowTripCount = totalTripCount - spec.highTripCount;
if (spec.lowTileSize * spec.lowTripCount +
spec.highTileSize * spec.highTripCount !=
tripCount) {
return failure();
}
return spec;
}

FailureOr<MultiSizeSpecification>
mlir::linalg::computeMultiTileSizes(OpBuilder &builder, LinalgOp op,
unsigned dimension, OpFoldResult targetSize,
Expand Down
3 changes: 0 additions & 3 deletions mlir/lib/Dialect/SCF/TransformOps/SCFTransformOps.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,6 @@ transform::GetParentForOp::apply(transform::TransformResults &results,
: scf::ForOp::getOperationName())
<< "' parent";
diag.attachNote(target->getLoc()) << "target op";
results.set(getResult().cast<OpResult>(), {});
return diag;
}
current = loop;
Expand Down Expand Up @@ -96,7 +95,6 @@ transform::LoopOutlineOp::apply(transform::TransformResults &results,
DiagnosedSilenceableFailure diag = emitSilenceableError()
<< "failed to outline";
diag.attachNote(target->getLoc()) << "target op";
results.set(getTransformed().cast<OpResult>(), {});
return diag;
}
func::CallOp call;
Expand Down Expand Up @@ -200,7 +198,6 @@ transform::LoopPipelineOp::applyToOne(scf::ForOp target,
results.push_back(*patternResult);
return DiagnosedSilenceableFailure::success();
}
results.assign(1, nullptr);
return emitDefaultSilenceableFailure(target);
}

Expand Down
81 changes: 44 additions & 37 deletions mlir/lib/Dialect/Transform/IR/TransformInterfaces.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -294,6 +294,20 @@ transform::TransformState::applyTransform(TransformOpInterface transform) {
if (result.isDefiniteFailure())
return result;

// If a silenceable failure was produced, some results may be unset, set them
// to empty lists.
if (result.isSilenceableFailure()) {
for (OpResult opResult : transform->getResults()) {
if (results.isSet(opResult.getResultNumber()))
continue;

if (opResult.getType().isa<TransformParamTypeInterface>())
results.setParams(opResult, {});
else
results.set(opResult, {});
}
}

// Remove the mapping for the operand if it is consumed by the operation. This
// allows us to catch use-after-free with assertions later on.
auto memEffectInterface =
Expand Down Expand Up @@ -421,6 +435,13 @@ bool transform::TransformResults::isParam(unsigned resultNumber) const {
return paramSegments[resultNumber].data() != nullptr;
}

bool transform::TransformResults::isSet(unsigned resultNumber) const {
assert(resultNumber < paramSegments.size() &&
"querying association for a non-existent handle");
return paramSegments[resultNumber].data() != nullptr ||
segments[resultNumber].data() != nullptr;
}

//===----------------------------------------------------------------------===//
// Utilities for TransformEachOpTrait.
//===----------------------------------------------------------------------===//
Expand All @@ -432,57 +453,43 @@ transform::detail::checkApplyToOne(Operation *transformOp,
Location transformOpLoc = transformOp->getLoc();
StringRef transformOpName = transformOp->getName().getStringRef();
unsigned expectedNumResults = transformOp->getNumResults();
// TODO: encode this implicit must always produce `expectedNumResults`
// and nullptr is fine with a proper trait.

// Reuse the emission of the diagnostic note.
auto emitDiag = [&]() {
auto diag = mlir::emitError(transformOpLoc);
diag.attachNote(payloadOpLoc) << "when applied to this op";
return diag;
};

if (partialResult.size() != expectedNumResults) {
auto diag = mlir::emitError(transformOpLoc, "applications of ")
<< transformOpName << " expected to produce "
<< expectedNumResults << " results (actually produced "
<< partialResult.size() << ").";
auto diag = emitDiag() << "application of " << transformOpName
<< " expected to produce " << expectedNumResults
<< " results (actually produced "
<< partialResult.size() << ").";
diag.attachNote(transformOpLoc)
<< "If you need variadic results, consider a generic `apply` "
<< "if you need variadic results, consider a generic `apply` "
<< "instead of the specialized `applyToOne`.";
diag.attachNote(transformOpLoc)
<< "Producing " << expectedNumResults << " null results is "
<< "allowed if the use case warrants it.";
diag.attachNote(payloadOpLoc) << "when applied to this op";
return failure();
}

// Check that all is null or none is null
// TODO: relax this behavior and encode with a proper trait.
if (llvm::any_of(
partialResult,
[](llvm::PointerUnion<Operation *, Attribute> ptr) { return ptr; }) &&
llvm::any_of(partialResult,
[](llvm::PointerUnion<Operation *, Attribute> ptr) {
return !ptr;
})) {
auto diag = mlir::emitError(transformOpLoc, "unexpected application of ")
<< transformOpName
<< " produces both null and non null results.";
diag.attachNote(payloadOpLoc) << "when applied to this op";
return failure();
}

// Check that the right kind of value was produced.
for (const auto &[ptr, res] :
llvm::zip(partialResult, transformOp->getResults())) {
if (ptr.isNull()) {
return emitDiag() << "null result #" << res.getResultNumber()
<< " produced";
}
if (ptr.is<Operation *>() &&
!res.getType().template isa<TransformHandleTypeInterface>()) {
mlir::emitError(transformOpLoc)
<< "applications of " << transformOpName
<< " expected to produce an Attribute for result #"
<< res.getResultNumber();
return failure();
return emitDiag() << "application of " << transformOpName
<< " expected to produce an Attribute for result #"
<< res.getResultNumber();
}
if (ptr.is<Attribute>() &&
!res.getType().template isa<TransformParamTypeInterface>()) {
mlir::emitError(transformOpLoc)
<< "applications of " << transformOpName
<< " expected to produce an Operation * for result #"
<< res.getResultNumber();
return failure();
return emitDiag() << "application of " << transformOpName
<< " expected to produce an Operation * for result #"
<< res.getResultNumber();
}
}
return success();
Expand Down
29 changes: 17 additions & 12 deletions mlir/lib/Dialect/Transform/IR/TransformOps.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -444,8 +444,6 @@ transform::GetProducerOfOperand::apply(transform::TransformResults &results,
<< "could not find a producer for operand number: " << operandNumber
<< " of " << *target;
diag.attachNote(target->getLoc()) << "target op";
results.set(getResult().cast<OpResult>(),
SmallVector<mlir::Operation *>{});
return diag;
}
producers.push_back(producer);
Expand Down Expand Up @@ -518,10 +516,6 @@ transform::SplitHandlesOp::apply(transform::TransformResults &results,
getHandle() ? state.getPayloadOps(getHandle()).size() : 0;
int64_t expectedNumResultHandles = getNumResultHandles();
if (numResultHandles != expectedNumResultHandles) {
// Failing case needs to propagate gracefully for both suppress and
// propagate modes.
for (int64_t idx = 0; idx < expectedNumResultHandles; ++idx)
results.set(getResults()[idx].cast<OpResult>(), {});
// Empty input handle corner case: always propagates empty handles in both
// suppress and propagate modes.
if (numResultHandles == 0)
Expand Down Expand Up @@ -586,12 +580,23 @@ transform::ReplicateOp::apply(transform::TransformResults &results,
unsigned numRepetitions = state.getPayloadOps(getPattern()).size();
for (const auto &en : llvm::enumerate(getHandles())) {
Value handle = en.value();
ArrayRef<Operation *> current = state.getPayloadOps(handle);
SmallVector<Operation *> payload;
payload.reserve(numRepetitions * current.size());
for (unsigned i = 0; i < numRepetitions; ++i)
llvm::append_range(payload, current);
results.set(getReplicated()[en.index()].cast<OpResult>(), payload);
if (handle.getType().isa<TransformHandleTypeInterface>()) {
ArrayRef<Operation *> current = state.getPayloadOps(handle);
SmallVector<Operation *> payload;
payload.reserve(numRepetitions * current.size());
for (unsigned i = 0; i < numRepetitions; ++i)
llvm::append_range(payload, current);
results.set(getReplicated()[en.index()].cast<OpResult>(), payload);
} else {
assert(handle.getType().isa<TransformParamTypeInterface>() &&
"expected param type");
ArrayRef<Attribute> current = state.getParams(handle);
SmallVector<Attribute> params;
params.reserve(numRepetitions * current.size());
for (unsigned i = 0; i < numRepetitions; ++i)
llvm::append_range(params, current);
results.setParams(getReplicated()[en.index()].cast<OpResult>(), params);
}
}
return DiagnosedSilenceableFailure::success();
}
Expand Down
73 changes: 55 additions & 18 deletions mlir/python/mlir/dialects/_structured_transform_ops_ext.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,11 +5,11 @@
try:
from ..ir import *
from ._ods_common import get_op_result_or_value as _get_op_result_or_value
from ..dialects import pdl
from ..dialects import pdl, transform
except ImportError as e:
raise RuntimeError("Error loading imports from extension module") from e

from typing import List, Optional, Sequence, Union
from typing import List, Optional, Sequence, Union, overload

IntOrAttrList = Sequence[Union[IntegerAttr, int]]
OptionalIntList = Optional[Union[ArrayAttr, IntOrAttrList]]
Expand Down Expand Up @@ -51,13 +51,13 @@ def _get_int_array_attr(

def _get_dense_int64_array_attr(
values: Sequence[int]) -> DenseI64ArrayAttr:
"""Creates a dense integer array from a sequence of integers.
"""Creates a dense integer array from a sequence of integers.
Expects the thread-local MLIR context to have been set by the context
manager.
"""
if values is None:
return DenseI64ArrayAttr.get([])
return DenseI64ArrayAttr.get(values)
if values is None:
return DenseI64ArrayAttr.get([])
return DenseI64ArrayAttr.get(values)

def _get_int_int_array_attr(
values: Optional[Union[ArrayAttr, Sequence[Union[ArrayAttr,
Expand Down Expand Up @@ -141,6 +141,7 @@ class MultiTileSizesOp:
"""Specialization for MultitileSizesOp class."""

def __init__(self,
result_type: Type,
target: Union[Operation, Value],
*,
dimension: Union[int, IntegerAttr],
Expand All @@ -149,9 +150,9 @@ def __init__(self,
loc=None,
ip=None):
super().__init__(
pdl.OperationType.get(),
pdl.OperationType.get(),
pdl.OperationType.get(),
result_type,
result_type,
result_type,
_get_op_result_or_value(target),
dimension=_get_int64_attr(dimension),
target_size=_get_int64_attr(target_size),
Expand Down Expand Up @@ -223,11 +224,12 @@ def __init__(self,
static_split_point = _get_int64_attr(ShapedType.get_dynamic_size())
dynamic_split_point = _get_op_result_or_value(split_point)

pdl_operation_type = pdl.OperationType.get()
target = _get_op_result_or_value(target)

super().__init__(
pdl_operation_type,
pdl_operation_type,
_get_op_result_or_value(target),
target.type,
target.type,
target,
dimension=dimension,
static_split_point=static_split_point,
dynamic_split_point=dynamic_split_point,
Expand All @@ -238,17 +240,38 @@ def __init__(self,
class TileOp:
"""Specialization for TileOp class."""

@overload
def __init__(self,
loop_types: Union[Type, List[Type]],
target: Union[Operation, Value],
*,
sizes: Optional[Union[Sequence[Union[int, IntegerAttr, Operation,
Value]], ArrayAttr]] = None,
interchange: OptionalIntList = None,
loc=None,
ip=None):
pdl_operation_type = pdl.OperationType.get()
i64_type = IntegerType.get_signless(64)
...

@overload
def __init__(self,
target: Union[Operation, Value],
*,
sizes: Optional[Union[Sequence[Union[int, IntegerAttr, Operation,
Value]], ArrayAttr]] = None,
interchange: OptionalIntList = None,
loc=None,
ip=None):
...

def __init__(self,
loop_types_or_target: Union[Type, List[Type], Operation, Value],
target_or_none: Optional[Union[Operation, Value]] = None,
*,
sizes: Optional[Union[Sequence[Union[int, IntegerAttr, Operation,
Value]], ArrayAttr]] = None,
interchange: OptionalIntList = None,
loc=None,
ip=None):
if sizes is None:
sizes = []

Expand All @@ -267,12 +290,26 @@ def __init__(self,

num_loops = sum(
v if v == 0 else 1 for v in self.__extract_values(sizes_attr))

if isinstance(loop_types_or_target, (Operation, Value)):
loop_types = [transform.AnyOpType.get()] * num_loops
target = loop_types_or_target
assert target_or_none is None, "Cannot construct TileOp with two targets."
else:
loop_types = ([loop_types_or_target] * num_loops) if isinstance(
loop_types_or_target, Type) else loop_types_or_target
target = target_or_none

target = _get_op_result_or_value(target)

super().__init__(
pdl_operation_type, [pdl_operation_type] * num_loops,
_get_op_result_or_value(target),
target.type,
loop_types,
target,
dynamic_sizes=dynamic_sizes,
static_sizes=sizes_attr,
interchange=_get_dense_int64_array_attr(interchange) if interchange else None,
interchange=_get_dense_int64_array_attr(interchange)
if interchange else None,
loc=loc,
ip=ip)

Expand Down
2 changes: 1 addition & 1 deletion mlir/test/Dialect/LLVM/transform-e2e.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ func.func @matmul_tensors(
transform.sequence failures(propagate) {
^bb1(%module_op: !pdl.operation):
%0 = transform.structured.match ops{["linalg.matmul"]} in %module_op
%1, %loops:3 = transform.structured.tile %0 [2, 2, 2]
%1, %loops:3 = transform.structured.tile %0 [2, 2, 2] : (!pdl.operation) -> (!pdl.operation, !pdl.operation, !pdl.operation, !pdl.operation)
%2 = get_closest_isolated_parent %1 : (!pdl.operation) -> !pdl.operation
transform.structured.vectorize %2
transform.bufferization.one_shot_bufferize layout{IdentityLayoutMap} %module_op
Expand Down
117 changes: 108 additions & 9 deletions mlir/test/Dialect/Linalg/multisize-tiling-full.mlir
Original file line number Diff line number Diff line change
@@ -1,23 +1,29 @@
// RUN: mlir-opt --test-transform-dialect-interpreter --scf-for-loop-canonicalization --canonicalize %s | FileCheck %s
// RUN: mlir-opt --test-transform-dialect-interpreter --scf-for-loop-canonicalization --canonicalize --split-input-file %s | FileCheck %s
// RUN: mlir-opt --test-transform-dialect-interpreter --split-input-file %s | FileCheck %s --check-prefix=NOCANON

// This implements a 2D multisize tiling with target sizes [3, 10].
transform.sequence failures(propagate) {
^bb1(%arg1: !pdl.operation):
%0 = transform.structured.match ops{["linalg.generic"]} in %arg1
%1:3 = transform.structured.multitile_sizes %0 { dimension = 0, target_size = 3}
%t:3 = transform.structured.multitile_sizes %0 { dimension = 1, target_size = 10}
%2:2 = transform.structured.split %0 after %1#2 { dimension = 0 }
%3:2 = transform.structured.tile %2#0 [%1#0]
%4:2 = transform.structured.tile %2#1 [%1#1]
%1:3 = transform.structured.multitile_sizes %0 { dimension = 0, target_size = 3} : (!pdl.operation) -> !pdl.operation
%t:3 = transform.structured.multitile_sizes %0 { dimension = 1, target_size = 10} : (!pdl.operation) -> !pdl.operation
%2:2 = transform.structured.split %0 after %1#2 { dimension = 0 } : !pdl.operation, !pdl.operation
%3:2 = transform.structured.tile %2#0 [%1#0] : (!pdl.operation, !pdl.operation) -> (!pdl.operation, !pdl.operation)
%4:2 = transform.structured.tile %2#1 [%1#1] : (!pdl.operation, !pdl.operation) -> (!pdl.operation, !pdl.operation)
%5 = merge_handles %3#0, %4#0 : !pdl.operation
%tt:3 = replicate num(%5) %t#0, %t#1, %t#2 : !pdl.operation, !pdl.operation, !pdl.operation, !pdl.operation
%6:2 = transform.structured.split %5 after %tt#2 { dimension = 1 }
transform.structured.tile %6#0 [0, %tt#0]
transform.structured.tile %6#1 [0, %tt#1]
%6:2 = transform.structured.split %5 after %tt#2 { dimension = 1 } : !pdl.operation, !pdl.operation
transform.structured.tile %6#0 [0, %tt#0] : (!pdl.operation, !pdl.operation) -> (!pdl.operation, !pdl.operation)
transform.structured.tile %6#1 [0, %tt#1] : (!pdl.operation, !pdl.operation) -> (!pdl.operation, !pdl.operation)
}

func.func private @elem(%arg0: f32, %arg1: index, %arg2: index) -> f32

// Without canonicalization, tile sizes are computed dynamically as affine maps.
// NOCANON-LABEL: @two_d
// NOCANON-COUNT-8: affine.apply
// NOCANON: scf.for

// CHECK-LABEL: @two_d
// CHECK-SAME: %[[IN:.+]]: tensor<10x34xf32>, %[[OUT:.+]]: tensor<10x34xf32>
func.func @two_d(%arg0: tensor<10x34xf32>,
Expand Down Expand Up @@ -93,3 +99,96 @@ func.func @two_d(%arg0: tensor<10x34xf32>,

return %0 : tensor<10x34xf32>
}

// -----

transform.sequence failures(propagate) {
^bb1(%arg1: !pdl.operation):
%0 = transform.structured.match ops{["linalg.generic"]} in %arg1
%1:3 = transform.structured.multitile_sizes %0 { dimension = 0, target_size = 3} : (!pdl.operation) -> !transform.param<i64>
%t:3 = transform.structured.multitile_sizes %0 { dimension = 1, target_size = 10} : (!pdl.operation) -> !transform.param<i64>
%2:2 = transform.structured.split %0 after %1#2 { dimension = 0 } : !pdl.operation, !transform.param<i64>
%3:2 = transform.structured.tile %2#0 [%1#0] : (!pdl.operation, !transform.param<i64>) -> (!pdl.operation, !pdl.operation)
%4:2 = transform.structured.tile %2#1 [%1#1] : (!pdl.operation, !transform.param<i64>) -> (!pdl.operation, !pdl.operation)
%5 = merge_handles %3#0, %4#0 : !pdl.operation
%tt:3 = replicate num(%5) %t#0, %t#1, %t#2 : !pdl.operation, !transform.param<i64>, !transform.param<i64>, !transform.param<i64>
%6:2 = transform.structured.split %5 after %tt#2 { dimension = 1 } : !pdl.operation, !transform.param<i64>
transform.structured.tile %6#0 [0, %tt#0] : (!pdl.operation, !transform.param<i64>) -> (!pdl.operation, !pdl.operation)
transform.structured.tile %6#1 [0, %tt#1] : (!pdl.operation, !transform.param<i64>) -> (!pdl.operation, !pdl.operation)
}

func.func private @elem(%arg0: f32, %arg1: index, %arg2: index) -> f32

// Even without canonicalization, tile sizes can be computed statically thanks
// to parameters.
// NOCANON-LABEL: @two_d
// NOCANON-NOT: affine.apply
// NOCANON: scf.for

// CHECK-LABEL: @two_d_param
// CHECK-SAME: %[[IN:.+]]: tensor<10x34xf32>, %[[OUT:.+]]: tensor<10x34xf32>
func.func @two_d_param(%arg0: tensor<10x34xf32>,
%arg1: tensor<10x34xf32>) -> tensor<10x34xf32> {
%0 = linalg.generic {
indexing_maps = [affine_map<(i, j) -> (i, j)>,
affine_map<(i, j) -> (i, j)>],
iterator_types = ["parallel", "parallel"]
}
ins(%arg0: tensor<10x34xf32>)
outs(%arg1: tensor<10x34xf32>) {
^bb0(%0: f32, %1: f32):
%i = linalg.index 0 : index
%j = linalg.index 1 : index
%call_res = func.call @elem(%0, %i, %j) : (f32, index, index) -> f32
linalg.yield %call_res : f32
} -> tensor<10x34xf32>

// CHECK: %[[SLICE_1_IN:.+]] = tensor.extract_slice %[[IN]][0, 0] [4, 34] [1, 1]
// CHECK: %[[SLICE_1:.+]] = tensor.extract_slice %[[OUT]][0, 0] [4, 34] [1, 1]
// CHECK: scf.for %[[I1:.+]] = %{{.*}} to %{{.*}} step %{{.*}} iter_args(%[[ITERARG_1:.+]] = %[[SLICE_1]])
// CHECK: %[[OUTSLICE_1_IN:.+]] = tensor.extract_slice %[[SLICE_1_IN]][%[[I1]], 0] [2, 34] [1, 1]
// CHECK: %[[OUTSLICE_1:.+]] = tensor.extract_slice %[[ITERARG_1]][%[[I1]], 0] [2, 34] [1, 1]

// CHECK: %[[SLICE_2_IN:.+]] = tensor.extract_slice %[[OUTSLICE_1_IN]][0, 0] [2, 16] [1, 1]
// CHECK: %[[SLICE_2:.+]] = tensor.extract_slice %[[OUTSLICE_1]][0, 0] [2, 16] [1, 1]
// CHECK: %[[LOOPRES:.+]] = scf.for %[[I2:.+]] = %{{.*}} to %{{.*}} step %{{.*}} iter_args(%[[ITERARG_2:.+]] = %[[SLICE_2]])
// CHECK: %[[INSLICE_2:.+]] = tensor.extract_slice %[[SLICE_2_IN]][0, %[[I2]]] [2, 8] [1, 1]
// CHECK: %[[OUTSLICE_2:.+]] = tensor.extract_slice %[[ITERARG_2]][0, %[[I2]]] [2, 8] [1, 1]
// CHECK: %[[RESSLICE_1:.+]] = linalg.generic {{.*}} ins(%[[INSLICE_2]] : tensor<2x8xf32>) outs(%[[OUTSLICE_2]] : tensor<2x8xf32>)
// CHECK: %[[RESPARTIAL:.+]] = tensor.insert_slice %[[RESSLICE_1]] into %[[ITERARG_2]]
// CHECK: scf.yield %[[RESPARTIAL]]

// CHECK: %[[INSERTED:.+]] = tensor.insert_slice %[[LOOPRES]] into %[[OUTSLICE_1]][0, 0] [2, 16] [1, 1]
// CHECK: %[[OUTSLICE_3:.+]] = tensor.extract_slice %[[INSERTED]][0, 16] [2, 18] [1, 1]
// CHECK: scf.for %{{.*}} iter_args(%{{.*}} = %[[OUTSLICE_3]])
// CHECK-COUNT-2: tensor.extract_slice
// CHECK: linalg.generic {{.*}} ins(%{{.*}} : tensor<2x9xf32>)
// CHECK: tensor.insert_slice
// CHECK: scf.yield
// CHECK: %[[INSERTED_2:.+]] = tensor.insert_slice %{{.*}} into %[[INSERTED]]
// CHECK: %[[INSERTED_3:.+]] = tensor.insert_slice %[[INSERTED_2]] into %[[ITERARG_1]]
// CHECK: scf.yield %[[INSERTED_3]]

// CHECK: tensor.insert_slice
// CHECK: tensor.extract_slice
// CHECK: scf.for
// CHECK-COUNT-2: tensor.extract_slice
// CHECK: scf.for
// CHECK-COUNT-2: tensor.extract_slice
// CHECK: linalg.generic {{.*}} ins(%{{.*}} : tensor<3x8xf32>)
// CHECK: tensor.insert_slice
// CHECK: scf.yield
// CHECK: tensor.insert_slice
// CHECK: tensor.extract_slice
// CHECK: scf.for
// CHECK-COUNT-2: tensor.extract_slice
// CHECK: linalg.generic {{.*}} ins(%{{.*}} : tensor<3x9xf32>)
// CHECK: tensor.insert_slice
// CHECK: scf.yield
// CHECK-COUNT-2: tensor.insert_slice
// CHECK: scf.yield
// CHECK: %[[RESULT:.+]] = tensor.insert_slice
// CHECK: return %[[RESULT]]

return %0 : tensor<10x34xf32>
}
2 changes: 1 addition & 1 deletion mlir/test/Dialect/Linalg/promotion_options.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,6 @@ func.func @gemm(%a : memref<?x?xf32>, %b : memref<?x?xf32>, %c : memref<?x?xf32>
transform.sequence failures(propagate) {
^bb0(%arg1: !pdl.operation):
%0 = transform.structured.match ops{["linalg.matmul"]} in %arg1
%1, %loops:3 = transform.structured.tile %0 [16, 16, 16]
%1, %loops:3 = transform.structured.tile %0 [16, 16, 16] : (!pdl.operation) -> (!pdl.operation, !pdl.operation, !pdl.operation, !pdl.operation)
%2 = transform.structured.promote %1 { operands_to_promote = [0, 2], force_full_tiles = [false, false], use_full_tiles_by_default }
}
2 changes: 1 addition & 1 deletion mlir/test/Dialect/Linalg/tile-conv.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ func.func @conv(%arg0 : memref<?x?xf32>, %arg1 : memref<?x?xf32>, %arg2 : memref
transform.sequence failures(propagate) {
^bb0(%arg1: !pdl.operation):
%0 = transform.structured.match ops{["linalg.conv_2d"]} in %arg1
%1, %loop:2 = transform.structured.tile %0 [2, 3]
%1, %loop:2 = transform.structured.tile %0 [2, 3] : (!pdl.operation) -> (!pdl.operation, !pdl.operation, !pdl.operation)
}

// CHECK: func @conv
Expand Down
4 changes: 2 additions & 2 deletions mlir/test/Dialect/Linalg/tile-indexed.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ func.func @indexed_vector(%arg0: memref<50xindex>) {
transform.sequence failures(propagate) {
^bb0(%arg1: !pdl.operation):
%0 = transform.structured.match ops{["linalg.generic"]} in %arg1
%1, %loop = transform.structured.tile %0 [10]
%1, %loop = transform.structured.tile %0 [10] : (!pdl.operation) -> (!pdl.operation, !pdl.operation)
}

// TILE-10n25-DAG: [[$MAP:#[a-zA-Z0-9_]*]] = affine_map<(d0, d1) -> (d0 + d1)>
Expand Down Expand Up @@ -44,7 +44,7 @@ func.func @indexed_matrix(%arg0: memref<50x50xindex>) {
transform.sequence failures(propagate) {
^bb0(%arg1: !pdl.operation):
%0 = transform.structured.match ops{["linalg.generic"]} in %arg1
%1, %loop:2 = transform.structured.tile %0 [10, 25]
%1, %loop:2 = transform.structured.tile %0 [10, 25] : (!pdl.operation) -> (!pdl.operation, !pdl.operation, !pdl.operation)
}

// TILE-10n25-DAG: [[$MAP:#[a-zA-Z0-9_]*]] = affine_map<(d0, d1) -> (d0 + d1)>
Expand Down
6 changes: 3 additions & 3 deletions mlir/test/Dialect/Linalg/tile-tensors.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ func.func @matmul_tensors(
transform.sequence failures(propagate) {
^bb0(%arg1: !pdl.operation):
%0 = transform.structured.match ops{["linalg.matmul"]} in %arg1
%1, %loops:3 = transform.structured.tile %0 [2, 3, 4]
%1, %loops:3 = transform.structured.tile %0 [2, 3, 4] : (!pdl.operation) -> (!pdl.operation, !pdl.operation, !pdl.operation, !pdl.operation)
}

// -----
Expand Down Expand Up @@ -61,7 +61,7 @@ func.func @generic_op_tensors(
transform.sequence failures(propagate) {
^bb0(%arg1: !pdl.operation):
%0 = transform.structured.match ops{["linalg.generic"]} in %arg1
%1, %loops:3 = transform.structured.tile %0 [2, 3, 4]
%1, %loops:3 = transform.structured.tile %0 [2, 3, 4] : (!pdl.operation) -> (!pdl.operation, !pdl.operation, !pdl.operation, !pdl.operation)
}

// CHECK-LABEL: func @generic_op_tensors
Expand Down Expand Up @@ -132,5 +132,5 @@ func.func @fold_extract_slice(
transform.sequence failures(propagate) {
^bb0(%arg1: !pdl.operation):
%0 = transform.structured.match ops{["linalg.generic"]} in %arg1
%1, %loops:3 = transform.structured.tile %0 [2, 3, 4]
%1, %loops:3 = transform.structured.tile %0 [2, 3, 4] : (!pdl.operation) -> (!pdl.operation, !pdl.operation, !pdl.operation, !pdl.operation)
}
2 changes: 1 addition & 1 deletion mlir/test/Dialect/Linalg/transform-op-fuse.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -89,7 +89,7 @@ transform.sequence failures(propagate) {
^bb1(%arg1: !pdl.operation):
%0 = transform.structured.match ops{["linalg.generic"]} in %arg1
%1, %loops:2 = transform.structured.fuse %0 {tile_sizes = [5, 0, 7], tile_interchange = [0, 2, 1]}
%2, %loops_2 = transform.structured.tile %1 [0, 4]
%2, %loops_2 = transform.structured.tile %1 [0, 4] : (!pdl.operation) -> (!pdl.operation, !pdl.operation)
}

// -----
Expand Down
54 changes: 51 additions & 3 deletions mlir/test/Dialect/Linalg/transform-op-multitile-sizes.mlir
Original file line number Diff line number Diff line change
@@ -1,11 +1,11 @@
// RUN: mlir-opt %s --test-transform-dialect-interpreter --split-input-file | FileCheck %s
// RUN: mlir-opt %s --test-transform-dialect-interpreter --split-input-file --verify-diagnostics | FileCheck %s

// CHECK-DAG: #[[$MAP13:.+]] = affine_map<() -> (13)>

transform.sequence failures(propagate) {
^bb0(%arg1: !pdl.operation):
%0 = transform.structured.match ops{["linalg.matmul"]} in %arg1
transform.structured.multitile_sizes %0 { target_size = 3, dimension = 0 }
transform.structured.multitile_sizes %0 { target_size = 3, dimension = 0 } : (!pdl.operation) -> !pdl.operation
}

// CHECK-LABEL: @multitile_sizes_static
Expand All @@ -29,7 +29,34 @@ func.func @multitile_sizes_static(
transform.sequence failures(propagate) {
^bb0(%arg1: !pdl.operation):
%0 = transform.structured.match ops{["linalg.matmul"]} in %arg1
transform.structured.multitile_sizes %0 { target_size = 3, divisor = 2, dimension = 0 }
%low_tile, %high_tile, %split_point =
transform.structured.multitile_sizes %0 { target_size = 3, dimension = 0 }
: (!pdl.operation) -> !transform.param<i64>
// expected-remark @below {{2 : i64}}
transform.test_print_param %low_tile : !transform.param<i64>
// expected-remark @below {{3 : i64}}
transform.test_print_param %high_tile : !transform.param<i64>
// expected-remark @below {{4 : i64}}
transform.test_print_param %split_point : !transform.param<i64>
}

// CHECK-LABEL: @multitile_sizes_static_gen
func.func @multitile_sizes_static_gen(
%arg0: tensor<13x34xf32>, %arg1: tensor<34x42xf32>, %arg2: tensor<13x42xf32>)
-> tensor<13x42xf32> {
%0 = linalg.matmul ins(%arg0, %arg1: tensor<13x34xf32>, tensor<34x42xf32>)
outs(%arg2: tensor<13x42xf32>)
-> tensor<13x42xf32>

return %0 : tensor<13x42xf32>
}

// -----

transform.sequence failures(propagate) {
^bb0(%arg1: !pdl.operation):
%0 = transform.structured.match ops{["linalg.matmul"]} in %arg1
transform.structured.multitile_sizes %0 { target_size = 3, divisor = 2, dimension = 0 } : (!pdl.operation) -> !pdl.operation
}

// CHECK: #[[$MAP_A:.+]] = affine_map<()[s0] -> ([[A_IMPL:s0 floordiv 2]])>
Expand Down Expand Up @@ -64,3 +91,24 @@ func.func @multitile_sizes_dynamic(

return %0 : tensor<?x?xf32>
}

// -----

transform.sequence failures(propagate) {
^bb0(%arg1: !pdl.operation):
%0 = transform.structured.match ops{["linalg.matmul"]} in %arg1
// expected-error @below {{cannot compute parametric tile sizes for dynamically shaped payload op}}
transform.structured.multitile_sizes %0 { target_size = 3, divisor = 2, dimension = 0 }
: (!pdl.operation) -> !transform.param<i64>
}

func.func @multitile_sizes_dynamic_gen(
%arg0: tensor<?x?xf32>, %arg1: tensor<?x?xf32>, %arg2: tensor<?x?xf32>)
-> tensor<?x?xf32> {
// expected-note @below {{payload op}}
%0 = linalg.matmul ins(%arg0, %arg1: tensor<?x?xf32>, tensor<?x?xf32>)
outs(%arg2: tensor<?x?xf32>)
-> tensor<?x?xf32>

return %0 : tensor<?x?xf32>
}
2 changes: 1 addition & 1 deletion mlir/test/Dialect/Linalg/transform-op-scalarize.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,6 @@ func.func @scalarize(%arg0: tensor<24x12xf32>,
transform.sequence failures(propagate) {
^bb1(%arg1: !pdl.operation):
%0 = transform.structured.match ops{["linalg.matmul"]} in %arg1
%1, %loops = transform.structured.tile %0 [10, 0, 0]
%1, %loops = transform.structured.tile %0 [10, 0, 0] : (!pdl.operation) -> (!pdl.operation, !pdl.operation)
%2 = transform.structured.scalarize %1
}
20 changes: 10 additions & 10 deletions mlir/test/Dialect/Linalg/transform-op-split.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
transform.sequence failures(propagate) {
^bb1(%arg1: !pdl.operation):
%0 = transform.structured.match ops{["linalg.generic"]} in %arg1
%1:2 = transform.structured.split %0 after 42 { dimension = 0 }
%1:2 = transform.structured.split %0 after 42 { dimension = 0 } : !pdl.operation
}

func.func private @elem(%arg0: f32, %arg1: index, %arg2: index) -> f32
Expand Down Expand Up @@ -51,7 +51,7 @@ func.func @one_d_static(%arg0: tensor<100xf32>, %arg1: tensor<100xf32>) -> tenso
transform.sequence failures(propagate) {
^bb1(%arg1: !pdl.operation):
%0 = transform.structured.match ops{["linalg.generic"]} in %arg1
%1:2 = transform.structured.split %0 after 42 { dimension = 0 }
%1:2 = transform.structured.split %0 after 42 { dimension = 0 } : !pdl.operation
}

func.func private @elem(%arg0: f32, %arg1: index, %arg2: index) -> f32
Expand Down Expand Up @@ -85,7 +85,7 @@ transform.sequence failures(propagate) {
^bb1(%arg1: !pdl.operation):
%0 = transform.structured.match ops{["linalg.generic"]} in %arg1
%1 = transform.structured.match ops{["func.call"]} in %arg1
transform.structured.split %0 after %1 { dimension = 0 }
transform.structured.split %0 after %1 { dimension = 0 } : !pdl.operation, !pdl.operation
}

func.func private @get_size() -> index
Expand Down Expand Up @@ -132,8 +132,8 @@ func.func @dynamic(%arg0: tensor<100xf32>, %arg1: tensor<100xf32>) -> tensor<100
transform.sequence failures(propagate) {
^bb1(%arg1: !pdl.operation):
%0 = transform.structured.match ops{["linalg.generic"]} in %arg1
%1:2 = transform.structured.split %0 after 4 { dimension = 0}
%2:2 = transform.structured.split %1#1 after 16 { dimension = 1 }
%1:2 = transform.structured.split %0 after 4 { dimension = 0 } : !pdl.operation
%2:2 = transform.structured.split %1#1 after 16 { dimension = 1 } : !pdl.operation
}

func.func private @elem(%arg0: f32, %arg1: index, %arg2: index) -> f32
Expand Down Expand Up @@ -199,7 +199,7 @@ transform.sequence failures(propagate) {
%0 = transform.structured.match ops{["linalg.generic"]} in %arg1
%1 = transform.structured.match ops{["func.call"]} in %arg1
// expected-error @below {{expected dynamic split point handle to point to a single-result index-typed op}}
transform.structured.split %0 after %1 { dimension = 0 }
transform.structured.split %0 after %1 { dimension = 0 } : !pdl.operation, !pdl.operation
}

func.func private @get_size() -> i64
Expand All @@ -225,7 +225,7 @@ transform.sequence failures(propagate) {
%0 = transform.structured.match ops{["linalg.generic"]} in %arg1
%1 = transform.structured.match ops{["func.call"]} in %arg1
// expected-error @below {{expected the dynamic split point handle to point to as many operations (0) as the target handle (1)}}
transform.structured.split %0 after %1 { dimension = 0 }
transform.structured.split %0 after %1 { dimension = 0 } : !pdl.operation, !pdl.operation
}

func.func private @get_size() -> i64
Expand All @@ -248,7 +248,7 @@ transform.sequence failures(propagate) {
^bb1(%arg1: !pdl.operation):
%0 = transform.structured.match ops{["func.return"]} in %arg1
// expected-error @below {{only applies to structured ops}}
transform.structured.split %0 after 16 { dimension = 1 }
transform.structured.split %0 after 16 { dimension = 1 } : !pdl.operation
}

func.func @noop(%arg0: tensor<100xf32>, %arg1: tensor<100xf32>) -> tensor<100xf32> {
Expand All @@ -262,7 +262,7 @@ transform.sequence failures(propagate) {
^bb1(%arg1: !pdl.operation):
%0 = transform.structured.match ops{["linalg.generic"]} in %arg1
// expected-error @below {{dimension 1 does not exist in target op}}
transform.structured.split %0 after 16 { dimension = 1 }
transform.structured.split %0 after 16 { dimension = 1 } : !pdl.operation
}

func.func @one_d_static(%arg0: tensor<100xf32>, %arg1: tensor<100xf32>) -> tensor<100xf32> {
Expand All @@ -285,7 +285,7 @@ transform.sequence failures(propagate) {
%0 = transform.structured.match ops{["linalg.generic"]} in %arg1
// expected-error @below {{splitting does not produce the second part for a subset of targets}}
// expected-note @below {{expected splitting to produce the second part of all or none of the targets}}
%1:2 = transform.structured.split %0 after 142 { dimension = 0 }
%1:2 = transform.structured.split %0 after 142 { dimension = 0 } : !pdl.operation
}

func.func private @elem(%arg0: f32, %arg1: index, %arg2: index) -> f32
Expand Down
56 changes: 53 additions & 3 deletions mlir/test/Dialect/Linalg/transform-op-tile.mlir
Original file line number Diff line number Diff line change
@@ -1,9 +1,9 @@
// RUN: mlir-opt --test-transform-dialect-interpreter --split-input-file %s | FileCheck %s
// RUN: mlir-opt --test-transform-dialect-interpreter --split-input-file --verify-diagnostics %s | FileCheck %s

transform.sequence failures(propagate) {
^bb0(%arg1: !pdl.operation):
%0 = transform.structured.match ops{["linalg.matmul"]} in %arg1
%1, %loops:3 = transform.structured.tile %0 [4, 4, 4]
%1, %loops:3 = transform.structured.tile %0 [4, 4, 4] : (!pdl.operation) -> (!pdl.operation, !pdl.operation, !pdl.operation, !pdl.operation)
}

// CHECK-LABEL: func @tile_linalg_matmul(
Expand Down Expand Up @@ -40,7 +40,7 @@ transform.sequence failures(propagate) {
^bb0(%arg1: !pdl.operation):
%0 = transform.structured.match ops{["linalg.matmul"]} in %arg1
%1 = transform.structured.match ops{["func.call"]} in %arg1
%2, %loops:3 = transform.structured.tile %0 [%1, %1, 4]
%2, %loops:3 = transform.structured.tile %0 [%1, %1, 4] : (!pdl.operation, !pdl.operation, !pdl.operation) -> (!pdl.operation, !pdl.operation, !pdl.operation, !pdl.operation)
}

func.func private @get_dynamic_tile_size() -> index
Expand Down Expand Up @@ -73,3 +73,53 @@ func.func @tile_linalg_matmul_dynamic(
// CHECK: return %[[TD0]] : tensor<128x128xf32>
return %0 : tensor<128x128xf32>
}

// -----

transform.sequence failures(propagate) {
^bb0(%arg1: !pdl.operation):
%0 = transform.structured.match ops{["linalg.matmul"]} in %arg1
// expected-note @below {{for this parameter}}
%1 = transform.test_produce_integer_param_with_type i64 : !transform.param<i64>
// expected-error @below {{expected as many parameter values (0) as target ops (2)}}
transform.structured.tile %0 [%1, %1, %1]
: (!pdl.operation, !transform.param<i64>, !transform.param<i64>, !transform.param<i64>)
-> (!pdl.operation, !pdl.operation, !pdl.operation, !pdl.operation)
}

func.func @tile_linalg_matmul(
%arg0: tensor<128x128xf32>, %arg1: tensor<128x128xf32>, %arg2: tensor<128x128xf32>)
-> (tensor<128x128xf32>, tensor<128x128xf32>) {
%0 = linalg.matmul ins(%arg0, %arg1: tensor<128x128xf32>, tensor<128x128xf32>)
outs(%arg2: tensor<128x128xf32>)
-> tensor<128x128xf32>
%1 = linalg.matmul ins(%0, %arg1: tensor<128x128xf32>, tensor<128x128xf32>)
outs(%arg2: tensor<128x128xf32>)
-> tensor<128x128xf32>
return %0, %1 : tensor<128x128xf32>, tensor<128x128xf32>
}

// -----

transform.sequence failures(propagate) {
^bb0(%arg1: !pdl.operation):
%0 = transform.structured.match ops{["linalg.matmul"]} in %arg1
// expected-note @below {{for this handle}}
%1 = transform.structured.match ops{["arith.constant"]} in %arg1
// expected-error @below {{expected as many dynamic size-producing operations (0) as target ops (2)}}
transform.structured.tile %0 [%1, %1, 1]
: (!pdl.operation, !pdl.operation, !pdl.operation)
-> (!pdl.operation, !pdl.operation, !pdl.operation, !pdl.operation)
}

func.func @tile_linalg_matmul(
%arg0: tensor<128x128xf32>, %arg1: tensor<128x128xf32>, %arg2: tensor<128x128xf32>)
-> (tensor<128x128xf32>, tensor<128x128xf32>) {
%0 = linalg.matmul ins(%arg0, %arg1: tensor<128x128xf32>, tensor<128x128xf32>)
outs(%arg2: tensor<128x128xf32>)
-> tensor<128x128xf32>
%1 = linalg.matmul ins(%0, %arg1: tensor<128x128xf32>, tensor<128x128xf32>)
outs(%arg2: tensor<128x128xf32>)
-> tensor<128x128xf32>
return %0, %1 : tensor<128x128xf32>, tensor<128x128xf32>
}
9 changes: 9 additions & 0 deletions mlir/test/Dialect/Linalg/transform-ops-invalid.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -45,3 +45,12 @@ transform.sequence failures(propagate) {
// expected-error@below {{'transform.structured.interchange' op attribute 'iterator_interchange' failed to satisfy constraint: i64 dense array attribute whose value is non-negative}}
transform.structured.interchange %arg0 iterator_interchange = [-3, 1]
}

// -----

transform.sequence failures(propagate) {
^bb0(%arg0: !pdl.operation):
// expected-error@below {{expects all results type to be the same}}
"transform.structured.multitile_sizes"(%arg0) { target_size = 3, divisor = 2, dimension = 0 }
: (!pdl.operation) -> (!transform.param<i64>, !transform.param<i64>, !transform.param<i32>)
}
8 changes: 7 additions & 1 deletion mlir/test/Dialect/Linalg/transform-ops.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,13 @@
transform.sequence failures(propagate) {
^bb1(%arg0: !pdl.operation):
// CHECK %{{.*}}, %{{.*}}:2 = transform.structured.tile
%0, %1:2 = transform.structured.tile %arg0 [2, 0, 3]
%0, %1:2 = transform.structured.tile %arg0 [2, 0, 3] : (!pdl.operation) -> (!pdl.operation, !pdl.operation, !pdl.operation)
}

transform.sequence failures(propagate) {
^bb1(%arg0: !transform.any_op):
%0:2 = transform.structured.split %arg0 after 42 { dimension = 0 } : !transform.any_op
transform.structured.split %0#0 after %0#1 { dimension = 1 } : !transform.any_op, !transform.any_op
}

//===----------------------------------------------------------------------===//
Expand Down
20 changes: 10 additions & 10 deletions mlir/test/Dialect/Linalg/transform-patterns.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ func.func @dot(%x: memref<?xf32, strided<[1], offset: ?>>,
transform.sequence failures(propagate) {
^bb0(%arg1: !pdl.operation):
%0 = transform.structured.match ops{["linalg.dot"]} in %arg1
%1, %loop = transform.structured.tile %0 [8000]
%1, %loop = transform.structured.tile %0 [8000] : (!pdl.operation) -> (!pdl.operation, !pdl.operation)
}

// CHECK-LABEL: func @dot
Expand All @@ -38,7 +38,7 @@ func.func @matvec(%A: memref<?x?xf32, strided<[?, 1], offset: ?>>,
transform.sequence failures(propagate) {
^bb0(%arg1: !pdl.operation):
%0 = transform.structured.match ops{["linalg.matvec"]} in %arg1
%1, %loops:2 = transform.structured.tile %0 [5, 6]
%1, %loops:2 = transform.structured.tile %0 [5, 6] : (!pdl.operation) -> (!pdl.operation, !pdl.operation, !pdl.operation)
}

// CHECK-LABEL: func @matvec
Expand All @@ -65,10 +65,10 @@ func.func @matmul(%A: memref<?x?xf32, strided<[?, 1], offset: ?>>,
transform.sequence failures(propagate) {
^bb0(%arg1: !pdl.operation):
%0 = transform.structured.match ops{["linalg.matmul"]} in %arg1
%1, %loops:3 = transform.structured.tile %0 [2000, 3000, 4000]
%2, %loops_2:3 = transform.structured.tile %1 [200, 300, 400]
%3, %loops_3:3 = transform.structured.tile %2 [20, 30, 40]
%4, %loops_4:3 = transform.structured.tile %3 [2, 3, 4]
%1, %loops:3 = transform.structured.tile %0 [2000, 3000, 4000] : (!pdl.operation) -> (!pdl.operation, !pdl.operation, !pdl.operation, !pdl.operation)
%2, %loops_2:3 = transform.structured.tile %1 [200, 300, 400] : (!pdl.operation) -> (!pdl.operation, !pdl.operation, !pdl.operation, !pdl.operation)
%3, %loops_3:3 = transform.structured.tile %2 [20, 30, 40] : (!pdl.operation) -> (!pdl.operation, !pdl.operation, !pdl.operation, !pdl.operation)
%4, %loops_4:3 = transform.structured.tile %3 [2, 3, 4] : (!pdl.operation) -> (!pdl.operation, !pdl.operation, !pdl.operation, !pdl.operation)
}

// CHECK-LABEL: func @matmul
Expand Down Expand Up @@ -164,7 +164,7 @@ func.func @matvec_perm(%A: memref<?x?xf32, strided<[?, 1], offset: ?>>,
transform.sequence failures(propagate) {
^bb0(%arg1: !pdl.operation):
%0 = transform.structured.match ops{["linalg.matvec"]} in %arg1
%1, %loops:2 = transform.structured.tile %0 [5, 6] {interchange = [1, 0]}
%1, %loops:2 = transform.structured.tile %0 [5, 6] {interchange = [1, 0]} : (!pdl.operation) -> (!pdl.operation, !pdl.operation, !pdl.operation)
}

// CHECK-LABEL: func @matvec_perm
Expand All @@ -191,9 +191,9 @@ func.func @matmul_perm(%A: memref<?x?xf32, strided<[?, 1], offset: ?>>,
transform.sequence failures(propagate) {
^bb0(%arg1: !pdl.operation):
%0 = transform.structured.match ops{["linalg.matmul"]} in %arg1
%1, %loops:3 = transform.structured.tile %0 [2000, 3000, 4000] {interchange = [1, 2, 0]}
%2, %loops_2:3 = transform.structured.tile %1 [200, 300, 400] {interchange = [1, 0, 2]}
%3, %loops_3:3 = transform.structured.tile %2 [20, 30, 40]
%1, %loops:3 = transform.structured.tile %0 [2000, 3000, 4000] {interchange = [1, 2, 0]} : (!pdl.operation) -> (!pdl.operation, !pdl.operation, !pdl.operation, !pdl.operation)
%2, %loops_2:3 = transform.structured.tile %1 [200, 300, 400] {interchange = [1, 0, 2]} : (!pdl.operation) -> (!pdl.operation, !pdl.operation, !pdl.operation, !pdl.operation)
%3, %loops_3:3 = transform.structured.tile %2 [20, 30, 40] : (!pdl.operation) -> (!pdl.operation, !pdl.operation, !pdl.operation, !pdl.operation)
}

// CHECK-LABEL: func @matmul_perm
Expand Down
2 changes: 1 addition & 1 deletion mlir/test/Dialect/Transform/selective-targeting.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -77,7 +77,7 @@ transform.with_pdl_patterns {
transform.sequence %arg0 : !pdl.operation failures(propagate) {
^bb1(%arg1: !pdl.operation):
%0 = pdl_match @pdl_target_attrA in %arg1 : (!pdl.operation) -> !pdl.operation
transform.structured.tile %0 [4, 4, 4]
transform.structured.tile %0 [4, 4, 4] : (!pdl.operation) -> (!pdl.operation, !pdl.operation, !pdl.operation, !pdl.operation)
%1 = pdl_match @pdl_target_attrC in %arg1 : (!pdl.operation) -> !pdl.operation
%2 = transform.get_closest_isolated_parent %1 : (!pdl.operation) -> !pdl.operation
transform.structured.vectorize %2
Expand Down
44 changes: 24 additions & 20 deletions mlir/test/Dialect/Transform/test-interpreter.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -409,9 +409,8 @@ transform.with_pdl_patterns {
transform.sequence %arg0 : !pdl.operation failures(propagate) {
^bb0(%arg1: !pdl.operation):
%0 = pdl_match @some in %arg1 : (!pdl.operation) -> !pdl.operation
// expected-error @below {{applications of transform.test_wrong_number_of_results expected to produce 3 results (actually produced 1).}}
// expected-note @below {{If you need variadic results, consider a generic `apply` instead of the specialized `applyToOne`.}}
// expected-note @below {{Producing 3 null results is allowed if the use case warrants it.}}
// expected-error @below {{application of transform.test_wrong_number_of_results expected to produce 3 results (actually produced 1).}}
// expected-note @below {{if you need variadic results, consider a generic `apply` instead of the specialized `applyToOne`.}}
transform.test_wrong_number_of_results %0
}
}
Expand All @@ -437,9 +436,8 @@ transform.with_pdl_patterns {
transform.sequence %arg0 : !pdl.operation failures(propagate) {
^bb0(%arg1: !pdl.operation):
%0 = pdl_match @some in %arg1 : (!pdl.operation) -> !pdl.operation
// expected-error @below {{applications of transform.test_wrong_number_of_multi_results expected to produce 1 results (actually produced 0)}}
// expected-note @below {{If you need variadic results, consider a generic `apply` instead of the specialized `applyToOne`.}}
// expected-note @below {{Producing 1 null results is allowed if the use case warrants it.}}
// expected-error @below {{application of transform.test_wrong_number_of_multi_results expected to produce 1 results (actually produced 0)}}
// expected-note @below {{if you need variadic results, consider a generic `apply` instead of the specialized `applyToOne`.}}
transform.test_wrong_number_of_multi_results %0
}
}
Expand Down Expand Up @@ -514,7 +512,7 @@ transform.with_pdl_patterns {
transform.sequence %arg0 : !pdl.operation failures(propagate) {
^bb0(%arg1: !pdl.operation):
%0 = pdl_match @some in %arg1 : (!pdl.operation) -> !pdl.operation
// expected-error @below {{unexpected application of transform.test_mixed_null_and_non_null_results produces both null and non null results.}}
// expected-error @below {{null result #0 produced}}
transform.test_mixed_null_and_non_null_results %0
}
}
Expand Down Expand Up @@ -1041,12 +1039,15 @@ func.func private @three_test_ops(%arg0: i32) {

// -----

transform.sequence failures(propagate) {
^bb0(%arg0: !transform.any_op):
// expected-error @below {{expected to produce an Operation * for result #0}}
transform.test_produce_transform_param_or_forward_operand %arg0
{ first_result_is_param }
: (!transform.any_op) -> (!transform.any_op, !transform.param<i64>)
// expected-note @below {{when applied to this op}}
module {
transform.sequence failures(propagate) {
^bb0(%arg0: !transform.any_op):
// expected-error @below {{expected to produce an Operation * for result #0}}
transform.test_produce_transform_param_or_forward_operand %arg0
{ first_result_is_param }
: (!transform.any_op) -> (!transform.any_op, !transform.param<i64>)
}
}

// -----
Expand All @@ -1055,7 +1056,7 @@ transform.sequence failures(propagate) {
module {
transform.sequence failures(propagate) {
^bb0(%arg0: !transform.any_op):
// expected-error @below {{produces both null and non null results}}
// expected-error @below {{null result #0 produced}}
transform.test_produce_transform_param_or_forward_operand %arg0
{ first_result_is_null }
: (!transform.any_op) -> (!transform.any_op, !transform.param<i64>)
Expand All @@ -1064,12 +1065,15 @@ module {

// -----

transform.sequence failures(propagate) {
^bb0(%arg0: !transform.any_op):
// expected-error @below {{expected to produce an Attribute for result #1}}
transform.test_produce_transform_param_or_forward_operand %arg0
{ second_result_is_handle }
: (!transform.any_op) -> (!transform.any_op, !transform.param<i64>)
// expected-note @below {{when applied to this op}}
module {
transform.sequence failures(propagate) {
^bb0(%arg0: !transform.any_op):
// expected-error @below {{expected to produce an Attribute for result #1}}
transform.test_produce_transform_param_or_forward_operand %arg0
{ second_result_is_handle }
: (!transform.any_op) -> (!transform.any_op, !transform.param<i64>)
}
}

// -----
Expand Down
2 changes: 1 addition & 1 deletion mlir/test/Dialect/Vector/transform-vector.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ func.func @matmul_tensors(
transform.sequence failures(propagate) {
^bb1(%module_op: !pdl.operation):
%0 = transform.structured.match ops{["linalg.matmul"]} in %module_op
%1, %loops:3 = transform.structured.tile %0 [8, 4, 2]
%1, %loops:3 = transform.structured.tile %0 [8, 4, 2] : (!pdl.operation) -> (!pdl.operation, !pdl.operation, !pdl.operation, !pdl.operation)
%2 = get_closest_isolated_parent %1 : (!pdl.operation) -> !pdl.operation
transform.structured.vectorize %2
transform.bufferization.one_shot_bufferize %module_op
Expand Down
53 changes: 46 additions & 7 deletions mlir/test/python/dialects/transform_structured_ext.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,8 +56,10 @@ def testInterchange():
def testMultitileSizes():
sequence = transform.SequenceOp(transform.FailurePropagationMode.PROPAGATE, [], pdl.OperationType.get())
with InsertionPoint(sequence.body):
structured.MultiTileSizesOp(
sequence.bodyTarget, dimension=1, target_size=42)
structured.MultiTileSizesOp(pdl.OperationType.get(),
sequence.bodyTarget,
dimension=1,
target_size=42)
transform.YieldOp()
# CHECK-LABEL: TEST: testMultitileSizes
# CHECK: transform.sequence
Expand Down Expand Up @@ -110,7 +112,9 @@ def testSplit():
def testTileCompact():
sequence = transform.SequenceOp(transform.FailurePropagationMode.PROPAGATE, [], pdl.OperationType.get())
with InsertionPoint(sequence.body):
structured.TileOp(sequence.bodyTarget, sizes=[4, 8], interchange=[0, 1])
structured.TileOp(sequence.bodyTarget,
sizes=[4, 8],
interchange=[0, 1])
transform.YieldOp()
# CHECK-LABEL: TEST: testTileCompact
# CHECK: transform.sequence
Expand All @@ -123,7 +127,9 @@ def testTileAttributes():
attr = DenseI64ArrayAttr.get([4, 8])
ichange = DenseI64ArrayAttr.get([0, 1])
with InsertionPoint(sequence.body):
structured.TileOp(sequence.bodyTarget, sizes=attr, interchange=ichange)
structured.TileOp(sequence.bodyTarget,
sizes=attr,
interchange=ichange)
transform.YieldOp()
# CHECK-LABEL: TEST: testTileAttributes
# CHECK: transform.sequence
Expand All @@ -134,8 +140,9 @@ def testTileAttributes():
def testTileZero():
sequence = transform.SequenceOp(transform.FailurePropagationMode.PROPAGATE, [], pdl.OperationType.get())
with InsertionPoint(sequence.body):
structured.TileOp(
sequence.bodyTarget, sizes=[4, 0, 2, 0], interchange=[0, 1, 2, 3])
structured.TileOp(sequence.bodyTarget,
sizes=[4, 0, 2, 0],
interchange=[0, 1, 2, 3])
transform.YieldOp()
# CHECK-LABEL: TEST: testTileZero
# CHECK: transform.sequence
Expand All @@ -151,14 +158,46 @@ def testTileDynamic():
with InsertionPoint(sequence.body):
m1 = transform.PDLMatchOp(pdl.OperationType.get(), sequence.bodyTarget, "first")
m2 = transform.PDLMatchOp(pdl.OperationType.get(), sequence.bodyTarget, "second")
structured.TileOp(sequence.bodyTarget, sizes=[m1, 3, m2, 0])
structured.TileOp(sequence.bodyTarget,
sizes=[m1, 3, m2, 0])
transform.YieldOp()
# CHECK-LABEL: TEST: testTileDynamic
# CHECK: %[[FIRST:.+]] = pdl_match
# CHECK: %[[SECOND:.+]] = pdl_match
# CHECK: %{{.+}}, %{{.+}}:3 = transform.structured.tile %{{.*}}[%[[FIRST]], 3, %[[SECOND]], 0]


@run
def testTileExplicitLoopTypeSingle():
sequence = transform.SequenceOp(transform.FailurePropagationMode.PROPAGATE,
[], transform.AnyOpType.get())
with InsertionPoint(sequence.body):
structured.TileOp(transform.OperationType.get("scf.for"),
sequence.bodyTarget,
sizes=[2, 3, 4])
transform.YieldOp()
# CHECK-LABEL: TEST: testTileExplicitLoopTypeSingle
# CHECK: = transform.structured.tile %{{.*}} : (!{{.*}}) ->
# CHECK-COUNT-3: !transform.op<"scf.for">



@run
def testTileExplicitLoopTypeAll():
sequence = transform.SequenceOp(transform.FailurePropagationMode.PROPAGATE,
[], transform.AnyOpType.get())
types = [
transform.OperationType.get(x)
for x in ["scf.for", "scf.parallel", "scf.foreach_thread"]
]
with InsertionPoint(sequence.body):
structured.TileOp(types, sequence.bodyTarget, sizes=[2, 3, 4])
transform.YieldOp()
# CHECK-LABEL: TEST: testTileExplicitLoopTypeAll
# CHECK: = transform.structured.tile
# CHECK-SAME : (!transform.any_op) -> (!transform.any_op, !transform.op<"scf.for">,
# CHECK-SAME: !transform.op<"scf.parallel">, !transform.op<"scf.foreach_thread">

@run
def testVectorize():
sequence = transform.SequenceOp(transform.FailurePropagationMode.PROPAGATE, [], pdl.OperationType.get())
Expand Down