Skip to content

Conversation

@luciechoi
Copy link
Contributor

@luciechoi luciechoi commented Oct 30, 2025

Skip constant folding the loop predicates if the loop contains control convergence tokens referenced outside the loop.

Fixes #164496.

Verified loop_peeling.test passes with the fix.

Similar control convergence issues are found on other passes. #165642

HLSL used for tests:

RWStructuredBuffer<uint> Out : register(u0);

[numthreads(8,1,1)]
void main(uint3 TID : SV_GroupThreadID) {
    for (uint i = 0; i < 8; i++) {
        if (i == TID.x) {
            Out[TID.x] = WaveActiveMax(TID.x);
            break;
        }
    }
}

With nested loop:

RWStructuredBuffer<uint> Out : register(u0);

[numthreads(8,8,1)]
void main(uint3 TID : SV_GroupThreadID) {
    for (uint i = 0; i < 8; i++) {
        for (uint j = 0; j < 8; j++) {
            if (i == TID.x && j == TID.y) {
                uint index = TID.x * 8 + TID.y;
                Out[index] = WaveActiveMax(index);
                break;
            }
        }
    }
}

@llvmbot
Copy link
Member

llvmbot commented Oct 30, 2025

@llvm/pr-subscribers-llvm-transforms

Author: Lucie Choi (luciechoi)

Changes

Skip constant folding the loop predicates if the loop contains control convergence operations (e.g. wave intrinsics).

Fixes #164496.

Verified loop_peeling.test passes with the fix.


Full diff: https://github.com/llvm/llvm-project/pull/165643.diff

3 Files Affected:

  • (modified) llvm/lib/Transforms/Scalar/IndVarSimplify.cpp (+31)
  • (added) llvm/test/Transforms/IndVarSimplify/skip-predication-convergence.ll (+98)
  • (added) llvm/test/Transforms/IndVarSimplify/skip-predictaion-nested-convergence.ll (+139)
diff --git a/llvm/lib/Transforms/Scalar/IndVarSimplify.cpp b/llvm/lib/Transforms/Scalar/IndVarSimplify.cpp
index 7ebcc219efc15..421aad8872f9a 100644
--- a/llvm/lib/Transforms/Scalar/IndVarSimplify.cpp
+++ b/llvm/lib/Transforms/Scalar/IndVarSimplify.cpp
@@ -1859,6 +1859,37 @@ bool IndVarSimplify::predicateLoopExits(Loop *L, SCEVExpander &Rewriter) {
         }
       }
 
+  // If the loop body uses a convergence token defined within the loop, skip
+  // predication. This is to avoid changing the convergence behavior of the
+  // loop.
+  SmallVector<BasicBlock *, 16> blocks = ExitingBlocks;
+  SmallVector<Value *, 16> tokens = {};
+  size_t index = 0; // Assume Exiting Blocks are sorted.
+  while (index < blocks.size()) {
+    BasicBlock *BB = blocks[index];
+    index++;
+    const auto exitingBlockName = BB->getName();
+    for (Instruction &I : *BB) {
+      // Check if the instruction uses any convergence tokens.
+      if (auto *CB = dyn_cast<CallBase>(&I);
+          CB && !isa<ConvergenceControlInst>(&I)) {
+        auto token = CB->getConvergenceControlToken();
+        if (token && llvm::is_contained(tokens, token)) {
+          return false;
+        }
+      }
+      if (isa<ConvergenceControlInst>(&I)) {
+        tokens.push_back(cast<Value>(&I));
+      }
+    }
+
+    for (BasicBlock *Succ : successors(BB)) {
+      const auto succName = Succ->getName();
+      if (Succ != L->getLoopLatch() && !llvm::is_contained(blocks, Succ))
+        blocks.push_back(Succ);
+    }
+  }
+
   bool Changed = false;
   // Finally, do the actual predication for all predicatable blocks.  A couple
   // of notes here:
diff --git a/llvm/test/Transforms/IndVarSimplify/skip-predication-convergence.ll b/llvm/test/Transforms/IndVarSimplify/skip-predication-convergence.ll
new file mode 100644
index 0000000000000..12fca6778f15e
--- /dev/null
+++ b/llvm/test/Transforms/IndVarSimplify/skip-predication-convergence.ll
@@ -0,0 +1,98 @@
+; NOTE: Assertions have been autogenerated by utils/update_test_checks.py
+; RUN: opt < %s -passes=indvars -indvars-predicate-loops=1 -S | FileCheck %s
+
+; Loop with body using loop convergence token should be skipped by IndVarSimplify.
+
+%"class.hlsl::RWStructuredBuffer" = type { target("spirv.VulkanBuffer", [0 x i32], 12, 1), target("spirv.VulkanBuffer", i32, 12, 1) }
+
+@_ZL3Out = internal global %"class.hlsl::RWStructuredBuffer" poison, align 8
+@.str = private unnamed_addr constant [4 x i8] c"Out\00", align 1
+
+declare token @llvm.experimental.convergence.entry() #0
+
+define void @loop() local_unnamed_addr #1 {
+; CHECK-LABEL: @loop(
+; CHECK-NEXT:  entry:
+; CHECK-NEXT:    [[TMP0:%.*]] = tail call token @llvm.experimental.convergence.entry()
+; CHECK-NEXT:    [[TMP1:%.*]] = tail call target("spirv.VulkanBuffer", [0 x i32], 12, 1) @llvm.spv.resource.handlefrombinding.tspirv.VulkanBuffer_a0i32_12_1t(i32 0, i32 0, i32 1, i32 0, ptr nonnull @.str)
+; CHECK-NEXT:    [[TMP2:%.*]] = tail call target("spirv.VulkanBuffer", i32, 12, 1) @llvm.spv.resource.counterhandlefromimplicitbinding.tspirv.VulkanBuffer_i32_12_1t.tspirv.VulkanBuffer_a0i32_12_1t(target("spirv.VulkanBuffer", [0 x i32], 12, 1) [[TMP1]], i32 0, i32 0)
+; CHECK-NEXT:    store target("spirv.VulkanBuffer", [0 x i32], 12, 1) [[TMP1]], ptr @_ZL3Out, align 8
+; CHECK-NEXT:    store target("spirv.VulkanBuffer", i32, 12, 1) [[TMP2]], ptr getelementptr inbounds nuw (i8, ptr @_ZL3Out, i64 8), align 8
+; CHECK-NEXT:    [[TMP3:%.*]] = tail call i32 @llvm.spv.thread.id.in.group.i32(i32 0)
+; CHECK-NEXT:    br label [[FOR_COND_I:%.*]]
+; CHECK:       for.cond.i:
+; CHECK-NEXT:    [[I_0_I:%.*]] = phi i32 [ 0, [[ENTRY:%.*]] ], [ [[INC_I:%.*]], [[FOR_BODY_I:%.*]] ]
+; CHECK-NEXT:    [[TMP4:%.*]] = call token @llvm.experimental.convergence.loop() [ "convergencectrl"(token [[TMP0]]) ]
+; CHECK-NEXT:    [[CMP_I:%.*]] = icmp ult i32 [[I_0_I]], 8
+; CHECK-NEXT:    br i1 [[CMP_I]], label [[FOR_BODY_I]], label [[_Z4LOOPDV3_J_EXIT_LOOPEXIT:%.*]]
+; CHECK:       for.body.i:
+; CHECK-NEXT:    [[CMP1_I:%.*]] = icmp eq i32 [[I_0_I]], [[TMP3]]
+; CHECK-NEXT:    [[INC_I]] = add nuw nsw i32 [[I_0_I]], 1
+; CHECK-NEXT:    br i1 [[CMP1_I]], label [[IF_THEN_I:%.*]], label [[FOR_COND_I]]
+; CHECK:       _Z4loopDv3_j.exit.loopexit:
+; CHECK-NEXT:    br label [[_Z4LOOPDV3_J_EXIT:%.*]]
+; CHECK:       if.then.i:
+; CHECK-NEXT:    [[HLSL_WAVE_ACTIVE_MAX2_I:%.*]] = call spir_func i32 @llvm.spv.wave.reduce.umax.i32(i32 [[TMP3]]) [ "convergencectrl"(token [[TMP4]]) ]
+; CHECK-NEXT:    [[TMP5:%.*]] = tail call noundef align 4 dereferenceable(4) ptr addrspace(11) @llvm.spv.resource.getpointer.p11.tspirv.VulkanBuffer_a0i32_12_1t(target("spirv.VulkanBuffer", [0 x i32], 12, 1) [[TMP1]], i32 [[TMP3]])
+; CHECK-NEXT:    store i32 [[HLSL_WAVE_ACTIVE_MAX2_I]], ptr addrspace(11) [[TMP5]], align 4
+; CHECK-NEXT:    br label [[_Z4LOOPDV3_J_EXIT]]
+; CHECK:       _Z4loopDv3_j.exit:
+; CHECK-NEXT:    ret void
+;
+entry:
+  %0 = tail call token @llvm.experimental.convergence.entry()
+  %1 = tail call target("spirv.VulkanBuffer", [0 x i32], 12, 1) @llvm.spv.resource.handlefrombinding.tspirv.VulkanBuffer_a0i32_12_1t(i32 0, i32 0, i32 1, i32 0, ptr nonnull @.str)
+  %2 = tail call target("spirv.VulkanBuffer", i32, 12, 1) @llvm.spv.resource.counterhandlefromimplicitbinding.tspirv.VulkanBuffer_i32_12_1t.tspirv.VulkanBuffer_a0i32_12_1t(target("spirv.VulkanBuffer", [0 x i32], 12, 1) %1, i32 0, i32 0)
+  store target("spirv.VulkanBuffer", [0 x i32], 12, 1) %1, ptr @_ZL3Out, align 8
+  store target("spirv.VulkanBuffer", i32, 12, 1) %2, ptr getelementptr inbounds nuw (i8, ptr @_ZL3Out, i64 8), align 8
+  %3 = tail call i32 @llvm.spv.thread.id.in.group.i32(i32 0)
+  br label %for.cond.i
+
+; Loop:
+for.cond.i:                                       ; preds = %for.body.i, %entry
+  %i.0.i = phi i32 [ 0, %entry ], [ %inc.i, %for.body.i ]
+  %4 = call token @llvm.experimental.convergence.loop() [ "convergencectrl"(token %0) ]
+  %cmp.i = icmp ult i32 %i.0.i, 8
+  br i1 %cmp.i, label %for.body.i, label %_Z4loopDv3_j.exit.loopexit
+
+for.body.i:                                       ; preds = %for.cond.i
+  %cmp1.i = icmp eq i32 %i.0.i, %3
+  %inc.i = add nuw nsw i32 %i.0.i, 1
+  br i1 %cmp1.i, label %if.then.i, label %for.cond.i
+
+; Exit blocks
+_Z4loopDv3_j.exit.loopexit:                       ; preds = %for.cond.i
+  br label %_Z4loopDv3_j.exit
+
+if.then.i:                                        ; preds = %for.body.i
+  %hlsl.wave.active.max2.i = call spir_func i32 @llvm.spv.wave.reduce.umax.i32(i32 %3) [ "convergencectrl"(token %4) ]
+  %5 = tail call noundef align 4 dereferenceable(4) ptr addrspace(11) @llvm.spv.resource.getpointer.p11.tspirv.VulkanBuffer_a0i32_12_1t(target("spirv.VulkanBuffer", [0 x i32], 12, 1) %1, i32 %3)
+  store i32 %hlsl.wave.active.max2.i, ptr addrspace(11) %5, align 4
+  br label %_Z4loopDv3_j.exit
+
+_Z4loopDv3_j.exit:                                ; preds = %_Z4loopDv3_j.exit.loopexit, %if.then.i
+  ret void
+}
+
+; Function Attrs: mustprogress nofree nosync nounwind willreturn memory(none)
+declare i32 @llvm.spv.thread.id.in.group.i32(i32) #2
+
+; Function Attrs: convergent mustprogress nocallback nofree nosync nounwind willreturn memory(none)
+declare token @llvm.experimental.convergence.loop() #0
+
+; Function Attrs: convergent mustprogress nocallback nofree nosync nounwind willreturn memory(none)
+declare i32 @llvm.spv.wave.reduce.umax.i32(i32) #0
+
+; Function Attrs: mustprogress nocallback nofree nosync nounwind willreturn memory(none)
+declare target("spirv.VulkanBuffer", [0 x i32], 12, 1) @llvm.spv.resource.handlefrombinding.tspirv.VulkanBuffer_a0i32_12_1t(i32, i32, i32, i32, ptr) #4
+
+; Function Attrs: mustprogress nocallback nofree nosync nounwind willreturn memory(none)
+declare target("spirv.VulkanBuffer", i32, 12, 1) @llvm.spv.resource.counterhandlefromimplicitbinding.tspirv.VulkanBuffer_i32_12_1t.tspirv.VulkanBuffer_a0i32_12_1t(target("spirv.VulkanBuffer", [0 x i32], 12, 1), i32, i32) #4
+
+; Function Attrs: mustprogress nocallback nofree nosync nounwind willreturn memory(none)
+declare ptr addrspace(11) @llvm.spv.resource.getpointer.p11.tspirv.VulkanBuffer_a0i32_12_1t(target("spirv.VulkanBuffer", [0 x i32], 12, 1), i32) #4
+
+attributes #0 = { convergent mustprogress nocallback nofree nosync nounwind willreturn memory(none) }
+attributes #1 = { convergent noinline norecurse "frame-pointer"="all" "hlsl.numthreads"="8,1,1" "hlsl.shader"="compute" "no-infs-fp-math"="true" "no-nans-fp-math"="true" "no-signed-zeros-fp-math"="true" "no-trapping-math"="true" "stack-protector-buffer-size"="8" }
+attributes #2 = { mustprogress nofree nosync nounwind willreturn memory(none) }
+attributes #4 = { mustprogress nocallback nofree nosync nounwind willreturn memory(none) }
diff --git a/llvm/test/Transforms/IndVarSimplify/skip-predictaion-nested-convergence.ll b/llvm/test/Transforms/IndVarSimplify/skip-predictaion-nested-convergence.ll
new file mode 100644
index 0000000000000..22f25b1428556
--- /dev/null
+++ b/llvm/test/Transforms/IndVarSimplify/skip-predictaion-nested-convergence.ll
@@ -0,0 +1,139 @@
+; NOTE: Assertions have been autogenerated by utils/update_test_checks.py
+; RUN: opt < %s -passes=indvars -indvars-predicate-loops=1 -S | FileCheck %s
+
+; Nested loops with body using loop convergence token should be skipped by IndVarSimplify.
+
+%"class.hlsl::RWStructuredBuffer" = type { target("spirv.VulkanBuffer", [0 x i32], 12, 1), target("spirv.VulkanBuffer", i32, 12, 1) }
+
+@_ZL3Out = internal global %"class.hlsl::RWStructuredBuffer" poison, align 8
+@.str = private unnamed_addr constant [4 x i8] c"Out\00", align 1
+
+; Function Attrs: convergent mustprogress nocallback nofree nosync nounwind willreturn memory(none)
+declare token @llvm.experimental.convergence.entry() #0
+
+define void @nested() local_unnamed_addr #1 {
+; CHECK-LABEL: @nested(
+; CHECK-NEXT:  entry:
+; CHECK-NEXT:    [[TMP0:%.*]] = tail call token @llvm.experimental.convergence.entry()
+; CHECK-NEXT:    [[TMP1:%.*]] = tail call target("spirv.VulkanBuffer", [0 x i32], 12, 1) @llvm.spv.resource.handlefrombinding.tspirv.VulkanBuffer_a0i32_12_1t(i32 0, i32 0, i32 1, i32 0, ptr nonnull @.str)
+; CHECK-NEXT:    [[TMP2:%.*]] = tail call target("spirv.VulkanBuffer", i32, 12, 1) @llvm.spv.resource.counterhandlefromimplicitbinding.tspirv.VulkanBuffer_i32_12_1t.tspirv.VulkanBuffer_a0i32_12_1t(target("spirv.VulkanBuffer", [0 x i32], 12, 1) [[TMP1]], i32 0, i32 0)
+; CHECK-NEXT:    store target("spirv.VulkanBuffer", [0 x i32], 12, 1) [[TMP1]], ptr @_ZL3Out, align 8
+; CHECK-NEXT:    store target("spirv.VulkanBuffer", i32, 12, 1) [[TMP2]], ptr getelementptr inbounds nuw (i8, ptr @_ZL3Out, i64 8), align 8
+; CHECK-NEXT:    [[TMP3:%.*]] = tail call i32 @llvm.spv.thread.id.in.group.i32(i32 0)
+; CHECK-NEXT:    [[TMP4:%.*]] = tail call i32 @llvm.spv.thread.id.in.group.i32(i32 1)
+; CHECK-NEXT:    [[MUL_I:%.*]] = shl nuw nsw i32 [[TMP3]], 3
+; CHECK-NEXT:    [[ADD_I:%.*]] = add nuw nsw i32 [[MUL_I]], [[TMP4]]
+; CHECK-NEXT:    br label [[FOR_COND_I:%.*]]
+; CHECK:       for.cond.i:
+; CHECK-NEXT:    [[I_0_I:%.*]] = phi i32 [ 0, [[ENTRY:%.*]] ], [ [[INC10_I:%.*]], [[CLEANUP_I:%.*]] ]
+; CHECK-NEXT:    [[TMP5:%.*]] = call token @llvm.experimental.convergence.loop() [ "convergencectrl"(token [[TMP0]]) ]
+; CHECK-NEXT:    [[CMP_I:%.*]] = icmp ult i32 [[I_0_I]], 8
+; CHECK-NEXT:    br i1 [[CMP_I]], label [[FOR_COND1_I_PREHEADER:%.*]], label [[_Z4NESTEDDV3_J_EXIT:%.*]]
+; CHECK:       for.cond1.i.preheader:
+; CHECK-NEXT:    [[CMP5_I:%.*]] = icmp eq i32 [[I_0_I]], [[TMP3]]
+; CHECK-NEXT:    br label [[FOR_COND1_I:%.*]]
+; CHECK:       for.cond1.i:
+; CHECK-NEXT:    [[J_0_I:%.*]] = phi i32 [ [[INC_I:%.*]], [[FOR_BODY4_I:%.*]] ], [ 0, [[FOR_COND1_I_PREHEADER]] ]
+; CHECK-NEXT:    [[TMP6:%.*]] = call token @llvm.experimental.convergence.loop() [ "convergencectrl"(token [[TMP5]]) ]
+; CHECK-NEXT:    [[CMP2_I:%.*]] = icmp ult i32 [[J_0_I]], 8
+; CHECK-NEXT:    br i1 [[CMP2_I]], label [[FOR_BODY4_I]], label [[CLEANUP_I_LOOPEXIT:%.*]]
+; CHECK:       for.body4.i:
+; CHECK-NEXT:    [[CMP6_I:%.*]] = icmp eq i32 [[J_0_I]], [[TMP4]]
+; CHECK-NEXT:    [[OR_COND:%.*]] = select i1 [[CMP5_I]], i1 [[CMP6_I]], i1 false
+; CHECK-NEXT:    [[INC_I]] = add nuw nsw i32 [[J_0_I]], 1
+; CHECK-NEXT:    br i1 [[OR_COND]], label [[IF_THEN_I:%.*]], label [[FOR_COND1_I]]
+; CHECK:       cleanup.i.loopexit:
+; CHECK-NEXT:    br label [[CLEANUP_I]]
+; CHECK:       if.then.i:
+; CHECK-NEXT:    [[HLSL_WAVE_ACTIVE_MAX7_I:%.*]] = call spir_func i32 @llvm.spv.wave.reduce.umax.i32(i32 [[ADD_I]]) [ "convergencectrl"(token [[TMP6]]) ]
+; CHECK-NEXT:    [[TMP7:%.*]] = tail call noundef align 4 dereferenceable(4) ptr addrspace(11) @llvm.spv.resource.getpointer.p11.tspirv.VulkanBuffer_a0i32_12_1t(target("spirv.VulkanBuffer", [0 x i32], 12, 1) [[TMP1]], i32 [[ADD_I]])
+; CHECK-NEXT:    store i32 [[HLSL_WAVE_ACTIVE_MAX7_I]], ptr addrspace(11) [[TMP7]], align 4
+; CHECK-NEXT:    br label [[CLEANUP_I]]
+; CHECK:       cleanup.i:
+; CHECK-NEXT:    [[INC10_I]] = add nuw nsw i32 [[I_0_I]], 1
+; CHECK-NEXT:    br label [[FOR_COND_I]]
+; CHECK:       _Z4nestedDv3_j.exit:
+; CHECK-NEXT:    ret void
+;
+entry:
+  %0 = tail call token @llvm.experimental.convergence.entry()
+  %1 = tail call target("spirv.VulkanBuffer", [0 x i32], 12, 1) @llvm.spv.resource.handlefrombinding.tspirv.VulkanBuffer_a0i32_12_1t(i32 0, i32 0, i32 1, i32 0, ptr nonnull @.str)
+  %2 = tail call target("spirv.VulkanBuffer", i32, 12, 1) @llvm.spv.resource.counterhandlefromimplicitbinding.tspirv.VulkanBuffer_i32_12_1t.tspirv.VulkanBuffer_a0i32_12_1t(target("spirv.VulkanBuffer", [0 x i32], 12, 1) %1, i32 0, i32 0)
+  store target("spirv.VulkanBuffer", [0 x i32], 12, 1) %1, ptr @_ZL3Out, align 8
+  store target("spirv.VulkanBuffer", i32, 12, 1) %2, ptr getelementptr inbounds nuw (i8, ptr @_ZL3Out, i64 8), align 8
+  %3 = tail call i32 @llvm.spv.thread.id.in.group.i32(i32 0)
+  %4 = tail call i32 @llvm.spv.thread.id.in.group.i32(i32 1)
+  %mul.i = shl nuw nsw i32 %3, 3
+  %add.i = add nuw nsw i32 %mul.i, %4
+  br label %for.cond.i
+
+for.cond.i:                                       ; preds = %cleanup.i, %entry
+  %i.0.i = phi i32 [ 0, %entry ], [ %inc10.i, %cleanup.i ]
+  %5 = call token @llvm.experimental.convergence.loop() [ "convergencectrl"(token %0) ]
+  %cmp.i = icmp ult i32 %i.0.i, 8
+  br i1 %cmp.i, label %for.cond1.i.preheader, label %_Z4nestedDv3_j.exit
+
+; Preheader:
+for.cond1.i.preheader:                            ; preds = %for.cond.i
+  %cmp5.i = icmp eq i32 %i.0.i, %3
+  br label %for.cond1.i
+
+; Loop:
+for.cond1.i:                                      ; preds = %for.body4.i, %for.cond1.i.preheader
+  %j.0.i = phi i32 [ %inc.i, %for.body4.i ], [ 0, %for.cond1.i.preheader ]
+  %6 = call token @llvm.experimental.convergence.loop() [ "convergencectrl"(token %5) ]
+  %cmp2.i = icmp ult i32 %j.0.i, 8
+  br i1 %cmp2.i, label %for.body4.i, label %cleanup.i.loopexit
+
+for.body4.i:                                      ; preds = %for.cond1.i
+  %cmp6.i = icmp eq i32 %j.0.i, %4
+  %or.cond = select i1 %cmp5.i, i1 %cmp6.i, i1 false
+  %inc.i = add nuw nsw i32 %j.0.i, 1
+  br i1 %or.cond, label %if.then.i, label %for.cond1.i
+
+; Exit blocks
+cleanup.i.loopexit:                               ; preds = %for.cond1.i
+  br label %cleanup.i
+
+if.then.i:                                        ; preds = %for.body4.i
+  %hlsl.wave.active.max7.i = call spir_func i32 @llvm.spv.wave.reduce.umax.i32(i32 %add.i) [ "convergencectrl"(token %6) ]
+  %7 = tail call noundef align 4 dereferenceable(4) ptr addrspace(11) @llvm.spv.resource.getpointer.p11.tspirv.VulkanBuffer_a0i32_12_1t(target("spirv.VulkanBuffer", [0 x i32], 12, 1) %1, i32 %add.i)
+  store i32 %hlsl.wave.active.max7.i, ptr addrspace(11) %7, align 4
+  br label %cleanup.i
+
+cleanup.i:                                        ; preds = %cleanup.i.loopexit, %if.then.i
+  %inc10.i = add nuw nsw i32 %i.0.i, 1
+  br label %for.cond.i
+
+_Z4nestedDv3_j.exit:                                ; preds = %for.cond.i
+  ret void
+}
+
+; Function Attrs: mustprogress nofree nosync nounwind willreturn memory(none)
+declare i32 @llvm.spv.thread.id.in.group.i32(i32) #2
+
+; Function Attrs: convergent mustprogress nocallback nofree nosync nounwind willreturn memory(none)
+declare token @llvm.experimental.convergence.loop() #0
+
+; Function Attrs: convergent mustprogress nocallback nofree nosync nounwind willreturn memory(none)
+declare i32 @llvm.spv.wave.reduce.umax.i32(i32) #0
+
+; Function Attrs: mustprogress nocallback nofree nosync nounwind willreturn memory(none)
+declare target("spirv.VulkanBuffer", [0 x i32], 12, 1) @llvm.spv.resource.handlefrombinding.tspirv.VulkanBuffer_a0i32_12_1t(i32, i32, i32, i32, ptr) #4
+
+; Function Attrs: mustprogress nocallback nofree nosync nounwind willreturn memory(none)
+declare target("spirv.VulkanBuffer", i32, 12, 1) @llvm.spv.resource.counterhandlefromimplicitbinding.tspirv.VulkanBuffer_i32_12_1t.tspirv.VulkanBuffer_a0i32_12_1t(target("spirv.VulkanBuffer", [0 x i32], 12, 1), i32, i32) #4
+
+; Function Attrs: mustprogress nocallback nofree nosync nounwind willreturn memory(none)
+declare ptr addrspace(11) @llvm.spv.resource.getpointer.p11.tspirv.VulkanBuffer_a0i32_12_1t(target("spirv.VulkanBuffer", [0 x i32], 12, 1), i32) #4
+
+; Function Attrs: nocallback nofree nosync nounwind willreturn memory(inaccessiblemem: readwrite)
+declare void @llvm.experimental.noalias.scope.decl(metadata) #5
+
+attributes #0 = { convergent mustprogress nocallback nofree nosync nounwind willreturn memory(none) }
+attributes #1 = { convergent noinline norecurse "frame-pointer"="all" "hlsl.numthreads"="8,8,1" "hlsl.shader"="compute" "no-infs-fp-math"="true" "no-nans-fp-math"="true" "no-signed-zeros-fp-math"="true" "no-trapping-math"="true" "stack-protector-buffer-size"="8" }
+attributes #2 = { mustprogress nofree nosync nounwind willreturn memory(none) }
+attributes #3 = { mustprogress nocallback nofree nosync nounwind willreturn memory(argmem: readwrite) }
+attributes #4 = { mustprogress nocallback nofree nosync nounwind willreturn memory(none) }
+attributes #5 = { nocallback nofree nosync nounwind willreturn memory(inaccessiblemem: readwrite) }
+attributes #6 = { nounwind }

@luciechoi luciechoi requested review from Keenuts and s-perron October 30, 2025 03:47
}
}

for (BasicBlock *Succ : successors(BB)) {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is very inefficient. It would make more sense to look for any convergence control tokens defined in the loop and inspect their users. Instead of inspecting the entire function reachable from the loop exit.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

+1

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thank you both for pointing out. As discussed offline with Steven, I've modified to use analyzeBasicBlock helper instead, like in LoopUnrollPass.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't think analyzeBasicBlock is appropriate here -- CodeMetrics is a utility for transforms that perform block cloning.

Looking a bit closer, I'm also not sure the problem here is really related to convergence control per se. This transform is generally not valid for cases where a value defined in the loop is used outside the loop. Normally, this is enforced by checking for the absence of LCSSA phis. However, token-returning instructions are a special case, because LCSSA cannot be formed for them.

So I think what you should be doing is looking for any token-returning instructions in the loop (not necessarily convergence control in particular) and bailing out if they have uses outside the loop.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't think analyzeBasicBlock is appropriate here -- CodeMetrics is a utility for transforms that perform block cloning.

For the convergence tokens, the check if there is an extended loop does exactly what you suggested. If the loop has a convergence token and it is used outside the loop, then the convergence will be ExtendedLoop. However, it can be overkill because it checks for more than it necessary for this pass.

Also it is not appropriate if we want to generalized this to all tokens, which seems to be the right direction.

Copy link
Contributor Author

@luciechoi luciechoi Nov 1, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks, removed analyzeBasicBlock and updated the code to check for any tokens used outside the loop.

@nikic nikic requested a review from ssahasra October 30, 2025 09:01
@luciechoi luciechoi requested a review from nikic October 30, 2025 22:19
}
}

for (BasicBlock *Succ : successors(BB)) {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't think analyzeBasicBlock is appropriate here -- CodeMetrics is a utility for transforms that perform block cloning.

For the convergence tokens, the check if there is an extended loop does exactly what you suggested. If the loop has a convergence token and it is used outside the loop, then the convergence will be ExtendedLoop. However, it can be overkill because it checks for more than it necessary for this pass.

Also it is not appropriate if we want to generalized this to all tokens, which seems to be the right direction.

@github-actions
Copy link

github-actions bot commented Nov 1, 2025

✅ With the latest revision this PR passed the C/C++ code formatter.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Projects

None yet

Development

Successfully merging this pull request may close these issues.

[HLSL][SPIR-V] Invalid LLVM IR "convergencectrl"(token poison) for wave intrinsics WaveActiveMax

4 participants