Skip to content

[mlir][acc] Add canonicalization patterns for compute_region#192376

Merged
razvanlupusoru merged 3 commits into
llvm:mainfrom
razvanlupusoru:acccomputecanonical
Apr 16, 2026
Merged

[mlir][acc] Add canonicalization patterns for compute_region#192376
razvanlupusoru merged 3 commits into
llvm:mainfrom
razvanlupusoru:acccomputecanonical

Conversation

@razvanlupusoru
Copy link
Copy Markdown
Contributor

This PR improves the APIs for navigating through acc.compute_region block arguments and also adds canonicalization patterns for those arguments to remove unused ones and merge duplicates.

This PR improves the APIs for navigating through acc.compute_region
block arguments and also adds canonicalization patterns for those
arguments to remove unused ones and merge duplicates.
@llvmbot
Copy link
Copy Markdown
Member

llvmbot commented Apr 16, 2026

@llvm/pr-subscribers-mlir-openacc
@llvm/pr-subscribers-openacc

@llvm/pr-subscribers-mlir

Author: Razvan Lupusoru (razvanlupusoru)

Changes

This PR improves the APIs for navigating through acc.compute_region block arguments and also adds canonicalization patterns for those arguments to remove unused ones and merge duplicates.


Full diff: https://github.com/llvm/llvm-project/pull/192376.diff

5 Files Affected:

  • (modified) mlir/include/mlir/Dialect/OpenACC/OpenACCCGOps.td (+17-2)
  • (modified) mlir/lib/Dialect/OpenACC/IR/OpenACCCG.cpp (+111-2)
  • (modified) mlir/lib/Dialect/OpenACC/Utils/OpenACCUtils.cpp (+2)
  • (added) mlir/test/Dialect/OpenACC/compute-region-canonicalize.mlir (+78)
  • (modified) mlir/unittests/Dialect/OpenACC/OpenACCUtilsCGTest.cpp (+5)
diff --git a/mlir/include/mlir/Dialect/OpenACC/OpenACCCGOps.td b/mlir/include/mlir/Dialect/OpenACC/OpenACCCGOps.td
index 63fc8476c08d4..76902a6d2690e 100644
--- a/mlir/include/mlir/Dialect/OpenACC/OpenACCCGOps.td
+++ b/mlir/include/mlir/Dialect/OpenACC/OpenACCCGOps.td
@@ -297,6 +297,11 @@ def OpenACC_ComputeRegionOp
     region (e.g., `"acc.parallel"`, `"acc.kernels"`). This is intended to
     be solely informational.
 
+    Canonicalization may simplify `ins` captures: duplicate `ins` operands
+    (same SSA value threaded more than once) are merged by reusing the first
+    block argument, and unused `ins` operands (block arguments with no uses)
+    are removed. `launch` operands are never merged or dropped.
+
     Example:
 
     ```mlir
@@ -327,6 +332,8 @@ def OpenACC_ComputeRegionOp
 
   let regions = (region AnyRegion:$region);
 
+  let hasCanonicalizer = 1;
+
   let extraClassDeclaration = [{
     /// Look up the par_width op for the given dimension among launch args.
     std::optional<mlir::Value> getLaunchArg(
@@ -365,9 +372,17 @@ def OpenACC_ComputeRegionOp
       return &getRegion().back().back();
     }
 
-    /// Map a block argument back to its corresponding operand
-    /// ($launchArgs or $inputArgs).
+    /// Return the `launch` or `ins` operand threaded to `blockArg`, or a null
+    /// `Value` if `blockArg` is not an argument of `getBody()` or its index is
+    /// out of range for this op's `launch` and `ins` operands.
     ::mlir::Value getOperand(::mlir::BlockArgument blockArg);
+
+    /// If `value` is a launch or input operand, return the body block argument
+    /// it is threaded through; otherwise `std::nullopt`. If `value` matches
+    /// more than one `ins` operand, the first match is returned (canonicalization
+    /// may merge duplicate `ins` values). Duplicate `launch` operands are not
+    /// folded.
+    std::optional<::mlir::BlockArgument> getBlockArg(::mlir::Value value);
   }];
 
   let hasVerifier = 1;
diff --git a/mlir/lib/Dialect/OpenACC/IR/OpenACCCG.cpp b/mlir/lib/Dialect/OpenACC/IR/OpenACCCG.cpp
index 04f8c848c7287..1e86dcf2c3946 100644
--- a/mlir/lib/Dialect/OpenACC/IR/OpenACCCG.cpp
+++ b/mlir/lib/Dialect/OpenACC/IR/OpenACCCG.cpp
@@ -19,6 +19,7 @@
 #include "mlir/Dialect/Utils/StaticValueUtils.h"
 #include "mlir/IR/BuiltinAttributes.h"
 #include "mlir/IR/BuiltinTypes.h"
+#include "mlir/IR/PatternMatch.h"
 #include "mlir/IR/Region.h"
 #include "mlir/Interfaces/ControlFlowInterfaces.h"
 #include "mlir/Support/LogicalResult.h"
@@ -107,6 +108,90 @@ struct RemoveEmptyKernelEnvironment
   }
 };
 
+static void updateComputeRegionInputOperandSegments(ComputeRegionOp op,
+                                                    PatternRewriter &rewriter,
+                                                    size_t numInput) {
+  const size_t numLaunch = op.getLaunchArgs().size();
+  op->setAttr(ComputeRegionOp::getOperandSegmentSizeAttr(),
+              rewriter.getDenseI32ArrayAttr(
+                  {static_cast<int32_t>(numLaunch),
+                   static_cast<int32_t>(numInput), op.getStream() ? 1 : 0}));
+}
+
+struct ComputeRegionRemoveDuplicateArgs
+    : public OpRewritePattern<ComputeRegionOp> {
+  using OpRewritePattern::OpRewritePattern;
+
+  LogicalResult matchAndRewrite(ComputeRegionOp op,
+                                PatternRewriter &rewriter) const override {
+    Block *body = op.getBody();
+    const size_t numLaunch = op.getLaunchArgs().size();
+    size_t numInput = op.getInputArgs().size();
+    assert(body->getNumArguments() == numLaunch + numInput &&
+           "region args mismatch");
+
+    bool mergedAny = false;
+    while (true) {
+      bool merged = false;
+      for (size_t j = 1; j < numInput && !merged; ++j) {
+        for (size_t i = 0; i < j; ++i) {
+          if (op->getOperand(static_cast<unsigned>(numLaunch + i)) !=
+              op->getOperand(static_cast<unsigned>(numLaunch + j)))
+            continue;
+          unsigned keepIdx = static_cast<unsigned>(numLaunch + i);
+          unsigned dropIdx = static_cast<unsigned>(numLaunch + j);
+          rewriter.replaceAllUsesWith(body->getArgument(dropIdx),
+                                      body->getArgument(keepIdx));
+          body->eraseArgument(dropIdx);
+          op->eraseOperand(dropIdx);
+          --numInput;
+          merged = true;
+          mergedAny = true;
+          break;
+        }
+      }
+      if (!merged)
+        break;
+    }
+
+    if (!mergedAny)
+      return failure();
+    updateComputeRegionInputOperandSegments(op, rewriter, numInput);
+    return success();
+  }
+};
+
+struct ComputeRegionRemoveUnusedArgs
+    : public OpRewritePattern<ComputeRegionOp> {
+  using OpRewritePattern::OpRewritePattern;
+
+  LogicalResult matchAndRewrite(ComputeRegionOp op,
+                                PatternRewriter &rewriter) const override {
+    Block *body = op.getBody();
+    const size_t numLaunch = op.getLaunchArgs().size();
+    size_t numInput = op.getInputArgs().size();
+    assert(body->getNumArguments() == numLaunch + numInput &&
+           "region args mismatch");
+
+    bool changed = false;
+    for (size_t k = numLaunch; k < numLaunch + numInput;) {
+      if (!body->getArgument(static_cast<unsigned>(k)).use_empty()) {
+        ++k;
+        continue;
+      }
+      body->eraseArgument(static_cast<unsigned>(k));
+      op->eraseOperand(static_cast<unsigned>(k));
+      --numInput;
+      changed = true;
+    }
+
+    if (!changed)
+      return failure();
+    updateComputeRegionInputOperandSegments(op, rewriter, numInput);
+    return success();
+  }
+};
+
 template <typename EffectTy>
 static void addOperandEffect(
     SmallVectorImpl<SideEffects::EffectInstance<MemoryEffects::Effect>>
@@ -441,15 +526,39 @@ SmallVector<GPUParallelDimAttr> ComputeRegionOp::getLaunchParDims() {
 }
 
 Value ComputeRegionOp::getOperand(BlockArgument blockArg) {
+  Block *body = getBody();
+  if (blockArg.getOwner() != body)
+    return Value();
   unsigned argNumber = blockArg.getArgNumber();
   unsigned numLaunchArgs = getLaunchArgs().size();
-  assert(argNumber < (numLaunchArgs + getInputArgs().size()) &&
-         "invalid block argument");
+  unsigned numInputArgs = getInputArgs().size();
+  if (argNumber >= numLaunchArgs + numInputArgs)
+    return Value();
   if (argNumber < numLaunchArgs)
     return getLaunchArgs()[argNumber];
   return getInputArgs()[argNumber - numLaunchArgs];
 }
 
+std::optional<BlockArgument> ComputeRegionOp::getBlockArg(Value value) {
+  Block *body = getBody();
+  for (auto [idx, launchVal] : llvm::enumerate(getLaunchArgs())) {
+    if (launchVal == value)
+      return body->getArgument(idx);
+  }
+  unsigned numLaunch = getLaunchArgs().size();
+  for (auto [idx, inputVal] : llvm::enumerate(getInputArgs())) {
+    if (inputVal == value)
+      return body->getArgument(numLaunch + idx);
+  }
+  return std::nullopt;
+}
+
+void ComputeRegionOp::getCanonicalizationPatterns(
+    RewritePatternSet &results, MLIRContext *context) {
+  results.add<ComputeRegionRemoveDuplicateArgs, ComputeRegionRemoveUnusedArgs>(
+      context);
+}
+
 BlockArgument ComputeRegionOp::gpuParWidth(gpu::Processor processor) {
   return parDimToWidth(GPUParallelDimAttr::get(getContext(), processor));
 }
diff --git a/mlir/lib/Dialect/OpenACC/Utils/OpenACCUtils.cpp b/mlir/lib/Dialect/OpenACC/Utils/OpenACCUtils.cpp
index 1cc313206a99f..f20ace4398696 100644
--- a/mlir/lib/Dialect/OpenACC/Utils/OpenACCUtils.cpp
+++ b/mlir/lib/Dialect/OpenACC/Utils/OpenACCUtils.cpp
@@ -37,6 +37,8 @@ mlir::Operation *mlir::acc::getACCDataClauseOpForBlockArg(mlir::Value v) {
     return nullptr;
 
   mlir::Value orig = computeReg.getOperand(barg);
+  if (!orig)
+    return nullptr;
   mlir::Operation *def = orig.getDefiningOp();
   return mlir::isa_and_nonnull<ACC_DATA_ENTRY_OPS>(def) ? def : nullptr;
 }
diff --git a/mlir/test/Dialect/OpenACC/compute-region-canonicalize.mlir b/mlir/test/Dialect/OpenACC/compute-region-canonicalize.mlir
new file mode 100644
index 0000000000000..68b1193508ad6
--- /dev/null
+++ b/mlir/test/Dialect/OpenACC/compute-region-canonicalize.mlir
@@ -0,0 +1,78 @@
+// RUN: mlir-opt -canonicalize -split-input-file %s | FileCheck %s
+
+// -----
+
+// CHECK-LABEL: func @merge_duplicate_ins
+func.func @merge_duplicate_ins() -> i32 {
+  %c0 = arith.constant 0 : i32
+  %m = memref.alloca() : memref<i32>
+  memref.store %c0, %m[] : memref<i32>
+  acc.compute_region ins(%a = %m, %b = %m) : (memref<i32>, memref<i32>) {
+    %c1 = arith.constant 1 : i32
+    %v = memref.load %a[] : memref<i32>
+    %x = arith.addi %v, %c1 : i32
+    memref.store %x, %a[] : memref<i32>
+    acc.yield
+  } {origin = "acc.serial"}
+  %r = memref.load %m[] : memref<i32>
+  return %r : i32
+}
+// CHECK: acc.compute_region ins({{.*}}) : (memref<i32>) {
+
+// -----
+
+// CHECK-LABEL: func @merge_duplicate_ins_complex_pattern
+func.func @merge_duplicate_ins_complex_pattern() -> i32 {
+  %c0 = arith.constant 0 : i32
+  %ma = memref.alloca() : memref<i32>
+  %mb = memref.alloca() : memref<i32>
+  %mc = memref.alloca() : memref<i32>
+  memref.store %c0, %ma[] : memref<i32>
+  memref.store %c0, %mb[] : memref<i32>
+  memref.store %c0, %mc[] : memref<i32>
+  acc.compute_region ins(%a0 = %ma, %b0 = %mb, %a1 = %ma, %mc0 = %mc, %mc1 = %mc, %b1 = %mb, %a2 = %ma) : (memref<i32>, memref<i32>, memref<i32>, memref<i32>, memref<i32>, memref<i32>, memref<i32>) {
+    %one = arith.constant 1 : i32
+    %v0 = memref.load %a0[] : memref<i32>
+    %v1 = memref.load %b0[] : memref<i32>
+    %v2 = memref.load %a1[] : memref<i32>
+    %v3 = memref.load %mc0[] : memref<i32>
+    %v4 = memref.load %mc1[] : memref<i32>
+    %v5 = memref.load %b1[] : memref<i32>
+    %v6 = memref.load %a2[] : memref<i32>
+    %sum1 = arith.addi %v0, %v1 : i32
+    %sum2 = arith.addi %sum1, %v2 : i32
+    %sum3 = arith.addi %sum2, %v3 : i32
+    %sum4 = arith.addi %sum3, %v4 : i32
+    %sum5 = arith.addi %sum4, %v5 : i32
+    %sum6 = arith.addi %sum5, %v6 : i32
+    %out = arith.addi %sum6, %one : i32
+    memref.store %out, %a0[] : memref<i32>
+    acc.yield
+  } {origin = "acc.serial"}
+  %r = memref.load %ma[] : memref<i32>
+  return %r : i32
+}
+// CHECK: acc.compute_region ins({{.*}}) : (memref<i32>, memref<i32>, memref<i32>) {
+
+// -----
+
+// CHECK-LABEL: func @drop_unused_ins
+func.func @drop_unused_ins() -> i32 {
+  %c0 = arith.constant 0 : i32
+  %ma = memref.alloca() : memref<i32>
+  %mb = memref.alloca() : memref<i32>
+  %mc = memref.alloca() : memref<i32>
+  memref.store %c0, %ma[] : memref<i32>
+  memref.store %c0, %mb[] : memref<i32>
+  memref.store %c0, %mc[] : memref<i32>
+  acc.compute_region ins(%a = %ma, %b = %mb, %c = %mc) : (memref<i32>, memref<i32>, memref<i32>) {
+    %c1 = arith.constant 1 : i32
+    %v = memref.load %a[] : memref<i32>
+    %x = arith.addi %v, %c1 : i32
+    memref.store %x, %a[] : memref<i32>
+    acc.yield
+  } {origin = "acc.serial"}
+  %r = memref.load %ma[] : memref<i32>
+  return %r : i32
+}
+// CHECK: acc.compute_region ins({{.*}}) : (memref<i32>) {
diff --git a/mlir/unittests/Dialect/OpenACC/OpenACCUtilsCGTest.cpp b/mlir/unittests/Dialect/OpenACC/OpenACCUtilsCGTest.cpp
index e7e5974ed5c70..6fe0ffb2d54fe 100644
--- a/mlir/unittests/Dialect/OpenACC/OpenACCUtilsCGTest.cpp
+++ b/mlir/unittests/Dialect/OpenACC/OpenACCUtilsCGTest.cpp
@@ -216,5 +216,10 @@ TEST_F(OpenACCUtilsCGTest, buildComputeRegionWithInputArgsToMap) {
   }
   EXPECT_TRUE(foundAddI);
 
+  EXPECT_EQ(cr.getOperand(crBlock.getArgument(0)), deviceBlock->getArgument(0));
+  ASSERT_TRUE(cr.getBlockArg(deviceBlock->getArgument(0)).has_value());
+  EXPECT_EQ(*cr.getBlockArg(deviceBlock->getArgument(0)),
+            crBlock.getArgument(0));
+
   func::ReturnOp::create(rewriter, loc);
 }

@github-actions
Copy link
Copy Markdown

github-actions Bot commented Apr 16, 2026

✅ With the latest revision this PR passed the C/C++ code formatter.

Copy link
Copy Markdown
Contributor

@vzakhari vzakhari left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM. Thank you, Razvan!

@razvanlupusoru razvanlupusoru merged commit 1b433e9 into llvm:main Apr 16, 2026
10 checks passed
alexfh pushed a commit to alexfh/llvm-project that referenced this pull request Apr 18, 2026
…2376)

This PR improves the APIs for navigating through acc.compute_region
block arguments and also adds canonicalization patterns for those
arguments to remove unused ones and merge duplicates.
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants