Skip to content

Commit

Permalink
[flang][openacc] Add data operands conversion from FIR
Browse files Browse the repository at this point in the history
This patch revive an old PR attempt [1] to perform the
data operands conversion needed for translation to LLVMIR.

This is currently not supporting box/class type since they will
normally not reach this pass when the proposed change in this RFC [2]
are implemented.

[1] flang-compiler#915
[2] https://discourse.llvm.org/t/rfc-openacc-dialect-data-operation-improvements/69825/2

Depends on D147824

Reviewed By: PeteSteinfeld, razvanlupusoru

Differential Revision: https://reviews.llvm.org/D147825
  • Loading branch information
clementval committed Apr 10, 2023
1 parent 016970d commit 68bcd64
Show file tree
Hide file tree
Showing 5 changed files with 278 additions and 1 deletion.
2 changes: 2 additions & 0 deletions flang/include/flang/Optimizer/Transforms/Passes.h
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,7 @@ namespace fir {
#define GEN_PASS_DECL_SIMPLIFYREGIONLITE
#define GEN_PASS_DECL_ALGEBRAICSIMPLIFICATION
#define GEN_PASS_DECL_POLYMORPHICOPCONVERSION
#define GEN_PASS_DECL_OPENACCDATAOPERANDCONVERSION
#include "flang/Optimizer/Transforms/Passes.h.inc"

std::unique_ptr<mlir::Pass> createAbstractResultOnFuncOptPass();
Expand Down Expand Up @@ -70,6 +71,7 @@ std::unique_ptr<mlir::Pass> createAlgebraicSimplificationPass();
std::unique_ptr<mlir::Pass>
createAlgebraicSimplificationPass(const mlir::GreedyRewriteConfig &config);
std::unique_ptr<mlir::Pass> createPolymorphicOpConversionPass();
std::unique_ptr<mlir::Pass> createOpenACCDataOperandConversionPass();

// declarative passes
#define GEN_PASS_REGISTRATION
Expand Down
11 changes: 10 additions & 1 deletion flang/include/flang/Optimizer/Transforms/Passes.td
Original file line number Diff line number Diff line change
Expand Up @@ -284,5 +284,14 @@ def PolymorphicOpConversion : Pass<"fir-polymorphic-op", "::mlir::func::FuncOp">
];
}


def OpenACCDataOperandConversion : Pass<"fir-openacc-data-operand-conversion", "::mlir::func::FuncOp"> {
let summary = "Convert the FIR operands in OpenACC ops to LLVM dialect";
let dependentDialects = ["mlir::LLVM::LLVMDialect"];
let options = [
Option<"useOpaquePointers", "use-opaque-pointers", "bool",
/*default=*/"true", "Generate LLVM IR using opaque pointers "
"instead of typed pointers">,
];
}

#endif // FLANG_OPTIMIZER_TRANSFORMS_PASSES
2 changes: 2 additions & 0 deletions flang/lib/Optimizer/Transforms/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -15,13 +15,15 @@ add_flang_library(FIRTransforms
SimplifyIntrinsics.cpp
AddDebugFoundation.cpp
PolymorphicOpConversion.cpp
OpenACC/OpenACCDataOperandConversion.cpp

DEPENDS
FIRDialect
FIROptTransformsPassIncGen

LINK_LIBS
FIRBuilder
FIRCodeGen
FIRDialect
FIRDialectSupport
FIRSupport
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,180 @@
//===- OpenACCDataOperandConversion.cpp - OpenACC data operand conversion -===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//

#include "flang/Optimizer/Dialect/FIRDialect.h"
#include "flang/Optimizer/Transforms/Passes.h"
#include "mlir/Conversion/LLVMCommon/Pattern.h"
#include "mlir/Conversion/OpenACCToLLVM/ConvertOpenACCToLLVM.h"
#include "mlir/Dialect/Func/IR/FuncOps.h"
#include "mlir/Dialect/LLVMIR/LLVMDialect.h"
#include "mlir/Dialect/OpenACC/OpenACC.h"
#include "mlir/IR/Builders.h"
#include "mlir/IR/BuiltinOps.h"
#include "mlir/Pass/Pass.h"

namespace fir {
#define GEN_PASS_DEF_OPENACCDATAOPERANDCONVERSION
#include "flang/Optimizer/Transforms/Passes.h.inc"
} // namespace fir

#define DEBUG_TYPE "flang-openacc-conversion"
#include "../CodeGen/TypeConverter.h"

using namespace fir;
using namespace mlir;

//===----------------------------------------------------------------------===//
// Conversion patterns
//===----------------------------------------------------------------------===//

namespace {

template <typename Op>
class LegalizeDataOpForLLVMTranslation : public ConvertOpToLLVMPattern<Op> {
using ConvertOpToLLVMPattern<Op>::ConvertOpToLLVMPattern;

LogicalResult
matchAndRewrite(Op op, typename Op::Adaptor adaptor,
ConversionPatternRewriter &builder) const override {
Location loc = op.getLoc();
fir::LLVMTypeConverter &converter =
*static_cast<fir::LLVMTypeConverter *>(this->getTypeConverter());

unsigned numDataOperands = op.getNumDataOperands();

// Keep the non data operands without modification.
auto nonDataOperands = adaptor.getOperands().take_front(
adaptor.getOperands().size() - numDataOperands);
SmallVector<Value> convertedOperands;
convertedOperands.append(nonDataOperands.begin(), nonDataOperands.end());

// Go over the data operand and legalize them for translation.
for (unsigned idx = 0; idx < numDataOperands; ++idx) {
Value originalDataOperand = op.getDataOperand(idx);
if (auto refTy =
originalDataOperand.getType().dyn_cast<fir::ReferenceType>()) {
if (refTy.getEleTy().isa<fir::BaseBoxType>())
return builder.notifyMatchFailure(op, "BaseBoxType not supported");
mlir::Type convertedType =
converter.convertType(refTy).cast<mlir::LLVM::LLVMPointerType>();
mlir::Value castedOperand =
builder
.create<mlir::UnrealizedConversionCastOp>(loc, convertedType,
originalDataOperand)
.getResult(0);
convertedOperands.push_back(castedOperand);
} else {
// Type not supported.
return builder.notifyMatchFailure(op, "expecting a reference type");
}
}

builder.replaceOpWithNewOp<Op>(op, TypeRange(), convertedOperands,
op.getOperation()->getAttrs());

return success();
}
};
} // namespace

namespace {
struct OpenACCDataOperandConversion
: public fir::impl::OpenACCDataOperandConversionBase<
OpenACCDataOperandConversion> {
using Base::Base;

void runOnOperation() override;
};
} // namespace

void OpenACCDataOperandConversion::runOnOperation() {
auto op = getOperation();
auto *context = op.getContext();

// Convert to OpenACC operations with LLVM IR dialect
RewritePatternSet patterns(context);
LowerToLLVMOptions options(context);
options.useOpaquePointers = useOpaquePointers;
fir::LLVMTypeConverter converter(
op.getOperation()->getParentOfType<mlir::ModuleOp>(), true);
patterns.add<LegalizeDataOpForLLVMTranslation<acc::DataOp>>(converter);
patterns.add<LegalizeDataOpForLLVMTranslation<acc::EnterDataOp>>(converter);
patterns.add<LegalizeDataOpForLLVMTranslation<acc::ExitDataOp>>(converter);
patterns.add<LegalizeDataOpForLLVMTranslation<acc::ParallelOp>>(converter);
patterns.add<LegalizeDataOpForLLVMTranslation<acc::UpdateOp>>(converter);

ConversionTarget target(*context);
target.addLegalDialect<fir::FIROpsDialect>();
target.addLegalDialect<LLVM::LLVMDialect>();
target.addLegalOp<UnrealizedConversionCastOp>();

auto allDataOperandsAreConverted = [](ValueRange operands) {
for (Value operand : operands) {
if (!operand.getType().isa<LLVM::LLVMPointerType>())
return false;
}
return true;
};

target.addDynamicallyLegalOp<acc::DataOp>(
[allDataOperandsAreConverted](acc::DataOp op) {
return allDataOperandsAreConverted(op.getCopyOperands()) &&
allDataOperandsAreConverted(op.getCopyinOperands()) &&
allDataOperandsAreConverted(op.getCopyinReadonlyOperands()) &&
allDataOperandsAreConverted(op.getCopyoutOperands()) &&
allDataOperandsAreConverted(op.getCopyoutZeroOperands()) &&
allDataOperandsAreConverted(op.getCreateOperands()) &&
allDataOperandsAreConverted(op.getCreateZeroOperands()) &&
allDataOperandsAreConverted(op.getNoCreateOperands()) &&
allDataOperandsAreConverted(op.getPresentOperands()) &&
allDataOperandsAreConverted(op.getDeviceptrOperands()) &&
allDataOperandsAreConverted(op.getAttachOperands());
});

target.addDynamicallyLegalOp<acc::EnterDataOp>(
[allDataOperandsAreConverted](acc::EnterDataOp op) {
return allDataOperandsAreConverted(op.getCopyinOperands()) &&
allDataOperandsAreConverted(op.getCreateOperands()) &&
allDataOperandsAreConverted(op.getCreateZeroOperands()) &&
allDataOperandsAreConverted(op.getAttachOperands());
});

target.addDynamicallyLegalOp<acc::ExitDataOp>(
[allDataOperandsAreConverted](acc::ExitDataOp op) {
return allDataOperandsAreConverted(op.getCopyoutOperands()) &&
allDataOperandsAreConverted(op.getDeleteOperands()) &&
allDataOperandsAreConverted(op.getDetachOperands());
});

target.addDynamicallyLegalOp<acc::ParallelOp>(
[allDataOperandsAreConverted](acc::ParallelOp op) {
return allDataOperandsAreConverted(op.getReductionOperands()) &&
allDataOperandsAreConverted(op.getCopyOperands()) &&
allDataOperandsAreConverted(op.getCopyinOperands()) &&
allDataOperandsAreConverted(op.getCopyinReadonlyOperands()) &&
allDataOperandsAreConverted(op.getCopyoutOperands()) &&
allDataOperandsAreConverted(op.getCopyoutZeroOperands()) &&
allDataOperandsAreConverted(op.getCreateOperands()) &&
allDataOperandsAreConverted(op.getCreateZeroOperands()) &&
allDataOperandsAreConverted(op.getNoCreateOperands()) &&
allDataOperandsAreConverted(op.getPresentOperands()) &&
allDataOperandsAreConverted(op.getDevicePtrOperands()) &&
allDataOperandsAreConverted(op.getAttachOperands()) &&
allDataOperandsAreConverted(op.getGangPrivateOperands()) &&
allDataOperandsAreConverted(op.getGangFirstPrivateOperands());
});

target.addDynamicallyLegalOp<acc::UpdateOp>(
[allDataOperandsAreConverted](acc::UpdateOp op) {
return allDataOperandsAreConverted(op.getHostOperands()) &&
allDataOperandsAreConverted(op.getDeviceOperands());
});

if (failed(applyPartialConversion(op, target, std::move(patterns))))
signalPassFailure();
}
84 changes: 84 additions & 0 deletions flang/test/Transforms/OpenACC/convert-data-operands-to-llvmir.fir
Original file line number Diff line number Diff line change
@@ -0,0 +1,84 @@
// RUN: fir-opt -fir-openacc-data-operand-conversion='use-opaque-pointers=1' -split-input-file %s | FileCheck %s

func.func @_QQsub1() attributes {fir.bindc_name = "arr"} {
%0 = fir.address_of(@_QFEa) : !fir.ref<!fir.array<10xf32>>
acc.data copy(%0 : !fir.ref<!fir.array<10xf32>>) {
acc.terminator
}
return
}

// CHECK-LABEL: func.func @_QQsub1() attributes {fir.bindc_name = "arr"} {
// CHECK: %[[ADDR:.*]] = fir.address_of(@_QFEa) : !fir.ref<!fir.array<10xf32>>
// CHECK: %[[CAST:.*]] = builtin.unrealized_conversion_cast %[[ADDR]] : !fir.ref<!fir.array<10xf32>> to !llvm.ptr<array<10 x f32>>
// CHECK: acc.data copy(%[[CAST]] : !llvm.ptr<array<10 x f32>>)

// -----

func.func @_QQsub_enter_exit() attributes {fir.bindc_name = "a"} {
%0 = fir.address_of(@_QFEa) : !fir.ref<!fir.array<10xf32>>
acc.enter_data copyin(%0 : !fir.ref<!fir.array<10xf32>>)
acc.exit_data copyout(%0 : !fir.ref<!fir.array<10xf32>>)
return
}

// CHECK-LABEL: func.func @_QQsub_enter_exit() attributes {fir.bindc_name = "a"} {
// CHECK: %[[ADDR:.*]] = fir.address_of(@_QFEa) : !fir.ref<!fir.array<10xf32>>
// CHECK: %[[CAST0:.*]] = builtin.unrealized_conversion_cast %[[ADDR]] : !fir.ref<!fir.array<10xf32>> to !llvm.ptr<array<10 x f32>>
// CHECK: acc.enter_data copyin(%[[CAST0]] : !llvm.ptr<array<10 x f32>>)
// CHECK: %[[CAST1:.*]] = builtin.unrealized_conversion_cast %[[ADDR]] : !fir.ref<!fir.array<10xf32>> to !llvm.ptr<array<10 x f32>>
// CHECK: acc.exit_data copyout(%[[CAST1]] : !llvm.ptr<array<10 x f32>>)

// -----

func.func @_QQsub_update() attributes {fir.bindc_name = "a"} {
%0 = fir.address_of(@_QFEa) : !fir.ref<!fir.array<10xf32>>
acc.update device(%0 : !fir.ref<!fir.array<10xf32>>)
return
}

// CHECK-LABEL: func.func @_QQsub_update() attributes {fir.bindc_name = "a"} {
// CHECK: %[[ADDR:.*]] = fir.address_of(@_QFEa) : !fir.ref<!fir.array<10xf32>>
// CHECK: %[[CAST:.*]] = builtin.unrealized_conversion_cast %[[ADDR]] : !fir.ref<!fir.array<10xf32>> to !llvm.ptr<array<10 x f32>>
// CHECK: acc.update device(%[[CAST]] : !llvm.ptr<array<10 x f32>>)

// -----

func.func @_QQsub_parallel() attributes {fir.bindc_name = "test"} {
%0 = fir.address_of(@_QFEa) : !fir.ref<!fir.array<10xf32>>
%1 = fir.alloca i32 {bindc_name = "i", uniq_name = "_QFEi"}
acc.parallel copyin(%0: !fir.ref<!fir.array<10xf32>>) {
acc.loop {
%c1_i32 = arith.constant 1 : i32
%2 = fir.convert %c1_i32 : (i32) -> index
%c10_i32 = arith.constant 10 : i32
%3 = fir.convert %c10_i32 : (i32) -> index
%c1 = arith.constant 1 : index
%4 = fir.convert %2 : (index) -> i32
%5:2 = fir.do_loop %arg0 = %2 to %3 step %c1 iter_args(%arg1 = %4) -> (index, i32) {
fir.store %arg1 to %1 : !fir.ref<i32>
%6 = fir.load %1 : !fir.ref<i32>
%7 = fir.convert %6 : (i32) -> f32
%c10_i64 = arith.constant 10 : i64
%c1_i64 = arith.constant 1 : i64
%8 = arith.subi %c10_i64, %c1_i64 : i64
%9 = fir.coordinate_of %0, %8 : (!fir.ref<!fir.array<10xf32>>, i64) -> !fir.ref<f32>
fir.store %7 to %9 : !fir.ref<f32>
%10 = arith.addi %arg0, %c1 : index
%11 = fir.convert %c1 : (index) -> i32
%12 = fir.load %1 : !fir.ref<i32>
%13 = arith.addi %12, %11 : i32
fir.result %10, %13 : index, i32
}
fir.store %5#1 to %1 : !fir.ref<i32>
acc.yield
}
acc.yield
}
return
}

// CHECK-LABEL: func.func @_QQsub_parallel() attributes {fir.bindc_name = "test"} {
// CHECK: %[[ADDR:.*]] = fir.address_of(@_QFEa) : !fir.ref<!fir.array<10xf32>>
// CHECK: %[[CAST:.*]] = builtin.unrealized_conversion_cast %[[ADDR]] : !fir.ref<!fir.array<10xf32>> to !llvm.ptr<array<10 x f32>>
// CHECK: acc.parallel copyin(%[[CAST]]: !llvm.ptr<array<10 x f32>>) {

0 comments on commit 68bcd64

Please sign in to comment.