diff --git a/llvm/include/llvm/Transforms/Utils/BasicBlockUtils.h b/llvm/include/llvm/Transforms/Utils/BasicBlockUtils.h index 49885b7f06a15..c951597498033 100644 --- a/llvm/include/llvm/Transforms/Utils/BasicBlockUtils.h +++ b/llvm/include/llvm/Transforms/Utils/BasicBlockUtils.h @@ -45,6 +45,10 @@ class ReturnInst; class TargetLibraryInfo; class Value; +/// Check if the given basic block contains any loop or entry convergent +/// intrinsic instructions. +LLVM_ABI bool HasLoopOrEntryConvergenceToken(const BasicBlock *BB); + /// Replace contents of every block in \p BBs with single unreachable /// instruction. If \p Updates is specified, collect all necessary DT updates /// into this vector. If \p KeepOneInputPHIs is true, one-input Phis in diff --git a/llvm/lib/Transforms/Scalar/JumpThreading.cpp b/llvm/lib/Transforms/Scalar/JumpThreading.cpp index c7d71eb5633ec..aa8c80abb66ab 100644 --- a/llvm/lib/Transforms/Scalar/JumpThreading.cpp +++ b/llvm/lib/Transforms/Scalar/JumpThreading.cpp @@ -1920,6 +1920,13 @@ bool JumpThreadingPass::maybeMergeBasicBlockIntoOnlyPred(BasicBlock *BB) { if (Unreachable.count(SinglePred)) return false; + // Don't merge if both the basic block and the predecessor contain loop or + // entry convergent intrinsics, since there may only be one convergence token + // per block. + if (HasLoopOrEntryConvergenceToken(BB) && + HasLoopOrEntryConvergenceToken(SinglePred)) + return false; + // If SinglePred was a loop header, BB becomes one. if (LoopHeaders.erase(SinglePred)) LoopHeaders.insert(BB); diff --git a/llvm/lib/Transforms/Utils/BasicBlockUtils.cpp b/llvm/lib/Transforms/Utils/BasicBlockUtils.cpp index 076c5da4393fc..da3e60da9df74 100644 --- a/llvm/lib/Transforms/Utils/BasicBlockUtils.cpp +++ b/llvm/lib/Transforms/Utils/BasicBlockUtils.cpp @@ -92,7 +92,7 @@ emptyAndDetachBlock(BasicBlock *BB, "applying corresponding DTU updates."); } -static bool HasLoopOrEntryConvergenceToken(const BasicBlock *BB) { +bool llvm::HasLoopOrEntryConvergenceToken(const BasicBlock *BB) { for (const Instruction &I : *BB) { const ConvergenceControlInst *CCI = dyn_cast(&I); if (CCI && (CCI->isLoop() || CCI->isEntry())) diff --git a/llvm/test/Transforms/JumpThreading/skip-merging-duplicate-convergence-intrinsics.ll b/llvm/test/Transforms/JumpThreading/skip-merging-duplicate-convergence-intrinsics.ll new file mode 100644 index 0000000000000..7d3ef44db203c --- /dev/null +++ b/llvm/test/Transforms/JumpThreading/skip-merging-duplicate-convergence-intrinsics.ll @@ -0,0 +1,68 @@ +; NOTE: Assertions have been autogenerated by utils/update_test_checks.py +; RUN: opt < %s -S -passes=jump-threading | FileCheck %s + +declare token @llvm.experimental.convergence.entry() #0 + +define void @nested(i32 %tidx, i32 %tidy, ptr %array) #0 { +; CHECK-LABEL: @nested( +; CHECK-NEXT: entry: +; CHECK-NEXT: [[TMP0:%.*]] = tail call token @llvm.experimental.convergence.entry() +; CHECK-NEXT: [[TMP1:%.*]] = or i32 [[TIDY:%.*]], [[TIDX:%.*]] +; CHECK-NEXT: [[OR_COND_I:%.*]] = icmp eq i32 [[TMP1]], 0 +; CHECK-NEXT: br label [[FOR_COND_I:%.*]] +; CHECK: for.cond.i: +; CHECK-NEXT: [[TMP2:%.*]] = call token @llvm.experimental.convergence.loop() [ "convergencectrl"(token [[TMP0]]) ] +; CHECK-NEXT: br label [[FOR_COND1_I:%.*]] +; CHECK: for.cond1.i: +; CHECK-NEXT: [[CMP2_I:%.*]] = phi i1 [ false, [[FOR_BODY4_I:%.*]] ], [ true, [[FOR_COND_I]] ] +; CHECK-NEXT: [[TMP3:%.*]] = call token @llvm.experimental.convergence.loop() [ "convergencectrl"(token [[TMP2]]) ] +; CHECK-NEXT: br i1 [[CMP2_I]], label [[FOR_BODY4_I]], label [[EXIT:%.*]] +; CHECK: for.body4.i: +; CHECK-NEXT: br i1 [[OR_COND_I]], label [[IF_THEN_I:%.*]], label [[FOR_COND1_I]] +; CHECK: if.then.i: +; CHECK-NEXT: [[TEST_VAL:%.*]] = call spir_func i32 @func_test(i32 0) [ "convergencectrl"(token [[TMP3]]) ] +; CHECK-NEXT: [[TMP4:%.*]] = getelementptr inbounds i32, ptr [[ARRAY:%.*]], i32 0 +; CHECK-NEXT: store i32 [[TEST_VAL]], ptr [[TMP4]], align 4 +; CHECK-NEXT: br label [[EXIT]] +; CHECK: exit: +; CHECK-NEXT: ret void +; +entry: + %0 = tail call token @llvm.experimental.convergence.entry() + %2 = or i32 %tidy, %tidx + %or.cond.i = icmp eq i32 %2, 0 + br label %for.cond.i + +for.cond.i: + %3 = call token @llvm.experimental.convergence.loop() [ "convergencectrl"(token %0) ] + br label %for.cond1.i + +for.cond1.i: + %cmp2.i = phi i1 [ false, %for.body4.i ], [ true, %for.cond.i ] + %4 = call token @llvm.experimental.convergence.loop() [ "convergencectrl"(token %3) ] + br i1 %cmp2.i, label %for.body4.i, label %cleanup.i.loopexit + +for.body4.i: + br i1 %or.cond.i, label %if.then.i, label %for.cond1.i + +if.then.i: + %test.val = call spir_func i32 @func_test(i32 0) [ "convergencectrl"(token %4) ] + %5 = getelementptr inbounds i32, ptr %array, i32 0 + store i32 %test.val, ptr %5, align 4 + br label %cleanup.i + +cleanup.i.loopexit: + br label %cleanup.i + +cleanup.i: + br label %exit + +exit: + ret void +} + +declare token @llvm.experimental.convergence.loop() #0 + +declare i32 @func_test(i32) #0 + +attributes #0 = { convergent } \ No newline at end of file