diff --git a/mlir/include/mlir/Dialect/OpenACC/OpenACCOps.td b/mlir/include/mlir/Dialect/OpenACC/OpenACCOps.td index c689b7e46ea9e..5b89f741e296d 100644 --- a/mlir/include/mlir/Dialect/OpenACC/OpenACCOps.td +++ b/mlir/include/mlir/Dialect/OpenACC/OpenACCOps.td @@ -2184,6 +2184,8 @@ def OpenACC_KernelEnvironmentOp : OpenACC_Op<"kernel_environment", ) $region attr-dict }]; + + let hasCanonicalizer = 1; } //===----------------------------------------------------------------------===// diff --git a/mlir/lib/Dialect/OpenACC/IR/OpenACC.cpp b/mlir/lib/Dialect/OpenACC/IR/OpenACC.cpp index b2f1d840f3bca..8c9c137b8aebb 100644 --- a/mlir/lib/Dialect/OpenACC/IR/OpenACC.cpp +++ b/mlir/lib/Dialect/OpenACC/IR/OpenACC.cpp @@ -1042,6 +1042,65 @@ struct RemoveConstantIfConditionWithRegion : public OpRewritePattern { } }; +/// Remove empty acc.kernel_environment operations. If the operation has wait +/// operands, create a acc.wait operation to preserve synchronization. +struct RemoveEmptyKernelEnvironment + : public OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(acc::KernelEnvironmentOp op, + PatternRewriter &rewriter) const override { + assert(op->getNumRegions() == 1 && "expected op to have one region"); + + Block &block = op.getRegion().front(); + if (!block.empty()) + return failure(); + + // Conservatively disable canonicalization of empty acc.kernel_environment + // operations if the wait operands in the kernel_environment cannot be fully + // represented by acc.wait operation. + + // Disable canonicalization if device type is not the default + if (auto deviceTypeAttr = op.getWaitOperandsDeviceTypeAttr()) { + for (auto attr : deviceTypeAttr) { + if (auto dtAttr = mlir::dyn_cast(attr)) { + if (dtAttr.getValue() != mlir::acc::DeviceType::None) + return failure(); + } + } + } + + // Disable canonicalization if any wait segment has a devnum + if (auto hasDevnumAttr = op.getHasWaitDevnumAttr()) { + for (auto attr : hasDevnumAttr) { + if (auto boolAttr = mlir::dyn_cast(attr)) { + if (boolAttr.getValue()) + return failure(); + } + } + } + + // Disable canonicalization if there are multiple wait segments + if (auto segmentsAttr = op.getWaitOperandsSegmentsAttr()) { + if (segmentsAttr.size() > 1) + return failure(); + } + + // Remove empty kernel environment. + // Preserve synchronization by creating acc.wait operation if needed. + if (!op.getWaitOperands().empty() || op.getWaitOnlyAttr()) + rewriter.replaceOpWithNewOp(op, op.getWaitOperands(), + /*asyncOperand=*/Value(), + /*waitDevnum=*/Value(), + /*async=*/nullptr, + /*ifCond=*/Value()); + else + rewriter.eraseOp(op); + + return success(); + } +}; + //===----------------------------------------------------------------------===// // Recipe Region Helpers //===----------------------------------------------------------------------===// @@ -2690,6 +2749,15 @@ void acc::HostDataOp::getCanonicalizationPatterns(RewritePatternSet &results, results.add>(context); } +//===----------------------------------------------------------------------===// +// KernelEnvironmentOp +//===----------------------------------------------------------------------===// + +void acc::KernelEnvironmentOp::getCanonicalizationPatterns( + RewritePatternSet &results, MLIRContext *context) { + results.add(context); +} + //===----------------------------------------------------------------------===// // LoopOp //===----------------------------------------------------------------------===// diff --git a/mlir/test/Dialect/OpenACC/canonicalize.mlir b/mlir/test/Dialect/OpenACC/canonicalize.mlir index fdc8e6b5cae6e..38d3df31305ad 100644 --- a/mlir/test/Dialect/OpenACC/canonicalize.mlir +++ b/mlir/test/Dialect/OpenACC/canonicalize.mlir @@ -219,3 +219,30 @@ func.func @update_unnecessary_computations(%x: memref) { // CHECK-LABEL: func.func @update_unnecessary_computations // CHECK-NOT: acc.atomic.update // CHECK: acc.atomic.write + +// ----- + +func.func @kernel_environment_canonicalization(%q1: i32, %q2: i32, %q3: i32) { + // Empty kernel_environment (no wait) - should be removed + acc.kernel_environment { + } + + acc.kernel_environment wait({%q1 : i32, %q2 : i32}) { + } + + acc.kernel_environment wait { + } + + acc.kernel_environment wait({%q3 : i32} [#acc.device_type]) { + } + + return +} + +// CHECK-LABEL: func.func @kernel_environment_canonicalization +// CHECK-SAME: ([[Q1:%.*]]: i32, [[Q2:%.*]]: i32, [[Q3:%.*]]: i32) +// CHECK-NOT: acc.kernel_environment wait({{.*}}[#acc.device_type]) +// CHECK: acc.wait([[Q1]], [[Q2]] : i32, i32) +// CHECK: acc.wait{{$}} +// CHECK: acc.kernel_environment wait({{.*}}[#acc.device_type]) +// CHECK: return