diff --git a/llvm/lib/Target/SPIRV/SPIRVAsmPrinter.cpp b/llvm/lib/Target/SPIRV/SPIRVAsmPrinter.cpp index 4eee8062f2824..1de4616fd5b77 100644 --- a/llvm/lib/Target/SPIRV/SPIRVAsmPrinter.cpp +++ b/llvm/lib/Target/SPIRV/SPIRVAsmPrinter.cpp @@ -43,6 +43,8 @@ using namespace llvm; namespace { class SPIRVAsmPrinter : public AsmPrinter { + unsigned NLabels = 0; + public: explicit SPIRVAsmPrinter(TargetMachine &TM, std::unique_ptr Streamer) @@ -109,10 +111,9 @@ void SPIRVAsmPrinter::emitEndOfAsmFile(Module &M) { uint32_t DecSPIRVVersion = ST->getSPIRVVersion(); uint32_t Major = DecSPIRVVersion / 10; uint32_t Minor = DecSPIRVVersion - Major * 10; - // TODO: calculate Bound more carefully from maximum used register number, - // accounting for generated OpLabels and other related instructions if - // needed. - unsigned Bound = 2 * (ST->getBound() + 1); + // Bound is an approximation that accounts for the maximum used register + // number and number of generated OpLabels + unsigned Bound = 2 * (ST->getBound() + 1) + NLabels; bool FlagToRestore = OutStreamer->getUseAssemblerInfoForParsing(); OutStreamer->setUseAssemblerInfoForParsing(true); if (MCAssembler *Asm = OutStreamer->getAssemblerPtr()) @@ -158,6 +159,7 @@ void SPIRVAsmPrinter::emitOpLabel(const MachineBasicBlock &MBB) { LabelInst.setOpcode(SPIRV::OpLabel); LabelInst.addOperand(MCOperand::createReg(MAI->getOrCreateMBBRegister(MBB))); outputMCInst(LabelInst); + ++NLabels; } void SPIRVAsmPrinter::emitBasicBlockStart(const MachineBasicBlock &MBB) { diff --git a/llvm/lib/Target/SPIRV/SPIRVEmitIntrinsics.cpp b/llvm/lib/Target/SPIRV/SPIRVEmitIntrinsics.cpp index b341fcb41d031..e8ce5a35b457d 100644 --- a/llvm/lib/Target/SPIRV/SPIRVEmitIntrinsics.cpp +++ b/llvm/lib/Target/SPIRV/SPIRVEmitIntrinsics.cpp @@ -460,15 +460,36 @@ void SPIRVEmitIntrinsics::preprocessCompositeConstants(IRBuilder<> &B) { } Instruction *SPIRVEmitIntrinsics::visitSwitchInst(SwitchInst &I) { - IRBuilder<> B(I.getParent()); + BasicBlock *ParentBB = I.getParent(); + IRBuilder<> B(ParentBB); + B.SetInsertPoint(&I); SmallVector Args; - for (auto &Op : I.operands()) - if (Op.get()->getType()->isSized()) + SmallVector BBCases; + for (auto &Op : I.operands()) { + if (Op.get()->getType()->isSized()) { Args.push_back(Op); - B.SetInsertPoint(&I); - B.CreateIntrinsic(Intrinsic::spv_switch, {I.getOperand(0)->getType()}, - {Args}); - return &I; + } else if (BasicBlock *BB = dyn_cast(Op.get())) { + BBCases.push_back(BB); + Args.push_back(BlockAddress::get(BB->getParent(), BB)); + } else { + report_fatal_error("Unexpected switch operand"); + } + } + CallInst *NewI = B.CreateIntrinsic(Intrinsic::spv_switch, + {I.getOperand(0)->getType()}, {Args}); + // remove switch to avoid its unneeded and undesirable unwrap into branches + // and conditions + I.replaceAllUsesWith(NewI); + I.eraseFromParent(); + // insert artificial and temporary instruction to preserve valid CFG, + // it will be removed after IR translation pass + B.SetInsertPoint(ParentBB); + IndirectBrInst *BrI = B.CreateIndirectBr( + Constant::getNullValue(PointerType::getUnqual(ParentBB->getContext())), + BBCases.size()); + for (BasicBlock *BBCase : BBCases) + BrI->addDestination(BBCase); + return BrI; } Instruction *SPIRVEmitIntrinsics::visitGetElementPtrInst(GetElementPtrInst &I) { diff --git a/llvm/lib/Target/SPIRV/SPIRVGlobalRegistry.h b/llvm/lib/Target/SPIRV/SPIRVGlobalRegistry.h index ac799374adce8..37f575e884ef4 100644 --- a/llvm/lib/Target/SPIRV/SPIRVGlobalRegistry.h +++ b/llvm/lib/Target/SPIRV/SPIRVGlobalRegistry.h @@ -284,7 +284,12 @@ class SPIRVGlobalRegistry { // Return the VReg holding the result of the given OpTypeXXX instruction. Register getSPIRVTypeID(const SPIRVType *SpirvType) const; - void setCurrentFunc(MachineFunction &MF) { CurMF = &MF; } + // Return previous value of the current machine function + MachineFunction *setCurrentFunc(MachineFunction &MF) { + MachineFunction *Ret = CurMF; + CurMF = &MF; + return Ret; + } // Whether the given VReg has an OpTypeXXX instruction mapped to it with the // given opcode (e.g. OpTypeFloat). diff --git a/llvm/lib/Target/SPIRV/SPIRVISelLowering.cpp b/llvm/lib/Target/SPIRV/SPIRVISelLowering.cpp index d450078d793fb..8db54c74f2369 100644 --- a/llvm/lib/Target/SPIRV/SPIRVISelLowering.cpp +++ b/llvm/lib/Target/SPIRV/SPIRVISelLowering.cpp @@ -160,12 +160,15 @@ void validateFunCallMachineDef(const SPIRVSubtarget &STI, : nullptr; if (DefElemType) { const Type *DefElemTy = GR.getTypeForSPIRVType(DefElemType); - // Switch GR context to the call site instead of the (default) definition - // side - GR.setCurrentFunc(*FunCall.getParent()->getParent()); + // validatePtrTypes() works in the context if the call site + // When we process historical records about forward calls + // we need to switch context to the (forward) call site and + // then restore it back to the current machine function. + MachineFunction *CurMF = + GR.setCurrentFunc(*FunCall.getParent()->getParent()); validatePtrTypes(STI, CallMRI, GR, FunCall, OpIdx, DefElemType, DefElemTy); - GR.setCurrentFunc(*FunDef->getParent()->getParent()); + GR.setCurrentFunc(*CurMF); } } } @@ -215,6 +218,11 @@ void validateAccessChain(const SPIRVSubtarget &STI, MachineRegisterInfo *MRI, // TODO: the logic of inserting additional bitcast's is to be moved // to pre-IRTranslation passes eventually void SPIRVTargetLowering::finalizeLowering(MachineFunction &MF) const { + // finalizeLowering() is called twice (see GlobalISel/InstructionSelect.cpp) + // We'd like to avoid the needless second processing pass. + if (ProcessedMF.find(&MF) != ProcessedMF.end()) + return; + MachineRegisterInfo *MRI = &MF.getRegInfo(); SPIRVGlobalRegistry &GR = *STI.getSPIRVGlobalRegistry(); GR.setCurrentFunc(MF); @@ -302,5 +310,6 @@ void SPIRVTargetLowering::finalizeLowering(MachineFunction &MF) const { } } } + ProcessedMF.insert(&MF); TargetLowering::finalizeLowering(MF); } diff --git a/llvm/lib/Target/SPIRV/SPIRVISelLowering.h b/llvm/lib/Target/SPIRV/SPIRVISelLowering.h index b01571bfc1eeb..8c1de7d97d1a3 100644 --- a/llvm/lib/Target/SPIRV/SPIRVISelLowering.h +++ b/llvm/lib/Target/SPIRV/SPIRVISelLowering.h @@ -16,6 +16,7 @@ #include "SPIRVGlobalRegistry.h" #include "llvm/CodeGen/TargetLowering.h" +#include namespace llvm { class SPIRVSubtarget; @@ -23,6 +24,9 @@ class SPIRVSubtarget; class SPIRVTargetLowering : public TargetLowering { const SPIRVSubtarget &STI; + // Record of already processed machine functions + mutable std::set ProcessedMF; + public: explicit SPIRVTargetLowering(const TargetMachine &TM, const SPIRVSubtarget &ST) diff --git a/llvm/lib/Target/SPIRV/SPIRVPreLegalizer.cpp b/llvm/lib/Target/SPIRV/SPIRVPreLegalizer.cpp index b133f0ae85de2..7e155a36aadbc 100644 --- a/llvm/lib/Target/SPIRV/SPIRVPreLegalizer.cpp +++ b/llvm/lib/Target/SPIRV/SPIRVPreLegalizer.cpp @@ -438,186 +438,75 @@ static void processInstrsWithTypeFolding(MachineFunction &MF, } } +// Find basic blocks of the switch and replace registers in spv_switch() by its +// MBB equivalent. static void processSwitches(MachineFunction &MF, SPIRVGlobalRegistry *GR, MachineIRBuilder MIB) { - // Before IRTranslator pass, calls to spv_switch intrinsic are inserted before - // each switch instruction. IRTranslator lowers switches to G_ICMP + G_BRCOND - // + G_BR triples. A switch with two cases may be transformed to this MIR - // sequence: - // - // intrinsic(@llvm.spv.switch), %CmpReg, %Const0, %Const1 - // %Dst0 = G_ICMP intpred(eq), %CmpReg, %Const0 - // G_BRCOND %Dst0, %bb.2 - // G_BR %bb.5 - // bb.5.entry: - // %Dst1 = G_ICMP intpred(eq), %CmpReg, %Const1 - // G_BRCOND %Dst1, %bb.3 - // G_BR %bb.4 - // bb.2.sw.bb: - // ... - // bb.3.sw.bb1: - // ... - // bb.4.sw.epilog: - // ... - // - // Sometimes (in case of range-compare switches), additional G_SUBs - // instructions are inserted before G_ICMPs. Those need to be additionally - // processed. - // - // This function modifies spv_switch call's operands to include destination - // MBBs (default and for each constant value). - // - // At the end, the function removes redundant [G_SUB] + G_ICMP + G_BRCOND + - // G_BR sequences. - - MachineRegisterInfo &MRI = MF.getRegInfo(); - - // Collect spv_switches and G_ICMPs across all MBBs in MF. - std::vector RelevantInsts; - - // Collect redundant MIs from [G_SUB] + G_ICMP + G_BRCOND + G_BR sequences. - // After updating spv_switches, the instructions can be removed. - std::vector PostUpdateArtifacts; - - // Temporary set of compare registers. G_SUBs and G_ICMPs relating to - // spv_switch use these registers. - DenseSet CompareRegs; + DenseMap BB2MBB; + SmallVector>> + Switches; for (MachineBasicBlock &MBB : MF) { + MachineRegisterInfo &MRI = MF.getRegInfo(); + BB2MBB[MBB.getBasicBlock()] = &MBB; for (MachineInstr &MI : MBB) { + if (!isSpvIntrinsic(MI, Intrinsic::spv_switch)) + continue; // Calls to spv_switch intrinsics representing IR switches. - if (isSpvIntrinsic(MI, Intrinsic::spv_switch)) { - assert(MI.getOperand(1).isReg()); - CompareRegs.insert(MI.getOperand(1).getReg()); - RelevantInsts.push_back(&MI); - } - - // G_SUBs coming from range-compare switch lowering. G_SUBs are found - // after spv_switch but before G_ICMP. - if (MI.getOpcode() == TargetOpcode::G_SUB && MI.getOperand(1).isReg() && - CompareRegs.contains(MI.getOperand(1).getReg())) { - assert(MI.getOperand(0).isReg() && MI.getOperand(1).isReg()); - Register Dst = MI.getOperand(0).getReg(); - CompareRegs.insert(Dst); - PostUpdateArtifacts.push_back(&MI); - } - - // G_ICMPs relating to switches. - if (MI.getOpcode() == TargetOpcode::G_ICMP && MI.getOperand(2).isReg() && - CompareRegs.contains(MI.getOperand(2).getReg())) { - Register Dst = MI.getOperand(0).getReg(); - RelevantInsts.push_back(&MI); - PostUpdateArtifacts.push_back(&MI); - MachineInstr *CBr = MRI.use_begin(Dst)->getParent(); - assert(CBr->getOpcode() == SPIRV::G_BRCOND); - PostUpdateArtifacts.push_back(CBr); - MachineInstr *Br = CBr->getNextNode(); - assert(Br->getOpcode() == SPIRV::G_BR); - PostUpdateArtifacts.push_back(Br); + SmallVector NewOps; + for (unsigned i = 2; i < MI.getNumOperands(); ++i) { + Register Reg = MI.getOperand(i).getReg(); + if (i % 2 == 1) { + MachineInstr *ConstInstr = getDefInstrMaybeConstant(Reg, &MRI); + NewOps.push_back(ConstInstr); + } else { + MachineInstr *BuildMBB = MRI.getVRegDef(Reg); + assert(BuildMBB && + BuildMBB->getOpcode() == TargetOpcode::G_BLOCK_ADDR && + BuildMBB->getOperand(1).isBlockAddress() && + BuildMBB->getOperand(1).getBlockAddress()); + NewOps.push_back(BuildMBB); + } } + Switches.push_back(std::make_pair(&MI, NewOps)); } } - // Update each spv_switch with destination MBBs. - for (auto i = RelevantInsts.begin(); i != RelevantInsts.end(); i++) { - if (!isSpvIntrinsic(**i, Intrinsic::spv_switch)) - continue; - - // Currently considered spv_switch. - MachineInstr *Switch = *i; - // Set the first successor as default MBB to support empty switches. - MachineBasicBlock *DefaultMBB = *Switch->getParent()->succ_begin(); - // Container for mapping values to MMBs. - SmallDenseMap ValuesToMBBs; - - // Walk all G_ICMPs to collect ValuesToMBBs. Start at currently considered - // spv_switch (i) and break at any spv_switch with the same compare - // register (indicating we are back at the same scope). - Register CompareReg = Switch->getOperand(1).getReg(); - for (auto j = i + 1; j != RelevantInsts.end(); j++) { - if (isSpvIntrinsic(**j, Intrinsic::spv_switch) && - (*j)->getOperand(1).getReg() == CompareReg) - break; - - if (!((*j)->getOpcode() == TargetOpcode::G_ICMP && - (*j)->getOperand(2).getReg() == CompareReg)) - continue; - - MachineInstr *ICMP = *j; - Register Dst = ICMP->getOperand(0).getReg(); - MachineOperand &PredOp = ICMP->getOperand(1); - const auto CC = static_cast(PredOp.getPredicate()); - (void)CC; - assert((CC == CmpInst::ICMP_EQ || CC == CmpInst::ICMP_ULE) && - MRI.hasOneUse(Dst) && MRI.hasOneDef(CompareReg)); - uint64_t Value = getIConstVal(ICMP->getOperand(3).getReg(), &MRI); - MachineInstr *CBr = MRI.use_begin(Dst)->getParent(); - assert(CBr->getOpcode() == SPIRV::G_BRCOND && CBr->getOperand(1).isMBB()); - MachineBasicBlock *MBB = CBr->getOperand(1).getMBB(); - - // Map switch case Value to target MBB. - ValuesToMBBs[Value] = MBB; - - // Add target MBB as successor to the switch's MBB. - Switch->getParent()->addSuccessor(MBB); - - // The next MI is always G_BR to either the next case or the default. - MachineInstr *NextMI = CBr->getNextNode(); - assert(NextMI->getOpcode() == SPIRV::G_BR && - NextMI->getOperand(0).isMBB()); - MachineBasicBlock *NextMBB = NextMI->getOperand(0).getMBB(); - // Default MBB does not begin with G_ICMP using spv_switch compare - // register. - if (NextMBB->front().getOpcode() != SPIRV::G_ICMP || - (NextMBB->front().getOperand(2).isReg() && - NextMBB->front().getOperand(2).getReg() != CompareReg)) { - // Set default MBB and add it as successor to the switch's MBB. - DefaultMBB = NextMBB; - Switch->getParent()->addSuccessor(DefaultMBB); + SmallPtrSet ToEraseMI; + for (auto &SwIt : Switches) { + MachineInstr &MI = *SwIt.first; + SmallVector &Ins = SwIt.second; + SmallVector NewOps; + for (unsigned i = 0; i < Ins.size(); ++i) { + if (Ins[i]->getOpcode() == TargetOpcode::G_BLOCK_ADDR) { + BasicBlock *CaseBB = + Ins[i]->getOperand(1).getBlockAddress()->getBasicBlock(); + auto It = BB2MBB.find(CaseBB); + if (It == BB2MBB.end()) + report_fatal_error("cannot find a machine basic block by a basic " + "block in a switch statement"); + NewOps.push_back(MachineOperand::CreateMBB(It->second)); + MI.getParent()->addSuccessor(It->second); + ToEraseMI.insert(Ins[i]); + } else { + NewOps.push_back( + MachineOperand::CreateCImm(Ins[i]->getOperand(1).getCImm())); } } - - // Modify considered spv_switch operands using collected Values and - // MBBs. - SmallVector Values; - SmallVector MBBs; - for (unsigned k = 2; k < Switch->getNumExplicitOperands(); k++) { - Register CReg = Switch->getOperand(k).getReg(); - uint64_t Val = getIConstVal(CReg, &MRI); - MachineInstr *ConstInstr = getDefInstrMaybeConstant(CReg, &MRI); - if (!ValuesToMBBs[Val]) - continue; - - Values.push_back(ConstInstr->getOperand(1).getCImm()); - MBBs.push_back(ValuesToMBBs[Val]); - } - - for (unsigned k = Switch->getNumExplicitOperands() - 1; k > 1; k--) - Switch->removeOperand(k); - - Switch->addOperand(MachineOperand::CreateMBB(DefaultMBB)); - for (unsigned k = 0; k < Values.size(); k++) { - Switch->addOperand(MachineOperand::CreateCImm(Values[k])); - Switch->addOperand(MachineOperand::CreateMBB(MBBs[k])); - } - } - - for (MachineInstr *MI : PostUpdateArtifacts) { - MachineBasicBlock *ParentMBB = MI->getParent(); - MI->eraseFromParent(); - // If G_ICMP + G_BRCOND + G_BR were the only MIs in MBB, erase this MBB. It - // can be safely assumed, there are no breaks or phis directing into this - // MBB. However, we need to remove this MBB from the CFG graph. MBBs must be - // erased top-down. - if (ParentMBB->empty()) { - while (!ParentMBB->pred_empty()) - (*ParentMBB->pred_begin())->removeSuccessor(ParentMBB); - - while (!ParentMBB->succ_empty()) - ParentMBB->removeSuccessor(ParentMBB->succ_begin()); - - ParentMBB->eraseFromParent(); + for (unsigned i = MI.getNumOperands() - 1; i > 1; --i) + MI.removeOperand(i); + for (auto &MO : NewOps) + MI.addOperand(MO); + if (MachineInstr *Next = MI.getNextNode()) { + if (isSpvIntrinsic(*Next, Intrinsic::spv_track_constant)) { + ToEraseMI.insert(Next); + Next = MI.getNextNode(); + } + if (Next && Next->getOpcode() == TargetOpcode::G_BRINDIRECT) + ToEraseMI.insert(Next); } } + for (MachineInstr *BlockAddrI : ToEraseMI) + BlockAddrI->eraseFromParent(); } static bool isImplicitFallthrough(MachineBasicBlock &MBB) { diff --git a/llvm/test/CodeGen/SPIRV/branching/OpSwitchUnreachable.ll b/llvm/test/CodeGen/SPIRV/branching/OpSwitchUnreachable.ll index e73efbeade70d..6eb36e5756ecf 100644 --- a/llvm/test/CodeGen/SPIRV/branching/OpSwitchUnreachable.ll +++ b/llvm/test/CodeGen/SPIRV/branching/OpSwitchUnreachable.ll @@ -1,8 +1,9 @@ ; RUN: llc -O0 -mtriple=spirv32-unknown-unknown %s -o - | FileCheck %s --check-prefix=CHECK-SPIRV +; RUN: %if spirv-tools %{ llc -O0 -mtriple=spirv32-unknown-unknown %s -o - -filetype=obj | spirv-val %} define void @test_switch_with_unreachable_block(i1 %a) { %value = zext i1 %a to i32 -; CHECK-SPIRV: OpSwitch %[[#]] %[[#REACHABLE:]] +; CHECK-SPIRV: OpSwitch %[[#]] %[[#UNREACHABLE:]] 0 %[[#REACHABLE:]] 1 %[[#REACHABLE:]] switch i32 %value, label %unreachable [ i32 0, label %reachable i32 1, label %reachable @@ -13,7 +14,7 @@ reachable: ; CHECK-SPIRV-NEXT: OpReturn ret void -; CHECK-SPIRV: %[[#]] = OpLabel +; CHECK-SPIRV: %[[#UNREACHABLE]] = OpLabel ; CHECK-SPIRV-NEXT: OpUnreachable unreachable: unreachable diff --git a/llvm/test/CodeGen/SPIRV/branching/switch-range-check.ll b/llvm/test/CodeGen/SPIRV/branching/switch-range-check.ll new file mode 100644 index 0000000000000..85a4d4db089cb --- /dev/null +++ b/llvm/test/CodeGen/SPIRV/branching/switch-range-check.ll @@ -0,0 +1,73 @@ +; RUN: llc -O0 -mtriple=spirv32-unknown-unknown %s -o - | FileCheck %s +; RUN: %if spirv-tools %{ llc -O0 -mtriple=spirv32-unknown-unknown %s -o - -filetype=obj | spirv-val %} + +; CHECK: %[[#Var:]] = OpPhi +; CHECK: OpSwitch %[[#Var]] %[[#]] [[#]] %[[#]] [[#]] %[[#]] [[#]] %[[#]] [[#]] %[[#]] [[#]] %[[#]] [[#]] %[[#]] [[#]] %[[#]] [[#]] %[[#]] [[#]] %[[#]] [[#]] %[[#]] [[#]] %[[#]] [[#]] %[[#]] +; CHECK-COUNT-11: OpBranch +; CHECK-NOT: OpBranch + +define spir_func void @foo(i64 noundef %addr, i64 noundef %as) { +entry: + %src = inttoptr i64 %as to ptr addrspace(4) + %val = load i8, ptr addrspace(4) %src + %cmp = icmp sgt i8 %val, 0 + br i1 %cmp, label %if.then, label %if.end + +if.then: + %add.ptr = getelementptr inbounds i8, ptr addrspace(4) %src, i64 1 + %cond = load i8, ptr addrspace(4) %add.ptr + br label %if.end + +if.end: + %swval = phi i8 [ %cond, %if.then ], [ %val, %entry ] + switch i8 %swval, label %sw.default [ + i8 -127, label %sw.epilog + i8 -126, label %sw.bb3 + i8 -125, label %sw.bb4 + i8 -111, label %sw.bb5 + i8 -110, label %sw.bb6 + i8 -109, label %sw.bb7 + i8 -15, label %sw.bb8 + i8 -14, label %sw.bb8 + i8 -13, label %sw.bb8 + i8 -124, label %sw.bb9 + i8 -95, label %sw.bb10 + i8 -123, label %sw.bb11 + ] + +sw.bb3: + br label %sw.epilog + +sw.bb4: + br label %sw.epilog + +sw.bb5: + br label %sw.epilog + +sw.bb6: + br label %sw.epilog + +sw.bb7: + br label %sw.epilog + +sw.bb8: + br label %sw.epilog + +sw.bb9: + br label %sw.epilog + +sw.bb10: + br label %sw.epilog + +sw.bb11: + br label %sw.epilog + +sw.default: + br label %sw.epilog + +sw.epilog: + br label %exit + +exit: + ret void +}