Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[clang][SPIR-V] Always add convergence intrinsics #88918

Merged
merged 6 commits into from
May 14, 2024

Conversation

Keenuts
Copy link
Contributor

@Keenuts Keenuts commented Apr 16, 2024

PR #80680 added bits in the codegen to lazily add convergence intrinsics when required. This logic relied on the LoopStack. The issue is when parsing the condition, the loopstack doesn't yet reflect the correct values, as expected since we are not yet in the loop.

However, convergence tokens should sometimes already be available. The solution which seemed the simplest is to greedily generate the tokens when we generate SPIR-V.

Fixes #88144

@llvmbot llvmbot added clang Clang issues not falling into any other category clang:codegen HLSL HLSL Language Support labels Apr 16, 2024
@Keenuts Keenuts requested a review from llvm-beanz April 16, 2024 15:39
@llvmbot
Copy link
Collaborator

llvmbot commented Apr 16, 2024

@llvm/pr-subscribers-clang-codegen
@llvm/pr-subscribers-hlsl

@llvm/pr-subscribers-clang

Author: Nathan Gauër (Keenuts)

Changes

PR #80680 added bits in the codegen to lazily add convergence intrinsics when required. This logic relied on the LoopStack. The issue is when parsing the condition, the loopstack doesn't yet reflect the correct values, as expected since we are not yet in the loop.

However, convergence tokens should sometimes already be available. The solution which seemed the simplest is to greedily generate the tokens when we generate SPIR-V.

Fixes #88144


Patch is 26.72 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/88918.diff

9 Files Affected:

  • (modified) clang/lib/CodeGen/CGBuiltin.cpp (+1-87)
  • (modified) clang/lib/CodeGen/CGCall.cpp (+3)
  • (modified) clang/lib/CodeGen/CGStmt.cpp (+94)
  • (modified) clang/lib/CodeGen/CodeGenFunction.cpp (+9)
  • (modified) clang/lib/CodeGen/CodeGenFunction.h (+8-1)
  • (modified) clang/test/CodeGenHLSL/builtins/RWBuffer-constructor.hlsl (-1)
  • (added) clang/test/CodeGenHLSL/convergence/do.while.hlsl (+90)
  • (added) clang/test/CodeGenHLSL/convergence/for.hlsl (+121)
  • (added) clang/test/CodeGenHLSL/convergence/while.hlsl (+119)
diff --git a/clang/lib/CodeGen/CGBuiltin.cpp b/clang/lib/CodeGen/CGBuiltin.cpp
index df7502b8def531..f5d40a1555fcb5 100644
--- a/clang/lib/CodeGen/CGBuiltin.cpp
+++ b/clang/lib/CodeGen/CGBuiltin.cpp
@@ -1133,91 +1133,8 @@ struct BitTest {
   static BitTest decodeBitTestBuiltin(unsigned BuiltinID);
 };
 
-// Returns the first convergence entry/loop/anchor instruction found in |BB|.
-// std::nullptr otherwise.
-llvm::IntrinsicInst *getConvergenceToken(llvm::BasicBlock *BB) {
-  for (auto &I : *BB) {
-    auto *II = dyn_cast<llvm::IntrinsicInst>(&I);
-    if (II && isConvergenceControlIntrinsic(II->getIntrinsicID()))
-      return II;
-  }
-  return nullptr;
-}
-
 } // namespace
 
-llvm::CallBase *
-CodeGenFunction::addConvergenceControlToken(llvm::CallBase *Input,
-                                            llvm::Value *ParentToken) {
-  llvm::Value *bundleArgs[] = {ParentToken};
-  llvm::OperandBundleDef OB("convergencectrl", bundleArgs);
-  auto Output = llvm::CallBase::addOperandBundle(
-      Input, llvm::LLVMContext::OB_convergencectrl, OB, Input);
-  Input->replaceAllUsesWith(Output);
-  Input->eraseFromParent();
-  return Output;
-}
-
-llvm::IntrinsicInst *
-CodeGenFunction::emitConvergenceLoopToken(llvm::BasicBlock *BB,
-                                          llvm::Value *ParentToken) {
-  CGBuilderTy::InsertPoint IP = Builder.saveIP();
-  Builder.SetInsertPoint(&BB->front());
-  auto CB = Builder.CreateIntrinsic(
-      llvm::Intrinsic::experimental_convergence_loop, {}, {});
-  Builder.restoreIP(IP);
-
-  auto I = addConvergenceControlToken(CB, ParentToken);
-  return cast<llvm::IntrinsicInst>(I);
-}
-
-llvm::IntrinsicInst *
-CodeGenFunction::getOrEmitConvergenceEntryToken(llvm::Function *F) {
-  auto *BB = &F->getEntryBlock();
-  auto *token = getConvergenceToken(BB);
-  if (token)
-    return token;
-
-  // Adding a convergence token requires the function to be marked as
-  // convergent.
-  F->setConvergent();
-
-  CGBuilderTy::InsertPoint IP = Builder.saveIP();
-  Builder.SetInsertPoint(&BB->front());
-  auto I = Builder.CreateIntrinsic(
-      llvm::Intrinsic::experimental_convergence_entry, {}, {});
-  assert(isa<llvm::IntrinsicInst>(I));
-  Builder.restoreIP(IP);
-
-  return cast<llvm::IntrinsicInst>(I);
-}
-
-llvm::IntrinsicInst *
-CodeGenFunction::getOrEmitConvergenceLoopToken(const LoopInfo *LI) {
-  assert(LI != nullptr);
-
-  auto *token = getConvergenceToken(LI->getHeader());
-  if (token)
-    return token;
-
-  llvm::IntrinsicInst *PII =
-      LI->getParent()
-          ? emitConvergenceLoopToken(
-                LI->getHeader(), getOrEmitConvergenceLoopToken(LI->getParent()))
-          : getOrEmitConvergenceEntryToken(LI->getHeader()->getParent());
-
-  return emitConvergenceLoopToken(LI->getHeader(), PII);
-}
-
-llvm::CallBase *
-CodeGenFunction::addControlledConvergenceToken(llvm::CallBase *Input) {
-  llvm::Value *ParentToken =
-      LoopStack.hasInfo()
-          ? getOrEmitConvergenceLoopToken(&LoopStack.getInfo())
-          : getOrEmitConvergenceEntryToken(Input->getFunction());
-  return addConvergenceControlToken(Input, ParentToken);
-}
-
 BitTest BitTest::decodeBitTestBuiltin(unsigned BuiltinID) {
   switch (BuiltinID) {
     // Main portable variants.
@@ -18306,12 +18223,9 @@ Value *CodeGenFunction::EmitHLSLBuiltinExpr(unsigned BuiltinID,
         ArrayRef<Value *>{Op0}, nullptr, "dx.rsqrt");
   }
   case Builtin::BI__builtin_hlsl_wave_get_lane_index: {
-    auto *CI = EmitRuntimeCall(CGM.CreateRuntimeFunction(
+    return EmitRuntimeCall(CGM.CreateRuntimeFunction(
         llvm::FunctionType::get(IntTy, {}, false), "__hlsl_wave_get_lane_index",
         {}, false, true));
-    if (getTarget().getTriple().isSPIRVLogical())
-      CI = dyn_cast<CallInst>(addControlledConvergenceToken(CI));
-    return CI;
   }
   }
   return nullptr;
diff --git a/clang/lib/CodeGen/CGCall.cpp b/clang/lib/CodeGen/CGCall.cpp
index f12765b826935b..06d4bceacfd34b 100644
--- a/clang/lib/CodeGen/CGCall.cpp
+++ b/clang/lib/CodeGen/CGCall.cpp
@@ -4824,6 +4824,9 @@ llvm::CallInst *CodeGenFunction::EmitRuntimeCall(llvm::FunctionCallee callee,
   llvm::CallInst *call = Builder.CreateCall(
       callee, args, getBundlesForFunclet(callee.getCallee()), name);
   call->setCallingConv(getRuntimeCC());
+
+  if (getTarget().getTriple().isSPIRVLogical() && call->isConvergent())
+    return dyn_cast<llvm::CallInst>(addControlledConvergenceToken(call));
   return call;
 }
 
diff --git a/clang/lib/CodeGen/CGStmt.cpp b/clang/lib/CodeGen/CGStmt.cpp
index 576fe2f7a2d46f..f8287e100f4bd5 100644
--- a/clang/lib/CodeGen/CGStmt.cpp
+++ b/clang/lib/CodeGen/CGStmt.cpp
@@ -915,6 +915,10 @@ void CodeGenFunction::EmitWhileStmt(const WhileStmt &S,
   JumpDest LoopHeader = getJumpDestInCurrentScope("while.cond");
   EmitBlock(LoopHeader.getBlock());
 
+  if (getTarget().getTriple().isSPIRVLogical())
+    ConvergenceTokenStack.push_back(emitConvergenceLoopToken(
+        LoopHeader.getBlock(), ConvergenceTokenStack.back()));
+
   // Create an exit block for when the condition fails, which will
   // also become the break target.
   JumpDest LoopExit = getJumpDestInCurrentScope("while.end");
@@ -1017,6 +1021,9 @@ void CodeGenFunction::EmitWhileStmt(const WhileStmt &S,
   // block.
   if (llvm::EnableSingleByteCoverage)
     incrementProfileCounter(&S);
+
+  if (getTarget().getTriple().isSPIRVLogical())
+    ConvergenceTokenStack.pop_back();
 }
 
 void CodeGenFunction::EmitDoStmt(const DoStmt &S,
@@ -1036,6 +1043,11 @@ void CodeGenFunction::EmitDoStmt(const DoStmt &S,
     EmitBlockWithFallThrough(LoopBody, S.getBody());
   else
     EmitBlockWithFallThrough(LoopBody, &S);
+
+  if (getTarget().getTriple().isSPIRVLogical())
+    ConvergenceTokenStack.push_back(
+        emitConvergenceLoopToken(LoopBody, ConvergenceTokenStack.back()));
+
   {
     RunCleanupsScope BodyScope(*this);
     EmitStmt(S.getBody());
@@ -1090,6 +1102,9 @@ void CodeGenFunction::EmitDoStmt(const DoStmt &S,
   // block.
   if (llvm::EnableSingleByteCoverage)
     incrementProfileCounter(&S);
+
+  if (getTarget().getTriple().isSPIRVLogical())
+    ConvergenceTokenStack.pop_back();
 }
 
 void CodeGenFunction::EmitForStmt(const ForStmt &S,
@@ -1109,6 +1124,10 @@ void CodeGenFunction::EmitForStmt(const ForStmt &S,
   llvm::BasicBlock *CondBlock = CondDest.getBlock();
   EmitBlock(CondBlock);
 
+  if (getTarget().getTriple().isSPIRVLogical())
+    ConvergenceTokenStack.push_back(
+        emitConvergenceLoopToken(CondBlock, ConvergenceTokenStack.back()));
+
   Expr::EvalResult Result;
   bool CondIsConstInt =
       !S.getCond() || S.getCond()->EvaluateAsInt(Result, getContext());
@@ -1222,6 +1241,9 @@ void CodeGenFunction::EmitForStmt(const ForStmt &S,
   // block.
   if (llvm::EnableSingleByteCoverage)
     incrementProfileCounter(&S);
+
+  if (getTarget().getTriple().isSPIRVLogical())
+    ConvergenceTokenStack.pop_back();
 }
 
 void
@@ -1244,6 +1266,10 @@ CodeGenFunction::EmitCXXForRangeStmt(const CXXForRangeStmt &S,
   llvm::BasicBlock *CondBlock = createBasicBlock("for.cond");
   EmitBlock(CondBlock);
 
+  if (getTarget().getTriple().isSPIRVLogical())
+    ConvergenceTokenStack.push_back(
+        emitConvergenceLoopToken(CondBlock, ConvergenceTokenStack.back()));
+
   const SourceRange &R = S.getSourceRange();
   LoopStack.push(CondBlock, CGM.getContext(), CGM.getCodeGenOpts(), ForAttrs,
                  SourceLocToDebugLoc(R.getBegin()),
@@ -1312,6 +1338,9 @@ CodeGenFunction::EmitCXXForRangeStmt(const CXXForRangeStmt &S,
   // block.
   if (llvm::EnableSingleByteCoverage)
     incrementProfileCounter(&S);
+
+  if (getTarget().getTriple().isSPIRVLogical())
+    ConvergenceTokenStack.pop_back();
 }
 
 void CodeGenFunction::EmitReturnOfRValue(RValue RV, QualType Ty) {
@@ -3101,3 +3130,68 @@ CodeGenFunction::GenerateCapturedStmtFunction(const CapturedStmt &S) {
 
   return F;
 }
+
+namespace {
+// Returns the first convergence entry/loop/anchor instruction found in |BB|.
+// std::nullptr otherwise.
+llvm::IntrinsicInst *getConvergenceToken(llvm::BasicBlock *BB) {
+  for (auto &I : *BB) {
+    auto *II = dyn_cast<llvm::IntrinsicInst>(&I);
+    if (II && llvm::isConvergenceControlIntrinsic(II->getIntrinsicID()))
+      return II;
+  }
+  return nullptr;
+}
+
+} // namespace
+
+llvm::CallBase *
+CodeGenFunction::addConvergenceControlToken(llvm::CallBase *Input,
+                                            llvm::Value *ParentToken) {
+  llvm::Value *bundleArgs[] = {ParentToken};
+  llvm::OperandBundleDef OB("convergencectrl", bundleArgs);
+  auto Output = llvm::CallBase::addOperandBundle(
+      Input, llvm::LLVMContext::OB_convergencectrl, OB, Input);
+  Input->replaceAllUsesWith(Output);
+  Input->eraseFromParent();
+  return Output;
+}
+
+llvm::IntrinsicInst *
+CodeGenFunction::emitConvergenceLoopToken(llvm::BasicBlock *BB,
+                                          llvm::Value *ParentToken) {
+  CGBuilderTy::InsertPoint IP = Builder.saveIP();
+
+  if (BB->empty())
+    Builder.SetInsertPoint(BB);
+  else
+    Builder.SetInsertPoint(&BB->front());
+
+  auto CB = Builder.CreateIntrinsic(
+      llvm::Intrinsic::experimental_convergence_loop, {}, {});
+  Builder.restoreIP(IP);
+
+  auto I = addConvergenceControlToken(CB, ParentToken);
+  return cast<llvm::IntrinsicInst>(I);
+}
+
+llvm::IntrinsicInst *
+CodeGenFunction::getOrEmitConvergenceEntryToken(llvm::Function *F) {
+  auto *BB = &F->getEntryBlock();
+  auto *token = getConvergenceToken(BB);
+  if (token)
+    return token;
+
+  // Adding a convergence token requires the function to be marked as
+  // convergent.
+  F->setConvergent();
+
+  CGBuilderTy::InsertPoint IP = Builder.saveIP();
+  Builder.SetInsertPoint(&BB->front());
+  auto I = Builder.CreateIntrinsic(
+      llvm::Intrinsic::experimental_convergence_entry, {}, {});
+  assert(isa<llvm::IntrinsicInst>(I));
+  Builder.restoreIP(IP);
+
+  return cast<llvm::IntrinsicInst>(I);
+}
diff --git a/clang/lib/CodeGen/CodeGenFunction.cpp b/clang/lib/CodeGen/CodeGenFunction.cpp
index 6474d6c8c1d1e4..8f3327bf12a4b3 100644
--- a/clang/lib/CodeGen/CodeGenFunction.cpp
+++ b/clang/lib/CodeGen/CodeGenFunction.cpp
@@ -347,6 +347,12 @@ void CodeGenFunction::FinishFunction(SourceLocation EndLoc) {
   assert(BreakContinueStack.empty() &&
          "mismatched push/pop in break/continue stack!");
 
+  if (getTarget().getTriple().isSPIRVLogical()) {
+    ConvergenceTokenStack.pop_back();
+    assert(ConvergenceTokenStack.empty() &&
+           "mismatched push/pop in convergence stack!");
+  }
+
   bool OnlySimpleReturnStmts = NumSimpleReturnExprs > 0
     && NumSimpleReturnExprs == NumReturnExprs
     && ReturnBlock.getBlock()->use_empty();
@@ -1271,6 +1277,9 @@ void CodeGenFunction::StartFunction(GlobalDecl GD, QualType RetTy,
   if (CurFuncDecl)
     if (const auto *VecWidth = CurFuncDecl->getAttr<MinVectorWidthAttr>())
       LargestVectorWidth = VecWidth->getVectorWidth();
+
+  if (getTarget().getTriple().isSPIRVLogical())
+    ConvergenceTokenStack.push_back(getOrEmitConvergenceEntryToken(CurFn));
 }
 
 void CodeGenFunction::EmitFunctionBody(const Stmt *Body) {
diff --git a/clang/lib/CodeGen/CodeGenFunction.h b/clang/lib/CodeGen/CodeGenFunction.h
index e2a7e28c8211ea..12c5e71bf6af60 100644
--- a/clang/lib/CodeGen/CodeGenFunction.h
+++ b/clang/lib/CodeGen/CodeGenFunction.h
@@ -314,6 +314,9 @@ class CodeGenFunction : public CodeGenTypeCache {
   /// Stack to track the Logical Operator recursion nest for MC/DC.
   SmallVector<const BinaryOperator *, 16> MCDCLogOpStack;
 
+  /// Stack to track the controlled convergence tokens.
+  SmallVector<llvm::IntrinsicInst *, 4> ConvergenceTokenStack;
+
   /// Number of nested loop to be consumed by the last surrounding
   /// loop-associated directive.
   int ExpectedOMPLoopDepth = 0;
@@ -4987,7 +4990,11 @@ class CodeGenFunction : public CodeGenTypeCache {
                                      const llvm::Twine &Name = "");
   // Adds a convergence_ctrl token to |Input| and emits the required parent
   // convergence instructions.
-  llvm::CallBase *addControlledConvergenceToken(llvm::CallBase *Input);
+  template <typename CallType>
+  CallType *addControlledConvergenceToken(CallType *Input) {
+    return dyn_cast<CallType>(
+        addConvergenceControlToken(Input, ConvergenceTokenStack.back()));
+  }
 
 private:
   // Emits a convergence_loop instruction for the given |BB|, with |ParentToken|
diff --git a/clang/test/CodeGenHLSL/builtins/RWBuffer-constructor.hlsl b/clang/test/CodeGenHLSL/builtins/RWBuffer-constructor.hlsl
index 74b3f59bf7600f..e51eac7f57c2d3 100644
--- a/clang/test/CodeGenHLSL/builtins/RWBuffer-constructor.hlsl
+++ b/clang/test/CodeGenHLSL/builtins/RWBuffer-constructor.hlsl
@@ -1,4 +1,3 @@
-// RUN: %clang_cc1 -triple dxil-pc-shadermodel6.3-library -x hlsl -emit-llvm -disable-llvm-passes -o - %s | FileCheck %s
 // RUN: %clang_cc1 -triple spirv-vulkan-library -x hlsl -emit-llvm -disable-llvm-passes -o - %s | FileCheck %s --check-prefix=CHECK-SPIRV
 
 RWBuffer<float> Buf;
diff --git a/clang/test/CodeGenHLSL/convergence/do.while.hlsl b/clang/test/CodeGenHLSL/convergence/do.while.hlsl
new file mode 100644
index 00000000000000..ea5a45ba8fd780
--- /dev/null
+++ b/clang/test/CodeGenHLSL/convergence/do.while.hlsl
@@ -0,0 +1,90 @@
+// RUN: %clang_cc1 -std=hlsl2021 -finclude-default-header -x hlsl -triple \
+// RUN:   spirv-pc-vulkan-library %s -emit-llvm -disable-llvm-passes -o - | FileCheck %s
+
+bool cond();
+void foo();
+
+void test1() {
+  do {
+  } while (cond());
+}
+// CHECK: define spir_func void @_Z5test1v() [[A0:#[0-9]+]] {
+// CHECK: entry:
+// CHECK:   [[T0:%[0-9]+]] = call token @llvm.experimental.convergence.entry()
+// CHECK: do.body:
+// CHECK:   [[T1:%[0-9]+]] = call token @llvm.experimental.convergence.loop() [ "convergencectrl"(token [[T0]]) ]
+// CHECK: do.cond:
+// CHECK:                    call spir_func noundef i1 @_Z4condv() [[A3:#[0-9]+]] [ "convergencectrl"(token [[T1]]) ]
+
+void test2() {
+  do {
+    foo();
+  } while (cond());
+}
+// CHECK: define spir_func void @_Z5test2v() [[A0:#[0-9]+]] {
+// CHECK: entry:
+// CHECK:   [[T0:%[0-9]+]] = call token @llvm.experimental.convergence.entry()
+// CHECK: do.body:
+// CHECK:   [[T1:%[0-9]+]] = call token @llvm.experimental.convergence.loop() [ "convergencectrl"(token [[T0]]) ]
+// CHECK:                    call spir_func void @_Z3foov() [[A3]] [ "convergencectrl"(token [[T1]]) ]
+// CHECK: do.cond:
+// CHECK:                    call spir_func noundef i1 @_Z4condv() [[A3:#[0-9]+]] [ "convergencectrl"(token [[T1]]) ]
+
+void test3() {
+  do {
+    if (cond())
+      foo();
+  } while (cond());
+}
+// CHECK: define spir_func void @_Z5test3v() [[A0:#[0-9]+]] {
+// CHECK: entry:
+// CHECK:   [[T0:%[0-9]+]] = call token @llvm.experimental.convergence.entry()
+// CHECK: do.body:
+// CHECK:   [[T1:%[0-9]+]] = call token @llvm.experimental.convergence.loop() [ "convergencectrl"(token [[T0]]) ]
+// CHECK: if.then:
+// CHECK:                    call spir_func void @_Z3foov() [[A3]] [ "convergencectrl"(token [[T1]]) ]
+// CHECK: do.cond:
+// CHECK:                    call spir_func noundef i1 @_Z4condv() [[A3:#[0-9]+]] [ "convergencectrl"(token [[T1]]) ]
+
+void test4() {
+  do {
+    if (cond()) {
+      foo();
+      break;
+    }
+  } while (cond());
+}
+// CHECK: define spir_func void @_Z5test4v() [[A0:#[0-9]+]] {
+// CHECK: entry:
+// CHECK:   [[T0:%[0-9]+]] = call token @llvm.experimental.convergence.entry()
+// CHECK: do.body:
+// CHECK:   [[T1:%[0-9]+]] = call token @llvm.experimental.convergence.loop() [ "convergencectrl"(token [[T0]]) ]
+// CHECK: if.then:
+// CHECK:                    call spir_func void @_Z3foov() [[A3]] [ "convergencectrl"(token [[T1]]) ]
+// CHECK: do.cond:
+// CHECK:                    call spir_func noundef i1 @_Z4condv() [[A3:#[0-9]+]] [ "convergencectrl"(token [[T1]]) ]
+
+void test5() {
+  do {
+    while (cond()) {
+      if (cond()) {
+        foo();
+        break;
+      }
+    }
+  } while (cond());
+}
+// CHECK: define spir_func void @_Z5test5v() [[A0:#[0-9]+]] {
+// CHECK: entry:
+// CHECK:   [[T0:%[0-9]+]] = call token @llvm.experimental.convergence.entry()
+// CHECK: do.body:
+// CHECK:   [[T1:%[0-9]+]] = call token @llvm.experimental.convergence.loop() [ "convergencectrl"(token [[T0]]) ]
+// CHECK: while.cond:
+// CHECK:   [[T2:%[0-9]+]] = call token @llvm.experimental.convergence.loop() [ "convergencectrl"(token [[T1]]) ]
+// CHECK: if.then:
+// CHECK:                    call spir_func void @_Z3foov() [[A3]] [ "convergencectrl"(token [[T2]]) ]
+// CHECK: do.cond:
+// CHECK:                    call spir_func noundef i1 @_Z4condv() [[A3:#[0-9]+]] [ "convergencectrl"(token [[T1]]) ]
+
+// CHECK-DAG: attributes [[A0]] = { {{.*}}convergent{{.*}} }
+// CHECK-DAG: attributes [[A3]] = { {{.*}}convergent{{.*}} }
diff --git a/clang/test/CodeGenHLSL/convergence/for.hlsl b/clang/test/CodeGenHLSL/convergence/for.hlsl
new file mode 100644
index 00000000000000..180fae74ba7514
--- /dev/null
+++ b/clang/test/CodeGenHLSL/convergence/for.hlsl
@@ -0,0 +1,121 @@
+// RUN: %clang_cc1 -std=hlsl2021 -finclude-default-header -x hlsl -triple \
+// RUN:   spirv-pc-vulkan-library %s -emit-llvm -disable-llvm-passes -o - | FileCheck %s
+
+bool cond();
+bool cond2();
+void foo();
+
+void test1() {
+  for (;;) {
+    foo();
+  }
+}
+// CHECK: define spir_func void @_Z5test1v() [[A0:#[0-9]+]] {
+// CHECK: entry:
+// CHECK:   [[T0:%[0-9]+]] = call token @llvm.experimental.convergence.entry()
+// CHECK: for.cond:
+// CHECK:   [[T1:%[0-9]+]] = call token @llvm.experimental.convergence.loop() [ "convergencectrl"(token [[T0]]) ]
+// CHECK:                    call spir_func void @_Z3foov() [[A3:#[0-9]+]] [ "convergencectrl"(token [[T1]]) ]
+
+void test2() {
+  for (;cond();) {
+    foo();
+  }
+}
+// CHECK: define spir_func void @_Z5test2v() [[A0:#[0-9]+]] {
+// CHECK: entry:
+// CHECK:   [[T0:%[0-9]+]] = call token @llvm.experimental.convergence.entry()
+// CHECK: for.cond:
+// CHECK:   [[T1:%[0-9]+]] = call token @llvm.experimental.convergence.loop() [ "convergencectrl"(token [[T0]]) ]
+// CHECK:                    call spir_func noundef i1 @_Z4condv() [[A3]] [ "convergencectrl"(token [[T1]]) ]
+// CHECK: for.body:
+// CHECK:                    call spir_func void @_Z3foov() [[A3:#[0-9]+]] [ "convergencectrl"(token [[T1]]) ]
+
+void test3() {
+  for (cond();;) {
+    foo();
+  }
+}
+// CHECK: define spir_func void @_Z5test3v() [[A0:#[0-9]+]] {
+// CHECK: entry:
+// CHECK:   [[T0:%[0-9]+]] = call token @llvm.experimental.convergence.entry()
+// CHECK:                    call spir_func noundef i1 @_Z4condv() [[A3]] [ "convergencectrl"(token [[T0]]) ]
+// CHECK: for.cond:
+// CHECK:   [[T1:%[0-9]+]] = call token @llvm.experimental.convergence.loop() [ "convergencectrl"(token [[T0]]) ]
+// CHECK:                    call spir_func void @_Z3foov() [[A3:#[0-9]+]] [ "convergencectrl"(token [[T1]]) ]
+
+void test4() {
+  for (cond();cond2();) {
+    foo();
+  }
+}
+// CHECK: define spir_func void @_Z5test4v() [[A0:#[0-9]+]] {
+// CHECK: entry:
+// CHECK:   [[T0:%[0-9]+]] = call token @llvm.experimental.convergence.entry()
+// CHECK:                    call spir_func noundef i1 @_Z4condv() [[A3]] [ "convergencectrl"(token [[T0]]) ]
+// CHECK: for.cond:
+// CHECK:   [[T1:%[0-9]+]] = call token @llvm.experimental.convergence.loop() [ "convergencectrl"(token [[T0]]) ]
+// CHECK:                    call spir_func noundef i1 @_Z5cond2v() [[A3]] [ "convergencectrl"(token [[T1]]) ]
+// CHECK: for.body:
+// CHECK:                    call spir_func void @_Z3foov() [[A3:#[0-9]+]] [ "convergencectrl"(token [[T1]]) ]
+
+void test5() {
+  for (cond();cond2();foo()) {
+  }
+}
+// CHECK: define spir_func void @_Z5test5v() [[A0:#[0-9]+]] {
+// CHECK: entry:
+// CHECK:   [[T0:%[0-9]+]] = call token @llvm.experimental.convergence.entry()
+// CHECK:                    call spir_func noundef i1 @_Z4condv() [[A3]] [ "convergencectrl"(token [[T0]]) ]
+// CHECK: for.cond:
+// CHECK:   [[T1:%[0-9]+]] = call token @llvm.experimental.convergence.loop() [ "convergencectrl"(token [[T0]]) ]
+// CHECK:                    call spir_func noundef i1 @_Z5cond2v() [[A3]] [ "convergencectrl"(token [[T1]]) ]
+// CHECK: for.inc:
+// CHECK:                    call spir_func void @_Z3foov() [[A3:#[0-9]+]] [ "convergencectrl"(token [[T1]]) ]
+
+void test6() {
+  for (cond();cond2();foo()) {
+    if (cond()) {
+      foo();
+      break;
+    }
+  }
+}
+// CHECK: defin...
[truncated]

@Keenuts Keenuts requested a review from arsenm April 16, 2024 15:40
clang/lib/CodeGen/CGCall.cpp Outdated Show resolved Hide resolved
clang/lib/CodeGen/CodeGenFunction.h Outdated Show resolved Hide resolved
clang/lib/CodeGen/CGCall.cpp Outdated Show resolved Hide resolved
clang/lib/CodeGen/CGStmt.cpp Outdated Show resolved Hide resolved
clang/lib/CodeGen/CGStmt.cpp Outdated Show resolved Hide resolved
clang/lib/CodeGen/CGStmt.cpp Outdated Show resolved Hide resolved
@Keenuts Keenuts force-pushed the frontend-convergence-greedy branch from c9aa6a6 to 131ea26 Compare May 7, 2024 13:35
@Keenuts
Copy link
Contributor Author

Keenuts commented May 7, 2024

Hi all, rebased on main, and addressed the comments.
This commits changes the register order on SPIR-V vs DXIL, which required me to fix the mad+lerp intrinsic tests. Should be NFC, just storing the register name in a CHECK variable.

@arsenm arsenm changed the title [clang][SPIR-V] Always add convervence intrinsics [clang][SPIR-V] Always add convergence intrinsics May 7, 2024
clang/lib/CodeGen/CodeGenModule.h Show resolved Hide resolved
@Keenuts
Copy link
Contributor Author

Keenuts commented May 13, 2024

Thanks for the reviews. Waiting for 1 approval from MS and I'll merge

PR llvm#80680 added bits in the codegen to lazily add convergence intrinsics
when required. This logic relied on the LoopStack. The issue is
when parsing the condition, the loopstack doesn't yet reflect the
correct values, as expected since we are not yet in the loop.

However, convergence tokens should sometimes already be available.
The solution which seemed the simplest is to greedily generate the
tokens when we generate SPIR-V.

Fixes llvm#88144

Signed-off-by: Nathan Gauër <brioche@google.com>
Signed-off-by: Nathan Gauër <brioche@google.com>
Signed-off-by: Nathan Gauër <brioche@google.com>
@Keenuts Keenuts force-pushed the frontend-convergence-greedy branch from 7bf0431 to 166afdb Compare May 14, 2024 09:32
@Keenuts
Copy link
Contributor Author

Keenuts commented May 14, 2024

rebased on main, local tests are passing, waiting on CI to merge.

@Keenuts Keenuts merged commit e08f1fd into llvm:main May 14, 2024
3 of 4 checks passed
@Keenuts Keenuts deleted the frontend-convergence-greedy branch May 14, 2024 15:00
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
clang:codegen clang Clang issues not falling into any other category HLSL HLSL Language Support
Projects
Status: Done
Development

Successfully merging this pull request may close these issues.

[HLSL] Handle convergence intrinsic generation for do...while & while conditions
6 participants