Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[MLIR][SCF] Update scf.parallel lowering to OpenMP (3/5) #89212

Merged
merged 16 commits into from
Apr 24, 2024

Conversation

skatrak
Copy link
Contributor

@skatrak skatrak commented Apr 18, 2024

This patch makes changes to the scf.parallel to omp.parallel + omp.wsloop lowering pass in order to introduce a nested omp.loop_nest as well, and to follow the new loop wrapper role for omp.wsloop.

This PR on its own will not pass premerge tests. All patches in the stack are needed before it can be compiled and passes tests.

This patch updates the definition of `omp.wsloop` to enforce the restrictions
of a loop wrapper operation.

Related tests are updated but this PR on its own will not pass premerge tests.
All patches in the stack are needed before it can be compiled and passes tests.
This patch updates verifiers for `omp.ordered.region`, `omp.cancel` and
`omp.cancellation_point`, which check for a parent `omp.wsloop`.

After transitioning to a loop wrapper-based approach, the expected direct
parent will become `omp.loop_nest` instead, so verifiers need to take this into
account.

This PR on its own will not pass premerge tests. All patches in the stack are
needed before it can be compiled and passes tests.
This patch makes changes to the `scf.parallel` to `omp.parallel` + `omp.wsloop`
lowering pass in order to introduce a nested `omp.loop_nest` as well, and to
follow the new loop wrapper role for `omp.wsloop`.

This PR on its own will not pass premerge tests. All patches in the stack are
needed before it can be compiled and passes tests.
@llvmbot
Copy link
Collaborator

llvmbot commented Apr 18, 2024

@llvm/pr-subscribers-flang-openmp
@llvm/pr-subscribers-mlir-openmp

@llvm/pr-subscribers-mlir

Author: Sergio Afonso (skatrak)

Changes

This patch makes changes to the scf.parallel to omp.parallel + omp.wsloop lowering pass in order to introduce a nested omp.loop_nest as well, and to follow the new loop wrapper role for omp.wsloop.

This PR on its own will not pass premerge tests. All patches in the stack are needed before it can be compiled and passes tests.


Full diff: https://github.com/llvm/llvm-project/pull/89212.diff

3 Files Affected:

  • (modified) mlir/lib/Conversion/SCFToOpenMP/SCFToOpenMP.cpp (+40-12)
  • (modified) mlir/test/Conversion/SCFToOpenMP/reductions.mlir (+5)
  • (modified) mlir/test/Conversion/SCFToOpenMP/scf-to-openmp.mlir (+23-8)
diff --git a/mlir/lib/Conversion/SCFToOpenMP/SCFToOpenMP.cpp b/mlir/lib/Conversion/SCFToOpenMP/SCFToOpenMP.cpp
index 7f91367ad427a2..f0b8d6c5309357 100644
--- a/mlir/lib/Conversion/SCFToOpenMP/SCFToOpenMP.cpp
+++ b/mlir/lib/Conversion/SCFToOpenMP/SCFToOpenMP.cpp
@@ -461,18 +461,51 @@ struct ParallelOpLowering : public OpRewritePattern<scf::ParallelOp> {
       // Replace the loop.
       {
         OpBuilder::InsertionGuard allocaGuard(rewriter);
-        auto loop = rewriter.create<omp::WsloopOp>(
+        // Create worksharing loop wrapper.
+        auto wsloopOp = rewriter.create<omp::WsloopOp>(parallelOp.getLoc());
+        if (!reductionVariables.empty()) {
+          wsloopOp.setReductionsAttr(
+              ArrayAttr::get(rewriter.getContext(), reductionDeclSymbols));
+          wsloopOp.getReductionVarsMutable().append(reductionVariables);
+        }
+        rewriter.create<omp::TerminatorOp>(loc); // omp.parallel terminator.
+
+        // The wrapper's entry block arguments will define the reduction
+        // variables.
+        llvm::SmallVector<mlir::Type> reductionTypes;
+        reductionTypes.reserve(reductionVariables.size());
+        llvm::transform(reductionVariables, std::back_inserter(reductionTypes),
+                        [](mlir::Value v) { return v.getType(); });
+        rewriter.createBlock(
+            &wsloopOp.getRegion(), {}, reductionTypes,
+            llvm::SmallVector<mlir::Location>(reductionVariables.size(),
+                                              parallelOp.getLoc()));
+
+        rewriter.setInsertionPoint(
+            rewriter.create<omp::TerminatorOp>(parallelOp.getLoc()));
+
+        // Create loop nest and populate region with contents of scf.parallel.
+        auto loopOp = rewriter.create<omp::LoopNestOp>(
             parallelOp.getLoc(), parallelOp.getLowerBound(),
             parallelOp.getUpperBound(), parallelOp.getStep());
-        rewriter.create<omp::TerminatorOp>(loc);
 
-        rewriter.inlineRegionBefore(parallelOp.getRegion(), loop.getRegion(),
-                                    loop.getRegion().begin());
+        rewriter.inlineRegionBefore(parallelOp.getRegion(), loopOp.getRegion(),
+                                    loopOp.getRegion().begin());
+
+        // Remove reduction-related block arguments from omp.loop_nest and
+        // redirect uses to the corresponding omp.wsloop block argument.
+        mlir::Block &loopOpEntryBlock = loopOp.getRegion().front();
+        unsigned numLoops = parallelOp.getNumLoops();
+        rewriter.replaceAllUsesWith(
+            loopOpEntryBlock.getArguments().drop_front(numLoops),
+            wsloopOp.getRegion().getArguments());
+        loopOpEntryBlock.eraseArguments(
+            numLoops, loopOpEntryBlock.getNumArguments() - numLoops);
 
-        Block *ops = rewriter.splitBlock(&*loop.getRegion().begin(),
-                                         loop.getRegion().begin()->begin());
+        Block *ops = rewriter.splitBlock(&*loopOp.getRegion().begin(),
+                                         loopOp.getRegion().begin()->begin());
 
-        rewriter.setInsertionPointToStart(&*loop.getRegion().begin());
+        rewriter.setInsertionPointToStart(&*loopOp.getRegion().begin());
 
         auto scope = rewriter.create<memref::AllocaScopeOp>(parallelOp.getLoc(),
                                                             TypeRange());
@@ -481,11 +514,6 @@ struct ParallelOpLowering : public OpRewritePattern<scf::ParallelOp> {
         rewriter.mergeBlocks(ops, scopeBlock);
         rewriter.setInsertionPointToEnd(&*scope.getBodyRegion().begin());
         rewriter.create<memref::AllocaScopeReturnOp>(loc, ValueRange());
-        if (!reductionVariables.empty()) {
-          loop.setReductionsAttr(
-              ArrayAttr::get(rewriter.getContext(), reductionDeclSymbols));
-          loop.getReductionVarsMutable().append(reductionVariables);
-        }
       }
     }
 
diff --git a/mlir/test/Conversion/SCFToOpenMP/reductions.mlir b/mlir/test/Conversion/SCFToOpenMP/reductions.mlir
index 3b6c145d62f1a8..fc6d56559c2618 100644
--- a/mlir/test/Conversion/SCFToOpenMP/reductions.mlir
+++ b/mlir/test/Conversion/SCFToOpenMP/reductions.mlir
@@ -28,6 +28,7 @@ func.func @reduction1(%arg0 : index, %arg1 : index, %arg2 : index,
   // CHECK: omp.parallel
   // CHECK: omp.wsloop
   // CHECK-SAME: reduction(@[[$REDF]] %[[BUF]] -> %[[PVT_BUF:[a-z0-9]+]]
+  // CHECK: omp.loop_nest
   // CHECK: memref.alloca_scope
   scf.parallel (%i0, %i1) = (%arg0, %arg1) to (%arg2, %arg3)
                             step (%arg4, %step) init (%zero) -> (f32) {
@@ -43,6 +44,7 @@ func.func @reduction1(%arg0 : index, %arg1 : index, %arg2 : index,
     }
     // CHECK: omp.yield
   }
+  // CHECK:   omp.terminator
   // CHECK: omp.terminator
   // CHECK: llvm.load %[[BUF]]
   return
@@ -107,6 +109,7 @@ func.func @reduction_muli(%arg0 : index, %arg1 : index, %arg2 : index,
   %one = arith.constant 1 : i32
   // CHECK: %[[RED_VAR:.*]] = llvm.alloca %{{.*}} x i32 : (i64) -> !llvm.ptr
   // CHECK: omp.wsloop reduction(@[[$REDI]] %[[RED_VAR]] -> %[[RED_PVT_VAR:.*]] : !llvm.ptr)
+  // CHECK: omp.loop_nest
   scf.parallel (%i0, %i1) = (%arg0, %arg1) to (%arg2, %arg3)
                             step (%arg4, %step) init (%one) -> (i32) {
     // CHECK: %[[C2:.*]] = arith.constant 2 : i32
@@ -208,6 +211,7 @@ func.func @reduction4(%arg0 : index, %arg1 : index, %arg2 : index,
   // CHECK: omp.wsloop
   // CHECK-SAME: reduction(@[[$REDF1]] %[[BUF1]] -> %[[PVT_BUF1:[a-z0-9]+]]
   // CHECK-SAME:           @[[$REDF2]] %[[BUF2]] -> %[[PVT_BUF2:[a-z0-9]+]]
+  // CHECK: omp.loop_nest
   // CHECK: memref.alloca_scope
   %res:2 = scf.parallel (%i0, %i1) = (%arg0, %arg1) to (%arg2, %arg3)
                         step (%arg4, %step) init (%zero, %ione) -> (f32, i64) {
@@ -236,6 +240,7 @@ func.func @reduction4(%arg0 : index, %arg1 : index, %arg2 : index,
     }
     // CHECK: omp.yield
   }
+  // CHECK:   omp.terminator
   // CHECK: omp.terminator
   // CHECK: %[[RES1:.*]] = llvm.load %[[BUF1]] : !llvm.ptr -> f32
   // CHECK: %[[RES2:.*]] = llvm.load %[[BUF2]] : !llvm.ptr -> i64
diff --git a/mlir/test/Conversion/SCFToOpenMP/scf-to-openmp.mlir b/mlir/test/Conversion/SCFToOpenMP/scf-to-openmp.mlir
index acd2690c56e2e6..b2f19d294cb5fe 100644
--- a/mlir/test/Conversion/SCFToOpenMP/scf-to-openmp.mlir
+++ b/mlir/test/Conversion/SCFToOpenMP/scf-to-openmp.mlir
@@ -2,10 +2,11 @@
 
 // CHECK-LABEL: @parallel
 func.func @parallel(%arg0: index, %arg1: index, %arg2: index,
-          %arg3: index, %arg4: index, %arg5: index) {
+                    %arg3: index, %arg4: index, %arg5: index) {
   // CHECK: %[[FOUR:.+]] = llvm.mlir.constant(4 : i32) : i32
   // CHECK: omp.parallel num_threads(%[[FOUR]] : i32) {
-  // CHECK: omp.wsloop for (%[[LVAR1:.*]], %[[LVAR2:.*]]) : index = (%arg0, %arg1) to (%arg2, %arg3) step (%arg4, %arg5) {
+  // CHECK: omp.wsloop {
+  // CHECK: omp.loop_nest (%[[LVAR1:.*]], %[[LVAR2:.*]]) : index = (%arg0, %arg1) to (%arg2, %arg3) step (%arg4, %arg5) {
   // CHECK: memref.alloca_scope
   scf.parallel (%i, %j) = (%arg0, %arg1) to (%arg2, %arg3) step (%arg4, %arg5) {
     // CHECK: "test.payload"(%[[LVAR1]], %[[LVAR2]]) : (index, index) -> ()
@@ -13,6 +14,8 @@ func.func @parallel(%arg0: index, %arg1: index, %arg2: index,
     // CHECK:   omp.yield
     // CHECK: }
   }
+  // CHECK:     omp.terminator
+  // CHECK:   }
   // CHECK:   omp.terminator
   // CHECK: }
   return
@@ -23,20 +26,26 @@ func.func @nested_loops(%arg0: index, %arg1: index, %arg2: index,
                    %arg3: index, %arg4: index, %arg5: index) {
   // CHECK: %[[FOUR:.+]] = llvm.mlir.constant(4 : i32) : i32
   // CHECK: omp.parallel num_threads(%[[FOUR]] : i32) {
-  // CHECK: omp.wsloop for (%[[LVAR_OUT1:.*]]) : index = (%arg0) to (%arg2) step (%arg4) {
-    // CHECK: memref.alloca_scope
+  // CHECK: omp.wsloop {
+  // CHECK: omp.loop_nest (%[[LVAR_OUT1:.*]]) : index = (%arg0) to (%arg2) step (%arg4) {
+  // CHECK: memref.alloca_scope
   scf.parallel (%i) = (%arg0) to (%arg2) step (%arg4) {
     // CHECK: omp.parallel
-    // CHECK: omp.wsloop for (%[[LVAR_IN1:.*]]) : index = (%arg1) to (%arg3) step (%arg5) {
+    // CHECK: omp.wsloop {
+    // CHECK: omp.loop_nest (%[[LVAR_IN1:.*]]) : index = (%arg1) to (%arg3) step (%arg5) {
     // CHECK: memref.alloca_scope
     scf.parallel (%j) = (%arg1) to (%arg3) step (%arg5) {
       // CHECK: "test.payload"(%[[LVAR_OUT1]], %[[LVAR_IN1]]) : (index, index) -> ()
       "test.payload"(%i, %j) : (index, index) -> ()
       // CHECK: }
     }
-    // CHECK:   omp.yield
+    // CHECK:     omp.yield
+    // CHECK:   }
+    // CHECK:   omp.terminator
     // CHECK: }
   }
+  // CHECK:     omp.terminator
+  // CHECK:   }
   // CHECK:   omp.terminator
   // CHECK: }
   return
@@ -47,7 +56,8 @@ func.func @adjacent_loops(%arg0: index, %arg1: index, %arg2: index,
                      %arg3: index, %arg4: index, %arg5: index) {
   // CHECK: %[[FOUR:.+]] = llvm.mlir.constant(4 : i32) : i32
   // CHECK: omp.parallel num_threads(%[[FOUR]] : i32) {
-  // CHECK: omp.wsloop for (%[[LVAR_AL1:.*]]) : index = (%arg0) to (%arg2) step (%arg4) {
+  // CHECK: omp.wsloop {
+  // CHECK: omp.loop_nest (%[[LVAR_AL1:.*]]) : index = (%arg0) to (%arg2) step (%arg4) {
   // CHECK: memref.alloca_scope
   scf.parallel (%i) = (%arg0) to (%arg2) step (%arg4) {
     // CHECK: "test.payload1"(%[[LVAR_AL1]]) : (index) -> ()
@@ -55,12 +65,15 @@ func.func @adjacent_loops(%arg0: index, %arg1: index, %arg2: index,
     // CHECK:   omp.yield
     // CHECK: }
   }
+  // CHECK:     omp.terminator
+  // CHECK:   }
   // CHECK:   omp.terminator
   // CHECK: }
 
   // CHECK: %[[FOUR:.+]] = llvm.mlir.constant(4 : i32) : i32
   // CHECK: omp.parallel num_threads(%[[FOUR]] : i32) {
-  // CHECK: omp.wsloop for (%[[LVAR_AL2:.*]]) : index = (%arg1) to (%arg3) step (%arg5) {
+  // CHECK: omp.wsloop {
+  // CHECK: omp.loop_nest (%[[LVAR_AL2:.*]]) : index = (%arg1) to (%arg3) step (%arg5) {
   // CHECK: memref.alloca_scope
   scf.parallel (%j) = (%arg1) to (%arg3) step (%arg5) {
     // CHECK: "test.payload2"(%[[LVAR_AL2]]) : (index) -> ()
@@ -68,6 +81,8 @@ func.func @adjacent_loops(%arg0: index, %arg1: index, %arg2: index,
     // CHECK:   omp.yield
     // CHECK: }
   }
+  // CHECK:     omp.terminator
+  // CHECK:   }
   // CHECK:   omp.terminator
   // CHECK: }
   return

Copy link
Contributor

@tblah tblah left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM, thanks

Base automatically changed from users/skatrak/spr/wsloop-wrapper-02-dependent-ops to main April 24, 2024 13:28
@skatrak skatrak merged commit 8843d54 into main Apr 24, 2024
2 of 3 checks passed
@skatrak skatrak deleted the users/skatrak/spr/wsloop-wrapper-03-scf-parallel branch April 24, 2024 13:29
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

4 participants