96 changes: 96 additions & 0 deletions mlir/test/Dialect/Linalg/transform-op-fuse-into-containing.mlir
Original file line number Diff line number Diff line change
@@ -0,0 +1,96 @@
// RUN: mlir-opt --test-transform-dialect-interpreter --split-input-file %s | FileCheck %s

#map0 = affine_map<()[s0, s1] -> (s0 ceildiv s1)>
#map1 = affine_map<(d0)[s0] -> (d0 * s0)>
#map2 = affine_map<(d0)[s0, s1] -> (-(d0 * s1) + s0, s1)>

module {
// CHECK-LABEL: func.func @fuse_tileable_op
// CHECK-SAME: %[[CHUNK_SIZE:[0-9a-z]+]]: index
// CHECK-SAME: %[[IN:[0-9a-z]+]]: tensor<?xf32>
// CHECK-SAME: %[[OUT:[0-9a-z]+]]: tensor<?xf32>
func.func @fuse_tileable_op(%arg0: index, %arg1: tensor<?xf32>, %arg2: tensor<?xf32>) -> tensor<?xf32> {
%cst = arith.constant 4.200000e+01 : f32
%c0 = arith.constant 0 : index
%0 = linalg.fill ins(%cst : f32) outs(%arg1 : tensor<?xf32>) -> tensor<?xf32>
%d0 = tensor.dim %arg1, %c0 : tensor<?xf32>
%1 = affine.apply #map0()[%d0, %arg0]

// CHECK: scf.foreach_thread {{.*}} {
%2 = scf.foreach_thread (%arg3) in (%1) -> (tensor<?xf32>) {
%3 = affine.apply #map1(%arg3)[%arg0]
%4 = affine.min #map2(%arg3)[%d0, %arg0]
%5 = tensor.extract_slice %arg2[%3] [%4] [1] : tensor<?xf32> to tensor<?xf32>

// CHECK: %[[T0:.*]] = tensor.extract_slice %[[IN]][%{{.*}}] [%{{.*}}] [{{.*}}]
// CHECK: %[[T1:.*]] = linalg.fill {{.*}} outs(%[[T0]]
%6 = tensor.extract_slice %0[%3] [%4] [1] : tensor<?xf32> to tensor<?xf32>

// CHECK: %[[T2:.*]] = linalg.elemwise_unary ins(%[[T1]]
%7 = linalg.elemwise_unary ins(%6 : tensor<?xf32>) outs(%5 : tensor<?xf32>) -> tensor<?xf32>
scf.foreach_thread.perform_concurrently {
tensor.parallel_insert_slice %7 into %arg2[%3] [%4] [1] : tensor<?xf32> into tensor<?xf32>
}
}
// CHECK: }
func.return %2 : tensor<?xf32>
}

transform.with_pdl_patterns {
^bb0(%arg0: !pdl.operation):
transform.sequence %arg0 {
^bb1(%arg1: !pdl.operation):
%0 = transform.structured.match ops{["linalg.fill"]} in %arg1
%1 = transform.structured.match ops{["scf.foreach_thread"]} in %arg1

// linalg.fill is tileable. The op is tiled and fused.
transform.structured.fuse_into_containing_op %0 into %1
}
}
}

// -----

#map0 = affine_map<()[s0] -> (64 ceildiv s0)>
#map1 = affine_map<(d0)[s0] -> (d0 * s0)>
#map2 = affine_map<(d0)[s0] -> (-(d0 * s0) + 64, s0)>

module {
// CHECK-LABEL: func.func @fuse_untileable_op
// CHECK-SAME: %[[CHUNK_SIZE:[0-9a-z]+]]: index
// CHECK-SAME: %[[IN:[0-9a-z]+]]: tensor<64xf32>
// CHECK-SAME: %[[OUT:[0-9a-z]+]]: tensor<64xf32>
func.func @fuse_untileable_op(%arg0: index, %arg1: tensor<64xf32>, %arg2: tensor<64xf32>) -> tensor<64xf32> {
%0 = linalg.init_tensor [%arg0] : tensor<?xf32>
%1 = affine.apply #map0()[%arg0]

// CHECK: scf.foreach_thread {{.*}} {
%2 = scf.foreach_thread (%arg3) in (%1) -> (tensor<64xf32>) {
// CHECK: %[[INIT_TENSOR:.*]] = linalg.init_tensor
%3 = affine.apply #map1(%arg3)[%arg0]
%4 = affine.min #map2(%arg3)[%arg0]
%5 = tensor.extract_slice %arg2[%3] [%4] [1] : tensor<64xf32> to tensor<?xf32>

// CHECK: %[[T2:.*]] = linalg.elemwise_unary ins(%[[INIT_TENSOR]]
%7 = linalg.elemwise_unary ins(%0 : tensor<?xf32>) outs(%5 : tensor<?xf32>) -> tensor<?xf32>
scf.foreach_thread.perform_concurrently {
tensor.parallel_insert_slice %7 into %arg2[%3] [%4] [1] : tensor<?xf32> into tensor<64xf32>
}
}
// CHECK: }

func.return %2 : tensor<64xf32>
}

transform.with_pdl_patterns {
^bb0(%arg0: !pdl.operation):
transform.sequence %arg0 {
^bb1(%arg1: !pdl.operation):
%0 = transform.structured.match ops{["linalg.init_tensor"]} in %arg1
%1 = transform.structured.match ops{["scf.foreach_thread"]} in %arg1

// linalg.init_tensor is not tileable. The op is cloned and fused.
transform.structured.fuse_into_containing_op %0 into %1
}
}
}
24 changes: 24 additions & 0 deletions mlir/test/Dialect/Linalg/transform-op-match.mlir
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
// RUN: mlir-opt %s --test-transform-dialect-interpreter -allow-unregistered-dialect --split-input-file --verify-diagnostics

func.func @bar() {
// expected-remark @below {{matched op name}}
// expected-remark @below {{matched attr name}}
%0 = arith.constant {my_attr} 0: i32
// expected-remark @below {{matched op name}}
%1 = arith.constant 1 : i32
return
}

transform.with_pdl_patterns {
^bb0(%arg0: !pdl.operation):
transform.sequence %arg0 {
^bb1(%arg1: !pdl.operation):
%match_name = transform.structured.match ops{["arith.constant"]} in %arg1
transform.test_print_remark_at_operand %match_name, "matched op name"
transform.test_consume_operand %match_name

%match_attr = transform.structured.match ops{["arith.constant"]} attribute{"my_attr"} in %arg1
transform.test_print_remark_at_operand %match_attr, "matched attr name"
transform.test_consume_operand %match_attr
}
}
58 changes: 58 additions & 0 deletions mlir/test/Dialect/Linalg/transform-tile-and-fuse.mlir
Original file line number Diff line number Diff line change
@@ -0,0 +1,58 @@
// RUN: mlir-opt %s --test-transform-dialect-interpreter -canonicalize | FileCheck %s

// This is a simple tile-and-fuse example with a single fusion group.

module {
// CHECK: func @foo
// CHECK: scf.foreach_thread {{.*}} {
// CHECK: linalg.fill
// CHECK: linalg.matmul
// CHECK: linalg.generic
// CHECK: }
func.func @foo(%A: tensor<?x?xf32>, %B: tensor<?x?xf32>, %C: tensor<?xf32>,
%D: tensor<?x?xf32>, %sz0: index, %sz1: index)
-> tensor<?x?xf32>
{
%cst = arith.constant 0.000000e+00 : f32
%5 = linalg.fill
{__producer__}
ins(%cst : f32)
outs(%D : tensor<?x?xf32>) -> tensor<?x?xf32>
%6 = linalg.matmul
{__producer__}
ins(%A, %B : tensor<?x?xf32>, tensor<?x?xf32>)
outs(%5 : tensor<?x?xf32>) -> tensor<?x?xf32>
%7 = linalg.generic
{__root__,
indexing_maps = [affine_map<(d0, d1) -> (d0)>,
affine_map<(d0, d1) -> (d0, d1)>,
affine_map<(d0, d1) -> (d0, d1)>],
iterator_types = ["parallel", "parallel"]
}
ins(%C, %6 : tensor<?xf32>, tensor<?x?xf32>)
outs(%D : tensor<?x?xf32>) {
^bb0(%arg2: f32, %arg3: f32, %arg4: f32):
%16 = arith.maxf %arg3, %cst : f32
%17 = arith.cmpf ogt, %arg2, %cst : f32
%18 = arith.select %17, %cst, %16 : f32
linalg.yield %18 : f32
} -> tensor<?x?xf32>
return %7 : tensor<?x?xf32>
}

transform.with_pdl_patterns {
^bb0(%arg0: !pdl.operation):
transform.sequence %arg0 {
^bb1(%arg1: !pdl.operation):
// Find the root and all producers.
%root = transform.structured.match attribute{"__root__"} in %arg1
%producers = transform.structured.match attribute{"__producer__"} in %arg1

// Tile the root.
%foreach_thread_op, %tiled_op = transform.structured.tile_to_foreach_thread_op %root num_threads [10, 20]

// Fuse all producers.
transform.structured.fuse_into_containing_op %producers into %foreach_thread_op
}
}
}