[mlir][acc] Add canonicalization patterns for compute_region#192376
Merged
Conversation
This PR improves the APIs for navigating through acc.compute_region block arguments and also adds canonicalization patterns for those arguments to remove unused ones and merge duplicates.
Member
|
@llvm/pr-subscribers-mlir-openacc @llvm/pr-subscribers-mlir Author: Razvan Lupusoru (razvanlupusoru) ChangesThis PR improves the APIs for navigating through acc.compute_region block arguments and also adds canonicalization patterns for those arguments to remove unused ones and merge duplicates. Full diff: https://github.com/llvm/llvm-project/pull/192376.diff 5 Files Affected:
diff --git a/mlir/include/mlir/Dialect/OpenACC/OpenACCCGOps.td b/mlir/include/mlir/Dialect/OpenACC/OpenACCCGOps.td
index 63fc8476c08d4..76902a6d2690e 100644
--- a/mlir/include/mlir/Dialect/OpenACC/OpenACCCGOps.td
+++ b/mlir/include/mlir/Dialect/OpenACC/OpenACCCGOps.td
@@ -297,6 +297,11 @@ def OpenACC_ComputeRegionOp
region (e.g., `"acc.parallel"`, `"acc.kernels"`). This is intended to
be solely informational.
+ Canonicalization may simplify `ins` captures: duplicate `ins` operands
+ (same SSA value threaded more than once) are merged by reusing the first
+ block argument, and unused `ins` operands (block arguments with no uses)
+ are removed. `launch` operands are never merged or dropped.
+
Example:
```mlir
@@ -327,6 +332,8 @@ def OpenACC_ComputeRegionOp
let regions = (region AnyRegion:$region);
+ let hasCanonicalizer = 1;
+
let extraClassDeclaration = [{
/// Look up the par_width op for the given dimension among launch args.
std::optional<mlir::Value> getLaunchArg(
@@ -365,9 +372,17 @@ def OpenACC_ComputeRegionOp
return &getRegion().back().back();
}
- /// Map a block argument back to its corresponding operand
- /// ($launchArgs or $inputArgs).
+ /// Return the `launch` or `ins` operand threaded to `blockArg`, or a null
+ /// `Value` if `blockArg` is not an argument of `getBody()` or its index is
+ /// out of range for this op's `launch` and `ins` operands.
::mlir::Value getOperand(::mlir::BlockArgument blockArg);
+
+ /// If `value` is a launch or input operand, return the body block argument
+ /// it is threaded through; otherwise `std::nullopt`. If `value` matches
+ /// more than one `ins` operand, the first match is returned (canonicalization
+ /// may merge duplicate `ins` values). Duplicate `launch` operands are not
+ /// folded.
+ std::optional<::mlir::BlockArgument> getBlockArg(::mlir::Value value);
}];
let hasVerifier = 1;
diff --git a/mlir/lib/Dialect/OpenACC/IR/OpenACCCG.cpp b/mlir/lib/Dialect/OpenACC/IR/OpenACCCG.cpp
index 04f8c848c7287..1e86dcf2c3946 100644
--- a/mlir/lib/Dialect/OpenACC/IR/OpenACCCG.cpp
+++ b/mlir/lib/Dialect/OpenACC/IR/OpenACCCG.cpp
@@ -19,6 +19,7 @@
#include "mlir/Dialect/Utils/StaticValueUtils.h"
#include "mlir/IR/BuiltinAttributes.h"
#include "mlir/IR/BuiltinTypes.h"
+#include "mlir/IR/PatternMatch.h"
#include "mlir/IR/Region.h"
#include "mlir/Interfaces/ControlFlowInterfaces.h"
#include "mlir/Support/LogicalResult.h"
@@ -107,6 +108,90 @@ struct RemoveEmptyKernelEnvironment
}
};
+static void updateComputeRegionInputOperandSegments(ComputeRegionOp op,
+ PatternRewriter &rewriter,
+ size_t numInput) {
+ const size_t numLaunch = op.getLaunchArgs().size();
+ op->setAttr(ComputeRegionOp::getOperandSegmentSizeAttr(),
+ rewriter.getDenseI32ArrayAttr(
+ {static_cast<int32_t>(numLaunch),
+ static_cast<int32_t>(numInput), op.getStream() ? 1 : 0}));
+}
+
+struct ComputeRegionRemoveDuplicateArgs
+ : public OpRewritePattern<ComputeRegionOp> {
+ using OpRewritePattern::OpRewritePattern;
+
+ LogicalResult matchAndRewrite(ComputeRegionOp op,
+ PatternRewriter &rewriter) const override {
+ Block *body = op.getBody();
+ const size_t numLaunch = op.getLaunchArgs().size();
+ size_t numInput = op.getInputArgs().size();
+ assert(body->getNumArguments() == numLaunch + numInput &&
+ "region args mismatch");
+
+ bool mergedAny = false;
+ while (true) {
+ bool merged = false;
+ for (size_t j = 1; j < numInput && !merged; ++j) {
+ for (size_t i = 0; i < j; ++i) {
+ if (op->getOperand(static_cast<unsigned>(numLaunch + i)) !=
+ op->getOperand(static_cast<unsigned>(numLaunch + j)))
+ continue;
+ unsigned keepIdx = static_cast<unsigned>(numLaunch + i);
+ unsigned dropIdx = static_cast<unsigned>(numLaunch + j);
+ rewriter.replaceAllUsesWith(body->getArgument(dropIdx),
+ body->getArgument(keepIdx));
+ body->eraseArgument(dropIdx);
+ op->eraseOperand(dropIdx);
+ --numInput;
+ merged = true;
+ mergedAny = true;
+ break;
+ }
+ }
+ if (!merged)
+ break;
+ }
+
+ if (!mergedAny)
+ return failure();
+ updateComputeRegionInputOperandSegments(op, rewriter, numInput);
+ return success();
+ }
+};
+
+struct ComputeRegionRemoveUnusedArgs
+ : public OpRewritePattern<ComputeRegionOp> {
+ using OpRewritePattern::OpRewritePattern;
+
+ LogicalResult matchAndRewrite(ComputeRegionOp op,
+ PatternRewriter &rewriter) const override {
+ Block *body = op.getBody();
+ const size_t numLaunch = op.getLaunchArgs().size();
+ size_t numInput = op.getInputArgs().size();
+ assert(body->getNumArguments() == numLaunch + numInput &&
+ "region args mismatch");
+
+ bool changed = false;
+ for (size_t k = numLaunch; k < numLaunch + numInput;) {
+ if (!body->getArgument(static_cast<unsigned>(k)).use_empty()) {
+ ++k;
+ continue;
+ }
+ body->eraseArgument(static_cast<unsigned>(k));
+ op->eraseOperand(static_cast<unsigned>(k));
+ --numInput;
+ changed = true;
+ }
+
+ if (!changed)
+ return failure();
+ updateComputeRegionInputOperandSegments(op, rewriter, numInput);
+ return success();
+ }
+};
+
template <typename EffectTy>
static void addOperandEffect(
SmallVectorImpl<SideEffects::EffectInstance<MemoryEffects::Effect>>
@@ -441,15 +526,39 @@ SmallVector<GPUParallelDimAttr> ComputeRegionOp::getLaunchParDims() {
}
Value ComputeRegionOp::getOperand(BlockArgument blockArg) {
+ Block *body = getBody();
+ if (blockArg.getOwner() != body)
+ return Value();
unsigned argNumber = blockArg.getArgNumber();
unsigned numLaunchArgs = getLaunchArgs().size();
- assert(argNumber < (numLaunchArgs + getInputArgs().size()) &&
- "invalid block argument");
+ unsigned numInputArgs = getInputArgs().size();
+ if (argNumber >= numLaunchArgs + numInputArgs)
+ return Value();
if (argNumber < numLaunchArgs)
return getLaunchArgs()[argNumber];
return getInputArgs()[argNumber - numLaunchArgs];
}
+std::optional<BlockArgument> ComputeRegionOp::getBlockArg(Value value) {
+ Block *body = getBody();
+ for (auto [idx, launchVal] : llvm::enumerate(getLaunchArgs())) {
+ if (launchVal == value)
+ return body->getArgument(idx);
+ }
+ unsigned numLaunch = getLaunchArgs().size();
+ for (auto [idx, inputVal] : llvm::enumerate(getInputArgs())) {
+ if (inputVal == value)
+ return body->getArgument(numLaunch + idx);
+ }
+ return std::nullopt;
+}
+
+void ComputeRegionOp::getCanonicalizationPatterns(
+ RewritePatternSet &results, MLIRContext *context) {
+ results.add<ComputeRegionRemoveDuplicateArgs, ComputeRegionRemoveUnusedArgs>(
+ context);
+}
+
BlockArgument ComputeRegionOp::gpuParWidth(gpu::Processor processor) {
return parDimToWidth(GPUParallelDimAttr::get(getContext(), processor));
}
diff --git a/mlir/lib/Dialect/OpenACC/Utils/OpenACCUtils.cpp b/mlir/lib/Dialect/OpenACC/Utils/OpenACCUtils.cpp
index 1cc313206a99f..f20ace4398696 100644
--- a/mlir/lib/Dialect/OpenACC/Utils/OpenACCUtils.cpp
+++ b/mlir/lib/Dialect/OpenACC/Utils/OpenACCUtils.cpp
@@ -37,6 +37,8 @@ mlir::Operation *mlir::acc::getACCDataClauseOpForBlockArg(mlir::Value v) {
return nullptr;
mlir::Value orig = computeReg.getOperand(barg);
+ if (!orig)
+ return nullptr;
mlir::Operation *def = orig.getDefiningOp();
return mlir::isa_and_nonnull<ACC_DATA_ENTRY_OPS>(def) ? def : nullptr;
}
diff --git a/mlir/test/Dialect/OpenACC/compute-region-canonicalize.mlir b/mlir/test/Dialect/OpenACC/compute-region-canonicalize.mlir
new file mode 100644
index 0000000000000..68b1193508ad6
--- /dev/null
+++ b/mlir/test/Dialect/OpenACC/compute-region-canonicalize.mlir
@@ -0,0 +1,78 @@
+// RUN: mlir-opt -canonicalize -split-input-file %s | FileCheck %s
+
+// -----
+
+// CHECK-LABEL: func @merge_duplicate_ins
+func.func @merge_duplicate_ins() -> i32 {
+ %c0 = arith.constant 0 : i32
+ %m = memref.alloca() : memref<i32>
+ memref.store %c0, %m[] : memref<i32>
+ acc.compute_region ins(%a = %m, %b = %m) : (memref<i32>, memref<i32>) {
+ %c1 = arith.constant 1 : i32
+ %v = memref.load %a[] : memref<i32>
+ %x = arith.addi %v, %c1 : i32
+ memref.store %x, %a[] : memref<i32>
+ acc.yield
+ } {origin = "acc.serial"}
+ %r = memref.load %m[] : memref<i32>
+ return %r : i32
+}
+// CHECK: acc.compute_region ins({{.*}}) : (memref<i32>) {
+
+// -----
+
+// CHECK-LABEL: func @merge_duplicate_ins_complex_pattern
+func.func @merge_duplicate_ins_complex_pattern() -> i32 {
+ %c0 = arith.constant 0 : i32
+ %ma = memref.alloca() : memref<i32>
+ %mb = memref.alloca() : memref<i32>
+ %mc = memref.alloca() : memref<i32>
+ memref.store %c0, %ma[] : memref<i32>
+ memref.store %c0, %mb[] : memref<i32>
+ memref.store %c0, %mc[] : memref<i32>
+ acc.compute_region ins(%a0 = %ma, %b0 = %mb, %a1 = %ma, %mc0 = %mc, %mc1 = %mc, %b1 = %mb, %a2 = %ma) : (memref<i32>, memref<i32>, memref<i32>, memref<i32>, memref<i32>, memref<i32>, memref<i32>) {
+ %one = arith.constant 1 : i32
+ %v0 = memref.load %a0[] : memref<i32>
+ %v1 = memref.load %b0[] : memref<i32>
+ %v2 = memref.load %a1[] : memref<i32>
+ %v3 = memref.load %mc0[] : memref<i32>
+ %v4 = memref.load %mc1[] : memref<i32>
+ %v5 = memref.load %b1[] : memref<i32>
+ %v6 = memref.load %a2[] : memref<i32>
+ %sum1 = arith.addi %v0, %v1 : i32
+ %sum2 = arith.addi %sum1, %v2 : i32
+ %sum3 = arith.addi %sum2, %v3 : i32
+ %sum4 = arith.addi %sum3, %v4 : i32
+ %sum5 = arith.addi %sum4, %v5 : i32
+ %sum6 = arith.addi %sum5, %v6 : i32
+ %out = arith.addi %sum6, %one : i32
+ memref.store %out, %a0[] : memref<i32>
+ acc.yield
+ } {origin = "acc.serial"}
+ %r = memref.load %ma[] : memref<i32>
+ return %r : i32
+}
+// CHECK: acc.compute_region ins({{.*}}) : (memref<i32>, memref<i32>, memref<i32>) {
+
+// -----
+
+// CHECK-LABEL: func @drop_unused_ins
+func.func @drop_unused_ins() -> i32 {
+ %c0 = arith.constant 0 : i32
+ %ma = memref.alloca() : memref<i32>
+ %mb = memref.alloca() : memref<i32>
+ %mc = memref.alloca() : memref<i32>
+ memref.store %c0, %ma[] : memref<i32>
+ memref.store %c0, %mb[] : memref<i32>
+ memref.store %c0, %mc[] : memref<i32>
+ acc.compute_region ins(%a = %ma, %b = %mb, %c = %mc) : (memref<i32>, memref<i32>, memref<i32>) {
+ %c1 = arith.constant 1 : i32
+ %v = memref.load %a[] : memref<i32>
+ %x = arith.addi %v, %c1 : i32
+ memref.store %x, %a[] : memref<i32>
+ acc.yield
+ } {origin = "acc.serial"}
+ %r = memref.load %ma[] : memref<i32>
+ return %r : i32
+}
+// CHECK: acc.compute_region ins({{.*}}) : (memref<i32>) {
diff --git a/mlir/unittests/Dialect/OpenACC/OpenACCUtilsCGTest.cpp b/mlir/unittests/Dialect/OpenACC/OpenACCUtilsCGTest.cpp
index e7e5974ed5c70..6fe0ffb2d54fe 100644
--- a/mlir/unittests/Dialect/OpenACC/OpenACCUtilsCGTest.cpp
+++ b/mlir/unittests/Dialect/OpenACC/OpenACCUtilsCGTest.cpp
@@ -216,5 +216,10 @@ TEST_F(OpenACCUtilsCGTest, buildComputeRegionWithInputArgsToMap) {
}
EXPECT_TRUE(foundAddI);
+ EXPECT_EQ(cr.getOperand(crBlock.getArgument(0)), deviceBlock->getArgument(0));
+ ASSERT_TRUE(cr.getBlockArg(deviceBlock->getArgument(0)).has_value());
+ EXPECT_EQ(*cr.getBlockArg(deviceBlock->getArgument(0)),
+ crBlock.getArgument(0));
+
func::ReturnOp::create(rewriter, loc);
}
|
|
✅ With the latest revision this PR passed the C/C++ code formatter. |
added 2 commits
April 15, 2026 19:26
vzakhari
approved these changes
Apr 16, 2026
Contributor
vzakhari
left a comment
There was a problem hiding this comment.
LGTM. Thank you, Razvan!
alexfh
pushed a commit
to alexfh/llvm-project
that referenced
this pull request
Apr 18, 2026
…2376) This PR improves the APIs for navigating through acc.compute_region block arguments and also adds canonicalization patterns for those arguments to remove unused ones and merge duplicates.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
This PR improves the APIs for navigating through acc.compute_region block arguments and also adds canonicalization patterns for those arguments to remove unused ones and merge duplicates.