diff --git a/llvm/lib/Target/SPIRV/SPIRV.h b/llvm/lib/Target/SPIRV/SPIRV.h index 6979107349d96..fb8580cd47c01 100644 --- a/llvm/lib/Target/SPIRV/SPIRV.h +++ b/llvm/lib/Target/SPIRV/SPIRV.h @@ -24,7 +24,7 @@ FunctionPass *createSPIRVStripConvergenceIntrinsicsPass(); FunctionPass *createSPIRVRegularizerPass(); FunctionPass *createSPIRVPreLegalizerPass(); FunctionPass *createSPIRVPostLegalizerPass(); -FunctionPass *createSPIRVEmitIntrinsicsPass(SPIRVTargetMachine *TM); +ModulePass *createSPIRVEmitIntrinsicsPass(SPIRVTargetMachine *TM); InstructionSelector * createSPIRVInstructionSelector(const SPIRVTargetMachine &TM, const SPIRVSubtarget &Subtarget, diff --git a/llvm/lib/Target/SPIRV/SPIRVCallLowering.cpp b/llvm/lib/Target/SPIRV/SPIRVCallLowering.cpp index 9e4ba2191366b..c107b99cf4cb6 100644 --- a/llvm/lib/Target/SPIRV/SPIRVCallLowering.cpp +++ b/llvm/lib/Target/SPIRV/SPIRVCallLowering.cpp @@ -383,7 +383,16 @@ bool SPIRVCallLowering::lowerFormalArguments(MachineIRBuilder &MIRBuilder, if (F.isDeclaration()) GR->add(&F, &MIRBuilder.getMF(), FuncVReg); FunctionType *FTy = getOriginalFunctionType(F); - SPIRVType *RetTy = GR->getOrCreateSPIRVType(FTy->getReturnType(), MIRBuilder); + Type *FRetTy = FTy->getReturnType(); + if (isUntypedPointerTy(FRetTy)) { + if (Type *FRetElemTy = GR->findDeducedElementType(&F)) { + TypedPointerType *DerivedTy = + TypedPointerType::get(FRetElemTy, getPointerAddressSpace(FRetTy)); + GR->addReturnType(&F, DerivedTy); + FRetTy = DerivedTy; + } + } + SPIRVType *RetTy = GR->getOrCreateSPIRVType(FRetTy, MIRBuilder); FTy = fixFunctionTypeIfPtrArgs(GR, F, FTy, RetTy, ArgTypeVRegs); SPIRVType *FuncTy = GR->getOrCreateOpTypeFunctionWithArgs( FTy, RetTy, ArgTypeVRegs, MIRBuilder); @@ -505,8 +514,13 @@ bool SPIRVCallLowering::lowerCall(MachineIRBuilder &MIRBuilder, // TODO: support constexpr casts and indirect calls. if (CF == nullptr) return false; - if (FunctionType *FTy = getOriginalFunctionType(*CF)) + if (FunctionType *FTy = getOriginalFunctionType(*CF)) { OrigRetTy = FTy->getReturnType(); + if (isUntypedPointerTy(OrigRetTy)) { + if (auto *DerivedRetTy = GR->findReturnType(CF)) + OrigRetTy = DerivedRetTy; + } + } } MachineRegisterInfo *MRI = MIRBuilder.getMRI(); diff --git a/llvm/lib/Target/SPIRV/SPIRVEmitIntrinsics.cpp b/llvm/lib/Target/SPIRV/SPIRVEmitIntrinsics.cpp index e8ce5a35b457d..472bc8638c9af 100644 --- a/llvm/lib/Target/SPIRV/SPIRVEmitIntrinsics.cpp +++ b/llvm/lib/Target/SPIRV/SPIRVEmitIntrinsics.cpp @@ -51,7 +51,7 @@ void initializeSPIRVEmitIntrinsicsPass(PassRegistry &); namespace { class SPIRVEmitIntrinsics - : public FunctionPass, + : public ModulePass, public InstVisitor { SPIRVTargetMachine *TM = nullptr; SPIRVGlobalRegistry *GR = nullptr; @@ -61,6 +61,9 @@ class SPIRVEmitIntrinsics DenseMap AggrConstTypes; DenseSet AggrStores; + // a registry of created Intrinsic::spv_assign_ptr_type instructions + DenseMap AssignPtrTypeInstr; + // deduce element type of untyped pointers Type *deduceElementType(Value *I); Type *deduceElementTypeHelper(Value *I); @@ -75,6 +78,9 @@ class SPIRVEmitIntrinsics Type *deduceNestedTypeHelper(User *U, Type *Ty, std::unordered_set &Visited); + // deduce Types of operands of the Instruction if possible + void deduceOperandElementType(Instruction *I); + void preprocessCompositeConstants(IRBuilder<> &B); void preprocessUndefs(IRBuilder<> &B); @@ -111,10 +117,10 @@ class SPIRVEmitIntrinsics public: static char ID; - SPIRVEmitIntrinsics() : FunctionPass(ID) { + SPIRVEmitIntrinsics() : ModulePass(ID) { initializeSPIRVEmitIntrinsicsPass(*PassRegistry::getPassRegistry()); } - SPIRVEmitIntrinsics(SPIRVTargetMachine *_TM) : FunctionPass(ID), TM(_TM) { + SPIRVEmitIntrinsics(SPIRVTargetMachine *_TM) : ModulePass(ID), TM(_TM) { initializeSPIRVEmitIntrinsicsPass(*PassRegistry::getPassRegistry()); } Instruction *visitInstruction(Instruction &I) { return &I; } @@ -130,7 +136,15 @@ class SPIRVEmitIntrinsics Instruction *visitAllocaInst(AllocaInst &I); Instruction *visitAtomicCmpXchgInst(AtomicCmpXchgInst &I); Instruction *visitUnreachableInst(UnreachableInst &I); - bool runOnFunction(Function &F) override; + + StringRef getPassName() const override { return "SPIRV emit intrinsics"; } + + bool runOnModule(Module &M) override; + bool runOnFunction(Function &F); + + void getAnalysisUsage(AnalysisUsage &AU) const override { + ModulePass::getAnalysisUsage(AU); + } }; } // namespace @@ -269,6 +283,12 @@ Type *SPIRVEmitIntrinsics::deduceElementTypeHelper( if (Ty) break; } + } else if (auto *Ref = dyn_cast(I)) { + for (Value *Op : {Ref->getTrueValue(), Ref->getFalseValue()}) { + Ty = deduceElementTypeByUsersDeep(Op, Visited); + if (Ty) + break; + } } // remember the found relationship @@ -368,6 +388,112 @@ Type *SPIRVEmitIntrinsics::deduceElementType(Value *I) { return IntegerType::getInt8Ty(I->getContext()); } +// If the Instruction has Pointer operands with unresolved types, this function +// tries to deduce them. If the Instruction has Pointer operands with known +// types which differ from expected, this function tries to insert a bitcast to +// resolve the issue. +void SPIRVEmitIntrinsics::deduceOperandElementType(Instruction *I) { + SmallVector> Ops; + Type *KnownElemTy = nullptr; + // look for known basic patterns of type inference + if (auto *Ref = dyn_cast(I)) { + if (!isPointerTy(I->getType()) || + !(KnownElemTy = GR->findDeducedElementType(I))) + return; + for (unsigned i = 0; i < Ref->getNumIncomingValues(); i++) { + Value *Op = Ref->getIncomingValue(i); + if (isPointerTy(Op->getType())) + Ops.push_back(std::make_pair(Op, i)); + } + } else if (auto *Ref = dyn_cast(I)) { + if (!isPointerTy(I->getType()) || + !(KnownElemTy = GR->findDeducedElementType(I))) + return; + for (unsigned i = 0; i < Ref->getNumOperands(); i++) { + Value *Op = Ref->getOperand(i); + if (isPointerTy(Op->getType())) + Ops.push_back(std::make_pair(Op, i)); + } + } else if (auto *Ref = dyn_cast(I)) { + Type *RetTy = F->getReturnType(); + if (!isPointerTy(RetTy)) + return; + Value *Op = Ref->getReturnValue(); + if (!Op) + return; + if (!(KnownElemTy = GR->findDeducedElementType(F))) { + if (Type *OpElemTy = GR->findDeducedElementType(Op)) { + GR->addDeducedElementType(F, OpElemTy); + TypedPointerType *DerivedTy = + TypedPointerType::get(OpElemTy, getPointerAddressSpace(RetTy)); + GR->addReturnType(F, DerivedTy); + } + return; + } + Ops.push_back(std::make_pair(Op, 0)); + } else if (auto *Ref = dyn_cast(I)) { + if (!isPointerTy(Ref->getOperand(0)->getType())) + return; + Value *Op0 = Ref->getOperand(0); + Value *Op1 = Ref->getOperand(1); + Type *ElemTy0 = GR->findDeducedElementType(Op0); + Type *ElemTy1 = GR->findDeducedElementType(Op1); + if (ElemTy0) { + KnownElemTy = ElemTy0; + Ops.push_back(std::make_pair(Op1, 1)); + } else if (ElemTy1) { + KnownElemTy = ElemTy1; + Ops.push_back(std::make_pair(Op0, 0)); + } + } + + // There is no enough info to deduce types or all is valid. + if (!KnownElemTy || Ops.size() == 0) + return; + + LLVMContext &Ctx = F->getContext(); + IRBuilder<> B(Ctx); + for (auto &OpIt : Ops) { + Value *Op = OpIt.first; + if (Op->use_empty()) + continue; + Type *Ty = GR->findDeducedElementType(Op); + if (Ty == KnownElemTy) + continue; + if (Instruction *User = dyn_cast(Op->use_begin()->get())) + setInsertPointSkippingPhis(B, User->getNextNode()); + else + B.SetInsertPoint(I); + Value *OpTyVal = Constant::getNullValue(KnownElemTy); + Type *OpTy = Op->getType(); + if (!Ty) { + GR->addDeducedElementType(Op, KnownElemTy); + // check if there is existing Intrinsic::spv_assign_ptr_type instruction + auto It = AssignPtrTypeInstr.find(Op); + if (It == AssignPtrTypeInstr.end()) { + CallInst *CI = + buildIntrWithMD(Intrinsic::spv_assign_ptr_type, {OpTy}, OpTyVal, Op, + {B.getInt32(getPointerAddressSpace(OpTy))}, B); + AssignPtrTypeInstr[Op] = CI; + } else { + It->second->setArgOperand( + 1, + MetadataAsValue::get( + Ctx, MDNode::get(Ctx, ValueAsMetadata::getConstant(OpTyVal)))); + } + } else { + SmallVector Types = {OpTy, OpTy}; + MetadataAsValue *VMD = MetadataAsValue::get( + Ctx, MDNode::get(Ctx, ValueAsMetadata::getConstant(OpTyVal))); + SmallVector Args = {Op, VMD, + B.getInt32(getPointerAddressSpace(OpTy))}; + CallInst *PtrCastI = + B.CreateIntrinsic(Intrinsic::spv_ptrcast, {Types}, Args); + I->setOperand(OpIt.second, PtrCastI); + } + } +} + void SPIRVEmitIntrinsics::replaceMemInstrUses(Instruction *Old, Instruction *New, IRBuilder<> &B) { @@ -630,6 +756,7 @@ void SPIRVEmitIntrinsics::replacePointerOperandWithPtrCast( ExpectedElementTypeConst, Pointer, {B.getInt32(AddressSpace)}, B); GR->addDeducedElementType(CI, ExpectedElementType); GR->addDeducedElementType(Pointer, ExpectedElementType); + AssignPtrTypeInstr[Pointer] = CI; return; } @@ -914,6 +1041,7 @@ void SPIRVEmitIntrinsics::insertAssignPtrTypeIntrs(Instruction *I, CallInst *CI = buildIntrWithMD(Intrinsic::spv_assign_ptr_type, {I->getType()}, EltTyConst, I, {B.getInt32(AddressSpace)}, B); GR->addDeducedElementType(CI, ElemTy); + AssignPtrTypeInstr[I] = CI; } void SPIRVEmitIntrinsics::insertAssignTypeIntrs(Instruction *I, @@ -1070,6 +1198,7 @@ void SPIRVEmitIntrinsics::processParamTypes(Function *F, IRBuilder<> &B) { {B.getInt32(getPointerAddressSpace(Arg->getType()))}, B); GR->addDeducedElementType(AssignPtrTyCI, ElemTy); GR->addDeducedElementType(Arg, ElemTy); + AssignPtrTypeInstr[Arg] = AssignPtrTyCI; } } } @@ -1114,6 +1243,10 @@ bool SPIRVEmitIntrinsics::runOnFunction(Function &Func) { insertAssignTypeIntrs(I, B); insertPtrCastOrAssignTypeInstr(I, B); } + + for (auto &I : instructions(Func)) + deduceOperandElementType(&I); + for (auto *I : Worklist) { TrackConstants = true; if (!I->getType()->isVoidTy() || isa(I)) @@ -1126,13 +1259,29 @@ bool SPIRVEmitIntrinsics::runOnFunction(Function &Func) { processInstrAfterVisit(I, B); } - // check if function parameter types are set - if (!F->isIntrinsic()) - processParamTypes(F, B); - return true; } -FunctionPass *llvm::createSPIRVEmitIntrinsicsPass(SPIRVTargetMachine *TM) { +bool SPIRVEmitIntrinsics::runOnModule(Module &M) { + bool Changed = false; + + for (auto &F : M) { + Changed |= runOnFunction(F); + } + + for (auto &F : M) { + // check if function parameter types are set + if (!F.isDeclaration() && !F.isIntrinsic()) { + const SPIRVSubtarget &ST = TM->getSubtarget(F); + GR = ST.getSPIRVGlobalRegistry(); + IRBuilder<> B(F.getContext()); + processParamTypes(&F, B); + } + } + + return Changed; +} + +ModulePass *llvm::createSPIRVEmitIntrinsicsPass(SPIRVTargetMachine *TM) { return new SPIRVEmitIntrinsics(TM); } diff --git a/llvm/lib/Target/SPIRV/SPIRVGlobalRegistry.cpp b/llvm/lib/Target/SPIRV/SPIRVGlobalRegistry.cpp index 70197e948c658..05e41e06248e3 100644 --- a/llvm/lib/Target/SPIRV/SPIRVGlobalRegistry.cpp +++ b/llvm/lib/Target/SPIRV/SPIRVGlobalRegistry.cpp @@ -23,7 +23,6 @@ #include "llvm/ADT/APInt.h" #include "llvm/IR/Constants.h" #include "llvm/IR/Type.h" -#include "llvm/IR/TypedPointerType.h" #include "llvm/Support/Casting.h" #include @@ -61,7 +60,6 @@ SPIRVType *SPIRVGlobalRegistry::assignVectTypeToVReg( SPIRVType *SPIRVGlobalRegistry::assignTypeToVReg( const Type *Type, Register VReg, MachineIRBuilder &MIRBuilder, SPIRV::AccessQualifier::AccessQualifier AccessQual, bool EmitIR) { - SPIRVType *SpirvType = getOrCreateSPIRVType(Type, MIRBuilder, AccessQual, EmitIR); assignSPIRVTypeToVReg(SpirvType, VReg, MIRBuilder.getMF()); @@ -726,7 +724,8 @@ SPIRVType *SPIRVGlobalRegistry::getOpTypeStruct(const StructType *Ty, bool EmitIR) { SmallVector FieldTypes; for (const auto &Elem : Ty->elements()) { - SPIRVType *ElemTy = findSPIRVType(Elem, MIRBuilder); + SPIRVType *ElemTy = + findSPIRVType(toTypedPointer(Elem, Ty->getContext()), MIRBuilder); assert(ElemTy && ElemTy->getOpcode() != SPIRV::OpTypeVoid && "Invalid struct element type"); FieldTypes.push_back(getSPIRVTypeID(ElemTy)); @@ -919,8 +918,10 @@ SPIRVType *SPIRVGlobalRegistry::restOfCreateSPIRVType( return SpirvType; } -SPIRVType *SPIRVGlobalRegistry::getSPIRVTypeForVReg(Register VReg) const { - auto t = VRegToTypeMap.find(CurMF); +SPIRVType * +SPIRVGlobalRegistry::getSPIRVTypeForVReg(Register VReg, + const MachineFunction *MF) const { + auto t = VRegToTypeMap.find(MF ? MF : CurMF); if (t != VRegToTypeMap.end()) { auto tt = t->second.find(VReg); if (tt != t->second.end()) diff --git a/llvm/lib/Target/SPIRV/SPIRVGlobalRegistry.h b/llvm/lib/Target/SPIRV/SPIRVGlobalRegistry.h index 2e3e69456ac26..55979ba403a0e 100644 --- a/llvm/lib/Target/SPIRV/SPIRVGlobalRegistry.h +++ b/llvm/lib/Target/SPIRV/SPIRVGlobalRegistry.h @@ -21,6 +21,7 @@ #include "SPIRVInstrInfo.h" #include "llvm/CodeGen/GlobalISel/MachineIRBuilder.h" #include "llvm/IR/Constant.h" +#include "llvm/IR/TypedPointerType.h" namespace llvm { using SPIRVType = const MachineInstr; @@ -58,6 +59,9 @@ class SPIRVGlobalRegistry { SmallPtrSet TypesInProcessing; DenseMap ForwardPointerTypes; + // if a function returns a pointer, this is to map it into TypedPointerType + DenseMap FunResPointerTypes; + // Number of bits pointers and size_t integers require. const unsigned PointerSize; @@ -134,6 +138,16 @@ class SPIRVGlobalRegistry { void setBound(unsigned V) { Bound = V; } unsigned getBound() { return Bound; } + // Add a record to the map of function return pointer types. + void addReturnType(const Function *ArgF, TypedPointerType *DerivedTy) { + FunResPointerTypes[ArgF] = DerivedTy; + } + // Find a record in the map of function return pointer types. + const TypedPointerType *findReturnType(const Function *ArgF) { + auto It = FunResPointerTypes.find(ArgF); + return It == FunResPointerTypes.end() ? nullptr : It->second; + } + // Deduced element types of untyped pointers and composites: // - Add a record to the map of deduced element types. void addDeducedElementType(Value *Val, Type *Ty) { DeducedElTys[Val] = Ty; } @@ -276,8 +290,12 @@ class SPIRVGlobalRegistry { SPIRV::AccessQualifier::ReadWrite); // Return the SPIR-V type instruction corresponding to the given VReg, or - // nullptr if no such type instruction exists. - SPIRVType *getSPIRVTypeForVReg(Register VReg) const; + // nullptr if no such type instruction exists. The second argument MF + // allows to search for the association in a context of the machine functions + // than the current one, without switching between different "current" machine + // functions. + SPIRVType *getSPIRVTypeForVReg(Register VReg, + const MachineFunction *MF = nullptr) const; // Whether the given VReg has a SPIR-V type mapped to it yet. bool hasSPIRVTypeForVReg(Register VReg) const { diff --git a/llvm/lib/Target/SPIRV/SPIRVISelLowering.cpp b/llvm/lib/Target/SPIRV/SPIRVISelLowering.cpp index 8db54c74f2369..b8296c3f6eeae 100644 --- a/llvm/lib/Target/SPIRV/SPIRVISelLowering.cpp +++ b/llvm/lib/Target/SPIRV/SPIRVISelLowering.cpp @@ -88,19 +88,24 @@ static void validatePtrTypes(const SPIRVSubtarget &STI, MachineRegisterInfo *MRI, SPIRVGlobalRegistry &GR, MachineInstr &I, unsigned OpIdx, SPIRVType *ResType, const Type *ResTy = nullptr) { + // Get operand type + MachineFunction *MF = I.getParent()->getParent(); Register OpReg = I.getOperand(OpIdx).getReg(); SPIRVType *TypeInst = MRI->getVRegDef(OpReg); - SPIRVType *OpType = GR.getSPIRVTypeForVReg( + Register OpTypeReg = TypeInst && TypeInst->getOpcode() == SPIRV::OpFunctionParameter ? TypeInst->getOperand(1).getReg() - : OpReg); + : OpReg; + SPIRVType *OpType = GR.getSPIRVTypeForVReg(OpTypeReg, MF); if (!ResType || !OpType || OpType->getOpcode() != SPIRV::OpTypePointer) return; - SPIRVType *ElemType = GR.getSPIRVTypeForVReg(OpType->getOperand(2).getReg()); + // Get operand's pointee type + Register ElemTypeReg = OpType->getOperand(2).getReg(); + SPIRVType *ElemType = GR.getSPIRVTypeForVReg(ElemTypeReg, MF); if (!ElemType) return; - bool IsSameMF = - ElemType->getParent()->getParent() == ResType->getParent()->getParent(); + // Check if we need a bitcast to make a statement valid + bool IsSameMF = MF == ResType->getParent()->getParent(); bool IsEqualTypes = IsSameMF ? ElemType == ResType : GR.getTypeForSPIRVType(ElemType) == ResTy; if (IsEqualTypes) @@ -156,7 +161,8 @@ void validateFunCallMachineDef(const SPIRVSubtarget &STI, SPIRVType *DefPtrType = DefMRI->getVRegDef(FunDef->getOperand(1).getReg()); SPIRVType *DefElemType = DefPtrType && DefPtrType->getOpcode() == SPIRV::OpTypePointer - ? GR.getSPIRVTypeForVReg(DefPtrType->getOperand(2).getReg()) + ? GR.getSPIRVTypeForVReg(DefPtrType->getOperand(2).getReg(), + DefPtrType->getParent()->getParent()) : nullptr; if (DefElemType) { const Type *DefElemTy = GR.getTypeForSPIRVType(DefElemType); @@ -177,7 +183,7 @@ void validateFunCallMachineDef(const SPIRVSubtarget &STI, // with a processed definition. Return Function pointer if it's a forward // call (ahead of definition), and nullptr otherwise. const Function *validateFunCall(const SPIRVSubtarget &STI, - MachineRegisterInfo *MRI, + MachineRegisterInfo *CallMRI, SPIRVGlobalRegistry &GR, MachineInstr &FunCall) { const GlobalValue *GV = FunCall.getOperand(2).getGlobal(); @@ -186,7 +192,8 @@ const Function *validateFunCall(const SPIRVSubtarget &STI, const_cast(GR.getFunctionDefinition(F)); if (!FunDef) return F; - validateFunCallMachineDef(STI, MRI, MRI, GR, FunCall, FunDef); + MachineRegisterInfo *DefMRI = &FunDef->getParent()->getParent()->getRegInfo(); + validateFunCallMachineDef(STI, DefMRI, CallMRI, GR, FunCall, FunDef); return nullptr; } diff --git a/llvm/lib/Target/SPIRV/SPIRVInstrInfo.cpp b/llvm/lib/Target/SPIRV/SPIRVInstrInfo.cpp index e3f76419f1313..aacfecc1e313f 100644 --- a/llvm/lib/Target/SPIRV/SPIRVInstrInfo.cpp +++ b/llvm/lib/Target/SPIRV/SPIRVInstrInfo.cpp @@ -248,7 +248,7 @@ void SPIRVInstrInfo::copyPhysReg(MachineBasicBlock &MBB, bool SPIRVInstrInfo::expandPostRAPseudo(MachineInstr &MI) const { if (MI.getOpcode() == SPIRV::GET_ID || MI.getOpcode() == SPIRV::GET_fID || MI.getOpcode() == SPIRV::GET_pID || MI.getOpcode() == SPIRV::GET_vfID || - MI.getOpcode() == SPIRV::GET_vID) { + MI.getOpcode() == SPIRV::GET_vID || MI.getOpcode() == SPIRV::GET_vpID) { auto &MRI = MI.getMF()->getRegInfo(); MRI.replaceRegWith(MI.getOperand(0).getReg(), MI.getOperand(1).getReg()); MI.eraseFromParent(); diff --git a/llvm/lib/Target/SPIRV/SPIRVInstrInfo.td b/llvm/lib/Target/SPIRV/SPIRVInstrInfo.td index 99c57dac4141d..a3f981457c8da 100644 --- a/llvm/lib/Target/SPIRV/SPIRVInstrInfo.td +++ b/llvm/lib/Target/SPIRV/SPIRVInstrInfo.td @@ -22,6 +22,7 @@ let isCodeGenOnly=1 in { def GET_pID: Pseudo<(outs pID:$dst_id), (ins ANYID:$src)>; def GET_vID: Pseudo<(outs vID:$dst_id), (ins ANYID:$src)>; def GET_vfID: Pseudo<(outs vfID:$dst_id), (ins ANYID:$src)>; + def GET_vpID: Pseudo<(outs vpID:$dst_id), (ins ANYID:$src)>; } def SPVTypeBin : SDTypeProfile<1, 2, []>; @@ -55,7 +56,7 @@ multiclass BinOpTypedGen opCode, SDNode node, bit genF = 0 } } -multiclass TernOpTypedGen opCode, SDNode node, bit genI = 1, bit genF = 0, bit genV = 0> { +multiclass TernOpTypedGen opCode, SDNode node, bit genP = 1, bit genI = 1, bit genF = 0, bit genV = 0> { if genF then { def SFSCond: TernOpTyped; def SFVCond: TernOpTyped; @@ -64,6 +65,10 @@ multiclass TernOpTypedGen opCode, SDNode node, bit genI = def SISCond: TernOpTyped; def SIVCond: TernOpTyped; } + if genP then { + def SPSCond: TernOpTyped; + def SPVCond: TernOpTyped; + } if genV then { if genF then { def VFSCond: TernOpTyped; @@ -73,6 +78,10 @@ multiclass TernOpTypedGen opCode, SDNode node, bit genI = def VISCond: TernOpTyped; def VIVCond: TernOpTyped; } + if genP then { + def VPSCond: TernOpTyped; + def VPVCond: TernOpTyped; + } } } @@ -552,7 +561,7 @@ def OpLogicalOr: BinOp<"OpLogicalOr", 166>; def OpLogicalAnd: BinOp<"OpLogicalAnd", 167>; def OpLogicalNot: UnOp<"OpLogicalNot", 168>; -defm OpSelect: TernOpTypedGen<"OpSelect", 169, select, 1, 1, 1>; +defm OpSelect: TernOpTypedGen<"OpSelect", 169, select, 1, 1, 1, 1>; def OpIEqual: BinOp<"OpIEqual", 170>; def OpINotEqual: BinOp<"OpINotEqual", 171>; diff --git a/llvm/lib/Target/SPIRV/SPIRVPostLegalizer.cpp b/llvm/lib/Target/SPIRV/SPIRVPostLegalizer.cpp index b9d66de9555b1..f069a92ac6868 100644 --- a/llvm/lib/Target/SPIRV/SPIRVPostLegalizer.cpp +++ b/llvm/lib/Target/SPIRV/SPIRVPostLegalizer.cpp @@ -56,7 +56,7 @@ extern void processInstr(MachineInstr &MI, MachineIRBuilder &MIB, static bool isMetaInstrGET(unsigned Opcode) { return Opcode == SPIRV::GET_ID || Opcode == SPIRV::GET_fID || Opcode == SPIRV::GET_pID || Opcode == SPIRV::GET_vID || - Opcode == SPIRV::GET_vfID; + Opcode == SPIRV::GET_vfID || Opcode == SPIRV::GET_vpID; } static bool mayBeInserted(unsigned Opcode) { diff --git a/llvm/lib/Target/SPIRV/SPIRVPreLegalizer.cpp b/llvm/lib/Target/SPIRV/SPIRVPreLegalizer.cpp index 7e155a36aadbc..2c964595fc39e 100644 --- a/llvm/lib/Target/SPIRV/SPIRVPreLegalizer.cpp +++ b/llvm/lib/Target/SPIRV/SPIRVPreLegalizer.cpp @@ -64,9 +64,16 @@ static void addConstantsToTrack(MachineFunction &MF, SPIRVGlobalRegistry *GR) { auto *BuildVec = MRI.getVRegDef(MI.getOperand(2).getReg()); assert(BuildVec && BuildVec->getOpcode() == TargetOpcode::G_BUILD_VECTOR); - for (unsigned i = 0; i < ConstVec->getNumElements(); ++i) - GR->add(ConstVec->getElementAsConstant(i), &MF, - BuildVec->getOperand(1 + i).getReg()); + for (unsigned i = 0; i < ConstVec->getNumElements(); ++i) { + // Ensure that OpConstantComposite reuses a constant when it's + // already created and available in the same machine function. + Constant *ElemConst = ConstVec->getElementAsConstant(i); + Register ElemReg = GR->find(ElemConst, &MF); + if (!ElemReg.isValid()) + GR->add(ElemConst, &MF, BuildVec->getOperand(1 + i).getReg()); + else + BuildVec->getOperand(1 + i).setReg(ElemReg); + } } GR->add(Const, &MF, MI.getOperand(2).getReg()); } else { diff --git a/llvm/lib/Target/SPIRV/SPIRVRegisterBankInfo.cpp b/llvm/lib/Target/SPIRV/SPIRVRegisterBankInfo.cpp index 9bf9d7fe5b39e..5983c9229cb3c 100644 --- a/llvm/lib/Target/SPIRV/SPIRVRegisterBankInfo.cpp +++ b/llvm/lib/Target/SPIRV/SPIRVRegisterBankInfo.cpp @@ -39,6 +39,8 @@ SPIRVRegisterBankInfo::getRegBankFromRegClass(const TargetRegisterClass &RC, return SPIRV::vIDRegBank; case SPIRV::vfIDRegClassID: return SPIRV::vfIDRegBank; + case SPIRV::vpIDRegClassID: + return SPIRV::vpIDRegBank; case SPIRV::ANYIDRegClassID: case SPIRV::ANYRegClassID: return SPIRV::IDRegBank; diff --git a/llvm/lib/Target/SPIRV/SPIRVRegisterBanks.td b/llvm/lib/Target/SPIRV/SPIRVRegisterBanks.td index 90c7f3a6e6726..c7f1e172f3d4f 100644 --- a/llvm/lib/Target/SPIRV/SPIRVRegisterBanks.td +++ b/llvm/lib/Target/SPIRV/SPIRVRegisterBanks.td @@ -12,4 +12,5 @@ def IDRegBank : RegisterBank<"IDBank", [ID]>; def fIDRegBank : RegisterBank<"fIDBank", [fID]>; def vIDRegBank : RegisterBank<"vIDBank", [vID]>; def vfIDRegBank : RegisterBank<"vfIDBank", [vfID]>; +def vpIDRegBank : RegisterBank<"vpIDBank", [vpID]>; def TYPERegBank : RegisterBank<"TYPEBank", [TYPE]>; diff --git a/llvm/lib/Target/SPIRV/SPIRVRegisterInfo.td b/llvm/lib/Target/SPIRV/SPIRVRegisterInfo.td index d0b64b6895d03..6d2bfb91a97f1 100644 --- a/llvm/lib/Target/SPIRV/SPIRVRegisterInfo.td +++ b/llvm/lib/Target/SPIRV/SPIRVRegisterInfo.td @@ -12,6 +12,17 @@ let Namespace = "SPIRV" in { def p0 : PtrValueType ; + + class P0Vec + : PtrValueType { + let nElem = 2; + let ElementType = p0; + let isInteger = false; + let isFP = false; + let isVector = true; + } + + def v2p0 : P0Vec; // All registers are for 32-bit identifiers, so have a single dummy register // Class for registers that are the result of OpTypeXXX instructions @@ -21,14 +32,16 @@ let Namespace = "SPIRV" in { // Class for every other non-type ID def ID0 : Register<"ID0">; def ID : RegisterClass<"SPIRV", [i32], 32, (add ID0)>; - def fID0 : Register<"FID0">; + def fID0 : Register<"fID0">; def fID : RegisterClass<"SPIRV", [f32], 32, (add fID0)>; def pID0 : Register<"pID0">; def pID : RegisterClass<"SPIRV", [p0], 32, (add pID0)>; - def vID0 : Register<"pID0">; + def vID0 : Register<"vID0">; def vID : RegisterClass<"SPIRV", [v2i32], 32, (add vID0)>; - def vfID0 : Register<"pID0">; + def vfID0 : Register<"vfID0">; def vfID : RegisterClass<"SPIRV", [v2f32], 32, (add vfID0)>; + def vpID0 : Register<"vpID0">; + def vpID : RegisterClass<"SPIRV", [v2p0], 32, (add vpID0)>; def ANYID : RegisterClass<"SPIRV", [i32, f32, p0, v2i32, v2f32], 32, (add ID, fID, pID, vID, vfID)>; diff --git a/llvm/lib/Target/SPIRV/SPIRVUtils.cpp b/llvm/lib/Target/SPIRV/SPIRVUtils.cpp index 299a4341193bf..2e44c208ed8e0 100644 --- a/llvm/lib/Target/SPIRV/SPIRVUtils.cpp +++ b/llvm/lib/Target/SPIRV/SPIRVUtils.cpp @@ -251,7 +251,8 @@ bool isSpvIntrinsic(const MachineInstr &MI, Intrinsic::ID IntrinsicID) { } Type *getMDOperandAsType(const MDNode *N, unsigned I) { - return cast(N->getOperand(I))->getType(); + Type *ElementTy = cast(N->getOperand(I))->getType(); + return toTypedPointer(ElementTy, N->getContext()); } // The set of names is borrowed from the SPIR-V translator. diff --git a/llvm/lib/Target/SPIRV/SPIRVUtils.h b/llvm/lib/Target/SPIRV/SPIRVUtils.h index c2c3475e1a936..cd1a2af09147e 100644 --- a/llvm/lib/Target/SPIRV/SPIRVUtils.h +++ b/llvm/lib/Target/SPIRV/SPIRVUtils.h @@ -149,5 +149,12 @@ inline Type *reconstructFunctionType(Function *F) { return FunctionType::get(F->getReturnType(), ArgTys, F->isVarArg()); } +inline Type *toTypedPointer(Type *Ty, LLVMContext &Ctx) { + return isUntypedPointerTy(Ty) + ? TypedPointerType::get(IntegerType::getInt8Ty(Ctx), + getPointerAddressSpace(Ty)) + : Ty; +} + } // namespace llvm #endif // LLVM_LIB_TARGET_SPIRV_SPIRVUTILS_H diff --git a/llvm/test/CodeGen/SPIRV/const-composite.ll b/llvm/test/CodeGen/SPIRV/const-composite.ll new file mode 100644 index 0000000000000..4e304bb951670 --- /dev/null +++ b/llvm/test/CodeGen/SPIRV/const-composite.ll @@ -0,0 +1,26 @@ +; This test is to ensure that OpConstantComposite reuses a constant when it's +; already created and available in the same machine function. In this test case +; it's `1` that is passed implicitly as a part of the `foo` function argument +; and also takes part in a composite constant creation. + +; RUN: llc -O0 -mtriple=spirv64-unknown-unknown %s -o - | FileCheck %s --check-prefix=CHECK-SPIRV +; RUN: %if spirv-tools %{ llc -O0 -mtriple=spirv64-unknown-unknown %s -o - -filetype=obj | spirv-val %} + +; CHECK-SPIRV: %[[#type_int32:]] = OpTypeInt 32 0 +; CHECK-SPIRV: %[[#const1:]] = OpConstant %[[#type_int32]] 1 +; CHECK-SPIRV: OpTypeArray %[[#]] %[[#const1:]] +; CHECK-SPIRV: %[[#const0:]] = OpConstant %[[#type_int32]] 0 +; CHECK-SPIRV: OpConstantComposite %[[#]] %[[#const0]] %[[#const1]] + +%struct = type { [1 x i64] } + +define spir_kernel void @foo(ptr noundef byval(%struct) %arg) { +entry: + call spir_func void @bar(<2 x i32> noundef ) + ret void +} + +define spir_func void @bar(<2 x i32> noundef) { +entry: + ret void +} diff --git a/llvm/test/CodeGen/SPIRV/instructions/ret-type.ll b/llvm/test/CodeGen/SPIRV/instructions/ret-type.ll new file mode 100644 index 0000000000000..bf71eb5628e21 --- /dev/null +++ b/llvm/test/CodeGen/SPIRV/instructions/ret-type.ll @@ -0,0 +1,82 @@ +; RUN: llc -O0 -mtriple=spirv32-unknown-unknown %s -o - | FileCheck %s +; RUN: %if spirv-tools %{ llc -O0 -mtriple=spirv32-unknown-unknown --translator-compatibility-mode %s -o - -filetype=obj | spirv-val %} +; RUN: %if spirv-tools %{ llc -O0 -mtriple=spirv32-unknown-unknown %s -o - -filetype=obj | spirv-val %} + +; CHECK-DAG: OpName %[[Test1:.*]] "test1" +; CHECK-DAG: OpName %[[Foo:.*]] "foo" +; CHECK-DAG: OpName %[[Bar:.*]] "bar" +; CHECK-DAG: OpName %[[Test2:.*]] "test2" + +; CHECK-DAG: %[[Long:.*]] = OpTypeInt 64 0 +; CHECK-DAG: %[[Array:.*]] = OpTypeArray %[[Long]] %[[#]] +; CHECK-DAG: %[[Struct1:.*]] = OpTypeStruct %[[Array]] +; CHECK-DAG: %[[Struct2:.*]] = OpTypeStruct %[[Struct1]] +; CHECK-DAG: %[[StructPtr:.*]] = OpTypePointer Function %[[Struct2]] +; CHECK-DAG: %[[Bool:.*]] = OpTypeBool +; CHECK-DAG: %[[FooType:.*]] = OpTypeFunction %[[StructPtr:.*]] %[[StructPtr]] %[[StructPtr]] %[[Bool]] +; CHECK-DAG: %[[Char:.*]] = OpTypeInt 8 0 +; CHECK-DAG: %[[CharPtr:.*]] = OpTypePointer Function %[[Char]] + +; CHECK: %[[Test1]] = OpFunction +; CHECK: OpFunctionCall %[[StructPtr:.*]] %[[Foo]] +; CHECK: OpFunctionCall %[[StructPtr:.*]] %[[Bar]] +; CHECK: OpFunctionEnd + +; CHECK: %[[Foo]] = OpFunction %[[StructPtr:.*]] None %[[FooType]] +; CHECK: %[[Arg1:.*]] = OpFunctionParameter %[[StructPtr]] +; CHECK: %[[Arg2:.*]] = OpFunctionParameter +; CHECK: %[[Sw:.*]] = OpFunctionParameter +; CHECK: %[[Res:.*]] = OpInBoundsPtrAccessChain %[[StructPtr]] %[[Arg1]] %[[#]] +; CHECK: OpReturnValue %[[Res]] +; CHECK: OpReturnValue %[[Arg2]] + +; CHECK: %[[Bar]] = OpFunction %[[StructPtr:.*]] None %[[#]] +; CHECK: %[[BarArg:.*]] = OpFunctionParameter +; CHECK: %[[BarRes:.*]] = OpInBoundsPtrAccessChain %[[CharPtr]] %[[BarArg]] %[[#]] +; CHECK: %[[BarResCasted:.*]] = OpBitcast %[[StructPtr]] %[[BarRes]] +; CHECK: %[[BarResStruct:.*]] = OpInBoundsPtrAccessChain %[[StructPtr]] %[[#]] %[[#]] +; CHECK: OpReturnValue %[[BarResStruct]] +; CHECK: OpReturnValue %[[BarResCasted]] + +; CHECK: %[[Test2]] = OpFunction +; CHECK: OpFunctionCall %[[StructPtr:.*]] %[[Foo]] +; CHECK: OpFunctionCall %[[StructPtr:.*]] %[[Bar]] +; CHECK: OpFunctionEnd + +%struct = type { %array } +%array = type { [1 x i64] } + +define spir_func void @test1(ptr %arg1, ptr %arg2, i1 %sw) { +entry: + %r1 = call ptr @foo(ptr %arg1, ptr %arg2, i1 %sw) + %r2 = call ptr @bar(ptr %arg1, i1 %sw) + ret void +} + +define spir_func ptr @foo(ptr %arg1, ptr %arg2, i1 %sw) { +entry: + br i1 %sw, label %exit, label %sw1 +sw1: + %result = getelementptr inbounds %struct, ptr %arg1, i64 100 + ret ptr %result +exit: + ret ptr %arg2 +} + +define spir_func ptr @bar(ptr %arg1, i1 %sw) { +entry: + %charptr = getelementptr inbounds i8, ptr %arg1, i64 0 + br i1 %sw, label %exit, label %sw1 +sw1: + %result = getelementptr inbounds %struct, ptr %arg1, i64 100 + ret ptr %result +exit: + ret ptr %charptr +} + +define spir_func void @test2(ptr %arg1, ptr %arg2, i1 %sw) { +entry: + %r1 = call ptr @foo(ptr %arg1, ptr %arg2, i1 %sw) + %r2 = call ptr @bar(ptr %arg1, i1 %sw) + ret void +} diff --git a/llvm/test/CodeGen/SPIRV/instructions/select-phi.ll b/llvm/test/CodeGen/SPIRV/instructions/select-phi.ll new file mode 100644 index 0000000000000..afc75c616f023 --- /dev/null +++ b/llvm/test/CodeGen/SPIRV/instructions/select-phi.ll @@ -0,0 +1,58 @@ +; RUN: llc -O0 -mtriple=spirv32-unknown-unknown %s -o - | FileCheck %s +; RUN: %if spirv-tools %{ llc -O0 -mtriple=spirv32-unknown-unknown --translator-compatibility-mode %s -o - -filetype=obj | spirv-val %} +; RUN: %if spirv-tools %{ llc -O0 -mtriple=spirv32-unknown-unknown %s -o - -filetype=obj | spirv-val %} + +; CHECK-DAG: %[[Char:.*]] = OpTypeInt 8 0 +; CHECK-DAG: %[[Long:.*]] = OpTypeInt 32 0 +; CHECK-DAG: %[[Array:.*]] = OpTypeArray %[[Long]] %[[#]] +; CHECK-DAG: %[[Struct:.*]] = OpTypeStruct %[[Array]] +; CHECK-DAG: %[[StructPtr:.*]] = OpTypePointer Function %[[Struct]] +; CHECK-DAG: %[[CharPtr:.*]] = OpTypePointer Function %[[Char]] + +; CHECK: %[[Branch1:.*]] = OpLabel +; CHECK: %[[Res1:.*]] = OpVariable %[[StructPtr]] Function +; CHECK: OpBranchConditional %[[#]] %[[#]] %[[Branch2:.*]] +; CHECK: %[[Res2:.*]] = OpInBoundsPtrAccessChain %[[CharPtr]] %[[#]] %[[#]] +; CHECK: %[[Res2Casted:.*]] = OpBitcast %[[StructPtr]] %[[Res2]] +; CHECK: OpBranchConditional %[[#]] %[[#]] %[[BranchSelect:.*]] +; CHECK: %[[SelectRes:.*]] = OpSelect %[[CharPtr]] %[[#]] %[[#]] %[[#]] +; CHECK: %[[SelectResCasted:.*]] = OpBitcast %[[StructPtr]] %[[SelectRes]] +; CHECK: OpLabel +; CHECK: OpPhi %[[StructPtr]] %[[Res1]] %[[Branch1]] %[[Res2Casted]] %[[Branch2]] %[[SelectResCasted]] %[[BranchSelect]] + +%struct = type { %array } +%array = type { [1 x i64] } +%array3 = type { [3 x i32] } + +define spir_kernel void @foo(ptr addrspace(1) noundef align 1 %arg1, ptr noundef byval(%struct) align 8 %arg2, i1 noundef zeroext %expected) { +entry: + %agg = alloca %array3, align 8 + %r0 = load i64, ptr %arg2, align 8 + %add.ptr = getelementptr inbounds i8, ptr %agg, i64 12 + %r1 = load i32, ptr %agg, align 4 + %tobool0 = icmp slt i32 %r1, 0 + br i1 %tobool0, label %exit, label %sw1 + +sw1: ; preds = %entry + %incdec1 = getelementptr inbounds i8, ptr %agg, i64 4 + %r2 = load i32, ptr %incdec1, align 4 + %tobool1 = icmp slt i32 %r2, 0 + br i1 %tobool1, label %exit, label %sw2 + +sw2: ; preds = %sw1 + %incdec2 = getelementptr inbounds i8, ptr %agg, i64 8 + %r3 = load i32, ptr %incdec2, align 4 + %tobool2 = icmp slt i32 %r3, 0 + %spec.select = select i1 %tobool2, ptr %incdec2, ptr %add.ptr + br label %exit + +exit: ; preds = %sw2, %sw1, %entry + %retval.0 = phi ptr [ %agg, %entry ], [ %incdec1, %sw1 ], [ %spec.select, %sw2 ] + %add.ptr.i = getelementptr inbounds i8, ptr addrspace(1) %arg1, i64 %r0 + %r4 = icmp eq ptr %retval.0, %add.ptr + %cmp = xor i1 %r4, %expected + %frombool6.i = zext i1 %cmp to i8 + store i8 %frombool6.i, ptr addrspace(1) %add.ptr.i, align 1 + %r5 = icmp eq ptr %add.ptr, %retval.0 + ret void +} diff --git a/llvm/test/CodeGen/SPIRV/instructions/select.ll b/llvm/test/CodeGen/SPIRV/instructions/select.ll index f54ef21f20859..c4176b17abb44 100644 --- a/llvm/test/CodeGen/SPIRV/instructions/select.ll +++ b/llvm/test/CodeGen/SPIRV/instructions/select.ll @@ -1,6 +1,8 @@ ; RUN: llc -O0 -mtriple=spirv32-unknown-unknown %s -o - | FileCheck %s +; RUN: %if spirv-tools %{ llc -O0 -mtriple=spirv32-unknown-unknown %s -o - -filetype=obj | spirv-val %} ; CHECK-DAG: OpName [[SCALARi32:%.+]] "select_i32" +; CHECK-DAG: OpName [[SCALARPTR:%.+]] "select_ptr" ; CHECK-DAG: OpName [[VEC2i32:%.+]] "select_i32v2" ; CHECK-DAG: OpName [[VEC2i32v2:%.+]] "select_v2i32v2" @@ -17,6 +19,19 @@ define i32 @select_i32(i1 %c, i32 %t, i32 %f) { ret i32 %r } +; CHECK: [[SCALARPTR]] = OpFunction +; CHECK-NEXT: [[C:%.+]] = OpFunctionParameter +; CHECK-NEXT: [[T:%.+]] = OpFunctionParameter +; CHECK-NEXT: [[F:%.+]] = OpFunctionParameter +; CHECK: OpLabel +; CHECK: [[R:%.+]] = OpSelect {{%.+}} [[C]] [[T]] [[F]] +; CHECK: OpReturnValue [[R]] +; CHECK-NEXT: OpFunctionEnd +define ptr @select_ptr(i1 %c, ptr %t, ptr %f) { + %r = select i1 %c, ptr %t, ptr %f + ret ptr %r +} + ; CHECK: [[VEC2i32]] = OpFunction ; CHECK-NEXT: [[C:%.+]] = OpFunctionParameter ; CHECK-NEXT: [[T:%.+]] = OpFunctionParameter