Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 10 additions & 6 deletions mlir/include/mlir/Dialect/OpenACC/OpenACCUtilsLoop.h
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
#define MLIR_DIALECT_OPENACC_OPENACCUTILSLOOP_H_

namespace mlir {
class OpBuilder;
class RewriterBase;
namespace scf {
class ForOp;
class ParallelOp;
Expand All @@ -27,26 +27,30 @@ class LoopOp;
/// The loop arguments are converted to index type. If enableCollapse is true,
/// nested loops are collapsed into a single loop.
/// @param loopOp The acc.loop operation to convert (must not be unstructured)
/// @param rewriter RewriterBase for creating operations
/// @param enableCollapse Whether to collapse nested loops into one
/// @return The created scf.for operation or nullptr on creation error.
/// An InFlightDiagnostic is emitted on creation error.
scf::ForOp convertACCLoopToSCFFor(LoopOp loopOp, bool enableCollapse);
scf::ForOp convertACCLoopToSCFFor(LoopOp loopOp, RewriterBase &rewriter,
bool enableCollapse);

/// Convert acc.loop to scf.parallel.
/// The loop induction variables are converted to index types.
/// @param loopOp The acc.loop operation to convert
/// @param builder OpBuilder for creating operations
/// @param rewriter RewriterBase for creating and erasing operations
/// @return The created scf.parallel operation or nullptr on creation error.
/// An InFlightDiagnostic is emitted on creation error.
scf::ParallelOp convertACCLoopToSCFParallel(LoopOp loopOp, OpBuilder &builder);
scf::ParallelOp convertACCLoopToSCFParallel(LoopOp loopOp,
RewriterBase &rewriter);

/// Convert an unstructured acc.loop to scf.execute_region.
/// @param loopOp The acc.loop operation to convert (must be unstructured)
/// @param builder OpBuilder for creating operations
/// @param rewriter RewriterBase for creating and erasing operations
/// @return The created scf.execute_region operation or nullptr on creation
/// error. An InFlightDiagnostic is emitted on creation error.
scf::ExecuteRegionOp
convertUnstructuredACCLoopToSCFExecuteRegion(LoopOp loopOp, OpBuilder &builder);
convertUnstructuredACCLoopToSCFExecuteRegion(LoopOp loopOp,
RewriterBase &rewriter);

} // namespace acc
} // namespace mlir
Expand Down
122 changes: 122 additions & 0 deletions mlir/include/mlir/Dialect/OpenACC/Transforms/ACCSpecializePatterns.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,122 @@
//===- ACCSpecializePatterns.h - Common ACC Specialization Patterns ------===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
//
// This file contains common rewrite pattern templates used by both
// ACCSpecializeForHost and ACCSpecializeForDevice passes.
//
// The patterns provide the following transformations:
//
// - ACCOpReplaceWithVarConversion<OpTy>: Replaces a data entry operation
// with its var operand. Used for ops like acc.copyin, acc.create, etc.
//
// - ACCOpEraseConversion<OpTy>: Simply erases an operation. Used for
// data exit ops like acc.copyout, acc.delete, and runtime ops.
//
// - ACCRegionUnwrapConversion<OpTy>: Inlines the region of an operation
// and erases the wrapper. Used for structured data constructs
// (acc.data, acc.host_data) and compute constructs (acc.parallel, etc.)
//
// - ACCDeclareEnterOpConversion: Erases acc.declare_enter and its
// associated acc.declare_exit operation.
//
//===----------------------------------------------------------------------===//

#ifndef MLIR_DIALECT_OPENACC_TRANSFORMS_ACCSPECIALIZEPATTERNS_H
#define MLIR_DIALECT_OPENACC_TRANSFORMS_ACCSPECIALIZEPATTERNS_H

#include "mlir/Dialect/OpenACC/OpenACC.h"
#include "mlir/IR/PatternMatch.h"

namespace mlir {
namespace acc {

//===----------------------------------------------------------------------===//
// Generic pattern templates for ACC specialization
//===----------------------------------------------------------------------===//

/// Pattern to replace an ACC op with its var operand.
/// Used for data entry ops like acc.copyin, acc.create, acc.attach, etc.
template <typename OpTy>
class ACCOpReplaceWithVarConversion : public OpRewritePattern<OpTy> {
using OpRewritePattern<OpTy>::OpRewritePattern;

public:
LogicalResult matchAndRewrite(OpTy op,
PatternRewriter &rewriter) const override {
// Replace this op with its var operand; it's possible the op has no uses
// if the op that had previously used it was already converted.
if (op->use_empty())
rewriter.eraseOp(op);
else
rewriter.replaceOp(op, op.getVar());
return success();
}
};

/// Pattern to simply erase an ACC op (for ops with no results).
/// Used for data exit ops like acc.copyout, acc.delete, acc.detach, etc.
template <typename OpTy>
class ACCOpEraseConversion : public OpRewritePattern<OpTy> {
using OpRewritePattern<OpTy>::OpRewritePattern;

public:
LogicalResult matchAndRewrite(OpTy op,
PatternRewriter &rewriter) const override {
assert(op->getNumResults() == 0 && "expected op with no results");
rewriter.eraseOp(op);
return success();
}
};

/// Pattern to unwrap a region from an ACC op and erase the wrapper.
/// Moves the region's contents to the parent block and removes the wrapper op.
/// Used for structured data constructs (acc.data, acc.host_data,
/// acc.kernel_environment, acc.declare) and compute constructs (acc.parallel,
/// acc.serial, acc.kernels).
template <typename OpTy>
class ACCRegionUnwrapConversion : public OpRewritePattern<OpTy> {
using OpRewritePattern<OpTy>::OpRewritePattern;

public:
LogicalResult matchAndRewrite(OpTy op,
PatternRewriter &rewriter) const override {
assert(op.getRegion().hasOneBlock() && "expected one block");
Block *block = &op.getRegion().front();
// Erase the terminator (acc.yield or acc.terminator) before unwrapping
rewriter.eraseOp(block->getTerminator());
rewriter.inlineBlockBefore(block, op);
rewriter.eraseOp(op);
return success();
}
};

/// Pattern to erase acc.declare_enter and its associated acc.declare_exit.
/// The declare_enter produces a token that is consumed by declare_exit.
class ACCDeclareEnterOpConversion
: public OpRewritePattern<acc::DeclareEnterOp> {
using OpRewritePattern<acc::DeclareEnterOp>::OpRewritePattern;

public:
LogicalResult matchAndRewrite(acc::DeclareEnterOp op,
PatternRewriter &rewriter) const override {
// If the enter token is used by an exit, erase exit first.
if (!op->use_empty()) {
assert(op->hasOneUse() && "expected one use");
auto exitOp = dyn_cast<acc::DeclareExitOp>(*op->getUsers().begin());
assert(exitOp && "expected declare exit op");
rewriter.eraseOp(exitOp);
}
rewriter.eraseOp(op);
return success();
}
};

} // namespace acc
} // namespace mlir

#endif // MLIR_DIALECT_OPENACC_TRANSFORMS_ACCSPECIALIZEPATTERNS_H
32 changes: 32 additions & 0 deletions mlir/include/mlir/Dialect/OpenACC/Transforms/Passes.h
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
#include "mlir/Dialect/Arith/IR/Arith.h"
#include "mlir/Dialect/MemRef/IR/MemRef.h"
#include "mlir/Dialect/OpenACC/OpenACC.h"
#include "mlir/Dialect/SCF/IR/SCF.h"
#include "mlir/Pass/Pass.h"

namespace mlir {
Expand All @@ -22,9 +23,40 @@ class FuncOp;

namespace acc {

class OpenACCSupport;

#define GEN_PASS_DECL
#include "mlir/Dialect/OpenACC/Transforms/Passes.h.inc"

//===----------------------------------------------------------------------===//
// ACCSpecializeForDevice patterns
//===----------------------------------------------------------------------===//

/// Populates all patterns for device specialization.
/// In specialized device code (such as specialized acc routine), many ACC
/// operations do not make sense because they are host-side constructs. This
/// function adds patterns to remove or transform them.
void populateACCSpecializeForDevicePatterns(RewritePatternSet &patterns);

//===----------------------------------------------------------------------===//
// ACCSpecializeForHost patterns
//===----------------------------------------------------------------------===//

/// Populates patterns for converting orphan ACC operations to host.
/// All patterns check that the operation is NOT inside or associated with a
/// compute region before converting.
/// @param enableLoopConversion Whether to convert orphan acc.loop operations.
void populateACCOrphanToHostPatterns(RewritePatternSet &patterns,
OpenACCSupport &accSupport,
bool enableLoopConversion = true);

/// Populates all patterns for host fallback path (when `if` clause evaluates
/// to false). In this mode, ALL ACC operations should be converted or removed.
/// @param enableLoopConversion Whether to convert orphan acc.loop operations.
void populateACCHostFallbackPatterns(RewritePatternSet &patterns,
OpenACCSupport &accSupport,
bool enableLoopConversion = true);

/// Generate the code for registering conversion passes.
#define GEN_PASS_REGISTRATION
#include "mlir/Dialect/OpenACC/Transforms/Passes.h.inc"
Expand Down
58 changes: 58 additions & 0 deletions mlir/include/mlir/Dialect/OpenACC/Transforms/Passes.td
Original file line number Diff line number Diff line change
Expand Up @@ -194,4 +194,62 @@ def ACCLoopTiling : Pass<"acc-loop-tiling", "mlir::func::FuncOp"> {
];
}

def ACCSpecializeForDevice : Pass<"acc-specialize-for-device", "mlir::func::FuncOp"> {
let summary = "Strip OpenACC constructs inside device code";
let description = [{
In a specialized acc routine or compute construct, many OpenACC operations
do not make sense because they are host-side constructs. This pass removes
or transforms these operations appropriately.

The following operations are handled:
- Data entry ops (replaced with var): acc.attach, acc.copyin, acc.create,
acc.declare_device_resident, acc.declare_link, acc.deviceptr,
acc.get_deviceptr, acc.nocreate, acc.present, acc.update_device,
acc.use_device
- Data exit ops (erased): acc.copyout, acc.delete, acc.detach,
acc.update_host
- Structured data (inline region): acc.data, acc.host_data,
acc.kernel_environment
- Unstructured data (erased): acc.enter_data, acc.exit_data, acc.update,
acc.declare_enter, acc.declare_exit
- Compute constructs (inline region): acc.parallel, acc.serial, acc.kernels
- Runtime ops (erased): acc.init, acc.shutdown, acc.set, acc.wait
}];
let dependentDialects = ["mlir::acc::OpenACCDialect"];
}

def ACCSpecializeForHost : Pass<"acc-specialize-for-host", "mlir::func::FuncOp"> {
let summary = "Convert OpenACC operations for host execution";
let description = [{
This pass converts OpenACC operations to host-compatible representations.
It serves as a conversion pass that transforms ACC constructs to enable
execution on the host rather than on accelerator devices.

There are two modes of operation:

1. Default mode (orphan operations only): Only orphan operations that are
not allowed outside compute regions are converted. Structured/unstructured
data constructs, compute constructs, and their associated data operations
are NOT removed.

2. Host fallback mode (enableHostFallback=true): ALL ACC operations within
the region are converted to host equivalents. This is used when the `if`
clause evaluates to false at runtime.

The following operations are handled:
- Atomic ops: converted to load/store operations
- Loop ops: converted to scf.for or scf.execute_region
- Data entry ops (orphan): replaced with var operand
- In host fallback mode: all data, compute, and runtime ops are removed
}];
let dependentDialects = ["mlir::acc::OpenACCDialect",
"mlir::scf::SCFDialect"];
let options = [
Option<"enableHostFallback", "enable-host-fallback", "bool", "false",
"Enable host fallback mode which converts ALL ACC operations, "
"not just orphan operations. Use this when the `if` clause "
"evaluates to false.">
];
}

#endif // MLIR_DIALECT_OPENACC_TRANSFORMS_PASSES
Loading
Loading