-
Notifications
You must be signed in to change notification settings - Fork 15.7k
[Flang][OpenMP] Implement device clause lowering for target directive #173509
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
|
@llvm/pr-subscribers-mlir-openmp @llvm/pr-subscribers-mlir Author: Chi-Chun, Chen (chichunchen) ChangesAdd lowering support for the OpenMP The device expression is propagated through MLIR OpenMP and passed to the host-side Full diff: https://github.com/llvm/llvm-project/pull/173509.diff 7 Files Affected:
diff --git a/flang/lib/Lower/OpenMP/OpenMP.cpp b/flang/lib/Lower/OpenMP/OpenMP.cpp
index 7965119764e5d..4f2b8ef15519c 100644
--- a/flang/lib/Lower/OpenMP/OpenMP.cpp
+++ b/flang/lib/Lower/OpenMP/OpenMP.cpp
@@ -4087,7 +4087,8 @@ static void genOMP(lower::AbstractConverter &converter, lower::SymMap &symTable,
!std::holds_alternative<clause::Mergeable>(clause.u) &&
!std::holds_alternative<clause::Untied>(clause.u) &&
!std::holds_alternative<clause::TaskReduction>(clause.u) &&
- !std::holds_alternative<clause::Detach>(clause.u)) {
+ !std::holds_alternative<clause::Detach>(clause.u) &&
+ !std::holds_alternative<clause::Device>(clause.u)) {
std::string name =
parser::ToUpperCaseLetters(llvm::omp::getOpenMPClauseName(clause.id));
if (!semaCtx.langOptions().OpenMPSimd)
diff --git a/flang/test/Lower/OpenMP/target.f90 b/flang/test/Lower/OpenMP/target.f90
index c5d39695e5389..55a6b7a595ed1 100644
--- a/flang/test/Lower/OpenMP/target.f90
+++ b/flang/test/Lower/OpenMP/target.f90
@@ -694,3 +694,44 @@ subroutine target_unstructured
!$omp end target
!CHECK: }
end subroutine target_unstructured
+
+!===============================================================================
+! Target `device` clause
+!===============================================================================
+
+!CHECK-LABEL: func.func @_QPomp_target_device() {
+subroutine omp_target_device
+ integer :: dev32
+ integer(kind=8) :: dev64
+ integer(kind=2) :: dev16
+
+ dev32 = 1
+ dev64 = 2_8
+ dev16 = 3_2
+
+ !$omp target device(dev32)
+ !$omp end target
+ ! CHECK: %[[DEV32:.*]] = fir.load %{{.*}} : !fir.ref<i32>
+ ! CHECK: omp.target device(%[[DEV32]] : i32)
+
+ !$omp target device(dev64)
+ !$omp end target
+ ! CHECK: %[[DEV64:.*]] = fir.load %{{.*}} : !fir.ref<i64>
+ ! CHECK: omp.target device(%[[DEV64]] : i64)
+
+ !$omp target device(dev16)
+ !$omp end target
+ ! CHECK: %[[DEV16:.*]] = fir.load %{{.*}} : !fir.ref<i16>
+ ! CHECK: omp.target device(%[[DEV16]] : i16)
+
+ !$omp target device(2)
+ !$omp end target
+ ! CHECK: %[[C2:.*]] = arith.constant 2 : i32
+ ! CHECK: omp.target device(%[[C2]] : i32)
+
+ !$omp target device(5_8)
+ !$omp end target
+ ! CHECK: %[[C5:.*]] = arith.constant 5 : i64
+ ! CHECK: omp.target device(%[[C5]] : i64)
+
+end subroutine omp_target_device
\ No newline at end of file
diff --git a/llvm/include/llvm/Frontend/OpenMP/OMPIRBuilder.h b/llvm/include/llvm/Frontend/OpenMP/OMPIRBuilder.h
index f5eb6222fd58d..8103a7e9504ea 100644
--- a/llvm/include/llvm/Frontend/OpenMP/OMPIRBuilder.h
+++ b/llvm/include/llvm/Frontend/OpenMP/OMPIRBuilder.h
@@ -3341,7 +3341,7 @@ class OpenMPIRBuilder {
const LocationDescription &Loc, bool IsOffloadEntry,
OpenMPIRBuilder::InsertPointTy AllocaIP,
OpenMPIRBuilder::InsertPointTy CodeGenIP, TargetDataInfo &Info,
- TargetRegionEntryInfo &EntryInfo,
+ Value *DeviceID, TargetRegionEntryInfo &EntryInfo,
const TargetKernelDefaultAttrs &DefaultAttrs,
const TargetKernelRuntimeAttrs &RuntimeAttrs, Value *IfCond,
SmallVectorImpl<Value *> &Inputs, GenMapInfoCallbackTy GenMapInfoCB,
diff --git a/llvm/lib/Frontend/OpenMP/OMPIRBuilder.cpp b/llvm/lib/Frontend/OpenMP/OMPIRBuilder.cpp
index 716f8582dd7b2..3be96350cb058 100644
--- a/llvm/lib/Frontend/OpenMP/OMPIRBuilder.cpp
+++ b/llvm/lib/Frontend/OpenMP/OMPIRBuilder.cpp
@@ -8548,7 +8548,7 @@ Error OpenMPIRBuilder::emitOffloadingArraysAndArgs(
static void emitTargetCall(
OpenMPIRBuilder &OMPBuilder, IRBuilderBase &Builder,
OpenMPIRBuilder::InsertPointTy AllocaIP,
- OpenMPIRBuilder::TargetDataInfo &Info,
+ OpenMPIRBuilder::TargetDataInfo &Info, Value *DeviceID,
const OpenMPIRBuilder::TargetKernelDefaultAttrs &DefaultAttrs,
const OpenMPIRBuilder::TargetKernelRuntimeAttrs &RuntimeAttrs,
Value *IfCond, Function *OutlinedFn, Constant *OutlinedFnID,
@@ -8680,8 +8680,6 @@ static void emitTargetCall(
}
unsigned NumTargetItems = Info.NumberOfPtrs;
- // TODO: Use correct device ID
- Value *DeviceID = Builder.getInt64(OMP_DEVICEID_UNDEF);
uint32_t SrcLocStrSize;
Constant *SrcLocStr = OMPBuilder.getOrCreateDefaultSrcLocStr(SrcLocStrSize);
Value *RTLoc = OMPBuilder.getOrCreateIdent(SrcLocStr, SrcLocStrSize,
@@ -8740,7 +8738,7 @@ static void emitTargetCall(
OpenMPIRBuilder::InsertPointOrErrorTy OpenMPIRBuilder::createTarget(
const LocationDescription &Loc, bool IsOffloadEntry, InsertPointTy AllocaIP,
- InsertPointTy CodeGenIP, TargetDataInfo &Info,
+ InsertPointTy CodeGenIP, TargetDataInfo &Info, Value *DeviceID,
TargetRegionEntryInfo &EntryInfo,
const TargetKernelDefaultAttrs &DefaultAttrs,
const TargetKernelRuntimeAttrs &RuntimeAttrs, Value *IfCond,
@@ -8770,10 +8768,10 @@ OpenMPIRBuilder::InsertPointOrErrorTy OpenMPIRBuilder::createTarget(
// to make a remote call (offload) to the previously outlined function
// that represents the target region. Do that now.
if (!Config.isTargetDevice())
- emitTargetCall(*this, Builder, AllocaIP, Info, DefaultAttrs, RuntimeAttrs,
- IfCond, OutlinedFn, OutlinedFnID, Inputs, GenMapInfoCB,
- CustomMapperCB, Dependencies, HasNowait, DynCGroupMem,
- DynCGroupMemFallback);
+ emitTargetCall(*this, Builder, AllocaIP, Info, DeviceID, DefaultAttrs,
+ RuntimeAttrs, IfCond, OutlinedFn, OutlinedFnID, Inputs,
+ GenMapInfoCB, CustomMapperCB, Dependencies, HasNowait,
+ DynCGroupMem, DynCGroupMemFallback);
return Builder.saveIP();
}
diff --git a/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp b/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp
index 03d67a52853f6..ac2d6c93b890e 100644
--- a/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp
+++ b/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp
@@ -320,7 +320,7 @@ static LogicalResult checkImplementationStatus(Operation &op) {
result = todo("depend");
};
auto checkDevice = [&todo](auto op, LogicalResult &result) {
- if (op.getDevice())
+ if (op.getDevice() && !isa<omp::TargetOp>(op))
result = todo("device");
};
auto checkHint = [](auto op, LogicalResult &) {
@@ -5961,6 +5961,15 @@ convertOmpTarget(Operation &opInst, llvm::IRBuilderBase &builder,
llvm::OpenMPIRBuilder *ompBuilder = moduleTranslation.getOpenMPBuilder();
bool isTargetDevice = ompBuilder->Config.isTargetDevice();
bool isGPU = ompBuilder->Config.isGPU();
+ llvm::Value *deviceIDValue = builder.getInt64(llvm::omp::OMP_DEVICEID_UNDEF);
+
+ if (!isTargetDevice) {
+ if (mlir::Value devId = targetOp.getDevice()) {
+ deviceIDValue = moduleTranslation.lookupValue(devId);
+ deviceIDValue =
+ builder.CreateSExtOrTrunc(deviceIDValue, builder.getInt64Ty());
+ }
+ }
auto parentFn = opInst.getParentOfType<LLVM::LLVMFuncOp>();
auto argIface = cast<omp::BlockArgOpenMPOpInterface>(opInst);
@@ -6235,9 +6244,10 @@ convertOmpTarget(Operation &opInst, llvm::IRBuilderBase &builder,
llvm::OpenMPIRBuilder::InsertPointOrErrorTy afterIP =
moduleTranslation.getOpenMPBuilder()->createTarget(
- ompLoc, isOffloadEntry, allocaIP, builder.saveIP(), info, entryInfo,
- defaultAttrs, runtimeAttrs, ifCond, kernelInput, genMapInfoCB, bodyCB,
- argAccessorCB, customMapperCB, dds, targetOp.getNowait());
+ ompLoc, isOffloadEntry, allocaIP, builder.saveIP(), info,
+ deviceIDValue, entryInfo, defaultAttrs, runtimeAttrs, ifCond,
+ kernelInput, genMapInfoCB, bodyCB, argAccessorCB, customMapperCB, dds,
+ targetOp.getNowait());
if (failed(handleError(afterIP, opInst)))
return failure();
diff --git a/mlir/test/Target/LLVMIR/omptarget-device.mlir b/mlir/test/Target/LLVMIR/omptarget-device.mlir
new file mode 100644
index 0000000000000..b4c9744cc0c87
--- /dev/null
+++ b/mlir/test/Target/LLVMIR/omptarget-device.mlir
@@ -0,0 +1,68 @@
+// RUN: mlir-translate -mlir-to-llvmir %s | FileCheck %s
+
+module attributes {omp.is_target_device = false, omp.target_triples = ["nvptx64-nvidia-cuda"]} {
+ llvm.func @foo(%d16 : i16, %d32 : i32, %d64 : i64) {
+ %x = llvm.mlir.constant(0 : i32) : i32
+
+ // Constant i16 -> i64 in the runtime call.
+ %c1_i16 = llvm.mlir.constant(1 : i16) : i16
+ omp.target device(%c1_i16 : i16)
+ host_eval(%x -> %lb, %x -> %ub, %x -> %step : i32, i32, i32) {
+ omp.terminator
+ }
+
+ // Constant i32 -> i64 in the runtime call.
+ %c2_i32 = llvm.mlir.constant(2 : i32) : i32
+ omp.target device(%c2_i32 : i32)
+ host_eval(%x -> %lb, %x -> %ub, %x -> %step : i32, i32, i32) {
+ omp.terminator
+ }
+
+ // Constant i64 stays i64 in the runtime call.
+ %c3_i64 = llvm.mlir.constant(3 : i64) : i64
+ omp.target device(%c3_i64 : i64)
+ host_eval(%x -> %lb, %x -> %ub, %x -> %step : i32, i32, i32) {
+ omp.terminator
+ }
+
+ // Variable i16 -> cast to i64.
+ omp.target device(%d16 : i16)
+ host_eval(%x -> %lb, %x -> %ub, %x -> %step : i32, i32, i32) {
+ omp.terminator
+ }
+
+ // Variable i32 -> cast to i64.
+ omp.target device(%d32 : i32)
+ host_eval(%x -> %lb, %x -> %ub, %x -> %step : i32, i32, i32) {
+ omp.terminator
+ }
+
+ // Variable i64 stays i64.
+ omp.target device(%d64 : i64)
+ host_eval(%x -> %lb, %x -> %ub, %x -> %step : i32, i32, i32) {
+ omp.terminator
+ }
+
+ llvm.return
+ }
+}
+
+// CHECK-LABEL: define void @foo(i16 %{{.*}}, i32 %{{.*}}, i64 %{{.*}}) {
+// CHECK: br label %entry
+// CHECK: entry:
+
+// ---- Constant cases (device id is 2nd argument) ----
+// CHECK-DAG: call i32 @__tgt_target_kernel(ptr {{.*}}, i64 1, i32 {{.*}}, i32 {{.*}}, ptr {{.*}}, ptr {{.*}})
+// CHECK-DAG: call i32 @__tgt_target_kernel(ptr {{.*}}, i64 2, i32 {{.*}}, i32 {{.*}}, ptr {{.*}}, ptr {{.*}})
+// CHECK-DAG: call i32 @__tgt_target_kernel(ptr {{.*}}, i64 3, i32 {{.*}}, i32 {{.*}}, ptr {{.*}}, ptr {{.*}})
+
+// Variable i16 -> i64
+// CHECK: %[[D16_I64:.*]] = sext i16 %{{.*}} to i64
+// CHECK: call i32 @__tgt_target_kernel(ptr {{.*}}, i64 %[[D16_I64]], i32 {{.*}}, i32 {{.*}}, ptr {{.*}}, ptr {{.*}})
+
+// Variable i32 -> i64
+// CHECK: %[[D32_I64:.*]] = sext i32 %{{.*}} to i64
+// CHECK: call i32 @__tgt_target_kernel(ptr {{.*}}, i64 %[[D32_I64]], i32 {{.*}}, i32 {{.*}}, ptr {{.*}}, ptr {{.*}})
+
+// Variable i64
+// CHECK: call i32 @__tgt_target_kernel(ptr {{.*}}, i64 %{{.*}}, i32 {{.*}}, i32 {{.*}}, ptr {{.*}}, ptr {{.*}})
\ No newline at end of file
diff --git a/mlir/test/Target/LLVMIR/openmp-todo.mlir b/mlir/test/Target/LLVMIR/openmp-todo.mlir
index 396c57af81c44..d4cc9e215de1d 100644
--- a/mlir/test/Target/LLVMIR/openmp-todo.mlir
+++ b/mlir/test/Target/LLVMIR/openmp-todo.mlir
@@ -174,8 +174,6 @@ llvm.func @target_allocate(%x : !llvm.ptr) {
// -----
llvm.func @target_device(%x : i32) {
- // expected-error@below {{not yet implemented: Unhandled clause device in omp.target operation}}
- // expected-error@below {{LLVM Translation failed for operation: omp.target}}
omp.target device(%x : i32) {
omp.terminator
}
|
|
@llvm/pr-subscribers-flang-openmp Author: Chi-Chun, Chen (chichunchen) ChangesAdd lowering support for the OpenMP The device expression is propagated through MLIR OpenMP and passed to the host-side Full diff: https://github.com/llvm/llvm-project/pull/173509.diff 7 Files Affected:
diff --git a/flang/lib/Lower/OpenMP/OpenMP.cpp b/flang/lib/Lower/OpenMP/OpenMP.cpp
index 7965119764e5d..4f2b8ef15519c 100644
--- a/flang/lib/Lower/OpenMP/OpenMP.cpp
+++ b/flang/lib/Lower/OpenMP/OpenMP.cpp
@@ -4087,7 +4087,8 @@ static void genOMP(lower::AbstractConverter &converter, lower::SymMap &symTable,
!std::holds_alternative<clause::Mergeable>(clause.u) &&
!std::holds_alternative<clause::Untied>(clause.u) &&
!std::holds_alternative<clause::TaskReduction>(clause.u) &&
- !std::holds_alternative<clause::Detach>(clause.u)) {
+ !std::holds_alternative<clause::Detach>(clause.u) &&
+ !std::holds_alternative<clause::Device>(clause.u)) {
std::string name =
parser::ToUpperCaseLetters(llvm::omp::getOpenMPClauseName(clause.id));
if (!semaCtx.langOptions().OpenMPSimd)
diff --git a/flang/test/Lower/OpenMP/target.f90 b/flang/test/Lower/OpenMP/target.f90
index c5d39695e5389..55a6b7a595ed1 100644
--- a/flang/test/Lower/OpenMP/target.f90
+++ b/flang/test/Lower/OpenMP/target.f90
@@ -694,3 +694,44 @@ subroutine target_unstructured
!$omp end target
!CHECK: }
end subroutine target_unstructured
+
+!===============================================================================
+! Target `device` clause
+!===============================================================================
+
+!CHECK-LABEL: func.func @_QPomp_target_device() {
+subroutine omp_target_device
+ integer :: dev32
+ integer(kind=8) :: dev64
+ integer(kind=2) :: dev16
+
+ dev32 = 1
+ dev64 = 2_8
+ dev16 = 3_2
+
+ !$omp target device(dev32)
+ !$omp end target
+ ! CHECK: %[[DEV32:.*]] = fir.load %{{.*}} : !fir.ref<i32>
+ ! CHECK: omp.target device(%[[DEV32]] : i32)
+
+ !$omp target device(dev64)
+ !$omp end target
+ ! CHECK: %[[DEV64:.*]] = fir.load %{{.*}} : !fir.ref<i64>
+ ! CHECK: omp.target device(%[[DEV64]] : i64)
+
+ !$omp target device(dev16)
+ !$omp end target
+ ! CHECK: %[[DEV16:.*]] = fir.load %{{.*}} : !fir.ref<i16>
+ ! CHECK: omp.target device(%[[DEV16]] : i16)
+
+ !$omp target device(2)
+ !$omp end target
+ ! CHECK: %[[C2:.*]] = arith.constant 2 : i32
+ ! CHECK: omp.target device(%[[C2]] : i32)
+
+ !$omp target device(5_8)
+ !$omp end target
+ ! CHECK: %[[C5:.*]] = arith.constant 5 : i64
+ ! CHECK: omp.target device(%[[C5]] : i64)
+
+end subroutine omp_target_device
\ No newline at end of file
diff --git a/llvm/include/llvm/Frontend/OpenMP/OMPIRBuilder.h b/llvm/include/llvm/Frontend/OpenMP/OMPIRBuilder.h
index f5eb6222fd58d..8103a7e9504ea 100644
--- a/llvm/include/llvm/Frontend/OpenMP/OMPIRBuilder.h
+++ b/llvm/include/llvm/Frontend/OpenMP/OMPIRBuilder.h
@@ -3341,7 +3341,7 @@ class OpenMPIRBuilder {
const LocationDescription &Loc, bool IsOffloadEntry,
OpenMPIRBuilder::InsertPointTy AllocaIP,
OpenMPIRBuilder::InsertPointTy CodeGenIP, TargetDataInfo &Info,
- TargetRegionEntryInfo &EntryInfo,
+ Value *DeviceID, TargetRegionEntryInfo &EntryInfo,
const TargetKernelDefaultAttrs &DefaultAttrs,
const TargetKernelRuntimeAttrs &RuntimeAttrs, Value *IfCond,
SmallVectorImpl<Value *> &Inputs, GenMapInfoCallbackTy GenMapInfoCB,
diff --git a/llvm/lib/Frontend/OpenMP/OMPIRBuilder.cpp b/llvm/lib/Frontend/OpenMP/OMPIRBuilder.cpp
index 716f8582dd7b2..3be96350cb058 100644
--- a/llvm/lib/Frontend/OpenMP/OMPIRBuilder.cpp
+++ b/llvm/lib/Frontend/OpenMP/OMPIRBuilder.cpp
@@ -8548,7 +8548,7 @@ Error OpenMPIRBuilder::emitOffloadingArraysAndArgs(
static void emitTargetCall(
OpenMPIRBuilder &OMPBuilder, IRBuilderBase &Builder,
OpenMPIRBuilder::InsertPointTy AllocaIP,
- OpenMPIRBuilder::TargetDataInfo &Info,
+ OpenMPIRBuilder::TargetDataInfo &Info, Value *DeviceID,
const OpenMPIRBuilder::TargetKernelDefaultAttrs &DefaultAttrs,
const OpenMPIRBuilder::TargetKernelRuntimeAttrs &RuntimeAttrs,
Value *IfCond, Function *OutlinedFn, Constant *OutlinedFnID,
@@ -8680,8 +8680,6 @@ static void emitTargetCall(
}
unsigned NumTargetItems = Info.NumberOfPtrs;
- // TODO: Use correct device ID
- Value *DeviceID = Builder.getInt64(OMP_DEVICEID_UNDEF);
uint32_t SrcLocStrSize;
Constant *SrcLocStr = OMPBuilder.getOrCreateDefaultSrcLocStr(SrcLocStrSize);
Value *RTLoc = OMPBuilder.getOrCreateIdent(SrcLocStr, SrcLocStrSize,
@@ -8740,7 +8738,7 @@ static void emitTargetCall(
OpenMPIRBuilder::InsertPointOrErrorTy OpenMPIRBuilder::createTarget(
const LocationDescription &Loc, bool IsOffloadEntry, InsertPointTy AllocaIP,
- InsertPointTy CodeGenIP, TargetDataInfo &Info,
+ InsertPointTy CodeGenIP, TargetDataInfo &Info, Value *DeviceID,
TargetRegionEntryInfo &EntryInfo,
const TargetKernelDefaultAttrs &DefaultAttrs,
const TargetKernelRuntimeAttrs &RuntimeAttrs, Value *IfCond,
@@ -8770,10 +8768,10 @@ OpenMPIRBuilder::InsertPointOrErrorTy OpenMPIRBuilder::createTarget(
// to make a remote call (offload) to the previously outlined function
// that represents the target region. Do that now.
if (!Config.isTargetDevice())
- emitTargetCall(*this, Builder, AllocaIP, Info, DefaultAttrs, RuntimeAttrs,
- IfCond, OutlinedFn, OutlinedFnID, Inputs, GenMapInfoCB,
- CustomMapperCB, Dependencies, HasNowait, DynCGroupMem,
- DynCGroupMemFallback);
+ emitTargetCall(*this, Builder, AllocaIP, Info, DeviceID, DefaultAttrs,
+ RuntimeAttrs, IfCond, OutlinedFn, OutlinedFnID, Inputs,
+ GenMapInfoCB, CustomMapperCB, Dependencies, HasNowait,
+ DynCGroupMem, DynCGroupMemFallback);
return Builder.saveIP();
}
diff --git a/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp b/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp
index 03d67a52853f6..ac2d6c93b890e 100644
--- a/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp
+++ b/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp
@@ -320,7 +320,7 @@ static LogicalResult checkImplementationStatus(Operation &op) {
result = todo("depend");
};
auto checkDevice = [&todo](auto op, LogicalResult &result) {
- if (op.getDevice())
+ if (op.getDevice() && !isa<omp::TargetOp>(op))
result = todo("device");
};
auto checkHint = [](auto op, LogicalResult &) {
@@ -5961,6 +5961,15 @@ convertOmpTarget(Operation &opInst, llvm::IRBuilderBase &builder,
llvm::OpenMPIRBuilder *ompBuilder = moduleTranslation.getOpenMPBuilder();
bool isTargetDevice = ompBuilder->Config.isTargetDevice();
bool isGPU = ompBuilder->Config.isGPU();
+ llvm::Value *deviceIDValue = builder.getInt64(llvm::omp::OMP_DEVICEID_UNDEF);
+
+ if (!isTargetDevice) {
+ if (mlir::Value devId = targetOp.getDevice()) {
+ deviceIDValue = moduleTranslation.lookupValue(devId);
+ deviceIDValue =
+ builder.CreateSExtOrTrunc(deviceIDValue, builder.getInt64Ty());
+ }
+ }
auto parentFn = opInst.getParentOfType<LLVM::LLVMFuncOp>();
auto argIface = cast<omp::BlockArgOpenMPOpInterface>(opInst);
@@ -6235,9 +6244,10 @@ convertOmpTarget(Operation &opInst, llvm::IRBuilderBase &builder,
llvm::OpenMPIRBuilder::InsertPointOrErrorTy afterIP =
moduleTranslation.getOpenMPBuilder()->createTarget(
- ompLoc, isOffloadEntry, allocaIP, builder.saveIP(), info, entryInfo,
- defaultAttrs, runtimeAttrs, ifCond, kernelInput, genMapInfoCB, bodyCB,
- argAccessorCB, customMapperCB, dds, targetOp.getNowait());
+ ompLoc, isOffloadEntry, allocaIP, builder.saveIP(), info,
+ deviceIDValue, entryInfo, defaultAttrs, runtimeAttrs, ifCond,
+ kernelInput, genMapInfoCB, bodyCB, argAccessorCB, customMapperCB, dds,
+ targetOp.getNowait());
if (failed(handleError(afterIP, opInst)))
return failure();
diff --git a/mlir/test/Target/LLVMIR/omptarget-device.mlir b/mlir/test/Target/LLVMIR/omptarget-device.mlir
new file mode 100644
index 0000000000000..b4c9744cc0c87
--- /dev/null
+++ b/mlir/test/Target/LLVMIR/omptarget-device.mlir
@@ -0,0 +1,68 @@
+// RUN: mlir-translate -mlir-to-llvmir %s | FileCheck %s
+
+module attributes {omp.is_target_device = false, omp.target_triples = ["nvptx64-nvidia-cuda"]} {
+ llvm.func @foo(%d16 : i16, %d32 : i32, %d64 : i64) {
+ %x = llvm.mlir.constant(0 : i32) : i32
+
+ // Constant i16 -> i64 in the runtime call.
+ %c1_i16 = llvm.mlir.constant(1 : i16) : i16
+ omp.target device(%c1_i16 : i16)
+ host_eval(%x -> %lb, %x -> %ub, %x -> %step : i32, i32, i32) {
+ omp.terminator
+ }
+
+ // Constant i32 -> i64 in the runtime call.
+ %c2_i32 = llvm.mlir.constant(2 : i32) : i32
+ omp.target device(%c2_i32 : i32)
+ host_eval(%x -> %lb, %x -> %ub, %x -> %step : i32, i32, i32) {
+ omp.terminator
+ }
+
+ // Constant i64 stays i64 in the runtime call.
+ %c3_i64 = llvm.mlir.constant(3 : i64) : i64
+ omp.target device(%c3_i64 : i64)
+ host_eval(%x -> %lb, %x -> %ub, %x -> %step : i32, i32, i32) {
+ omp.terminator
+ }
+
+ // Variable i16 -> cast to i64.
+ omp.target device(%d16 : i16)
+ host_eval(%x -> %lb, %x -> %ub, %x -> %step : i32, i32, i32) {
+ omp.terminator
+ }
+
+ // Variable i32 -> cast to i64.
+ omp.target device(%d32 : i32)
+ host_eval(%x -> %lb, %x -> %ub, %x -> %step : i32, i32, i32) {
+ omp.terminator
+ }
+
+ // Variable i64 stays i64.
+ omp.target device(%d64 : i64)
+ host_eval(%x -> %lb, %x -> %ub, %x -> %step : i32, i32, i32) {
+ omp.terminator
+ }
+
+ llvm.return
+ }
+}
+
+// CHECK-LABEL: define void @foo(i16 %{{.*}}, i32 %{{.*}}, i64 %{{.*}}) {
+// CHECK: br label %entry
+// CHECK: entry:
+
+// ---- Constant cases (device id is 2nd argument) ----
+// CHECK-DAG: call i32 @__tgt_target_kernel(ptr {{.*}}, i64 1, i32 {{.*}}, i32 {{.*}}, ptr {{.*}}, ptr {{.*}})
+// CHECK-DAG: call i32 @__tgt_target_kernel(ptr {{.*}}, i64 2, i32 {{.*}}, i32 {{.*}}, ptr {{.*}}, ptr {{.*}})
+// CHECK-DAG: call i32 @__tgt_target_kernel(ptr {{.*}}, i64 3, i32 {{.*}}, i32 {{.*}}, ptr {{.*}}, ptr {{.*}})
+
+// Variable i16 -> i64
+// CHECK: %[[D16_I64:.*]] = sext i16 %{{.*}} to i64
+// CHECK: call i32 @__tgt_target_kernel(ptr {{.*}}, i64 %[[D16_I64]], i32 {{.*}}, i32 {{.*}}, ptr {{.*}}, ptr {{.*}})
+
+// Variable i32 -> i64
+// CHECK: %[[D32_I64:.*]] = sext i32 %{{.*}} to i64
+// CHECK: call i32 @__tgt_target_kernel(ptr {{.*}}, i64 %[[D32_I64]], i32 {{.*}}, i32 {{.*}}, ptr {{.*}}, ptr {{.*}})
+
+// Variable i64
+// CHECK: call i32 @__tgt_target_kernel(ptr {{.*}}, i64 %{{.*}}, i32 {{.*}}, i32 {{.*}}, ptr {{.*}}, ptr {{.*}})
\ No newline at end of file
diff --git a/mlir/test/Target/LLVMIR/openmp-todo.mlir b/mlir/test/Target/LLVMIR/openmp-todo.mlir
index 396c57af81c44..d4cc9e215de1d 100644
--- a/mlir/test/Target/LLVMIR/openmp-todo.mlir
+++ b/mlir/test/Target/LLVMIR/openmp-todo.mlir
@@ -174,8 +174,6 @@ llvm.func @target_allocate(%x : !llvm.ptr) {
// -----
llvm.func @target_device(%x : i32) {
- // expected-error@below {{not yet implemented: Unhandled clause device in omp.target operation}}
- // expected-error@below {{LLVM Translation failed for operation: omp.target}}
omp.target device(%x : i32) {
omp.terminator
}
|
🐧 Linux x64 Test Results
✅ The build succeeded and all tests passed. |
🪟 Windows x64 Test Results
✅ The build succeeded and all tests passed. |
Add lowering support for the OpenMP `device` clause on the `target` directive in Flang. The device expression is propagated through MLIR OpenMP and passed to the host-side `__tgt_target_kernel` call.
5e71194 to
3519e24
Compare
|
✅ With the latest revision this PR passed the C/C++ code formatter. |
ergawy
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks! Just one small comment about proberly reporting current lowering status.
| }; | ||
| auto checkDevice = [&todo](auto op, LogicalResult &result) { | ||
| if (op.getDevice()) | ||
| if (op.getDevice() && !isa<omp::TargetOp>(op)) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This change is not needed. I think instead you should add checkDevice(...) to .Case<omp::TargetEnterDataOp, omp::TargetExitDataOp, omp::TargetUpdateOp> below. Also, we should add a Case<omp::TargetDataOp>.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I removed checkDevice from TargetOp and added it to TargetUpdateOp and TargetDataOp. I did not add checkDevice to TargetEnterDataOp or TargetExitDataOp, because doing so would cause regressions in omptarget-llvm.mlir. That test uses the device clause but hardcodes the device_id to -1 when calling the __tgt_target_enter_data and __tgt_target_exit_data APIs.
tblah
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thank you for implementing this. Please could you update flang/docs/OpenMPSupport.md
e05fb83 to
320c3a9
Compare
- Update OpenMP support for the device clause on the target construct
- Remove checkDevice from the target construct.
- Add checkDevice on target_update and target_data
- Does not add checkDevice on target_enter and target_exit
- Adding checkDevice to target_enter and target_exit would cause a
regression in omptarget-llvm.mlir, since that test currently uses
the device clause in both target_enter and target_exit but expects
-1 to be passed to the __tgt_target_data* APIs.
320c3a9 to
d9313a0
Compare
Done! |
ergawy
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM! Thanks
tblah
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks!
Add lowering support for the OpenMP
deviceclause on thetargetdirective in Flang.The device expression is propagated through MLIR OpenMP and passed to the host-side
__tgt_target_kernelcall.