Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[SPIR-V] Add pass to merge convergence region exit targets #92531

Merged
merged 5 commits into from
Jun 3, 2024

Conversation

Keenuts
Copy link
Contributor

@Keenuts Keenuts commented May 17, 2024

The structurizer required regions to be SESE: single entry, single exit.
This new pass transforms multiple-exit regions into single-exit regions.

      +---+
      | A |
      +---+
      /   \
   +---+ +---+
   | B | | C |  A, B & C belongs to the same convergence region.
   +---+ +---+
     |     |
   +---+ +---+
   | D | | E |  C & D belongs to the parent convergence region.
   +---+ +---+  This means B & C are the exit blocks of the region.
      \   /     And D & E the targets of those exits.
       \ /
        |
      +---+
      | F |
      +---+

This pass would assign one value per exit target:
B = 0
C = 1

Then, create one variable per exit block (B, C), and assign it to the correct value: in B, the variable will have the value 0, and in C, the value 1.

Then, we'd create a new block H, with a PHI node to gather those 2 variables, and a switch, to route to the correct target.

Finally, the branches in B and C are updated to exit to this new block.

      +---+
      | A |
      +---+
      /   \
   +---+ +---+
   | B | | C |
   +---+ +---+
      \   /
      +---+
      | H |
      +---+
      /   \
   +---+ +---+
   | D | | E |
   +---+ +---+
      \   /
       \ /
        |
      +---+
      | F |
      +---+

Note: the variable is set depending on the condition used to branch. If B's terminator was conditional, the variable would be set using a SELECT.
All internal edges of a region are left intact, only exiting edges are updated.

@llvmbot
Copy link
Collaborator

llvmbot commented May 17, 2024

@llvm/pr-subscribers-backend-spir-v

Author: Nathan Gauër (Keenuts)

Changes

The structurizer required regions to be SESE: single entry, single exit.
This new pass transforms multiple-exit regions into single-exit regions.

      +---+
      | A |
      +---+
      /   \
   +---+ +---+
   | B | | C |  A, B & C belongs to the same convergence region.
   +---+ +---+
     |     |
   +---+ +---+
   | D | | E |  C & D belongs to the parent convergence region.
   +---+ +---+  This means B & C are the exit blocks of the region.
      \   /     And D & E the targets of those exits.
       \ /
        |
      +---+
      | F |
      +---+

This pass would assign one value per exit target:
B = 0
C = 1

Then, create one variable per exit block (B, C), and assign it to the correct value: in B, the variable will have the value 0, and in C, the value 1.

Then, we'd create a new block H, with a PHI node to gather those 2 variables, and a switch, to route to the correct target.

Finally, the branches in B and C are updated to exit to this new block.

      +---+
      | A |
      +---+
      /   \
   +---+ +---+
   | B | | C |
   +---+ +---+
      \   /
      +---+
      | H |
      +---+
      /   \
   +---+ +---+
   | D | | E |
   +---+ +---+
      \   /
       \ /
        |
      +---+
      | F |
      +---+

Note: the variable is set depending on the condition used to branch. If B's terminator was conditional, the variable would be set using a SELECT.
All internal edges of a region are left intact, only exiting edges are updated.


Patch is 29.86 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/92531.diff

10 Files Affected:

  • (modified) llvm/lib/Target/SPIRV/CMakeLists.txt (+1)
  • (modified) llvm/lib/Target/SPIRV/SPIRV.h (+1)
  • (modified) llvm/lib/Target/SPIRV/SPIRVEmitIntrinsics.cpp (+23)
  • (modified) llvm/lib/Target/SPIRV/SPIRVInstrInfo.td (+1-1)
  • (added) llvm/lib/Target/SPIRV/SPIRVMergeRegionExitTargets.cpp (+290)
  • (modified) llvm/lib/Target/SPIRV/SPIRVTargetMachine.cpp (+1)
  • (added) llvm/test/CodeGen/SPIRV/structurizer/merge-exit-break.ll (+84)
  • (added) llvm/test/CodeGen/SPIRV/structurizer/merge-exit-convergence-in-break.ll (+94)
  • (added) llvm/test/CodeGen/SPIRV/structurizer/merge-exit-multiple-break.ll (+103)
  • (added) llvm/test/CodeGen/SPIRV/structurizer/merge-exit-simple-white-identity.ll (+49)
diff --git a/llvm/lib/Target/SPIRV/CMakeLists.txt b/llvm/lib/Target/SPIRV/CMakeLists.txt
index 7001ac382f41c..35a463a89ec64 100644
--- a/llvm/lib/Target/SPIRV/CMakeLists.txt
+++ b/llvm/lib/Target/SPIRV/CMakeLists.txt
@@ -24,6 +24,7 @@ add_llvm_target(SPIRVCodeGen
   SPIRVInstrInfo.cpp
   SPIRVInstructionSelector.cpp
   SPIRVStripConvergentIntrinsics.cpp
+  SPIRVMergeRegionExitTargets.cpp
   SPIRVISelLowering.cpp
   SPIRVLegalizerInfo.cpp
   SPIRVMCInstLower.cpp
diff --git a/llvm/lib/Target/SPIRV/SPIRV.h b/llvm/lib/Target/SPIRV/SPIRV.h
index fb8580cd47c01..e597a1dc8dc06 100644
--- a/llvm/lib/Target/SPIRV/SPIRV.h
+++ b/llvm/lib/Target/SPIRV/SPIRV.h
@@ -20,6 +20,7 @@ class InstructionSelector;
 class RegisterBankInfo;
 
 ModulePass *createSPIRVPrepareFunctionsPass(const SPIRVTargetMachine &TM);
+FunctionPass *createSPIRVMergeRegionExitTargetsPass();
 FunctionPass *createSPIRVStripConvergenceIntrinsicsPass();
 FunctionPass *createSPIRVRegularizerPass();
 FunctionPass *createSPIRVPreLegalizerPass();
diff --git a/llvm/lib/Target/SPIRV/SPIRVEmitIntrinsics.cpp b/llvm/lib/Target/SPIRV/SPIRVEmitIntrinsics.cpp
index c00066f5dca62..2f294fdb4075e 100644
--- a/llvm/lib/Target/SPIRV/SPIRVEmitIntrinsics.cpp
+++ b/llvm/lib/Target/SPIRV/SPIRVEmitIntrinsics.cpp
@@ -150,6 +150,16 @@ class SPIRVEmitIntrinsics
     ModulePass::getAnalysisUsage(AU);
   }
 };
+
+bool isConvergenceIntrinsic(const Instruction *I) {
+  const auto *II = dyn_cast<IntrinsicInst>(I);
+  if (!II)
+    return false;
+
+  return II->getIntrinsicID() == Intrinsic::experimental_convergence_entry ||
+         II->getIntrinsicID() == Intrinsic::experimental_convergence_loop ||
+         II->getIntrinsicID() == Intrinsic::experimental_convergence_anchor;
+}
 } // namespace
 
 char SPIRVEmitIntrinsics::ID = 0;
@@ -1067,6 +1077,10 @@ void SPIRVEmitIntrinsics::processGlobalValue(GlobalVariable &GV,
 
 void SPIRVEmitIntrinsics::insertAssignPtrTypeIntrs(Instruction *I,
                                                    IRBuilder<> &B) {
+  // Don't assign types to LLVM tokens.
+  if (isConvergenceIntrinsic(I))
+    return;
+
   reportFatalOnTokenType(I);
   if (!isPointerTy(I->getType()) || !requireAssignType(I) ||
       isa<BitCastInst>(I))
@@ -1085,6 +1099,10 @@ void SPIRVEmitIntrinsics::insertAssignPtrTypeIntrs(Instruction *I,
 
 void SPIRVEmitIntrinsics::insertAssignTypeIntrs(Instruction *I,
                                                 IRBuilder<> &B) {
+  // Don't assign types to LLVM tokens.
+  if (isConvergenceIntrinsic(I))
+    return;
+
   reportFatalOnTokenType(I);
   Type *Ty = I->getType();
   if (!Ty->isVoidTy() && !isPointerTy(Ty) && requireAssignType(I)) {
@@ -1312,6 +1330,11 @@ bool SPIRVEmitIntrinsics::runOnFunction(Function &Func) {
     I = visit(*I);
     if (!I)
       continue;
+
+    // Don't emit intrinsics for convergence operations.
+    if (isConvergenceIntrinsic(I))
+      continue;
+
     processInstrAfterVisit(I, B);
   }
 
diff --git a/llvm/lib/Target/SPIRV/SPIRVInstrInfo.td b/llvm/lib/Target/SPIRV/SPIRVInstrInfo.td
index 151d0ec1fe569..6634481daf12e 100644
--- a/llvm/lib/Target/SPIRV/SPIRVInstrInfo.td
+++ b/llvm/lib/Target/SPIRV/SPIRVInstrInfo.td
@@ -615,7 +615,7 @@ def OpFwidthCoarse: UnOp<"OpFwidthCoarse", 215>;
 def OpPhi: Op<245, (outs ID:$res), (ins TYPE:$type, ID:$var0, ID:$block0, variable_ops),
                   "$res = OpPhi $type $var0 $block0">;
 def OpLoopMerge: Op<246, (outs), (ins ID:$merge, ID:$continue, LoopControl:$lc, variable_ops),
-                  "OpLoopMerge $merge $merge $continue $lc">;
+                  "OpLoopMerge $merge $continue $lc">;
 def OpSelectionMerge: Op<247, (outs), (ins ID:$merge, SelectionControl:$sc),
                   "OpSelectionMerge $merge $sc">;
 def OpLabel: Op<248, (outs ID:$label), (ins), "$label = OpLabel">;
diff --git a/llvm/lib/Target/SPIRV/SPIRVMergeRegionExitTargets.cpp b/llvm/lib/Target/SPIRV/SPIRVMergeRegionExitTargets.cpp
new file mode 100644
index 0000000000000..13781e24f0d42
--- /dev/null
+++ b/llvm/lib/Target/SPIRV/SPIRVMergeRegionExitTargets.cpp
@@ -0,0 +1,290 @@
+//===-- SPIRVMergeRegionExitTargets.cpp ----------------------*- C++ -*-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// Merge the multiple exit targets of a convergence region into a single block.
+// Each exit target will be assigned a constant value, and a phi node + switch
+// will allow the new exit target to re-route to the correct basic block.
+//
+//===----------------------------------------------------------------------===//
+
+#include "Analysis/SPIRVConvergenceRegionAnalysis.h"
+#include "SPIRV.h"
+#include "SPIRVSubtarget.h"
+#include "SPIRVTargetMachine.h"
+#include "SPIRVUtils.h"
+#include "llvm/Analysis/LoopInfo.h"
+#include "llvm/CodeGen/IntrinsicLowering.h"
+#include "llvm/IR/CFG.h"
+#include "llvm/IR/Dominators.h"
+#include "llvm/IR/IRBuilder.h"
+#include "llvm/IR/IntrinsicInst.h"
+#include "llvm/IR/Intrinsics.h"
+#include "llvm/IR/IntrinsicsSPIRV.h"
+#include "llvm/InitializePasses.h"
+#include "llvm/Transforms/Utils/Cloning.h"
+#include "llvm/Transforms/Utils/LoopSimplify.h"
+#include "llvm/Transforms/Utils/LowerMemIntrinsics.h"
+
+using namespace llvm;
+
+namespace llvm {
+void initializeSPIRVMergeRegionExitTargetsPass(PassRegistry &);
+} // namespace llvm
+
+namespace llvm {
+
+class SPIRVMergeRegionExitTargets : public FunctionPass {
+public:
+  static char ID;
+
+  SPIRVMergeRegionExitTargets() : FunctionPass(ID) {
+    initializeSPIRVMergeRegionExitTargetsPass(*PassRegistry::getPassRegistry());
+  };
+
+  // Gather all the successors of |BB|.
+  // This function asserts if the terminator neither a branch, switch or return.
+  std::unordered_set<BasicBlock *> gatherSuccessors(BasicBlock *BB) {
+    std::unordered_set<BasicBlock *> output;
+    auto *T = BB->getTerminator();
+
+    if (auto *BI = dyn_cast<BranchInst>(T)) {
+      output.insert(BI->getSuccessor(0));
+      if (BI->isConditional())
+        output.insert(BI->getSuccessor(1));
+      return output;
+    }
+
+    if (auto *SI = dyn_cast<SwitchInst>(T)) {
+      output.insert(SI->getDefaultDest());
+      for (auto &Case : SI->cases()) {
+        output.insert(Case.getCaseSuccessor());
+      }
+      return output;
+    }
+
+    if (auto *RI = dyn_cast<ReturnInst>(T))
+      return output;
+
+    assert(false && "Unhandled terminator type.");
+    return output;
+  }
+
+  /// Create a value in BB set to the value associated with the branch the block
+  /// terminator will take.
+  llvm::Value *createExitVariable(
+      BasicBlock *BB,
+      const std::unordered_map<BasicBlock *, ConstantInt *> &TargetToValue) {
+    auto *T = BB->getTerminator();
+    if (auto *RI = dyn_cast<ReturnInst>(T)) {
+      return nullptr;
+    }
+
+    IRBuilder<> Builder(BB);
+    Builder.SetInsertPoint(T);
+
+    if (auto *BI = dyn_cast<BranchInst>(T)) {
+
+      BasicBlock *LHSTarget = BI->getSuccessor(0);
+      BasicBlock *RHSTarget =
+          BI->isConditional() ? BI->getSuccessor(1) : nullptr;
+
+      Value *LHS = TargetToValue.count(LHSTarget) != 0
+                       ? TargetToValue.at(LHSTarget)
+                       : nullptr;
+      Value *RHS = TargetToValue.count(RHSTarget) != 0
+                       ? TargetToValue.at(RHSTarget)
+                       : nullptr;
+
+      if (LHS == nullptr || RHS == nullptr)
+        return LHS == nullptr ? RHS : LHS;
+      return Builder.CreateSelect(BI->getCondition(), LHS, RHS);
+    }
+
+    // TODO: add support for switch cases.
+    assert(false && "Unhandled terminator type.");
+  }
+
+  /// Replaces |BB|'s branch targets present in |ToReplace| with |NewTarget|.
+  void replaceBranchTargets(BasicBlock *BB,
+                            const std::unordered_set<BasicBlock *> ToReplace,
+                            BasicBlock *NewTarget) {
+    auto *T = BB->getTerminator();
+    if (auto *RI = dyn_cast<ReturnInst>(T))
+      return;
+
+    if (auto *BI = dyn_cast<BranchInst>(T)) {
+      for (size_t i = 0; i < BI->getNumSuccessors(); i++) {
+        if (ToReplace.count(BI->getSuccessor(i)) != 0)
+          BI->setSuccessor(i, NewTarget);
+      }
+      return;
+    }
+
+    if (auto *SI = dyn_cast<SwitchInst>(T)) {
+      for (size_t i = 0; i < SI->getNumSuccessors(); i++) {
+        if (ToReplace.count(SI->getSuccessor(i)) != 0)
+          SI->setSuccessor(i, NewTarget);
+      }
+      return;
+    }
+
+    assert(false && "Unhandled terminator type.");
+  }
+
+  // Run the pass on the given convergence region, ignoring the sub-regions.
+  // Returns true if the CFG changed, false otherwise.
+  bool runOnConvergenceRegionNoRecurse(LoopInfo &LI,
+                                       const SPIRV::ConvergenceRegion *CR) {
+    // Gather all the exit targets for this region.
+    std::unordered_set<BasicBlock *> ExitTargets;
+    for (BasicBlock *Exit : CR->Exits) {
+      for (BasicBlock *Target : gatherSuccessors(Exit)) {
+        if (CR->Blocks.count(Target) == 0)
+          ExitTargets.insert(Target);
+      }
+    }
+
+    // If we have zero or one exit target, nothing do to.
+    if (ExitTargets.size() <= 1)
+      return false;
+
+    // Create the new single exit target.
+    auto F = CR->Entry->getParent();
+    auto NewExitTarget = BasicBlock::Create(F->getContext(), "new.exit", F);
+    IRBuilder<> Builder(NewExitTarget);
+
+    // CodeGen output needs to be stable. Using the set as-is would order
+    // the targets differently depending on the allocation pattern.
+    // Sorting per basic-block ordering in the function.
+    std::vector<BasicBlock *> SortedExitTargets;
+    std::vector<BasicBlock *> SortedExits;
+    for (BasicBlock &BB : *F) {
+      if (ExitTargets.count(&BB) != 0)
+        SortedExitTargets.push_back(&BB);
+      if (CR->Exits.count(&BB) != 0)
+        SortedExits.push_back(&BB);
+    }
+
+    // Creating one constant per distinct exit target. This will be route to the
+    // correct target.
+    std::unordered_map<BasicBlock *, ConstantInt *> TargetToValue;
+    for (BasicBlock *Target : SortedExitTargets)
+      TargetToValue.emplace(Target, Builder.getInt32(TargetToValue.size()));
+
+    // Creating one variable per exit node, set to the constant matching the
+    // targeted external block.
+    std::vector<std::pair<BasicBlock *, Value *>> ExitToVariable;
+    for (auto Exit : SortedExits) {
+      llvm::Value *Value = createExitVariable(Exit, TargetToValue);
+      ExitToVariable.emplace_back(std::make_pair(Exit, Value));
+    }
+
+    // Gather the correct value depending on the exit we came from.
+    llvm::PHINode *node =
+        Builder.CreatePHI(Builder.getInt32Ty(), ExitToVariable.size());
+    for (auto [BB, Value] : ExitToVariable) {
+      node->addIncoming(Value, BB);
+    }
+
+    // Creating the switch to jump to the correct exit target.
+    std::vector<std::pair<BasicBlock *, ConstantInt *>> CasesList(
+        TargetToValue.begin(), TargetToValue.end());
+    llvm::SwitchInst *Sw =
+        Builder.CreateSwitch(node, CasesList[0].first, CasesList.size() - 1);
+    for (size_t i = 1; i < CasesList.size(); i++)
+      Sw->addCase(CasesList[i].second, CasesList[i].first);
+
+    // Fix exit branches to redirect to the new exit.
+    for (auto Exit : CR->Exits)
+      replaceBranchTargets(Exit, ExitTargets, NewExitTarget);
+
+    return true;
+  }
+
+  /// Run the pass on the given convergence region and sub-regions (DFS).
+  /// Returns true if a region/sub-region was modified, false otherwise.
+  /// This returns as soon as one region/sub-region has been modified.
+  bool runOnConvergenceRegion(LoopInfo &LI,
+                              const SPIRV::ConvergenceRegion *CR) {
+    for (auto *Child : CR->Children)
+      if (runOnConvergenceRegion(LI, Child))
+        return true;
+
+    return runOnConvergenceRegionNoRecurse(LI, CR);
+  }
+
+#if !NDEBUG
+  /// Validates each edge exiting the region has the same destination basic
+  /// block.
+  void validateRegionExits(const SPIRV::ConvergenceRegion *CR) {
+    for (auto *Child : CR->Children)
+      validateRegionExits(Child);
+
+    std::unordered_set<BasicBlock *> ExitTargets;
+    for (auto *Exit : CR->Exits) {
+      auto Set = gatherSuccessors(Exit);
+      for (auto *BB : Set) {
+        if (CR->Blocks.count(BB) == 0)
+          ExitTargets.insert(BB);
+      }
+    }
+
+    assert(ExitTargets.size() <= 1);
+  }
+#endif
+
+  virtual bool runOnFunction(Function &F) override {
+    LoopInfo &LI = getAnalysis<LoopInfoWrapperPass>().getLoopInfo();
+    const auto *TopLevelRegion =
+        getAnalysis<SPIRVConvergenceRegionAnalysisWrapperPass>()
+            .getRegionInfo()
+            .getTopLevelRegion();
+
+    // FIXME: very inefficient method: each time a region is modified, we bubble
+    // back up, and recompute the whole convergence region tree. Once the
+    // algorithm is completed and test coverage good enough, rewrite this pass
+    // to be efficient instead of simple.
+    bool modified = false;
+    while (runOnConvergenceRegion(LI, TopLevelRegion)) {
+      TopLevelRegion = getAnalysis<SPIRVConvergenceRegionAnalysisWrapperPass>()
+                           .getRegionInfo()
+                           .getTopLevelRegion();
+      modified = true;
+    }
+
+    F.dump();
+#if !NDEBUG
+    validateRegionExits(TopLevelRegion);
+#endif
+    return modified;
+  }
+
+  void getAnalysisUsage(AnalysisUsage &AU) const override {
+    AU.addRequired<DominatorTreeWrapperPass>();
+    AU.addRequired<LoopInfoWrapperPass>();
+    AU.addRequired<SPIRVConvergenceRegionAnalysisWrapperPass>();
+    FunctionPass::getAnalysisUsage(AU);
+  }
+};
+} // namespace llvm
+
+char SPIRVMergeRegionExitTargets::ID = 0;
+
+INITIALIZE_PASS_BEGIN(SPIRVMergeRegionExitTargets, "split-region-exit-blocks",
+                      "SPIRV split region exit blocks", false, false)
+INITIALIZE_PASS_DEPENDENCY(LoopSimplify)
+INITIALIZE_PASS_DEPENDENCY(DominatorTreeWrapperPass)
+INITIALIZE_PASS_DEPENDENCY(LoopInfoWrapperPass)
+INITIALIZE_PASS_DEPENDENCY(SPIRVConvergenceRegionAnalysisWrapperPass)
+
+INITIALIZE_PASS_END(SPIRVMergeRegionExitTargets, "split-region-exit-blocks",
+                    "SPIRV split region exit blocks", false, false)
+
+FunctionPass *llvm::createSPIRVMergeRegionExitTargetsPass() {
+  return new SPIRVMergeRegionExitTargets();
+}
diff --git a/llvm/lib/Target/SPIRV/SPIRVTargetMachine.cpp b/llvm/lib/Target/SPIRV/SPIRVTargetMachine.cpp
index ae8baa3f11913..d0e51caf46e73 100644
--- a/llvm/lib/Target/SPIRV/SPIRVTargetMachine.cpp
+++ b/llvm/lib/Target/SPIRV/SPIRVTargetMachine.cpp
@@ -164,6 +164,7 @@ void SPIRVPassConfig::addIRPasses() {
     //  - all loop exits are dominated by the loop pre-header.
     //  - loops have a single back-edge.
     addPass(createLoopSimplifyPass());
+    addPass(createSPIRVMergeRegionExitTargetsPass());
   }
 
   TargetPassConfig::addIRPasses();
diff --git a/llvm/test/CodeGen/SPIRV/structurizer/merge-exit-break.ll b/llvm/test/CodeGen/SPIRV/structurizer/merge-exit-break.ll
new file mode 100644
index 0000000000000..b3fcdc978625f
--- /dev/null
+++ b/llvm/test/CodeGen/SPIRV/structurizer/merge-exit-break.ll
@@ -0,0 +1,84 @@
+; RUN: llc -mtriple=spirv-unknown-vulkan-compute -O0 %s -o - | FileCheck %s --match-full-lines
+
+target datalayout = "e-i64:64-v16:16-v24:32-v32:32-v48:64-v96:128-v192:256-v256:256-v512:512-v1024:1024-G1"
+target triple = "spirv-unknown-vulkan-compute"
+
+define internal spir_func void @main() #0 {
+
+; CHECK:                      OpDecorate %[[#builtin:]] BuiltIn SubgroupLocalInvocationId
+; CHECK-DAG:  %[[#int_ty:]] = OpTypeInt 32 0
+; CHECK-DAG: %[[#pint_ty:]] = OpTypePointer Function %[[#int_ty]]
+; CHECK-DAG: %[[#bool_ty:]] = OpTypeBool
+; CHECK-DAG:   %[[#int_0:]] = OpConstant %[[#int_ty]] 0
+; CHECK-DAG:   %[[#int_1:]] = OpConstant %[[#int_ty]] 1
+; CHECK-DAG:  %[[#int_10:]] = OpConstant %[[#int_ty]] 10
+
+; CHECK:   %[[#entry:]] = OpLabel
+; CHECK:     %[[#idx:]] = OpVariable %[[#pint_ty]] Function
+; CHECK:                  OpStore %[[#idx]] %[[#int_0]] Aligned 4
+; CHECK:                  OpBranch %[[#while_cond:]]
+entry:
+  %0 = call token @llvm.experimental.convergence.entry()
+  %idx = alloca i32, align 4
+  store i32 0, ptr %idx, align 4
+  br label %while.cond
+
+; CHECK:   %[[#while_cond]] = OpLabel
+; CHECK:         %[[#tmp:]] = OpLoad %[[#int_ty]] %[[#idx]] Aligned 4
+; CHECK:         %[[#cmp:]] = OpINotEqual %[[#bool_ty]] %[[#tmp]] %[[#int_10]]
+; CHECK:                      OpBranchConditional %[[#cmp]] %[[#while_body:]] %[[#new_end:]]
+while.cond:
+  %1 = call token @llvm.experimental.convergence.loop() [ "convergencectrl"(token %0) ]
+  %2 = load i32, ptr %idx, align 4
+  %cmp = icmp ne i32 %2, 10
+  br i1 %cmp, label %while.body, label %while.end
+
+; CHECK:   %[[#while_body]] = OpLabel
+; CHECK-NEXT:    %[[#tmp:]] = OpLoad %[[#int_ty]] %[[#builtin]] Aligned 1
+; CHECK-NEXT:                 OpStore %[[#idx]] %[[#tmp]] Aligned 4
+; CHECK-NEXT:    %[[#tmp:]] = OpLoad %[[#int_ty]] %[[#idx]] Aligned 4
+; CHECK-NEXT:   %[[#cmp1:]] = OpIEqual %[[#bool_ty]] %[[#tmp]] %[[#int_0]]
+; CHECK:                      OpBranchConditional %[[#cmp1]] %[[#new_end]] %[[#if_end:]]
+while.body:
+  %3 = call i32 @__hlsl_wave_get_lane_index() [ "convergencectrl"(token %1) ]
+  store i32 %3, ptr %idx, align 4
+  %4 = load i32, ptr %idx, align 4
+  %cmp1 = icmp eq i32 %4, 0
+  br i1 %cmp1, label %if.then, label %if.end
+
+; CHECK:   %[[#if_then:]] = OpLabel
+; CHECK:                    OpBranch %[[#while_end:]]
+if.then:
+  br label %while.end
+
+; CHECK:   %[[#if_end]] = OpLabel
+; CHECK:                  OpBranch %[[#while_cond]]
+if.end:
+  br label %while.cond
+
+; CHECK:   %[[#while_end_loopexit:]] = OpLabel
+; CHECK:                               OpBranch %[[#while_end]]
+
+; CHECK:   %[[#while_end]] = OpLabel
+; CHECK:                     OpReturn
+while.end:
+  ret void
+
+; CHECK:   %[[#new_end]] = OpLabel
+; CHECK:    %[[#route:]] = OpPhi %[[#int_ty]] %[[#int_1]] %[[#while_cond]] %[[#int_0]] %[[#while_body]]
+; CHECK:                   OpSwitch %[[#route]] %[[#while_end_loopexit]] 0 %[[#if_then]]
+}
+
+declare token @llvm.experimental.convergence.entry() #2
+declare token @llvm.experimental.convergence.loop() #2
+declare i32 @__hlsl_wave_get_lane_index() #3
+
+attributes #0 = { convergent noinline norecurse nounwind optnone "no-trapping-math"="true" "stack-protector-buffer-size"="8" }
+attributes #1 = { convergent norecurse "hlsl.numthreads"="1,1,1" "hlsl.shader"="compute" "no-trapping-math"="true" "stack-protector-buffer-size"="8" }
+attributes #2 = { convergent nocallback nofree nosync nounwind willreturn memory(none) }
+attributes #3 = { convergent }
+
+!llvm.module.flags = !{!0, !1}
+
+!0 = !{i32 1, !"wchar_size", i32 4}
+!1 = !{i32 4, !"dx.disable_optimizations", i32 1}
diff --git a/llvm/test/CodeGen/SPIRV/structurizer/merge-exit-convergence-in-break.ll b/llvm/test/CodeGen/SPIRV/structurizer/merge-exit-convergence-in-break.ll
new file mode 100644
index 0000000000000..a67c58fdd5749
--- /dev/null
+++ b/llvm/test/CodeGen/SPIRV/structurizer/merge-exit-convergence-in-break.ll
@@ -0,0 +1,94 @@
+; RUN: llc -mtriple=spirv-unknown-vulkan-compute -O0 %s -o - | FileCheck %s --match-full-lines
+
+target datalayout = "e-i64:64-v16:16-v24:32-v32:32-v48:64-v96:128-v192:256-v256:256-v512:512-v1024:1024-G1"
+target triple = "spirv-unknown-vulkan-compute"
+
+define internal spir_func void @main() #0 {
+
+; CHECK:                      OpDecorate %[[#builtin:]] BuiltIn SubgroupLocalInvocationId
+; CHECK-DAG:  %[[#int_ty:]] = OpTypeInt 32 0
+; CHECK-DAG: %[[#pint_ty:]] = OpTypePointer Function %[[#int_ty]]
+; CHECK-DAG: %[[#bool_ty:]] = OpTypeBool
+; CHECK-DAG:   %[[#int_0:]] = OpConstant %[[#int_ty]] 0
+; CHECK-DAG:   %[[#int_1:]] = OpConstant %[[#int_ty]] 1
+; CHECK-DAG:  %[[#int_10:]] = OpConstant %[[#int_ty]] 10
+
+; CHECK:   %[[#entry:]] = OpLabel
+; CHECK:     %[[#idx:]] = OpVariable %[[#pint_ty]] Function
+; CHECK:                  OpStore %[[#idx]] %[[#int_0]] Aligned 4
+; CHECK:                  OpBranch %[[#while_cond:]]
+entry:
+  %0 = call token @llvm.experimental.convergence.entry()
+  %idx = alloca i32, align 4
+  store i32 0, ptr %idx, align 4
...
[truncated]

The structurizer required regions to be SESE: single entry, single
exit.
This new pass transforms multiple-exit regions into single-exit regions.

```
      +---+
      | A |
      +---+
      /   \
   +---+ +---+
   | B | | C |  A, B & C belongs to the same convergence region.
   +---+ +---+
     |     |
   +---+ +---+
   | D | | E |  C & D belongs to the parent convergence region.
   +---+ +---+  This means B & C are the exit blocks of the region.
      \   /     And D & E the targets of those exits.
       \ /
        |
      +---+
      | F |
      +---+
```

This pass would assign one value per exit target:
B = 0
C = 1

Then, create one variable per exit block (B, C), and assign it
to the correct value: in B, the variable will have the value 0,
and in C, the value 1.

Then, we'd create a new block H, with a PHI node to gather those
2 variables, and a switch, to route to the correct target.

Finally, the branches in B and C are updated to exit to this new block.

```
      +---+
      | A |
      +---+
      /   \
   +---+ +---+
   | B | | C |
   +---+ +---+
      \   /
      +---+
      | H |
      +---+
      /   \
   +---+ +---+
   | D | | E |
   +---+ +---+
      \   /
       \ /
        |
      +---+
      | F |
      +---+
```

Note: the variable is set depending on the condition used to branch.
If B's terminator was conditional, the variable would be set using a
SELECT.
All internal edges of a region are left intact, only exiting edges are
updated.

Signed-off-by: Nathan Gauër <brioche@google.com>
Signed-off-by: Nathan Gauër <brioche@google.com>
@Keenuts
Copy link
Contributor Author

Keenuts commented May 21, 2024

Thanks, feedback applied!

Copy link
Member

@michalpaszkowski michalpaszkowski left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@Keenuts Thank you for the patch and detailed explanation! LGTM!
A high level note for either this/future patch: It may be worth to update the SPIRVUsage document with how target intrinsics are affected in the Vulkan environment. We can discuss if there is any sensible way to split this document to discuss differences between SPIR-V "flavors".

@Keenuts
Copy link
Contributor Author

Keenuts commented Jun 3, 2024

@Keenuts Thank you for the patch and detailed explanation! LGTM! A high level note for either this/future patch: It may be worth to update the SPIRVUsage document with how target intrinsics are affected in the Vulkan environment. We can discuss if there is any sensible way to split this document to discuss differences between SPIR-V "flavors".

Oh right, yes, could be valuable to add this, and in general, explain how the LLVM IR is structurized for the graphical SPIR-V.

@Keenuts Keenuts merged commit a5641f1 into llvm:main Jun 3, 2024
8 checks passed
@Keenuts Keenuts deleted the add-merge-exit-target branch June 3, 2024 09:35
MaskRay added a commit that referenced this pull request Jun 3, 2024
@MaskRay
Copy link
Member

MaskRay commented Jun 3, 2024

auto *RI = dyn_cast when RI is unused causes -Wunused-but-set-variable warnings with a relative new Clang (>= 2021-04). I fixed them in a088c61.

You probably want to compile and test a modern Clang to catch such issues.

@Keenuts
Copy link
Contributor Author

Keenuts commented Jun 3, 2024

auto *RI = dyn_cast when RI is unused causes -Wunused-but-set-variable warnings with a relative new Clang (>= 2021-04). I fixed them in a088c61.

You probably want to compile and test a modern Clang to catch such issues.

Oh, thanks for fixing!
Weirdly my system's clang is 16.0.6, so seems like it's from 2023.
Added '-DCMAKE_CXX_FLAGS=-Wunused-but-set-variable" but this warning is still not showing on main (before your PR of course). With what options do you compile?

Mine are:

cmake -Hllvm -GNinja -Bbuild -DCMAKE_BUILD_TYPE=Debug -DLLVM_ENABLE_PROJECTS='clang' -DLLVM_TARGETS_TO_BUILD='X86' -DLLVM_EXPERIMENTAL_TARGETS_TO_BUILD="DirectX;SPIRV" -DLLVM_OPTIMIZED_TABLEGEN=1 -DLLVM_ENABLE_LLD=1 -DLLVM_USE_SPLIT_DWARF=1 -DCMAKE_CXX_COMPILER=clang++-16 -DCMAKE_C_COMPILER=clang-16 -DCMAKE_EXPORT_COMPILE_COMMANDS=1 -DCMAKE_INSTALL_PREFIX=build/install -DLLVM_INCLUDE_SPIRV_TOOLS_TESTS=1

Copy link
Collaborator

@ssahasra ssahasra left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There is a lot of overlap between GPU backends like SPIRV and AMDGPU, where we are tackling similar problems with the CFG. It will be good to reuse and improve code wherever possible. After a cursory look, I think FixIrreducible.cpp and UnifyLoopExits.cpp can both potentially benefit from the "exit variable" created here. Those passes currently track a whole bunch of boolean variables, one for each exit block.

@@ -150,6 +150,16 @@ class SPIRVEmitIntrinsics
ModulePass::getAnalysisUsage(AU);
}
};

bool isConvergenceIntrinsic(const Instruction *I) {
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This wasn't necessary, because now we can simply do isa<ConvergenceControlInst>(I)

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

8 participants