Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[MLIR][OpenMP] Update omp.wsloop translation to LLVM IR (4/5) #89214

Merged
merged 21 commits into from
Apr 24, 2024

Conversation

skatrak
Copy link
Contributor

@skatrak skatrak commented Apr 18, 2024

This patch introduces minimal changes to the MLIR to LLVM IR translation of omp.wsloop to support the loop wrapper approach.

There is omp.loop_nest related translation code that should be extracted and shared among all loop operations (e.g. omp.simd). This would possibly also help in the addition of support for compound constructs later on. This first approach is only intended to keep things running after the transition to loop wrappers and not to add support for other use cases enabled by that transition.

This PR on its own will not pass premerge tests. All patches in the stack are needed before it can be compiled and passes tests.

This patch updates the definition of `omp.wsloop` to enforce the restrictions
of a loop wrapper operation.

Related tests are updated but this PR on its own will not pass premerge tests.
All patches in the stack are needed before it can be compiled and passes tests.
This patch updates verifiers for `omp.ordered.region`, `omp.cancel` and
`omp.cancellation_point`, which check for a parent `omp.wsloop`.

After transitioning to a loop wrapper-based approach, the expected direct
parent will become `omp.loop_nest` instead, so verifiers need to take this into
account.

This PR on its own will not pass premerge tests. All patches in the stack are
needed before it can be compiled and passes tests.
This patch makes changes to the `scf.parallel` to `omp.parallel` + `omp.wsloop`
lowering pass in order to introduce a nested `omp.loop_nest` as well, and to
follow the new loop wrapper role for `omp.wsloop`.

This PR on its own will not pass premerge tests. All patches in the stack are
needed before it can be compiled and passes tests.
This patch introduces minimal changes to the MLIR to LLVM IR translation of
omp.wsloop to support the loop wrapper approach.

There is `omp.loop_nest` related translation code that should be extracted and
shared among all loop operations (e.g. `omp.simd`). This would possibly also
help in the addition of support for compound constructs later on. This first
approach is only intended to keep things running after the transition to loop
wrappers and not to add support for other use cases enabled by that transition.

This PR on its own will not pass premerge tests. All patches in the stack are
needed before it can be compiled and passes tests.
@llvmbot
Copy link
Collaborator

llvmbot commented Apr 18, 2024

@llvm/pr-subscribers-mlir-openmp
@llvm/pr-subscribers-flang-openmp

@llvm/pr-subscribers-mlir-llvm

Author: Sergio Afonso (skatrak)

Changes

This patch introduces minimal changes to the MLIR to LLVM IR translation of omp.wsloop to support the loop wrapper approach.

There is omp.loop_nest related translation code that should be extracted and shared among all loop operations (e.g. omp.simd). This would possibly also help in the addition of support for compound constructs later on. This first approach is only intended to keep things running after the transition to loop wrappers and not to add support for other use cases enabled by that transition.

This PR on its own will not pass premerge tests. All patches in the stack are needed before it can be compiled and passes tests.


Patch is 69.30 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/89214.diff

9 Files Affected:

  • (modified) mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp (+35-33)
  • (modified) mlir/test/Target/LLVMIR/omptarget-parallel-wsloop.mlir (+7-4)
  • (modified) mlir/test/Target/LLVMIR/omptarget-wsloop-collapsed.mlir (+10-7)
  • (modified) mlir/test/Target/LLVMIR/omptarget-wsloop.mlir (+12-6)
  • (modified) mlir/test/Target/LLVMIR/openmp-data-target-device.mlir (+17-14)
  • (modified) mlir/test/Target/LLVMIR/openmp-llvm.mlir (+404-337)
  • (modified) mlir/test/Target/LLVMIR/openmp-nested.mlir (+18-12)
  • (modified) mlir/test/Target/LLVMIR/openmp-reduction.mlir (+63-50)
  • (modified) mlir/test/Target/LLVMIR/openmp-wsloop-reduction-cleanup.mlir (+6-3)
diff --git a/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp b/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp
index e89ff9209b034a..22d6462b881dc0 100644
--- a/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp
+++ b/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp
@@ -916,35 +916,37 @@ static LogicalResult inlineReductionCleanup(
 static LogicalResult
 convertOmpWsloop(Operation &opInst, llvm::IRBuilderBase &builder,
                  LLVM::ModuleTranslation &moduleTranslation) {
-  auto loop = cast<omp::WsloopOp>(opInst);
-  const bool isByRef = loop.getByref();
+  auto wsloopOp = cast<omp::WsloopOp>(opInst);
+  auto loopOp = cast<omp::LoopNestOp>(wsloopOp.getWrappedLoop());
+  const bool isByRef = wsloopOp.getByref();
+
   // TODO: this should be in the op verifier instead.
-  if (loop.getLowerBound().empty())
+  if (loopOp.getLowerBound().empty())
     return failure();
 
   // Static is the default.
   auto schedule =
-      loop.getScheduleVal().value_or(omp::ClauseScheduleKind::Static);
+      wsloopOp.getScheduleVal().value_or(omp::ClauseScheduleKind::Static);
 
   // Find the loop configuration.
-  llvm::Value *step = moduleTranslation.lookupValue(loop.getStep()[0]);
+  llvm::Value *step = moduleTranslation.lookupValue(loopOp.getStep()[0]);
   llvm::Type *ivType = step->getType();
   llvm::Value *chunk = nullptr;
-  if (loop.getScheduleChunkVar()) {
+  if (wsloopOp.getScheduleChunkVar()) {
     llvm::Value *chunkVar =
-        moduleTranslation.lookupValue(loop.getScheduleChunkVar());
+        moduleTranslation.lookupValue(wsloopOp.getScheduleChunkVar());
     chunk = builder.CreateSExtOrTrunc(chunkVar, ivType);
   }
 
   SmallVector<omp::DeclareReductionOp> reductionDecls;
-  collectReductionDecls(loop, reductionDecls);
+  collectReductionDecls(wsloopOp, reductionDecls);
   llvm::OpenMPIRBuilder::InsertPointTy allocaIP =
       findAllocaInsertPoint(builder, moduleTranslation);
 
   SmallVector<llvm::Value *> privateReductionVariables;
   DenseMap<Value, llvm::Value *> reductionVariableMap;
   if (!isByRef) {
-    allocByValReductionVars(loop, builder, moduleTranslation, allocaIP,
+    allocByValReductionVars(wsloopOp, builder, moduleTranslation, allocaIP,
                             reductionDecls, privateReductionVariables,
                             reductionVariableMap);
   }
@@ -952,13 +954,12 @@ convertOmpWsloop(Operation &opInst, llvm::IRBuilderBase &builder,
   // Before the loop, store the initial values of reductions into reduction
   // variables. Although this could be done after allocas, we don't want to mess
   // up with the alloca insertion point.
-  MutableArrayRef<BlockArgument> reductionArgs =
-      loop.getRegion().getArguments().take_back(loop.getNumReductionVars());
-  for (unsigned i = 0; i < loop.getNumReductionVars(); ++i) {
+  ArrayRef<BlockArgument> reductionArgs = wsloopOp.getRegion().getArguments();
+  for (unsigned i = 0; i < wsloopOp.getNumReductionVars(); ++i) {
     SmallVector<llvm::Value *> phis;
 
     // map block argument to initializer region
-    mapInitializationArg(loop, moduleTranslation, reductionDecls, i);
+    mapInitializationArg(wsloopOp, moduleTranslation, reductionDecls, i);
 
     if (failed(inlineConvertOmpRegions(reductionDecls[i].getInitializerRegion(),
                                        "omp.reduction.neutral", builder,
@@ -977,7 +978,7 @@ convertOmpWsloop(Operation &opInst, llvm::IRBuilderBase &builder,
 
       privateReductionVariables.push_back(var);
       moduleTranslation.mapValue(reductionArgs[i], phis[0]);
-      reductionVariableMap.try_emplace(loop.getReductionVars()[i], phis[0]);
+      reductionVariableMap.try_emplace(wsloopOp.getReductionVars()[i], phis[0]);
     } else {
       // for by-ref case the store is inside of the reduction region
       builder.CreateStore(phis[0], privateReductionVariables[i]);
@@ -1008,33 +1009,34 @@ convertOmpWsloop(Operation &opInst, llvm::IRBuilderBase &builder,
   auto bodyGen = [&](llvm::OpenMPIRBuilder::InsertPointTy ip, llvm::Value *iv) {
     // Make sure further conversions know about the induction variable.
     moduleTranslation.mapValue(
-        loop.getRegion().front().getArgument(loopInfos.size()), iv);
+        loopOp.getRegion().front().getArgument(loopInfos.size()), iv);
 
     // Capture the body insertion point for use in nested loops. BodyIP of the
     // CanonicalLoopInfo always points to the beginning of the entry block of
     // the body.
     bodyInsertPoints.push_back(ip);
 
-    if (loopInfos.size() != loop.getNumLoops() - 1)
+    if (loopInfos.size() != loopOp.getNumLoops() - 1)
       return;
 
     // Convert the body of the loop.
     builder.restoreIP(ip);
-    convertOmpOpRegions(loop.getRegion(), "omp.wsloop.region", builder,
+    convertOmpOpRegions(loopOp.getRegion(), "omp.wsloop.region", builder,
                         moduleTranslation, bodyGenStatus);
   };
 
   // Delegate actual loop construction to the OpenMP IRBuilder.
-  // TODO: this currently assumes Wsloop is semantically similar to SCF loop,
-  // i.e. it has a positive step, uses signed integer semantics. Reconsider
-  // this code when Wsloop clearly supports more cases.
+  // TODO: this currently assumes omp.loop_nest is semantically similar to SCF
+  // loop, i.e. it has a positive step, uses signed integer semantics.
+  // Reconsider this code when the nested loop operation clearly supports more
+  // cases.
   llvm::OpenMPIRBuilder *ompBuilder = moduleTranslation.getOpenMPBuilder();
-  for (unsigned i = 0, e = loop.getNumLoops(); i < e; ++i) {
+  for (unsigned i = 0, e = loopOp.getNumLoops(); i < e; ++i) {
     llvm::Value *lowerBound =
-        moduleTranslation.lookupValue(loop.getLowerBound()[i]);
+        moduleTranslation.lookupValue(loopOp.getLowerBound()[i]);
     llvm::Value *upperBound =
-        moduleTranslation.lookupValue(loop.getUpperBound()[i]);
-    llvm::Value *step = moduleTranslation.lookupValue(loop.getStep()[i]);
+        moduleTranslation.lookupValue(loopOp.getUpperBound()[i]);
+    llvm::Value *step = moduleTranslation.lookupValue(loopOp.getStep()[i]);
 
     // Make sure loop trip count are emitted in the preheader of the outermost
     // loop at the latest so that they are all available for the new collapsed
@@ -1047,7 +1049,7 @@ convertOmpWsloop(Operation &opInst, llvm::IRBuilderBase &builder,
     }
     loopInfos.push_back(ompBuilder->createCanonicalLoop(
         loc, bodyGen, lowerBound, upperBound, step,
-        /*IsSigned=*/true, loop.getInclusive(), computeIP));
+        /*IsSigned=*/true, loopOp.getInclusive(), computeIP));
 
     if (failed(bodyGenStatus))
       return failure();
@@ -1062,13 +1064,13 @@ convertOmpWsloop(Operation &opInst, llvm::IRBuilderBase &builder,
   allocaIP = findAllocaInsertPoint(builder, moduleTranslation);
 
   // TODO: Handle doacross loops when the ordered clause has a parameter.
-  bool isOrdered = loop.getOrderedVal().has_value();
+  bool isOrdered = wsloopOp.getOrderedVal().has_value();
   std::optional<omp::ScheduleModifier> scheduleModifier =
-      loop.getScheduleModifier();
-  bool isSimd = loop.getSimdModifier();
+      wsloopOp.getScheduleModifier();
+  bool isSimd = wsloopOp.getSimdModifier();
 
   ompBuilder->applyWorkshareLoop(
-      ompLoc.DL, loopInfo, allocaIP, !loop.getNowait(),
+      ompLoc.DL, loopInfo, allocaIP, !wsloopOp.getNowait(),
       convertToScheduleKind(schedule), chunk, isSimd,
       scheduleModifier == omp::ScheduleModifier::monotonic,
       scheduleModifier == omp::ScheduleModifier::nonmonotonic, isOrdered);
@@ -1080,7 +1082,7 @@ convertOmpWsloop(Operation &opInst, llvm::IRBuilderBase &builder,
   builder.restoreIP(afterIP);
 
   // Process the reductions if required.
-  if (loop.getNumReductionVars() == 0)
+  if (wsloopOp.getNumReductionVars() == 0)
     return success();
 
   // Create the reduction generators. We need to own them here because
@@ -1088,7 +1090,7 @@ convertOmpWsloop(Operation &opInst, llvm::IRBuilderBase &builder,
   SmallVector<OwningReductionGen> owningReductionGens;
   SmallVector<OwningAtomicReductionGen> owningAtomicReductionGens;
   SmallVector<llvm::OpenMPIRBuilder::ReductionInfo> reductionInfos;
-  collectReductionInfo(loop, builder, moduleTranslation, reductionDecls,
+  collectReductionInfo(wsloopOp, builder, moduleTranslation, reductionDecls,
                        owningReductionGens, owningAtomicReductionGens,
                        privateReductionVariables, reductionInfos);
 
@@ -1099,9 +1101,9 @@ convertOmpWsloop(Operation &opInst, llvm::IRBuilderBase &builder,
   builder.SetInsertPoint(tempTerminator);
   llvm::OpenMPIRBuilder::InsertPointTy contInsertPoint =
       ompBuilder->createReductions(builder.saveIP(), allocaIP, reductionInfos,
-                                   loop.getNowait(), isByRef);
+                                   wsloopOp.getNowait(), isByRef);
   if (!contInsertPoint.getBlock())
-    return loop->emitOpError() << "failed to convert reductions";
+    return wsloopOp->emitOpError() << "failed to convert reductions";
   auto nextInsertionPoint =
       ompBuilder->createBarrier(contInsertPoint, llvm::omp::OMPD_for);
   tempTerminator->eraseFromParent();
diff --git a/mlir/test/Target/LLVMIR/omptarget-parallel-wsloop.mlir b/mlir/test/Target/LLVMIR/omptarget-parallel-wsloop.mlir
index b0fe642238f14f..360b3b0c0e60c1 100644
--- a/mlir/test/Target/LLVMIR/omptarget-parallel-wsloop.mlir
+++ b/mlir/test/Target/LLVMIR/omptarget-parallel-wsloop.mlir
@@ -12,10 +12,13 @@ module attributes {dlti.dl_spec = #dlti.dl_spec<#dlti.dl_entry<"dlti.alloca_memo
       %loop_ub = llvm.mlir.constant(9 : i32) : i32
       %loop_lb = llvm.mlir.constant(0 : i32) : i32
       %loop_step = llvm.mlir.constant(1 : i32) : i32
-      omp.wsloop for  (%loop_cnt) : i32 = (%loop_lb) to (%loop_ub) inclusive step (%loop_step) {
-        %gep = llvm.getelementptr %arg0[0, %loop_cnt] : (!llvm.ptr, i32) -> !llvm.ptr, !llvm.array<10 x i32>
-        llvm.store %loop_cnt, %gep : i32, !llvm.ptr
-        omp.yield
+      omp.wsloop {
+        omp.loop_nest (%loop_cnt) : i32 = (%loop_lb) to (%loop_ub) inclusive step (%loop_step) {
+          %gep = llvm.getelementptr %arg0[0, %loop_cnt] : (!llvm.ptr, i32) -> !llvm.ptr, !llvm.array<10 x i32>
+          llvm.store %loop_cnt, %gep : i32, !llvm.ptr
+          omp.yield
+        }
+        omp.terminator
       }
      omp.terminator
     }
diff --git a/mlir/test/Target/LLVMIR/omptarget-wsloop-collapsed.mlir b/mlir/test/Target/LLVMIR/omptarget-wsloop-collapsed.mlir
index 0d77423abcb4f1..13d34b7e58f77e 100644
--- a/mlir/test/Target/LLVMIR/omptarget-wsloop-collapsed.mlir
+++ b/mlir/test/Target/LLVMIR/omptarget-wsloop-collapsed.mlir
@@ -8,13 +8,16 @@ module attributes {dlti.dl_spec = #dlti.dl_spec<#dlti.dl_entry<"dlti.alloca_memo
     %loop_ub = llvm.mlir.constant(99 : i32) : i32
     %loop_lb = llvm.mlir.constant(0 : i32) : i32
     %loop_step = llvm.mlir.constant(1 : index) : i32
-    omp.wsloop for  (%arg1, %arg2) : i32 = (%loop_lb, %loop_lb) to (%loop_ub, %loop_ub) inclusive step (%loop_step, %loop_step) {
-      %1 = llvm.add %arg1, %arg2  : i32
-      %2 = llvm.mul %arg2, %loop_ub overflow<nsw>  : i32
-      %3 = llvm.add %arg1, %2 :i32
-      %4 = llvm.getelementptr %arg0[%3] : (!llvm.ptr, i32) -> !llvm.ptr, i32
-      llvm.store %1, %4 : i32, !llvm.ptr
-      omp.yield
+    omp.wsloop {
+      omp.loop_nest (%arg1, %arg2) : i32 = (%loop_lb, %loop_lb) to (%loop_ub, %loop_ub) inclusive step (%loop_step, %loop_step) {
+        %1 = llvm.add %arg1, %arg2  : i32
+        %2 = llvm.mul %arg2, %loop_ub overflow<nsw>  : i32
+        %3 = llvm.add %arg1, %2 :i32
+        %4 = llvm.getelementptr %arg0[%3] : (!llvm.ptr, i32) -> !llvm.ptr, i32
+        llvm.store %1, %4 : i32, !llvm.ptr
+        omp.yield
+      }
+      omp.terminator
     }
     llvm.return
   }
diff --git a/mlir/test/Target/LLVMIR/omptarget-wsloop.mlir b/mlir/test/Target/LLVMIR/omptarget-wsloop.mlir
index 0f3f503dfa5377..ee851eaf71ac0b 100644
--- a/mlir/test/Target/LLVMIR/omptarget-wsloop.mlir
+++ b/mlir/test/Target/LLVMIR/omptarget-wsloop.mlir
@@ -8,10 +8,13 @@ module attributes {dlti.dl_spec = #dlti.dl_spec<#dlti.dl_entry<"dlti.alloca_memo
       %loop_ub = llvm.mlir.constant(9 : i32) : i32
       %loop_lb = llvm.mlir.constant(0 : i32) : i32
       %loop_step = llvm.mlir.constant(1 : i32) : i32
-      omp.wsloop for  (%loop_cnt) : i32 = (%loop_lb) to (%loop_ub) inclusive step (%loop_step) {
-        %gep = llvm.getelementptr %arg0[0, %loop_cnt] : (!llvm.ptr, i32) -> !llvm.ptr, !llvm.array<10 x i32>
-        llvm.store %loop_cnt, %gep : i32, !llvm.ptr
-        omp.yield
+      omp.wsloop {
+        omp.loop_nest (%loop_cnt) : i32 = (%loop_lb) to (%loop_ub) inclusive step (%loop_step) {
+          %gep = llvm.getelementptr %arg0[0, %loop_cnt] : (!llvm.ptr, i32) -> !llvm.ptr, !llvm.array<10 x i32>
+          llvm.store %loop_cnt, %gep : i32, !llvm.ptr
+          omp.yield
+        }
+        omp.terminator
       }
     llvm.return
   }
@@ -20,8 +23,11 @@ module attributes {dlti.dl_spec = #dlti.dl_spec<#dlti.dl_entry<"dlti.alloca_memo
       %loop_ub = llvm.mlir.constant(9 : i32) : i32
       %loop_lb = llvm.mlir.constant(0 : i32) : i32
       %loop_step = llvm.mlir.constant(1 : i32) : i32
-      omp.wsloop for  (%loop_cnt) : i32 = (%loop_lb) to (%loop_ub) inclusive step (%loop_step) {
-        omp.yield
+      omp.wsloop {
+        omp.loop_nest (%loop_cnt) : i32 = (%loop_lb) to (%loop_ub) inclusive step (%loop_step) {
+          omp.yield
+        }
+        omp.terminator
       }
     llvm.return
   }
diff --git a/mlir/test/Target/LLVMIR/openmp-data-target-device.mlir b/mlir/test/Target/LLVMIR/openmp-data-target-device.mlir
index d41429a6de066f..4ea9df369af66c 100644
--- a/mlir/test/Target/LLVMIR/openmp-data-target-device.mlir
+++ b/mlir/test/Target/LLVMIR/openmp-data-target-device.mlir
@@ -31,20 +31,23 @@ module attributes { } {
           %18 = llvm.mlir.constant(1 : i64) : i64
           %19 = llvm.alloca %18 x i32 {pinned} : (i64) -> !llvm.ptr<5>
           %20 = llvm.addrspacecast %19 : !llvm.ptr<5> to !llvm.ptr
-          omp.wsloop  for  (%arg2) : i32 = (%16) to (%15) inclusive step (%16) {
-            llvm.store %arg2, %20 : i32, !llvm.ptr
-            %21 = llvm.load %20 : !llvm.ptr -> i32
-            %22 = llvm.sext %21 : i32 to i64
-            %23 = llvm.mlir.constant(1 : i64) : i64
-            %24 = llvm.mlir.constant(0 : i64) : i64
-            %25 = llvm.sub %22, %23 overflow<nsw>  : i64
-            %26 = llvm.mul %25, %23 overflow<nsw>  : i64
-            %27 = llvm.mul %26, %23 overflow<nsw>  : i64
-            %28 = llvm.add %27, %24 overflow<nsw>  : i64
-            %29 = llvm.mul %23, %17 overflow<nsw>  : i64
-            %30 = llvm.getelementptr %arg0[%28] : (!llvm.ptr, i64) -> !llvm.ptr, i32
-            llvm.store %21, %30 : i32, !llvm.ptr
-            omp.yield
+          omp.wsloop {
+            omp.loop_nest (%arg2) : i32 = (%16) to (%15) inclusive step (%16) {
+              llvm.store %arg2, %20 : i32, !llvm.ptr
+              %21 = llvm.load %20 : !llvm.ptr -> i32
+              %22 = llvm.sext %21 : i32 to i64
+              %23 = llvm.mlir.constant(1 : i64) : i64
+              %24 = llvm.mlir.constant(0 : i64) : i64
+              %25 = llvm.sub %22, %23 overflow<nsw>  : i64
+              %26 = llvm.mul %25, %23 overflow<nsw>  : i64
+              %27 = llvm.mul %26, %23 overflow<nsw>  : i64
+              %28 = llvm.add %27, %24 overflow<nsw>  : i64
+              %29 = llvm.mul %23, %17 overflow<nsw>  : i64
+              %30 = llvm.getelementptr %arg0[%28] : (!llvm.ptr, i64) -> !llvm.ptr, i32
+              llvm.store %21, %30 : i32, !llvm.ptr
+              omp.yield
+            }
+            omp.terminator
           }
           omp.terminator
         }
diff --git a/mlir/test/Target/LLVMIR/openmp-llvm.mlir b/mlir/test/Target/LLVMIR/openmp-llvm.mlir
index d1390022c1dc44..ad40ca26bec9f8 100644
--- a/mlir/test/Target/LLVMIR/openmp-llvm.mlir
+++ b/mlir/test/Target/LLVMIR/openmp-llvm.mlir
@@ -320,18 +320,20 @@ llvm.func @wsloop_simple(%arg0: !llvm.ptr) {
   %1 = llvm.mlir.constant(10 : index) : i64
   %2 = llvm.mlir.constant(1 : index) : i64
   omp.parallel {
-    "omp.wsloop"(%1, %0, %2) ({
-    ^bb0(%arg1: i64):
-      // The form of the emitted IR is controlled by OpenMPIRBuilder and
-      // tested there. Just check that the right functions are called.
-      // CHECK: call i32 @__kmpc_global_thread_num
-      // CHECK: call void @__kmpc_for_static_init_{{.*}}(ptr @[[$loc_struct]],
-      %3 = llvm.mlir.constant(2.000000e+00 : f32) : f32
-      %4 = llvm.getelementptr %arg0[%arg1] : (!llvm.ptr, i64) -> !llvm.ptr, f32
-      llvm.store %3, %4 : f32, !llvm.ptr
-      omp.yield
+    "omp.wsloop"() ({
+      omp.loop_nest (%arg1) : i64 = (%1) to (%0) step (%2) {
+        // The form of the emitted IR is controlled by OpenMPIRBuilder and
+        // tested there. Just check that the right functions are called.
+        // CHECK: call i32 @__kmpc_global_thread_num
+        // CHECK: call void @__kmpc_for_static_init_{{.*}}(ptr @[[$loc_struct]],
+        %3 = llvm.mlir.constant(2.000000e+00 : f32) : f32
+        %4 = llvm.getelementptr %arg0[%arg1] : (!llvm.ptr, i64) -> !llvm.ptr, f32
+        llvm.store %3, %4 : f32, !llvm.ptr
+        omp.yield
+      }
+      omp.terminator
       // CHECK: call void @__kmpc_for_static_fini(ptr @[[$loc_struct]],
-    }) {operandSegmentSizes = array<i32: 1, 1, 1, 0, 0, 0, 0>} : (i64, i64, i64) -> ()
+    }) : () -> ()
     omp.terminator
   }
   llvm.return
@@ -345,13 +347,15 @@ llvm.func @wsloop_inclusive_1(%arg0: !llvm.ptr) {
   %1 = llvm.mlir.constant(10 : index) : i64
   %2 = llvm.mlir.constant(1 : index) : i64
   // CHECK: store i64 31, ptr %{{.*}}upperbound
-  "omp.wsloop"(%1, %0, %2) ({
-  ^bb0(%arg1: i64):
-    %3 = llvm.mlir.constant(2.000000e+00 : f32) : f32
-    %4 = llvm.getelementptr %arg0[%arg1] : (!llvm.ptr, i64) -> !llvm.ptr, f32
-    llvm.store %3, %4 : f32, !llvm.ptr
-    omp.yield
-  }) {operandSegmentSizes = array<i32: 1, 1, 1, 0, 0, 0, 0>} : (i64, i64, i64) -> ()
+  "omp.wsloop"() ({
+    omp.loop_nest (%arg1) : i64 = (%1) to (%0) step (%2) {
+      %3 = llvm.mlir.constant(2.000000e+00 : f32) : f32
+      %4 = llvm.getelementptr %arg0[%arg1] : (!llvm.ptr, i64) -> !llvm.ptr, f32
+      llvm.store %3, %4 : f32, !llvm.ptr
+      omp.yield
+    }
+    omp.terminator
+  }) : () -> ()
   llvm.return
 }
 
@@ -363,13 +367,15 @@ llvm.func @wsloop_inclusive_2(%arg0: !llvm.ptr) {
   %1 = llvm.mlir.constant(10 : index) : i64
   %2 = llvm.mlir.constant(1 : index) : i64
   // CHECK: store i64 32, ptr %{{.*}}upperbound
-  "omp.wsloop"(%1, %0, %2) ({
-  ^bb0(%arg1: i64):
-    %3 = llvm.mlir.constant(2.000000e+00 : f32) : f32
-    %4 = llvm.getelementptr %arg0[%arg1] : (!llvm.ptr, i64) -> !llvm.ptr, f32
-    llvm.store %3, %4 : f32, !llvm.ptr
-    omp.yield
-  }) {inclusive, operandSegmentSizes = array<i32: 1, 1, 1, 0, 0, 0, 0>} : (i64, i64, i64) -> ()
+  "omp.wsloop"() ({
+    omp.loop_nest (%arg1) : i64 = (%1) to (%0) inclusive step (%2) {
+      %3 = llvm.mlir.constant(2.000000e+00 : f32) : f32
+      %4 = llvm.getelementptr %arg0[%arg1] : (!llvm.ptr, i64) -> !llvm.ptr, f32
+      llvm.store %3, %4 : f32, !llvm.ptr
+      omp.yield
+    }
+    omp.terminator
+  }) : () -> ()
   llvm.return
 }
 
@@ -379,14 +385,16 @@ llvm.func @body(i32)
 
 // CHECK-LABEL: @test_omp_wsloop_static_defchunk
 llvm.func @test_omp_wsloop_static_defchunk(%lb : i32, %ub : i32, %step : i32) -> () {
- omp.wsloop schedule(static)
- for (%iv) : i32 = (%lb) to (%ub) step (%step) {
-   // CHECK: call void @__kmpc_for_static_init_4u(ptr @{{.*}}, i32 %{{.*}}, i32 34, ptr %{{.*}}, ptr %{{.*}}, ptr %{{.*}}, ptr %{{.*}}, i32 1, i32 0)
-   // CHECK: call void @__kmpc_for_static_fini
-   llvm.call @body(%iv) : (i32) -> ()
-   omp.yield
- }
- llvm.return
+  omp.wsloop schedule(static) {
+    omp.loop_nest (%iv) : i32 = (%lb) to (%ub) step (%step) {
+      // CHECK: call void @__kmpc_for_static_init_4u(ptr @{{.*}}, i32 %{{.*}}, i32 34, ptr %{{.*}}, ptr %{{.*}}, ptr %{{.*}}, ptr %{{.*}}, i32 1, i32 0)
+      // CHECK: call void @__kmpc_for_static_fini
+      llvm.call @body(%iv) : (i32) -> ()
+      omp.yield
+    }
+    omp.terminator
+  }
+  llvm.return
 }
 
 // -----
@@ -395,15 +403,17 @@ llvm.func @body(i32)
 
 // CHECK-...
[truncated]

@llvmbot
Copy link
Collaborator

llvmbot commented Apr 18, 2024

@llvm/pr-subscribers-mlir

Author: Sergio Afonso (skatrak)

Changes

This patch introduces minimal changes to the MLIR to LLVM IR translation of omp.wsloop to support the loop wrapper approach.

There is omp.loop_nest related translation code that should be extracted and shared among all loop operations (e.g. omp.simd). This would possibly also help in the addition of support for compound constructs later on. This first approach is only intended to keep things running after the transition to loop wrappers and not to add support for other use cases enabled by that transition.

This PR on its own will not pass premerge tests. All patches in the stack are needed before it can be compiled and passes tests.


Patch is 69.30 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/89214.diff

9 Files Affected:

  • (modified) mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp (+35-33)
  • (modified) mlir/test/Target/LLVMIR/omptarget-parallel-wsloop.mlir (+7-4)
  • (modified) mlir/test/Target/LLVMIR/omptarget-wsloop-collapsed.mlir (+10-7)
  • (modified) mlir/test/Target/LLVMIR/omptarget-wsloop.mlir (+12-6)
  • (modified) mlir/test/Target/LLVMIR/openmp-data-target-device.mlir (+17-14)
  • (modified) mlir/test/Target/LLVMIR/openmp-llvm.mlir (+404-337)
  • (modified) mlir/test/Target/LLVMIR/openmp-nested.mlir (+18-12)
  • (modified) mlir/test/Target/LLVMIR/openmp-reduction.mlir (+63-50)
  • (modified) mlir/test/Target/LLVMIR/openmp-wsloop-reduction-cleanup.mlir (+6-3)
diff --git a/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp b/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp
index e89ff9209b034a..22d6462b881dc0 100644
--- a/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp
+++ b/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp
@@ -916,35 +916,37 @@ static LogicalResult inlineReductionCleanup(
 static LogicalResult
 convertOmpWsloop(Operation &opInst, llvm::IRBuilderBase &builder,
                  LLVM::ModuleTranslation &moduleTranslation) {
-  auto loop = cast<omp::WsloopOp>(opInst);
-  const bool isByRef = loop.getByref();
+  auto wsloopOp = cast<omp::WsloopOp>(opInst);
+  auto loopOp = cast<omp::LoopNestOp>(wsloopOp.getWrappedLoop());
+  const bool isByRef = wsloopOp.getByref();
+
   // TODO: this should be in the op verifier instead.
-  if (loop.getLowerBound().empty())
+  if (loopOp.getLowerBound().empty())
     return failure();
 
   // Static is the default.
   auto schedule =
-      loop.getScheduleVal().value_or(omp::ClauseScheduleKind::Static);
+      wsloopOp.getScheduleVal().value_or(omp::ClauseScheduleKind::Static);
 
   // Find the loop configuration.
-  llvm::Value *step = moduleTranslation.lookupValue(loop.getStep()[0]);
+  llvm::Value *step = moduleTranslation.lookupValue(loopOp.getStep()[0]);
   llvm::Type *ivType = step->getType();
   llvm::Value *chunk = nullptr;
-  if (loop.getScheduleChunkVar()) {
+  if (wsloopOp.getScheduleChunkVar()) {
     llvm::Value *chunkVar =
-        moduleTranslation.lookupValue(loop.getScheduleChunkVar());
+        moduleTranslation.lookupValue(wsloopOp.getScheduleChunkVar());
     chunk = builder.CreateSExtOrTrunc(chunkVar, ivType);
   }
 
   SmallVector<omp::DeclareReductionOp> reductionDecls;
-  collectReductionDecls(loop, reductionDecls);
+  collectReductionDecls(wsloopOp, reductionDecls);
   llvm::OpenMPIRBuilder::InsertPointTy allocaIP =
       findAllocaInsertPoint(builder, moduleTranslation);
 
   SmallVector<llvm::Value *> privateReductionVariables;
   DenseMap<Value, llvm::Value *> reductionVariableMap;
   if (!isByRef) {
-    allocByValReductionVars(loop, builder, moduleTranslation, allocaIP,
+    allocByValReductionVars(wsloopOp, builder, moduleTranslation, allocaIP,
                             reductionDecls, privateReductionVariables,
                             reductionVariableMap);
   }
@@ -952,13 +954,12 @@ convertOmpWsloop(Operation &opInst, llvm::IRBuilderBase &builder,
   // Before the loop, store the initial values of reductions into reduction
   // variables. Although this could be done after allocas, we don't want to mess
   // up with the alloca insertion point.
-  MutableArrayRef<BlockArgument> reductionArgs =
-      loop.getRegion().getArguments().take_back(loop.getNumReductionVars());
-  for (unsigned i = 0; i < loop.getNumReductionVars(); ++i) {
+  ArrayRef<BlockArgument> reductionArgs = wsloopOp.getRegion().getArguments();
+  for (unsigned i = 0; i < wsloopOp.getNumReductionVars(); ++i) {
     SmallVector<llvm::Value *> phis;
 
     // map block argument to initializer region
-    mapInitializationArg(loop, moduleTranslation, reductionDecls, i);
+    mapInitializationArg(wsloopOp, moduleTranslation, reductionDecls, i);
 
     if (failed(inlineConvertOmpRegions(reductionDecls[i].getInitializerRegion(),
                                        "omp.reduction.neutral", builder,
@@ -977,7 +978,7 @@ convertOmpWsloop(Operation &opInst, llvm::IRBuilderBase &builder,
 
       privateReductionVariables.push_back(var);
       moduleTranslation.mapValue(reductionArgs[i], phis[0]);
-      reductionVariableMap.try_emplace(loop.getReductionVars()[i], phis[0]);
+      reductionVariableMap.try_emplace(wsloopOp.getReductionVars()[i], phis[0]);
     } else {
       // for by-ref case the store is inside of the reduction region
       builder.CreateStore(phis[0], privateReductionVariables[i]);
@@ -1008,33 +1009,34 @@ convertOmpWsloop(Operation &opInst, llvm::IRBuilderBase &builder,
   auto bodyGen = [&](llvm::OpenMPIRBuilder::InsertPointTy ip, llvm::Value *iv) {
     // Make sure further conversions know about the induction variable.
     moduleTranslation.mapValue(
-        loop.getRegion().front().getArgument(loopInfos.size()), iv);
+        loopOp.getRegion().front().getArgument(loopInfos.size()), iv);
 
     // Capture the body insertion point for use in nested loops. BodyIP of the
     // CanonicalLoopInfo always points to the beginning of the entry block of
     // the body.
     bodyInsertPoints.push_back(ip);
 
-    if (loopInfos.size() != loop.getNumLoops() - 1)
+    if (loopInfos.size() != loopOp.getNumLoops() - 1)
       return;
 
     // Convert the body of the loop.
     builder.restoreIP(ip);
-    convertOmpOpRegions(loop.getRegion(), "omp.wsloop.region", builder,
+    convertOmpOpRegions(loopOp.getRegion(), "omp.wsloop.region", builder,
                         moduleTranslation, bodyGenStatus);
   };
 
   // Delegate actual loop construction to the OpenMP IRBuilder.
-  // TODO: this currently assumes Wsloop is semantically similar to SCF loop,
-  // i.e. it has a positive step, uses signed integer semantics. Reconsider
-  // this code when Wsloop clearly supports more cases.
+  // TODO: this currently assumes omp.loop_nest is semantically similar to SCF
+  // loop, i.e. it has a positive step, uses signed integer semantics.
+  // Reconsider this code when the nested loop operation clearly supports more
+  // cases.
   llvm::OpenMPIRBuilder *ompBuilder = moduleTranslation.getOpenMPBuilder();
-  for (unsigned i = 0, e = loop.getNumLoops(); i < e; ++i) {
+  for (unsigned i = 0, e = loopOp.getNumLoops(); i < e; ++i) {
     llvm::Value *lowerBound =
-        moduleTranslation.lookupValue(loop.getLowerBound()[i]);
+        moduleTranslation.lookupValue(loopOp.getLowerBound()[i]);
     llvm::Value *upperBound =
-        moduleTranslation.lookupValue(loop.getUpperBound()[i]);
-    llvm::Value *step = moduleTranslation.lookupValue(loop.getStep()[i]);
+        moduleTranslation.lookupValue(loopOp.getUpperBound()[i]);
+    llvm::Value *step = moduleTranslation.lookupValue(loopOp.getStep()[i]);
 
     // Make sure loop trip count are emitted in the preheader of the outermost
     // loop at the latest so that they are all available for the new collapsed
@@ -1047,7 +1049,7 @@ convertOmpWsloop(Operation &opInst, llvm::IRBuilderBase &builder,
     }
     loopInfos.push_back(ompBuilder->createCanonicalLoop(
         loc, bodyGen, lowerBound, upperBound, step,
-        /*IsSigned=*/true, loop.getInclusive(), computeIP));
+        /*IsSigned=*/true, loopOp.getInclusive(), computeIP));
 
     if (failed(bodyGenStatus))
       return failure();
@@ -1062,13 +1064,13 @@ convertOmpWsloop(Operation &opInst, llvm::IRBuilderBase &builder,
   allocaIP = findAllocaInsertPoint(builder, moduleTranslation);
 
   // TODO: Handle doacross loops when the ordered clause has a parameter.
-  bool isOrdered = loop.getOrderedVal().has_value();
+  bool isOrdered = wsloopOp.getOrderedVal().has_value();
   std::optional<omp::ScheduleModifier> scheduleModifier =
-      loop.getScheduleModifier();
-  bool isSimd = loop.getSimdModifier();
+      wsloopOp.getScheduleModifier();
+  bool isSimd = wsloopOp.getSimdModifier();
 
   ompBuilder->applyWorkshareLoop(
-      ompLoc.DL, loopInfo, allocaIP, !loop.getNowait(),
+      ompLoc.DL, loopInfo, allocaIP, !wsloopOp.getNowait(),
       convertToScheduleKind(schedule), chunk, isSimd,
       scheduleModifier == omp::ScheduleModifier::monotonic,
       scheduleModifier == omp::ScheduleModifier::nonmonotonic, isOrdered);
@@ -1080,7 +1082,7 @@ convertOmpWsloop(Operation &opInst, llvm::IRBuilderBase &builder,
   builder.restoreIP(afterIP);
 
   // Process the reductions if required.
-  if (loop.getNumReductionVars() == 0)
+  if (wsloopOp.getNumReductionVars() == 0)
     return success();
 
   // Create the reduction generators. We need to own them here because
@@ -1088,7 +1090,7 @@ convertOmpWsloop(Operation &opInst, llvm::IRBuilderBase &builder,
   SmallVector<OwningReductionGen> owningReductionGens;
   SmallVector<OwningAtomicReductionGen> owningAtomicReductionGens;
   SmallVector<llvm::OpenMPIRBuilder::ReductionInfo> reductionInfos;
-  collectReductionInfo(loop, builder, moduleTranslation, reductionDecls,
+  collectReductionInfo(wsloopOp, builder, moduleTranslation, reductionDecls,
                        owningReductionGens, owningAtomicReductionGens,
                        privateReductionVariables, reductionInfos);
 
@@ -1099,9 +1101,9 @@ convertOmpWsloop(Operation &opInst, llvm::IRBuilderBase &builder,
   builder.SetInsertPoint(tempTerminator);
   llvm::OpenMPIRBuilder::InsertPointTy contInsertPoint =
       ompBuilder->createReductions(builder.saveIP(), allocaIP, reductionInfos,
-                                   loop.getNowait(), isByRef);
+                                   wsloopOp.getNowait(), isByRef);
   if (!contInsertPoint.getBlock())
-    return loop->emitOpError() << "failed to convert reductions";
+    return wsloopOp->emitOpError() << "failed to convert reductions";
   auto nextInsertionPoint =
       ompBuilder->createBarrier(contInsertPoint, llvm::omp::OMPD_for);
   tempTerminator->eraseFromParent();
diff --git a/mlir/test/Target/LLVMIR/omptarget-parallel-wsloop.mlir b/mlir/test/Target/LLVMIR/omptarget-parallel-wsloop.mlir
index b0fe642238f14f..360b3b0c0e60c1 100644
--- a/mlir/test/Target/LLVMIR/omptarget-parallel-wsloop.mlir
+++ b/mlir/test/Target/LLVMIR/omptarget-parallel-wsloop.mlir
@@ -12,10 +12,13 @@ module attributes {dlti.dl_spec = #dlti.dl_spec<#dlti.dl_entry<"dlti.alloca_memo
       %loop_ub = llvm.mlir.constant(9 : i32) : i32
       %loop_lb = llvm.mlir.constant(0 : i32) : i32
       %loop_step = llvm.mlir.constant(1 : i32) : i32
-      omp.wsloop for  (%loop_cnt) : i32 = (%loop_lb) to (%loop_ub) inclusive step (%loop_step) {
-        %gep = llvm.getelementptr %arg0[0, %loop_cnt] : (!llvm.ptr, i32) -> !llvm.ptr, !llvm.array<10 x i32>
-        llvm.store %loop_cnt, %gep : i32, !llvm.ptr
-        omp.yield
+      omp.wsloop {
+        omp.loop_nest (%loop_cnt) : i32 = (%loop_lb) to (%loop_ub) inclusive step (%loop_step) {
+          %gep = llvm.getelementptr %arg0[0, %loop_cnt] : (!llvm.ptr, i32) -> !llvm.ptr, !llvm.array<10 x i32>
+          llvm.store %loop_cnt, %gep : i32, !llvm.ptr
+          omp.yield
+        }
+        omp.terminator
       }
      omp.terminator
     }
diff --git a/mlir/test/Target/LLVMIR/omptarget-wsloop-collapsed.mlir b/mlir/test/Target/LLVMIR/omptarget-wsloop-collapsed.mlir
index 0d77423abcb4f1..13d34b7e58f77e 100644
--- a/mlir/test/Target/LLVMIR/omptarget-wsloop-collapsed.mlir
+++ b/mlir/test/Target/LLVMIR/omptarget-wsloop-collapsed.mlir
@@ -8,13 +8,16 @@ module attributes {dlti.dl_spec = #dlti.dl_spec<#dlti.dl_entry<"dlti.alloca_memo
     %loop_ub = llvm.mlir.constant(99 : i32) : i32
     %loop_lb = llvm.mlir.constant(0 : i32) : i32
     %loop_step = llvm.mlir.constant(1 : index) : i32
-    omp.wsloop for  (%arg1, %arg2) : i32 = (%loop_lb, %loop_lb) to (%loop_ub, %loop_ub) inclusive step (%loop_step, %loop_step) {
-      %1 = llvm.add %arg1, %arg2  : i32
-      %2 = llvm.mul %arg2, %loop_ub overflow<nsw>  : i32
-      %3 = llvm.add %arg1, %2 :i32
-      %4 = llvm.getelementptr %arg0[%3] : (!llvm.ptr, i32) -> !llvm.ptr, i32
-      llvm.store %1, %4 : i32, !llvm.ptr
-      omp.yield
+    omp.wsloop {
+      omp.loop_nest (%arg1, %arg2) : i32 = (%loop_lb, %loop_lb) to (%loop_ub, %loop_ub) inclusive step (%loop_step, %loop_step) {
+        %1 = llvm.add %arg1, %arg2  : i32
+        %2 = llvm.mul %arg2, %loop_ub overflow<nsw>  : i32
+        %3 = llvm.add %arg1, %2 :i32
+        %4 = llvm.getelementptr %arg0[%3] : (!llvm.ptr, i32) -> !llvm.ptr, i32
+        llvm.store %1, %4 : i32, !llvm.ptr
+        omp.yield
+      }
+      omp.terminator
     }
     llvm.return
   }
diff --git a/mlir/test/Target/LLVMIR/omptarget-wsloop.mlir b/mlir/test/Target/LLVMIR/omptarget-wsloop.mlir
index 0f3f503dfa5377..ee851eaf71ac0b 100644
--- a/mlir/test/Target/LLVMIR/omptarget-wsloop.mlir
+++ b/mlir/test/Target/LLVMIR/omptarget-wsloop.mlir
@@ -8,10 +8,13 @@ module attributes {dlti.dl_spec = #dlti.dl_spec<#dlti.dl_entry<"dlti.alloca_memo
       %loop_ub = llvm.mlir.constant(9 : i32) : i32
       %loop_lb = llvm.mlir.constant(0 : i32) : i32
       %loop_step = llvm.mlir.constant(1 : i32) : i32
-      omp.wsloop for  (%loop_cnt) : i32 = (%loop_lb) to (%loop_ub) inclusive step (%loop_step) {
-        %gep = llvm.getelementptr %arg0[0, %loop_cnt] : (!llvm.ptr, i32) -> !llvm.ptr, !llvm.array<10 x i32>
-        llvm.store %loop_cnt, %gep : i32, !llvm.ptr
-        omp.yield
+      omp.wsloop {
+        omp.loop_nest (%loop_cnt) : i32 = (%loop_lb) to (%loop_ub) inclusive step (%loop_step) {
+          %gep = llvm.getelementptr %arg0[0, %loop_cnt] : (!llvm.ptr, i32) -> !llvm.ptr, !llvm.array<10 x i32>
+          llvm.store %loop_cnt, %gep : i32, !llvm.ptr
+          omp.yield
+        }
+        omp.terminator
       }
     llvm.return
   }
@@ -20,8 +23,11 @@ module attributes {dlti.dl_spec = #dlti.dl_spec<#dlti.dl_entry<"dlti.alloca_memo
       %loop_ub = llvm.mlir.constant(9 : i32) : i32
       %loop_lb = llvm.mlir.constant(0 : i32) : i32
       %loop_step = llvm.mlir.constant(1 : i32) : i32
-      omp.wsloop for  (%loop_cnt) : i32 = (%loop_lb) to (%loop_ub) inclusive step (%loop_step) {
-        omp.yield
+      omp.wsloop {
+        omp.loop_nest (%loop_cnt) : i32 = (%loop_lb) to (%loop_ub) inclusive step (%loop_step) {
+          omp.yield
+        }
+        omp.terminator
       }
     llvm.return
   }
diff --git a/mlir/test/Target/LLVMIR/openmp-data-target-device.mlir b/mlir/test/Target/LLVMIR/openmp-data-target-device.mlir
index d41429a6de066f..4ea9df369af66c 100644
--- a/mlir/test/Target/LLVMIR/openmp-data-target-device.mlir
+++ b/mlir/test/Target/LLVMIR/openmp-data-target-device.mlir
@@ -31,20 +31,23 @@ module attributes { } {
           %18 = llvm.mlir.constant(1 : i64) : i64
           %19 = llvm.alloca %18 x i32 {pinned} : (i64) -> !llvm.ptr<5>
           %20 = llvm.addrspacecast %19 : !llvm.ptr<5> to !llvm.ptr
-          omp.wsloop  for  (%arg2) : i32 = (%16) to (%15) inclusive step (%16) {
-            llvm.store %arg2, %20 : i32, !llvm.ptr
-            %21 = llvm.load %20 : !llvm.ptr -> i32
-            %22 = llvm.sext %21 : i32 to i64
-            %23 = llvm.mlir.constant(1 : i64) : i64
-            %24 = llvm.mlir.constant(0 : i64) : i64
-            %25 = llvm.sub %22, %23 overflow<nsw>  : i64
-            %26 = llvm.mul %25, %23 overflow<nsw>  : i64
-            %27 = llvm.mul %26, %23 overflow<nsw>  : i64
-            %28 = llvm.add %27, %24 overflow<nsw>  : i64
-            %29 = llvm.mul %23, %17 overflow<nsw>  : i64
-            %30 = llvm.getelementptr %arg0[%28] : (!llvm.ptr, i64) -> !llvm.ptr, i32
-            llvm.store %21, %30 : i32, !llvm.ptr
-            omp.yield
+          omp.wsloop {
+            omp.loop_nest (%arg2) : i32 = (%16) to (%15) inclusive step (%16) {
+              llvm.store %arg2, %20 : i32, !llvm.ptr
+              %21 = llvm.load %20 : !llvm.ptr -> i32
+              %22 = llvm.sext %21 : i32 to i64
+              %23 = llvm.mlir.constant(1 : i64) : i64
+              %24 = llvm.mlir.constant(0 : i64) : i64
+              %25 = llvm.sub %22, %23 overflow<nsw>  : i64
+              %26 = llvm.mul %25, %23 overflow<nsw>  : i64
+              %27 = llvm.mul %26, %23 overflow<nsw>  : i64
+              %28 = llvm.add %27, %24 overflow<nsw>  : i64
+              %29 = llvm.mul %23, %17 overflow<nsw>  : i64
+              %30 = llvm.getelementptr %arg0[%28] : (!llvm.ptr, i64) -> !llvm.ptr, i32
+              llvm.store %21, %30 : i32, !llvm.ptr
+              omp.yield
+            }
+            omp.terminator
           }
           omp.terminator
         }
diff --git a/mlir/test/Target/LLVMIR/openmp-llvm.mlir b/mlir/test/Target/LLVMIR/openmp-llvm.mlir
index d1390022c1dc44..ad40ca26bec9f8 100644
--- a/mlir/test/Target/LLVMIR/openmp-llvm.mlir
+++ b/mlir/test/Target/LLVMIR/openmp-llvm.mlir
@@ -320,18 +320,20 @@ llvm.func @wsloop_simple(%arg0: !llvm.ptr) {
   %1 = llvm.mlir.constant(10 : index) : i64
   %2 = llvm.mlir.constant(1 : index) : i64
   omp.parallel {
-    "omp.wsloop"(%1, %0, %2) ({
-    ^bb0(%arg1: i64):
-      // The form of the emitted IR is controlled by OpenMPIRBuilder and
-      // tested there. Just check that the right functions are called.
-      // CHECK: call i32 @__kmpc_global_thread_num
-      // CHECK: call void @__kmpc_for_static_init_{{.*}}(ptr @[[$loc_struct]],
-      %3 = llvm.mlir.constant(2.000000e+00 : f32) : f32
-      %4 = llvm.getelementptr %arg0[%arg1] : (!llvm.ptr, i64) -> !llvm.ptr, f32
-      llvm.store %3, %4 : f32, !llvm.ptr
-      omp.yield
+    "omp.wsloop"() ({
+      omp.loop_nest (%arg1) : i64 = (%1) to (%0) step (%2) {
+        // The form of the emitted IR is controlled by OpenMPIRBuilder and
+        // tested there. Just check that the right functions are called.
+        // CHECK: call i32 @__kmpc_global_thread_num
+        // CHECK: call void @__kmpc_for_static_init_{{.*}}(ptr @[[$loc_struct]],
+        %3 = llvm.mlir.constant(2.000000e+00 : f32) : f32
+        %4 = llvm.getelementptr %arg0[%arg1] : (!llvm.ptr, i64) -> !llvm.ptr, f32
+        llvm.store %3, %4 : f32, !llvm.ptr
+        omp.yield
+      }
+      omp.terminator
       // CHECK: call void @__kmpc_for_static_fini(ptr @[[$loc_struct]],
-    }) {operandSegmentSizes = array<i32: 1, 1, 1, 0, 0, 0, 0>} : (i64, i64, i64) -> ()
+    }) : () -> ()
     omp.terminator
   }
   llvm.return
@@ -345,13 +347,15 @@ llvm.func @wsloop_inclusive_1(%arg0: !llvm.ptr) {
   %1 = llvm.mlir.constant(10 : index) : i64
   %2 = llvm.mlir.constant(1 : index) : i64
   // CHECK: store i64 31, ptr %{{.*}}upperbound
-  "omp.wsloop"(%1, %0, %2) ({
-  ^bb0(%arg1: i64):
-    %3 = llvm.mlir.constant(2.000000e+00 : f32) : f32
-    %4 = llvm.getelementptr %arg0[%arg1] : (!llvm.ptr, i64) -> !llvm.ptr, f32
-    llvm.store %3, %4 : f32, !llvm.ptr
-    omp.yield
-  }) {operandSegmentSizes = array<i32: 1, 1, 1, 0, 0, 0, 0>} : (i64, i64, i64) -> ()
+  "omp.wsloop"() ({
+    omp.loop_nest (%arg1) : i64 = (%1) to (%0) step (%2) {
+      %3 = llvm.mlir.constant(2.000000e+00 : f32) : f32
+      %4 = llvm.getelementptr %arg0[%arg1] : (!llvm.ptr, i64) -> !llvm.ptr, f32
+      llvm.store %3, %4 : f32, !llvm.ptr
+      omp.yield
+    }
+    omp.terminator
+  }) : () -> ()
   llvm.return
 }
 
@@ -363,13 +367,15 @@ llvm.func @wsloop_inclusive_2(%arg0: !llvm.ptr) {
   %1 = llvm.mlir.constant(10 : index) : i64
   %2 = llvm.mlir.constant(1 : index) : i64
   // CHECK: store i64 32, ptr %{{.*}}upperbound
-  "omp.wsloop"(%1, %0, %2) ({
-  ^bb0(%arg1: i64):
-    %3 = llvm.mlir.constant(2.000000e+00 : f32) : f32
-    %4 = llvm.getelementptr %arg0[%arg1] : (!llvm.ptr, i64) -> !llvm.ptr, f32
-    llvm.store %3, %4 : f32, !llvm.ptr
-    omp.yield
-  }) {inclusive, operandSegmentSizes = array<i32: 1, 1, 1, 0, 0, 0, 0>} : (i64, i64, i64) -> ()
+  "omp.wsloop"() ({
+    omp.loop_nest (%arg1) : i64 = (%1) to (%0) inclusive step (%2) {
+      %3 = llvm.mlir.constant(2.000000e+00 : f32) : f32
+      %4 = llvm.getelementptr %arg0[%arg1] : (!llvm.ptr, i64) -> !llvm.ptr, f32
+      llvm.store %3, %4 : f32, !llvm.ptr
+      omp.yield
+    }
+    omp.terminator
+  }) : () -> ()
   llvm.return
 }
 
@@ -379,14 +385,16 @@ llvm.func @body(i32)
 
 // CHECK-LABEL: @test_omp_wsloop_static_defchunk
 llvm.func @test_omp_wsloop_static_defchunk(%lb : i32, %ub : i32, %step : i32) -> () {
- omp.wsloop schedule(static)
- for (%iv) : i32 = (%lb) to (%ub) step (%step) {
-   // CHECK: call void @__kmpc_for_static_init_4u(ptr @{{.*}}, i32 %{{.*}}, i32 34, ptr %{{.*}}, ptr %{{.*}}, ptr %{{.*}}, ptr %{{.*}}, i32 1, i32 0)
-   // CHECK: call void @__kmpc_for_static_fini
-   llvm.call @body(%iv) : (i32) -> ()
-   omp.yield
- }
- llvm.return
+  omp.wsloop schedule(static) {
+    omp.loop_nest (%iv) : i32 = (%lb) to (%ub) step (%step) {
+      // CHECK: call void @__kmpc_for_static_init_4u(ptr @{{.*}}, i32 %{{.*}}, i32 34, ptr %{{.*}}, ptr %{{.*}}, ptr %{{.*}}, ptr %{{.*}}, i32 1, i32 0)
+      // CHECK: call void @__kmpc_for_static_fini
+      llvm.call @body(%iv) : (i32) -> ()
+      omp.yield
+    }
+    omp.terminator
+  }
+  llvm.return
 }
 
 // -----
@@ -395,15 +403,17 @@ llvm.func @body(i32)
 
 // CHECK-...
[truncated]

Copy link
Contributor

@tblah tblah left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

Comment on lines +919 to +920
auto wsloopOp = cast<omp::WsloopOp>(opInst);
auto loopOp = cast<omp::LoopNestOp>(wsloopOp.getWrappedLoop());
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Will this be updated for when omp.wsloop { omp.simd { omp.loop_nest ... is going to be supported? Does the current verifier reject this?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, this won't work for composite constructs. The current verifiers allow the composite constructs that are supported by the spec, although they aren't supported by Flang lowering or by the translation from MLIR to LLVM IR yet. That's something that will come later on, the purpose of this patch is just to keep everything that worked before working after the changes to omp.wsloop.

Comment on lines +1029 to +1032
// TODO: this currently assumes omp.loop_nest is semantically similar to SCF
// loop, i.e. it has a positive step, uses signed integer semantics.
// Reconsider this code when the nested loop operation clearly supports more
// cases.
Copy link
Member

@Meinersbur Meinersbur Apr 18, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't understand the context of this comment. scf.for is not used here. OpenMPIRBuilder::createCanonicalLoop should support all these cases.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This was a comment that already existed before. The only changes I made to it are to make it refer to the right operation (it mentioned omp.wsloop, but the restriction applies to omp.loop_nest now). As far as I can understand it, what this comes to say is that regardless of whether the OpenMPIRBuilder could represent more cases, the omp.loop_nest operation itself is similar to scf.for, so some assumptions are made in that regard below. When omp.loop_nest is replaced by a full-fledged omp.canonical_loop then these assumptions will no longer apply, so that code will have to be revised.

mlir/test/Target/LLVMIR/openmp-llvm.mlir Show resolved Hide resolved
Copy link
Member

@Meinersbur Meinersbur left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

Base automatically changed from users/skatrak/spr/wsloop-wrapper-03-scf-parallel to main April 24, 2024 13:29
@skatrak skatrak merged commit 2e37f28 into main Apr 24, 2024
2 of 3 checks passed
@skatrak skatrak deleted the users/skatrak/spr/wsloop-wrapper-04-llvm-ir branch April 24, 2024 13:29
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

4 participants