Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[SPIR-V] Re-implement switch and improve validation of forward calls #87823

Merged
merged 2 commits into from
Apr 9, 2024

Conversation

VyacheslavLevytskyy
Copy link
Contributor

@VyacheslavLevytskyy VyacheslavLevytskyy commented Apr 5, 2024

This PR fixes issue #87763 and preserves valid CFG in cases when previous scheme failed to generate valid code for a switch statement. The PR hardens one existing test case and adds one more test case as a validation of a new switch generation. Tests are passing spirv-val now.

This PR also improves validation of forward calls.

@llvmbot
Copy link
Collaborator

llvmbot commented Apr 5, 2024

@llvm/pr-subscribers-backend-spir-v

Author: Vyacheslav Levytskyy (VyacheslavLevytskyy)

Changes

This PR fixes issue #87763 and preserves valid CFG in cases when previous scheme failed to generate valid code for a switch statement.

This PR also improves validation of forward calls.


Patch is 22.26 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/87823.diff

8 Files Affected:

  • (modified) llvm/lib/Target/SPIRV/SPIRVAsmPrinter.cpp (+4-1)
  • (modified) llvm/lib/Target/SPIRV/SPIRVEmitIntrinsics.cpp (+28-7)
  • (modified) llvm/lib/Target/SPIRV/SPIRVGlobalRegistry.h (+6-1)
  • (modified) llvm/lib/Target/SPIRV/SPIRVISelLowering.cpp (+13-4)
  • (modified) llvm/lib/Target/SPIRV/SPIRVISelLowering.h (+4)
  • (modified) llvm/lib/Target/SPIRV/SPIRVPreLegalizer.cpp (+56-167)
  • (modified) llvm/test/CodeGen/SPIRV/branching/OpSwitchUnreachable.ll (+3-2)
  • (added) llvm/test/CodeGen/SPIRV/branching/switch-range-check.ll (+118)
diff --git a/llvm/lib/Target/SPIRV/SPIRVAsmPrinter.cpp b/llvm/lib/Target/SPIRV/SPIRVAsmPrinter.cpp
index 4eee8062f28248..ffaa7ada9a8060 100644
--- a/llvm/lib/Target/SPIRV/SPIRVAsmPrinter.cpp
+++ b/llvm/lib/Target/SPIRV/SPIRVAsmPrinter.cpp
@@ -43,6 +43,8 @@ using namespace llvm;
 
 namespace {
 class SPIRVAsmPrinter : public AsmPrinter {
+  unsigned NLabels = 0;
+
 public:
   explicit SPIRVAsmPrinter(TargetMachine &TM,
                            std::unique_ptr<MCStreamer> Streamer)
@@ -112,7 +114,7 @@ void SPIRVAsmPrinter::emitEndOfAsmFile(Module &M) {
   // TODO: calculate Bound more carefully from maximum used register number,
   // accounting for generated OpLabels and other related instructions if
   // needed.
-  unsigned Bound = 2 * (ST->getBound() + 1);
+  unsigned Bound = 2 * (ST->getBound() + 1) + NLabels;
   bool FlagToRestore = OutStreamer->getUseAssemblerInfoForParsing();
   OutStreamer->setUseAssemblerInfoForParsing(true);
   if (MCAssembler *Asm = OutStreamer->getAssemblerPtr())
@@ -158,6 +160,7 @@ void SPIRVAsmPrinter::emitOpLabel(const MachineBasicBlock &MBB) {
   LabelInst.setOpcode(SPIRV::OpLabel);
   LabelInst.addOperand(MCOperand::createReg(MAI->getOrCreateMBBRegister(MBB)));
   outputMCInst(LabelInst);
+  ++NLabels;
 }
 
 void SPIRVAsmPrinter::emitBasicBlockStart(const MachineBasicBlock &MBB) {
diff --git a/llvm/lib/Target/SPIRV/SPIRVEmitIntrinsics.cpp b/llvm/lib/Target/SPIRV/SPIRVEmitIntrinsics.cpp
index b341fcb41d0312..e8ce5a35b457d5 100644
--- a/llvm/lib/Target/SPIRV/SPIRVEmitIntrinsics.cpp
+++ b/llvm/lib/Target/SPIRV/SPIRVEmitIntrinsics.cpp
@@ -460,15 +460,36 @@ void SPIRVEmitIntrinsics::preprocessCompositeConstants(IRBuilder<> &B) {
 }
 
 Instruction *SPIRVEmitIntrinsics::visitSwitchInst(SwitchInst &I) {
-  IRBuilder<> B(I.getParent());
+  BasicBlock *ParentBB = I.getParent();
+  IRBuilder<> B(ParentBB);
+  B.SetInsertPoint(&I);
   SmallVector<Value *, 4> Args;
-  for (auto &Op : I.operands())
-    if (Op.get()->getType()->isSized())
+  SmallVector<BasicBlock *> BBCases;
+  for (auto &Op : I.operands()) {
+    if (Op.get()->getType()->isSized()) {
       Args.push_back(Op);
-  B.SetInsertPoint(&I);
-  B.CreateIntrinsic(Intrinsic::spv_switch, {I.getOperand(0)->getType()},
-                    {Args});
-  return &I;
+    } else if (BasicBlock *BB = dyn_cast<BasicBlock>(Op.get())) {
+      BBCases.push_back(BB);
+      Args.push_back(BlockAddress::get(BB->getParent(), BB));
+    } else {
+      report_fatal_error("Unexpected switch operand");
+    }
+  }
+  CallInst *NewI = B.CreateIntrinsic(Intrinsic::spv_switch,
+                                     {I.getOperand(0)->getType()}, {Args});
+  // remove switch to avoid its unneeded and undesirable unwrap into branches
+  // and conditions
+  I.replaceAllUsesWith(NewI);
+  I.eraseFromParent();
+  // insert artificial and temporary instruction to preserve valid CFG,
+  // it will be removed after IR translation pass
+  B.SetInsertPoint(ParentBB);
+  IndirectBrInst *BrI = B.CreateIndirectBr(
+      Constant::getNullValue(PointerType::getUnqual(ParentBB->getContext())),
+      BBCases.size());
+  for (BasicBlock *BBCase : BBCases)
+    BrI->addDestination(BBCase);
+  return BrI;
 }
 
 Instruction *SPIRVEmitIntrinsics::visitGetElementPtrInst(GetElementPtrInst &I) {
diff --git a/llvm/lib/Target/SPIRV/SPIRVGlobalRegistry.h b/llvm/lib/Target/SPIRV/SPIRVGlobalRegistry.h
index ac799374adce8c..37f575e884ef48 100644
--- a/llvm/lib/Target/SPIRV/SPIRVGlobalRegistry.h
+++ b/llvm/lib/Target/SPIRV/SPIRVGlobalRegistry.h
@@ -284,7 +284,12 @@ class SPIRVGlobalRegistry {
   // Return the VReg holding the result of the given OpTypeXXX instruction.
   Register getSPIRVTypeID(const SPIRVType *SpirvType) const;
 
-  void setCurrentFunc(MachineFunction &MF) { CurMF = &MF; }
+  // Return previous value of the current machine function
+  MachineFunction *setCurrentFunc(MachineFunction &MF) {
+    MachineFunction *Ret = CurMF;
+    CurMF = &MF;
+    return Ret;
+  }
 
   // Whether the given VReg has an OpTypeXXX instruction mapped to it with the
   // given opcode (e.g. OpTypeFloat).
diff --git a/llvm/lib/Target/SPIRV/SPIRVISelLowering.cpp b/llvm/lib/Target/SPIRV/SPIRVISelLowering.cpp
index d450078d793fb7..8db54c74f23690 100644
--- a/llvm/lib/Target/SPIRV/SPIRVISelLowering.cpp
+++ b/llvm/lib/Target/SPIRV/SPIRVISelLowering.cpp
@@ -160,12 +160,15 @@ void validateFunCallMachineDef(const SPIRVSubtarget &STI,
             : nullptr;
     if (DefElemType) {
       const Type *DefElemTy = GR.getTypeForSPIRVType(DefElemType);
-      // Switch GR context to the call site instead of the (default) definition
-      // side
-      GR.setCurrentFunc(*FunCall.getParent()->getParent());
+      // validatePtrTypes() works in the context if the call site
+      // When we process historical records about forward calls
+      // we need to switch context to the (forward) call site and
+      // then restore it back to the current machine function.
+      MachineFunction *CurMF =
+          GR.setCurrentFunc(*FunCall.getParent()->getParent());
       validatePtrTypes(STI, CallMRI, GR, FunCall, OpIdx, DefElemType,
                        DefElemTy);
-      GR.setCurrentFunc(*FunDef->getParent()->getParent());
+      GR.setCurrentFunc(*CurMF);
     }
   }
 }
@@ -215,6 +218,11 @@ void validateAccessChain(const SPIRVSubtarget &STI, MachineRegisterInfo *MRI,
 // TODO: the logic of inserting additional bitcast's is to be moved
 // to pre-IRTranslation passes eventually
 void SPIRVTargetLowering::finalizeLowering(MachineFunction &MF) const {
+  // finalizeLowering() is called twice (see GlobalISel/InstructionSelect.cpp)
+  // We'd like to avoid the needless second processing pass.
+  if (ProcessedMF.find(&MF) != ProcessedMF.end())
+    return;
+
   MachineRegisterInfo *MRI = &MF.getRegInfo();
   SPIRVGlobalRegistry &GR = *STI.getSPIRVGlobalRegistry();
   GR.setCurrentFunc(MF);
@@ -302,5 +310,6 @@ void SPIRVTargetLowering::finalizeLowering(MachineFunction &MF) const {
       }
     }
   }
+  ProcessedMF.insert(&MF);
   TargetLowering::finalizeLowering(MF);
 }
diff --git a/llvm/lib/Target/SPIRV/SPIRVISelLowering.h b/llvm/lib/Target/SPIRV/SPIRVISelLowering.h
index b01571bfc1eeb5..8c1de7d97d1a3c 100644
--- a/llvm/lib/Target/SPIRV/SPIRVISelLowering.h
+++ b/llvm/lib/Target/SPIRV/SPIRVISelLowering.h
@@ -16,6 +16,7 @@
 
 #include "SPIRVGlobalRegistry.h"
 #include "llvm/CodeGen/TargetLowering.h"
+#include <set>
 
 namespace llvm {
 class SPIRVSubtarget;
@@ -23,6 +24,9 @@ class SPIRVSubtarget;
 class SPIRVTargetLowering : public TargetLowering {
   const SPIRVSubtarget &STI;
 
+  // Record of already processed machine functions
+  mutable std::set<const MachineFunction *> ProcessedMF;
+
 public:
   explicit SPIRVTargetLowering(const TargetMachine &TM,
                                const SPIRVSubtarget &ST)
diff --git a/llvm/lib/Target/SPIRV/SPIRVPreLegalizer.cpp b/llvm/lib/Target/SPIRV/SPIRVPreLegalizer.cpp
index b133f0ae85de20..7e155a36aadbc4 100644
--- a/llvm/lib/Target/SPIRV/SPIRVPreLegalizer.cpp
+++ b/llvm/lib/Target/SPIRV/SPIRVPreLegalizer.cpp
@@ -438,186 +438,75 @@ static void processInstrsWithTypeFolding(MachineFunction &MF,
   }
 }
 
+// Find basic blocks of the switch and replace registers in spv_switch() by its
+// MBB equivalent.
 static void processSwitches(MachineFunction &MF, SPIRVGlobalRegistry *GR,
                             MachineIRBuilder MIB) {
-  // Before IRTranslator pass, calls to spv_switch intrinsic are inserted before
-  // each switch instruction. IRTranslator lowers switches to G_ICMP + G_BRCOND
-  // + G_BR triples. A switch with two cases may be transformed to this MIR
-  // sequence:
-  //
-  //   intrinsic(@llvm.spv.switch), %CmpReg, %Const0, %Const1
-  //   %Dst0 = G_ICMP intpred(eq), %CmpReg, %Const0
-  //   G_BRCOND %Dst0, %bb.2
-  //   G_BR %bb.5
-  // bb.5.entry:
-  //   %Dst1 = G_ICMP intpred(eq), %CmpReg, %Const1
-  //   G_BRCOND %Dst1, %bb.3
-  //   G_BR %bb.4
-  // bb.2.sw.bb:
-  //   ...
-  // bb.3.sw.bb1:
-  //   ...
-  // bb.4.sw.epilog:
-  //   ...
-  //
-  // Sometimes (in case of range-compare switches), additional G_SUBs
-  // instructions are inserted before G_ICMPs. Those need to be additionally
-  // processed.
-  //
-  // This function modifies spv_switch call's operands to include destination
-  // MBBs (default and for each constant value).
-  //
-  // At the end, the function removes redundant [G_SUB] + G_ICMP + G_BRCOND +
-  // G_BR sequences.
-
-  MachineRegisterInfo &MRI = MF.getRegInfo();
-
-  // Collect spv_switches and G_ICMPs across all MBBs in MF.
-  std::vector<MachineInstr *> RelevantInsts;
-
-  // Collect redundant MIs from [G_SUB] + G_ICMP + G_BRCOND + G_BR sequences.
-  // After updating spv_switches, the instructions can be removed.
-  std::vector<MachineInstr *> PostUpdateArtifacts;
-
-  // Temporary set of compare registers. G_SUBs and G_ICMPs relating to
-  // spv_switch use these registers.
-  DenseSet<Register> CompareRegs;
+  DenseMap<const BasicBlock *, MachineBasicBlock *> BB2MBB;
+  SmallVector<std::pair<MachineInstr *, SmallVector<MachineInstr *, 8>>>
+      Switches;
   for (MachineBasicBlock &MBB : MF) {
+    MachineRegisterInfo &MRI = MF.getRegInfo();
+    BB2MBB[MBB.getBasicBlock()] = &MBB;
     for (MachineInstr &MI : MBB) {
+      if (!isSpvIntrinsic(MI, Intrinsic::spv_switch))
+        continue;
       // Calls to spv_switch intrinsics representing IR switches.
-      if (isSpvIntrinsic(MI, Intrinsic::spv_switch)) {
-        assert(MI.getOperand(1).isReg());
-        CompareRegs.insert(MI.getOperand(1).getReg());
-        RelevantInsts.push_back(&MI);
-      }
-
-      // G_SUBs coming from range-compare switch lowering. G_SUBs are found
-      // after spv_switch but before G_ICMP.
-      if (MI.getOpcode() == TargetOpcode::G_SUB && MI.getOperand(1).isReg() &&
-          CompareRegs.contains(MI.getOperand(1).getReg())) {
-        assert(MI.getOperand(0).isReg() && MI.getOperand(1).isReg());
-        Register Dst = MI.getOperand(0).getReg();
-        CompareRegs.insert(Dst);
-        PostUpdateArtifacts.push_back(&MI);
-      }
-
-      // G_ICMPs relating to switches.
-      if (MI.getOpcode() == TargetOpcode::G_ICMP && MI.getOperand(2).isReg() &&
-          CompareRegs.contains(MI.getOperand(2).getReg())) {
-        Register Dst = MI.getOperand(0).getReg();
-        RelevantInsts.push_back(&MI);
-        PostUpdateArtifacts.push_back(&MI);
-        MachineInstr *CBr = MRI.use_begin(Dst)->getParent();
-        assert(CBr->getOpcode() == SPIRV::G_BRCOND);
-        PostUpdateArtifacts.push_back(CBr);
-        MachineInstr *Br = CBr->getNextNode();
-        assert(Br->getOpcode() == SPIRV::G_BR);
-        PostUpdateArtifacts.push_back(Br);
+      SmallVector<MachineInstr *, 8> NewOps;
+      for (unsigned i = 2; i < MI.getNumOperands(); ++i) {
+        Register Reg = MI.getOperand(i).getReg();
+        if (i % 2 == 1) {
+          MachineInstr *ConstInstr = getDefInstrMaybeConstant(Reg, &MRI);
+          NewOps.push_back(ConstInstr);
+        } else {
+          MachineInstr *BuildMBB = MRI.getVRegDef(Reg);
+          assert(BuildMBB &&
+                 BuildMBB->getOpcode() == TargetOpcode::G_BLOCK_ADDR &&
+                 BuildMBB->getOperand(1).isBlockAddress() &&
+                 BuildMBB->getOperand(1).getBlockAddress());
+          NewOps.push_back(BuildMBB);
+        }
       }
+      Switches.push_back(std::make_pair(&MI, NewOps));
     }
   }
 
-  // Update each spv_switch with destination MBBs.
-  for (auto i = RelevantInsts.begin(); i != RelevantInsts.end(); i++) {
-    if (!isSpvIntrinsic(**i, Intrinsic::spv_switch))
-      continue;
-
-    // Currently considered spv_switch.
-    MachineInstr *Switch = *i;
-    // Set the first successor as default MBB to support empty switches.
-    MachineBasicBlock *DefaultMBB = *Switch->getParent()->succ_begin();
-    // Container for mapping values to MMBs.
-    SmallDenseMap<uint64_t, MachineBasicBlock *> ValuesToMBBs;
-
-    // Walk all G_ICMPs to collect ValuesToMBBs. Start at currently considered
-    // spv_switch (i) and break at any spv_switch with the same compare
-    // register (indicating we are back at the same scope).
-    Register CompareReg = Switch->getOperand(1).getReg();
-    for (auto j = i + 1; j != RelevantInsts.end(); j++) {
-      if (isSpvIntrinsic(**j, Intrinsic::spv_switch) &&
-          (*j)->getOperand(1).getReg() == CompareReg)
-        break;
-
-      if (!((*j)->getOpcode() == TargetOpcode::G_ICMP &&
-            (*j)->getOperand(2).getReg() == CompareReg))
-        continue;
-
-      MachineInstr *ICMP = *j;
-      Register Dst = ICMP->getOperand(0).getReg();
-      MachineOperand &PredOp = ICMP->getOperand(1);
-      const auto CC = static_cast<CmpInst::Predicate>(PredOp.getPredicate());
-      (void)CC;
-      assert((CC == CmpInst::ICMP_EQ || CC == CmpInst::ICMP_ULE) &&
-             MRI.hasOneUse(Dst) && MRI.hasOneDef(CompareReg));
-      uint64_t Value = getIConstVal(ICMP->getOperand(3).getReg(), &MRI);
-      MachineInstr *CBr = MRI.use_begin(Dst)->getParent();
-      assert(CBr->getOpcode() == SPIRV::G_BRCOND && CBr->getOperand(1).isMBB());
-      MachineBasicBlock *MBB = CBr->getOperand(1).getMBB();
-
-      // Map switch case Value to target MBB.
-      ValuesToMBBs[Value] = MBB;
-
-      // Add target MBB as successor to the switch's MBB.
-      Switch->getParent()->addSuccessor(MBB);
-
-      // The next MI is always G_BR to either the next case or the default.
-      MachineInstr *NextMI = CBr->getNextNode();
-      assert(NextMI->getOpcode() == SPIRV::G_BR &&
-             NextMI->getOperand(0).isMBB());
-      MachineBasicBlock *NextMBB = NextMI->getOperand(0).getMBB();
-      // Default MBB does not begin with G_ICMP using spv_switch compare
-      // register.
-      if (NextMBB->front().getOpcode() != SPIRV::G_ICMP ||
-          (NextMBB->front().getOperand(2).isReg() &&
-           NextMBB->front().getOperand(2).getReg() != CompareReg)) {
-        // Set default MBB and add it as successor to the switch's MBB.
-        DefaultMBB = NextMBB;
-        Switch->getParent()->addSuccessor(DefaultMBB);
+  SmallPtrSet<MachineInstr *, 8> ToEraseMI;
+  for (auto &SwIt : Switches) {
+    MachineInstr &MI = *SwIt.first;
+    SmallVector<MachineInstr *, 8> &Ins = SwIt.second;
+    SmallVector<MachineOperand, 8> NewOps;
+    for (unsigned i = 0; i < Ins.size(); ++i) {
+      if (Ins[i]->getOpcode() == TargetOpcode::G_BLOCK_ADDR) {
+        BasicBlock *CaseBB =
+            Ins[i]->getOperand(1).getBlockAddress()->getBasicBlock();
+        auto It = BB2MBB.find(CaseBB);
+        if (It == BB2MBB.end())
+          report_fatal_error("cannot find a machine basic block by a basic "
+                             "block in a switch statement");
+        NewOps.push_back(MachineOperand::CreateMBB(It->second));
+        MI.getParent()->addSuccessor(It->second);
+        ToEraseMI.insert(Ins[i]);
+      } else {
+        NewOps.push_back(
+            MachineOperand::CreateCImm(Ins[i]->getOperand(1).getCImm()));
       }
     }
-
-    // Modify considered spv_switch operands using collected Values and
-    // MBBs.
-    SmallVector<const ConstantInt *, 3> Values;
-    SmallVector<MachineBasicBlock *, 3> MBBs;
-    for (unsigned k = 2; k < Switch->getNumExplicitOperands(); k++) {
-      Register CReg = Switch->getOperand(k).getReg();
-      uint64_t Val = getIConstVal(CReg, &MRI);
-      MachineInstr *ConstInstr = getDefInstrMaybeConstant(CReg, &MRI);
-      if (!ValuesToMBBs[Val])
-        continue;
-
-      Values.push_back(ConstInstr->getOperand(1).getCImm());
-      MBBs.push_back(ValuesToMBBs[Val]);
-    }
-
-    for (unsigned k = Switch->getNumExplicitOperands() - 1; k > 1; k--)
-      Switch->removeOperand(k);
-
-    Switch->addOperand(MachineOperand::CreateMBB(DefaultMBB));
-    for (unsigned k = 0; k < Values.size(); k++) {
-      Switch->addOperand(MachineOperand::CreateCImm(Values[k]));
-      Switch->addOperand(MachineOperand::CreateMBB(MBBs[k]));
-    }
-  }
-
-  for (MachineInstr *MI : PostUpdateArtifacts) {
-    MachineBasicBlock *ParentMBB = MI->getParent();
-    MI->eraseFromParent();
-    // If G_ICMP + G_BRCOND + G_BR were the only MIs in MBB, erase this MBB. It
-    // can be safely assumed, there are no breaks or phis directing into this
-    // MBB. However, we need to remove this MBB from the CFG graph. MBBs must be
-    // erased top-down.
-    if (ParentMBB->empty()) {
-      while (!ParentMBB->pred_empty())
-        (*ParentMBB->pred_begin())->removeSuccessor(ParentMBB);
-
-      while (!ParentMBB->succ_empty())
-        ParentMBB->removeSuccessor(ParentMBB->succ_begin());
-
-      ParentMBB->eraseFromParent();
+    for (unsigned i = MI.getNumOperands() - 1; i > 1; --i)
+      MI.removeOperand(i);
+    for (auto &MO : NewOps)
+      MI.addOperand(MO);
+    if (MachineInstr *Next = MI.getNextNode()) {
+      if (isSpvIntrinsic(*Next, Intrinsic::spv_track_constant)) {
+        ToEraseMI.insert(Next);
+        Next = MI.getNextNode();
+      }
+      if (Next && Next->getOpcode() == TargetOpcode::G_BRINDIRECT)
+        ToEraseMI.insert(Next);
     }
   }
+  for (MachineInstr *BlockAddrI : ToEraseMI)
+    BlockAddrI->eraseFromParent();
 }
 
 static bool isImplicitFallthrough(MachineBasicBlock &MBB) {
diff --git a/llvm/test/CodeGen/SPIRV/branching/OpSwitchUnreachable.ll b/llvm/test/CodeGen/SPIRV/branching/OpSwitchUnreachable.ll
index e73efbeade70dc..6eb36e5756ecf6 100644
--- a/llvm/test/CodeGen/SPIRV/branching/OpSwitchUnreachable.ll
+++ b/llvm/test/CodeGen/SPIRV/branching/OpSwitchUnreachable.ll
@@ -1,8 +1,9 @@
 ; RUN: llc -O0 -mtriple=spirv32-unknown-unknown %s -o - | FileCheck %s --check-prefix=CHECK-SPIRV
+; RUN: %if spirv-tools %{ llc -O0 -mtriple=spirv32-unknown-unknown %s -o - -filetype=obj | spirv-val %}
 
 define void @test_switch_with_unreachable_block(i1 %a) {
   %value = zext i1 %a to i32
-; CHECK-SPIRV:      OpSwitch %[[#]] %[[#REACHABLE:]]
+; CHECK-SPIRV:      OpSwitch %[[#]] %[[#UNREACHABLE:]] 0 %[[#REACHABLE:]] 1 %[[#REACHABLE:]]
   switch i32 %value, label %unreachable [
     i32 0, label %reachable
     i32 1, label %reachable
@@ -13,7 +14,7 @@ reachable:
 ; CHECK-SPIRV-NEXT: OpReturn
   ret void
 
-; CHECK-SPIRV:      %[[#]] = OpLabel
+; CHECK-SPIRV:      %[[#UNREACHABLE]] = OpLabel
 ; CHECK-SPIRV-NEXT: OpUnreachable
 unreachable:
   unreachable
diff --git a/llvm/test/CodeGen/SPIRV/branching/switch-range-check.ll b/llvm/test/CodeGen/SPIRV/branching/switch-range-check.ll
new file mode 100644
index 00000000000000..8ec384a0b07d2d
--- /dev/null
+++ b/llvm/test/CodeGen/SPIRV/branching/switch-range-check.ll
@@ -0,0 +1,118 @@
+; RUN: llc -O0 -mtriple=spirv32-unknown-unknown %s -o - | FileCheck %s
+; RUN: %if spirv-tools %{ llc -O0 -mtriple=spirv32-unknown-unknown %s -o - -filetype=obj | spirv-val %}
+
+; CHECK: %[[#Var:]] = OpPhi
+; CHECK: OpSwitch %[[#Var]] %[[#]] [[#]] %[[#]] [[#]] %[[#]] [[#]] %[[#]] [[#]] %[[#]] [[#]] %[[#]] [[#]] %[[#]] [[#]] %[[#]] [[#]] %[[#]] [[#]] %[[#]] [[#]] %[[#]] [[#]] %[[#]] [[#]] %[[#]]
+; CHECK-COUNT-11: OpBranch
+; CHECK-NOT: OpBranch
+
+define spir_func void @foo(i64 noundef %addr, i64 noundef %as) {
+entry:
+  %0 = inttoptr i64 %as to ptr addrspace(4)
+  %1 = load i8, ptr addrspace(4) %0
+  %cmp = icmp sgt i8 %1, 0
+  br i1 %cmp, label %if.then, label %if.end
+
+if.then:                                          ; preds = %entry
+  %add.ptr = getelementptr inbounds i8, ptr addrspace(4) %0, i64 1
+  %2 = load i8, ptr addrspace(4) %add.ptr
+  br label %if.end
+
+if.end:                                           ; preds = %if.then, %entry
+  %shadow_value.0.in = phi i8 [ %2, %if.then ], [ %1, %entry ]
+  switch i8 %shadow_value.0.in, label %sw.default [
+    i8 -127, label %sw.epilog
+    i8 -126, label %sw.bb3
+    i8 -125, label %sw.bb4
+    i8 -111, label %sw.bb5
+    i8 -110, label %sw.bb6
+    i8 -109, label %sw.bb7
+    i8 -15, label %sw.bb8
+    i8 -14, label %sw.bb8
+    i8 -13, label %sw.bb8
+    i8 -124, label %sw.bb9
+    i8 -95, label %sw.bb10
+    i8 -123, label %sw.bb11
+  ]
+
+sw.bb3:                                           ; preds = %if.end
+  br label %sw.epilog
+
+sw.bb4:                                           ; preds = %if.end
+  br label %sw.epilog
+
+sw.bb5:                                           ; preds = %if.end
+  br label %sw.epilog
+
+sw.bb6:                                           ; preds = %if.end
+  br lab...
[truncated]

Comment on lines 114 to 116
// TODO: calculate Bound more carefully from maximum used register number,
// accounting for generated OpLabels and other related instructions if
// needed.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Don't we need to update this comment?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Indeed, thank you.

Copy link
Member

@michalpaszkowski michalpaszkowski left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM! Apologies for this taking so long... needed to dig out older versions of benchmarks. I remember switches used to be quite sensitive point for performance. No significant differences here.

@VyacheslavLevytskyy VyacheslavLevytskyy merged commit 23b058c into llvm:main Apr 9, 2024
4 of 5 checks passed
@Keenuts
Copy link
Contributor

Keenuts commented May 15, 2024

@VyacheslavLevytskyy
Seems like using indirect branch to keep the CFG correct caused issues with the ASM printer:

Example building switch-range-check.ll

	OpReturn
Ltmp0:                                  ; Address of block that was removed by CodeGen
Ltmp1:                                  ; Address of block that was removed by CodeGen
Ltmp2:                                  ; Address of block that was removed by CodeGen
Ltmp3:                                  ; Address of block that was removed by CodeGen
Ltmp4:                                  ; Address of block that was removed by CodeGen
Ltmp5:                                  ; Address of block that was removed by CodeGen
Ltmp6:                                  ; Address of block that was removed by CodeGen
Ltmp7:                                  ; Address of block that was removed by CodeGen
Ltmp8:                                  ; Address of block that was removed by CodeGen
Ltmp9:                                  ; Address of block that was removed by CodeGen
Ltmp10:                                 ; Address of block that was removed by CodeGen
	OpFunctionEnd

Seems like the indirect branch is removed by the IRTranslator since the switch don't select it, but because the MBB/BBs are still marked as "AddressTaken", the asm printer generates bad labels.
The dirty solution I have locally is to modify the AsmPrinter and check for the triple before emitting this label+comment, but do you have an idea how to fix this?

@VyacheslavLevytskyy
Copy link
Contributor Author

@Keenuts Thank you! I created #92390 to address the issue.

VyacheslavLevytskyy added a commit that referenced this pull request May 17, 2024
…after internal intrinsic 'spv_switch' is processed (#92390)

After internal intrinsic 'spv_switch' is processed we need to delete
G_BLOCK_ADDR instructions that were generated to keep track of the
corresponding basic blocks. If we just delete G_BLOCK_ADDR instructions
with BlockAddress operands, this leaves their BasicBlock counterparts in
a "address taken" status. This would make AsmPrinter to generate a
series of unneeded labels of a `"Address of block that was removed by
CodeGen"` kind. This PR is to ensure that we don't have a dangling
BlockAddress constants by zapping the BlockAddress nodes, and only after
that proceed with erasing G_BLOCK_ADDR instructions.

See also #87823 for more
details.
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

5 participants