Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 16 additions & 0 deletions mlir/include/mlir/Dialect/OpenACC/Transforms/Passes.td
Original file line number Diff line number Diff line change
Expand Up @@ -136,4 +136,20 @@ def ACCImplicitRoutine : Pass<"acc-implicit-routine", "mlir::ModuleOp"> {
];
}

def ACCLegalizeSerial : Pass<"acc-legalize-serial", "mlir::func::FuncOp"> {
let summary = "Legalize OpenACC serial constructs";
let description = [{
This pass converts `acc.serial` constructs into `acc.parallel` constructs
with `num_gangs(1)`, `num_workers(1)`, and `vector_length(1)`.

This transformation simplifies processing of acc regions by unifying the
handling of serial and parallel constructs. Since an OpenACC serial region
executes sequentially (like a parallel region with a single gang, worker,
and vector), this conversion is semantically equivalent while enabling code
reuse in later compilation stages.
}];
let dependentDialects = ["mlir::acc::OpenACCDialect",
"mlir::arith::ArithDialect"];
}

#endif // MLIR_DIALECT_OPENACC_TRANSFORMS_PASSES
117 changes: 117 additions & 0 deletions mlir/lib/Dialect/OpenACC/Transforms/ACCLegalizeSerial.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,117 @@
//===- ACCLegalizeSerial.cpp - Legalize ACC Serial region -----------------===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
//
// This pass converts acc.serial into acc.parallel with num_gangs(1)
// num_workers(1) vector_length(1).
//
// This transformation simplifies processing of acc regions by unifying the
// handling of serial and parallel constructs. Since an OpenACC serial region
// executes sequentially (like a parallel region with a single gang, worker, and
// vector), this conversion is semantically equivalent while enabling code reuse
// in later compilation stages.
//
//===----------------------------------------------------------------------===//

#include "mlir/Dialect/OpenACC/Transforms/Passes.h"

#include "mlir/Dialect/Arith/IR/Arith.h"
#include "mlir/Dialect/Func/IR/FuncOps.h"
#include "mlir/Dialect/OpenACC/OpenACC.h"
#include "mlir/IR/Builders.h"
#include "mlir/IR/BuiltinAttributes.h"
#include "mlir/IR/Location.h"
#include "mlir/IR/MLIRContext.h"
#include "mlir/IR/PatternMatch.h"
#include "mlir/IR/Region.h"
#include "mlir/IR/Value.h"
#include "mlir/Support/LLVM.h"
#include "mlir/Support/LogicalResult.h"
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
#include "llvm/Support/Debug.h"

namespace mlir {
namespace acc {
#define GEN_PASS_DEF_ACCLEGALIZESERIAL
#include "mlir/Dialect/OpenACC/Transforms/Passes.h.inc"
} // namespace acc
} // namespace mlir

#define DEBUG_TYPE "acc-legalize-serial"

namespace {
using namespace mlir;

struct ACCSerialOpConversion : public OpRewritePattern<acc::SerialOp> {
using OpRewritePattern<acc::SerialOp>::OpRewritePattern;

LogicalResult matchAndRewrite(acc::SerialOp serialOp,
PatternRewriter &rewriter) const override {

const Location loc = serialOp.getLoc();

// Create a container holding the constant value of 1 for use as the
// num_gangs, num_workers, and vector_length attributes.
llvm::SmallVector<mlir::Value> numValues;
auto value = arith::ConstantIntOp::create(rewriter, loc, 1, 32);
numValues.push_back(value);

// Since num_gangs is specified as both attributes and values, create a
// segment attribute.
llvm::SmallVector<int32_t> numGangsSegments;
numGangsSegments.push_back(numValues.size());
auto gangSegmentsAttr = rewriter.getDenseI32ArrayAttr(numGangsSegments);

// Create a device_type attribute set to `none` which ensures that
// the parallel dimensions specification applies to the default clauses.
llvm::SmallVector<mlir::Attribute> crtDeviceTypes;
auto crtDeviceTypeAttr = mlir::acc::DeviceTypeAttr::get(
rewriter.getContext(), mlir::acc::DeviceType::None);
crtDeviceTypes.push_back(crtDeviceTypeAttr);
auto devTypeAttr =
mlir::ArrayAttr::get(rewriter.getContext(), crtDeviceTypes);

LLVM_DEBUG(llvm::dbgs() << "acc.serial OP: " << serialOp << "\n");

// Create a new acc.parallel op with the same operands - except include the
// num_gangs, num_workers, and vector_length attributes.
acc::ParallelOp parOp = acc::ParallelOp::create(
rewriter, loc, serialOp.getAsyncOperands(),
serialOp.getAsyncOperandsDeviceTypeAttr(), serialOp.getAsyncOnlyAttr(),
serialOp.getWaitOperands(), serialOp.getWaitOperandsSegmentsAttr(),
serialOp.getWaitOperandsDeviceTypeAttr(),
serialOp.getHasWaitDevnumAttr(), serialOp.getWaitOnlyAttr(), numValues,
gangSegmentsAttr, devTypeAttr, numValues, devTypeAttr, numValues,
devTypeAttr, serialOp.getIfCond(), serialOp.getSelfCond(),
serialOp.getSelfAttrAttr(), serialOp.getReductionOperands(),
serialOp.getPrivateOperands(), serialOp.getFirstprivateOperands(),
serialOp.getDataClauseOperands(), serialOp.getDefaultAttrAttr(),
serialOp.getCombinedAttr());

parOp.getRegion().takeBody(serialOp.getRegion());

LLVM_DEBUG(llvm::dbgs() << "acc.parallel OP: " << parOp << "\n");
rewriter.replaceOp(serialOp, parOp);

return success();
}
};

class ACCLegalizeSerial
: public mlir::acc::impl::ACCLegalizeSerialBase<ACCLegalizeSerial> {
public:
using ACCLegalizeSerialBase<ACCLegalizeSerial>::ACCLegalizeSerialBase;
void runOnOperation() override {
func::FuncOp funcOp = getOperation();
MLIRContext *context = funcOp.getContext();
RewritePatternSet patterns(context);
patterns.insert<ACCSerialOpConversion>(context);
(void)applyPatternsGreedily(funcOp, std::move(patterns));
}
};

} // namespace
1 change: 1 addition & 0 deletions mlir/lib/Dialect/OpenACC/Transforms/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ add_mlir_dialect_library(MLIROpenACCTransforms
ACCImplicitData.cpp
ACCImplicitDeclare.cpp
ACCImplicitRoutine.cpp
ACCLegalizeSerial.cpp
LegalizeDataValues.cpp

ADDITIONAL_HEADER_DIRS
Expand Down
164 changes: 164 additions & 0 deletions mlir/test/Dialect/OpenACC/legalize-serial.mlir
Original file line number Diff line number Diff line change
@@ -0,0 +1,164 @@
// RUN: mlir-opt %s -acc-legalize-serial | FileCheck %s

acc.private.recipe @privatization_memref_10_f32 : memref<10xf32> init {
^bb0(%arg0: memref<10xf32>):
%0 = memref.alloc() : memref<10xf32>
acc.yield %0 : memref<10xf32>
} destroy {
^bb0(%arg0: memref<10xf32>):
memref.dealloc %arg0 : memref<10xf32>
acc.terminator
}

acc.private.recipe @privatization_memref_10_10_f32 : memref<10x10xf32> init {
^bb0(%arg0: memref<10x10xf32>):
%0 = memref.alloc() : memref<10x10xf32>
acc.yield %0 : memref<10x10xf32>
} destroy {
^bb0(%arg0: memref<10x10xf32>):
memref.dealloc %arg0 : memref<10x10xf32>
acc.terminator
}

acc.firstprivate.recipe @firstprivatization_memref_10xf32 : memref<10xf32> init {
^bb0(%arg0: memref<10xf32>):
%0 = memref.alloc() : memref<10xf32>
acc.yield %0 : memref<10xf32>
} copy {
^bb0(%arg0: memref<10xf32>, %arg1: memref<10xf32>):
acc.terminator
} destroy {
^bb0(%arg0: memref<10xf32>):
memref.dealloc %arg0 : memref<10xf32>
acc.terminator
}

acc.reduction.recipe @reduction_add_i64 : i64 reduction_operator<add> init {
^bb0(%0: i64):
%1 = arith.constant 0 : i64
acc.yield %1 : i64
} combiner {
^bb0(%0: i64, %1: i64):
%2 = arith.addi %0, %1 : i64
acc.yield %2 : i64
}

acc.reduction.recipe @reduction_add_memref_i64 : memref<i64> reduction_operator<add> init {
^bb0(%arg0: memref<i64>):
%0 = memref.alloca() : memref<i64>
%c0 = arith.constant 0 : i64
memref.store %c0, %0[] : memref<i64>
acc.yield %0 : memref<i64>
} combiner {
^bb0(%arg0: memref<i64>, %arg1: memref<i64>):
%0 = memref.load %arg0[] : memref<i64>
%1 = memref.load %arg1[] : memref<i64>
%2 = arith.addi %0, %1 : i64
memref.store %2, %arg0[] : memref<i64>
acc.terminator
}

// CHECK: func.func @testserialop(%[[VAL_0:.*]]: memref<10xf32>, %[[VAL_1:.*]]: memref<10xf32>, %[[VAL_2:.*]]: memref<10x10xf32>) {
// CHECK: %[[VAL_3:.*]] = arith.constant 1 : i64
// CHECK: %[[VAL_4:.*]] = arith.constant 1 : i32
// CHECK: %[[VAL_5:.*]] = arith.constant 1 : index
// CHECK: acc.parallel async(%[[VAL_3]] : i64) num_gangs({%[[VAL_4]] : i32}) num_workers(%[[VAL_4]] : i32) vector_length(%[[VAL_4]] : i32) {
// CHECK: }
// CHECK: acc.parallel async(%[[VAL_4]] : i32) num_gangs({%[[VAL_4]] : i32}) num_workers(%[[VAL_4]] : i32) vector_length(%[[VAL_4]] : i32) {
// CHECK: }
// CHECK: acc.parallel async(%[[VAL_5]] : index) num_gangs({%[[VAL_4]] : i32}) num_workers(%[[VAL_4]] : i32) vector_length(%[[VAL_4]] : i32) {
// CHECK: }
// CHECK: acc.parallel num_gangs({%[[VAL_4]] : i32}) num_workers(%[[VAL_4]] : i32) vector_length(%[[VAL_4]] : i32) wait({%[[VAL_3]] : i64}) {
// CHECK: }
// CHECK: acc.parallel num_gangs({%[[VAL_4]] : i32}) num_workers(%[[VAL_4]] : i32) vector_length(%[[VAL_4]] : i32) wait({%[[VAL_4]] : i32}) {
// CHECK: }
// CHECK: acc.parallel num_gangs({%[[VAL_4]] : i32}) num_workers(%[[VAL_4]] : i32) vector_length(%[[VAL_4]] : i32) wait({%[[VAL_5]] : index}) {
// CHECK: }
// CHECK: acc.parallel num_gangs({%[[VAL_4]] : i32}) num_workers(%[[VAL_4]] : i32) vector_length(%[[VAL_4]] : i32) wait({%[[VAL_3]] : i64, %[[VAL_4]] : i32, %[[VAL_5]] : index}) {
// CHECK: }
// CHECK: %[[VAL_6:.*]] = acc.firstprivate varPtr(%[[VAL_1]] : memref<10xf32>) recipe(@firstprivatization_memref_10xf32) -> memref<10xf32>
// CHECK: %[[VAL_9:.*]] = acc.private varPtr(%[[VAL_2]] : memref<10x10xf32>) recipe(@privatization_memref_10_10_f32) -> memref<10x10xf32>
// CHECK: acc.parallel firstprivate(%[[VAL_6]] : memref<10xf32>) num_gangs({%[[VAL_4]] : i32}) num_workers(%[[VAL_4]] : i32) private(%[[VAL_9]] : memref<10x10xf32>) vector_length(%[[VAL_4]] : i32) {
// CHECK: }
// CHECK: %[[VAL_7:.*]] = acc.copyin varPtr(%[[VAL_0]] : memref<10xf32>) -> memref<10xf32> {dataClause = #acc<data_clause acc_copy>}
// CHECK: acc.parallel dataOperands(%[[VAL_7]] : memref<10xf32>) num_gangs({%[[VAL_4]] : i32}) num_workers(%[[VAL_4]] : i32) vector_length(%[[VAL_4]] : i32) {
// CHECK: }
// CHECK: %[[I64MEM:.*]] = memref.alloca() : memref<i64>
// CHECK: memref.store %[[VAL_3]], %[[I64MEM]][] : memref<i64>
// CHECK: %[[VAL_10:.*]] = acc.reduction varPtr(%[[I64MEM]] : memref<i64>) recipe(@reduction_add_memref_i64) -> memref<i64>
// CHECK: acc.parallel num_gangs({%[[VAL_4]] : i32}) num_workers(%[[VAL_4]] : i32) vector_length(%[[VAL_4]] : i32) reduction(%[[VAL_10]] : memref<i64>) {
// CHECK: }
// CHECK: acc.parallel combined(loop) num_gangs({%[[VAL_4]] : i32}) num_workers(%[[VAL_4]] : i32) vector_length(%[[VAL_4]] : i32) {
// CHECK: acc.loop combined(serial) control(%{{.*}} : index) = (%[[VAL_5]] : index) to (%[[VAL_5]] : index) step (%[[VAL_5]] : index) {
// CHECK: acc.yield
// CHECK: } attributes {seq = [#acc.device_type<none>]}
// CHECK: acc.terminator
// CHECK: }
// CHECK: acc.parallel num_gangs({%[[VAL_4]] : i32}) num_workers(%[[VAL_4]] : i32) vector_length(%[[VAL_4]] : i32) {
// CHECK: } attributes {defaultAttr = #acc<defaultvalue none>}
// CHECK: acc.parallel num_gangs({%[[VAL_4]] : i32}) num_workers(%[[VAL_4]] : i32) vector_length(%[[VAL_4]] : i32) {
// CHECK: } attributes {defaultAttr = #acc<defaultvalue present>}
// CHECK: acc.parallel num_gangs({%[[VAL_4]] : i32}) num_workers(%[[VAL_4]] : i32) vector_length(%[[VAL_4]] : i32) {
// CHECK: }
// CHECK: acc.parallel num_gangs({%[[VAL_4]] : i32}) num_workers(%[[VAL_4]] : i32) vector_length(%[[VAL_4]] : i32) {
// CHECK: }
// CHECK: acc.parallel num_gangs({%[[VAL_4]] : i32}) num_workers(%[[VAL_4]] : i32) vector_length(%[[VAL_4]] : i32) {
// CHECK: } attributes {selfAttr}
// CHECK: acc.parallel num_gangs({%[[VAL_4]] : i32}) num_workers(%[[VAL_4]] : i32) vector_length(%[[VAL_4]] : i32) {
// CHECK: acc.yield
// CHECK: } attributes {selfAttr}
// CHECK: return
// CHECK: }

func.func @testserialop(%a: memref<10xf32>, %b: memref<10xf32>, %c: memref<10x10xf32>) -> () {
%i64value = arith.constant 1 : i64
%i32value = arith.constant 1 : i32
%idxValue = arith.constant 1 : index
acc.serial async(%i64value: i64) {
}
acc.serial async(%i32value: i32) {
}
acc.serial async(%idxValue: index) {
}
acc.serial wait({%i64value: i64}) {
}
acc.serial wait({%i32value: i32}) {
}
acc.serial wait({%idxValue: index}) {
}
acc.serial wait({%i64value : i64, %i32value : i32, %idxValue : index}) {
}
%firstprivate = acc.firstprivate varPtr(%b : memref<10xf32>) recipe(@firstprivatization_memref_10xf32) -> memref<10xf32>
%c_private = acc.private varPtr(%c : memref<10x10xf32>) recipe(@privatization_memref_10_10_f32) -> memref<10x10xf32>
acc.serial private(%c_private : memref<10x10xf32>) firstprivate(%firstprivate : memref<10xf32>) {
}
%copyinfromcopy = acc.copyin varPtr(%a : memref<10xf32>) -> memref<10xf32> {dataClause = #acc<data_clause acc_copy>}
acc.serial dataOperands(%copyinfromcopy : memref<10xf32>) {
}
%i64mem = memref.alloca() : memref<i64>
memref.store %i64value, %i64mem[] : memref<i64>
%i64reduction = acc.reduction varPtr(%i64mem : memref<i64>) recipe(@reduction_add_memref_i64) -> memref<i64>
acc.serial reduction(%i64reduction : memref<i64>) {
}
acc.serial combined(loop) {
acc.loop combined(serial) control(%arg3 : index) = (%idxValue : index) to (%idxValue : index) step (%idxValue : index) {
acc.yield
} attributes {seq = [#acc.device_type<none>]}
acc.terminator
}
acc.serial {
} attributes {defaultAttr = #acc<defaultvalue none>}
acc.serial {
} attributes {defaultAttr = #acc<defaultvalue present>}
acc.serial {
} attributes {asyncAttr}
acc.serial {
} attributes {waitAttr}
acc.serial {
} attributes {selfAttr}
acc.serial {
acc.yield
} attributes {selfAttr}
return
}