[Flang][OpenMP] Don't generate code for unreachable target regions.#178937
Conversation
When a target region is placed inside a constant false condition (e.g.,
`if (.false.)`), MLIR's Canonicalizer pass correctly eliminates the dead
code on the host side, removing the `omp.target` operation entirely.
However, the device-side compilation pipeline is unaware of this
elimination and attempts to generate kernel code. Since the host never
created offload metadata for the eliminated target, the device-side
kernel function lacks the "kernel" attribute, causing OpenMPOpt to fail
with an assertion when it expects all outlined kernels to have this
attribute. The problem can be seen with the following code:
```fortran
program cele
implicit none
real :: V
integer :: i
if (.false.) then
!$omp target teams distribute parallel do
do i = 1, 5
V = V * 2
end do
!$omp end target teams distribute parallel do
end if
end program
```
It currently fails with the follwoing assertion:
```
Assertion `omp::isOpenMPKernel(*Kernel) && "Expected kernel function!"' failed.
llvm/lib/Transforms/IPO/OpenMPOpt.cpp:4291
```
This PR adds MarkUnreachableTargetsPass that identifies `omp.target`
operations in unreachable code blocks and marks them with
`omp.target_unreachable` attribute. This attribute is later used in
FunctionFilteringPass and in OpenMPToLLVMIRTranslation to prevent
generation of code for such op.
|
@llvm/pr-subscribers-mlir-openmp @llvm/pr-subscribers-flang-fir-hlfir Author: Abid Qadeer (abidh) ChangesWhen a target region is placed inside a constant false condition (e.g., program cele
implicit none
real :: V
integer :: i
if (.false.) then
!$omp target teams distribute parallel do
do i = 1, 5
V = V * 2
end do
!$omp end target teams distribute parallel do
end if
end programIt currently fails with the following assertion: This PR adds Patch is 21.50 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/178937.diff 9 Files Affected:
diff --git a/flang/include/flang/Optimizer/OpenMP/Passes.td b/flang/include/flang/Optimizer/OpenMP/Passes.td
index f17b1e3794908..793d575b7da8e 100644
--- a/flang/include/flang/Optimizer/OpenMP/Passes.td
+++ b/flang/include/flang/Optimizer/OpenMP/Passes.td
@@ -41,6 +41,18 @@ def MarkDeclareTargetPass
let dependentDialects = ["mlir::omp::OpenMPDialect"];
}
+def MarkUnreachableTargetsPass
+ : Pass<"omp-mark-unreachable-targets", "mlir::ModuleOp"> {
+ let summary = "Marks OpenMP target operations in unreachable code";
+ let description = [{
+ Identifies OpenMP target operations that reside in unreachable code
+ (e.g., inside if(.false.) blocks) and marks them with an attribute.
+ This allows device compilation to skip generating code for targets
+ that were eliminated on the host side.
+ }];
+ let dependentDialects = ["mlir::omp::OpenMPDialect"];
+}
+
def FunctionFilteringPass : Pass<"omp-function-filtering"> {
let summary = "Filters out functions intended for the host when compiling "
"for the target device.";
diff --git a/flang/lib/Optimizer/OpenMP/CMakeLists.txt b/flang/lib/Optimizer/OpenMP/CMakeLists.txt
index 23a7dc8f08399..136b9d9ea9313 100644
--- a/flang/lib/Optimizer/OpenMP/CMakeLists.txt
+++ b/flang/lib/Optimizer/OpenMP/CMakeLists.txt
@@ -8,6 +8,7 @@ add_flang_library(FlangOpenMPTransforms
MapsForPrivatizedSymbols.cpp
MapInfoFinalization.cpp
MarkDeclareTarget.cpp
+ MarkUnreachableTargets.cpp
LowerWorkdistribute.cpp
LowerWorkshare.cpp
LowerNontemporal.cpp
diff --git a/flang/lib/Optimizer/OpenMP/FunctionFiltering.cpp b/flang/lib/Optimizer/OpenMP/FunctionFiltering.cpp
index 0acee8991e372..f7828df21182e 100644
--- a/flang/lib/Optimizer/OpenMP/FunctionFiltering.cpp
+++ b/flang/lib/Optimizer/OpenMP/FunctionFiltering.cpp
@@ -120,10 +120,15 @@ class FunctionFilteringPass
// Do not filter functions with target regions inside, because they have
// to be available for both host and device so that regular and reverse
// offloading can be supported.
+ // However, skip target regions marked as unreachable.
bool hasTargetRegion =
funcOp
- ->walk<WalkOrder::PreOrder>(
- [&](omp::TargetOp) { return WalkResult::interrupt(); })
+ ->walk<WalkOrder::PreOrder>([&](omp::TargetOp targetOp) {
+ // Skip targets marked as unreachable
+ if (targetOp->hasAttr("omp.target_unreachable"))
+ return WalkResult::advance();
+ return WalkResult::interrupt();
+ })
.wasInterrupted();
omp::DeclareTargetDeviceType declareType =
diff --git a/flang/lib/Optimizer/OpenMP/MarkUnreachableTargets.cpp b/flang/lib/Optimizer/OpenMP/MarkUnreachableTargets.cpp
new file mode 100644
index 0000000000000..06423f32a77f0
--- /dev/null
+++ b/flang/lib/Optimizer/OpenMP/MarkUnreachableTargets.cpp
@@ -0,0 +1,166 @@
+//===- MarkUnreachableTargets.cpp ----------------------------------------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// This pass marks OpenMP target operations that are in unreachable code
+// with an attribute. This allows device compilation to skip generating code
+// for such ops.
+//
+//===----------------------------------------------------------------------===//
+
+#include "flang/Optimizer/OpenMP/Passes.h"
+
+#include "flang/Optimizer/Dialect/FIRDialect.h"
+#include "flang/Optimizer/Dialect/FIROps.h"
+#include "mlir/Dialect/Arith/IR/Arith.h"
+#include "mlir/Dialect/ControlFlow/IR/ControlFlowOps.h"
+#include "mlir/Dialect/Func/IR/FuncOps.h"
+#include "mlir/Dialect/OpenMP/OpenMPDialect.h"
+#include "mlir/IR/BuiltinOps.h"
+#include "mlir/Pass/Pass.h"
+#include "mlir/Support/LLVM.h"
+#include "llvm/ADT/SmallSet.h"
+
+namespace flangomp {
+#define GEN_PASS_DEF_MARKUNREACHABLETARGETSPASS
+#include "flang/Optimizer/OpenMP/Passes.h.inc"
+} // namespace flangomp
+
+using namespace mlir;
+
+namespace {
+
+/// Check if an operation is nested inside a fir.if with a constant false
+/// condition.
+static bool isInUnreachableIfBlock(Operation *op) {
+ Operation *current = op;
+
+ // Walk up through parent operations
+ while (current) {
+ Operation *parentOp = current->getParentOp();
+ if (!parentOp)
+ break;
+
+ // Check for fir.if with constant false condition
+ if (auto firIf = dyn_cast<fir::IfOp>(parentOp)) {
+ if (auto constOp =
+ firIf.getCondition().getDefiningOp<arith::ConstantOp>()) {
+ if (auto intAttr = dyn_cast<IntegerAttr>(constOp.getValue())) {
+ // If condition is false (0) and op is in the "then" region
+ if (intAttr.getInt() == 0 &&
+ current->getParentRegion() == &firIf.getThenRegion())
+ return true;
+ // If condition is true (non-zero) and op is in the "else" region
+ if (intAttr.getInt() != 0 && !firIf.getElseRegion().empty() &&
+ current->getParentRegion() == &firIf.getElseRegion())
+ return true;
+ }
+ }
+ }
+
+ current = parentOp;
+ }
+
+ return false;
+}
+
+/// Check if a block is unreachable due to constant condition branches.
+/// A block is unreachable only if ALL predecessors lead to it through
+/// unreachable paths (i.e., constant false conditions).
+/// This handles patterns like:
+/// %false = arith.constant false
+/// cf.cond_br %false, ^bb1, ^bb2
+/// where ^bb1 is unreachable.
+static bool isBlockUnreachable(Block *block) {
+ // Entry blocks and blocks with no predecessors are reachable
+ if (block->hasNoPredecessors())
+ return false;
+
+ // Check all predecessors - block is unreachable only if ALL paths are
+ // provably unreachable via constant conditions
+ for (Block *pred : block->getPredecessors()) {
+ Operation *terminator = pred->getTerminator();
+
+ // Check if this is a cf.cond_br with constant condition
+ if (auto condBr = dyn_cast<cf::CondBranchOp>(terminator)) {
+ // Try to get the constant value of the condition
+ if (auto constOp =
+ condBr.getCondition().getDefiningOp<arith::ConstantOp>()) {
+ if (auto intAttr = dyn_cast<IntegerAttr>(constOp.getValue())) {
+ bool condIsTrue = intAttr.getInt() != 0;
+
+ // If condition is false and block is the true destination,
+ // this path is unreachable - continue checking other predecessors
+ if (!condIsTrue && block == condBr.getTrueDest())
+ continue;
+ // If condition is true and block is the false destination,
+ // this path is unreachable - continue checking other predecessors
+ if (condIsTrue && block == condBr.getFalseDest())
+ continue;
+ // Otherwise, this path IS reachable (condition matches destination)
+ return false;
+ }
+ }
+ }
+
+ // If we reach here, this predecessor either:
+ // - is not a CondBranchOp, OR
+ // - doesn't have a constant condition
+ // Either way, this path could be taken, so block is reachable
+ return false;
+ }
+
+ // All predecessors lead to this block through unreachable paths
+ return true;
+}
+
+/// Recursively check if an operation is in an unreachable block.
+/// This walks up the block hierarchy to check if any containing block
+/// is unreachable, handling both fir.if and cf.cond_br patterns.
+static bool isOperationUnreachable(Operation *op) {
+ // First check for fir.if patterns (before SCF lowering)
+ if (isInUnreachableIfBlock(op))
+ return true;
+
+ // Then check for cf.cond_br patterns (after SCF lowering)
+ Block *currentBlock = op->getBlock();
+
+ // Walk up through nested regions checking each block
+ while (currentBlock) {
+ if (isBlockUnreachable(currentBlock))
+ return true;
+
+ // Move to parent operation's block
+ Operation *parentOp = currentBlock->getParentOp();
+ if (!parentOp || isa<ModuleOp>(parentOp) || isa<func::FuncOp>(parentOp))
+ break;
+
+ currentBlock = parentOp->getBlock();
+ }
+
+ return false;
+}
+
+class MarkUnreachableTargetsPass
+ : public flangomp::impl::MarkUnreachableTargetsPassBase<
+ MarkUnreachableTargetsPass> {
+public:
+ MarkUnreachableTargetsPass() = default;
+
+ void runOnOperation() override {
+ MLIRContext *context = &getContext();
+ auto module = getOperation();
+
+ // Walk all target operations and mark those that are unreachable
+ module.walk([&](omp::TargetOp targetOp) {
+ if (isOperationUnreachable(targetOp.getOperation()))
+ targetOp->setAttr("omp.target_unreachable", UnitAttr::get(context));
+ });
+ }
+};
+
+} // namespace
diff --git a/flang/lib/Optimizer/Passes/Pipelines.cpp b/flang/lib/Optimizer/Passes/Pipelines.cpp
index 6054675643c64..a73a3f9d2325c 100644
--- a/flang/lib/Optimizer/Passes/Pipelines.cpp
+++ b/flang/lib/Optimizer/Passes/Pipelines.cpp
@@ -341,6 +341,11 @@ void createOpenMPFIRPassPipeline(mlir::PassManager &pm,
pm.addPass(flangomp::createAutomapToTargetDataPass());
pm.addPass(flangomp::createMapInfoFinalizationPass());
pm.addPass(flangomp::createMarkDeclareTargetPass());
+
+ // Mark unreachable target operations before FunctionFilteringPass
+ // extracts them.
+ pm.addPass(flangomp::createMarkUnreachableTargetsPass());
+
pm.addPass(flangomp::createGenericLoopConversionPass());
if (opts.isTargetDevice)
pm.addPass(flangomp::createFunctionFilteringPass());
diff --git a/flang/test/Lower/OpenMP/target-dead-code.f90 b/flang/test/Lower/OpenMP/target-dead-code.f90
new file mode 100644
index 0000000000000..82932ca82858f
--- /dev/null
+++ b/flang/test/Lower/OpenMP/target-dead-code.f90
@@ -0,0 +1,82 @@
+! RUN: %flang_fc1 -emit-hlfir -fopenmp %s -o - | FileCheck %s --check-prefix=FIR
+
+! Test that OpenMP target regions in dead code are marked for elimination
+
+! Test 1: if (.false.) with target - should be marked unreachable
+! FIR-LABEL: func.func @_QPtest_dead_simple
+! FIR: %[[FALSE:.*]] = arith.constant false
+! FIR: fir.if %[[FALSE]] {
+! FIR: omp.target
+! FIR: } {omp.target_unreachable}
+subroutine test_dead_simple()
+ real :: v
+ if (.false.) then
+ !$omp target map(tofrom:v)
+ v = 1.0
+ !$omp end target
+ end if
+end subroutine
+
+! Test 2: Live target - should NOT be marked
+! FIR-LABEL: func.func @_QPtest_live_simple
+! FIR: omp.target
+! FIR-NOT: omp.target_unreachable
+subroutine test_live_simple()
+ real :: v
+ !$omp target map(tofrom:v)
+ v = 2.0
+ !$omp end target
+end subroutine
+
+! Test 3: Mixed dead and live
+! FIR-LABEL: func.func @_QPtest_mixed
+subroutine test_mixed()
+ real :: v
+ ! Dead - should be marked
+ ! FIR: fir.if %{{.*}} {
+ if (.false.) then
+ !$omp target map(tofrom:v)
+ ! FIR: omp.target
+ ! FIR: } {omp.target_unreachable}
+ v = 3.0
+ !$omp end target
+ end if
+ ! Live - should NOT be marked
+ !$omp target map(tofrom:v)
+ ! FIR: omp.target
+ ! FIR-NOT: omp.target_unreachable
+ v = 4.0
+ !$omp end target
+end subroutine
+
+! Test 4: Nested - outer false
+! FIR-LABEL: func.func @_QPtest_nested_outer_false
+subroutine test_nested_outer_false()
+ real :: v
+ ! FIR: fir.if %{{.*}} {
+ if (.false.) then
+ ! FIR: fir.if %{{.*}} {
+ if (.true.) then
+ ! FIR: omp.target
+ ! FIR: } {omp.target_unreachable}
+ !$omp target map(tofrom:v)
+ v = 5.0
+ !$omp end target
+ end if
+ end if
+end subroutine
+
+! Test 5: Parameter constant
+! FIR-LABEL: func.func @_QPtest_parameter
+subroutine test_parameter()
+ real :: v
+ logical, parameter :: DEAD = .false.
+ ! FIR: fir.if %{{.*}} {
+ if (DEAD) then
+ ! FIR: omp.target
+ ! FIR: } {omp.target_unreachable}
+ !$omp target map(tofrom:v)
+ v = 6.0
+ !$omp end target
+ end if
+end subroutine
diff --git a/flang/test/Transforms/OpenMP/mark-unreachable-targets.mlir b/flang/test/Transforms/OpenMP/mark-unreachable-targets.mlir
new file mode 100644
index 0000000000000..66f7e607a65fd
--- /dev/null
+++ b/flang/test/Transforms/OpenMP/mark-unreachable-targets.mlir
@@ -0,0 +1,331 @@
+// RUN: fir-opt --omp-mark-unreachable-targets %s | FileCheck %s
+
+// CHECK-LABEL: func.func @test_if_false_simple
+func.func @test_if_false_simple() {
+ %false = arith.constant false
+ // CHECK: fir.if %{{.*}} {
+ fir.if %false {
+ // CHECK: omp.target
+ // CHECK-NEXT: omp.terminator
+ // CHECK-NEXT: } {omp.target_unreachable}
+ omp.target {
+ omp.terminator
+ }
+ }
+ return
+}
+
+// -----
+
+// CHECK-LABEL: func.func @test_if_true_simple
+func.func @test_if_true_simple() {
+ %true = arith.constant true
+ // CHECK: fir.if %{{.*}} {
+ fir.if %true {
+ // CHECK: omp.target {
+ // CHECK-NEXT: omp.terminator
+ // CHECK-NEXT: }
+ // CHECK-NOT: omp.target_unreachable
+ omp.target {
+ omp.terminator
+ }
+ }
+ return
+}
+
+// -----
+
+// CHECK-LABEL: func.func @test_nested_outer_false
+func.func @test_nested_outer_false() {
+ %false = arith.constant false
+ %true = arith.constant true
+ // CHECK: fir.if %{{.*}} {
+ fir.if %false {
+ // CHECK: fir.if %{{.*}} {
+ fir.if %true {
+ // CHECK: omp.target
+ // CHECK-NEXT: omp.terminator
+ // CHECK-NEXT: } {omp.target_unreachable}
+ omp.target {
+ omp.terminator
+ }
+ }
+ }
+ return
+}
+
+// -----
+
+// CHECK-LABEL: func.func @test_nested_inner_false
+func.func @test_nested_inner_false() {
+ %false = arith.constant false
+ %true = arith.constant true
+ // CHECK: fir.if %{{.*}} {
+ fir.if %true {
+ // CHECK: fir.if %{{.*}} {
+ fir.if %false {
+ // CHECK: omp.target
+ // CHECK: } {omp.target_unreachable}
+ omp.target {
+ omp.terminator
+ }
+ }
+ }
+ return
+}
+
+// -----
+
+// CHECK-LABEL: func.func @test_nested_both_true
+func.func @test_nested_both_true() {
+ %true1 = arith.constant true
+ %true2 = arith.constant true
+ // CHECK: fir.if %{{.*}} {
+ fir.if %true1 {
+ // CHECK: fir.if %{{.*}} {
+ fir.if %true2 {
+ // CHECK: omp.target {
+ // CHECK-NEXT: omp.terminator
+ // CHECK-NEXT: }
+ // CHECK-NOT: omp.target_unreachable
+ omp.target {
+ omp.terminator
+ }
+ }
+ }
+ return
+}
+
+// -----
+
+// CHECK-LABEL: func.func @test_mixed_targets
+func.func @test_mixed_targets() {
+ %false = arith.constant false
+ %true = arith.constant true
+
+ // Dead target
+ // CHECK: fir.if %{{.*}} {
+ fir.if %false {
+ // CHECK: omp.target
+ // CHECK: } {omp.target_unreachable}
+ omp.target {
+ omp.terminator
+ }
+ }
+
+ // Live target - should NOT have unreachable attribute
+ // CHECK: omp.target {
+ // CHECK-NEXT: omp.terminator
+ // CHECK-NEXT: }
+ // CHECK-NOT: omp.target_unreachable
+ omp.target {
+ omp.terminator
+ }
+
+ // Another live target in if (true)
+ // CHECK: fir.if %{{.*}} {
+ fir.if %true {
+ // CHECK: omp.target {
+ // CHECK-NEXT: omp.terminator
+ // CHECK-NEXT: }
+ // CHECK-NOT: omp.target_unreachable
+ omp.target {
+ omp.terminator
+ }
+ }
+
+ return
+}
+
+// -----
+
+// CHECK-LABEL: func.func @test_multiple_dead_targets
+func.func @test_multiple_dead_targets() {
+ %false = arith.constant false
+
+ // CHECK: fir.if %{{.*}} {
+ fir.if %false {
+ // CHECK: omp.target
+ // CHECK: } {omp.target_unreachable}
+ omp.target {
+ omp.terminator
+ }
+
+ // CHECK: omp.target
+ // CHECK: } {omp.target_unreachable}
+ omp.target {
+ omp.terminator
+ }
+
+ // CHECK: omp.target
+ // CHECK: } {omp.target_unreachable}
+ omp.target {
+ omp.terminator
+ }
+ }
+ return
+}
+
+// -----
+
+// CHECK-LABEL: func.func @test_if_else_false
+func.func @test_if_else_false() {
+ %false = arith.constant false
+
+ // CHECK: fir.if %{{.*}} {
+ fir.if %false {
+ // CHECK: omp.target
+ // CHECK: } {omp.target_unreachable}
+ omp.target {
+ omp.terminator
+ }
+ } else {
+ // Else branch should not be marked (it's reachable)
+ // CHECK: omp.target {
+ // CHECK-NEXT: omp.terminator
+ // CHECK-NEXT: }
+ // CHECK-NOT: omp.target_unreachable
+ omp.target {
+ omp.terminator
+ }
+ }
+ return
+}
+
+// -----
+
+// Test with cf.cond_br
+// CHECK-LABEL: func.func @test_cf_cond_br_false
+func.func @test_cf_cond_br_false() {
+ %false = arith.constant false
+ // CHECK: cf.cond_br %{{.*}}, ^bb1, ^bb2
+ cf.cond_br %false, ^bb1, ^bb2
+^bb1:
+ // CHECK: omp.target
+ // CHECK: } {omp.target_unreachable}
+ omp.target {
+ omp.terminator
+ }
+ cf.br ^bb2
+^bb2:
+ return
+}
+
+// -----
+
+// CHECK-LABEL: func.func @test_cf_cond_br_true
+func.func @test_cf_cond_br_true() {
+ %true = arith.constant true
+ // CHECK: cf.cond_br %{{.*}}, ^bb1, ^bb2
+ cf.cond_br %true, ^bb1, ^bb2
+^bb1:
+ // CHECK: omp.target {
+ // CHECK-NEXT: omp.terminator
+ // CHECK-NEXT: }
+ // CHECK-NOT: omp.target_unreachable
+ omp.target {
+ omp.terminator
+ }
+ cf.br ^bb2
+^bb2:
+ return
+}
+
+// -----
+
+// CHECK-LABEL: func.func @test_runtime_condition
+func.func @test_runtime_condition(%arg0: i1) {
+ // CHECK: fir.if %arg0 {
+ fir.if %arg0 {
+ // Runtime condition - should NOT be marked
+ // CHECK: omp.target {
+ // CHECK-NEXT: omp.terminator
+ // CHECK-NEXT: }
+ // CHECK-NOT: omp.target_unreachable
+ omp.target {
+ omp.terminator
+ }
+ }
+ return
+}
+
+// -----
+
+// Test for multiple predecessors - one reachable, one unreachable
+// The block should NOT be marked as unreachable if ANY path is reachable
+// CHECK-LABEL: func.func @test_multiple_predecessors
+func.func @test_multiple_predecessors() {
+ %false = arith.constant false
+ cf.cond_br %false, ^bb2, ^bb1
+^bb1:
+ // Reachable path to bb2
+ cf.br ^bb2
+^bb2:
+ // This block has two predecessors:
+ // - bb0 with false condition (unreachable path)
+ // - bb1 with unconditional branch (reachable path)
+ // Target should NOT be marked unreachable because bb1 provides a reachable path
+ // CHECK: omp.target {
+ // CHECK-NEXT: omp.terminator
+ // CHECK-NEXT: }
+ // CHECK-NOT: omp.target_unreachable
+ omp.target {
+ omp.terminator
+ }
+ return
+}
+
+// -----
+
+// Test for multiple predecessors - ALL unreachable
+// CHECK-LABEL: func.func @test_multiple_predecessors_all_unreachable
+func.func @test_multiple_predecessors_all_unreachable() {
+ %false1 = arith.constant false
+ %false2 = arith.constant false
+ cf.cond_br %false1, ^bb3, ^bb1
+^bb1:
+ cf.cond_br %false2, ^bb3, ^bb2
+^bb2:
+ cf.br ^bb4
+^bb3:
+ // This block has two predecessors:
+ // - bb0 with false condition to bb3 (unreachable)
+ // - bb1 with false condition to bb3 (unreachable)
+ // Target SHOULD be marked unreachable because ALL paths are unreachable
+ // CHECK: omp.target
+ // CHECK-NEXT: omp.terminator
+ // CHECK-NEXT: } {omp.target_unreachable}
+ omp.target {
+ omp.terminator
+ }
+ cf.br ^bb4
+^bb4:
+ return
+}
+
+// -----
+
+// Test for multiple predecessors with mixed constant and runtime conditions
+// CHECK-LABEL: func.func @test_multiple_predecessors_mixed
+func.func @test_multiple_predecessors_mixed(%arg0: i1) {
+ %false = arith.constant false
+ cf.cond_br %false, ^bb2, ^bb1
+^bb1:
+ // Runtime condition - could branch to bb2
+ cf.cond_br %arg0, ^bb2, ^bb3
+^bb2:
+ // This block has two predecessors:
+ // - bb0 with false condition (unreachable)
+ // - bb1 with runtime condition (potentially reachable)
+ // Target should NOT be marked because we can't prove bb1 path is unreachable
+ // CHECK: omp.target {
+ // CHECK-NEXT: omp.terminator
+ // CHECK-NEXT: }
+ // CHECK-NOT: omp.target_unreachable
+ omp.target {
+ omp.terminator
+ }
+ cf.br ^bb3
+^bb3:
+ return
+}
diff --git a/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp b/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp
index f04d614633965..acb6145628799 100644
--- a/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp
+++ b/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp
@@ -6370,6 +6370,11 @@ static LogicalResult
convertOmpTarget(Operation &opInst, llvm::IRBuilderBase &builder,
LLVM::ModuleTranslation &modul...
[truncated]
|
|
@llvm/pr-subscribers-mlir Author: Abid Qadeer (abidh) ChangesWhen a target region is placed inside a constant false condition (e.g., program cele
implicit none
real :: V
integer :: i
if (.false.) then
!$omp target teams distribute parallel do
do i = 1, 5
V = V * 2
end do
!$omp end target teams distribute parallel do
end if
end programIt currently fails with the following assertion: This PR adds Patch is 21.50 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/178937.diff 9 Files Affected:
diff --git a/flang/include/flang/Optimizer/OpenMP/Passes.td b/flang/include/flang/Optimizer/OpenMP/Passes.td
index f17b1e3794908..793d575b7da8e 100644
--- a/flang/include/flang/Optimizer/OpenMP/Passes.td
+++ b/flang/include/flang/Optimizer/OpenMP/Passes.td
@@ -41,6 +41,18 @@ def MarkDeclareTargetPass
let dependentDialects = ["mlir::omp::OpenMPDialect"];
}
+def MarkUnreachableTargetsPass
+ : Pass<"omp-mark-unreachable-targets", "mlir::ModuleOp"> {
+ let summary = "Marks OpenMP target operations in unreachable code";
+ let description = [{
+ Identifies OpenMP target operations that reside in unreachable code
+ (e.g., inside if(.false.) blocks) and marks them with an attribute.
+ This allows device compilation to skip generating code for targets
+ that were eliminated on the host side.
+ }];
+ let dependentDialects = ["mlir::omp::OpenMPDialect"];
+}
+
def FunctionFilteringPass : Pass<"omp-function-filtering"> {
let summary = "Filters out functions intended for the host when compiling "
"for the target device.";
diff --git a/flang/lib/Optimizer/OpenMP/CMakeLists.txt b/flang/lib/Optimizer/OpenMP/CMakeLists.txt
index 23a7dc8f08399..136b9d9ea9313 100644
--- a/flang/lib/Optimizer/OpenMP/CMakeLists.txt
+++ b/flang/lib/Optimizer/OpenMP/CMakeLists.txt
@@ -8,6 +8,7 @@ add_flang_library(FlangOpenMPTransforms
MapsForPrivatizedSymbols.cpp
MapInfoFinalization.cpp
MarkDeclareTarget.cpp
+ MarkUnreachableTargets.cpp
LowerWorkdistribute.cpp
LowerWorkshare.cpp
LowerNontemporal.cpp
diff --git a/flang/lib/Optimizer/OpenMP/FunctionFiltering.cpp b/flang/lib/Optimizer/OpenMP/FunctionFiltering.cpp
index 0acee8991e372..f7828df21182e 100644
--- a/flang/lib/Optimizer/OpenMP/FunctionFiltering.cpp
+++ b/flang/lib/Optimizer/OpenMP/FunctionFiltering.cpp
@@ -120,10 +120,15 @@ class FunctionFilteringPass
// Do not filter functions with target regions inside, because they have
// to be available for both host and device so that regular and reverse
// offloading can be supported.
+ // However, skip target regions marked as unreachable.
bool hasTargetRegion =
funcOp
- ->walk<WalkOrder::PreOrder>(
- [&](omp::TargetOp) { return WalkResult::interrupt(); })
+ ->walk<WalkOrder::PreOrder>([&](omp::TargetOp targetOp) {
+ // Skip targets marked as unreachable
+ if (targetOp->hasAttr("omp.target_unreachable"))
+ return WalkResult::advance();
+ return WalkResult::interrupt();
+ })
.wasInterrupted();
omp::DeclareTargetDeviceType declareType =
diff --git a/flang/lib/Optimizer/OpenMP/MarkUnreachableTargets.cpp b/flang/lib/Optimizer/OpenMP/MarkUnreachableTargets.cpp
new file mode 100644
index 0000000000000..06423f32a77f0
--- /dev/null
+++ b/flang/lib/Optimizer/OpenMP/MarkUnreachableTargets.cpp
@@ -0,0 +1,166 @@
+//===- MarkUnreachableTargets.cpp ----------------------------------------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// This pass marks OpenMP target operations that are in unreachable code
+// with an attribute. This allows device compilation to skip generating code
+// for such ops.
+//
+//===----------------------------------------------------------------------===//
+
+#include "flang/Optimizer/OpenMP/Passes.h"
+
+#include "flang/Optimizer/Dialect/FIRDialect.h"
+#include "flang/Optimizer/Dialect/FIROps.h"
+#include "mlir/Dialect/Arith/IR/Arith.h"
+#include "mlir/Dialect/ControlFlow/IR/ControlFlowOps.h"
+#include "mlir/Dialect/Func/IR/FuncOps.h"
+#include "mlir/Dialect/OpenMP/OpenMPDialect.h"
+#include "mlir/IR/BuiltinOps.h"
+#include "mlir/Pass/Pass.h"
+#include "mlir/Support/LLVM.h"
+#include "llvm/ADT/SmallSet.h"
+
+namespace flangomp {
+#define GEN_PASS_DEF_MARKUNREACHABLETARGETSPASS
+#include "flang/Optimizer/OpenMP/Passes.h.inc"
+} // namespace flangomp
+
+using namespace mlir;
+
+namespace {
+
+/// Check if an operation is nested inside a fir.if with a constant false
+/// condition.
+static bool isInUnreachableIfBlock(Operation *op) {
+ Operation *current = op;
+
+ // Walk up through parent operations
+ while (current) {
+ Operation *parentOp = current->getParentOp();
+ if (!parentOp)
+ break;
+
+ // Check for fir.if with constant false condition
+ if (auto firIf = dyn_cast<fir::IfOp>(parentOp)) {
+ if (auto constOp =
+ firIf.getCondition().getDefiningOp<arith::ConstantOp>()) {
+ if (auto intAttr = dyn_cast<IntegerAttr>(constOp.getValue())) {
+ // If condition is false (0) and op is in the "then" region
+ if (intAttr.getInt() == 0 &&
+ current->getParentRegion() == &firIf.getThenRegion())
+ return true;
+ // If condition is true (non-zero) and op is in the "else" region
+ if (intAttr.getInt() != 0 && !firIf.getElseRegion().empty() &&
+ current->getParentRegion() == &firIf.getElseRegion())
+ return true;
+ }
+ }
+ }
+
+ current = parentOp;
+ }
+
+ return false;
+}
+
+/// Check if a block is unreachable due to constant condition branches.
+/// A block is unreachable only if ALL predecessors lead to it through
+/// unreachable paths (i.e., constant false conditions).
+/// This handles patterns like:
+/// %false = arith.constant false
+/// cf.cond_br %false, ^bb1, ^bb2
+/// where ^bb1 is unreachable.
+static bool isBlockUnreachable(Block *block) {
+ // Entry blocks and blocks with no predecessors are reachable
+ if (block->hasNoPredecessors())
+ return false;
+
+ // Check all predecessors - block is unreachable only if ALL paths are
+ // provably unreachable via constant conditions
+ for (Block *pred : block->getPredecessors()) {
+ Operation *terminator = pred->getTerminator();
+
+ // Check if this is a cf.cond_br with constant condition
+ if (auto condBr = dyn_cast<cf::CondBranchOp>(terminator)) {
+ // Try to get the constant value of the condition
+ if (auto constOp =
+ condBr.getCondition().getDefiningOp<arith::ConstantOp>()) {
+ if (auto intAttr = dyn_cast<IntegerAttr>(constOp.getValue())) {
+ bool condIsTrue = intAttr.getInt() != 0;
+
+ // If condition is false and block is the true destination,
+ // this path is unreachable - continue checking other predecessors
+ if (!condIsTrue && block == condBr.getTrueDest())
+ continue;
+ // If condition is true and block is the false destination,
+ // this path is unreachable - continue checking other predecessors
+ if (condIsTrue && block == condBr.getFalseDest())
+ continue;
+ // Otherwise, this path IS reachable (condition matches destination)
+ return false;
+ }
+ }
+ }
+
+ // If we reach here, this predecessor either:
+ // - is not a CondBranchOp, OR
+ // - doesn't have a constant condition
+ // Either way, this path could be taken, so block is reachable
+ return false;
+ }
+
+ // All predecessors lead to this block through unreachable paths
+ return true;
+}
+
+/// Recursively check if an operation is in an unreachable block.
+/// This walks up the block hierarchy to check if any containing block
+/// is unreachable, handling both fir.if and cf.cond_br patterns.
+static bool isOperationUnreachable(Operation *op) {
+ // First check for fir.if patterns (before SCF lowering)
+ if (isInUnreachableIfBlock(op))
+ return true;
+
+ // Then check for cf.cond_br patterns (after SCF lowering)
+ Block *currentBlock = op->getBlock();
+
+ // Walk up through nested regions checking each block
+ while (currentBlock) {
+ if (isBlockUnreachable(currentBlock))
+ return true;
+
+ // Move to parent operation's block
+ Operation *parentOp = currentBlock->getParentOp();
+ if (!parentOp || isa<ModuleOp>(parentOp) || isa<func::FuncOp>(parentOp))
+ break;
+
+ currentBlock = parentOp->getBlock();
+ }
+
+ return false;
+}
+
+class MarkUnreachableTargetsPass
+ : public flangomp::impl::MarkUnreachableTargetsPassBase<
+ MarkUnreachableTargetsPass> {
+public:
+ MarkUnreachableTargetsPass() = default;
+
+ void runOnOperation() override {
+ MLIRContext *context = &getContext();
+ auto module = getOperation();
+
+ // Walk all target operations and mark those that are unreachable
+ module.walk([&](omp::TargetOp targetOp) {
+ if (isOperationUnreachable(targetOp.getOperation()))
+ targetOp->setAttr("omp.target_unreachable", UnitAttr::get(context));
+ });
+ }
+};
+
+} // namespace
diff --git a/flang/lib/Optimizer/Passes/Pipelines.cpp b/flang/lib/Optimizer/Passes/Pipelines.cpp
index 6054675643c64..a73a3f9d2325c 100644
--- a/flang/lib/Optimizer/Passes/Pipelines.cpp
+++ b/flang/lib/Optimizer/Passes/Pipelines.cpp
@@ -341,6 +341,11 @@ void createOpenMPFIRPassPipeline(mlir::PassManager &pm,
pm.addPass(flangomp::createAutomapToTargetDataPass());
pm.addPass(flangomp::createMapInfoFinalizationPass());
pm.addPass(flangomp::createMarkDeclareTargetPass());
+
+ // Mark unreachable target operations before FunctionFilteringPass
+ // extracts them.
+ pm.addPass(flangomp::createMarkUnreachableTargetsPass());
+
pm.addPass(flangomp::createGenericLoopConversionPass());
if (opts.isTargetDevice)
pm.addPass(flangomp::createFunctionFilteringPass());
diff --git a/flang/test/Lower/OpenMP/target-dead-code.f90 b/flang/test/Lower/OpenMP/target-dead-code.f90
new file mode 100644
index 0000000000000..82932ca82858f
--- /dev/null
+++ b/flang/test/Lower/OpenMP/target-dead-code.f90
@@ -0,0 +1,82 @@
+! RUN: %flang_fc1 -emit-hlfir -fopenmp %s -o - | FileCheck %s --check-prefix=FIR
+
+! Test that OpenMP target regions in dead code are marked for elimination
+
+! Test 1: if (.false.) with target - should be marked unreachable
+! FIR-LABEL: func.func @_QPtest_dead_simple
+! FIR: %[[FALSE:.*]] = arith.constant false
+! FIR: fir.if %[[FALSE]] {
+! FIR: omp.target
+! FIR: } {omp.target_unreachable}
+subroutine test_dead_simple()
+ real :: v
+ if (.false.) then
+ !$omp target map(tofrom:v)
+ v = 1.0
+ !$omp end target
+ end if
+end subroutine
+
+! Test 2: Live target - should NOT be marked
+! FIR-LABEL: func.func @_QPtest_live_simple
+! FIR: omp.target
+! FIR-NOT: omp.target_unreachable
+subroutine test_live_simple()
+ real :: v
+ !$omp target map(tofrom:v)
+ v = 2.0
+ !$omp end target
+end subroutine
+
+! Test 3: Mixed dead and live
+! FIR-LABEL: func.func @_QPtest_mixed
+subroutine test_mixed()
+ real :: v
+ ! Dead - should be marked
+ ! FIR: fir.if %{{.*}} {
+ if (.false.) then
+ !$omp target map(tofrom:v)
+ ! FIR: omp.target
+ ! FIR: } {omp.target_unreachable}
+ v = 3.0
+ !$omp end target
+ end if
+ ! Live - should NOT be marked
+ !$omp target map(tofrom:v)
+ ! FIR: omp.target
+ ! FIR-NOT: omp.target_unreachable
+ v = 4.0
+ !$omp end target
+end subroutine
+
+! Test 4: Nested - outer false
+! FIR-LABEL: func.func @_QPtest_nested_outer_false
+subroutine test_nested_outer_false()
+ real :: v
+ ! FIR: fir.if %{{.*}} {
+ if (.false.) then
+ ! FIR: fir.if %{{.*}} {
+ if (.true.) then
+ ! FIR: omp.target
+ ! FIR: } {omp.target_unreachable}
+ !$omp target map(tofrom:v)
+ v = 5.0
+ !$omp end target
+ end if
+ end if
+end subroutine
+
+! Test 5: Parameter constant
+! FIR-LABEL: func.func @_QPtest_parameter
+subroutine test_parameter()
+ real :: v
+ logical, parameter :: DEAD = .false.
+ ! FIR: fir.if %{{.*}} {
+ if (DEAD) then
+ ! FIR: omp.target
+ ! FIR: } {omp.target_unreachable}
+ !$omp target map(tofrom:v)
+ v = 6.0
+ !$omp end target
+ end if
+end subroutine
diff --git a/flang/test/Transforms/OpenMP/mark-unreachable-targets.mlir b/flang/test/Transforms/OpenMP/mark-unreachable-targets.mlir
new file mode 100644
index 0000000000000..66f7e607a65fd
--- /dev/null
+++ b/flang/test/Transforms/OpenMP/mark-unreachable-targets.mlir
@@ -0,0 +1,331 @@
+// RUN: fir-opt --omp-mark-unreachable-targets %s | FileCheck %s
+
+// CHECK-LABEL: func.func @test_if_false_simple
+func.func @test_if_false_simple() {
+ %false = arith.constant false
+ // CHECK: fir.if %{{.*}} {
+ fir.if %false {
+ // CHECK: omp.target
+ // CHECK-NEXT: omp.terminator
+ // CHECK-NEXT: } {omp.target_unreachable}
+ omp.target {
+ omp.terminator
+ }
+ }
+ return
+}
+
+// -----
+
+// CHECK-LABEL: func.func @test_if_true_simple
+func.func @test_if_true_simple() {
+ %true = arith.constant true
+ // CHECK: fir.if %{{.*}} {
+ fir.if %true {
+ // CHECK: omp.target {
+ // CHECK-NEXT: omp.terminator
+ // CHECK-NEXT: }
+ // CHECK-NOT: omp.target_unreachable
+ omp.target {
+ omp.terminator
+ }
+ }
+ return
+}
+
+// -----
+
+// CHECK-LABEL: func.func @test_nested_outer_false
+func.func @test_nested_outer_false() {
+ %false = arith.constant false
+ %true = arith.constant true
+ // CHECK: fir.if %{{.*}} {
+ fir.if %false {
+ // CHECK: fir.if %{{.*}} {
+ fir.if %true {
+ // CHECK: omp.target
+ // CHECK-NEXT: omp.terminator
+ // CHECK-NEXT: } {omp.target_unreachable}
+ omp.target {
+ omp.terminator
+ }
+ }
+ }
+ return
+}
+
+// -----
+
+// CHECK-LABEL: func.func @test_nested_inner_false
+func.func @test_nested_inner_false() {
+ %false = arith.constant false
+ %true = arith.constant true
+ // CHECK: fir.if %{{.*}} {
+ fir.if %true {
+ // CHECK: fir.if %{{.*}} {
+ fir.if %false {
+ // CHECK: omp.target
+ // CHECK: } {omp.target_unreachable}
+ omp.target {
+ omp.terminator
+ }
+ }
+ }
+ return
+}
+
+// -----
+
+// CHECK-LABEL: func.func @test_nested_both_true
+func.func @test_nested_both_true() {
+ %true1 = arith.constant true
+ %true2 = arith.constant true
+ // CHECK: fir.if %{{.*}} {
+ fir.if %true1 {
+ // CHECK: fir.if %{{.*}} {
+ fir.if %true2 {
+ // CHECK: omp.target {
+ // CHECK-NEXT: omp.terminator
+ // CHECK-NEXT: }
+ // CHECK-NOT: omp.target_unreachable
+ omp.target {
+ omp.terminator
+ }
+ }
+ }
+ return
+}
+
+// -----
+
+// CHECK-LABEL: func.func @test_mixed_targets
+func.func @test_mixed_targets() {
+ %false = arith.constant false
+ %true = arith.constant true
+
+ // Dead target
+ // CHECK: fir.if %{{.*}} {
+ fir.if %false {
+ // CHECK: omp.target
+ // CHECK: } {omp.target_unreachable}
+ omp.target {
+ omp.terminator
+ }
+ }
+
+ // Live target - should NOT have unreachable attribute
+ // CHECK: omp.target {
+ // CHECK-NEXT: omp.terminator
+ // CHECK-NEXT: }
+ // CHECK-NOT: omp.target_unreachable
+ omp.target {
+ omp.terminator
+ }
+
+ // Another live target in if (true)
+ // CHECK: fir.if %{{.*}} {
+ fir.if %true {
+ // CHECK: omp.target {
+ // CHECK-NEXT: omp.terminator
+ // CHECK-NEXT: }
+ // CHECK-NOT: omp.target_unreachable
+ omp.target {
+ omp.terminator
+ }
+ }
+
+ return
+}
+
+// -----
+
+// CHECK-LABEL: func.func @test_multiple_dead_targets
+func.func @test_multiple_dead_targets() {
+ %false = arith.constant false
+
+ // CHECK: fir.if %{{.*}} {
+ fir.if %false {
+ // CHECK: omp.target
+ // CHECK: } {omp.target_unreachable}
+ omp.target {
+ omp.terminator
+ }
+
+ // CHECK: omp.target
+ // CHECK: } {omp.target_unreachable}
+ omp.target {
+ omp.terminator
+ }
+
+ // CHECK: omp.target
+ // CHECK: } {omp.target_unreachable}
+ omp.target {
+ omp.terminator
+ }
+ }
+ return
+}
+
+// -----
+
+// CHECK-LABEL: func.func @test_if_else_false
+func.func @test_if_else_false() {
+ %false = arith.constant false
+
+ // CHECK: fir.if %{{.*}} {
+ fir.if %false {
+ // CHECK: omp.target
+ // CHECK: } {omp.target_unreachable}
+ omp.target {
+ omp.terminator
+ }
+ } else {
+ // Else branch should not be marked (it's reachable)
+ // CHECK: omp.target {
+ // CHECK-NEXT: omp.terminator
+ // CHECK-NEXT: }
+ // CHECK-NOT: omp.target_unreachable
+ omp.target {
+ omp.terminator
+ }
+ }
+ return
+}
+
+// -----
+
+// Test with cf.cond_br
+// CHECK-LABEL: func.func @test_cf_cond_br_false
+func.func @test_cf_cond_br_false() {
+ %false = arith.constant false
+ // CHECK: cf.cond_br %{{.*}}, ^bb1, ^bb2
+ cf.cond_br %false, ^bb1, ^bb2
+^bb1:
+ // CHECK: omp.target
+ // CHECK: } {omp.target_unreachable}
+ omp.target {
+ omp.terminator
+ }
+ cf.br ^bb2
+^bb2:
+ return
+}
+
+// -----
+
+// CHECK-LABEL: func.func @test_cf_cond_br_true
+func.func @test_cf_cond_br_true() {
+ %true = arith.constant true
+ // CHECK: cf.cond_br %{{.*}}, ^bb1, ^bb2
+ cf.cond_br %true, ^bb1, ^bb2
+^bb1:
+ // CHECK: omp.target {
+ // CHECK-NEXT: omp.terminator
+ // CHECK-NEXT: }
+ // CHECK-NOT: omp.target_unreachable
+ omp.target {
+ omp.terminator
+ }
+ cf.br ^bb2
+^bb2:
+ return
+}
+
+// -----
+
+// CHECK-LABEL: func.func @test_runtime_condition
+func.func @test_runtime_condition(%arg0: i1) {
+ // CHECK: fir.if %arg0 {
+ fir.if %arg0 {
+ // Runtime condition - should NOT be marked
+ // CHECK: omp.target {
+ // CHECK-NEXT: omp.terminator
+ // CHECK-NEXT: }
+ // CHECK-NOT: omp.target_unreachable
+ omp.target {
+ omp.terminator
+ }
+ }
+ return
+}
+
+// -----
+
+// Test for multiple predecessors - one reachable, one unreachable
+// The block should NOT be marked as unreachable if ANY path is reachable
+// CHECK-LABEL: func.func @test_multiple_predecessors
+func.func @test_multiple_predecessors() {
+ %false = arith.constant false
+ cf.cond_br %false, ^bb2, ^bb1
+^bb1:
+ // Reachable path to bb2
+ cf.br ^bb2
+^bb2:
+ // This block has two predecessors:
+ // - bb0 with false condition (unreachable path)
+ // - bb1 with unconditional branch (reachable path)
+ // Target should NOT be marked unreachable because bb1 provides a reachable path
+ // CHECK: omp.target {
+ // CHECK-NEXT: omp.terminator
+ // CHECK-NEXT: }
+ // CHECK-NOT: omp.target_unreachable
+ omp.target {
+ omp.terminator
+ }
+ return
+}
+
+// -----
+
+// Test for multiple predecessors - ALL unreachable
+// CHECK-LABEL: func.func @test_multiple_predecessors_all_unreachable
+func.func @test_multiple_predecessors_all_unreachable() {
+ %false1 = arith.constant false
+ %false2 = arith.constant false
+ cf.cond_br %false1, ^bb3, ^bb1
+^bb1:
+ cf.cond_br %false2, ^bb3, ^bb2
+^bb2:
+ cf.br ^bb4
+^bb3:
+ // This block has two predecessors:
+ // - bb0 with false condition to bb3 (unreachable)
+ // - bb1 with false condition to bb3 (unreachable)
+ // Target SHOULD be marked unreachable because ALL paths are unreachable
+ // CHECK: omp.target
+ // CHECK-NEXT: omp.terminator
+ // CHECK-NEXT: } {omp.target_unreachable}
+ omp.target {
+ omp.terminator
+ }
+ cf.br ^bb4
+^bb4:
+ return
+}
+
+// -----
+
+// Test for multiple predecessors with mixed constant and runtime conditions
+// CHECK-LABEL: func.func @test_multiple_predecessors_mixed
+func.func @test_multiple_predecessors_mixed(%arg0: i1) {
+ %false = arith.constant false
+ cf.cond_br %false, ^bb2, ^bb1
+^bb1:
+ // Runtime condition - could branch to bb2
+ cf.cond_br %arg0, ^bb2, ^bb3
+^bb2:
+ // This block has two predecessors:
+ // - bb0 with false condition (unreachable)
+ // - bb1 with runtime condition (potentially reachable)
+ // Target should NOT be marked because we can't prove bb1 path is unreachable
+ // CHECK: omp.target {
+ // CHECK-NEXT: omp.terminator
+ // CHECK-NEXT: }
+ // CHECK-NOT: omp.target_unreachable
+ omp.target {
+ omp.terminator
+ }
+ cf.br ^bb3
+^bb3:
+ return
+}
diff --git a/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp b/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp
index f04d614633965..acb6145628799 100644
--- a/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp
+++ b/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp
@@ -6370,6 +6370,11 @@ static LogicalResult
convertOmpTarget(Operation &opInst, llvm::IRBuilderBase &builder,
LLVM::ModuleTranslation &modul...
[truncated]
|
|
@llvm/pr-subscribers-flang-openmp Author: Abid Qadeer (abidh) ChangesWhen a target region is placed inside a constant false condition (e.g., program cele
implicit none
real :: V
integer :: i
if (.false.) then
!$omp target teams distribute parallel do
do i = 1, 5
V = V * 2
end do
!$omp end target teams distribute parallel do
end if
end programIt currently fails with the following assertion: This PR adds Patch is 21.50 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/178937.diff 9 Files Affected:
diff --git a/flang/include/flang/Optimizer/OpenMP/Passes.td b/flang/include/flang/Optimizer/OpenMP/Passes.td
index f17b1e3794908..793d575b7da8e 100644
--- a/flang/include/flang/Optimizer/OpenMP/Passes.td
+++ b/flang/include/flang/Optimizer/OpenMP/Passes.td
@@ -41,6 +41,18 @@ def MarkDeclareTargetPass
let dependentDialects = ["mlir::omp::OpenMPDialect"];
}
+def MarkUnreachableTargetsPass
+ : Pass<"omp-mark-unreachable-targets", "mlir::ModuleOp"> {
+ let summary = "Marks OpenMP target operations in unreachable code";
+ let description = [{
+ Identifies OpenMP target operations that reside in unreachable code
+ (e.g., inside if(.false.) blocks) and marks them with an attribute.
+ This allows device compilation to skip generating code for targets
+ that were eliminated on the host side.
+ }];
+ let dependentDialects = ["mlir::omp::OpenMPDialect"];
+}
+
def FunctionFilteringPass : Pass<"omp-function-filtering"> {
let summary = "Filters out functions intended for the host when compiling "
"for the target device.";
diff --git a/flang/lib/Optimizer/OpenMP/CMakeLists.txt b/flang/lib/Optimizer/OpenMP/CMakeLists.txt
index 23a7dc8f08399..136b9d9ea9313 100644
--- a/flang/lib/Optimizer/OpenMP/CMakeLists.txt
+++ b/flang/lib/Optimizer/OpenMP/CMakeLists.txt
@@ -8,6 +8,7 @@ add_flang_library(FlangOpenMPTransforms
MapsForPrivatizedSymbols.cpp
MapInfoFinalization.cpp
MarkDeclareTarget.cpp
+ MarkUnreachableTargets.cpp
LowerWorkdistribute.cpp
LowerWorkshare.cpp
LowerNontemporal.cpp
diff --git a/flang/lib/Optimizer/OpenMP/FunctionFiltering.cpp b/flang/lib/Optimizer/OpenMP/FunctionFiltering.cpp
index 0acee8991e372..f7828df21182e 100644
--- a/flang/lib/Optimizer/OpenMP/FunctionFiltering.cpp
+++ b/flang/lib/Optimizer/OpenMP/FunctionFiltering.cpp
@@ -120,10 +120,15 @@ class FunctionFilteringPass
// Do not filter functions with target regions inside, because they have
// to be available for both host and device so that regular and reverse
// offloading can be supported.
+ // However, skip target regions marked as unreachable.
bool hasTargetRegion =
funcOp
- ->walk<WalkOrder::PreOrder>(
- [&](omp::TargetOp) { return WalkResult::interrupt(); })
+ ->walk<WalkOrder::PreOrder>([&](omp::TargetOp targetOp) {
+ // Skip targets marked as unreachable
+ if (targetOp->hasAttr("omp.target_unreachable"))
+ return WalkResult::advance();
+ return WalkResult::interrupt();
+ })
.wasInterrupted();
omp::DeclareTargetDeviceType declareType =
diff --git a/flang/lib/Optimizer/OpenMP/MarkUnreachableTargets.cpp b/flang/lib/Optimizer/OpenMP/MarkUnreachableTargets.cpp
new file mode 100644
index 0000000000000..06423f32a77f0
--- /dev/null
+++ b/flang/lib/Optimizer/OpenMP/MarkUnreachableTargets.cpp
@@ -0,0 +1,166 @@
+//===- MarkUnreachableTargets.cpp ----------------------------------------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// This pass marks OpenMP target operations that are in unreachable code
+// with an attribute. This allows device compilation to skip generating code
+// for such ops.
+//
+//===----------------------------------------------------------------------===//
+
+#include "flang/Optimizer/OpenMP/Passes.h"
+
+#include "flang/Optimizer/Dialect/FIRDialect.h"
+#include "flang/Optimizer/Dialect/FIROps.h"
+#include "mlir/Dialect/Arith/IR/Arith.h"
+#include "mlir/Dialect/ControlFlow/IR/ControlFlowOps.h"
+#include "mlir/Dialect/Func/IR/FuncOps.h"
+#include "mlir/Dialect/OpenMP/OpenMPDialect.h"
+#include "mlir/IR/BuiltinOps.h"
+#include "mlir/Pass/Pass.h"
+#include "mlir/Support/LLVM.h"
+#include "llvm/ADT/SmallSet.h"
+
+namespace flangomp {
+#define GEN_PASS_DEF_MARKUNREACHABLETARGETSPASS
+#include "flang/Optimizer/OpenMP/Passes.h.inc"
+} // namespace flangomp
+
+using namespace mlir;
+
+namespace {
+
+/// Check if an operation is nested inside a fir.if with a constant false
+/// condition.
+static bool isInUnreachableIfBlock(Operation *op) {
+ Operation *current = op;
+
+ // Walk up through parent operations
+ while (current) {
+ Operation *parentOp = current->getParentOp();
+ if (!parentOp)
+ break;
+
+ // Check for fir.if with constant false condition
+ if (auto firIf = dyn_cast<fir::IfOp>(parentOp)) {
+ if (auto constOp =
+ firIf.getCondition().getDefiningOp<arith::ConstantOp>()) {
+ if (auto intAttr = dyn_cast<IntegerAttr>(constOp.getValue())) {
+ // If condition is false (0) and op is in the "then" region
+ if (intAttr.getInt() == 0 &&
+ current->getParentRegion() == &firIf.getThenRegion())
+ return true;
+ // If condition is true (non-zero) and op is in the "else" region
+ if (intAttr.getInt() != 0 && !firIf.getElseRegion().empty() &&
+ current->getParentRegion() == &firIf.getElseRegion())
+ return true;
+ }
+ }
+ }
+
+ current = parentOp;
+ }
+
+ return false;
+}
+
+/// Check if a block is unreachable due to constant condition branches.
+/// A block is unreachable only if ALL predecessors lead to it through
+/// unreachable paths (i.e., constant false conditions).
+/// This handles patterns like:
+/// %false = arith.constant false
+/// cf.cond_br %false, ^bb1, ^bb2
+/// where ^bb1 is unreachable.
+static bool isBlockUnreachable(Block *block) {
+ // Entry blocks and blocks with no predecessors are reachable
+ if (block->hasNoPredecessors())
+ return false;
+
+ // Check all predecessors - block is unreachable only if ALL paths are
+ // provably unreachable via constant conditions
+ for (Block *pred : block->getPredecessors()) {
+ Operation *terminator = pred->getTerminator();
+
+ // Check if this is a cf.cond_br with constant condition
+ if (auto condBr = dyn_cast<cf::CondBranchOp>(terminator)) {
+ // Try to get the constant value of the condition
+ if (auto constOp =
+ condBr.getCondition().getDefiningOp<arith::ConstantOp>()) {
+ if (auto intAttr = dyn_cast<IntegerAttr>(constOp.getValue())) {
+ bool condIsTrue = intAttr.getInt() != 0;
+
+ // If condition is false and block is the true destination,
+ // this path is unreachable - continue checking other predecessors
+ if (!condIsTrue && block == condBr.getTrueDest())
+ continue;
+ // If condition is true and block is the false destination,
+ // this path is unreachable - continue checking other predecessors
+ if (condIsTrue && block == condBr.getFalseDest())
+ continue;
+ // Otherwise, this path IS reachable (condition matches destination)
+ return false;
+ }
+ }
+ }
+
+ // If we reach here, this predecessor either:
+ // - is not a CondBranchOp, OR
+ // - doesn't have a constant condition
+ // Either way, this path could be taken, so block is reachable
+ return false;
+ }
+
+ // All predecessors lead to this block through unreachable paths
+ return true;
+}
+
+/// Recursively check if an operation is in an unreachable block.
+/// This walks up the block hierarchy to check if any containing block
+/// is unreachable, handling both fir.if and cf.cond_br patterns.
+static bool isOperationUnreachable(Operation *op) {
+ // First check for fir.if patterns (before SCF lowering)
+ if (isInUnreachableIfBlock(op))
+ return true;
+
+ // Then check for cf.cond_br patterns (after SCF lowering)
+ Block *currentBlock = op->getBlock();
+
+ // Walk up through nested regions checking each block
+ while (currentBlock) {
+ if (isBlockUnreachable(currentBlock))
+ return true;
+
+ // Move to parent operation's block
+ Operation *parentOp = currentBlock->getParentOp();
+ if (!parentOp || isa<ModuleOp>(parentOp) || isa<func::FuncOp>(parentOp))
+ break;
+
+ currentBlock = parentOp->getBlock();
+ }
+
+ return false;
+}
+
+class MarkUnreachableTargetsPass
+ : public flangomp::impl::MarkUnreachableTargetsPassBase<
+ MarkUnreachableTargetsPass> {
+public:
+ MarkUnreachableTargetsPass() = default;
+
+ void runOnOperation() override {
+ MLIRContext *context = &getContext();
+ auto module = getOperation();
+
+ // Walk all target operations and mark those that are unreachable
+ module.walk([&](omp::TargetOp targetOp) {
+ if (isOperationUnreachable(targetOp.getOperation()))
+ targetOp->setAttr("omp.target_unreachable", UnitAttr::get(context));
+ });
+ }
+};
+
+} // namespace
diff --git a/flang/lib/Optimizer/Passes/Pipelines.cpp b/flang/lib/Optimizer/Passes/Pipelines.cpp
index 6054675643c64..a73a3f9d2325c 100644
--- a/flang/lib/Optimizer/Passes/Pipelines.cpp
+++ b/flang/lib/Optimizer/Passes/Pipelines.cpp
@@ -341,6 +341,11 @@ void createOpenMPFIRPassPipeline(mlir::PassManager &pm,
pm.addPass(flangomp::createAutomapToTargetDataPass());
pm.addPass(flangomp::createMapInfoFinalizationPass());
pm.addPass(flangomp::createMarkDeclareTargetPass());
+
+ // Mark unreachable target operations before FunctionFilteringPass
+ // extracts them.
+ pm.addPass(flangomp::createMarkUnreachableTargetsPass());
+
pm.addPass(flangomp::createGenericLoopConversionPass());
if (opts.isTargetDevice)
pm.addPass(flangomp::createFunctionFilteringPass());
diff --git a/flang/test/Lower/OpenMP/target-dead-code.f90 b/flang/test/Lower/OpenMP/target-dead-code.f90
new file mode 100644
index 0000000000000..82932ca82858f
--- /dev/null
+++ b/flang/test/Lower/OpenMP/target-dead-code.f90
@@ -0,0 +1,82 @@
+! RUN: %flang_fc1 -emit-hlfir -fopenmp %s -o - | FileCheck %s --check-prefix=FIR
+
+! Test that OpenMP target regions in dead code are marked for elimination
+
+! Test 1: if (.false.) with target - should be marked unreachable
+! FIR-LABEL: func.func @_QPtest_dead_simple
+! FIR: %[[FALSE:.*]] = arith.constant false
+! FIR: fir.if %[[FALSE]] {
+! FIR: omp.target
+! FIR: } {omp.target_unreachable}
+subroutine test_dead_simple()
+ real :: v
+ if (.false.) then
+ !$omp target map(tofrom:v)
+ v = 1.0
+ !$omp end target
+ end if
+end subroutine
+
+! Test 2: Live target - should NOT be marked
+! FIR-LABEL: func.func @_QPtest_live_simple
+! FIR: omp.target
+! FIR-NOT: omp.target_unreachable
+subroutine test_live_simple()
+ real :: v
+ !$omp target map(tofrom:v)
+ v = 2.0
+ !$omp end target
+end subroutine
+
+! Test 3: Mixed dead and live
+! FIR-LABEL: func.func @_QPtest_mixed
+subroutine test_mixed()
+ real :: v
+ ! Dead - should be marked
+ ! FIR: fir.if %{{.*}} {
+ if (.false.) then
+ !$omp target map(tofrom:v)
+ ! FIR: omp.target
+ ! FIR: } {omp.target_unreachable}
+ v = 3.0
+ !$omp end target
+ end if
+ ! Live - should NOT be marked
+ !$omp target map(tofrom:v)
+ ! FIR: omp.target
+ ! FIR-NOT: omp.target_unreachable
+ v = 4.0
+ !$omp end target
+end subroutine
+
+! Test 4: Nested - outer false
+! FIR-LABEL: func.func @_QPtest_nested_outer_false
+subroutine test_nested_outer_false()
+ real :: v
+ ! FIR: fir.if %{{.*}} {
+ if (.false.) then
+ ! FIR: fir.if %{{.*}} {
+ if (.true.) then
+ ! FIR: omp.target
+ ! FIR: } {omp.target_unreachable}
+ !$omp target map(tofrom:v)
+ v = 5.0
+ !$omp end target
+ end if
+ end if
+end subroutine
+
+! Test 5: Parameter constant
+! FIR-LABEL: func.func @_QPtest_parameter
+subroutine test_parameter()
+ real :: v
+ logical, parameter :: DEAD = .false.
+ ! FIR: fir.if %{{.*}} {
+ if (DEAD) then
+ ! FIR: omp.target
+ ! FIR: } {omp.target_unreachable}
+ !$omp target map(tofrom:v)
+ v = 6.0
+ !$omp end target
+ end if
+end subroutine
diff --git a/flang/test/Transforms/OpenMP/mark-unreachable-targets.mlir b/flang/test/Transforms/OpenMP/mark-unreachable-targets.mlir
new file mode 100644
index 0000000000000..66f7e607a65fd
--- /dev/null
+++ b/flang/test/Transforms/OpenMP/mark-unreachable-targets.mlir
@@ -0,0 +1,331 @@
+// RUN: fir-opt --omp-mark-unreachable-targets %s | FileCheck %s
+
+// CHECK-LABEL: func.func @test_if_false_simple
+func.func @test_if_false_simple() {
+ %false = arith.constant false
+ // CHECK: fir.if %{{.*}} {
+ fir.if %false {
+ // CHECK: omp.target
+ // CHECK-NEXT: omp.terminator
+ // CHECK-NEXT: } {omp.target_unreachable}
+ omp.target {
+ omp.terminator
+ }
+ }
+ return
+}
+
+// -----
+
+// CHECK-LABEL: func.func @test_if_true_simple
+func.func @test_if_true_simple() {
+ %true = arith.constant true
+ // CHECK: fir.if %{{.*}} {
+ fir.if %true {
+ // CHECK: omp.target {
+ // CHECK-NEXT: omp.terminator
+ // CHECK-NEXT: }
+ // CHECK-NOT: omp.target_unreachable
+ omp.target {
+ omp.terminator
+ }
+ }
+ return
+}
+
+// -----
+
+// CHECK-LABEL: func.func @test_nested_outer_false
+func.func @test_nested_outer_false() {
+ %false = arith.constant false
+ %true = arith.constant true
+ // CHECK: fir.if %{{.*}} {
+ fir.if %false {
+ // CHECK: fir.if %{{.*}} {
+ fir.if %true {
+ // CHECK: omp.target
+ // CHECK-NEXT: omp.terminator
+ // CHECK-NEXT: } {omp.target_unreachable}
+ omp.target {
+ omp.terminator
+ }
+ }
+ }
+ return
+}
+
+// -----
+
+// CHECK-LABEL: func.func @test_nested_inner_false
+func.func @test_nested_inner_false() {
+ %false = arith.constant false
+ %true = arith.constant true
+ // CHECK: fir.if %{{.*}} {
+ fir.if %true {
+ // CHECK: fir.if %{{.*}} {
+ fir.if %false {
+ // CHECK: omp.target
+ // CHECK: } {omp.target_unreachable}
+ omp.target {
+ omp.terminator
+ }
+ }
+ }
+ return
+}
+
+// -----
+
+// CHECK-LABEL: func.func @test_nested_both_true
+func.func @test_nested_both_true() {
+ %true1 = arith.constant true
+ %true2 = arith.constant true
+ // CHECK: fir.if %{{.*}} {
+ fir.if %true1 {
+ // CHECK: fir.if %{{.*}} {
+ fir.if %true2 {
+ // CHECK: omp.target {
+ // CHECK-NEXT: omp.terminator
+ // CHECK-NEXT: }
+ // CHECK-NOT: omp.target_unreachable
+ omp.target {
+ omp.terminator
+ }
+ }
+ }
+ return
+}
+
+// -----
+
+// CHECK-LABEL: func.func @test_mixed_targets
+func.func @test_mixed_targets() {
+ %false = arith.constant false
+ %true = arith.constant true
+
+ // Dead target
+ // CHECK: fir.if %{{.*}} {
+ fir.if %false {
+ // CHECK: omp.target
+ // CHECK: } {omp.target_unreachable}
+ omp.target {
+ omp.terminator
+ }
+ }
+
+ // Live target - should NOT have unreachable attribute
+ // CHECK: omp.target {
+ // CHECK-NEXT: omp.terminator
+ // CHECK-NEXT: }
+ // CHECK-NOT: omp.target_unreachable
+ omp.target {
+ omp.terminator
+ }
+
+ // Another live target in if (true)
+ // CHECK: fir.if %{{.*}} {
+ fir.if %true {
+ // CHECK: omp.target {
+ // CHECK-NEXT: omp.terminator
+ // CHECK-NEXT: }
+ // CHECK-NOT: omp.target_unreachable
+ omp.target {
+ omp.terminator
+ }
+ }
+
+ return
+}
+
+// -----
+
+// CHECK-LABEL: func.func @test_multiple_dead_targets
+func.func @test_multiple_dead_targets() {
+ %false = arith.constant false
+
+ // CHECK: fir.if %{{.*}} {
+ fir.if %false {
+ // CHECK: omp.target
+ // CHECK: } {omp.target_unreachable}
+ omp.target {
+ omp.terminator
+ }
+
+ // CHECK: omp.target
+ // CHECK: } {omp.target_unreachable}
+ omp.target {
+ omp.terminator
+ }
+
+ // CHECK: omp.target
+ // CHECK: } {omp.target_unreachable}
+ omp.target {
+ omp.terminator
+ }
+ }
+ return
+}
+
+// -----
+
+// CHECK-LABEL: func.func @test_if_else_false
+func.func @test_if_else_false() {
+ %false = arith.constant false
+
+ // CHECK: fir.if %{{.*}} {
+ fir.if %false {
+ // CHECK: omp.target
+ // CHECK: } {omp.target_unreachable}
+ omp.target {
+ omp.terminator
+ }
+ } else {
+ // Else branch should not be marked (it's reachable)
+ // CHECK: omp.target {
+ // CHECK-NEXT: omp.terminator
+ // CHECK-NEXT: }
+ // CHECK-NOT: omp.target_unreachable
+ omp.target {
+ omp.terminator
+ }
+ }
+ return
+}
+
+// -----
+
+// Test with cf.cond_br
+// CHECK-LABEL: func.func @test_cf_cond_br_false
+func.func @test_cf_cond_br_false() {
+ %false = arith.constant false
+ // CHECK: cf.cond_br %{{.*}}, ^bb1, ^bb2
+ cf.cond_br %false, ^bb1, ^bb2
+^bb1:
+ // CHECK: omp.target
+ // CHECK: } {omp.target_unreachable}
+ omp.target {
+ omp.terminator
+ }
+ cf.br ^bb2
+^bb2:
+ return
+}
+
+// -----
+
+// CHECK-LABEL: func.func @test_cf_cond_br_true
+func.func @test_cf_cond_br_true() {
+ %true = arith.constant true
+ // CHECK: cf.cond_br %{{.*}}, ^bb1, ^bb2
+ cf.cond_br %true, ^bb1, ^bb2
+^bb1:
+ // CHECK: omp.target {
+ // CHECK-NEXT: omp.terminator
+ // CHECK-NEXT: }
+ // CHECK-NOT: omp.target_unreachable
+ omp.target {
+ omp.terminator
+ }
+ cf.br ^bb2
+^bb2:
+ return
+}
+
+// -----
+
+// CHECK-LABEL: func.func @test_runtime_condition
+func.func @test_runtime_condition(%arg0: i1) {
+ // CHECK: fir.if %arg0 {
+ fir.if %arg0 {
+ // Runtime condition - should NOT be marked
+ // CHECK: omp.target {
+ // CHECK-NEXT: omp.terminator
+ // CHECK-NEXT: }
+ // CHECK-NOT: omp.target_unreachable
+ omp.target {
+ omp.terminator
+ }
+ }
+ return
+}
+
+// -----
+
+// Test for multiple predecessors - one reachable, one unreachable
+// The block should NOT be marked as unreachable if ANY path is reachable
+// CHECK-LABEL: func.func @test_multiple_predecessors
+func.func @test_multiple_predecessors() {
+ %false = arith.constant false
+ cf.cond_br %false, ^bb2, ^bb1
+^bb1:
+ // Reachable path to bb2
+ cf.br ^bb2
+^bb2:
+ // This block has two predecessors:
+ // - bb0 with false condition (unreachable path)
+ // - bb1 with unconditional branch (reachable path)
+ // Target should NOT be marked unreachable because bb1 provides a reachable path
+ // CHECK: omp.target {
+ // CHECK-NEXT: omp.terminator
+ // CHECK-NEXT: }
+ // CHECK-NOT: omp.target_unreachable
+ omp.target {
+ omp.terminator
+ }
+ return
+}
+
+// -----
+
+// Test for multiple predecessors - ALL unreachable
+// CHECK-LABEL: func.func @test_multiple_predecessors_all_unreachable
+func.func @test_multiple_predecessors_all_unreachable() {
+ %false1 = arith.constant false
+ %false2 = arith.constant false
+ cf.cond_br %false1, ^bb3, ^bb1
+^bb1:
+ cf.cond_br %false2, ^bb3, ^bb2
+^bb2:
+ cf.br ^bb4
+^bb3:
+ // This block has two predecessors:
+ // - bb0 with false condition to bb3 (unreachable)
+ // - bb1 with false condition to bb3 (unreachable)
+ // Target SHOULD be marked unreachable because ALL paths are unreachable
+ // CHECK: omp.target
+ // CHECK-NEXT: omp.terminator
+ // CHECK-NEXT: } {omp.target_unreachable}
+ omp.target {
+ omp.terminator
+ }
+ cf.br ^bb4
+^bb4:
+ return
+}
+
+// -----
+
+// Test for multiple predecessors with mixed constant and runtime conditions
+// CHECK-LABEL: func.func @test_multiple_predecessors_mixed
+func.func @test_multiple_predecessors_mixed(%arg0: i1) {
+ %false = arith.constant false
+ cf.cond_br %false, ^bb2, ^bb1
+^bb1:
+ // Runtime condition - could branch to bb2
+ cf.cond_br %arg0, ^bb2, ^bb3
+^bb2:
+ // This block has two predecessors:
+ // - bb0 with false condition (unreachable)
+ // - bb1 with runtime condition (potentially reachable)
+ // Target should NOT be marked because we can't prove bb1 path is unreachable
+ // CHECK: omp.target {
+ // CHECK-NEXT: omp.terminator
+ // CHECK-NEXT: }
+ // CHECK-NOT: omp.target_unreachable
+ omp.target {
+ omp.terminator
+ }
+ cf.br ^bb3
+^bb3:
+ return
+}
diff --git a/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp b/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp
index f04d614633965..acb6145628799 100644
--- a/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp
+++ b/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp
@@ -6370,6 +6370,11 @@ static LogicalResult
convertOmpTarget(Operation &opInst, llvm::IRBuilderBase &builder,
LLVM::ModuleTranslation &modul...
[truncated]
|
bhandarkar-pranav
left a comment
There was a problem hiding this comment.
Thank you for the PR, @abidh. This is a good first pass, but I have some thoughts
- Any reason, you haven't styled this as a forward pass from the entry block to compute reachabililty and then check the omp.target for membership of the set of reachable block/ops? The reason, I say this is that in
IsUnreachableBlockwe do not identify transitively dead (although perhaps rare) omp.target ops. Consider a CFG like
bb0
cf.cond_br false, bb1, bb2
bb1
cf.br bb3
bb2
cf.br bb4
bb3:
omp.target // Since `IsUnreachableBlock` only checks predecessors of the block of
// omp.target, and bb1 doesn't terminate with cf.cond_br this function
// would miss this case.
bb4:
return
- MLIR has
DominanceInfo::isReachableFromEntry(). When combined withSCCPPassyou could get what you want for free
pm.addPass(createSCCPPass()); // Fold constant branches
pm.addPass(createCanonicalizerPass()); // Clean up
pm.addPass(flangomp::createMarkUnreachableTargetsPass()); // now use DominanceInfo::isReachableFromEntry
| //===----------------------------------------------------------------------===// | ||
|
|
||
| #include "flang/Optimizer/OpenMP/Passes.h" | ||
|
|
| /// %false = arith.constant false | ||
| /// cf.cond_br %false, ^bb1, ^bb2 | ||
| /// where ^bb1 is unreachable. | ||
| static bool isBlockUnreachable(Block *block) { |
| auto targetOp = cast<omp::TargetOp>(opInst); | ||
|
|
||
| // Skip target operations marked as unreachable | ||
| if (targetOp->hasAttr("omp.target_unreachable")) |
There was a problem hiding this comment.
This literal is used in a couple of places. Instead of using a string literal, consider adding a constexpr llvm::StringLiteral. For example
constexpr llvm::StringLiteral kTargetUnreachableAttr = "omp.target_unreachable" in an appropriate header file, perhaps mlir/include/mlir/Dialect/OpenMP/OpenMPDialect.h
| // Check if this is a cf.cond_br with constant condition | ||
| if (auto condBr = dyn_cast<cf::CondBranchOp>(terminator)) { | ||
| // Try to get the constant value of the condition | ||
| if (auto constOp = |
There was a problem hiding this comment.
Consider using m_Constant like here
m_Constant matches by trait so it'll also catch ops like llvm.mlir.constant and not just arith::ConstantOp.
There was a problem hiding this comment.
Thanks for the suggestion. Done.
|
It seems like there is a deeper problem here. Correct me if I'm wrong, it's been a while since I looked at what we do past lowering. Here's how I see this. We compile the lowered code that contains both the host and the device code. Then we outline the target regions, and end up with a module that will be further compiled for host, and a module targeted for the device. If we make a change to the host code, that requires updates in the target code, this has to happen while we still have both (i.e. before the two modules are permanently separated). Also, whoever makes the change must do the update to ensure that the result is correct. If we delete the omp.target region on host, either it should be ok to still proceed with the compilation of the outlined target function, or the device module must be updated accordingly at the same time, or we shouldn't delete omp.target regions after outlining at all (if we can't do any of the other things). We should not have a situation where passes on these two modules produce outputs that are not consistent with each other. This new pass seems to try to patch these inconsistencies up, but we shouldn't have them show up in the first place. How do we end up in this situation? What if the dead code elimination pass (that removes the omp.target) ran before outlining? Also, seems like there is a legitimate way that some device function may end up without the kernel attribute. Maybe that should trigger checks whether we need to generate code for it instead of just asserting? |
Thanks for you comment @kparzysz. I am also new to this area. Here is what I understand about this problem. When we give the offload flag (--offload-arch) to the driver, the compiler is run twice: once in host mode and once for device. Host compilation:
Device compilation:
So there is an inconsistency between host and device. Our approach:
This ensures host-device synchronization preventing the inconsistency before outlining happens.
The assertion in OpenMPOpt is relatively harmless and could be changed to a warning or early return. However, the changes in this PR fix the problem at the root cause level and also avoid unnecessarily processing code that will never execute. |
|
@abidh you are correct about how target offloading is currently handled: there are multiple frontend invocations (one for the host and one for each offloading target), each producing an independent MLIR module, so there's never a common MLIR module (though we do expect them to initially be very similar). This behavior is shared with Clang, and one of the reasons for this approach, IIRC, is that there could be e.g. preprocessor logic causing code variations between targets. We do function filtering relatively early (still within Flang), which seems to be the reason why we end up with different lists of target regions. One of the advantages of that is that we avoid processing potentially large amounts of host-only code in MLIR only to get rid of it at the end, but we'd probably address this issue if we updated that pass to run right before translation to LLVM IR. That would have the advantage of not making host and device MLIR modules diverge much until the very end, and other compiler frontends targeting OpenMP offloading could benefit from function filtering too. One other alternative approach I would suggest considering would be to update the In any case, I don't see a good reason to add an attribute to target regions to mark them as "deleted". I'd prefer to just delete them when this situation is detected. In terms of moving around function filtering within the compilation flow, for the time being I'm leaning towards keeping things as they are (less potentially breaking changes to be done, and it doesn't appear that diverging optimizations of host and device code could become a regular source of problems). We might want to discuss this later, with regards to reuse among multiple frontends, though. Edit: I forgot that I had already merged my |
Thanks for your comments @skatrak. Can you please clarify if you are asking me to wait until you have made your changes in |
I just made an edit to my previous comment. There's no need to wait for these changes, among other things because I had already merged that and forgot... 😅 |
1. Use DominanceInfo to find unreachable blocks. 2. Remove unreachable op instead of marking it. 3. Rename pass and files to better reflect functionality.
Thanks @bhandarkar-pranav. That is very useful comment. I have changed code now so that block reachability is checked using |
Thanks for this suggestion. It is much better than marking. I have implemented it now. But I kept the removal in a separate pass which runs for both host and device so that we have the same view on both. The |
There was a problem hiding this comment.
Why not build on top of MLIR's mlir/Analysis/DataFlow/DeadCodeAnalysis.h?
Using dominance instead of data flow analysis might be lighter weight in compilation time (I never checked this) but I'm not sure it will produce easy to interpret results in the presence of loops (maybe that would be okay in this case, I haven't thought hard about it).
I may be jumping the gun here, but in my books unreachable code analysis is a control flow problem rather than a data flow problem. Strictly speaking, dead code analysis should tell you if the results of an operation/instruction are used or not (backwards data flow) whereas the question being asked here is - "Can this instruction be reached from the entry block?" I'll underscore that I haven't looked into |
Use DeadCodeAnalysis instead of Dominance analysis.
|
Thanks @bhandarkar-pranav @tblah for your comments. I have pushed a commit that uses |
tblah
left a comment
There was a problem hiding this comment.
LGTM but please wait for wider agreement on the mlir dead code analysis approach
I noticed that new approach of using DeadCodeAnalysis can also handle the unused function. This commits adds a testcase for that.
|
Another positive thing about using |
bhandarkar-pranav
left a comment
There was a problem hiding this comment.
LGTM. The code got simplified very nicely!
|
LLVM Buildbot has detected a new failure on builder Full details are available at: https://lab.llvm.org/buildbot/#/builders/169/builds/20019 Here is the relevant piece of the build log for the reference |
The error seems unrelated to PR and the build bot is now green. |
When a target region is placed inside a constant false condition (e.g.,
if (.false.)), the dead code gets eliminated on the host side, removing theomp.targetoperation entirely. However, the device-side compilation pipeline is unaware of this elimination and attempts to generate kernel code. Since the host never created offload metadata for the eliminated target, the device-side kernel function lacks the "kernel" attribute, causingOpenMPOptto fail with an assertion when it expects all outlined kernels to have this attribute. The problem can be seen with the following code:It currently fails with the following assertion:
This PR adds
MarkUnreachableTargetsPassthat identifiesomp.targetoperations in unreachable code blocks and marks them withomp.target_unreachableattribute. This attribute is later used inFunctionFilteringPassand inOpenMPToLLVMIRTranslationto prevent generation of code for such op.