Skip to content

Commit

Permalink
[MLIR][OpenMP]Add Flang lowering support for device_ptr and device_ad…
Browse files Browse the repository at this point in the history
…dr clauses

Add lowering support for the use_device_ptr and use_Device_addr clauses for the Target Data directive.

Depends on D152822

Differential Revision: https://reviews.llvm.org/D152824
  • Loading branch information
TIFitis committed Jun 22, 2023
1 parent 0657ae3 commit d21580c
Show file tree
Hide file tree
Showing 2 changed files with 131 additions and 22 deletions.
117 changes: 95 additions & 22 deletions flang/lib/Lower/OpenMP.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -723,6 +723,48 @@ createBodyOfOp(Op &op, Fortran::lower::AbstractConverter &converter,
}
}

static void createBodyOfTargetOp(
Fortran::lower::AbstractConverter &converter, mlir::omp::DataOp &dataOp,
const llvm::SmallVector<mlir::Type> &useDeviceTypes,
const llvm::SmallVector<mlir::Location> &useDeviceLocs,
const SmallVector<const Fortran::semantics::Symbol *> &useDeviceSymbols,
const mlir::Location &currentLocation) {
fir::FirOpBuilder &firOpBuilder = converter.getFirOpBuilder();
mlir::Region &region = dataOp.getRegion();

firOpBuilder.createBlock(&region, {}, useDeviceTypes, useDeviceLocs);
firOpBuilder.create<mlir::omp::TerminatorOp>(currentLocation);
firOpBuilder.setInsertionPointToStart(&region.front());

unsigned argIndex = 0;
for (auto *sym : useDeviceSymbols) {
const mlir::BlockArgument &arg = region.front().getArgument(argIndex);
mlir::Value val = fir::getBase(arg);
fir::ExtendedValue extVal = converter.getSymbolExtendedValue(*sym);
if (auto refType = val.getType().dyn_cast<fir::ReferenceType>()) {
if (fir::isa_builtin_cptr_type(refType.getElementType())) {
converter.bindSymbol(*sym, val);
} else {
extVal.match(
[&](const fir::MutableBoxValue &mbv) {
converter.bindSymbol(
*sym,
fir::MutableBoxValue(
val, fir::factory::getNonDeferredLenParams(extVal), {}));
},
[&](const auto &) {
TODO(converter.getCurrentLocation(),
"use_device clause operand unsupported type");
});
}
} else {
TODO(converter.getCurrentLocation(),
"use_device clause operand unsupported type");
}
argIndex++;
}
}

static void createTargetOp(Fortran::lower::AbstractConverter &converter,
const Fortran::parser::OmpClauseList &opClauseList,
const llvm::omp::Directive &directive,
Expand All @@ -732,13 +774,24 @@ static void createTargetOp(Fortran::lower::AbstractConverter &converter,

mlir::Value ifClauseOperand, deviceOperand, threadLmtOperand;
mlir::UnitAttr nowaitAttr;
llvm::SmallVector<mlir::Value> useDevicePtrOperand, useDeviceAddrOperand,
mapOperands;
llvm::SmallVector<mlir::Value> mapOperands, devicePtrOperands,
deviceAddrOperands;
llvm::SmallVector<mlir::IntegerAttr> mapTypes;
llvm::SmallVector<mlir::Type> useDeviceTypes;
llvm::SmallVector<mlir::Location> useDeviceLocs;
SmallVector<const Fortran::semantics::Symbol *> useDeviceSymbols;

/// Check for unsupported map operand types.
auto checkType = [](auto currentLocation, mlir::Type type) {
if (auto refType = type.dyn_cast<fir::ReferenceType>())
type = refType.getElementType();
if (auto boxType = type.dyn_cast_or_null<fir::BoxType>())
if (!boxType.getElementType().isa<fir::PointerType>())
TODO(currentLocation, "OMPD_target_data MapOperand BoxType");
};

auto addMapClause = [&firOpBuilder, &converter, &mapOperands,
&mapTypes](const auto &mapClause,
mlir::Location &currentLocation) {
auto addMapClause = [&](const auto &mapClause,
mlir::Location &currentLocation) {
auto mapType = std::get<Fortran::parser::OmpMapType::Type>(
std::get<std::optional<Fortran::parser::OmpMapType>>(mapClause->v.t)
->t);
Expand Down Expand Up @@ -793,18 +846,25 @@ static void createTargetOp(Fortran::lower::AbstractConverter &converter,
converter, mapOperand);

for (mlir::Value mapOp : mapOperand) {
/// Check for unsupported map operand types.
mlir::Type checkType = mapOp.getType();
if (auto refType = checkType.dyn_cast<fir::ReferenceType>())
checkType = refType.getElementType();
if (checkType.isa<fir::BoxType>())
TODO(currentLocation, "OMPD_target_data MapOperand BoxType");

checkType(mapOp.getLoc(), mapOp.getType());
mapOperands.push_back(mapOp);
mapTypes.push_back(mapTypeAttr);
}
};

auto addUseDeviceClause = [&](const auto &useDeviceClause, auto &operands) {
genObjectList(useDeviceClause, converter, operands);
for (auto &operand : operands) {
checkType(operand.getLoc(), operand.getType());
useDeviceTypes.push_back(operand.getType());
useDeviceLocs.push_back(operand.getLoc());
}
for (const Fortran::parser::OmpObject &ompObject : useDeviceClause.v) {
Fortran::semantics::Symbol *sym = getOmpObjectSymbol(ompObject);
useDeviceSymbols.push_back(sym);
}
};

for (const Fortran::parser::OmpClause &clause : opClauseList.v) {
mlir::Location currentLocation = converter.genLocation(clause.source);
if (const auto &ifClause =
Expand All @@ -825,19 +885,21 @@ static void createTargetOp(Fortran::lower::AbstractConverter &converter,
deviceOperand =
fir::getBase(converter.genExprValue(*deviceExpr, stmtCtx));
}
} else if (std::get_if<Fortran::parser::OmpClause::UseDevicePtr>(
&clause.u)) {
TODO(currentLocation, "OMPD_target Use Device Ptr");
} else if (std::get_if<Fortran::parser::OmpClause::UseDeviceAddr>(
&clause.u)) {
TODO(currentLocation, "OMPD_target Use Device Addr");
} else if (const auto &threadLmtClause =
std::get_if<Fortran::parser::OmpClause::ThreadLimit>(
&clause.u)) {
threadLmtOperand = fir::getBase(converter.genExprValue(
*Fortran::semantics::GetExpr(threadLmtClause->v), stmtCtx));
} else if (std::get_if<Fortran::parser::OmpClause::Nowait>(&clause.u)) {
nowaitAttr = firOpBuilder.getUnitAttr();
} else if (const auto &devPtrClause =
std::get_if<Fortran::parser::OmpClause::UseDevicePtr>(
&clause.u)) {
addUseDeviceClause(devPtrClause->v, devicePtrOperands);
} else if (const auto &devAddrClause =
std::get_if<Fortran::parser::OmpClause::UseDeviceAddr>(
&clause.u)) {
addUseDeviceClause(devAddrClause->v, deviceAddrOperands);
} else if (const auto &mapClause =
std::get_if<Fortran::parser::OmpClause::Map>(&clause.u)) {
addMapClause(mapClause, currentLocation);
Expand All @@ -859,9 +921,10 @@ static void createTargetOp(Fortran::lower::AbstractConverter &converter,
createBodyOfOp(targetOp, converter, currentLocation, *eval, &opClauseList);
} else if (directive == llvm::omp::Directive::OMPD_target_data) {
auto dataOp = firOpBuilder.create<omp::DataOp>(
currentLocation, ifClauseOperand, deviceOperand, useDevicePtrOperand,
useDeviceAddrOperand, mapOperands, mapTypesArrayAttr);
createBodyOfOp(dataOp, converter, currentLocation, *eval, &opClauseList);
currentLocation, ifClauseOperand, deviceOperand, devicePtrOperands,
deviceAddrOperands, mapOperands, mapTypesArrayAttr);
createBodyOfTargetOp(converter, dataOp, useDeviceTypes, useDeviceLocs,
useDeviceSymbols, currentLocation);
} else if (directive == llvm::omp::Directive::OMPD_target_enter_data) {
firOpBuilder.create<omp::EnterDataOp>(currentLocation, ifClauseOperand,
deviceOperand, nowaitAttr,
Expand Down Expand Up @@ -1157,7 +1220,17 @@ genOMP(Fortran::lower::AbstractConverter &converter,
continue;
} else if (std::get_if<Fortran::parser::OmpClause::Map>(&clause.u)) {
// Map clause is exclusive to Target Data directives. It is handled
// as part of the DataOp creation.
// as part of the TargetOp creation.
continue;
} else if (std::get_if<Fortran::parser::OmpClause::UseDevicePtr>(
&clause.u)) {
// UseDevicePtr clause is exclusive to Target Data directives. It is
// handled as part of the TargetOp creation.
continue;
} else if (std::get_if<Fortran::parser::OmpClause::UseDeviceAddr>(
&clause.u)) {
// UseDeviceAddr clause is exclusive to Target Data directives. It is
// handled as part of the TargetOp creation.
continue;
} else if (std::get_if<Fortran::parser::OmpClause::ThreadLimit>(
&clause.u)) {
Expand Down
36 changes: 36 additions & 0 deletions flang/test/Lower/OpenMP/target.f90
Original file line number Diff line number Diff line change
Expand Up @@ -162,3 +162,39 @@ subroutine omp_target_thread_limit
!$omp end target
!CHECK: }
end subroutine omp_target_thread_limit

!===============================================================================
! Target `use_device_ptr` clause
!===============================================================================

!CHECK-LABEL: func.func @_QPomp_target_device_ptr() {
subroutine omp_target_device_ptr
use iso_c_binding, only : c_ptr, c_loc
type(c_ptr) :: a
integer, target :: b
!CHECK: omp.target_data map((tofrom -> %[[VAL_0:.*]] : !fir.ref<!fir.type<_QM__fortran_builtinsT__builtin_c_ptr{__address:i64}>>)) use_device_ptr(%[[VAL_0]] : !fir.ref<!fir.type<_QM__fortran_builtinsT__builtin_c_ptr{__address:i64}>>)
!$omp target data map(tofrom: a) use_device_ptr(a)
!CHECK: ^bb0(%[[VAL_1:.*]]: !fir.ref<!fir.type<_QM__fortran_builtinsT__builtin_c_ptr{__address:i64}>>):
!CHECK: {{.*}} = fir.coordinate_of %[[VAL_1:.*]], {{.*}} : (!fir.ref<!fir.type<_QM__fortran_builtinsT__builtin_c_ptr{__address:i64}>>, !fir.field) -> !fir.ref<i64>
a = c_loc(b)
!CHECK: omp.terminator
!$omp end target data
!CHECK: }
end subroutine omp_target_device_ptr

!===============================================================================
! Target `use_device_addr` clause
!===============================================================================

!CHECK-LABEL: func.func @_QPomp_target_device_addr() {
subroutine omp_target_device_addr
integer, pointer :: a
!CHECK: omp.target_data map((tofrom -> %[[VAL_0:.*]] : !fir.ref<!fir.box<!fir.ptr<i32>>>)) use_device_addr(%[[VAL_0]] : !fir.ref<!fir.box<!fir.ptr<i32>>>)
!$omp target data map(tofrom: a) use_device_addr(a)
!CHECK: ^bb0(%[[VAL_1:.*]]: !fir.ref<!fir.box<!fir.ptr<i32>>>):
!CHECK: {{.*}} = fir.load %[[VAL_1]] : !fir.ref<!fir.box<!fir.ptr<i32>>>
a = 10
!CHECK: omp.terminator
!$omp end target data
!CHECK: }
end subroutine omp_target_device_addr

0 comments on commit d21580c

Please sign in to comment.