diff --git a/llvm/lib/Target/SPIRV/SPIRVCallLowering.cpp b/llvm/lib/Target/SPIRV/SPIRVCallLowering.cpp index afdca01561b0b..ad4e72a3128b1 100644 --- a/llvm/lib/Target/SPIRV/SPIRVCallLowering.cpp +++ b/llvm/lib/Target/SPIRV/SPIRVCallLowering.cpp @@ -201,21 +201,30 @@ static SPIRVType *getArgSPIRVType(const Function &F, unsigned ArgIdx, if (!isPointerTy(OriginalArgType)) return GR->getOrCreateSPIRVType(OriginalArgType, MIRBuilder, ArgAccessQual); - // In case OriginalArgType is of pointer type, there are three possibilities: + Argument *Arg = F.getArg(ArgIdx); + Type *ArgType = Arg->getType(); + if (isTypedPointerTy(ArgType)) { + SPIRVType *ElementType = GR->getOrCreateSPIRVType( + cast(ArgType)->getElementType(), MIRBuilder); + return GR->getOrCreateSPIRVPointerType( + ElementType, MIRBuilder, + addressSpaceToStorageClass(getPointerAddressSpace(ArgType), ST)); + } + + // In case OriginalArgType is of untyped pointer type, there are three + // possibilities: // 1) This is a pointer of an LLVM IR element type, passed byval/byref. // 2) This is an OpenCL/SPIR-V builtin type if there is spv_assign_type - // intrinsic assigning a TargetExtType. + // intrinsic assigning a TargetExtType. // 3) This is a pointer, try to retrieve pointer element type from a // spv_assign_ptr_type intrinsic or otherwise use default pointer element // type. - Argument *Arg = F.getArg(ArgIdx); - if (HasPointeeTypeAttr(Arg)) { - Type *ByValRefType = Arg->hasByValAttr() ? Arg->getParamByValType() - : Arg->getParamByRefType(); - SPIRVType *ElementType = GR->getOrCreateSPIRVType(ByValRefType, MIRBuilder); + if (hasPointeeTypeAttr(Arg)) { + SPIRVType *ElementType = + GR->getOrCreateSPIRVType(getPointeeTypeByAttr(Arg), MIRBuilder); return GR->getOrCreateSPIRVPointerType( ElementType, MIRBuilder, - addressSpaceToStorageClass(getPointerAddressSpace(Arg->getType()), ST)); + addressSpaceToStorageClass(getPointerAddressSpace(ArgType), ST)); } for (auto User : Arg->users()) { diff --git a/llvm/lib/Target/SPIRV/SPIRVEmitIntrinsics.cpp b/llvm/lib/Target/SPIRV/SPIRVEmitIntrinsics.cpp index 5828db6669ff1..7c5a38fa48d00 100644 --- a/llvm/lib/Target/SPIRV/SPIRVEmitIntrinsics.cpp +++ b/llvm/lib/Target/SPIRV/SPIRVEmitIntrinsics.cpp @@ -14,6 +14,7 @@ #include "SPIRV.h" #include "SPIRVBuiltins.h" #include "SPIRVMetadata.h" +#include "SPIRVSubtarget.h" #include "SPIRVTargetMachine.h" #include "SPIRVUtils.h" #include "llvm/IR/IRBuilder.h" @@ -53,14 +54,22 @@ class SPIRVEmitIntrinsics : public FunctionPass, public InstVisitor { SPIRVTargetMachine *TM = nullptr; + SPIRVGlobalRegistry *GR = nullptr; Function *F = nullptr; bool TrackConstants = true; DenseMap AggrConsts; + DenseMap AggrConstTypes; DenseSet AggrStores; - // deduce values type - DenseMap DeducedElTys; + // deduce element type of untyped pointers Type *deduceElementType(Value *I); + Type *deduceElementTypeHelper(Value *I); + Type *deduceElementTypeHelper(Value *I, std::unordered_set &Visited); + + // deduce nested types of composites + Type *deduceNestedTypeHelper(User *U); + Type *deduceNestedTypeHelper(User *U, Type *Ty, + std::unordered_set &Visited); void preprocessCompositeConstants(IRBuilder<> &B); void preprocessUndefs(IRBuilder<> &B); @@ -92,9 +101,9 @@ class SPIRVEmitIntrinsics void insertPtrCastOrAssignTypeInstr(Instruction *I, IRBuilder<> &B); void processGlobalValue(GlobalVariable &GV, IRBuilder<> &B); void processParamTypes(Function *F, IRBuilder<> &B); - Type *deduceFunParamType(Function *F, unsigned OpIdx); - Type *deduceFunParamType(Function *F, unsigned OpIdx, - std::unordered_set &FVisited); + Type *deduceFunParamElementType(Function *F, unsigned OpIdx); + Type *deduceFunParamElementType(Function *F, unsigned OpIdx, + std::unordered_set &FVisited); public: static char ID; @@ -169,17 +178,20 @@ static inline void reportFatalOnTokenType(const Instruction *I) { // Deduce and return a successfully deduced Type of the Instruction, // or nullptr otherwise. -static Type *deduceElementTypeHelper(Value *I, - std::unordered_set &Visited, - DenseMap &DeducedElTys) { +Type *SPIRVEmitIntrinsics::deduceElementTypeHelper(Value *I) { + std::unordered_set Visited; + return deduceElementTypeHelper(I, Visited); +} + +Type *SPIRVEmitIntrinsics::deduceElementTypeHelper( + Value *I, std::unordered_set &Visited) { // allow to pass nullptr as an argument if (!I) return nullptr; // maybe already known - auto It = DeducedElTys.find(I); - if (It != DeducedElTys.end()) - return It->second; + if (Type *KnownTy = GR->findDeducedElementType(I)) + return KnownTy; // maybe a cycle if (Visited.find(I) != Visited.end()) @@ -195,25 +207,99 @@ static Type *deduceElementTypeHelper(Value *I, Ty = Ref->getResultElementType(); } else if (auto *Ref = dyn_cast(I)) { Ty = Ref->getValueType(); + if (Value *Op = Ref->getNumOperands() > 0 ? Ref->getOperand(0) : nullptr) { + if (auto *PtrTy = dyn_cast(Ty)) { + if (Type *NestedTy = deduceElementTypeHelper(Op, Visited)) + Ty = TypedPointerType::get(NestedTy, PtrTy->getAddressSpace()); + } else { + Ty = deduceNestedTypeHelper(dyn_cast(Op), Ty, Visited); + } + } } else if (auto *Ref = dyn_cast(I)) { - Ty = deduceElementTypeHelper(Ref->getPointerOperand(), Visited, - DeducedElTys); + Ty = deduceElementTypeHelper(Ref->getPointerOperand(), Visited); } else if (auto *Ref = dyn_cast(I)) { if (Type *Src = Ref->getSrcTy(), *Dest = Ref->getDestTy(); isPointerTy(Src) && isPointerTy(Dest)) - Ty = deduceElementTypeHelper(Ref->getOperand(0), Visited, DeducedElTys); + Ty = deduceElementTypeHelper(Ref->getOperand(0), Visited); } // remember the found relationship - if (Ty) - DeducedElTys[I] = Ty; + if (Ty) { + // specify nested types if needed, otherwise return unchanged + GR->addDeducedElementType(I, Ty); + } return Ty; } -Type *SPIRVEmitIntrinsics::deduceElementType(Value *I) { +// Re-create a type of the value if it has untyped pointer fields, also nested. +// Return the original value type if no corrections of untyped pointer +// information is found or needed. +Type *SPIRVEmitIntrinsics::deduceNestedTypeHelper(User *U) { std::unordered_set Visited; - if (Type *Ty = deduceElementTypeHelper(I, Visited, DeducedElTys)) + return deduceNestedTypeHelper(U, U->getType(), Visited); +} + +Type *SPIRVEmitIntrinsics::deduceNestedTypeHelper( + User *U, Type *OrigTy, std::unordered_set &Visited) { + if (!U) + return OrigTy; + + // maybe already known + if (Type *KnownTy = GR->findDeducedCompositeType(U)) + return KnownTy; + + // maybe a cycle + if (Visited.find(U) != Visited.end()) + return OrigTy; + Visited.insert(U); + + if (dyn_cast(OrigTy)) { + SmallVector Tys; + bool Change = false; + for (unsigned i = 0; i < U->getNumOperands(); ++i) { + Value *Op = U->getOperand(i); + Type *OpTy = Op->getType(); + Type *Ty = OpTy; + if (Op) { + if (auto *PtrTy = dyn_cast(OpTy)) { + if (Type *NestedTy = deduceElementTypeHelper(Op, Visited)) + Ty = TypedPointerType::get(NestedTy, PtrTy->getAddressSpace()); + } else { + Ty = deduceNestedTypeHelper(dyn_cast(Op), OpTy, Visited); + } + } + Tys.push_back(Ty); + Change |= Ty != OpTy; + } + if (Change) { + Type *NewTy = StructType::create(Tys); + GR->addDeducedCompositeType(U, NewTy); + return NewTy; + } + } else if (auto *ArrTy = dyn_cast(OrigTy)) { + if (Value *Op = U->getNumOperands() > 0 ? U->getOperand(0) : nullptr) { + Type *OpTy = ArrTy->getElementType(); + Type *Ty = OpTy; + if (auto *PtrTy = dyn_cast(OpTy)) { + if (Type *NestedTy = deduceElementTypeHelper(Op, Visited)) + Ty = TypedPointerType::get(NestedTy, PtrTy->getAddressSpace()); + } else { + Ty = deduceNestedTypeHelper(dyn_cast(Op), OpTy, Visited); + } + if (Ty != OpTy) { + Type *NewTy = ArrayType::get(Ty, ArrTy->getNumElements()); + GR->addDeducedCompositeType(U, NewTy); + return NewTy; + } + } + } + + return OrigTy; +} + +Type *SPIRVEmitIntrinsics::deduceElementType(Value *I) { + if (Type *Ty = deduceElementTypeHelper(I)) return Ty; return IntegerType::getInt8Ty(I->getContext()); } @@ -257,6 +343,7 @@ void SPIRVEmitIntrinsics::preprocessUndefs(IRBuilder<> &B) { Worklist.push(IntrUndef); I->replaceUsesOfWith(Op, IntrUndef); AggrConsts[IntrUndef] = AggrUndef; + AggrConstTypes[IntrUndef] = AggrUndef->getType(); } } } @@ -282,6 +369,7 @@ void SPIRVEmitIntrinsics::preprocessCompositeConstants(IRBuilder<> &B) { I->replaceUsesOfWith(Op, CCI); KeepInst = true; SEI.AggrConsts[CCI] = AggrC; + SEI.AggrConstTypes[CCI] = SEI.deduceNestedTypeHelper(AggrC); }; if (auto *AggrC = dyn_cast(Op)) { @@ -396,8 +484,7 @@ void SPIRVEmitIntrinsics::replacePointerOperandWithPtrCast( Pointer = BC->getOperand(0); // Do not emit spv_ptrcast if Pointer's element type is ExpectedElementType - std::unordered_set Visited; - Type *PointerElemTy = deduceElementTypeHelper(Pointer, Visited, DeducedElTys); + Type *PointerElemTy = deduceElementTypeHelper(Pointer); if (PointerElemTy == ExpectedElementType) return; @@ -456,8 +543,8 @@ void SPIRVEmitIntrinsics::replacePointerOperandWithPtrCast( CallInst *CI = buildIntrWithMD( Intrinsic::spv_assign_ptr_type, {Pointer->getType()}, ExpectedElementTypeConst, Pointer, {B.getInt32(AddressSpace)}, B); - DeducedElTys[CI] = ExpectedElementType; - DeducedElTys[Pointer] = ExpectedElementType; + GR->addDeducedElementType(CI, ExpectedElementType); + GR->addDeducedElementType(Pointer, ExpectedElementType); return; } @@ -498,25 +585,29 @@ void SPIRVEmitIntrinsics::insertPtrCastOrAssignTypeInstr(Instruction *I, Function *CalledF = CI->getCalledFunction(); SmallVector CalledArgTys; bool HaveTypes = false; - for (auto &CalledArg : CalledF->args()) { - if (!isPointerTy(CalledArg.getType())) { + for (unsigned OpIdx = 0; OpIdx < CalledF->arg_size(); ++OpIdx) { + Argument *CalledArg = CalledF->getArg(OpIdx); + Type *ArgType = CalledArg->getType(); + if (!isPointerTy(ArgType)) { CalledArgTys.push_back(nullptr); - continue; - } - auto It = DeducedElTys.find(&CalledArg); - Type *ParamTy = It != DeducedElTys.end() ? It->second : nullptr; - if (!ParamTy) { - for (User *U : CalledArg.users()) { - if (Instruction *Inst = dyn_cast(U)) { - std::unordered_set Visited; - ParamTy = deduceElementTypeHelper(Inst, Visited, DeducedElTys); - if (ParamTy) - break; + } else if (isTypedPointerTy(ArgType)) { + CalledArgTys.push_back(cast(ArgType)->getElementType()); + HaveTypes = true; + } else { + Type *ElemTy = GR->findDeducedElementType(CalledArg); + if (!ElemTy && hasPointeeTypeAttr(CalledArg)) + ElemTy = getPointeeTypeByAttr(CalledArg); + if (!ElemTy) { + for (User *U : CalledArg->users()) { + if (Instruction *Inst = dyn_cast(U)) { + if ((ElemTy = deduceElementTypeHelper(Inst)) != nullptr) + break; + } } } + HaveTypes |= ElemTy != nullptr; + CalledArgTys.push_back(ElemTy); } - HaveTypes |= ParamTy != nullptr; - CalledArgTys.push_back(ParamTy); } std::string DemangledName = @@ -706,6 +797,10 @@ void SPIRVEmitIntrinsics::processGlobalValue(GlobalVariable &GV, if (GV.getName() == "llvm.global.annotations") return; if (GV.hasInitializer() && !isa(GV.getInitializer())) { + // Deduce element type and store results in Global Registry. + // Result is ignored, because TypedPointerType is not supported + // by llvm IR general logic. + deduceElementTypeHelper(&GV); Constant *Init = GV.getInitializer(); Type *Ty = isAggrToReplace(Init) ? B.getInt32Ty() : Init->getType(); Constant *Const = isAggrToReplace(Init) ? B.getInt32(1) : Init; @@ -732,7 +827,7 @@ void SPIRVEmitIntrinsics::insertAssignPtrTypeIntrs(Instruction *I, unsigned AddressSpace = getPointerAddressSpace(I->getType()); CallInst *CI = buildIntrWithMD(Intrinsic::spv_assign_ptr_type, {I->getType()}, EltTyConst, I, {B.getInt32(AddressSpace)}, B); - DeducedElTys[CI] = ElemTy; + GR->addDeducedElementType(CI, ElemTy); } void SPIRVEmitIntrinsics::insertAssignTypeIntrs(Instruction *I, @@ -745,9 +840,10 @@ void SPIRVEmitIntrinsics::insertAssignTypeIntrs(Instruction *I, if (auto *II = dyn_cast(I)) { if (II->getIntrinsicID() == Intrinsic::spv_const_composite || II->getIntrinsicID() == Intrinsic::spv_undef) { - auto t = AggrConsts.find(II); - assert(t != AggrConsts.end()); - TypeToAssign = t->second->getType(); + auto It = AggrConstTypes.find(II); + if (It == AggrConstTypes.end()) + report_fatal_error("Unknown composite intrinsic type"); + TypeToAssign = It->second; } } Constant *Const = UndefValue::get(TypeToAssign); @@ -807,12 +903,13 @@ void SPIRVEmitIntrinsics::processInstrAfterVisit(Instruction *I, } } -Type *SPIRVEmitIntrinsics::deduceFunParamType(Function *F, unsigned OpIdx) { +Type *SPIRVEmitIntrinsics::deduceFunParamElementType(Function *F, + unsigned OpIdx) { std::unordered_set FVisited; - return deduceFunParamType(F, OpIdx, FVisited); + return deduceFunParamElementType(F, OpIdx, FVisited); } -Type *SPIRVEmitIntrinsics::deduceFunParamType( +Type *SPIRVEmitIntrinsics::deduceFunParamElementType( Function *F, unsigned OpIdx, std::unordered_set &FVisited) { // maybe a cycle if (FVisited.find(F) != FVisited.end()) @@ -830,15 +927,15 @@ Type *SPIRVEmitIntrinsics::deduceFunParamType( if (!isPointerTy(OpArg->getType())) continue; // maybe we already know operand's element type - if (auto It = DeducedElTys.find(OpArg); It != DeducedElTys.end()) - return It->second; + if (Type *KnownTy = GR->findDeducedElementType(OpArg)) + return KnownTy; // search in actual parameter's users for (User *OpU : OpArg->users()) { Instruction *Inst = dyn_cast(OpU); if (!Inst || Inst == CI) continue; Visited.clear(); - if (Type *Ty = deduceElementTypeHelper(Inst, Visited, DeducedElTys)) + if (Type *Ty = deduceElementTypeHelper(Inst, Visited)) return Ty; } // check if it's a formal parameter of the outer function @@ -857,7 +954,7 @@ Type *SPIRVEmitIntrinsics::deduceFunParamType( // search in function parameters for (auto &Pair : Lookup) { - if (Type *Ty = deduceFunParamType(Pair.first, Pair.second, FVisited)) + if (Type *Ty = deduceFunParamElementType(Pair.first, Pair.second, FVisited)) return Ty; } @@ -866,19 +963,23 @@ Type *SPIRVEmitIntrinsics::deduceFunParamType( void SPIRVEmitIntrinsics::processParamTypes(Function *F, IRBuilder<> &B) { B.SetInsertPointPastAllocas(F); - DenseMap Args; for (unsigned OpIdx = 0; OpIdx < F->arg_size(); ++OpIdx) { Argument *Arg = F->getArg(OpIdx); - if (isUntypedPointerTy(Arg->getType()) && - DeducedElTys.find(Arg) == DeducedElTys.end() && - !HasPointeeTypeAttr(Arg)) { - if (Type *ElemTy = deduceFunParamType(F, OpIdx)) { + if (!isUntypedPointerTy(Arg->getType())) + continue; + + Type *ElemTy = GR->findDeducedElementType(Arg); + if (!ElemTy) { + if (hasPointeeTypeAttr(Arg) && + (ElemTy = getPointeeTypeByAttr(Arg)) != nullptr) { + GR->addDeducedElementType(Arg, ElemTy); + } else if ((ElemTy = deduceFunParamElementType(F, OpIdx)) != nullptr) { CallInst *AssignPtrTyCI = buildIntrWithMD( Intrinsic::spv_assign_ptr_type, {Arg->getType()}, Constant::getNullValue(ElemTy), Arg, {B.getInt32(getPointerAddressSpace(Arg->getType()))}, B); - DeducedElTys[AssignPtrTyCI] = ElemTy; - DeducedElTys[Arg] = ElemTy; + GR->addDeducedElementType(AssignPtrTyCI, ElemTy); + GR->addDeducedElementType(Arg, ElemTy); } } } @@ -887,9 +988,14 @@ void SPIRVEmitIntrinsics::processParamTypes(Function *F, IRBuilder<> &B) { bool SPIRVEmitIntrinsics::runOnFunction(Function &Func) { if (Func.isDeclaration()) return false; + + const SPIRVSubtarget &ST = TM->getSubtarget(Func); + GR = ST.getSPIRVGlobalRegistry(); + F = &Func; IRBuilder<> B(Func.getContext()); AggrConsts.clear(); + AggrConstTypes.clear(); AggrStores.clear(); // StoreInst's operand type can be changed during the next transformations, diff --git a/llvm/lib/Target/SPIRV/SPIRVGlobalRegistry.h b/llvm/lib/Target/SPIRV/SPIRVGlobalRegistry.h index ed0f90ff89ce6..e0099e5294472 100644 --- a/llvm/lib/Target/SPIRV/SPIRVGlobalRegistry.h +++ b/llvm/lib/Target/SPIRV/SPIRVGlobalRegistry.h @@ -41,9 +41,13 @@ class SPIRVGlobalRegistry { // map a Function to its definition (as a machine instruction operand) DenseMap FunctionToInstr; + DenseMap FunctionToInstrRev; // map function pointer (as a machine instruction operand) to the used // Function DenseMap InstrToFunction; + // Maps Functions to their calls (in a form of the machine instruction, + // OpFunctionCall) that happened before the definition is available + DenseMap> ForwardCalls; // Look for an equivalent of the newType in the map. Return the equivalent // if it's found, otherwise insert newType to the map and return the type. @@ -59,6 +63,13 @@ class SPIRVGlobalRegistry { // Holds the maximum ID we have in the module. unsigned Bound; + // Maps values associated with untyped pointers into deduced element types of + // untyped pointers. + DenseMap DeducedElTys; + // Maps composite values to deduced types where untyped pointers are replaced + // with typed ones + DenseMap DeducedNestedTys; + // Add a new OpTypeXXX instruction without checking for duplicates. SPIRVType *createSPIRVType(const Type *Type, MachineIRBuilder &MIRBuilder, SPIRV::AccessQualifier::AccessQualifier AQ = @@ -122,6 +133,37 @@ class SPIRVGlobalRegistry { void setBound(unsigned V) { Bound = V; } unsigned getBound() { return Bound; } + // Deduced element types of untyped pointers and composites: + // - Add a record to the map of deduced element types. + void addDeducedElementType(Value *Val, Type *Ty) { DeducedElTys[Val] = Ty; } + // - Find a record in the map of deduced element types. + Type *findDeducedElementType(const Value *Val) { + auto It = DeducedElTys.find(Val); + return It == DeducedElTys.end() ? nullptr : It->second; + } + // - Add a record to the map of deduced composite types. + void addDeducedCompositeType(Value *Val, Type *Ty) { + DeducedNestedTys[Val] = Ty; + } + // - Find a record in the map of deduced composite types. + Type *findDeducedCompositeType(const Value *Val) { + auto It = DeducedNestedTys.find(Val); + return It == DeducedNestedTys.end() ? nullptr : It->second; + } + // - Find a type of the given Global value + Type *getDeducedGlobalValueType(const GlobalValue *Global) { + // we may know element type if it was deduced earlier + Type *ElementTy = findDeducedElementType(Global); + if (!ElementTy) { + // or we may know element type if it's associated with a composite + // value + if (Value *GlobalElem = + Global->getNumOperands() > 0 ? Global->getOperand(0) : nullptr) + ElementTy = findDeducedCompositeType(GlobalElem); + } + return ElementTy ? ElementTy : Global->getValueType(); + } + // Map a machine operand that represents a use of a function via function // pointer to a machine operand that represents the function definition. // Return either the register or invalid value, because we have no context for @@ -133,18 +175,56 @@ class SPIRVGlobalRegistry { auto ResReg = FunctionToInstr.find(ResF->second); return ResReg == FunctionToInstr.end() ? nullptr : ResReg->second; } + + // Map a Function to a machine instruction that represents the function + // definition. + const MachineInstr *getFunctionDefinition(const Function *F) { + if (!F) + return nullptr; + auto MOIt = FunctionToInstr.find(F); + return MOIt == FunctionToInstr.end() ? nullptr : MOIt->second->getParent(); + } + + // Map a Function to a machine instruction that represents the function + // definition. + const Function *getFunctionByDefinition(const MachineInstr *MI) { + if (!MI) + return nullptr; + auto FIt = FunctionToInstrRev.find(MI); + return FIt == FunctionToInstrRev.end() ? nullptr : FIt->second; + } + // map function pointer (as a machine instruction operand) to the used // Function void recordFunctionPointer(const MachineOperand *MO, const Function *F) { InstrToFunction[MO] = F; } + // map a Function to its definition (as a machine instruction) void recordFunctionDefinition(const Function *F, const MachineOperand *MO) { FunctionToInstr[F] = MO; + FunctionToInstrRev[MO->getParent()] = F; } + // Return true if any OpConstantFunctionPointerINTEL were generated bool hasConstFunPtr() { return !InstrToFunction.empty(); } + // Add a record about forward function call. + void addForwardCall(const Function *F, MachineInstr *MI) { + auto It = ForwardCalls.find(F); + if (It == ForwardCalls.end()) + ForwardCalls[F] = {MI}; + else + It->second.push_back(MI); + } + + // Map a Function to the vector of machine instructions that represents + // forward function calls or to nullptr if not found. + SmallVector *getForwardCalls(const Function *F) { + auto It = ForwardCalls.find(F); + return It == ForwardCalls.end() ? nullptr : &It->second; + } + // Get or create a SPIR-V type corresponding the given LLVM IR type, // and map it to the given VReg by creating an ASSIGN_TYPE instruction. SPIRVType *assignTypeToVReg(const Type *Type, Register VReg, diff --git a/llvm/lib/Target/SPIRV/SPIRVISelLowering.cpp b/llvm/lib/Target/SPIRV/SPIRVISelLowering.cpp index 55b4c47c197da..4f5c1dc4f90b0 100644 --- a/llvm/lib/Target/SPIRV/SPIRVISelLowering.cpp +++ b/llvm/lib/Target/SPIRV/SPIRVISelLowering.cpp @@ -86,8 +86,8 @@ bool SPIRVTargetLowering::getTgtMemIntrinsic(IntrinsicInfo &Info, // when there is a type mismatch between results and operand types. static void validatePtrTypes(const SPIRVSubtarget &STI, MachineRegisterInfo *MRI, SPIRVGlobalRegistry &GR, - MachineInstr &I, SPIRVType *ResType, - unsigned OpIdx) { + MachineInstr &I, unsigned OpIdx, + SPIRVType *ResType, const Type *ResTy = nullptr) { Register OpReg = I.getOperand(OpIdx).getReg(); SPIRVType *TypeInst = MRI->getVRegDef(OpReg); SPIRVType *OpType = GR.getSPIRVTypeForVReg( @@ -97,7 +97,13 @@ static void validatePtrTypes(const SPIRVSubtarget &STI, if (!ResType || !OpType || OpType->getOpcode() != SPIRV::OpTypePointer) return; SPIRVType *ElemType = GR.getSPIRVTypeForVReg(OpType->getOperand(2).getReg()); - if (!ElemType || ElemType == ResType) + if (!ElemType) + return; + bool IsSameMF = + ElemType->getParent()->getParent() == ResType->getParent()->getParent(); + bool IsEqualTypes = IsSameMF ? ElemType == ResType + : GR.getTypeForSPIRVType(ElemType) == ResTy; + if (IsEqualTypes) return; // There is a type mismatch between results and operand types // and we insert a bitcast before the instruction to keep SPIR-V code valid @@ -105,7 +111,11 @@ static void validatePtrTypes(const SPIRVSubtarget &STI, static_cast( OpType->getOperand(1).getImm()); MachineIRBuilder MIB(I); - SPIRVType *NewPtrType = GR.getOrCreateSPIRVPointerType(ResType, MIB, SC); + SPIRVType *NewBaseType = + IsSameMF ? ResType + : GR.getOrCreateSPIRVType( + ResTy, MIB, SPIRV::AccessQualifier::ReadWrite, false); + SPIRVType *NewPtrType = GR.getOrCreateSPIRVPointerType(NewBaseType, MIB, SC); if (!GR.isBitcastCompatible(NewPtrType, OpType)) report_fatal_error( "insert validation bitcast: incompatible result and operand types"); @@ -123,6 +133,74 @@ static void validatePtrTypes(const SPIRVSubtarget &STI, I.getOperand(OpIdx).setReg(NewReg); } +// Insert a bitcast before the function call instruction to keep SPIR-V code +// valid when there is a type mismatch between actual and expected types of an +// argument: +// %formal = OpFunctionParameter %formal_type +// ... +// %res = OpFunctionCall %ty %fun %actual ... +// implies that %actual is of %formal_type, and in case of opaque pointers. +// We may need to insert a bitcast to ensure this. +void validateFunCallMachineDef(const SPIRVSubtarget &STI, + MachineRegisterInfo *DefMRI, + MachineRegisterInfo *CallMRI, + SPIRVGlobalRegistry &GR, MachineInstr &FunCall, + MachineInstr *FunDef) { + if (FunDef->getOpcode() != SPIRV::OpFunction) + return; + unsigned OpIdx = 3; + for (FunDef = FunDef->getNextNode(); + FunDef && FunDef->getOpcode() == SPIRV::OpFunctionParameter && + OpIdx < FunCall.getNumOperands(); + FunDef = FunDef->getNextNode(), OpIdx++) { + SPIRVType *DefPtrType = DefMRI->getVRegDef(FunDef->getOperand(1).getReg()); + SPIRVType *DefElemType = + DefPtrType && DefPtrType->getOpcode() == SPIRV::OpTypePointer + ? GR.getSPIRVTypeForVReg(DefPtrType->getOperand(2).getReg()) + : nullptr; + if (DefElemType) { + const Type *DefElemTy = GR.getTypeForSPIRVType(DefElemType); + // Switch GR context to the call site instead of the (default) definition + // side + GR.setCurrentFunc(*FunCall.getParent()->getParent()); + validatePtrTypes(STI, CallMRI, GR, FunCall, OpIdx, DefElemType, + DefElemTy); + GR.setCurrentFunc(*FunDef->getParent()->getParent()); + } + } +} + +// Ensure there is no mismatch between actual and expected arg types: calls +// with a processed definition. Return Function pointer if it's a forward +// call (ahead of definition), and nullptr otherwise. +const Function *validateFunCall(const SPIRVSubtarget &STI, + MachineRegisterInfo *MRI, + SPIRVGlobalRegistry &GR, + MachineInstr &FunCall) { + const GlobalValue *GV = FunCall.getOperand(2).getGlobal(); + const Function *F = dyn_cast(GV); + MachineInstr *FunDef = + const_cast(GR.getFunctionDefinition(F)); + if (!FunDef) + return F; + validateFunCallMachineDef(STI, MRI, MRI, GR, FunCall, FunDef); + return nullptr; +} + +// Ensure there is no mismatch between actual and expected arg types: calls +// ahead of a processed definition. +void validateForwardCalls(const SPIRVSubtarget &STI, + MachineRegisterInfo *DefMRI, SPIRVGlobalRegistry &GR, + MachineInstr &FunDef) { + const Function *F = GR.getFunctionByDefinition(&FunDef); + if (SmallVector *FwdCalls = GR.getForwardCalls(F)) + for (MachineInstr *FunCall : *FwdCalls) { + MachineRegisterInfo *CallMRI = + &FunCall->getParent()->getParent()->getRegInfo(); + validateFunCallMachineDef(STI, DefMRI, CallMRI, GR, *FunCall, &FunDef); + } +} + // TODO: the logic of inserting additional bitcast's is to be moved // to pre-IRTranslation passes eventually void SPIRVTargetLowering::finalizeLowering(MachineFunction &MF) const { @@ -137,14 +215,28 @@ void SPIRVTargetLowering::finalizeLowering(MachineFunction &MF) const { switch (MI.getOpcode()) { case SPIRV::OpLoad: // OpLoad , ptr %Op implies that %Op is a pointer to - validatePtrTypes(STI, MRI, GR, MI, - GR.getSPIRVTypeForVReg(MI.getOperand(0).getReg()), 2); + validatePtrTypes(STI, MRI, GR, MI, 2, + GR.getSPIRVTypeForVReg(MI.getOperand(0).getReg())); break; case SPIRV::OpStore: // OpStore ptr %Op, implies that %Op points to the 's type - validatePtrTypes(STI, MRI, GR, MI, - GR.getSPIRVTypeForVReg(MI.getOperand(1).getReg()), 0); + validatePtrTypes(STI, MRI, GR, MI, 0, + GR.getSPIRVTypeForVReg(MI.getOperand(1).getReg())); break; + + case SPIRV::OpFunctionCall: + // ensure there is no mismatch between actual and expected arg types: + // calls with a processed definition + if (MI.getNumOperands() > 3) + if (const Function *F = validateFunCall(STI, MRI, GR, MI)) + GR.addForwardCall(F, &MI); + break; + case SPIRV::OpFunction: + // ensure there is no mismatch between actual and expected arg types: + // calls ahead of a processed definition + validateForwardCalls(STI, MRI, GR, MI); + break; + // ensure that LLVM IR bitwise instructions result in logical SPIR-V // instructions when applied to bool type case SPIRV::OpBitwiseOrS: diff --git a/llvm/lib/Target/SPIRV/SPIRVInstructionSelector.cpp b/llvm/lib/Target/SPIRV/SPIRVInstructionSelector.cpp index 505b19a4d66ed..f4525e713c987 100644 --- a/llvm/lib/Target/SPIRV/SPIRVInstructionSelector.cpp +++ b/llvm/lib/Target/SPIRV/SPIRVInstructionSelector.cpp @@ -1897,7 +1897,7 @@ bool SPIRVInstructionSelector::selectGlobalValue( // FIXME: don't use MachineIRBuilder here, replace it with BuildMI. MachineIRBuilder MIRBuilder(I); const GlobalValue *GV = I.getOperand(1).getGlobal(); - Type *GVType = GV->getValueType(); + Type *GVType = GR.getDeducedGlobalValueType(GV); SPIRVType *PointerBaseType; if (GVType->isArrayTy()) { SPIRVType *ArrayElementType = diff --git a/llvm/lib/Target/SPIRV/SPIRVPreLegalizer.cpp b/llvm/lib/Target/SPIRV/SPIRVPreLegalizer.cpp index 41807da6afcbc..b133f0ae85de2 100644 --- a/llvm/lib/Target/SPIRV/SPIRVPreLegalizer.cpp +++ b/llvm/lib/Target/SPIRV/SPIRVPreLegalizer.cpp @@ -186,8 +186,9 @@ static SPIRVType *propagateSPIRVType(MachineInstr *MI, SPIRVGlobalRegistry *GR, } case TargetOpcode::G_GLOBAL_VALUE: { MIB.setInsertPt(*MI->getParent(), MI); - const auto *Global = MI->getOperand(1).getGlobal(); - auto *Ty = TypedPointerType::get(Global->getValueType(), + const GlobalValue *Global = MI->getOperand(1).getGlobal(); + Type *ElementTy = GR->getDeducedGlobalValueType(Global); + auto *Ty = TypedPointerType::get(ElementTy, Global->getType()->getAddressSpace()); SpirvTy = GR->getOrCreateSPIRVType(Ty, MIB); break; diff --git a/llvm/lib/Target/SPIRV/SPIRVUtils.h b/llvm/lib/Target/SPIRV/SPIRVUtils.h index eb87349f0941c..c2c3475e1a936 100644 --- a/llvm/lib/Target/SPIRV/SPIRVUtils.h +++ b/llvm/lib/Target/SPIRV/SPIRVUtils.h @@ -127,8 +127,26 @@ inline unsigned getPointerAddressSpace(const Type *T) { } // Return true if the Argument is decorated with a pointee type -inline bool HasPointeeTypeAttr(Argument *Arg) { - return Arg->hasByValAttr() || Arg->hasByRefAttr(); +inline bool hasPointeeTypeAttr(Argument *Arg) { + return Arg->hasByValAttr() || Arg->hasByRefAttr() || Arg->hasStructRetAttr(); +} + +// Return the pointee type of the argument or nullptr otherwise +inline Type *getPointeeTypeByAttr(Argument *Arg) { + if (Arg->hasByValAttr()) + return Arg->getParamByValType(); + if (Arg->hasStructRetAttr()) + return Arg->getParamStructRetType(); + if (Arg->hasByRefAttr()) + return Arg->getParamByRefType(); + return nullptr; +} + +inline Type *reconstructFunctionType(Function *F) { + SmallVector ArgTys; + for (unsigned i = 0; i < F->arg_size(); ++i) + ArgTys.push_back(F->getArg(i)->getType()); + return FunctionType::get(F->getReturnType(), ArgTys, F->isVarArg()); } } // namespace llvm diff --git a/llvm/test/CodeGen/SPIRV/pointers/nested-struct-opaque-pointers.ll b/llvm/test/CodeGen/SPIRV/pointers/nested-struct-opaque-pointers.ll new file mode 100644 index 0000000000000..77b895c7762fb --- /dev/null +++ b/llvm/test/CodeGen/SPIRV/pointers/nested-struct-opaque-pointers.ll @@ -0,0 +1,20 @@ +; RUN: llc -O0 -mtriple=spirv32-unknown-unknown %s -o - | FileCheck %s +; RUN: %if spirv-tools %{ llc -O0 -mtriple=spirv32-unknown-unknown %s -o - -filetype=obj | spirv-val %} + +; CHECK-NOT: OpTypeInt 8 0 + +@GI = addrspace(1) constant i64 42 + +@GS = addrspace(1) global {ptr addrspace(1), ptr addrspace(1)} { ptr addrspace(1) @GI, ptr addrspace(1) @GI } +@GS2 = addrspace(1) global {ptr addrspace(1), ptr addrspace(1)} { ptr addrspace(1) @GS, ptr addrspace(1) @GS } +@GS3 = addrspace(1) global {ptr addrspace(1), ptr addrspace(1)} { ptr addrspace(1) @GS2, ptr addrspace(1) @GS2 } + +@GPS = addrspace(1) global ptr addrspace(1) @GS3 + +@GPI1 = addrspace(1) global ptr addrspace(1) @GI +@GPI2 = addrspace(1) global ptr addrspace(1) @GPI1 +@GPI3 = addrspace(1) global ptr addrspace(1) @GPI2 + +define spir_kernel void @foo() { + ret void +} diff --git a/llvm/test/CodeGen/SPIRV/pointers/struct-opaque-pointers.ll b/llvm/test/CodeGen/SPIRV/pointers/struct-opaque-pointers.ll index ce3ab8895a594..6d4913f802c28 100644 --- a/llvm/test/CodeGen/SPIRV/pointers/struct-opaque-pointers.ll +++ b/llvm/test/CodeGen/SPIRV/pointers/struct-opaque-pointers.ll @@ -1,14 +1,14 @@ ; RUN: llc -O0 -mtriple=spirv32-unknown-unknown %s -o - | FileCheck %s ; RUN: %if spirv-tools %{ llc -O0 -mtriple=spirv32-unknown-unknown %s -o - -filetype=obj | spirv-val %} -; CHECK: %[[TyInt8:.*]] = OpTypeInt 8 0 -; CHECK: %[[TyInt8Ptr:.*]] = OpTypePointer {{[a-zA-Z]+}} %[[TyInt8]] -; CHECK: %[[TyStruct:.*]] = OpTypeStruct %[[TyInt8Ptr]] %[[TyInt8Ptr]] +; CHECK: %[[TyInt64:.*]] = OpTypeInt 64 0 +; CHECK: %[[TyInt64Ptr:.*]] = OpTypePointer {{[a-zA-Z]+}} %[[TyInt64]] +; CHECK: %[[TyStruct:.*]] = OpTypeStruct %[[TyInt64Ptr]] %[[TyInt64Ptr]] ; CHECK: %[[ConstStruct:.*]] = OpConstantComposite %[[TyStruct]] %[[ConstField:.*]] %[[ConstField]] ; CHECK: %[[TyStructPtr:.*]] = OpTypePointer {{[a-zA-Z]+}} %[[TyStruct]] ; CHECK: OpVariable %[[TyStructPtr]] {{[a-zA-Z]+}} %[[ConstStruct]] -@a = addrspace(1) constant i32 123 +@a = addrspace(1) constant i64 42 @struct = addrspace(1) global {ptr addrspace(1), ptr addrspace(1)} { ptr addrspace(1) @a, ptr addrspace(1) @a } define spir_kernel void @foo() { diff --git a/llvm/test/CodeGen/SPIRV/pointers/type-deduce-by-call-chain.ll b/llvm/test/CodeGen/SPIRV/pointers/type-deduce-by-call-chain.ll index 703f1e22a0321..1071d3443056c 100644 --- a/llvm/test/CodeGen/SPIRV/pointers/type-deduce-by-call-chain.ll +++ b/llvm/test/CodeGen/SPIRV/pointers/type-deduce-by-call-chain.ll @@ -34,6 +34,12 @@ entry: %addr = addrspacecast ptr addrspace(1) %lptr to ptr addrspace(4) %object = bitcast ptr addrspace(4) %addr to ptr addrspace(4) call spir_func void @foo(ptr addrspace(4) %object, i32 3) + %halfptr = getelementptr inbounds half, ptr addrspace(1) %_arg_cum, i64 1 + %halfaddr = addrspacecast ptr addrspace(1) %halfptr to ptr addrspace(4) + call spir_func void @foo(ptr addrspace(4) %halfaddr, i32 3) + %dblptr = getelementptr inbounds double, ptr addrspace(1) %_arg_cum, i64 1 + %dbladdr = addrspacecast ptr addrspace(1) %dblptr to ptr addrspace(4) + call spir_func void @foo(ptr addrspace(4) %dbladdr, i32 3) ret void } @@ -49,4 +55,3 @@ define void @foo(ptr addrspace(4) noundef %foo_object, i32 noundef %mem_order) { tail call void @foo_stub(ptr addrspace(4) noundef %foo_object, i32 noundef %mem_order) ret void } -