diff --git a/mlir/include/mlir/Dialect/OpenACC/Transforms/Passes.td b/mlir/include/mlir/Dialect/OpenACC/Transforms/Passes.td index e10fde3c2691f..68a52e0706d60 100644 --- a/mlir/include/mlir/Dialect/OpenACC/Transforms/Passes.td +++ b/mlir/include/mlir/Dialect/OpenACC/Transforms/Passes.td @@ -252,4 +252,39 @@ def ACCSpecializeForHost : Pass<"acc-specialize-for-host", "mlir::func::FuncOp"> ]; } +def ACCIfClauseLowering : Pass<"acc-if-clause-lowering", "mlir::func::FuncOp"> { + let summary = "Lower if clauses in ACC compute constructs"; + let description = [{ + This pass lowers OpenACC compute constructs (parallel, kernels, serial) with + `if` clauses using region specialization. It creates two execution paths: + device execution when the condition is true, host execution when false. + + When an ACC compute construct has an `if` clause, the construct should only + execute on the device when the condition is true. If the condition is false, + the code should execute on the host instead. This pass transforms: + + ```mlir + acc.parallel if(%cond) { ... } + ``` + + Into: + + ```mlir + scf.if %cond { + // Device path: clone data ops, compute construct without if, exit ops + acc.parallel { ... } + } else { + // Host path: original region body with ACC ops converted to host + } + ``` + + The transformation handles: + - Data entry operations (acc.copyin, acc.create, etc.) are cloned to device path + - Data exit operations (acc.copyout, acc.delete, etc.) are cloned to device path + - The host path uses `populateACCHostFallbackPatterns` to convert ACC ops + }]; + let dependentDialects = ["mlir::acc::OpenACCDialect", + "mlir::scf::SCFDialect"]; +} + #endif // MLIR_DIALECT_OPENACC_TRANSFORMS_PASSES diff --git a/mlir/lib/Dialect/OpenACC/Transforms/ACCIfClauseLowering.cpp b/mlir/lib/Dialect/OpenACC/Transforms/ACCIfClauseLowering.cpp new file mode 100644 index 0000000000000..5524c291a80e7 --- /dev/null +++ b/mlir/lib/Dialect/OpenACC/Transforms/ACCIfClauseLowering.cpp @@ -0,0 +1,245 @@ +//===- ACCIfClauseLowering.cpp - Lower ACC compute construct if clauses --===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This pass lowers OpenACC compute constructs (parallel, kernels, serial) with +// `if` clauses using region specialization. It creates two execution paths: +// device execution when the condition is true, host execution when false. +// +// Overview: +// --------- +// When an ACC compute construct has an `if` clause, the construct should only +// execute on the device when the condition is true. If the condition is false, +// the code should execute on the host instead. This pass transforms: +// +// acc.parallel if(%cond) { ... } +// +// Into: +// +// scf.if %cond { +// // Device path: clone data ops, compute construct without if, exit ops +// acc.parallel { ... } +// } else { +// // Host path: original region body with ACC ops converted to host +// } +// +// Transformations: +// ---------------- +// For each compute construct with an `if` clause: +// +// 1. Device Path (true branch): +// - Clone data entry operations (acc.copyin, acc.create, etc.) +// - Clone the compute construct without the `if` clause +// - Clone data exit operations (acc.copyout, acc.delete, etc.) +// +// 2. Host Path (false branch): +// - Move the original region body to the else branch +// - Apply host fallback patterns to convert ACC ops to host equivalents +// +// 3. Cleanup: +// - Erase the original compute construct and data operations +// - Replace uses of ACC variables with host variables in the else branch +// +// Requirements: +// ------------- +// To use this pass in a pipeline, the following requirements exist: +// +// 1. Analysis Registration (Optional): If custom behavior is needed for +// emitting not-yet-implemented messages for unsupported cases, the pipeline +// should pre-register the `acc::OpenACCSupport` analysis. +// +//===----------------------------------------------------------------------===// + +#include "mlir/Dialect/OpenACC/Transforms/Passes.h" + +#include "mlir/Dialect/Func/IR/FuncOps.h" +#include "mlir/Dialect/OpenACC/Analysis/OpenACCSupport.h" +#include "mlir/Dialect/OpenACC/OpenACC.h" +#include "mlir/Dialect/SCF/IR/SCF.h" +#include "mlir/IR/Builders.h" +#include "mlir/IR/IRMapping.h" +#include "mlir/IR/PatternMatch.h" +#include "mlir/Transforms/GreedyPatternRewriteDriver.h" +#include "llvm/ADT/STLExtras.h" +#include "llvm/Support/Debug.h" + +namespace mlir { +namespace acc { +#define GEN_PASS_DEF_ACCIFCLAUSELOWERING +#include "mlir/Dialect/OpenACC/Transforms/Passes.h.inc" +} // namespace acc +} // namespace mlir + +#define DEBUG_TYPE "acc-if-clause-lowering" + +using namespace mlir; +using namespace mlir::acc; + +namespace { + +class ACCIfClauseLowering + : public acc::impl::ACCIfClauseLoweringBase { + using ACCIfClauseLoweringBase::ACCIfClauseLoweringBase; + +private: + OpenACCSupport *accSupport = nullptr; + + void convertHostRegion(Operation *computeOp, Region ®ion); + + template + void lowerIfClauseForComputeConstruct(OpTy computeConstructOp, + SmallVector &eraseOps); + +public: + void runOnOperation() override; +}; + +void ACCIfClauseLowering::convertHostRegion(Operation *computeOp, + Region ®ion) { + // Only collect ACC dialect operations - other ops don't need conversion + SmallVector hostOps; + region.walk([&](Operation *op) { + if (isa(op->getDialect())) + hostOps.push_back(op); + }); + + RewritePatternSet patterns(computeOp->getContext()); + populateACCHostFallbackPatterns(patterns, *accSupport); + + GreedyRewriteConfig config; + config.setUseTopDownTraversal(true); + config.setStrictness(GreedyRewriteStrictness::ExistingOps); + if (failed(applyOpPatternsGreedily(hostOps, std::move(patterns), config))) + accSupport->emitNYI(computeOp->getLoc(), "failed to convert host region"); +} + +// Template function to handle if condition conversion for ACC compute +// constructs +template +void ACCIfClauseLowering::lowerIfClauseForComputeConstruct( + OpTy computeConstructOp, SmallVector &eraseOps) { + Value ifCond = computeConstructOp.getIfCond(); + if (!ifCond) + return; + + IRRewriter rewriter(computeConstructOp); + + LLVM_DEBUG(llvm::dbgs() << "Converting " << computeConstructOp->getName() + << " with if condition: " << computeConstructOp + << "\n"); + + // Collect data clause operations that need to be recreated in the if + // condition + SmallVector dataEntryOps; + SmallVector dataExitOps; + + // Collect data entry operations + for (Value operand : computeConstructOp.getDataClauseOperands()) { + if (Operation *defOp = operand.getDefiningOp()) + if (isa(defOp)) + dataEntryOps.push_back(defOp); + } + + // Find corresponding exit operations for each entry operation. + // Iterate backwards through entry ops since exit ops appear in reverse order. + for (Operation *dataEntryOp : llvm::reverse(dataEntryOps)) + for (Operation *user : dataEntryOp->getUsers()) + if (isa(user)) + dataExitOps.push_back(user); + + // Create scf.if with device and host execution paths + auto ifOp = scf::IfOp::create(rewriter, computeConstructOp.getLoc(), + TypeRange{}, ifCond, /*withElseRegion=*/true); + + // Declare deviceMapping at function scope for later use + IRMapping deviceMapping; + + // Device execution path (true branch) + Block &thenBlock = ifOp.getThenRegion().front(); + rewriter.setInsertionPointToStart(&thenBlock); + + // Clone data entry operations + SmallVector deviceDataOperands; + + LLVM_DEBUG(llvm::dbgs() << "Cloning " << dataEntryOps.size() + << " data entry operations for device path\n"); + + for (Operation *dataOp : dataEntryOps) { + Operation *clonedDataOp = rewriter.clone(*dataOp, deviceMapping); + deviceDataOperands.push_back(clonedDataOp->getResult(0)); + deviceMapping.map(dataOp->getResult(0), clonedDataOp->getResult(0)); + } + + // Create new compute op without if condition for device execution by + // cloning + OpTy newComputeOp = cast( + rewriter.clone(*computeConstructOp.getOperation(), deviceMapping)); + newComputeOp.getIfCondMutable().clear(); + newComputeOp.getDataClauseOperandsMutable().assign(deviceDataOperands); + + // Clone data exit operations + rewriter.setInsertionPointAfter(newComputeOp); + for (Operation *dataOp : dataExitOps) + rewriter.clone(*dataOp, deviceMapping); + + rewriter.setInsertionPointToEnd(&thenBlock); + if (!thenBlock.getTerminator()) + scf::YieldOp::create(rewriter, computeConstructOp.getLoc()); + + // Host execution path (false branch) + if (!computeConstructOp.getRegion().hasOneBlock()) { + accSupport->emitNYI(computeConstructOp.getLoc(), + "region with multiple blocks"); + return; + } + + // Don't need to clone original ops, just take them and legalize for host + ifOp.getElseRegion().takeBody(computeConstructOp.getRegion()); + + // Swap acc yield for scf yield + Block &elseBlock = ifOp.getElseRegion().front(); + elseBlock.getTerminator()->erase(); + rewriter.setInsertionPointToEnd(&elseBlock); + scf::YieldOp::create(rewriter, computeConstructOp.getLoc()); + + convertHostRegion(computeConstructOp, ifOp.getElseRegion()); + + // The original op is now empty and can be erased + eraseOps.push_back(computeConstructOp); + + // TODO: Can probably 'move' the data ops instead of cloning them + // which would eliminate need to explicitly erase + for (Operation *dataOp : dataExitOps) + eraseOps.push_back(dataOp); + + for (Operation *dataOp : dataEntryOps) { + // The new host code may contain uses of the acc variables. Replace them by + // the host values. + getAccVar(dataOp).replaceAllUsesWith(getVar(dataOp)); + eraseOps.push_back(dataOp); + } +} + +void ACCIfClauseLowering::runOnOperation() { + func::FuncOp funcOp = getOperation(); + accSupport = &getAnalysis(); + + SmallVector eraseOps; + funcOp.walk([&](Operation *op) { + if (auto parallelOp = dyn_cast(op)) + lowerIfClauseForComputeConstruct(parallelOp, eraseOps); + else if (auto kernelsOp = dyn_cast(op)) + lowerIfClauseForComputeConstruct(kernelsOp, eraseOps); + else if (auto serialOp = dyn_cast(op)) + lowerIfClauseForComputeConstruct(serialOp, eraseOps); + }); + + for (Operation *op : eraseOps) + op->erase(); +} + +} // namespace diff --git a/mlir/lib/Dialect/OpenACC/Transforms/CMakeLists.txt b/mlir/lib/Dialect/OpenACC/Transforms/CMakeLists.txt index e94ac6f332834..3a0ca338766e4 100644 --- a/mlir/lib/Dialect/OpenACC/Transforms/CMakeLists.txt +++ b/mlir/lib/Dialect/OpenACC/Transforms/CMakeLists.txt @@ -1,4 +1,5 @@ add_mlir_dialect_library(MLIROpenACCTransforms + ACCIfClauseLowering.cpp ACCImplicitData.cpp ACCLoopTiling.cpp ACCImplicitDeclare.cpp diff --git a/mlir/test/Dialect/OpenACC/acc-if-clause-lowering.mlir b/mlir/test/Dialect/OpenACC/acc-if-clause-lowering.mlir new file mode 100644 index 0000000000000..3f0df18619bc0 --- /dev/null +++ b/mlir/test/Dialect/OpenACC/acc-if-clause-lowering.mlir @@ -0,0 +1,224 @@ +// RUN: mlir-opt %s -acc-if-clause-lowering -split-input-file | FileCheck %s + +// Test acc.parallel with if condition +// CHECK-LABEL: func.func @test_parallel_if +func.func @test_parallel_if(%arg0: memref<10xi32>, %cond: i1) { + %c0_i32 = arith.constant 0 : i32 + %c1 = arith.constant 1 : index + %c10 = arith.constant 10 : index + + %copyin = acc.copyin varPtr(%arg0 : memref<10xi32>) -> memref<10xi32> + %create = acc.create varPtr(%arg0 : memref<10xi32>) -> memref<10xi32> {dataClause = #acc} + + // CHECK-NOT: acc.parallel if + // CHECK: scf.if %{{.*}} { + // CHECK: %[[COPYIN:.*]] = acc.copyin varPtr(%{{.*}}) -> memref<10xi32> + // CHECK: %[[CREATE:.*]] = acc.create varPtr(%{{.*}}) -> memref<10xi32> + // CHECK: acc.parallel dataOperands(%[[COPYIN]], %[[CREATE]] : memref<10xi32>, memref<10xi32>) { + // CHECK: scf.for + // CHECK: acc.yield + // CHECK: } + // CHECK: acc.delete accPtr(%[[CREATE]] : memref<10xi32>) + // CHECK: acc.copyout accPtr(%[[COPYIN]] : memref<10xi32>) to varPtr(%{{.*}} : memref<10xi32>) + // CHECK: } else { + // CHECK: scf.for + // CHECK: } + acc.parallel dataOperands(%copyin, %create : memref<10xi32>, memref<10xi32>) if(%cond) { + scf.for %i = %c1 to %c10 step %c1 { + memref.store %c0_i32, %arg0[%i] : memref<10xi32> + } + acc.yield + } + + acc.delete accPtr(%create : memref<10xi32>) + acc.copyout accPtr(%copyin : memref<10xi32>) to varPtr(%arg0 : memref<10xi32>) + return +} + +// ----- + +// Test acc.kernels with if condition +// CHECK-LABEL: func.func @test_kernels_if +func.func @test_kernels_if(%arg0: memref<5xi32>, %cond: i1) { + %c1_i32 = arith.constant 1 : i32 + %c1 = arith.constant 1 : index + %c5 = arith.constant 5 : index + + %copyin = acc.copyin varPtr(%arg0 : memref<5xi32>) -> memref<5xi32> + %create = acc.create varPtr(%arg0 : memref<5xi32>) -> memref<5xi32> {dataClause = #acc} + + // CHECK-NOT: acc.kernels if + // CHECK: scf.if %{{.*}} { + // CHECK: %[[COPYIN:.*]] = acc.copyin + // CHECK: %[[CREATE:.*]] = acc.create + // CHECK: acc.kernels dataOperands(%[[COPYIN]], %[[CREATE]] : memref<5xi32>, memref<5xi32>) { + // CHECK: scf.for + // CHECK: acc.terminator + // CHECK: } + // CHECK: acc.delete accPtr(%[[CREATE]] : memref<5xi32>) + // CHECK: acc.copyout accPtr(%[[COPYIN]] : memref<5xi32>) to varPtr(%{{.*}} : memref<5xi32>) + // CHECK: } else { + // CHECK: scf.for + // CHECK: } + acc.kernels dataOperands(%copyin, %create : memref<5xi32>, memref<5xi32>) if(%cond) { + scf.for %i = %c1 to %c5 step %c1 { + memref.store %c1_i32, %arg0[%i] : memref<5xi32> + } + acc.terminator + } + + acc.delete accPtr(%create : memref<5xi32>) + acc.copyout accPtr(%copyin : memref<5xi32>) to varPtr(%arg0 : memref<5xi32>) + return +} + +// ----- + +// Test acc.serial with if condition +// CHECK-LABEL: func.func @test_serial_if +func.func @test_serial_if(%arg0: memref<8xi32>, %cond: i1) { + %c2_i32 = arith.constant 2 : i32 + %c1 = arith.constant 1 : index + %c8 = arith.constant 8 : index + + %copyin = acc.copyin varPtr(%arg0 : memref<8xi32>) -> memref<8xi32> + %create = acc.create varPtr(%arg0 : memref<8xi32>) -> memref<8xi32> {dataClause = #acc} + + // CHECK-NOT: acc.serial if + // CHECK: scf.if %{{.*}} { + // CHECK: %[[COPYIN:.*]] = acc.copyin + // CHECK: %[[CREATE:.*]] = acc.create + // CHECK: acc.serial dataOperands(%[[COPYIN]], %[[CREATE]] : memref<8xi32>, memref<8xi32>) { + // CHECK: scf.for + // CHECK: acc.yield + // CHECK: } + // CHECK: acc.delete accPtr(%[[CREATE]] : memref<8xi32>) + // CHECK: acc.copyout accPtr(%[[COPYIN]] : memref<8xi32>) to varPtr(%{{.*}} : memref<8xi32>) + // CHECK: } else { + // CHECK: scf.for + // CHECK: } + acc.serial dataOperands(%copyin, %create : memref<8xi32>, memref<8xi32>) if(%cond) { + scf.for %i = %c1 to %c8 step %c1 { + memref.store %c2_i32, %arg0[%i] : memref<8xi32> + } + acc.yield + } + + acc.delete accPtr(%create : memref<8xi32>) + acc.copyout accPtr(%copyin : memref<8xi32>) to varPtr(%arg0 : memref<8xi32>) + return +} + +// ----- + +// Test that acc.parallel without if condition is not modified +// CHECK-LABEL: func.func @test_parallel_no_if +func.func @test_parallel_no_if(%arg0: memref<10xi32>) { + %c0_i32 = arith.constant 0 : i32 + %c1 = arith.constant 1 : index + %c10 = arith.constant 10 : index + + %copyin = acc.copyin varPtr(%arg0 : memref<10xi32>) -> memref<10xi32> + + // CHECK-NOT: scf.if + // CHECK: acc.parallel dataOperands(%{{.*}}) { + acc.parallel dataOperands(%copyin : memref<10xi32>) { + scf.for %i = %c1 to %c10 step %c1 { + memref.store %c0_i32, %arg0[%i] : memref<10xi32> + } + acc.yield + } + + acc.copyout accPtr(%copyin : memref<10xi32>) to varPtr(%arg0 : memref<10xi32>) + return +} + +// ----- + +// Test with private and reduction clauses inside compute construct +acc.private.recipe @privatization_memref_i32 : memref init { +^bb0(%arg0: memref): + %0 = memref.alloca() : memref + acc.yield %0 : memref +} + +acc.reduction.recipe @reduction_add_memref_f32 : memref reduction_operator init { +^bb0(%arg0: memref): + %cst = arith.constant 0.000000e+00 : f32 + %0 = memref.alloca() : memref + memref.store %cst, %0[] : memref + acc.yield %0 : memref +} combiner { +^bb0(%arg0: memref, %arg1: memref): + %0 = memref.load %arg0[] : memref + %1 = memref.load %arg1[] : memref + %2 = arith.addf %0, %1 : f32 + memref.store %2, %arg0[] : memref + acc.yield %arg0 : memref +} + +// CHECK-LABEL: func.func @test_reduction_if +func.func @test_reduction_if(%r: memref, %a: memref<8xf32>, %cond: i1) { + %c8_i32 = arith.constant 8 : i32 + %c1_i32 = arith.constant 1 : i32 + + %copyin = acc.copyin varPtr(%r : memref) -> memref {dataClause = #acc, implicit = true} + + // CHECK: scf.if + // CHECK: acc.parallel + // CHECK: } else { + // The else branch should have acc ops converted to host + // CHECK-NOT: acc.loop + // CHECK-NOT: acc.reduction + // CHECK-NOT: acc.private + // CHECK: } + acc.parallel combined(loop) dataOperands(%copyin : memref) if(%cond) { + %red = acc.reduction varPtr(%r : memref) recipe(@reduction_add_memref_f32) -> memref + %iter_var = memref.alloca() : memref + %priv = acc.private varPtr(%iter_var : memref) recipe(@privatization_memref_i32) -> memref + acc.loop combined(parallel) vector private(%priv : memref) reduction(%red : memref) control(%iv : i32) = (%c1_i32 : i32) to (%c8_i32 : i32) step (%c1_i32 : i32) { + memref.store %iv, %priv[] : memref + %idx = memref.load %priv[] : memref + %idx_cast = arith.index_cast %idx : i32 to index + %elem = memref.load %a[%idx_cast] : memref<8xf32> + %r_val = memref.load %r[] : memref + %new_r = arith.addf %r_val, %elem : f32 + memref.store %new_r, %r[] : memref + acc.yield + } attributes {inclusiveUpperbound = array, independent = [#acc.device_type]} + acc.yield + } + + acc.copyout accPtr(%copyin : memref) to varPtr(%r : memref) {dataClause = #acc, implicit = true} + return +} + +// ----- + +// Test that acc variable uses in host path are replaced with host variables +// CHECK-LABEL: func.func @test_acc_var_replacement +func.func @test_acc_var_replacement(%arg0: memref<10xi32>, %cond: i1) { + %c0_i32 = arith.constant 0 : i32 + %c1 = arith.constant 1 : index + %c10 = arith.constant 10 : index + + %copyin = acc.copyin varPtr(%arg0 : memref<10xi32>) -> memref<10xi32> + + // In the else branch, uses of %copyin should be replaced with %arg0 + // CHECK: scf.if + // CHECK: } else { + // CHECK: scf.for + // CHECK: memref.store %{{.*}}, %arg0[%{{.*}}] + // CHECK: } + acc.parallel dataOperands(%copyin : memref<10xi32>) if(%cond) { + scf.for %i = %c1 to %c10 step %c1 { + // Use the acc ptr inside the region + memref.store %c0_i32, %copyin[%i] : memref<10xi32> + } + acc.yield + } + + acc.copyout accPtr(%copyin : memref<10xi32>) to varPtr(%arg0 : memref<10xi32>) + return +} +