diff --git a/mlir/lib/Dialect/SCF/IR/SCF.cpp b/mlir/lib/Dialect/SCF/IR/SCF.cpp index 652414f6cbe54..87ee611bd63f0 100644 --- a/mlir/lib/Dialect/SCF/IR/SCF.cpp +++ b/mlir/lib/Dialect/SCF/IR/SCF.cpp @@ -1986,6 +1986,20 @@ struct FoldTensorCastOfOutputIntoForallOp if (tensorCastProducers.empty()) return failure(); + llvm::SmallMapVector yieldOpToIterArgsIndex; + for (auto [index, iterArg] : + llvm::enumerate(forallOp.getRegionIterArgs())) { + for (Operation *user : iterArg.getUsers()) { + if (isa(user)) { + auto [it, inserted] = yieldOpToIterArgsIndex.try_emplace(user, index); + if (!inserted) { + return rewriter.notifyMatchFailure( + forallOp, "expected exactly one iter arg per yielding op"); + } + } + } + } + // Create new loop. Location loc = forallOp.getLoc(); auto newForallOp = ForallOp::create( @@ -2012,13 +2026,11 @@ struct FoldTensorCastOfOutputIntoForallOp // After `mergeBlocks` happened, the destinations in the terminator were // mapped to the tensor.cast old-typed results of the output bbArgs. The // destination have to be updated to point to the output bbArgs directly. - auto terminator = newForallOp.getTerminator(); - for (auto [yieldingOp, outputBlockArg] : llvm::zip( - terminator.getYieldingOps(), newForallOp.getRegionIterArgs())) { - if (auto parallelCombingingOp = - dyn_cast(yieldingOp)) { - parallelCombingingOp.getUpdatedDestinations().assign(outputBlockArg); - } + ArrayRef newIterArgs = newForallOp.getRegionIterArgs(); + for (auto [yieldOp, iterArgsIndex] : yieldOpToIterArgsIndex) { + auto parallelCombiningOp = cast(yieldOp); + parallelCombiningOp.getUpdatedDestinations().assign( + newIterArgs[iterArgsIndex]); } // Cast results back to the original types. diff --git a/mlir/test/Dialect/SCF/canonicalize.mlir b/mlir/test/Dialect/SCF/canonicalize.mlir index ac590fc0c47b9..3cd018d4729cf 100644 --- a/mlir/test/Dialect/SCF/canonicalize.mlir +++ b/mlir/test/Dialect/SCF/canonicalize.mlir @@ -2028,6 +2028,43 @@ func.func @do_not_fold_tensor_cast_from_dynamic_to_static_type_into_forall( // ----- +// CHECK-LABEL: func.func @fold_tensor_cast_into_forall_with_multiple_result( +// CHECK-SAME: %[[ARG0:.*]]: tensor<16xf32>, +// CHECK-SAME: %[[ARG1:.*]]: tensor<8xf32>) -> (tensor, tensor<64xf32>) { +// CHECK: %[[CONSTANT_0:.*]] = arith.constant 8 : index +// CHECK: %[[CONSTANT_1:.*]] = arith.constant 16 : index +// CHECK: %[[EMPTY_0:.*]] = tensor.empty() : tensor<32xf32> +// CHECK: %[[EMPTY_1:.*]] = tensor.empty() : tensor<64xf32> +// CHECK: %[[FORALL_0:.*]]:2 = scf.forall (%[[VAL_0:.*]]) in (4) shared_outs(%[[VAL_1:.*]] = %[[EMPTY_0]], %[[VAL_2:.*]] = %[[EMPTY_1]]) -> (tensor<32xf32>, tensor<64xf32>) { +// CHECK: %[[MULI_0:.*]] = arith.muli %[[VAL_0]], %[[CONSTANT_0]] : index +// CHECK: %[[MULI_1:.*]] = arith.muli %[[VAL_0]], %[[CONSTANT_1]] : index +// CHECK: scf.forall.in_parallel { +// CHECK: tensor.parallel_insert_slice %[[ARG0]] into %[[VAL_2]]{{\[}}%[[MULI_1]]] [16] [1] : tensor<16xf32> into tensor<64xf32> +// CHECK: tensor.parallel_insert_slice %[[ARG1]] into %[[VAL_1]]{{\[}}%[[MULI_0]]] [8] [1] : tensor<8xf32> into tensor<32xf32> +// CHECK: } +// CHECK: } +// CHECK: %[[CAST_0:.*]] = tensor.cast %[[FORALL_0]]#0 : tensor<32xf32> to tensor +// CHECK: return %[[CAST_0]], %[[FORALL_0]]#1 : tensor, tensor<64xf32> +// CHECK: } +func.func @fold_tensor_cast_into_forall_with_multiple_result(%arg0: tensor<16xf32>, %arg1: tensor<8xf32>) -> (tensor, tensor<64xf32>) { + %c8 = arith.constant 8 : index + %c16 = arith.constant 16 : index + %c32 = arith.constant 32 : index + %0 = tensor.empty(%c32) : tensor + %1 = tensor.empty() : tensor<64xf32> + %2:2 = scf.forall (%arg2) in (4) shared_outs(%arg3 = %0, %arg4 = %1) -> (tensor, tensor<64xf32>) { + %3 = arith.muli %c8, %arg2 : index + %4 = arith.muli %c16, %arg2 : index + scf.forall.in_parallel { + tensor.parallel_insert_slice %arg0 into %arg4[%4] [16] [1] : tensor<16xf32> into tensor<64xf32> + tensor.parallel_insert_slice %arg1 into %arg3[%3] [8] [1] : tensor<8xf32> into tensor + } + } + return %2#0, %2#1 : tensor, tensor<64xf32> +} + +// ----- + #map = affine_map<()[s0, s1] -> (s0 ceildiv s1)> #map1 = affine_map<(d0)[s0] -> (d0 * s0)> #map2 = affine_map<(d0)[s0, s1] -> (-(d0 * s1) + s0, s1)>