Skip to content

Commit

Permalink
Fix bug in fusion refactor and add test
Browse files Browse the repository at this point in the history
  • Loading branch information
srcarroll committed Jul 3, 2024
1 parent 6825c15 commit 4e4a96e
Show file tree
Hide file tree
Showing 2 changed files with 57 additions and 0 deletions.
1 change: 1 addition & 0 deletions mlir/lib/Interfaces/LoopLikeInterface.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -138,6 +138,7 @@ LoopLikeOpInterface mlir::createFused(LoopLikeOpInterface target,
if (failed(maybeFusedLoop))
llvm_unreachable("failed to replace loop");
LoopLikeOpInterface fusedLoop = *maybeFusedLoop;
rewriter.moveOpBefore(fusedLoop, source);

// Map control operands.
IRMapping mapping;
Expand Down
56 changes: 56 additions & 0 deletions mlir/test/Dialect/SCF/transform-loop-fuse-sibling.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -371,6 +371,62 @@ module attributes {transform.with_named_sequence} {
}
}


// -----

// CHECK: #[[$MAP:.+]] = affine_map<(d0) -> (d0 * 32)
#map = affine_map<(d0) -> (d0 * 32)>
#map1 = affine_map<(d0, d1) -> (d0, d1)>
module {
// CHECK: func.func @loop_sibling_fusion(%[[ARG0:.*]]: {{.*}}, %[[ARG1:.*]]: {{.*}}, %[[ARG2:.*]]: {{.*}}, %[[ARG3:.*]]: {{.*}}
func.func @loop_sibling_fusion(%arg0: tensor<128xf32>, %arg1: tensor<128x128xf16>, %arg2: tensor<128x64xf32>, %arg3: tensor<128x128xf32>) -> (tensor<128xf32>, tensor<128x128xf16>) {
// CHECK: %[[EMPTY:.*]] = tensor.empty() : tensor<128x128xf16>
// CHECK-NEXT: %[[RESULTS:.*]]:2 = scf.forall (%[[I:.*]]) in (4) shared_outs(%[[S1:.*]] = %[[ARG0]], %[[S2:.*]] = %[[ARG1]]) -> (tensor<128xf32>, tensor<128x128xf16>) {
// CHECK-NEXT: %[[IDX:.*]] = affine.apply #[[$MAP]](%[[I]])
// CHECK-NEXT: %[[SLICE0:.*]] = tensor.extract_slice %[[ARG3]][%[[IDX]], 0] [32, 1] [1, 1] : tensor<128x128xf32> to tensor<32xf32>
// CHECK-NEXT: %[[SLICE1:.*]] = tensor.extract_slice %[[ARG3]][%[[IDX]], 0] [32, 128] [1, 1] : tensor<128x128xf32> to tensor<32x128xf32>
// CHECK-NEXT: %[[SLICE2:.*]] = tensor.extract_slice %[[EMPTY]][%[[IDX]], 0] [32, 128] [1, 1] : tensor<128x128xf16> to tensor<32x128xf16>
// CHECK-NEXT: %[[GENERIC:.*]] = linalg.generic {{.*}} ins(%[[SLICE1]] : {{.*}}) outs(%[[SLICE2]] : {{.*}})
// CHECK: scf.forall.in_parallel {
// CHECK-NEXT: tensor.parallel_insert_slice %[[SLICE0]] into %[[S1]][%[[IDX]]] [32] [1] : tensor<32xf32> into tensor<128xf32>
// CHECK-NEXT: tensor.parallel_insert_slice %[[GENERIC]] into %[[S2]][%[[IDX]], 0] [32, 128] [1, 1] : tensor<32x128xf16> into tensor<128x128xf16>
// CHECK-NEXT: }
// CHECK-NEXT: } {mapping = [#gpu.warp<linear_dim_0>]}
// CHECK-NEXT: return %[[RESULTS]]#0, %[[RESULTS]]#1
%0 = scf.forall (%arg4) in (4) shared_outs(%arg5 = %arg0) -> (tensor<128xf32>) {
%3 = affine.apply #map(%arg4)
%extracted_slice = tensor.extract_slice %arg3[%3, 0] [32, 1] [1, 1] : tensor<128x128xf32> to tensor<32xf32>
scf.forall.in_parallel {
tensor.parallel_insert_slice %extracted_slice into %arg5[%3] [32] [1] : tensor<32xf32> into tensor<128xf32>
}
} {mapping = [#gpu.warp<linear_dim_0>]}
%1 = tensor.empty() : tensor<128x128xf16>
%2 = scf.forall (%arg4) in (4) shared_outs(%arg5 = %arg1) -> (tensor<128x128xf16>) {
%3 = affine.apply #map(%arg4)
%extracted_slice = tensor.extract_slice %arg3[%3, 0] [32, 128] [1, 1] : tensor<128x128xf32> to tensor<32x128xf32>
%extracted_slice_0 = tensor.extract_slice %1[%3, 0] [32, 128] [1, 1] : tensor<128x128xf16> to tensor<32x128xf16>
%4 = linalg.generic {indexing_maps = [#map1, #map1], iterator_types = ["parallel", "parallel"]} ins(%extracted_slice : tensor<32x128xf32>) outs(%extracted_slice_0 : tensor<32x128xf16>) {
^bb0(%in: f32, %out: f16):
%5 = arith.truncf %in : f32 to f16
linalg.yield %5 : f16
} -> tensor<32x128xf16>
scf.forall.in_parallel {
tensor.parallel_insert_slice %4 into %arg5[%3, 0] [32, 128] [1, 1] : tensor<32x128xf16> into tensor<128x128xf16>
}
} {mapping = [#gpu.warp<linear_dim_0>]}
return %0, %2 : tensor<128xf32>, tensor<128x128xf16>
}
}

module attributes { transform.with_named_sequence } {
transform.named_sequence @__transform_main(%root: !transform.any_op) {
%loops = transform.structured.match ops{["scf.forall"]} in %root : (!transform.any_op) -> !transform.any_op
%loop1, %loop2 = transform.split_handle %loops : (!transform.any_op) -> (!transform.any_op, !transform.any_op)
%loop3 = transform.loop.fuse_sibling %loop1 into %loop2 : (!transform.any_op, !transform.any_op) -> !transform.any_op
transform.yield
}
}

// -----

func.func @source_for_uses_result_of_target_for_err(%A: tensor<128xf32>, %B: tensor<128xf32>) -> (tensor<128xf32>, tensor<128xf32>) {
Expand Down

0 comments on commit 4e4a96e

Please sign in to comment.