diff --git a/llvm/lib/Target/SPIRV/SPIRVBuiltins.cpp b/llvm/lib/Target/SPIRV/SPIRVBuiltins.cpp index c11b36a088545..40b652057e87f 100644 --- a/llvm/lib/Target/SPIRV/SPIRVBuiltins.cpp +++ b/llvm/lib/Target/SPIRV/SPIRVBuiltins.cpp @@ -291,6 +291,7 @@ buildBoolRegister(MachineIRBuilder &MIRBuilder, const SPIRVType *ResultType, Register ResultRegister = MIRBuilder.getMRI()->createGenericVirtualRegister(Type); + MIRBuilder.getMRI()->setRegClass(ResultRegister, &SPIRV::IDRegClass); GR->assignSPIRVTypeToVReg(BoolType, ResultRegister, MIRBuilder.getMF()); return std::make_tuple(ResultRegister, BoolType); } @@ -417,33 +418,41 @@ static Register buildConstantIntReg(uint64_t Val, MachineIRBuilder &MIRBuilder, } static Register buildScopeReg(Register CLScopeRegister, + SPIRV::Scope::Scope Scope, MachineIRBuilder &MIRBuilder, SPIRVGlobalRegistry *GR, - const MachineRegisterInfo *MRI) { - auto CLScope = - static_cast(getIConstVal(CLScopeRegister, MRI)); - SPIRV::Scope::Scope Scope = getSPIRVScope(CLScope); - - if (CLScope == static_cast(Scope)) - return CLScopeRegister; - + MachineRegisterInfo *MRI) { + if (CLScopeRegister.isValid()) { + auto CLScope = + static_cast(getIConstVal(CLScopeRegister, MRI)); + Scope = getSPIRVScope(CLScope); + + if (CLScope == static_cast(Scope)) { + MRI->setRegClass(CLScopeRegister, &SPIRV::IDRegClass); + return CLScopeRegister; + } + } return buildConstantIntReg(Scope, MIRBuilder, GR); } static Register buildMemSemanticsReg(Register SemanticsRegister, - Register PtrRegister, - const MachineRegisterInfo *MRI, + Register PtrRegister, unsigned &Semantics, + MachineIRBuilder &MIRBuilder, SPIRVGlobalRegistry *GR) { - std::memory_order Order = - static_cast(getIConstVal(SemanticsRegister, MRI)); - unsigned Semantics = - getSPIRVMemSemantics(Order) | - getMemSemanticsForStorageClass(GR->getPointerStorageClass(PtrRegister)); - - if (Order == Semantics) - return SemanticsRegister; + if (SemanticsRegister.isValid()) { + MachineRegisterInfo *MRI = MIRBuilder.getMRI(); + std::memory_order Order = + static_cast(getIConstVal(SemanticsRegister, MRI)); + Semantics = + getSPIRVMemSemantics(Order) | + getMemSemanticsForStorageClass(GR->getPointerStorageClass(PtrRegister)); - return Register(); + if (Order == Semantics) { + MRI->setRegClass(SemanticsRegister, &SPIRV::IDRegClass); + return SemanticsRegister; + } + } + return buildConstantIntReg(Semantics, MIRBuilder, GR); } /// Helper function for translating atomic init to OpStore. @@ -451,7 +460,8 @@ static bool buildAtomicInitInst(const SPIRV::IncomingCall *Call, MachineIRBuilder &MIRBuilder) { assert(Call->Arguments.size() == 2 && "Need 2 arguments for atomic init translation"); - + MIRBuilder.getMRI()->setRegClass(Call->Arguments[0], &SPIRV::IDRegClass); + MIRBuilder.getMRI()->setRegClass(Call->Arguments[1], &SPIRV::IDRegClass); MIRBuilder.buildInstr(SPIRV::OpStore) .addUse(Call->Arguments[0]) .addUse(Call->Arguments[1]); @@ -463,19 +473,22 @@ static bool buildAtomicLoadInst(const SPIRV::IncomingCall *Call, MachineIRBuilder &MIRBuilder, SPIRVGlobalRegistry *GR) { Register PtrRegister = Call->Arguments[0]; + MIRBuilder.getMRI()->setRegClass(PtrRegister, &SPIRV::IDRegClass); // TODO: if true insert call to __translate_ocl_memory_sccope before // OpAtomicLoad and the function implementation. We can use Translator's // output for transcoding/atomic_explicit_arguments.cl as an example. Register ScopeRegister; - if (Call->Arguments.size() > 1) + if (Call->Arguments.size() > 1) { ScopeRegister = Call->Arguments[1]; - else + MIRBuilder.getMRI()->setRegClass(ScopeRegister, &SPIRV::IDRegClass); + } else ScopeRegister = buildConstantIntReg(SPIRV::Scope::Device, MIRBuilder, GR); Register MemSemanticsReg; if (Call->Arguments.size() > 2) { // TODO: Insert call to __translate_ocl_memory_order before OpAtomicLoad. MemSemanticsReg = Call->Arguments[2]; + MIRBuilder.getMRI()->setRegClass(MemSemanticsReg, &SPIRV::IDRegClass); } else { int Semantics = SPIRV::MemorySemantics::SequentiallyConsistent | @@ -499,11 +512,12 @@ static bool buildAtomicStoreInst(const SPIRV::IncomingCall *Call, Register ScopeRegister = buildConstantIntReg(SPIRV::Scope::Device, MIRBuilder, GR); Register PtrRegister = Call->Arguments[0]; + MIRBuilder.getMRI()->setRegClass(PtrRegister, &SPIRV::IDRegClass); int Semantics = SPIRV::MemorySemantics::SequentiallyConsistent | getMemSemanticsForStorageClass(GR->getPointerStorageClass(PtrRegister)); Register MemSemanticsReg = buildConstantIntReg(Semantics, MIRBuilder, GR); - + MIRBuilder.getMRI()->setRegClass(Call->Arguments[1], &SPIRV::IDRegClass); MIRBuilder.buildInstr(SPIRV::OpAtomicStore) .addUse(PtrRegister) .addUse(ScopeRegister) @@ -525,6 +539,9 @@ static bool buildAtomicCompareExchangeInst(const SPIRV::IncomingCall *Call, Register ObjectPtr = Call->Arguments[0]; // Pointer (volatile A *object.) Register ExpectedArg = Call->Arguments[1]; // Comparator (C* expected). Register Desired = Call->Arguments[2]; // Value (C Desired). + MRI->setRegClass(ObjectPtr, &SPIRV::IDRegClass); + MRI->setRegClass(ExpectedArg, &SPIRV::IDRegClass); + MRI->setRegClass(Desired, &SPIRV::IDRegClass); SPIRVType *SpvDesiredTy = GR->getSPIRVTypeForVReg(Desired); LLT DesiredLLT = MRI->getType(Desired); @@ -564,6 +581,8 @@ static bool buildAtomicCompareExchangeInst(const SPIRV::IncomingCall *Call, MemSemEqualReg = Call->Arguments[3]; if (MemOrdNeq == MemSemEqual) MemSemUnequalReg = Call->Arguments[4]; + MRI->setRegClass(Call->Arguments[3], &SPIRV::IDRegClass); + MRI->setRegClass(Call->Arguments[4], &SPIRV::IDRegClass); } if (!MemSemEqualReg.isValid()) MemSemEqualReg = buildConstantIntReg(MemSemEqual, MIRBuilder, GR); @@ -580,6 +599,7 @@ static bool buildAtomicCompareExchangeInst(const SPIRV::IncomingCall *Call, Scope = getSPIRVScope(ClScope); if (ClScope == static_cast(Scope)) ScopeReg = Call->Arguments[5]; + MRI->setRegClass(Call->Arguments[5], &SPIRV::IDRegClass); } if (!ScopeReg.isValid()) ScopeReg = buildConstantIntReg(Scope, MIRBuilder, GR); @@ -591,6 +611,8 @@ static bool buildAtomicCompareExchangeInst(const SPIRV::IncomingCall *Call, MRI->setType(Expected, DesiredLLT); Register Tmp = !IsCmpxchg ? MRI->createGenericVirtualRegister(DesiredLLT) : Call->ReturnRegister; + if (!MRI->getRegClassOrNull(Tmp)) + MRI->setRegClass(Tmp, &SPIRV::IDRegClass); GR->assignSPIRVTypeToVReg(SpvDesiredTy, Tmp, MIRBuilder.getMF()); SPIRVType *IntTy = GR->getOrCreateSPIRVIntegerType(32, MIRBuilder); @@ -614,30 +636,23 @@ static bool buildAtomicCompareExchangeInst(const SPIRV::IncomingCall *Call, static bool buildAtomicRMWInst(const SPIRV::IncomingCall *Call, unsigned Opcode, MachineIRBuilder &MIRBuilder, SPIRVGlobalRegistry *GR) { - const MachineRegisterInfo *MRI = MIRBuilder.getMRI(); - SPIRV::Scope::Scope Scope = SPIRV::Scope::Workgroup; - Register ScopeRegister; - - if (Call->Arguments.size() >= 4) { - assert(Call->Arguments.size() == 4 && - "Too many args for explicit atomic RMW"); - ScopeRegister = buildScopeReg(Call->Arguments[3], MIRBuilder, GR, MRI); - } + MachineRegisterInfo *MRI = MIRBuilder.getMRI(); + Register ScopeRegister = + Call->Arguments.size() >= 4 ? Call->Arguments[3] : Register(); - if (!ScopeRegister.isValid()) - ScopeRegister = buildConstantIntReg(Scope, MIRBuilder, GR); + assert(Call->Arguments.size() <= 4 && + "Too many args for explicit atomic RMW"); + ScopeRegister = buildScopeReg(ScopeRegister, SPIRV::Scope::Workgroup, + MIRBuilder, GR, MRI); Register PtrRegister = Call->Arguments[0]; unsigned Semantics = SPIRV::MemorySemantics::None; - Register MemSemanticsReg; - - if (Call->Arguments.size() >= 3) - MemSemanticsReg = - buildMemSemanticsReg(Call->Arguments[2], PtrRegister, MRI, GR); - - if (!MemSemanticsReg.isValid()) - MemSemanticsReg = buildConstantIntReg(Semantics, MIRBuilder, GR); - + MRI->setRegClass(PtrRegister, &SPIRV::IDRegClass); + Register MemSemanticsReg = + Call->Arguments.size() >= 3 ? Call->Arguments[2] : Register(); + MemSemanticsReg = buildMemSemanticsReg(MemSemanticsReg, PtrRegister, + Semantics, MIRBuilder, GR); + MRI->setRegClass(Call->Arguments[1], &SPIRV::IDRegClass); MIRBuilder.buildInstr(Opcode) .addDef(Call->ReturnRegister) .addUse(GR->getSPIRVTypeID(Call->ReturnType)) @@ -653,32 +668,23 @@ static bool buildAtomicRMWInst(const SPIRV::IncomingCall *Call, unsigned Opcode, static bool buildAtomicFlagInst(const SPIRV::IncomingCall *Call, unsigned Opcode, MachineIRBuilder &MIRBuilder, SPIRVGlobalRegistry *GR) { - const MachineRegisterInfo *MRI = MIRBuilder.getMRI(); - + MachineRegisterInfo *MRI = MIRBuilder.getMRI(); Register PtrRegister = Call->Arguments[0]; unsigned Semantics = SPIRV::MemorySemantics::SequentiallyConsistent; - Register MemSemanticsReg; - - if (Call->Arguments.size() >= 2) - MemSemanticsReg = - buildMemSemanticsReg(Call->Arguments[1], PtrRegister, MRI, GR); - - if (!MemSemanticsReg.isValid()) - MemSemanticsReg = buildConstantIntReg(Semantics, MIRBuilder, GR); + Register MemSemanticsReg = + Call->Arguments.size() >= 2 ? Call->Arguments[1] : Register(); + MemSemanticsReg = buildMemSemanticsReg(MemSemanticsReg, PtrRegister, + Semantics, MIRBuilder, GR); assert((Opcode != SPIRV::OpAtomicFlagClear || (Semantics != SPIRV::MemorySemantics::Acquire && Semantics != SPIRV::MemorySemantics::AcquireRelease)) && "Invalid memory order argument!"); - SPIRV::Scope::Scope Scope = SPIRV::Scope::Device; - Register ScopeRegister; - - if (Call->Arguments.size() >= 3) - ScopeRegister = buildScopeReg(Call->Arguments[2], MIRBuilder, GR, MRI); - - if (!ScopeRegister.isValid()) - ScopeRegister = buildConstantIntReg(Scope, MIRBuilder, GR); + Register ScopeRegister = + Call->Arguments.size() >= 3 ? Call->Arguments[2] : Register(); + ScopeRegister = + buildScopeReg(ScopeRegister, SPIRV::Scope::Device, MIRBuilder, GR, MRI); auto MIB = MIRBuilder.buildInstr(Opcode); if (Opcode == SPIRV::OpAtomicFlagTestAndSet) @@ -694,7 +700,7 @@ static bool buildAtomicFlagInst(const SPIRV::IncomingCall *Call, static bool buildBarrierInst(const SPIRV::IncomingCall *Call, unsigned Opcode, MachineIRBuilder &MIRBuilder, SPIRVGlobalRegistry *GR) { - const MachineRegisterInfo *MRI = MIRBuilder.getMRI(); + MachineRegisterInfo *MRI = MIRBuilder.getMRI(); unsigned MemFlags = getIConstVal(Call->Arguments[0], MRI); unsigned MemSemantics = SPIRV::MemorySemantics::None; @@ -716,9 +722,10 @@ static bool buildBarrierInst(const SPIRV::IncomingCall *Call, unsigned Opcode, } Register MemSemanticsReg; - if (MemFlags == MemSemantics) + if (MemFlags == MemSemantics) { MemSemanticsReg = Call->Arguments[0]; - else + MRI->setRegClass(MemSemanticsReg, &SPIRV::IDRegClass); + } else MemSemanticsReg = buildConstantIntReg(MemSemantics, MIRBuilder, GR); Register ScopeReg; @@ -738,8 +745,10 @@ static bool buildBarrierInst(const SPIRV::IncomingCall *Call, unsigned Opcode, (Opcode == SPIRV::OpMemoryBarrier)) Scope = MemScope; - if (CLScope == static_cast(Scope)) + if (CLScope == static_cast(Scope)) { ScopeReg = Call->Arguments[1]; + MRI->setRegClass(ScopeReg, &SPIRV::IDRegClass); + } } if (!ScopeReg.isValid()) @@ -834,7 +843,7 @@ static bool generateGroupInst(const SPIRV::IncomingCall *Call, const SPIRV::DemangledBuiltin *Builtin = Call->Builtin; const SPIRV::GroupBuiltin *GroupBuiltin = SPIRV::lookupGroupBuiltin(Builtin->Name); - const MachineRegisterInfo *MRI = MIRBuilder.getMRI(); + MachineRegisterInfo *MRI = MIRBuilder.getMRI(); Register Arg0; if (GroupBuiltin->HasBoolArg) { Register ConstRegister = Call->Arguments[0]; @@ -876,8 +885,11 @@ static bool generateGroupInst(const SPIRV::IncomingCall *Call, MIB.addImm(GroupBuiltin->GroupOperation); if (Call->Arguments.size() > 0) { MIB.addUse(Arg0.isValid() ? Arg0 : Call->Arguments[0]); - for (unsigned i = 1; i < Call->Arguments.size(); i++) + MRI->setRegClass(Call->Arguments[0], &SPIRV::IDRegClass); + for (unsigned i = 1; i < Call->Arguments.size(); i++) { MIB.addUse(Call->Arguments[i]); + MRI->setRegClass(Call->Arguments[i], &SPIRV::IDRegClass); + } } // Build select instruction. @@ -936,16 +948,17 @@ static bool genWorkgroupQuery(const SPIRV::IncomingCall *Call, // If it's out of range (max dimension is 3), we can just return the constant // default value (0 or 1 depending on which query function). if (IsConstantIndex && getIConstVal(IndexRegister, MRI) >= 3) { - Register defaultReg = Call->ReturnRegister; + Register DefaultReg = Call->ReturnRegister; if (PointerSize != ResultWidth) { - defaultReg = MRI->createGenericVirtualRegister(LLT::scalar(PointerSize)); - GR->assignSPIRVTypeToVReg(PointerSizeType, defaultReg, + DefaultReg = MRI->createGenericVirtualRegister(LLT::scalar(PointerSize)); + MRI->setRegClass(DefaultReg, &SPIRV::IDRegClass); + GR->assignSPIRVTypeToVReg(PointerSizeType, DefaultReg, MIRBuilder.getMF()); - ToTruncate = defaultReg; + ToTruncate = DefaultReg; } auto NewRegister = GR->buildConstantInt(DefaultValue, MIRBuilder, PointerSizeType); - MIRBuilder.buildCopy(defaultReg, NewRegister); + MIRBuilder.buildCopy(DefaultReg, NewRegister); } else { // If it could be in range, we need to load from the given builtin. auto Vec3Ty = GR->getOrCreateSPIRVVectorType(PointerSizeType, 3, MIRBuilder); @@ -956,6 +969,7 @@ static bool genWorkgroupQuery(const SPIRV::IncomingCall *Call, Register Extracted = Call->ReturnRegister; if (!IsConstantIndex || PointerSize != ResultWidth) { Extracted = MRI->createGenericVirtualRegister(LLT::scalar(PointerSize)); + MRI->setRegClass(Extracted, &SPIRV::IDRegClass); GR->assignSPIRVTypeToVReg(PointerSizeType, Extracted, MIRBuilder.getMF()); } // Use Intrinsic::spv_extractelt so dynamic vs static extraction is @@ -974,6 +988,7 @@ static bool genWorkgroupQuery(const SPIRV::IncomingCall *Call, Register CompareRegister = MRI->createGenericVirtualRegister(LLT::scalar(1)); + MRI->setRegClass(CompareRegister, &SPIRV::IDRegClass); GR->assignSPIRVTypeToVReg(BoolType, CompareRegister, MIRBuilder.getMF()); // Use G_ICMP to check if idxVReg < 3. @@ -990,6 +1005,7 @@ static bool genWorkgroupQuery(const SPIRV::IncomingCall *Call, if (PointerSize != ResultWidth) { SelectionResult = MRI->createGenericVirtualRegister(LLT::scalar(PointerSize)); + MRI->setRegClass(SelectionResult, &SPIRV::IDRegClass); GR->assignSPIRVTypeToVReg(PointerSizeType, SelectionResult, MIRBuilder.getMF()); } @@ -1125,6 +1141,7 @@ static bool generateImageSizeQueryInst(const SPIRV::IncomingCall *Call, if (NumExpectedRetComponents != NumActualRetComponents) { QueryResult = MIRBuilder.getMRI()->createGenericVirtualRegister( LLT::fixed_vector(NumActualRetComponents, 32)); + MIRBuilder.getMRI()->setRegClass(QueryResult, &SPIRV::IDRegClass); SPIRVType *IntTy = GR->getOrCreateSPIRVIntegerType(32, MIRBuilder); QueryResultType = GR->getOrCreateSPIRVVectorType( IntTy, NumActualRetComponents, MIRBuilder); @@ -1133,6 +1150,7 @@ static bool generateImageSizeQueryInst(const SPIRV::IncomingCall *Call, bool IsDimBuf = ImgType->getOperand(2).getImm() == SPIRV::Dim::DIM_Buffer; unsigned Opcode = IsDimBuf ? SPIRV::OpImageQuerySize : SPIRV::OpImageQuerySizeLod; + MIRBuilder.getMRI()->setRegClass(Call->Arguments[0], &SPIRV::IDRegClass); auto MIB = MIRBuilder.buildInstr(Opcode) .addDef(QueryResult) .addUse(GR->getSPIRVTypeID(QueryResultType)) @@ -1177,6 +1195,7 @@ static bool generateImageMiscQueryInst(const SPIRV::IncomingCall *Call, SPIRV::lookupNativeBuiltin(Builtin->Name, Builtin->Set)->Opcode; Register Image = Call->Arguments[0]; + MIRBuilder.getMRI()->setRegClass(Image, &SPIRV::IDRegClass); SPIRV::Dim::Dim ImageDimensionality = static_cast( GR->getSPIRVTypeForVReg(Image)->getOperand(2).getImm()); @@ -1239,8 +1258,13 @@ static bool generateReadImageInst(const StringRef DemangledCall, SPIRVGlobalRegistry *GR) { Register Image = Call->Arguments[0]; MachineRegisterInfo *MRI = MIRBuilder.getMRI(); - - if (DemangledCall.contains_insensitive("ocl_sampler")) { + MRI->setRegClass(Image, &SPIRV::IDRegClass); + MRI->setRegClass(Call->Arguments[1], &SPIRV::IDRegClass); + bool HasOclSampler = DemangledCall.contains_insensitive("ocl_sampler"); + bool HasMsaa = DemangledCall.contains_insensitive("msaa"); + if (HasOclSampler || HasMsaa) + MRI->setRegClass(Call->Arguments[2], &SPIRV::IDRegClass); + if (HasOclSampler) { Register Sampler = Call->Arguments[1]; if (!GR->isScalarOfType(Sampler, SPIRV::OpTypeSampler) && @@ -1274,6 +1298,7 @@ static bool generateReadImageInst(const StringRef DemangledCall, } LLT LLType = LLT::scalar(GR->getScalarOrVectorBitWidth(TempType)); Register TempRegister = MRI->createGenericVirtualRegister(LLType); + MRI->setRegClass(TempRegister, &SPIRV::IDRegClass); GR->assignSPIRVTypeToVReg(TempType, TempRegister, MIRBuilder.getMF()); MIRBuilder.buildInstr(SPIRV::OpImageSampleExplicitLod) @@ -1290,7 +1315,7 @@ static bool generateReadImageInst(const StringRef DemangledCall, .addUse(GR->getSPIRVTypeID(Call->ReturnType)) .addUse(TempRegister) .addImm(0); - } else if (DemangledCall.contains_insensitive("msaa")) { + } else if (HasMsaa) { MIRBuilder.buildInstr(SPIRV::OpImageRead) .addDef(Call->ReturnRegister) .addUse(GR->getSPIRVTypeID(Call->ReturnType)) @@ -1311,6 +1336,9 @@ static bool generateReadImageInst(const StringRef DemangledCall, static bool generateWriteImageInst(const SPIRV::IncomingCall *Call, MachineIRBuilder &MIRBuilder, SPIRVGlobalRegistry *GR) { + MIRBuilder.getMRI()->setRegClass(Call->Arguments[0], &SPIRV::IDRegClass); + MIRBuilder.getMRI()->setRegClass(Call->Arguments[1], &SPIRV::IDRegClass); + MIRBuilder.getMRI()->setRegClass(Call->Arguments[2], &SPIRV::IDRegClass); MIRBuilder.buildInstr(SPIRV::OpImageWrite) .addUse(Call->Arguments[0]) // Image. .addUse(Call->Arguments[1]) // Coordinate. @@ -1322,10 +1350,11 @@ static bool generateSampleImageInst(const StringRef DemangledCall, const SPIRV::IncomingCall *Call, MachineIRBuilder &MIRBuilder, SPIRVGlobalRegistry *GR) { + MachineRegisterInfo *MRI = MIRBuilder.getMRI(); if (Call->Builtin->Name.contains_insensitive( "__translate_sampler_initializer")) { // Build sampler literal. - uint64_t Bitmask = getIConstVal(Call->Arguments[0], MIRBuilder.getMRI()); + uint64_t Bitmask = getIConstVal(Call->Arguments[0], MRI); Register Sampler = GR->buildConstantSampler( Call->ReturnRegister, getSamplerAddressingModeFromBitmask(Bitmask), getSamplerParamFromBitmask(Bitmask), @@ -1340,7 +1369,7 @@ static bool generateSampleImageInst(const StringRef DemangledCall, Register SampledImage = Call->ReturnRegister.isValid() ? Call->ReturnRegister - : MIRBuilder.getMRI()->createVirtualRegister(&SPIRV::IDRegClass); + : MRI->createVirtualRegister(&SPIRV::IDRegClass); MIRBuilder.buildInstr(SPIRV::OpSampledImage) .addDef(SampledImage) .addUse(GR->getSPIRVTypeID(SampledImageType)) @@ -1356,6 +1385,10 @@ static bool generateSampleImageInst(const StringRef DemangledCall, ReturnType = ReturnType.substr(0, ReturnType.find('(')); } SPIRVType *Type = GR->getOrCreateSPIRVTypeByName(ReturnType, MIRBuilder); + MRI->setRegClass(Call->Arguments[0], &SPIRV::IDRegClass); + MRI->setRegClass(Call->Arguments[1], &SPIRV::IDRegClass); + MRI->setRegClass(Call->Arguments[3], &SPIRV::IDRegClass); + MIRBuilder.buildInstr(SPIRV::OpImageSampleExplicitLod) .addDef(Call->ReturnRegister) .addUse(GR->getSPIRVTypeID(Type)) @@ -1431,6 +1464,75 @@ static bool generateSpecConstantInst(const SPIRV::IncomingCall *Call, } } +static bool buildNDRange(const SPIRV::IncomingCall *Call, + MachineIRBuilder &MIRBuilder, + SPIRVGlobalRegistry *GR) { + MachineRegisterInfo *MRI = MIRBuilder.getMRI(); + MRI->setRegClass(Call->Arguments[0], &SPIRV::IDRegClass); + SPIRVType *PtrType = GR->getSPIRVTypeForVReg(Call->Arguments[0]); + assert(PtrType->getOpcode() == SPIRV::OpTypePointer && + PtrType->getOperand(2).isReg()); + Register TypeReg = PtrType->getOperand(2).getReg(); + SPIRVType *StructType = GR->getSPIRVTypeForVReg(TypeReg); + MachineFunction &MF = MIRBuilder.getMF(); + Register TmpReg = MRI->createVirtualRegister(&SPIRV::IDRegClass); + GR->assignSPIRVTypeToVReg(StructType, TmpReg, MF); + // Skip the first arg, it's the destination pointer. OpBuildNDRange takes + // three other arguments, so pass zero constant on absence. + unsigned NumArgs = Call->Arguments.size(); + assert(NumArgs >= 2); + Register GlobalWorkSize = Call->Arguments[NumArgs < 4 ? 1 : 2]; + MRI->setRegClass(GlobalWorkSize, &SPIRV::IDRegClass); + Register LocalWorkSize = + NumArgs == 2 ? Register(0) : Call->Arguments[NumArgs < 4 ? 2 : 3]; + if (LocalWorkSize.isValid()) + MRI->setRegClass(LocalWorkSize, &SPIRV::IDRegClass); + Register GlobalWorkOffset = NumArgs <= 3 ? Register(0) : Call->Arguments[1]; + if (GlobalWorkOffset.isValid()) + MRI->setRegClass(GlobalWorkOffset, &SPIRV::IDRegClass); + if (NumArgs < 4) { + Register Const; + SPIRVType *SpvTy = GR->getSPIRVTypeForVReg(GlobalWorkSize); + if (SpvTy->getOpcode() == SPIRV::OpTypePointer) { + MachineInstr *DefInstr = MRI->getUniqueVRegDef(GlobalWorkSize); + assert(DefInstr && isSpvIntrinsic(*DefInstr, Intrinsic::spv_gep) && + DefInstr->getOperand(3).isReg()); + Register GWSPtr = DefInstr->getOperand(3).getReg(); + if (!MRI->getRegClassOrNull(GWSPtr)) + MRI->setRegClass(GWSPtr, &SPIRV::IDRegClass); + // TODO: Maybe simplify generation of the type of the fields. + unsigned Size = Call->Builtin->Name.equals("ndrange_3D") ? 3 : 2; + unsigned BitWidth = GR->getPointerSize() == 64 ? 64 : 32; + Type *BaseTy = IntegerType::get(MF.getFunction().getContext(), BitWidth); + Type *FieldTy = ArrayType::get(BaseTy, Size); + SPIRVType *SpvFieldTy = GR->getOrCreateSPIRVType(FieldTy, MIRBuilder); + GlobalWorkSize = MRI->createVirtualRegister(&SPIRV::IDRegClass); + GR->assignSPIRVTypeToVReg(SpvFieldTy, GlobalWorkSize, MF); + MIRBuilder.buildInstr(SPIRV::OpLoad) + .addDef(GlobalWorkSize) + .addUse(GR->getSPIRVTypeID(SpvFieldTy)) + .addUse(GWSPtr); + Const = GR->getOrCreateConsIntArray(0, MIRBuilder, SpvFieldTy); + } else { + Const = GR->buildConstantInt(0, MIRBuilder, SpvTy); + } + if (!LocalWorkSize.isValid()) + LocalWorkSize = Const; + if (!GlobalWorkOffset.isValid()) + GlobalWorkOffset = Const; + } + assert(LocalWorkSize.isValid() && GlobalWorkOffset.isValid()); + MIRBuilder.buildInstr(SPIRV::OpBuildNDRange) + .addDef(TmpReg) + .addUse(TypeReg) + .addUse(GlobalWorkSize) + .addUse(LocalWorkSize) + .addUse(GlobalWorkOffset); + return MIRBuilder.buildInstr(SPIRV::OpStore) + .addUse(Call->Arguments[0]) + .addUse(TmpReg); +} + static MachineInstr *getBlockStructInstr(Register ParamReg, MachineRegisterInfo *MRI) { // We expect the following sequence of instructions: @@ -1538,9 +1640,8 @@ static bool buildEnqueueKernel(const SPIRV::IncomingCall *Call, const SPIRVType *PointerSizeTy = GR->getOrCreateSPIRVPointerType( Int32Ty, MIRBuilder, SPIRV::StorageClass::Function); for (unsigned I = 0; I < LocalSizeNum; ++I) { - Register Reg = - MIRBuilder.getMRI()->createVirtualRegister(&SPIRV::IDRegClass); - MIRBuilder.getMRI()->setType(Reg, LLType); + Register Reg = MRI->createVirtualRegister(&SPIRV::IDRegClass); + MRI->setType(Reg, LLType); GR->assignSPIRVTypeToVReg(PointerSizeTy, Reg, MIRBuilder.getMF()); auto GEPInst = MIRBuilder.buildIntrinsic(Intrinsic::spv_gep, ArrayRef{Reg}, true); @@ -1605,6 +1706,7 @@ static bool generateEnqueueInst(const SPIRV::IncomingCall *Call, switch (Opcode) { case SPIRV::OpRetainEvent: case SPIRV::OpReleaseEvent: + MIRBuilder.getMRI()->setRegClass(Call->Arguments[0], &SPIRV::IDRegClass); return MIRBuilder.buildInstr(Opcode).addUse(Call->Arguments[0]); case SPIRV::OpCreateUserEvent: case SPIRV::OpGetDefaultQueue: @@ -1612,77 +1714,27 @@ static bool generateEnqueueInst(const SPIRV::IncomingCall *Call, .addDef(Call->ReturnRegister) .addUse(GR->getSPIRVTypeID(Call->ReturnType)); case SPIRV::OpIsValidEvent: + MIRBuilder.getMRI()->setRegClass(Call->Arguments[0], &SPIRV::IDRegClass); return MIRBuilder.buildInstr(Opcode) .addDef(Call->ReturnRegister) .addUse(GR->getSPIRVTypeID(Call->ReturnType)) .addUse(Call->Arguments[0]); case SPIRV::OpSetUserEventStatus: + MIRBuilder.getMRI()->setRegClass(Call->Arguments[0], &SPIRV::IDRegClass); + MIRBuilder.getMRI()->setRegClass(Call->Arguments[1], &SPIRV::IDRegClass); return MIRBuilder.buildInstr(Opcode) .addUse(Call->Arguments[0]) .addUse(Call->Arguments[1]); case SPIRV::OpCaptureEventProfilingInfo: + MIRBuilder.getMRI()->setRegClass(Call->Arguments[0], &SPIRV::IDRegClass); + MIRBuilder.getMRI()->setRegClass(Call->Arguments[1], &SPIRV::IDRegClass); + MIRBuilder.getMRI()->setRegClass(Call->Arguments[2], &SPIRV::IDRegClass); return MIRBuilder.buildInstr(Opcode) .addUse(Call->Arguments[0]) .addUse(Call->Arguments[1]) .addUse(Call->Arguments[2]); - case SPIRV::OpBuildNDRange: { - MachineRegisterInfo *MRI = MIRBuilder.getMRI(); - SPIRVType *PtrType = GR->getSPIRVTypeForVReg(Call->Arguments[0]); - assert(PtrType->getOpcode() == SPIRV::OpTypePointer && - PtrType->getOperand(2).isReg()); - Register TypeReg = PtrType->getOperand(2).getReg(); - SPIRVType *StructType = GR->getSPIRVTypeForVReg(TypeReg); - Register TmpReg = MRI->createVirtualRegister(&SPIRV::IDRegClass); - GR->assignSPIRVTypeToVReg(StructType, TmpReg, MIRBuilder.getMF()); - // Skip the first arg, it's the destination pointer. OpBuildNDRange takes - // three other arguments, so pass zero constant on absence. - unsigned NumArgs = Call->Arguments.size(); - assert(NumArgs >= 2); - Register GlobalWorkSize = Call->Arguments[NumArgs < 4 ? 1 : 2]; - Register LocalWorkSize = - NumArgs == 2 ? Register(0) : Call->Arguments[NumArgs < 4 ? 2 : 3]; - Register GlobalWorkOffset = NumArgs <= 3 ? Register(0) : Call->Arguments[1]; - if (NumArgs < 4) { - Register Const; - SPIRVType *SpvTy = GR->getSPIRVTypeForVReg(GlobalWorkSize); - if (SpvTy->getOpcode() == SPIRV::OpTypePointer) { - MachineInstr *DefInstr = MRI->getUniqueVRegDef(GlobalWorkSize); - assert(DefInstr && isSpvIntrinsic(*DefInstr, Intrinsic::spv_gep) && - DefInstr->getOperand(3).isReg()); - Register GWSPtr = DefInstr->getOperand(3).getReg(); - // TODO: Maybe simplify generation of the type of the fields. - unsigned Size = Call->Builtin->Name.equals("ndrange_3D") ? 3 : 2; - unsigned BitWidth = GR->getPointerSize() == 64 ? 64 : 32; - Type *BaseTy = IntegerType::get( - MIRBuilder.getMF().getFunction().getContext(), BitWidth); - Type *FieldTy = ArrayType::get(BaseTy, Size); - SPIRVType *SpvFieldTy = GR->getOrCreateSPIRVType(FieldTy, MIRBuilder); - GlobalWorkSize = MRI->createVirtualRegister(&SPIRV::IDRegClass); - GR->assignSPIRVTypeToVReg(SpvFieldTy, GlobalWorkSize, - MIRBuilder.getMF()); - MIRBuilder.buildInstr(SPIRV::OpLoad) - .addDef(GlobalWorkSize) - .addUse(GR->getSPIRVTypeID(SpvFieldTy)) - .addUse(GWSPtr); - Const = GR->getOrCreateConsIntArray(0, MIRBuilder, SpvFieldTy); - } else { - Const = GR->buildConstantInt(0, MIRBuilder, SpvTy); - } - if (!LocalWorkSize.isValid()) - LocalWorkSize = Const; - if (!GlobalWorkOffset.isValid()) - GlobalWorkOffset = Const; - } - MIRBuilder.buildInstr(Opcode) - .addDef(TmpReg) - .addUse(TypeReg) - .addUse(GlobalWorkSize) - .addUse(LocalWorkSize) - .addUse(GlobalWorkOffset); - return MIRBuilder.buildInstr(SPIRV::OpStore) - .addUse(Call->Arguments[0]) - .addUse(TmpReg); - } + case SPIRV::OpBuildNDRange: + return buildNDRange(Call, MIRBuilder, GR); case SPIRV::OpEnqueueKernel: return buildEnqueueKernel(Call, MIRBuilder, GR); default: @@ -1817,16 +1869,23 @@ static bool generateLoadStoreInst(const SPIRV::IncomingCall *Call, } // Add a pointer to the value to load/store. MIB.addUse(Call->Arguments[0]); + MachineRegisterInfo *MRI = MIRBuilder.getMRI(); + MRI->setRegClass(Call->Arguments[0], &SPIRV::IDRegClass); // Add a value to store. - if (!IsLoad) + if (!IsLoad) { MIB.addUse(Call->Arguments[1]); + MRI->setRegClass(Call->Arguments[1], &SPIRV::IDRegClass); + } // Add optional memory attributes and an alignment. - MachineRegisterInfo *MRI = MIRBuilder.getMRI(); unsigned NumArgs = Call->Arguments.size(); - if ((IsLoad && NumArgs >= 2) || NumArgs >= 3) + if ((IsLoad && NumArgs >= 2) || NumArgs >= 3) { MIB.addImm(getConstFromIntrinsic(Call->Arguments[IsLoad ? 1 : 2], MRI)); - if ((IsLoad && NumArgs >= 3) || NumArgs >= 4) + MRI->setRegClass(Call->Arguments[IsLoad ? 1 : 2], &SPIRV::IDRegClass); + } + if ((IsLoad && NumArgs >= 3) || NumArgs >= 4) { MIB.addImm(getConstFromIntrinsic(Call->Arguments[IsLoad ? 2 : 3], MRI)); + MRI->setRegClass(Call->Arguments[IsLoad ? 2 : 3], &SPIRV::IDRegClass); + } return true; } @@ -1846,6 +1905,8 @@ std::optional lowerBuiltin(const StringRef DemangledCall, SPIRVType *ReturnType = nullptr; if (OrigRetTy && !OrigRetTy->isVoidTy()) { ReturnType = GR->assignTypeToVReg(OrigRetTy, OrigRet, MIRBuilder); + if (!MIRBuilder.getMRI()->getRegClassOrNull(ReturnRegister)) + MIRBuilder.getMRI()->setRegClass(ReturnRegister, &SPIRV::IDRegClass); } else if (OrigRetTy && OrigRetTy->isVoidTy()) { ReturnRegister = MIRBuilder.getMRI()->createVirtualRegister(&IDRegClass); MIRBuilder.getMRI()->setType(ReturnRegister, LLT::scalar(32)); diff --git a/llvm/lib/Target/SPIRV/SPIRVCallLowering.cpp b/llvm/lib/Target/SPIRV/SPIRVCallLowering.cpp index 8b618686ee7da..47b25a1f83515 100644 --- a/llvm/lib/Target/SPIRV/SPIRVCallLowering.cpp +++ b/llvm/lib/Target/SPIRV/SPIRVCallLowering.cpp @@ -374,6 +374,7 @@ bool SPIRVCallLowering::lowerCall(MachineIRBuilder &MIRBuilder, FTy = getOriginalFunctionType(*CF); } + MachineRegisterInfo *MRI = MIRBuilder.getMRI(); Register ResVReg = Info.OrigRet.Regs.empty() ? Register(0) : Info.OrigRet.Regs[0]; std::string FuncName = Info.Callee.getGlobal()->getName().str(); @@ -410,8 +411,9 @@ bool SPIRVCallLowering::lowerCall(MachineIRBuilder &MIRBuilder, for (const Argument &Arg : CF->args()) { if (MIRBuilder.getDataLayout().getTypeStoreSize(Arg.getType()).isZero()) continue; // Don't handle zero sized types. - ToInsert.push_back( - {MIRBuilder.getMRI()->createGenericVirtualRegister(LLT::scalar(32))}); + Register Reg = MRI->createGenericVirtualRegister(LLT::scalar(32)); + MRI->setRegClass(Reg, &SPIRV::IDRegClass); + ToInsert.push_back({Reg}); VRegArgs.push_back(ToInsert.back()); } // TODO: Reuse FunctionLoweringInfo diff --git a/llvm/lib/Target/SPIRV/SPIRVGlobalRegistry.cpp b/llvm/lib/Target/SPIRV/SPIRVGlobalRegistry.cpp index 062188abbf5e8..c77a7f860eda2 100644 --- a/llvm/lib/Target/SPIRV/SPIRVGlobalRegistry.cpp +++ b/llvm/lib/Target/SPIRV/SPIRVGlobalRegistry.cpp @@ -143,6 +143,7 @@ SPIRVGlobalRegistry::getOrCreateConstIntReg(uint64_t Val, SPIRVType *SpvType, unsigned BitWidth = SpvType ? getScalarOrVectorBitWidth(SpvType) : 32; LLT LLTy = LLT::scalar(32); Res = CurMF->getRegInfo().createGenericVirtualRegister(LLTy); + CurMF->getRegInfo().setRegClass(Res, &SPIRV::IDRegClass); if (MIRBuilder) assignTypeToVReg(LLVMIntTy, Res, *MIRBuilder); else @@ -202,6 +203,7 @@ Register SPIRVGlobalRegistry::buildConstantInt(uint64_t Val, unsigned BitWidth = SpvType ? getScalarOrVectorBitWidth(SpvType) : 32; LLT LLTy = LLT::scalar(EmitIR ? BitWidth : 32); Res = MF.getRegInfo().createGenericVirtualRegister(LLTy); + MF.getRegInfo().setRegClass(Res, &SPIRV::IDRegClass); assignTypeToVReg(LLVMIntTy, Res, MIRBuilder, SPIRV::AccessQualifier::ReadWrite, EmitIR); DT.add(ConstInt, &MIRBuilder.getMF(), Res); @@ -247,6 +249,7 @@ Register SPIRVGlobalRegistry::buildConstantFP(APFloat Val, if (!Res.isValid()) { unsigned BitWidth = SpvType ? getScalarOrVectorBitWidth(SpvType) : 32; Res = MF.getRegInfo().createGenericVirtualRegister(LLT::scalar(BitWidth)); + MF.getRegInfo().setRegClass(Res, &SPIRV::IDRegClass); assignTypeToVReg(LLVMFPTy, Res, MIRBuilder); DT.add(ConstFP, &MF, Res); MIRBuilder.buildFConstant(Res, *ConstFP); @@ -272,6 +275,7 @@ Register SPIRVGlobalRegistry::getOrCreateIntCompositeOrNull( LLT LLTy = LLT::scalar(32); Register SpvVecConst = CurMF->getRegInfo().createGenericVirtualRegister(LLTy); + CurMF->getRegInfo().setRegClass(SpvVecConst, &SPIRV::IDRegClass); assignSPIRVTypeToVReg(SpvType, SpvVecConst, *CurMF); DT.add(CA, CurMF, SpvVecConst); MachineInstrBuilder MIB; @@ -343,6 +347,7 @@ Register SPIRVGlobalRegistry::getOrCreateIntCompositeOrNull( LLT LLTy = EmitIR ? LLT::fixed_vector(ElemCnt, BitWidth) : LLT::scalar(32); Register SpvVecConst = CurMF->getRegInfo().createGenericVirtualRegister(LLTy); + CurMF->getRegInfo().setRegClass(SpvVecConst, &SPIRV::IDRegClass); assignSPIRVTypeToVReg(SpvType, SpvVecConst, *CurMF); DT.add(CA, CurMF, SpvVecConst); if (EmitIR) { @@ -411,6 +416,7 @@ SPIRVGlobalRegistry::getOrCreateConstNullPtr(MachineIRBuilder &MIRBuilder, if (!Res.isValid()) { LLT LLTy = LLT::pointer(LLVMPtrTy->getAddressSpace(), PointerSize); Res = CurMF->getRegInfo().createGenericVirtualRegister(LLTy); + CurMF->getRegInfo().setRegClass(Res, &SPIRV::IDRegClass); assignSPIRVTypeToVReg(SpvType, Res, *CurMF); MIRBuilder.buildInstr(SPIRV::OpConstantNull) .addDef(Res) @@ -1090,6 +1096,7 @@ Register SPIRVGlobalRegistry::getOrCreateUndef(MachineInstr &I, return Res; LLT LLTy = LLT::scalar(32); Res = CurMF->getRegInfo().createGenericVirtualRegister(LLTy); + CurMF->getRegInfo().setRegClass(Res, &SPIRV::IDRegClass); assignSPIRVTypeToVReg(SpvType, Res, *CurMF); DT.add(UV, CurMF, Res); diff --git a/llvm/lib/Target/SPIRV/SPIRVPreLegalizer.cpp b/llvm/lib/Target/SPIRV/SPIRVPreLegalizer.cpp index 27d0e8a976f0d..2818329ece3cb 100644 --- a/llvm/lib/Target/SPIRV/SPIRVPreLegalizer.cpp +++ b/llvm/lib/Target/SPIRV/SPIRVPreLegalizer.cpp @@ -85,6 +85,9 @@ static void addConstantsToTrack(MachineFunction &MF, SPIRVGlobalRegistry *GR) { Register Reg = MI->getOperand(2).getReg(); if (RegsAlreadyAddedToDT.find(MI) != RegsAlreadyAddedToDT.end()) Reg = RegsAlreadyAddedToDT[MI]; + auto *RC = MRI.getRegClassOrNull(MI->getOperand(0).getReg()); + if (!MRI.getRegClassOrNull(Reg) && RC) + MRI.setRegClass(Reg, RC); MRI.replaceRegWith(MI->getOperand(0).getReg(), Reg); MI->eraseFromParent(); } @@ -201,8 +204,12 @@ Register insertAssignInstr(Register Reg, Type *Ty, SPIRVType *SpirvTy, (Def->getNextNode() ? Def->getNextNode()->getIterator() : Def->getParent()->end())); Register NewReg = MRI.createGenericVirtualRegister(MRI.getType(Reg)); - if (auto *RC = MRI.getRegClassOrNull(Reg)) + if (auto *RC = MRI.getRegClassOrNull(Reg)) { MRI.setRegClass(NewReg, RC); + } else { + MRI.setRegClass(NewReg, &SPIRV::IDRegClass); + MRI.setRegClass(Reg, &SPIRV::IDRegClass); + } SpirvTy = SpirvTy ? SpirvTy : GR->getOrCreateSPIRVType(Ty, MIB); GR->assignSPIRVTypeToVReg(SpirvTy, Reg, MIB.getMF()); // This is to make it convenient for Legalizer to get the SPIRVType @@ -217,7 +224,6 @@ Register insertAssignInstr(Register Reg, Type *Ty, SPIRVType *SpirvTy, .addUse(GR->getSPIRVTypeID(SpirvTy)) .setMIFlags(Flags); Def->getOperand(0).setReg(NewReg); - MRI.setRegClass(Reg, &SPIRV::ANYIDRegClass); return NewReg; } } // namespace llvm