Skip to content

Conversation

skatrak
Copy link
Member

@skatrak skatrak commented Mar 11, 2025

This patch introduces a use for the new getBlockArgsPairs to avoid having to manually list each applicable clause.

Also, the numClauseBlockArgs() function is introduced, which simplifies the implementation of the interface's verifier and enables better memory handling within getBlockArgsPairs.

@llvmbot
Copy link
Member

llvmbot commented Mar 11, 2025

@llvm/pr-subscribers-flang-openmp

@llvm/pr-subscribers-mlir

Author: Sergio Afonso (skatrak)

Changes

This patch introduces a use for the new getBlockArgsPairs to avoid having to manually list each applicable clause.

Also, the numClauseBlockArgs() function is introduced, which simplifies the implementation of the interface's verifier and enables better memory handling within getBlockArgsPairs.


Full diff: https://github.com/llvm/llvm-project/pull/130789.diff

3 Files Affected:

  • (modified) mlir/docs/Dialects/OpenMPDialect/_index.md (+2)
  • (modified) mlir/include/mlir/Dialect/OpenMP/OpenMPOpsInterfaces.td (+9-5)
  • (modified) mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp (+5-7)
diff --git a/mlir/docs/Dialects/OpenMPDialect/_index.md b/mlir/docs/Dialects/OpenMPDialect/_index.md
index adde176750437..1df80fac2a684 100644
--- a/mlir/docs/Dialects/OpenMPDialect/_index.md
+++ b/mlir/docs/Dialects/OpenMPDialect/_index.md
@@ -372,6 +372,8 @@ accessed:
   should be located.
   - `get<ClauseName>BlockArgs()`: Returns the list of entry block arguments
   defined by the given clause.
+  - `numClauseBlockArgs()`: Returns the total number of entry block arguments
+  defined by all clauses.
   - `getBlockArgsPairs()`: Returns a list of pairs where the first element is
   the outside value, or operand, and the second element is the corresponding
   entry block argument.
diff --git a/mlir/include/mlir/Dialect/OpenMP/OpenMPOpsInterfaces.td b/mlir/include/mlir/Dialect/OpenMP/OpenMPOpsInterfaces.td
index 0766b4e8d1472..3fa54d35ed09b 100644
--- a/mlir/include/mlir/Dialect/OpenMP/OpenMPOpsInterfaces.td
+++ b/mlir/include/mlir/Dialect/OpenMP/OpenMPOpsInterfaces.td
@@ -136,12 +136,20 @@ def BlockArgOpenMPOpInterface : OpInterface<"BlockArgOpenMPOpInterface"> {
     !foreach(clause, clauses, clause.startMethod),
     !foreach(clause, clauses, clause.blockArgsMethod),
     [
+      InterfaceMethod<
+        "Get the total number of clause-defined entry block arguments",
+        "unsigned", "numClauseBlockArgs", (ins),
+        "return " # !interleave(
+          !foreach(clause, clauses, "$_op." # clause.numArgsMethod.name # "()"),
+          " + ") # ";"
+      >,
       InterfaceMethod<
         "Populate a vector of pairs representing the matching between operands "
         "and entry block arguments.", "void", "getBlockArgsPairs",
         (ins "::llvm::SmallVectorImpl<std::pair<::mlir::Value, ::mlir::BlockArgument>> &" : $pairs),
         [{
           auto iface = ::llvm::cast<BlockArgOpenMPOpInterface>(*$_op);
+          pairs.reserve(pairs.size() + iface.numClauseBlockArgs());
         }] # !interleave(!foreach(clause, clauses, [{
         }] # "if (iface." # clause.numArgsMethod.name # "() > 0) {" # [{
         }] # "  for (auto [var, arg] : ::llvm::zip_equal(" #
@@ -155,11 +163,7 @@ def BlockArgOpenMPOpInterface : OpInterface<"BlockArgOpenMPOpInterface"> {
 
   let verify = [{
     auto iface = ::llvm::cast<BlockArgOpenMPOpInterface>($_op);
-  }] # "unsigned expectedArgs = "
-     # !interleave(
-         !foreach(clause, clauses, "iface." # clause.numArgsMethod.name # "()"),
-         " + "
-       ) # ";" # [{
+    unsigned expectedArgs = iface.numClauseBlockArgs();
     if ($_op->getRegion(0).getNumArguments() < expectedArgs)
       return $_op->emitOpError() << "expected at least " << expectedArgs
                                  << " entry block argument(s)";
diff --git a/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp b/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp
index 3373f19a006ba..b9893716980fe 100644
--- a/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp
+++ b/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp
@@ -550,18 +550,16 @@ convertIgnoredWrapper(omp::LoopWrapperInterface opInst,
   // corresponding operand. This is semantically equivalent to this wrapper not
   // being present.
   auto forwardArgs =
-      [&moduleTranslation](llvm::ArrayRef<BlockArgument> blockArgs,
-                           OperandRange operands) {
-        for (auto [arg, var] : llvm::zip_equal(blockArgs, operands))
+      [&moduleTranslation](omp::BlockArgOpenMPOpInterface blockArgIface) {
+        llvm::SmallVector<std::pair<Value, BlockArgument>> blockArgsPairs;
+        blockArgIface.getBlockArgsPairs(blockArgsPairs);
+        for (auto [var, arg] : blockArgsPairs)
           moduleTranslation.mapValue(arg, moduleTranslation.lookupValue(var));
       };
 
   return llvm::TypeSwitch<Operation *, LogicalResult>(opInst)
       .Case([&](omp::SimdOp op) {
-        auto blockArgIface = cast<omp::BlockArgOpenMPOpInterface>(*op);
-        forwardArgs(blockArgIface.getPrivateBlockArgs(), op.getPrivateVars());
-        forwardArgs(blockArgIface.getReductionBlockArgs(),
-                    op.getReductionVars());
+        forwardArgs(cast<omp::BlockArgOpenMPOpInterface>(*op));
         op.emitWarning() << "simd information on composite construct discarded";
         return success();
       })

@llvmbot
Copy link
Member

llvmbot commented Mar 11, 2025

@llvm/pr-subscribers-mlir-llvm

Author: Sergio Afonso (skatrak)

Changes

This patch introduces a use for the new getBlockArgsPairs to avoid having to manually list each applicable clause.

Also, the numClauseBlockArgs() function is introduced, which simplifies the implementation of the interface's verifier and enables better memory handling within getBlockArgsPairs.


Full diff: https://github.com/llvm/llvm-project/pull/130789.diff

3 Files Affected:

  • (modified) mlir/docs/Dialects/OpenMPDialect/_index.md (+2)
  • (modified) mlir/include/mlir/Dialect/OpenMP/OpenMPOpsInterfaces.td (+9-5)
  • (modified) mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp (+5-7)
diff --git a/mlir/docs/Dialects/OpenMPDialect/_index.md b/mlir/docs/Dialects/OpenMPDialect/_index.md
index adde176750437..1df80fac2a684 100644
--- a/mlir/docs/Dialects/OpenMPDialect/_index.md
+++ b/mlir/docs/Dialects/OpenMPDialect/_index.md
@@ -372,6 +372,8 @@ accessed:
   should be located.
   - `get<ClauseName>BlockArgs()`: Returns the list of entry block arguments
   defined by the given clause.
+  - `numClauseBlockArgs()`: Returns the total number of entry block arguments
+  defined by all clauses.
   - `getBlockArgsPairs()`: Returns a list of pairs where the first element is
   the outside value, or operand, and the second element is the corresponding
   entry block argument.
diff --git a/mlir/include/mlir/Dialect/OpenMP/OpenMPOpsInterfaces.td b/mlir/include/mlir/Dialect/OpenMP/OpenMPOpsInterfaces.td
index 0766b4e8d1472..3fa54d35ed09b 100644
--- a/mlir/include/mlir/Dialect/OpenMP/OpenMPOpsInterfaces.td
+++ b/mlir/include/mlir/Dialect/OpenMP/OpenMPOpsInterfaces.td
@@ -136,12 +136,20 @@ def BlockArgOpenMPOpInterface : OpInterface<"BlockArgOpenMPOpInterface"> {
     !foreach(clause, clauses, clause.startMethod),
     !foreach(clause, clauses, clause.blockArgsMethod),
     [
+      InterfaceMethod<
+        "Get the total number of clause-defined entry block arguments",
+        "unsigned", "numClauseBlockArgs", (ins),
+        "return " # !interleave(
+          !foreach(clause, clauses, "$_op." # clause.numArgsMethod.name # "()"),
+          " + ") # ";"
+      >,
       InterfaceMethod<
         "Populate a vector of pairs representing the matching between operands "
         "and entry block arguments.", "void", "getBlockArgsPairs",
         (ins "::llvm::SmallVectorImpl<std::pair<::mlir::Value, ::mlir::BlockArgument>> &" : $pairs),
         [{
           auto iface = ::llvm::cast<BlockArgOpenMPOpInterface>(*$_op);
+          pairs.reserve(pairs.size() + iface.numClauseBlockArgs());
         }] # !interleave(!foreach(clause, clauses, [{
         }] # "if (iface." # clause.numArgsMethod.name # "() > 0) {" # [{
         }] # "  for (auto [var, arg] : ::llvm::zip_equal(" #
@@ -155,11 +163,7 @@ def BlockArgOpenMPOpInterface : OpInterface<"BlockArgOpenMPOpInterface"> {
 
   let verify = [{
     auto iface = ::llvm::cast<BlockArgOpenMPOpInterface>($_op);
-  }] # "unsigned expectedArgs = "
-     # !interleave(
-         !foreach(clause, clauses, "iface." # clause.numArgsMethod.name # "()"),
-         " + "
-       ) # ";" # [{
+    unsigned expectedArgs = iface.numClauseBlockArgs();
     if ($_op->getRegion(0).getNumArguments() < expectedArgs)
       return $_op->emitOpError() << "expected at least " << expectedArgs
                                  << " entry block argument(s)";
diff --git a/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp b/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp
index 3373f19a006ba..b9893716980fe 100644
--- a/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp
+++ b/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp
@@ -550,18 +550,16 @@ convertIgnoredWrapper(omp::LoopWrapperInterface opInst,
   // corresponding operand. This is semantically equivalent to this wrapper not
   // being present.
   auto forwardArgs =
-      [&moduleTranslation](llvm::ArrayRef<BlockArgument> blockArgs,
-                           OperandRange operands) {
-        for (auto [arg, var] : llvm::zip_equal(blockArgs, operands))
+      [&moduleTranslation](omp::BlockArgOpenMPOpInterface blockArgIface) {
+        llvm::SmallVector<std::pair<Value, BlockArgument>> blockArgsPairs;
+        blockArgIface.getBlockArgsPairs(blockArgsPairs);
+        for (auto [var, arg] : blockArgsPairs)
           moduleTranslation.mapValue(arg, moduleTranslation.lookupValue(var));
       };
 
   return llvm::TypeSwitch<Operation *, LogicalResult>(opInst)
       .Case([&](omp::SimdOp op) {
-        auto blockArgIface = cast<omp::BlockArgOpenMPOpInterface>(*op);
-        forwardArgs(blockArgIface.getPrivateBlockArgs(), op.getPrivateVars());
-        forwardArgs(blockArgIface.getReductionBlockArgs(),
-                    op.getReductionVars());
+        forwardArgs(cast<omp::BlockArgOpenMPOpInterface>(*op));
         op.emitWarning() << "simd information on composite construct discarded";
         return success();
       })

@llvmbot
Copy link
Member

llvmbot commented Mar 11, 2025

@llvm/pr-subscribers-mlir-openmp

Author: Sergio Afonso (skatrak)

Changes

This patch introduces a use for the new getBlockArgsPairs to avoid having to manually list each applicable clause.

Also, the numClauseBlockArgs() function is introduced, which simplifies the implementation of the interface's verifier and enables better memory handling within getBlockArgsPairs.


Full diff: https://github.com/llvm/llvm-project/pull/130789.diff

3 Files Affected:

  • (modified) mlir/docs/Dialects/OpenMPDialect/_index.md (+2)
  • (modified) mlir/include/mlir/Dialect/OpenMP/OpenMPOpsInterfaces.td (+9-5)
  • (modified) mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp (+5-7)
diff --git a/mlir/docs/Dialects/OpenMPDialect/_index.md b/mlir/docs/Dialects/OpenMPDialect/_index.md
index adde176750437..1df80fac2a684 100644
--- a/mlir/docs/Dialects/OpenMPDialect/_index.md
+++ b/mlir/docs/Dialects/OpenMPDialect/_index.md
@@ -372,6 +372,8 @@ accessed:
   should be located.
   - `get<ClauseName>BlockArgs()`: Returns the list of entry block arguments
   defined by the given clause.
+  - `numClauseBlockArgs()`: Returns the total number of entry block arguments
+  defined by all clauses.
   - `getBlockArgsPairs()`: Returns a list of pairs where the first element is
   the outside value, or operand, and the second element is the corresponding
   entry block argument.
diff --git a/mlir/include/mlir/Dialect/OpenMP/OpenMPOpsInterfaces.td b/mlir/include/mlir/Dialect/OpenMP/OpenMPOpsInterfaces.td
index 0766b4e8d1472..3fa54d35ed09b 100644
--- a/mlir/include/mlir/Dialect/OpenMP/OpenMPOpsInterfaces.td
+++ b/mlir/include/mlir/Dialect/OpenMP/OpenMPOpsInterfaces.td
@@ -136,12 +136,20 @@ def BlockArgOpenMPOpInterface : OpInterface<"BlockArgOpenMPOpInterface"> {
     !foreach(clause, clauses, clause.startMethod),
     !foreach(clause, clauses, clause.blockArgsMethod),
     [
+      InterfaceMethod<
+        "Get the total number of clause-defined entry block arguments",
+        "unsigned", "numClauseBlockArgs", (ins),
+        "return " # !interleave(
+          !foreach(clause, clauses, "$_op." # clause.numArgsMethod.name # "()"),
+          " + ") # ";"
+      >,
       InterfaceMethod<
         "Populate a vector of pairs representing the matching between operands "
         "and entry block arguments.", "void", "getBlockArgsPairs",
         (ins "::llvm::SmallVectorImpl<std::pair<::mlir::Value, ::mlir::BlockArgument>> &" : $pairs),
         [{
           auto iface = ::llvm::cast<BlockArgOpenMPOpInterface>(*$_op);
+          pairs.reserve(pairs.size() + iface.numClauseBlockArgs());
         }] # !interleave(!foreach(clause, clauses, [{
         }] # "if (iface." # clause.numArgsMethod.name # "() > 0) {" # [{
         }] # "  for (auto [var, arg] : ::llvm::zip_equal(" #
@@ -155,11 +163,7 @@ def BlockArgOpenMPOpInterface : OpInterface<"BlockArgOpenMPOpInterface"> {
 
   let verify = [{
     auto iface = ::llvm::cast<BlockArgOpenMPOpInterface>($_op);
-  }] # "unsigned expectedArgs = "
-     # !interleave(
-         !foreach(clause, clauses, "iface." # clause.numArgsMethod.name # "()"),
-         " + "
-       ) # ";" # [{
+    unsigned expectedArgs = iface.numClauseBlockArgs();
     if ($_op->getRegion(0).getNumArguments() < expectedArgs)
       return $_op->emitOpError() << "expected at least " << expectedArgs
                                  << " entry block argument(s)";
diff --git a/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp b/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp
index 3373f19a006ba..b9893716980fe 100644
--- a/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp
+++ b/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp
@@ -550,18 +550,16 @@ convertIgnoredWrapper(omp::LoopWrapperInterface opInst,
   // corresponding operand. This is semantically equivalent to this wrapper not
   // being present.
   auto forwardArgs =
-      [&moduleTranslation](llvm::ArrayRef<BlockArgument> blockArgs,
-                           OperandRange operands) {
-        for (auto [arg, var] : llvm::zip_equal(blockArgs, operands))
+      [&moduleTranslation](omp::BlockArgOpenMPOpInterface blockArgIface) {
+        llvm::SmallVector<std::pair<Value, BlockArgument>> blockArgsPairs;
+        blockArgIface.getBlockArgsPairs(blockArgsPairs);
+        for (auto [var, arg] : blockArgsPairs)
           moduleTranslation.mapValue(arg, moduleTranslation.lookupValue(var));
       };
 
   return llvm::TypeSwitch<Operation *, LogicalResult>(opInst)
       .Case([&](omp::SimdOp op) {
-        auto blockArgIface = cast<omp::BlockArgOpenMPOpInterface>(*op);
-        forwardArgs(blockArgIface.getPrivateBlockArgs(), op.getPrivateVars());
-        forwardArgs(blockArgIface.getReductionBlockArgs(),
-                    op.getReductionVars());
+        forwardArgs(cast<omp::BlockArgOpenMPOpInterface>(*op));
         op.emitWarning() << "simd information on composite construct discarded";
         return success();
       })

Copy link
Member

@Meinersbur Meinersbur left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

Copy link
Contributor

@tblah tblah left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM, thanks

Copy link
Member

@ergawy ergawy left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

Base automatically changed from users/skatrak/omp-blockarg-iface-operands to main March 12, 2025 11:50
This patch introduces a use for the new `getBlockArgsPairs` to avoid having to
manually list each applicable clause.

Also, the `numClauseBlockArgs()` function is introduced, which simplifies the
implementation of the interface's verifier and enables better memory handling
within `getBlockArgsPairs`.
@skatrak skatrak force-pushed the users/skatrak/omp-blockarg-iface-uses branch from 270790b to cc7cb76 Compare March 12, 2025 11:53
@skatrak skatrak merged commit 6ff33ed into main Mar 13, 2025
11 checks passed
@skatrak skatrak deleted the users/skatrak/omp-blockarg-iface-uses branch March 13, 2025 14:48
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

5 participants