Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
45 changes: 18 additions & 27 deletions flang/lib/Lower/OpenMP/ClauseProcessor.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -44,15 +44,6 @@ mlir::omp::ReductionModifier translateReductionModifier(ReductionModifier mod) {
return mlir::omp::ReductionModifier::defaultmod;
}

/// Check for unsupported map operand types.
static void checkMapType(mlir::Location location, mlir::Type type) {
if (auto refType = mlir::dyn_cast<fir::ReferenceType>(type))
type = refType.getElementType();
if (auto boxType = mlir::dyn_cast_or_null<fir::BoxType>(type))
if (!mlir::isa<fir::PointerType>(boxType.getElementType()))
TODO(location, "OMPD_target_data MapOperand BoxType");
}

static mlir::omp::ScheduleModifier
translateScheduleModifier(const omp::clause::Schedule::OrderingModifier &m) {
switch (m) {
Expand Down Expand Up @@ -211,18 +202,6 @@ getIfClauseOperand(lower::AbstractConverter &converter,
ifVal);
}

static void addUseDeviceClause(
lower::AbstractConverter &converter, const omp::ObjectList &objects,
llvm::SmallVectorImpl<mlir::Value> &operands,
llvm::SmallVectorImpl<const semantics::Symbol *> &useDeviceSyms) {
genObjectList(objects, converter, operands);
for (mlir::Value &operand : operands)
checkMapType(operand.getLoc(), operand.getType());

for (const omp::Object &object : objects)
useDeviceSyms.push_back(object.sym());
}

//===----------------------------------------------------------------------===//
// ClauseProcessor unique clauses
//===----------------------------------------------------------------------===//
Expand Down Expand Up @@ -1225,14 +1204,26 @@ bool ClauseProcessor::processInReduction(
}

bool ClauseProcessor::processIsDevicePtr(
mlir::omp::IsDevicePtrClauseOps &result,
lower::StatementContext &stmtCtx, mlir::omp::IsDevicePtrClauseOps &result,
llvm::SmallVectorImpl<const semantics::Symbol *> &isDeviceSyms) const {
return findRepeatableClause<omp::clause::IsDevicePtr>(
[&](const omp::clause::IsDevicePtr &devPtrClause,
const parser::CharBlock &) {
addUseDeviceClause(converter, devPtrClause.v, result.isDevicePtrVars,
isDeviceSyms);
std::map<Object, OmpMapParentAndMemberData> parentMemberIndices;
bool clauseFound = findRepeatableClause<omp::clause::IsDevicePtr>(
[&](const omp::clause::IsDevicePtr &clause,
const parser::CharBlock &source) {
mlir::Location location = converter.genLocation(source);
// Force a map so the descriptor is materialized on the device with the
// device address inside.
mlir::omp::ClauseMapFlags mapTypeBits =
mlir::omp::ClauseMapFlags::is_device_ptr |
mlir::omp::ClauseMapFlags::to;
processMapObjects(stmtCtx, location, clause.v, mapTypeBits,
parentMemberIndices, result.isDevicePtrVars,
isDeviceSyms);
});

insertChildMapInfoIntoParent(converter, semaCtx, stmtCtx, parentMemberIndices,
result.isDevicePtrVars, isDeviceSyms);
return clauseFound;
}

bool ClauseProcessor::processLinear(mlir::omp::LinearClauseOps &result) const {
Expand Down
2 changes: 1 addition & 1 deletion flang/lib/Lower/OpenMP/ClauseProcessor.h
Original file line number Diff line number Diff line change
Expand Up @@ -134,7 +134,7 @@ class ClauseProcessor {
mlir::Location currentLocation, mlir::omp::InReductionClauseOps &result,
llvm::SmallVectorImpl<const semantics::Symbol *> &outReductionSyms) const;
bool processIsDevicePtr(
mlir::omp::IsDevicePtrClauseOps &result,
lower::StatementContext &stmtCtx, mlir::omp::IsDevicePtrClauseOps &result,
llvm::SmallVectorImpl<const semantics::Symbol *> &isDeviceSyms) const;
bool processLinear(mlir::omp::LinearClauseOps &result) const;
bool
Expand Down
42 changes: 38 additions & 4 deletions flang/lib/Lower/OpenMP/OpenMP.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1671,7 +1671,7 @@ static void genTargetClauses(
hostEvalInfo->collectValues(clauseOps.hostEvalVars);
}
cp.processIf(llvm::omp::Directive::OMPD_target, clauseOps);
cp.processIsDevicePtr(clauseOps, isDevicePtrSyms);
cp.processIsDevicePtr(stmtCtx, clauseOps, isDevicePtrSyms);
cp.processMap(loc, stmtCtx, clauseOps, llvm::omp::Directive::OMPD_unknown,
&mapSyms);
cp.processNowait(clauseOps);
Expand Down Expand Up @@ -2487,13 +2487,15 @@ static bool isDuplicateMappedSymbol(
const semantics::Symbol &sym,
const llvm::SetVector<const semantics::Symbol *> &privatizedSyms,
const llvm::SmallVectorImpl<const semantics::Symbol *> &hasDevSyms,
const llvm::SmallVectorImpl<const semantics::Symbol *> &mappedSyms) {
const llvm::SmallVectorImpl<const semantics::Symbol *> &mappedSyms,
const llvm::SmallVectorImpl<const semantics::Symbol *> &isDevicePtrSyms) {
llvm::SmallVector<const semantics::Symbol *> concatSyms;
concatSyms.reserve(privatizedSyms.size() + hasDevSyms.size() +
mappedSyms.size());
mappedSyms.size() + isDevicePtrSyms.size());
concatSyms.append(privatizedSyms.begin(), privatizedSyms.end());
concatSyms.append(hasDevSyms.begin(), hasDevSyms.end());
concatSyms.append(mappedSyms.begin(), mappedSyms.end());
concatSyms.append(isDevicePtrSyms.begin(), isDevicePtrSyms.end());

auto checkSymbol = [&](const semantics::Symbol &checkSym) {
return std::any_of(concatSyms.begin(), concatSyms.end(),
Expand Down Expand Up @@ -2533,6 +2535,38 @@ genTargetOp(lower::AbstractConverter &converter, lower::SymMap &symTable,
loc, clauseOps, defaultMaps, hasDeviceAddrSyms,
isDevicePtrSyms, mapSyms);

if (!isDevicePtrSyms.empty()) {
// is_device_ptr maps get duplicated so the clause and synthesized
// has_device_addr entry each own a unique MapInfoOp user, keeping
// MapInfoFinalization happy while still wiring the symbol into
// has_device_addr when the user didn’t spell it explicitly.
auto insertionPt = firOpBuilder.saveInsertionPoint();
auto alreadyPresent = [&](const semantics::Symbol *sym) {
return llvm::any_of(hasDeviceAddrSyms, [&](const semantics::Symbol *s) {
return s && sym && s->GetUltimate() == sym->GetUltimate();
});
};

for (auto [idx, sym] : llvm::enumerate(isDevicePtrSyms)) {
mlir::Value mapVal = clauseOps.isDevicePtrVars[idx];
assert(sym && "expected symbol for is_device_ptr");
assert(mapVal && "expected map value for is_device_ptr");
auto mapInfo = mapVal.getDefiningOp<mlir::omp::MapInfoOp>();
assert(mapInfo && "expected map info op");

if (!alreadyPresent(sym)) {
clauseOps.hasDeviceAddrVars.push_back(mapVal);
hasDeviceAddrSyms.push_back(sym);
}

firOpBuilder.setInsertionPointAfter(mapInfo);
mlir::Operation *clonedOp = firOpBuilder.clone(*mapInfo.getOperation());
auto clonedMapInfo = mlir::cast<mlir::omp::MapInfoOp>(clonedOp);
clauseOps.isDevicePtrVars[idx] = clonedMapInfo.getResult();
}
firOpBuilder.restoreInsertionPoint(insertionPt);
}

DataSharingProcessor dsp(converter, semaCtx, item->clauses, eval,
/*shouldCollectPreDeterminedSymbols=*/
lower::omp::isLastItemInQueue(item, queue),
Expand Down Expand Up @@ -2572,7 +2606,7 @@ genTargetOp(lower::AbstractConverter &converter, lower::SymMap &symTable,
return;

if (!isDuplicateMappedSymbol(sym, dsp.getAllSymbolsToPrivatize(),
hasDeviceAddrSyms, mapSyms)) {
hasDeviceAddrSyms, mapSyms, isDevicePtrSyms)) {
if (const auto *details =
sym.template detailsIf<semantics::HostAssocDetails>())
converter.copySymbolBinding(details->symbol(), sym);
Expand Down
9 changes: 9 additions & 0 deletions flang/test/Integration/OpenMP/map-types-and-sizes.f90
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,15 @@ subroutine mapType_array
!$omp end target
end subroutine mapType_array

!CHECK: @.offload_sizes{{.*}} = private unnamed_addr constant [1 x i64] [i64 8]
!CHECK: @.offload_maptypes{{.*}} = private unnamed_addr constant [1 x i64] [i64 33]
subroutine mapType_is_device_ptr
use iso_c_binding, only : c_ptr
type(c_ptr) :: p
!$omp target is_device_ptr(p)
!$omp end target
end subroutine mapType_is_device_ptr

!CHECK: @.offload_sizes{{.*}} = private unnamed_addr constant [5 x i64] [i64 0, i64 0, i64 0, i64 8, i64 0]
!CHECK: @.offload_maptypes{{.*}} = private unnamed_addr constant [5 x i64] [i64 32, i64 281474976711173, i64 281474976711173, i64 281474976711171, i64 281474976711187]
subroutine mapType_ptr
Expand Down
30 changes: 30 additions & 0 deletions flang/test/Lower/OpenMP/target.f90
Original file line number Diff line number Diff line change
Expand Up @@ -566,6 +566,36 @@ subroutine omp_target_device_addr
end subroutine omp_target_device_addr


!===============================================================================
! Target `is_device_ptr` clause
!===============================================================================

!CHECK-LABEL: func.func @_QPomp_target_is_device_ptr() {
subroutine omp_target_is_device_ptr
use iso_c_binding, only: c_ptr
implicit none
integer :: i
integer :: arr(4)
type(c_ptr) :: p

i = 0
arr = 0

!CHECK: %[[P_STORAGE:.*]] = omp.map.info {{.*}}{name = "p"}
!CHECK: %[[P_IS:.*]] = omp.map.info {{.*}}{name = "p"}
!CHECK: %[[ARR_MAP:.*]] = omp.map.info {{.*}}{name = "arr"}
!CHECK: omp.target is_device_ptr(%[[P_IS]] :
!CHECK-SAME: has_device_addr(%[[P_STORAGE]] ->
!CHECK-SAME: map_entries({{.*}}%[[ARR_MAP]] ->
!$omp target is_device_ptr(p)
i = i + 1
arr(1) = i
!$omp end target
!CHECK: omp.terminator
!CHECK: }
end subroutine omp_target_is_device_ptr


!===============================================================================
! Target Data with unstructured code
!===============================================================================
Expand Down
4 changes: 3 additions & 1 deletion mlir/include/mlir/Dialect/OpenMP/OpenMPEnums.td
Original file line number Diff line number Diff line change
Expand Up @@ -128,6 +128,7 @@ def ClauseMapFlagsAttachAuto : I32BitEnumAttrCaseBit<"attach_auto", 15>;
def ClauseMapFlagsRefPtr : I32BitEnumAttrCaseBit<"ref_ptr", 16>;
def ClauseMapFlagsRefPtee : I32BitEnumAttrCaseBit<"ref_ptee", 17>;
def ClauseMapFlagsRefPtrPtee : I32BitEnumAttrCaseBit<"ref_ptr_ptee", 18>;
def ClauseMapFlagsIsDevicePtr : I32BitEnumAttrCaseBit<"is_device_ptr", 19>;

def ClauseMapFlags : OpenMP_BitEnumAttr<
"ClauseMapFlags",
Expand All @@ -151,7 +152,8 @@ def ClauseMapFlags : OpenMP_BitEnumAttr<
ClauseMapFlagsAttachAuto,
ClauseMapFlagsRefPtr,
ClauseMapFlagsRefPtee,
ClauseMapFlagsRefPtrPtee
ClauseMapFlagsRefPtrPtee,
ClauseMapFlagsIsDevicePtr
]>;

def ClauseMapFlagsAttr : OpenMP_EnumAttr<ClauseMapFlags,
Expand Down
5 changes: 5 additions & 0 deletions mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1818,6 +1818,9 @@ static ParseResult parseMapClause(OpAsmParser &parser,
if (mapTypeMod == "ref_ptr_ptee")
mapTypeBits |= ClauseMapFlags::ref_ptr_ptee;

if (mapTypeMod == "is_device_ptr")
mapTypeBits |= ClauseMapFlags::is_device_ptr;

return success();
};

Expand Down Expand Up @@ -1887,6 +1890,8 @@ static void printMapClause(OpAsmPrinter &p, Operation *op,
mapTypeStrs.push_back("ref_ptee");
if (mapTypeToBool(mapFlags, ClauseMapFlags::ref_ptr_ptee))
mapTypeStrs.push_back("ref_ptr_ptee");
if (mapTypeToBool(mapFlags, ClauseMapFlags::is_device_ptr))
mapTypeStrs.push_back("is_device_ptr");
if (mapFlags == ClauseMapFlags::none)
mapTypeStrs.push_back("none");

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -332,10 +332,6 @@ static LogicalResult checkImplementationStatus(Operation &op) {
op.getInReductionSyms())
result = todo("in_reduction");
};
auto checkIsDevicePtr = [&todo](auto op, LogicalResult &result) {
if (!op.getIsDevicePtrVars().empty())
result = todo("is_device_ptr");
};
auto checkNowait = [&todo](auto op, LogicalResult &result) {
if (op.getNowait())
result = todo("nowait");
Expand Down Expand Up @@ -435,7 +431,6 @@ static LogicalResult checkImplementationStatus(Operation &op) {
checkBare(op, result);
checkDevice(op, result);
checkInReduction(op, result);
checkIsDevicePtr(op, result);
})
.Default([](Operation &) {
// Assume all clauses for an operation can be translated unless they are
Expand Down Expand Up @@ -3986,6 +3981,9 @@ convertClauseMapFlags(omp::ClauseMapFlags mlirFlags) {
auto mapTypeToBool = [&mlirFlags](omp::ClauseMapFlags flag) {
return (mlirFlags & flag) == flag;
};
const bool hasExplicitMap =
(mlirFlags & ~omp::ClauseMapFlags::is_device_ptr) !=
omp::ClauseMapFlags::none;

llvm::omp::OpenMPOffloadMappingFlags mapType =
llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_NONE;
Expand Down Expand Up @@ -4026,6 +4024,12 @@ convertClauseMapFlags(omp::ClauseMapFlags mlirFlags) {
if (mapTypeToBool(omp::ClauseMapFlags::attach))
mapType |= llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_ATTACH;

if (mapTypeToBool(omp::ClauseMapFlags::is_device_ptr)) {
mapType |= llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_TARGET_PARAM;
if (!hasExplicitMap)
mapType |= llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_LITERAL;
}

return mapType;
}

Expand Down Expand Up @@ -4149,6 +4153,9 @@ static void collectMapDataFromMapOperands(
llvm::Value *origValue = moduleTranslation.lookupValue(offloadPtr);
auto mapType = convertClauseMapFlags(mapOp.getMapType());
auto mapTypeAlways = llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_ALWAYS;
bool isDevicePtr =
(mapOp.getMapType() & omp::ClauseMapFlags::is_device_ptr) !=
omp::ClauseMapFlags::none;

mapData.OriginalValue.push_back(origValue);
mapData.BasePointers.push_back(origValue);
Expand All @@ -4175,14 +4182,18 @@ static void collectMapDataFromMapOperands(
mapData.Mappers.push_back(nullptr);
}
} else {
// For is_device_ptr we need the map type to propagate so the runtime
// can materialize the device-side copy of the pointer container.
mapData.Types.push_back(
llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_LITERAL);
isDevicePtr ? mapType
: llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_LITERAL);
mapData.Mappers.push_back(nullptr);
}
mapData.Names.push_back(LLVM::createMappingInformation(
mapOp.getLoc(), *moduleTranslation.getOpenMPBuilder()));
mapData.DevicePointers.push_back(
llvm::OpenMPIRBuilder::DeviceInfoTy::Address);
isDevicePtr ? llvm::OpenMPIRBuilder::DeviceInfoTy::Pointer
: llvm::OpenMPIRBuilder::DeviceInfoTy::Address);
mapData.IsAMapping.push_back(false);
mapData.IsAMember.push_back(checkIsAMember(hasDevAddrOperands, mapOp));
}
Expand Down
17 changes: 17 additions & 0 deletions mlir/test/Target/LLVMIR/omptarget-llvm.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -622,3 +622,20 @@ module attributes {omp.target_triples = ["amdgcn-amd-amdhsa"]} {
// CHECK: br label %[[VAL_40]]
// CHECK: omp.done: ; preds = %[[VAL_68]], %[[VAL_63]], %[[VAL_32]]
// CHECK: ret void

// -----

module attributes {omp.target_triples = ["amdgcn-amd-amdhsa"]} {
llvm.func @_QPomp_target_is_device_ptr(%arg0 : !llvm.ptr) {
%map = omp.map.info var_ptr(%arg0 : !llvm.ptr, !llvm.ptr)
map_clauses(is_device_ptr) capture(ByRef) -> !llvm.ptr {name = ""}
omp.target map_entries(%map -> %ptr_arg : !llvm.ptr) {
omp.terminator
}
llvm.return
}
}

// CHECK: @.offload_sizes = private unnamed_addr constant [1 x i64] [i64 8]
// CHECK: @.offload_maptypes = private unnamed_addr constant [1 x i64] [i64 288]
// CHECK-LABEL: define void @_QPomp_target_is_device_ptr
11 changes: 0 additions & 11 deletions mlir/test/Target/LLVMIR/openmp-todo.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -212,17 +212,6 @@ llvm.func @target_in_reduction(%x : !llvm.ptr) {

// -----

llvm.func @target_is_device_ptr(%x : !llvm.ptr) {
// expected-error@below {{not yet implemented: Unhandled clause is_device_ptr in omp.target operation}}
// expected-error@below {{LLVM Translation failed for operation: omp.target}}
omp.target is_device_ptr(%x : !llvm.ptr) {
omp.terminator
}
llvm.return
}

// -----

llvm.func @target_enter_data_depend(%x: !llvm.ptr) {
// expected-error@below {{not yet implemented: Unhandled clause depend in omp.target_enter_data operation}}
// expected-error@below {{LLVM Translation failed for operation: omp.target_enter_data}}
Expand Down
Loading
Loading