From 9b88b28f92f16b10b670579e668fa02402899a5a Mon Sep 17 00:00:00 2001 From: Adam Siemieniuk Date: Thu, 13 Jun 2024 12:11:48 +0200 Subject: [PATCH] [mlir][scf] Fix scf.forall to scf.parallel pass walker Adds proper walk results to the pass body to prevent runtime crashes on transformation failure. --- .../SCF/Transforms/ForallToParallel.cpp | 3 ++- mlir/test/Dialect/SCF/forall-to-parallel.mlir | 18 ++++++++++++++++++ 2 files changed, 20 insertions(+), 1 deletion(-) diff --git a/mlir/lib/Dialect/SCF/Transforms/ForallToParallel.cpp b/mlir/lib/Dialect/SCF/Transforms/ForallToParallel.cpp index 44e6840b03a3d..925d4a3c0a085 100644 --- a/mlir/lib/Dialect/SCF/Transforms/ForallToParallel.cpp +++ b/mlir/lib/Dialect/SCF/Transforms/ForallToParallel.cpp @@ -71,8 +71,9 @@ struct ForallToParallelLoop final parentOp->walk([&](scf::ForallOp forallOp) { if (failed(scf::forallToParallelLoop(rewriter, forallOp))) { - return signalPassFailure(); + return WalkResult::skip(); } + return WalkResult::advance(); }); } }; diff --git a/mlir/test/Dialect/SCF/forall-to-parallel.mlir b/mlir/test/Dialect/SCF/forall-to-parallel.mlir index acde601d47259..21e816956a094 100644 --- a/mlir/test/Dialect/SCF/forall-to-parallel.mlir +++ b/mlir/test/Dialect/SCF/forall-to-parallel.mlir @@ -78,3 +78,21 @@ func.func @mapping_attr() -> () { return } + +// ----- + +// CHECK-LABEL: @forall_with_outputs +// CHECK-SAME: %[[ARG0:.+]]: tensor<32x32xf32> +func.func @forall_with_outputs(%arg0: tensor<32x32xf32>) -> tensor<8x112x32x32xf32> { + // CHECK-NOT: scf.parallel + // CHECK: %[[RES:.+]] = scf.forall{{.*}}shared_outs + // CHECK: return %[[RES]] : tensor<8x112x32x32xf32> + + %0 = tensor.empty() : tensor<8x112x32x32xf32> + %1 = scf.forall (%arg1, %arg2) in (8, 112) shared_outs(%arg3 = %0) -> (tensor<8x112x32x32xf32>) { + scf.forall.in_parallel { + tensor.parallel_insert_slice %arg0 into %arg3[%arg1, %arg2, 0, 0] [1, 1, 32, 32] [1, 1, 1, 1] : tensor<32x32xf32> into tensor<8x112x32x32xf32> + } + } + return %1 : tensor<8x112x32x32xf32> +}