Skip to content

Conversation

@chichunchen
Copy link
Contributor

Add lowering support for the OpenMP device clause on the target directive in Flang.

The device expression is propagated through MLIR OpenMP and passed to the host-side __tgt_target_kernel call.

@chichunchen chichunchen requested a review from skatrak December 24, 2025 20:27
@llvmbot llvmbot added mlir:llvm mlir flang Flang issues not falling into any other category mlir:openmp flang:fir-hlfir flang:openmp clang:openmp OpenMP related changes to Clang labels Dec 24, 2025
@llvmbot
Copy link
Member

llvmbot commented Dec 24, 2025

@llvm/pr-subscribers-mlir-openmp
@llvm/pr-subscribers-flang-fir-hlfir

@llvm/pr-subscribers-mlir

Author: Chi-Chun, Chen (chichunchen)

Changes

Add lowering support for the OpenMP device clause on the target directive in Flang.

The device expression is propagated through MLIR OpenMP and passed to the host-side __tgt_target_kernel call.


Full diff: https://github.com/llvm/llvm-project/pull/173509.diff

7 Files Affected:

  • (modified) flang/lib/Lower/OpenMP/OpenMP.cpp (+2-1)
  • (modified) flang/test/Lower/OpenMP/target.f90 (+41)
  • (modified) llvm/include/llvm/Frontend/OpenMP/OMPIRBuilder.h (+1-1)
  • (modified) llvm/lib/Frontend/OpenMP/OMPIRBuilder.cpp (+6-8)
  • (modified) mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp (+14-4)
  • (added) mlir/test/Target/LLVMIR/omptarget-device.mlir (+68)
  • (modified) mlir/test/Target/LLVMIR/openmp-todo.mlir (-2)
diff --git a/flang/lib/Lower/OpenMP/OpenMP.cpp b/flang/lib/Lower/OpenMP/OpenMP.cpp
index 7965119764e5d..4f2b8ef15519c 100644
--- a/flang/lib/Lower/OpenMP/OpenMP.cpp
+++ b/flang/lib/Lower/OpenMP/OpenMP.cpp
@@ -4087,7 +4087,8 @@ static void genOMP(lower::AbstractConverter &converter, lower::SymMap &symTable,
         !std::holds_alternative<clause::Mergeable>(clause.u) &&
         !std::holds_alternative<clause::Untied>(clause.u) &&
         !std::holds_alternative<clause::TaskReduction>(clause.u) &&
-        !std::holds_alternative<clause::Detach>(clause.u)) {
+        !std::holds_alternative<clause::Detach>(clause.u) &&
+        !std::holds_alternative<clause::Device>(clause.u)) {
       std::string name =
           parser::ToUpperCaseLetters(llvm::omp::getOpenMPClauseName(clause.id));
       if (!semaCtx.langOptions().OpenMPSimd)
diff --git a/flang/test/Lower/OpenMP/target.f90 b/flang/test/Lower/OpenMP/target.f90
index c5d39695e5389..55a6b7a595ed1 100644
--- a/flang/test/Lower/OpenMP/target.f90
+++ b/flang/test/Lower/OpenMP/target.f90
@@ -694,3 +694,44 @@ subroutine target_unstructured
    !$omp end target
    !CHECK: }
 end subroutine target_unstructured
+
+!===============================================================================
+! Target `device` clause
+!===============================================================================
+
+!CHECK-LABEL: func.func @_QPomp_target_device() {
+subroutine omp_target_device
+  integer            :: dev32
+  integer(kind=8)    :: dev64
+  integer(kind=2)    :: dev16
+
+  dev32 = 1
+  dev64 = 2_8
+  dev16 = 3_2
+
+  !$omp target device(dev32)
+  !$omp end target
+  ! CHECK: %[[DEV32:.*]] = fir.load %{{.*}} : !fir.ref<i32>
+  ! CHECK: omp.target device(%[[DEV32]] : i32)
+
+  !$omp target device(dev64)
+  !$omp end target
+  ! CHECK: %[[DEV64:.*]] = fir.load %{{.*}} : !fir.ref<i64>
+  ! CHECK: omp.target device(%[[DEV64]] : i64)
+
+  !$omp target device(dev16)
+  !$omp end target
+  ! CHECK: %[[DEV16:.*]] = fir.load %{{.*}} : !fir.ref<i16>
+  ! CHECK: omp.target device(%[[DEV16]] : i16)
+
+  !$omp target device(2)
+  !$omp end target
+  ! CHECK: %[[C2:.*]] = arith.constant 2 : i32
+  ! CHECK: omp.target device(%[[C2]] : i32)
+
+  !$omp target device(5_8)
+  !$omp end target
+  ! CHECK: %[[C5:.*]] = arith.constant 5 : i64
+  ! CHECK: omp.target device(%[[C5]] : i64)
+
+end subroutine omp_target_device
\ No newline at end of file
diff --git a/llvm/include/llvm/Frontend/OpenMP/OMPIRBuilder.h b/llvm/include/llvm/Frontend/OpenMP/OMPIRBuilder.h
index f5eb6222fd58d..8103a7e9504ea 100644
--- a/llvm/include/llvm/Frontend/OpenMP/OMPIRBuilder.h
+++ b/llvm/include/llvm/Frontend/OpenMP/OMPIRBuilder.h
@@ -3341,7 +3341,7 @@ class OpenMPIRBuilder {
       const LocationDescription &Loc, bool IsOffloadEntry,
       OpenMPIRBuilder::InsertPointTy AllocaIP,
       OpenMPIRBuilder::InsertPointTy CodeGenIP, TargetDataInfo &Info,
-      TargetRegionEntryInfo &EntryInfo,
+      Value *DeviceID, TargetRegionEntryInfo &EntryInfo,
       const TargetKernelDefaultAttrs &DefaultAttrs,
       const TargetKernelRuntimeAttrs &RuntimeAttrs, Value *IfCond,
       SmallVectorImpl<Value *> &Inputs, GenMapInfoCallbackTy GenMapInfoCB,
diff --git a/llvm/lib/Frontend/OpenMP/OMPIRBuilder.cpp b/llvm/lib/Frontend/OpenMP/OMPIRBuilder.cpp
index 716f8582dd7b2..3be96350cb058 100644
--- a/llvm/lib/Frontend/OpenMP/OMPIRBuilder.cpp
+++ b/llvm/lib/Frontend/OpenMP/OMPIRBuilder.cpp
@@ -8548,7 +8548,7 @@ Error OpenMPIRBuilder::emitOffloadingArraysAndArgs(
 static void emitTargetCall(
     OpenMPIRBuilder &OMPBuilder, IRBuilderBase &Builder,
     OpenMPIRBuilder::InsertPointTy AllocaIP,
-    OpenMPIRBuilder::TargetDataInfo &Info,
+    OpenMPIRBuilder::TargetDataInfo &Info, Value *DeviceID,
     const OpenMPIRBuilder::TargetKernelDefaultAttrs &DefaultAttrs,
     const OpenMPIRBuilder::TargetKernelRuntimeAttrs &RuntimeAttrs,
     Value *IfCond, Function *OutlinedFn, Constant *OutlinedFnID,
@@ -8680,8 +8680,6 @@ static void emitTargetCall(
     }
 
     unsigned NumTargetItems = Info.NumberOfPtrs;
-    // TODO: Use correct device ID
-    Value *DeviceID = Builder.getInt64(OMP_DEVICEID_UNDEF);
     uint32_t SrcLocStrSize;
     Constant *SrcLocStr = OMPBuilder.getOrCreateDefaultSrcLocStr(SrcLocStrSize);
     Value *RTLoc = OMPBuilder.getOrCreateIdent(SrcLocStr, SrcLocStrSize,
@@ -8740,7 +8738,7 @@ static void emitTargetCall(
 
 OpenMPIRBuilder::InsertPointOrErrorTy OpenMPIRBuilder::createTarget(
     const LocationDescription &Loc, bool IsOffloadEntry, InsertPointTy AllocaIP,
-    InsertPointTy CodeGenIP, TargetDataInfo &Info,
+    InsertPointTy CodeGenIP, TargetDataInfo &Info, Value *DeviceID,
     TargetRegionEntryInfo &EntryInfo,
     const TargetKernelDefaultAttrs &DefaultAttrs,
     const TargetKernelRuntimeAttrs &RuntimeAttrs, Value *IfCond,
@@ -8770,10 +8768,10 @@ OpenMPIRBuilder::InsertPointOrErrorTy OpenMPIRBuilder::createTarget(
   // to make a remote call (offload) to the previously outlined function
   // that represents the target region. Do that now.
   if (!Config.isTargetDevice())
-    emitTargetCall(*this, Builder, AllocaIP, Info, DefaultAttrs, RuntimeAttrs,
-                   IfCond, OutlinedFn, OutlinedFnID, Inputs, GenMapInfoCB,
-                   CustomMapperCB, Dependencies, HasNowait, DynCGroupMem,
-                   DynCGroupMemFallback);
+    emitTargetCall(*this, Builder, AllocaIP, Info, DeviceID, DefaultAttrs,
+                   RuntimeAttrs, IfCond, OutlinedFn, OutlinedFnID, Inputs,
+                   GenMapInfoCB, CustomMapperCB, Dependencies, HasNowait,
+                   DynCGroupMem, DynCGroupMemFallback);
   return Builder.saveIP();
 }
 
diff --git a/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp b/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp
index 03d67a52853f6..ac2d6c93b890e 100644
--- a/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp
+++ b/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp
@@ -320,7 +320,7 @@ static LogicalResult checkImplementationStatus(Operation &op) {
       result = todo("depend");
   };
   auto checkDevice = [&todo](auto op, LogicalResult &result) {
-    if (op.getDevice())
+    if (op.getDevice() && !isa<omp::TargetOp>(op))
       result = todo("device");
   };
   auto checkHint = [](auto op, LogicalResult &) {
@@ -5961,6 +5961,15 @@ convertOmpTarget(Operation &opInst, llvm::IRBuilderBase &builder,
   llvm::OpenMPIRBuilder *ompBuilder = moduleTranslation.getOpenMPBuilder();
   bool isTargetDevice = ompBuilder->Config.isTargetDevice();
   bool isGPU = ompBuilder->Config.isGPU();
+  llvm::Value *deviceIDValue = builder.getInt64(llvm::omp::OMP_DEVICEID_UNDEF);
+
+  if (!isTargetDevice) {
+    if (mlir::Value devId = targetOp.getDevice()) {
+      deviceIDValue = moduleTranslation.lookupValue(devId);
+      deviceIDValue =
+          builder.CreateSExtOrTrunc(deviceIDValue, builder.getInt64Ty());
+    }
+  }
 
   auto parentFn = opInst.getParentOfType<LLVM::LLVMFuncOp>();
   auto argIface = cast<omp::BlockArgOpenMPOpInterface>(opInst);
@@ -6235,9 +6244,10 @@ convertOmpTarget(Operation &opInst, llvm::IRBuilderBase &builder,
 
   llvm::OpenMPIRBuilder::InsertPointOrErrorTy afterIP =
       moduleTranslation.getOpenMPBuilder()->createTarget(
-          ompLoc, isOffloadEntry, allocaIP, builder.saveIP(), info, entryInfo,
-          defaultAttrs, runtimeAttrs, ifCond, kernelInput, genMapInfoCB, bodyCB,
-          argAccessorCB, customMapperCB, dds, targetOp.getNowait());
+          ompLoc, isOffloadEntry, allocaIP, builder.saveIP(), info,
+          deviceIDValue, entryInfo, defaultAttrs, runtimeAttrs, ifCond,
+          kernelInput, genMapInfoCB, bodyCB, argAccessorCB, customMapperCB, dds,
+          targetOp.getNowait());
 
   if (failed(handleError(afterIP, opInst)))
     return failure();
diff --git a/mlir/test/Target/LLVMIR/omptarget-device.mlir b/mlir/test/Target/LLVMIR/omptarget-device.mlir
new file mode 100644
index 0000000000000..b4c9744cc0c87
--- /dev/null
+++ b/mlir/test/Target/LLVMIR/omptarget-device.mlir
@@ -0,0 +1,68 @@
+// RUN: mlir-translate -mlir-to-llvmir %s | FileCheck %s
+
+module attributes {omp.is_target_device = false, omp.target_triples = ["nvptx64-nvidia-cuda"]} {
+  llvm.func @foo(%d16 : i16, %d32 : i32, %d64 : i64) {
+    %x  = llvm.mlir.constant(0 : i32) : i32
+
+    // Constant i16 -> i64 in the runtime call.
+    %c1_i16 = llvm.mlir.constant(1 : i16) : i16
+    omp.target device(%c1_i16 : i16)
+      host_eval(%x -> %lb, %x -> %ub, %x -> %step : i32, i32, i32) {
+      omp.terminator
+    }
+
+    // Constant i32 -> i64 in the runtime call.
+    %c2_i32 = llvm.mlir.constant(2 : i32) : i32
+    omp.target device(%c2_i32 : i32)
+      host_eval(%x -> %lb, %x -> %ub, %x -> %step : i32, i32, i32) {
+      omp.terminator
+    }
+
+    // Constant i64 stays i64 in the runtime call.
+    %c3_i64 = llvm.mlir.constant(3 : i64) : i64
+    omp.target device(%c3_i64 : i64)
+      host_eval(%x -> %lb, %x -> %ub, %x -> %step : i32, i32, i32) {
+      omp.terminator
+    }
+
+    // Variable i16 -> cast to i64.
+    omp.target device(%d16 : i16)
+      host_eval(%x -> %lb, %x -> %ub, %x -> %step : i32, i32, i32) {
+      omp.terminator
+    }
+
+    // Variable i32 -> cast to i64.
+    omp.target device(%d32 : i32)
+      host_eval(%x -> %lb, %x -> %ub, %x -> %step : i32, i32, i32) {
+      omp.terminator
+    }
+
+    // Variable i64 stays i64.
+    omp.target device(%d64 : i64)
+      host_eval(%x -> %lb, %x -> %ub, %x -> %step : i32, i32, i32) {
+      omp.terminator
+    }
+
+    llvm.return
+  }
+}
+
+// CHECK-LABEL: define void @foo(i16 %{{.*}}, i32 %{{.*}}, i64 %{{.*}}) {
+// CHECK: br label %entry
+// CHECK: entry:
+
+// ---- Constant cases (device id is 2nd argument) ----
+// CHECK-DAG: call i32 @__tgt_target_kernel(ptr {{.*}}, i64 1, i32 {{.*}}, i32 {{.*}}, ptr {{.*}}, ptr {{.*}})
+// CHECK-DAG: call i32 @__tgt_target_kernel(ptr {{.*}}, i64 2, i32 {{.*}}, i32 {{.*}}, ptr {{.*}}, ptr {{.*}})
+// CHECK-DAG: call i32 @__tgt_target_kernel(ptr {{.*}}, i64 3, i32 {{.*}}, i32 {{.*}}, ptr {{.*}}, ptr {{.*}})
+
+// Variable i16 -> i64
+// CHECK: %[[D16_I64:.*]] = sext i16 %{{.*}} to i64
+// CHECK: call i32 @__tgt_target_kernel(ptr {{.*}}, i64 %[[D16_I64]], i32 {{.*}}, i32 {{.*}}, ptr {{.*}}, ptr {{.*}})
+
+// Variable i32 -> i64
+// CHECK: %[[D32_I64:.*]] = sext i32 %{{.*}} to i64
+// CHECK: call i32 @__tgt_target_kernel(ptr {{.*}}, i64 %[[D32_I64]], i32 {{.*}}, i32 {{.*}}, ptr {{.*}}, ptr {{.*}})
+
+// Variable i64
+// CHECK: call i32 @__tgt_target_kernel(ptr {{.*}}, i64 %{{.*}}, i32 {{.*}}, i32 {{.*}}, ptr {{.*}}, ptr {{.*}})
\ No newline at end of file
diff --git a/mlir/test/Target/LLVMIR/openmp-todo.mlir b/mlir/test/Target/LLVMIR/openmp-todo.mlir
index 396c57af81c44..d4cc9e215de1d 100644
--- a/mlir/test/Target/LLVMIR/openmp-todo.mlir
+++ b/mlir/test/Target/LLVMIR/openmp-todo.mlir
@@ -174,8 +174,6 @@ llvm.func @target_allocate(%x : !llvm.ptr) {
 // -----
 
 llvm.func @target_device(%x : i32) {
-  // expected-error@below {{not yet implemented: Unhandled clause device in omp.target operation}}
-  // expected-error@below {{LLVM Translation failed for operation: omp.target}}
   omp.target device(%x : i32) {
     omp.terminator
   }

@llvmbot
Copy link
Member

llvmbot commented Dec 24, 2025

@llvm/pr-subscribers-flang-openmp

Author: Chi-Chun, Chen (chichunchen)

Changes

Add lowering support for the OpenMP device clause on the target directive in Flang.

The device expression is propagated through MLIR OpenMP and passed to the host-side __tgt_target_kernel call.


Full diff: https://github.com/llvm/llvm-project/pull/173509.diff

7 Files Affected:

  • (modified) flang/lib/Lower/OpenMP/OpenMP.cpp (+2-1)
  • (modified) flang/test/Lower/OpenMP/target.f90 (+41)
  • (modified) llvm/include/llvm/Frontend/OpenMP/OMPIRBuilder.h (+1-1)
  • (modified) llvm/lib/Frontend/OpenMP/OMPIRBuilder.cpp (+6-8)
  • (modified) mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp (+14-4)
  • (added) mlir/test/Target/LLVMIR/omptarget-device.mlir (+68)
  • (modified) mlir/test/Target/LLVMIR/openmp-todo.mlir (-2)
diff --git a/flang/lib/Lower/OpenMP/OpenMP.cpp b/flang/lib/Lower/OpenMP/OpenMP.cpp
index 7965119764e5d..4f2b8ef15519c 100644
--- a/flang/lib/Lower/OpenMP/OpenMP.cpp
+++ b/flang/lib/Lower/OpenMP/OpenMP.cpp
@@ -4087,7 +4087,8 @@ static void genOMP(lower::AbstractConverter &converter, lower::SymMap &symTable,
         !std::holds_alternative<clause::Mergeable>(clause.u) &&
         !std::holds_alternative<clause::Untied>(clause.u) &&
         !std::holds_alternative<clause::TaskReduction>(clause.u) &&
-        !std::holds_alternative<clause::Detach>(clause.u)) {
+        !std::holds_alternative<clause::Detach>(clause.u) &&
+        !std::holds_alternative<clause::Device>(clause.u)) {
       std::string name =
           parser::ToUpperCaseLetters(llvm::omp::getOpenMPClauseName(clause.id));
       if (!semaCtx.langOptions().OpenMPSimd)
diff --git a/flang/test/Lower/OpenMP/target.f90 b/flang/test/Lower/OpenMP/target.f90
index c5d39695e5389..55a6b7a595ed1 100644
--- a/flang/test/Lower/OpenMP/target.f90
+++ b/flang/test/Lower/OpenMP/target.f90
@@ -694,3 +694,44 @@ subroutine target_unstructured
    !$omp end target
    !CHECK: }
 end subroutine target_unstructured
+
+!===============================================================================
+! Target `device` clause
+!===============================================================================
+
+!CHECK-LABEL: func.func @_QPomp_target_device() {
+subroutine omp_target_device
+  integer            :: dev32
+  integer(kind=8)    :: dev64
+  integer(kind=2)    :: dev16
+
+  dev32 = 1
+  dev64 = 2_8
+  dev16 = 3_2
+
+  !$omp target device(dev32)
+  !$omp end target
+  ! CHECK: %[[DEV32:.*]] = fir.load %{{.*}} : !fir.ref<i32>
+  ! CHECK: omp.target device(%[[DEV32]] : i32)
+
+  !$omp target device(dev64)
+  !$omp end target
+  ! CHECK: %[[DEV64:.*]] = fir.load %{{.*}} : !fir.ref<i64>
+  ! CHECK: omp.target device(%[[DEV64]] : i64)
+
+  !$omp target device(dev16)
+  !$omp end target
+  ! CHECK: %[[DEV16:.*]] = fir.load %{{.*}} : !fir.ref<i16>
+  ! CHECK: omp.target device(%[[DEV16]] : i16)
+
+  !$omp target device(2)
+  !$omp end target
+  ! CHECK: %[[C2:.*]] = arith.constant 2 : i32
+  ! CHECK: omp.target device(%[[C2]] : i32)
+
+  !$omp target device(5_8)
+  !$omp end target
+  ! CHECK: %[[C5:.*]] = arith.constant 5 : i64
+  ! CHECK: omp.target device(%[[C5]] : i64)
+
+end subroutine omp_target_device
\ No newline at end of file
diff --git a/llvm/include/llvm/Frontend/OpenMP/OMPIRBuilder.h b/llvm/include/llvm/Frontend/OpenMP/OMPIRBuilder.h
index f5eb6222fd58d..8103a7e9504ea 100644
--- a/llvm/include/llvm/Frontend/OpenMP/OMPIRBuilder.h
+++ b/llvm/include/llvm/Frontend/OpenMP/OMPIRBuilder.h
@@ -3341,7 +3341,7 @@ class OpenMPIRBuilder {
       const LocationDescription &Loc, bool IsOffloadEntry,
       OpenMPIRBuilder::InsertPointTy AllocaIP,
       OpenMPIRBuilder::InsertPointTy CodeGenIP, TargetDataInfo &Info,
-      TargetRegionEntryInfo &EntryInfo,
+      Value *DeviceID, TargetRegionEntryInfo &EntryInfo,
       const TargetKernelDefaultAttrs &DefaultAttrs,
       const TargetKernelRuntimeAttrs &RuntimeAttrs, Value *IfCond,
       SmallVectorImpl<Value *> &Inputs, GenMapInfoCallbackTy GenMapInfoCB,
diff --git a/llvm/lib/Frontend/OpenMP/OMPIRBuilder.cpp b/llvm/lib/Frontend/OpenMP/OMPIRBuilder.cpp
index 716f8582dd7b2..3be96350cb058 100644
--- a/llvm/lib/Frontend/OpenMP/OMPIRBuilder.cpp
+++ b/llvm/lib/Frontend/OpenMP/OMPIRBuilder.cpp
@@ -8548,7 +8548,7 @@ Error OpenMPIRBuilder::emitOffloadingArraysAndArgs(
 static void emitTargetCall(
     OpenMPIRBuilder &OMPBuilder, IRBuilderBase &Builder,
     OpenMPIRBuilder::InsertPointTy AllocaIP,
-    OpenMPIRBuilder::TargetDataInfo &Info,
+    OpenMPIRBuilder::TargetDataInfo &Info, Value *DeviceID,
     const OpenMPIRBuilder::TargetKernelDefaultAttrs &DefaultAttrs,
     const OpenMPIRBuilder::TargetKernelRuntimeAttrs &RuntimeAttrs,
     Value *IfCond, Function *OutlinedFn, Constant *OutlinedFnID,
@@ -8680,8 +8680,6 @@ static void emitTargetCall(
     }
 
     unsigned NumTargetItems = Info.NumberOfPtrs;
-    // TODO: Use correct device ID
-    Value *DeviceID = Builder.getInt64(OMP_DEVICEID_UNDEF);
     uint32_t SrcLocStrSize;
     Constant *SrcLocStr = OMPBuilder.getOrCreateDefaultSrcLocStr(SrcLocStrSize);
     Value *RTLoc = OMPBuilder.getOrCreateIdent(SrcLocStr, SrcLocStrSize,
@@ -8740,7 +8738,7 @@ static void emitTargetCall(
 
 OpenMPIRBuilder::InsertPointOrErrorTy OpenMPIRBuilder::createTarget(
     const LocationDescription &Loc, bool IsOffloadEntry, InsertPointTy AllocaIP,
-    InsertPointTy CodeGenIP, TargetDataInfo &Info,
+    InsertPointTy CodeGenIP, TargetDataInfo &Info, Value *DeviceID,
     TargetRegionEntryInfo &EntryInfo,
     const TargetKernelDefaultAttrs &DefaultAttrs,
     const TargetKernelRuntimeAttrs &RuntimeAttrs, Value *IfCond,
@@ -8770,10 +8768,10 @@ OpenMPIRBuilder::InsertPointOrErrorTy OpenMPIRBuilder::createTarget(
   // to make a remote call (offload) to the previously outlined function
   // that represents the target region. Do that now.
   if (!Config.isTargetDevice())
-    emitTargetCall(*this, Builder, AllocaIP, Info, DefaultAttrs, RuntimeAttrs,
-                   IfCond, OutlinedFn, OutlinedFnID, Inputs, GenMapInfoCB,
-                   CustomMapperCB, Dependencies, HasNowait, DynCGroupMem,
-                   DynCGroupMemFallback);
+    emitTargetCall(*this, Builder, AllocaIP, Info, DeviceID, DefaultAttrs,
+                   RuntimeAttrs, IfCond, OutlinedFn, OutlinedFnID, Inputs,
+                   GenMapInfoCB, CustomMapperCB, Dependencies, HasNowait,
+                   DynCGroupMem, DynCGroupMemFallback);
   return Builder.saveIP();
 }
 
diff --git a/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp b/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp
index 03d67a52853f6..ac2d6c93b890e 100644
--- a/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp
+++ b/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp
@@ -320,7 +320,7 @@ static LogicalResult checkImplementationStatus(Operation &op) {
       result = todo("depend");
   };
   auto checkDevice = [&todo](auto op, LogicalResult &result) {
-    if (op.getDevice())
+    if (op.getDevice() && !isa<omp::TargetOp>(op))
       result = todo("device");
   };
   auto checkHint = [](auto op, LogicalResult &) {
@@ -5961,6 +5961,15 @@ convertOmpTarget(Operation &opInst, llvm::IRBuilderBase &builder,
   llvm::OpenMPIRBuilder *ompBuilder = moduleTranslation.getOpenMPBuilder();
   bool isTargetDevice = ompBuilder->Config.isTargetDevice();
   bool isGPU = ompBuilder->Config.isGPU();
+  llvm::Value *deviceIDValue = builder.getInt64(llvm::omp::OMP_DEVICEID_UNDEF);
+
+  if (!isTargetDevice) {
+    if (mlir::Value devId = targetOp.getDevice()) {
+      deviceIDValue = moduleTranslation.lookupValue(devId);
+      deviceIDValue =
+          builder.CreateSExtOrTrunc(deviceIDValue, builder.getInt64Ty());
+    }
+  }
 
   auto parentFn = opInst.getParentOfType<LLVM::LLVMFuncOp>();
   auto argIface = cast<omp::BlockArgOpenMPOpInterface>(opInst);
@@ -6235,9 +6244,10 @@ convertOmpTarget(Operation &opInst, llvm::IRBuilderBase &builder,
 
   llvm::OpenMPIRBuilder::InsertPointOrErrorTy afterIP =
       moduleTranslation.getOpenMPBuilder()->createTarget(
-          ompLoc, isOffloadEntry, allocaIP, builder.saveIP(), info, entryInfo,
-          defaultAttrs, runtimeAttrs, ifCond, kernelInput, genMapInfoCB, bodyCB,
-          argAccessorCB, customMapperCB, dds, targetOp.getNowait());
+          ompLoc, isOffloadEntry, allocaIP, builder.saveIP(), info,
+          deviceIDValue, entryInfo, defaultAttrs, runtimeAttrs, ifCond,
+          kernelInput, genMapInfoCB, bodyCB, argAccessorCB, customMapperCB, dds,
+          targetOp.getNowait());
 
   if (failed(handleError(afterIP, opInst)))
     return failure();
diff --git a/mlir/test/Target/LLVMIR/omptarget-device.mlir b/mlir/test/Target/LLVMIR/omptarget-device.mlir
new file mode 100644
index 0000000000000..b4c9744cc0c87
--- /dev/null
+++ b/mlir/test/Target/LLVMIR/omptarget-device.mlir
@@ -0,0 +1,68 @@
+// RUN: mlir-translate -mlir-to-llvmir %s | FileCheck %s
+
+module attributes {omp.is_target_device = false, omp.target_triples = ["nvptx64-nvidia-cuda"]} {
+  llvm.func @foo(%d16 : i16, %d32 : i32, %d64 : i64) {
+    %x  = llvm.mlir.constant(0 : i32) : i32
+
+    // Constant i16 -> i64 in the runtime call.
+    %c1_i16 = llvm.mlir.constant(1 : i16) : i16
+    omp.target device(%c1_i16 : i16)
+      host_eval(%x -> %lb, %x -> %ub, %x -> %step : i32, i32, i32) {
+      omp.terminator
+    }
+
+    // Constant i32 -> i64 in the runtime call.
+    %c2_i32 = llvm.mlir.constant(2 : i32) : i32
+    omp.target device(%c2_i32 : i32)
+      host_eval(%x -> %lb, %x -> %ub, %x -> %step : i32, i32, i32) {
+      omp.terminator
+    }
+
+    // Constant i64 stays i64 in the runtime call.
+    %c3_i64 = llvm.mlir.constant(3 : i64) : i64
+    omp.target device(%c3_i64 : i64)
+      host_eval(%x -> %lb, %x -> %ub, %x -> %step : i32, i32, i32) {
+      omp.terminator
+    }
+
+    // Variable i16 -> cast to i64.
+    omp.target device(%d16 : i16)
+      host_eval(%x -> %lb, %x -> %ub, %x -> %step : i32, i32, i32) {
+      omp.terminator
+    }
+
+    // Variable i32 -> cast to i64.
+    omp.target device(%d32 : i32)
+      host_eval(%x -> %lb, %x -> %ub, %x -> %step : i32, i32, i32) {
+      omp.terminator
+    }
+
+    // Variable i64 stays i64.
+    omp.target device(%d64 : i64)
+      host_eval(%x -> %lb, %x -> %ub, %x -> %step : i32, i32, i32) {
+      omp.terminator
+    }
+
+    llvm.return
+  }
+}
+
+// CHECK-LABEL: define void @foo(i16 %{{.*}}, i32 %{{.*}}, i64 %{{.*}}) {
+// CHECK: br label %entry
+// CHECK: entry:
+
+// ---- Constant cases (device id is 2nd argument) ----
+// CHECK-DAG: call i32 @__tgt_target_kernel(ptr {{.*}}, i64 1, i32 {{.*}}, i32 {{.*}}, ptr {{.*}}, ptr {{.*}})
+// CHECK-DAG: call i32 @__tgt_target_kernel(ptr {{.*}}, i64 2, i32 {{.*}}, i32 {{.*}}, ptr {{.*}}, ptr {{.*}})
+// CHECK-DAG: call i32 @__tgt_target_kernel(ptr {{.*}}, i64 3, i32 {{.*}}, i32 {{.*}}, ptr {{.*}}, ptr {{.*}})
+
+// Variable i16 -> i64
+// CHECK: %[[D16_I64:.*]] = sext i16 %{{.*}} to i64
+// CHECK: call i32 @__tgt_target_kernel(ptr {{.*}}, i64 %[[D16_I64]], i32 {{.*}}, i32 {{.*}}, ptr {{.*}}, ptr {{.*}})
+
+// Variable i32 -> i64
+// CHECK: %[[D32_I64:.*]] = sext i32 %{{.*}} to i64
+// CHECK: call i32 @__tgt_target_kernel(ptr {{.*}}, i64 %[[D32_I64]], i32 {{.*}}, i32 {{.*}}, ptr {{.*}}, ptr {{.*}})
+
+// Variable i64
+// CHECK: call i32 @__tgt_target_kernel(ptr {{.*}}, i64 %{{.*}}, i32 {{.*}}, i32 {{.*}}, ptr {{.*}}, ptr {{.*}})
\ No newline at end of file
diff --git a/mlir/test/Target/LLVMIR/openmp-todo.mlir b/mlir/test/Target/LLVMIR/openmp-todo.mlir
index 396c57af81c44..d4cc9e215de1d 100644
--- a/mlir/test/Target/LLVMIR/openmp-todo.mlir
+++ b/mlir/test/Target/LLVMIR/openmp-todo.mlir
@@ -174,8 +174,6 @@ llvm.func @target_allocate(%x : !llvm.ptr) {
 // -----
 
 llvm.func @target_device(%x : i32) {
-  // expected-error@below {{not yet implemented: Unhandled clause device in omp.target operation}}
-  // expected-error@below {{LLVM Translation failed for operation: omp.target}}
   omp.target device(%x : i32) {
     omp.terminator
   }

@github-actions
Copy link

github-actions bot commented Dec 24, 2025

🐧 Linux x64 Test Results

  • 188314 tests passed
  • 4984 tests skipped

✅ The build succeeded and all tests passed.

@github-actions
Copy link

github-actions bot commented Dec 24, 2025

🪟 Windows x64 Test Results

  • 132904 tests passed
  • 3085 tests skipped

✅ The build succeeded and all tests passed.

Add lowering support for the OpenMP `device` clause on the `target`
directive in Flang.

The device expression is propagated through MLIR OpenMP and passed to
the host-side `__tgt_target_kernel` call.
@chichunchen chichunchen force-pushed the cchen/rfe/flang_device_clause branch from 5e71194 to 3519e24 Compare January 2, 2026 22:28
@github-actions
Copy link

github-actions bot commented Jan 2, 2026

✅ With the latest revision this PR passed the C/C++ code formatter.

@chichunchen chichunchen requested review from ergawy and tblah January 4, 2026 23:47
Copy link
Member

@ergawy ergawy left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks! Just one small comment about proberly reporting current lowering status.

};
auto checkDevice = [&todo](auto op, LogicalResult &result) {
if (op.getDevice())
if (op.getDevice() && !isa<omp::TargetOp>(op))
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This change is not needed. I think instead you should add checkDevice(...) to .Case<omp::TargetEnterDataOp, omp::TargetExitDataOp, omp::TargetUpdateOp> below. Also, we should add a Case<omp::TargetDataOp>.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I removed checkDevice from TargetOp and added it to TargetUpdateOp and TargetDataOp. I did not add checkDevice to TargetEnterDataOp or TargetExitDataOp, because doing so would cause regressions in omptarget-llvm.mlir. That test uses the device clause but hardcodes the device_id to -1 when calling the __tgt_target_enter_data and __tgt_target_exit_data APIs.

Copy link
Contributor

@tblah tblah left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thank you for implementing this. Please could you update flang/docs/OpenMPSupport.md

@chichunchen chichunchen force-pushed the cchen/rfe/flang_device_clause branch from e05fb83 to 320c3a9 Compare January 5, 2026 18:17
@chichunchen chichunchen requested review from ergawy and tblah January 5, 2026 18:26
- Update OpenMP support for the device clause on the target construct
- Remove checkDevice from the target construct.
- Add checkDevice on target_update and target_data
- Does not add checkDevice on target_enter and target_exit
    - Adding checkDevice to target_enter and target_exit would cause a
      regression in omptarget-llvm.mlir, since that test currently uses
      the device clause in both target_enter and target_exit but expects
      -1 to be passed to the __tgt_target_data* APIs.
@chichunchen chichunchen force-pushed the cchen/rfe/flang_device_clause branch from 320c3a9 to d9313a0 Compare January 5, 2026 18:27
@chichunchen
Copy link
Contributor Author

Thank you for implementing this. Please could you update flang/docs/OpenMPSupport.md

Done!

Copy link
Member

@ergawy ergawy left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM! Thanks

Copy link
Contributor

@tblah tblah left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

clang:openmp OpenMP related changes to Clang flang:fir-hlfir flang:openmp flang Flang issues not falling into any other category mlir:llvm mlir:openmp mlir

Projects

None yet

Development

Successfully merging this pull request may close these issues.

4 participants