-
Notifications
You must be signed in to change notification settings - Fork 10.8k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[MLIR][SCF] Update scf.parallel lowering to OpenMP (3/5) #89212
[MLIR][SCF] Update scf.parallel lowering to OpenMP (3/5) #89212
Conversation
This patch updates the definition of `omp.wsloop` to enforce the restrictions of a loop wrapper operation. Related tests are updated but this PR on its own will not pass premerge tests. All patches in the stack are needed before it can be compiled and passes tests.
This patch updates verifiers for `omp.ordered.region`, `omp.cancel` and `omp.cancellation_point`, which check for a parent `omp.wsloop`. After transitioning to a loop wrapper-based approach, the expected direct parent will become `omp.loop_nest` instead, so verifiers need to take this into account. This PR on its own will not pass premerge tests. All patches in the stack are needed before it can be compiled and passes tests.
This patch makes changes to the `scf.parallel` to `omp.parallel` + `omp.wsloop` lowering pass in order to introduce a nested `omp.loop_nest` as well, and to follow the new loop wrapper role for `omp.wsloop`. This PR on its own will not pass premerge tests. All patches in the stack are needed before it can be compiled and passes tests.
@llvm/pr-subscribers-flang-openmp @llvm/pr-subscribers-mlir Author: Sergio Afonso (skatrak) ChangesThis patch makes changes to the This PR on its own will not pass premerge tests. All patches in the stack are needed before it can be compiled and passes tests. Full diff: https://github.com/llvm/llvm-project/pull/89212.diff 3 Files Affected:
diff --git a/mlir/lib/Conversion/SCFToOpenMP/SCFToOpenMP.cpp b/mlir/lib/Conversion/SCFToOpenMP/SCFToOpenMP.cpp
index 7f91367ad427a2..f0b8d6c5309357 100644
--- a/mlir/lib/Conversion/SCFToOpenMP/SCFToOpenMP.cpp
+++ b/mlir/lib/Conversion/SCFToOpenMP/SCFToOpenMP.cpp
@@ -461,18 +461,51 @@ struct ParallelOpLowering : public OpRewritePattern<scf::ParallelOp> {
// Replace the loop.
{
OpBuilder::InsertionGuard allocaGuard(rewriter);
- auto loop = rewriter.create<omp::WsloopOp>(
+ // Create worksharing loop wrapper.
+ auto wsloopOp = rewriter.create<omp::WsloopOp>(parallelOp.getLoc());
+ if (!reductionVariables.empty()) {
+ wsloopOp.setReductionsAttr(
+ ArrayAttr::get(rewriter.getContext(), reductionDeclSymbols));
+ wsloopOp.getReductionVarsMutable().append(reductionVariables);
+ }
+ rewriter.create<omp::TerminatorOp>(loc); // omp.parallel terminator.
+
+ // The wrapper's entry block arguments will define the reduction
+ // variables.
+ llvm::SmallVector<mlir::Type> reductionTypes;
+ reductionTypes.reserve(reductionVariables.size());
+ llvm::transform(reductionVariables, std::back_inserter(reductionTypes),
+ [](mlir::Value v) { return v.getType(); });
+ rewriter.createBlock(
+ &wsloopOp.getRegion(), {}, reductionTypes,
+ llvm::SmallVector<mlir::Location>(reductionVariables.size(),
+ parallelOp.getLoc()));
+
+ rewriter.setInsertionPoint(
+ rewriter.create<omp::TerminatorOp>(parallelOp.getLoc()));
+
+ // Create loop nest and populate region with contents of scf.parallel.
+ auto loopOp = rewriter.create<omp::LoopNestOp>(
parallelOp.getLoc(), parallelOp.getLowerBound(),
parallelOp.getUpperBound(), parallelOp.getStep());
- rewriter.create<omp::TerminatorOp>(loc);
- rewriter.inlineRegionBefore(parallelOp.getRegion(), loop.getRegion(),
- loop.getRegion().begin());
+ rewriter.inlineRegionBefore(parallelOp.getRegion(), loopOp.getRegion(),
+ loopOp.getRegion().begin());
+
+ // Remove reduction-related block arguments from omp.loop_nest and
+ // redirect uses to the corresponding omp.wsloop block argument.
+ mlir::Block &loopOpEntryBlock = loopOp.getRegion().front();
+ unsigned numLoops = parallelOp.getNumLoops();
+ rewriter.replaceAllUsesWith(
+ loopOpEntryBlock.getArguments().drop_front(numLoops),
+ wsloopOp.getRegion().getArguments());
+ loopOpEntryBlock.eraseArguments(
+ numLoops, loopOpEntryBlock.getNumArguments() - numLoops);
- Block *ops = rewriter.splitBlock(&*loop.getRegion().begin(),
- loop.getRegion().begin()->begin());
+ Block *ops = rewriter.splitBlock(&*loopOp.getRegion().begin(),
+ loopOp.getRegion().begin()->begin());
- rewriter.setInsertionPointToStart(&*loop.getRegion().begin());
+ rewriter.setInsertionPointToStart(&*loopOp.getRegion().begin());
auto scope = rewriter.create<memref::AllocaScopeOp>(parallelOp.getLoc(),
TypeRange());
@@ -481,11 +514,6 @@ struct ParallelOpLowering : public OpRewritePattern<scf::ParallelOp> {
rewriter.mergeBlocks(ops, scopeBlock);
rewriter.setInsertionPointToEnd(&*scope.getBodyRegion().begin());
rewriter.create<memref::AllocaScopeReturnOp>(loc, ValueRange());
- if (!reductionVariables.empty()) {
- loop.setReductionsAttr(
- ArrayAttr::get(rewriter.getContext(), reductionDeclSymbols));
- loop.getReductionVarsMutable().append(reductionVariables);
- }
}
}
diff --git a/mlir/test/Conversion/SCFToOpenMP/reductions.mlir b/mlir/test/Conversion/SCFToOpenMP/reductions.mlir
index 3b6c145d62f1a8..fc6d56559c2618 100644
--- a/mlir/test/Conversion/SCFToOpenMP/reductions.mlir
+++ b/mlir/test/Conversion/SCFToOpenMP/reductions.mlir
@@ -28,6 +28,7 @@ func.func @reduction1(%arg0 : index, %arg1 : index, %arg2 : index,
// CHECK: omp.parallel
// CHECK: omp.wsloop
// CHECK-SAME: reduction(@[[$REDF]] %[[BUF]] -> %[[PVT_BUF:[a-z0-9]+]]
+ // CHECK: omp.loop_nest
// CHECK: memref.alloca_scope
scf.parallel (%i0, %i1) = (%arg0, %arg1) to (%arg2, %arg3)
step (%arg4, %step) init (%zero) -> (f32) {
@@ -43,6 +44,7 @@ func.func @reduction1(%arg0 : index, %arg1 : index, %arg2 : index,
}
// CHECK: omp.yield
}
+ // CHECK: omp.terminator
// CHECK: omp.terminator
// CHECK: llvm.load %[[BUF]]
return
@@ -107,6 +109,7 @@ func.func @reduction_muli(%arg0 : index, %arg1 : index, %arg2 : index,
%one = arith.constant 1 : i32
// CHECK: %[[RED_VAR:.*]] = llvm.alloca %{{.*}} x i32 : (i64) -> !llvm.ptr
// CHECK: omp.wsloop reduction(@[[$REDI]] %[[RED_VAR]] -> %[[RED_PVT_VAR:.*]] : !llvm.ptr)
+ // CHECK: omp.loop_nest
scf.parallel (%i0, %i1) = (%arg0, %arg1) to (%arg2, %arg3)
step (%arg4, %step) init (%one) -> (i32) {
// CHECK: %[[C2:.*]] = arith.constant 2 : i32
@@ -208,6 +211,7 @@ func.func @reduction4(%arg0 : index, %arg1 : index, %arg2 : index,
// CHECK: omp.wsloop
// CHECK-SAME: reduction(@[[$REDF1]] %[[BUF1]] -> %[[PVT_BUF1:[a-z0-9]+]]
// CHECK-SAME: @[[$REDF2]] %[[BUF2]] -> %[[PVT_BUF2:[a-z0-9]+]]
+ // CHECK: omp.loop_nest
// CHECK: memref.alloca_scope
%res:2 = scf.parallel (%i0, %i1) = (%arg0, %arg1) to (%arg2, %arg3)
step (%arg4, %step) init (%zero, %ione) -> (f32, i64) {
@@ -236,6 +240,7 @@ func.func @reduction4(%arg0 : index, %arg1 : index, %arg2 : index,
}
// CHECK: omp.yield
}
+ // CHECK: omp.terminator
// CHECK: omp.terminator
// CHECK: %[[RES1:.*]] = llvm.load %[[BUF1]] : !llvm.ptr -> f32
// CHECK: %[[RES2:.*]] = llvm.load %[[BUF2]] : !llvm.ptr -> i64
diff --git a/mlir/test/Conversion/SCFToOpenMP/scf-to-openmp.mlir b/mlir/test/Conversion/SCFToOpenMP/scf-to-openmp.mlir
index acd2690c56e2e6..b2f19d294cb5fe 100644
--- a/mlir/test/Conversion/SCFToOpenMP/scf-to-openmp.mlir
+++ b/mlir/test/Conversion/SCFToOpenMP/scf-to-openmp.mlir
@@ -2,10 +2,11 @@
// CHECK-LABEL: @parallel
func.func @parallel(%arg0: index, %arg1: index, %arg2: index,
- %arg3: index, %arg4: index, %arg5: index) {
+ %arg3: index, %arg4: index, %arg5: index) {
// CHECK: %[[FOUR:.+]] = llvm.mlir.constant(4 : i32) : i32
// CHECK: omp.parallel num_threads(%[[FOUR]] : i32) {
- // CHECK: omp.wsloop for (%[[LVAR1:.*]], %[[LVAR2:.*]]) : index = (%arg0, %arg1) to (%arg2, %arg3) step (%arg4, %arg5) {
+ // CHECK: omp.wsloop {
+ // CHECK: omp.loop_nest (%[[LVAR1:.*]], %[[LVAR2:.*]]) : index = (%arg0, %arg1) to (%arg2, %arg3) step (%arg4, %arg5) {
// CHECK: memref.alloca_scope
scf.parallel (%i, %j) = (%arg0, %arg1) to (%arg2, %arg3) step (%arg4, %arg5) {
// CHECK: "test.payload"(%[[LVAR1]], %[[LVAR2]]) : (index, index) -> ()
@@ -13,6 +14,8 @@ func.func @parallel(%arg0: index, %arg1: index, %arg2: index,
// CHECK: omp.yield
// CHECK: }
}
+ // CHECK: omp.terminator
+ // CHECK: }
// CHECK: omp.terminator
// CHECK: }
return
@@ -23,20 +26,26 @@ func.func @nested_loops(%arg0: index, %arg1: index, %arg2: index,
%arg3: index, %arg4: index, %arg5: index) {
// CHECK: %[[FOUR:.+]] = llvm.mlir.constant(4 : i32) : i32
// CHECK: omp.parallel num_threads(%[[FOUR]] : i32) {
- // CHECK: omp.wsloop for (%[[LVAR_OUT1:.*]]) : index = (%arg0) to (%arg2) step (%arg4) {
- // CHECK: memref.alloca_scope
+ // CHECK: omp.wsloop {
+ // CHECK: omp.loop_nest (%[[LVAR_OUT1:.*]]) : index = (%arg0) to (%arg2) step (%arg4) {
+ // CHECK: memref.alloca_scope
scf.parallel (%i) = (%arg0) to (%arg2) step (%arg4) {
// CHECK: omp.parallel
- // CHECK: omp.wsloop for (%[[LVAR_IN1:.*]]) : index = (%arg1) to (%arg3) step (%arg5) {
+ // CHECK: omp.wsloop {
+ // CHECK: omp.loop_nest (%[[LVAR_IN1:.*]]) : index = (%arg1) to (%arg3) step (%arg5) {
// CHECK: memref.alloca_scope
scf.parallel (%j) = (%arg1) to (%arg3) step (%arg5) {
// CHECK: "test.payload"(%[[LVAR_OUT1]], %[[LVAR_IN1]]) : (index, index) -> ()
"test.payload"(%i, %j) : (index, index) -> ()
// CHECK: }
}
- // CHECK: omp.yield
+ // CHECK: omp.yield
+ // CHECK: }
+ // CHECK: omp.terminator
// CHECK: }
}
+ // CHECK: omp.terminator
+ // CHECK: }
// CHECK: omp.terminator
// CHECK: }
return
@@ -47,7 +56,8 @@ func.func @adjacent_loops(%arg0: index, %arg1: index, %arg2: index,
%arg3: index, %arg4: index, %arg5: index) {
// CHECK: %[[FOUR:.+]] = llvm.mlir.constant(4 : i32) : i32
// CHECK: omp.parallel num_threads(%[[FOUR]] : i32) {
- // CHECK: omp.wsloop for (%[[LVAR_AL1:.*]]) : index = (%arg0) to (%arg2) step (%arg4) {
+ // CHECK: omp.wsloop {
+ // CHECK: omp.loop_nest (%[[LVAR_AL1:.*]]) : index = (%arg0) to (%arg2) step (%arg4) {
// CHECK: memref.alloca_scope
scf.parallel (%i) = (%arg0) to (%arg2) step (%arg4) {
// CHECK: "test.payload1"(%[[LVAR_AL1]]) : (index) -> ()
@@ -55,12 +65,15 @@ func.func @adjacent_loops(%arg0: index, %arg1: index, %arg2: index,
// CHECK: omp.yield
// CHECK: }
}
+ // CHECK: omp.terminator
+ // CHECK: }
// CHECK: omp.terminator
// CHECK: }
// CHECK: %[[FOUR:.+]] = llvm.mlir.constant(4 : i32) : i32
// CHECK: omp.parallel num_threads(%[[FOUR]] : i32) {
- // CHECK: omp.wsloop for (%[[LVAR_AL2:.*]]) : index = (%arg1) to (%arg3) step (%arg5) {
+ // CHECK: omp.wsloop {
+ // CHECK: omp.loop_nest (%[[LVAR_AL2:.*]]) : index = (%arg1) to (%arg3) step (%arg5) {
// CHECK: memref.alloca_scope
scf.parallel (%j) = (%arg1) to (%arg3) step (%arg5) {
// CHECK: "test.payload2"(%[[LVAR_AL2]]) : (index) -> ()
@@ -68,6 +81,8 @@ func.func @adjacent_loops(%arg0: index, %arg1: index, %arg2: index,
// CHECK: omp.yield
// CHECK: }
}
+ // CHECK: omp.terminator
+ // CHECK: }
// CHECK: omp.terminator
// CHECK: }
return
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM, thanks
… users/skatrak/spr/wsloop-wrapper-03-scf-parallel
…atrak/spr/wsloop-wrapper-02-dependent-ops
… users/skatrak/spr/wsloop-wrapper-03-scf-parallel
…atrak/spr/wsloop-wrapper-02-dependent-ops
… users/skatrak/spr/wsloop-wrapper-03-scf-parallel
…atrak/spr/wsloop-wrapper-02-dependent-ops
… users/skatrak/spr/wsloop-wrapper-03-scf-parallel
This patch makes changes to the
scf.parallel
toomp.parallel
+omp.wsloop
lowering pass in order to introduce a nestedomp.loop_nest
as well, and to follow the new loop wrapper role foromp.wsloop
.This PR on its own will not pass premerge tests. All patches in the stack are needed before it can be compiled and passes tests.