Skip to content

Commit

Permalink
[mlir][TilingInterface] Move TilingInterface tests to use transform d…
Browse files Browse the repository at this point in the history
…ialect ops. (#77204)

In the process a couple of test transform dialect ops are added just
for testing. These operations are not intended to use as full flushed
out of transformation ops, but are rather operations added for testing.

A separate operation is added to `LinalgTransformOps.td` to convert a
`TilingInterface` operation to loops using the
`generateScalarImplementation` method implemented by the
operation. Eventually this and other operations related to tiling
using the `TilingInterface` need to move to a better place (i.e. out
of `Linalg` dialect)
  • Loading branch information
MaheshRavishankar committed Jan 12, 2024
1 parent 3699811 commit aa2a96a
Show file tree
Hide file tree
Showing 15 changed files with 856 additions and 758 deletions.
Expand Up @@ -293,7 +293,10 @@ def FuseOp : Op<Transform_Dialect, "structured.fuse",
let results = (outs TransformHandleTypeInterface:$transformed,
Variadic<TransformHandleTypeInterface>:$loops);

let hasCustomAssemblyFormat = 1;
let assemblyFormat = [{
$target ($tile_sizes^)? (`interchange` $tile_interchange^)?
attr-dict `:` functional-type(operands, results)
}];
let hasVerifier = 1;
}

Expand Down Expand Up @@ -1269,6 +1272,33 @@ def ScalarizeOp : Op<Transform_Dialect, "structured.scalarize",
}];
}

def ConvertToLoopsOp : Op<Transform_Dialect, "structured.convert_to_loops",
[FunctionalStyleTransformOpTrait, MemoryEffectsOpInterface,
TransformOpInterface, TransformEachOpTrait,
ReportTrackingListenerFailuresOpTrait]> {
let description = [{
For operations that implement the `TilingInterface`, and implement
the `generateScalarImplementation` method, lowers the operation to
loops. This operation does not return any handles.
}];

let arguments = (ins TransformHandleTypeInterface:$target);
let results = (outs);

let assemblyFormat = [{
$target attr-dict `:` type($target)
}];

let extraClassDeclaration = [{
::mlir::DiagnosedSilenceableFailure applyToOne(
::mlir::transform::TransformRewriter &rewriter,
::mlir::TilingInterface target,
::mlir::transform::ApplyToEachResultList &results,
::mlir::transform::TransformState &state);
}];
}


//===----------------------------------------------------------------------===//
// DecomposeInterfaceOp
//===----------------------------------------------------------------------===//
Expand Down
55 changes: 22 additions & 33 deletions mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
Expand Up @@ -492,38 +492,6 @@ transform::FuseOp::apply(transform::TransformRewriter &rewriter,
: DiagnosedSilenceableFailure::success();
}

ParseResult transform::FuseOp::parse(OpAsmParser &parser,
OperationState &result) {
OpAsmParser::UnresolvedOperand targetOperand;
if (parser.parseOperand(targetOperand) ||
parser.parseOptionalAttrDict(result.attributes))
return failure();

FunctionType trailingType;
SMLoc typeLoc;
if (parser.getCurrentLocation(&typeLoc) ||
parser.parseColonType(trailingType)) {
return failure();
}
if (trailingType.getNumInputs() != 1)
return parser.emitError(typeLoc) << "expected one input type";

result.addTypes(trailingType.getResults());
if (parser.resolveOperand(targetOperand, trailingType.getInput(0),
result.operands))
return failure();
return success();
}

void transform::FuseOp::print(OpAsmPrinter &p) {
p << ' ';
p << getTarget();
p.printOptionalAttrDict((*this)->getAttrs());
p << " : ";
p.printFunctionalType(TypeRange(getOperand().getType()),
getResults().getTypes());
}

LogicalResult transform::FuseOp::verify() {
SmallVector<int64_t> permutation =
extractFromIntegerArrayAttr<int64_t>(getTileInterchange());
Expand Down Expand Up @@ -2111,6 +2079,22 @@ transform::ScalarizeOp::applyToOne(transform::TransformRewriter &rewriter,
return DiagnosedSilenceableFailure::success();
}

//===----------------------------------------------------------------------===//
// ConvertToLoopsOp
//===----------------------------------------------------------------------===//

DiagnosedSilenceableFailure transform::ConvertToLoopsOp::applyToOne(
transform::TransformRewriter &rewriter, TilingInterface target,
transform::ApplyToEachResultList &results,
transform::TransformState &state) {
rewriter.setInsertionPoint(target);
FailureOr<SmallVector<scf::ForOp>> loops =
scf::lowerToLoopsUsingSCFForOp(rewriter, target);
if (failed(loops))
return emitDefaultDefiniteFailure(target);
return DiagnosedSilenceableFailure::success();
}

//===----------------------------------------------------------------------===//
// RewriteInDestinationPassingStyleOp
//===----------------------------------------------------------------------===//
Expand Down Expand Up @@ -2620,7 +2604,12 @@ transform::TileUsingForOp::apply(transform::TransformRewriter &rewriter,
}

scf::SCFTilingOptions tilingOptions;
if (!tileSizes.empty()) {
if (tileSizes.empty()) {
tilingOptions.setTileSizeComputationFunction(
[](OpBuilder &, Operation *) -> SmallVector<OpFoldResult> {
return {};
});
} else {
tilingOptions.setTileSizeComputationFunction([&, index = i](OpBuilder &b,
Operation *) {
SmallVector<OpFoldResult> sizes;
Expand Down
5 changes: 1 addition & 4 deletions mlir/lib/Dialect/SCF/Transforms/TileUsingInterface.cpp
Expand Up @@ -283,10 +283,7 @@ mlir::scf::tileUsingSCFForOp(RewriterBase &rewriter, TilingInterface op,
// 1. Get the range of the loops that are represented by the operation.
SmallVector<Range> iterationDomain = op.getIterationDomain(rewriter);
size_t numLoops = iterationDomain.size();
if (numLoops == 0) {
return rewriter.notifyMatchFailure(
op, "unable to tile op with no iteration domain");
}

// 2. Materialize the tile sizes. Enforce the convention that "tiling by zero"
// skips tiling a particular dimension. This convention is significantly
// simpler to handle instead of adjusting affine maps to account for missing
Expand Down
@@ -1,18 +1,27 @@
// RUN: mlir-opt -test-tiling-interface=lower-to-scalar-using-scf-for -split-input-file %s | FileCheck %s
// RUN: mlir-opt -transform-interpreter -split-input-file -canonicalize -cse %s | FileCheck %s

func.func @gemm(%arg0 : memref<?x?xf32>, %arg1 : memref<?x?xf32>,
%arg2 : memref<?x?xf32>) {
linalg.matmul ins(%arg0, %arg1 : memref<?x?xf32>, memref<?x?xf32>)
outs(%arg2 : memref<?x?xf32>)
return
}

module attributes {transform.with_named_sequence} {
transform.named_sequence @__transform_main(%arg1 : !transform.any_op {transform.readonly}) {
%matmul = transform.structured.match ops{["linalg.matmul"]} in %arg1
: (!transform.any_op) -> !transform.any_op
transform.structured.convert_to_loops %matmul : !transform.any_op
transform.yield
}
}
// CHECK-LABEL: func @gemm
// CHECK-SAME: %[[ARG0:[a-zA-Z0-9]+]]: memref<?x?xf32>
// CHECK-SAME: %[[ARG1:[a-zA-Z0-9]+]]: memref<?x?xf32>
// CHECK-SAME: %[[ARG2:[a-zA-Z0-9]+]]: memref<?x?xf32>
// CHECK-DAG: %[[C0:.+]] = arith.constant 0 : index
// CHECK-DAG: %[[C1:.+]] = arith.constant 1 : index
// CHECK-DAG: %[[M:.+]] = memref.dim %[[ARG0]], %[[C0]]
// CHECK-DAG: %[[C1:.+]] = arith.constant 1 : index
// CHECK-DAG: %[[K:.+]] = memref.dim %[[ARG0]], %[[C1]]
// CHECK-DAG: %[[N:.+]] = memref.dim %[[ARG1]], %[[C1]]
// CHECK: scf.for %[[IV0:[a-zA-Z0-9]+]] = %[[C0]] to %[[M]] step %[[C1]]
Expand Down Expand Up @@ -51,6 +60,15 @@ func.func @indexed_generic(%arg0 : memref<200x300xi32>, %arg1 : memref<300xi16>,
}
return
}

module attributes {transform.with_named_sequence} {
transform.named_sequence @__transform_main(%arg1 : !transform.any_op {transform.readonly}) {
%generic = transform.structured.match ops{["linalg.generic"]} in %arg1
: (!transform.any_op) -> !transform.any_op
transform.structured.convert_to_loops %generic : !transform.any_op
transform.yield
}
}
// CHECK-LABEL: func @indexed_generic
// CHECK-SAME: %[[ARG0:.+]]: memref<200x300xi32>
// CHECK-SAME: %[[ARG1:.+]]: memref<300xi16>
Expand Down Expand Up @@ -87,8 +105,18 @@ func.func @conv_strides_and_dilation(%arg0 : memref<?x?x?x?xf32>, %arg1 : memref
outs(%arg2 : memref<?x?x?x?xf32>)
return
}
// CHECK-DAG: #[[MAP0:.+]] = affine_map<(d0, d1, d2, d3, d4, d5, d6) -> (d1 + d4 * 3)>
// CHECK-DAG: #[[MAP1:.+]] = affine_map<(d0, d1, d2, d3, d4, d5, d6) -> (d2 * 2 + d5 * 4)>

module attributes {transform.with_named_sequence} {
transform.named_sequence @__transform_main(%arg1 : !transform.any_op {transform.readonly}) {
%conv = transform.structured.match ops{["linalg.conv_2d_nhwc_hwcf"]} in %arg1
: (!transform.any_op) -> !transform.any_op
transform.structured.convert_to_loops %conv : !transform.any_op
transform.yield
}
}

// CHECK-DAG: #[[MAP0:.+]] = affine_map<(d0, d1) -> (d0 + d1 * 3)>
// CHECK-DAG: #[[MAP1:.+]] = affine_map<(d0, d1) -> (d0 * 2 + d1 * 4)>
// CHECK: func @conv_strides_and_dilation(
// CHECK-SAME: %[[ARG0:[a-zA-Z0-9]+]]: memref<?x?x?x?xf32>
// CHECK-SAME: %[[ARG1:[a-zA-Z0-9]+]]: memref<?x?x?x?xf32>
Expand All @@ -111,8 +139,8 @@ func.func @conv_strides_and_dilation(%arg0 : memref<?x?x?x?xf32>, %arg1 : memref
// CHECK: scf.for %[[IV4:[a-zA-Z0-9]+]] = %[[C0]] to %[[H]] step %[[C1]]
// CHECK: scf.for %[[IV5:[a-zA-Z0-9]+]] = %[[C0]] to %[[W]] step %[[C1]]
// CHECK: scf.for %[[IV6:[a-zA-Z0-9]+]] = %[[C0]] to %[[C]] step %[[C1]]
// CHECK-DAG: %[[I:.+]] = affine.apply #[[MAP0]](%[[IV0]], %[[IV1]], %[[IV2]], %[[IV3]], %[[IV4]], %[[IV5]], %[[IV6]])
// CHECK-DAG: %[[J:.+]] = affine.apply #[[MAP1]](%[[IV0]], %[[IV1]], %[[IV2]], %[[IV3]], %[[IV4]], %[[IV5]], %[[IV6]])
// CHECK-DAG: %[[I:.+]] = affine.apply #[[MAP0]](%[[IV1]], %[[IV4]])
// CHECK-DAG: %[[J:.+]] = affine.apply #[[MAP1]](%[[IV2]], %[[IV5]])
// CHECK-DAG: %[[T9:.+]] = memref.load %[[ARG0]][%[[IV0]], %[[I]], %[[J]], %[[IV6]]]
// CHECK-DAG: %[[T10:.+]] = memref.load %[[ARG1]][%[[IV4]], %[[IV5]], %[[IV6]], %[[IV3]]]
// CHECK-DAG: %[[T11:.+]] = memref.load %[[ARG2]][%[[IV0]], %[[IV1]], %[[IV2]], %[[IV3]]]
Expand All @@ -131,8 +159,18 @@ func.func @pool_strides_and_dilation(%arg0 : memref<?x?x?x?xf32>, %arg1 : memref
outs(%arg2 : memref<?x?x?x?xf32>)
return
}
// CHECK-DAG: #[[MAP0:.+]] = affine_map<(d0, d1, d2, d3, d4, d5) -> (d1 + d4 * 3)>
// CHECK-DAG: #[[MAP1:.+]] = affine_map<(d0, d1, d2, d3, d4, d5) -> (d2 * 2 + d5 * 4)>

module attributes {transform.with_named_sequence} {
transform.named_sequence @__transform_main(%arg1 : !transform.any_op {transform.readonly}) {
%pool = transform.structured.match ops{["linalg.pooling_nhwc_max"]} in %arg1
: (!transform.any_op) -> !transform.any_op
transform.structured.convert_to_loops %pool : !transform.any_op
transform.yield
}
}

// CHECK-DAG: #[[MAP0:.+]] = affine_map<(d0, d1) -> (d0 + d1 * 3)>
// CHECK-DAG: #[[MAP1:.+]] = affine_map<(d0, d1) -> (d0 * 2 + d1 * 4)>
// CHECK: func @pool_strides_and_dilation
// CHECK-SAME: %[[ARG0:[a-zA-Z0-9]+]]: memref<?x?x?x?xf32>
// CHECK-SAME: %[[ARG1:[a-zA-Z0-9]+]]: memref<?x?xf32>
Expand All @@ -153,8 +191,8 @@ func.func @pool_strides_and_dilation(%arg0 : memref<?x?x?x?xf32>, %arg1 : memref
// CHECK: scf.for %[[IV3:[a-zA-Z0-9]+]] = %[[C0]] to %[[C]] step %[[C1]]
// CHECK: scf.for %[[IV4:[a-zA-Z0-9]+]] = %[[C0]] to %[[H]] step %[[C1]]
// CHECK: scf.for %[[IV5:[a-zA-Z0-9]+]] = %[[C0]] to %[[W]] step %[[C1]]
// CHECK-DAG: %[[I:.+]] = affine.apply #[[MAP0]](%[[IV0]], %[[IV1]], %[[IV2]], %[[IV3]], %[[IV4]], %[[IV5]])
// CHECK-DAG: %[[J:.+]] = affine.apply #[[MAP1]](%[[IV0]], %[[IV1]], %[[IV2]], %[[IV3]], %[[IV4]], %[[IV5]])
// CHECK-DAG: %[[I:.+]] = affine.apply #[[MAP0]](%[[IV1]], %[[IV4]])
// CHECK-DAG: %[[J:.+]] = affine.apply #[[MAP1]](%[[IV2]], %[[IV5]])
// CHECK-DAG: %[[T8:.+]] = memref.load %[[ARG0]][%[[IV0]], %[[I]], %[[J]], %[[IV3]]]
// CHECK-DAG: %[[T9:.+]] = memref.load %[[ARG2]][%[[IV0]], %[[IV1]], %[[IV2]], %[[IV3]]]
// CHECK: %[[T10:.+]] = arith.maximumf %[[T9]], %[[T8]]
Expand All @@ -172,6 +210,15 @@ func.func @map(%lhs: memref<64xf32>,
}
return
}

module attributes {transform.with_named_sequence} {
transform.named_sequence @__transform_main(%arg1 : !transform.any_op {transform.readonly}) {
%map = transform.structured.match ops{["linalg.map"]} in %arg1
: (!transform.any_op) -> !transform.any_op
transform.structured.convert_to_loops %map : !transform.any_op
transform.yield
}
}
// CHECK-LABEL: func.func @map(
// CHECK-SAME: %[[LHS:[a-zA-Z0-9]+]]: memref<64xf32>,
// CHECK-SAME: %[[RHS:[a-zA-Z0-9]+]]: memref<64xf32>,
Expand All @@ -195,6 +242,15 @@ func.func @transpose(%arg0: memref<16x32x64xf32>,
outs(%arg1 : memref<32x64x16xf32>) permutation = [1, 2, 0]
return
}

module attributes {transform.with_named_sequence} {
transform.named_sequence @__transform_main(%arg1 : !transform.any_op {transform.readonly}) {
%transpose = transform.structured.match ops{["linalg.transpose"]} in %arg1
: (!transform.any_op) -> !transform.any_op
transform.structured.convert_to_loops %transpose : !transform.any_op
transform.yield
}
}
// CHECK-LABEL: func.func @transpose(
// CHECK-SAME: %[[IN:[a-zA-Z0-9]+]]: memref<16x32x64xf32>,
// CHECK-SAME: %[[OUT:[a-zA-Z0-9]+]]: memref<32x64x16xf32>)
Expand Down Expand Up @@ -223,6 +279,15 @@ func.func @reduce(%arg0: memref<16x32x64xf32>,
}
return
}

module attributes {transform.with_named_sequence} {
transform.named_sequence @__transform_main(%arg1 : !transform.any_op {transform.readonly}) {
%reduce = transform.structured.match ops{["linalg.reduce"]} in %arg1
: (!transform.any_op) -> !transform.any_op
transform.structured.convert_to_loops %reduce : !transform.any_op
transform.yield
}
}
// CHECK-LABEL: func.func @reduce(
// CHECK-SAME: %[[IN:[a-zA-Z0-9]+]]: memref<16x32x64xf32>,
// CHECK-SAME: %[[OUT:[a-zA-Z0-9]+]]: memref<16x64xf32>
Expand Down Expand Up @@ -251,6 +316,15 @@ func.func @broadcast(%input: memref<8x32xf32>,
dimensions = [1]
func.return
}

module attributes {transform.with_named_sequence} {
transform.named_sequence @__transform_main(%arg1 : !transform.any_op {transform.readonly}) {
%broadcast = transform.structured.match ops{["linalg.broadcast"]} in %arg1
: (!transform.any_op) -> !transform.any_op
transform.structured.convert_to_loops %broadcast : !transform.any_op
transform.yield
}
}
// CHECK-LABEL: func.func @broadcast(
// CHECK-SAME: %[[IN:[a-zA-Z0-9]+]]: memref<8x32xf32>,
// CHECK-SAME: %[[OUT:[a-zA-Z0-9]+]]: memref<8x16x32xf32>
Expand Down

0 comments on commit aa2a96a

Please sign in to comment.