diff --git a/llvm/lib/Target/SPIRV/SPIRVCallLowering.cpp b/llvm/lib/Target/SPIRV/SPIRVCallLowering.cpp index cc438b2bb8d4d..10569ef0468bd 100644 --- a/llvm/lib/Target/SPIRV/SPIRVCallLowering.cpp +++ b/llvm/lib/Target/SPIRV/SPIRVCallLowering.cpp @@ -150,7 +150,8 @@ getKernelArgTypeQual(const Function &F, unsigned ArgIdx) { static SPIRVType *getArgSPIRVType(const Function &F, unsigned ArgIdx, SPIRVGlobalRegistry *GR, - MachineIRBuilder &MIRBuilder) { + MachineIRBuilder &MIRBuilder, + const SPIRVSubtarget &ST) { // Read argument's access qualifier from metadata or default. SPIRV::AccessQualifier::AccessQualifier ArgAccessQual = getArgAccessQual(F, ArgIdx); @@ -169,8 +170,8 @@ static SPIRVType *getArgSPIRVType(const Function &F, unsigned ArgIdx, if (MDTypeStr.ends_with("*")) ResArgType = GR->getOrCreateSPIRVTypeByName( MDTypeStr, MIRBuilder, - addressSpaceToStorageClass( - OriginalArgType->getPointerAddressSpace())); + addressSpaceToStorageClass(OriginalArgType->getPointerAddressSpace(), + ST)); else if (MDTypeStr.ends_with("_t")) ResArgType = GR->getOrCreateSPIRVTypeByName( "opencl." + MDTypeStr.str(), MIRBuilder, @@ -206,6 +207,10 @@ bool SPIRVCallLowering::lowerFormalArguments(MachineIRBuilder &MIRBuilder, assert(GR && "Must initialize the SPIRV type registry before lowering args."); GR->setCurrentFunc(MIRBuilder.getMF()); + // Get access to information about available extensions + const SPIRVSubtarget *ST = + static_cast(&MIRBuilder.getMF().getSubtarget()); + // Assign types and names to all args, and store their types for later. FunctionType *FTy = getOriginalFunctionType(F); SmallVector ArgTypeVRegs; @@ -216,7 +221,7 @@ bool SPIRVCallLowering::lowerFormalArguments(MachineIRBuilder &MIRBuilder, // TODO: handle the case of multiple registers. if (VRegs[i].size() > 1) return false; - auto *SpirvTy = getArgSPIRVType(F, i, GR, MIRBuilder); + auto *SpirvTy = getArgSPIRVType(F, i, GR, MIRBuilder, *ST); GR->assignSPIRVTypeToVReg(SpirvTy, VRegs[i][0], MIRBuilder.getMF()); ArgTypeVRegs.push_back(SpirvTy); @@ -318,10 +323,6 @@ bool SPIRVCallLowering::lowerFormalArguments(MachineIRBuilder &MIRBuilder, if (F.hasName()) buildOpName(FuncVReg, F.getName(), MIRBuilder); - // Get access to information about available extensions - const auto *ST = - static_cast(&MIRBuilder.getMF().getSubtarget()); - // Handle entry points and function linkage. if (isEntryPoint(F)) { const auto &STI = MIRBuilder.getMF().getSubtarget(); diff --git a/llvm/lib/Target/SPIRV/SPIRVGlobalRegistry.cpp b/llvm/lib/Target/SPIRV/SPIRVGlobalRegistry.cpp index 47fec745c3f18..a1cb630f1aa47 100644 --- a/llvm/lib/Target/SPIRV/SPIRVGlobalRegistry.cpp +++ b/llvm/lib/Target/SPIRV/SPIRVGlobalRegistry.cpp @@ -709,7 +709,10 @@ SPIRVType *SPIRVGlobalRegistry::createSPIRVType( // TODO: change the implementation once opaque pointers are supported // in the SPIR-V specification. SpvElementType = getOrCreateSPIRVIntegerType(8, MIRBuilder); - auto SC = addressSpaceToStorageClass(PType->getAddressSpace()); + // Get access to information about available extensions + const SPIRVSubtarget *ST = + static_cast(&MIRBuilder.getMF().getSubtarget()); + auto SC = addressSpaceToStorageClass(PType->getAddressSpace(), *ST); // Null pointer means we have a loop in type definitions, make and // return corresponding OpTypeForwardPointer. if (SpvElementType == nullptr) { diff --git a/llvm/lib/Target/SPIRV/SPIRVInstrInfo.td b/llvm/lib/Target/SPIRV/SPIRVInstrInfo.td index 86f65b6320d53..7c5252e8cb372 100644 --- a/llvm/lib/Target/SPIRV/SPIRVInstrInfo.td +++ b/llvm/lib/Target/SPIRV/SPIRVInstrInfo.td @@ -430,6 +430,10 @@ def OpGenericCastToPtrExplicit : Op<123, (outs ID:$r), (ins TYPE:$t, ID:$p, Stor "$r = OpGenericCastToPtrExplicit $t $p $s">; def OpBitcast : UnOp<"OpBitcast", 124>; +// SPV_INTEL_usm_storage_classes +def OpPtrCastToCrossWorkgroupINTEL : UnOp<"OpPtrCastToCrossWorkgroupINTEL", 5934>; +def OpCrossWorkgroupCastToPtrINTEL : UnOp<"OpCrossWorkgroupCastToPtrINTEL", 5938>; + // 3.42.12 Composite Instructions def OpVectorExtractDynamic: Op<77, (outs ID:$res), (ins TYPE:$type, vID:$vec, ID:$idx), diff --git a/llvm/lib/Target/SPIRV/SPIRVInstructionSelector.cpp b/llvm/lib/Target/SPIRV/SPIRVInstructionSelector.cpp index 53d19a1e31382..7258d3b4d88ed 100644 --- a/llvm/lib/Target/SPIRV/SPIRVInstructionSelector.cpp +++ b/llvm/lib/Target/SPIRV/SPIRVInstructionSelector.cpp @@ -828,8 +828,18 @@ static bool isGenericCastablePtr(SPIRV::StorageClass::StorageClass SC) { } } +static bool isUSMStorageClass(SPIRV::StorageClass::StorageClass SC) { + switch (SC) { + case SPIRV::StorageClass::DeviceOnlyINTEL: + case SPIRV::StorageClass::HostOnlyINTEL: + return true; + default: + return false; + } +} + // In SPIR-V address space casting can only happen to and from the Generic -// storage class. We can also only case Workgroup, CrossWorkgroup, or Function +// storage class. We can also only cast Workgroup, CrossWorkgroup, or Function // pointers to and from Generic pointers. As such, we can convert e.g. from // Workgroup to Function by going via a Generic pointer as an intermediary. All // other combinations can only be done by a bitcast, and are probably not safe. @@ -862,13 +872,17 @@ bool SPIRVInstructionSelector::selectAddrSpaceCast(Register ResVReg, SPIRV::StorageClass::StorageClass SrcSC = GR.getPointerStorageClass(SrcPtr); SPIRV::StorageClass::StorageClass DstSC = GR.getPointerStorageClass(ResVReg); - // Casting from an eligable pointer to Generic. + // don't generate a cast between identical storage classes + if (SrcSC == DstSC) + return true; + + // Casting from an eligible pointer to Generic. if (DstSC == SPIRV::StorageClass::Generic && isGenericCastablePtr(SrcSC)) return selectUnOp(ResVReg, ResType, I, SPIRV::OpPtrCastToGeneric); - // Casting from Generic to an eligable pointer. + // Casting from Generic to an eligible pointer. if (SrcSC == SPIRV::StorageClass::Generic && isGenericCastablePtr(DstSC)) return selectUnOp(ResVReg, ResType, I, SPIRV::OpGenericCastToPtr); - // Casting between 2 eligable pointers using Generic as an intermediary. + // Casting between 2 eligible pointers using Generic as an intermediary. if (isGenericCastablePtr(SrcSC) && isGenericCastablePtr(DstSC)) { Register Tmp = MRI->createVirtualRegister(&SPIRV::IDRegClass); SPIRVType *GenericPtrTy = GR.getOrCreateSPIRVPointerType( @@ -886,6 +900,16 @@ bool SPIRVInstructionSelector::selectAddrSpaceCast(Register ResVReg, .addUse(Tmp) .constrainAllUses(TII, TRI, RBI); } + + // Check if instructions from the SPV_INTEL_usm_storage_classes extension may + // be applied + if (isUSMStorageClass(SrcSC) && DstSC == SPIRV::StorageClass::CrossWorkgroup) + return selectUnOp(ResVReg, ResType, I, + SPIRV::OpPtrCastToCrossWorkgroupINTEL); + if (SrcSC == SPIRV::StorageClass::CrossWorkgroup && isUSMStorageClass(DstSC)) + return selectUnOp(ResVReg, ResType, I, + SPIRV::OpCrossWorkgroupCastToPtrINTEL); + // TODO Should this case just be disallowed completely? // We're casting 2 other arbitrary address spaces, so have to bitcast. return selectUnOp(ResVReg, ResType, I, SPIRV::OpBitcast); @@ -1545,7 +1569,7 @@ bool SPIRVInstructionSelector::selectGlobalValue( } SPIRVType *ResType = GR.getOrCreateSPIRVPointerType( PointerBaseType, I, TII, - addressSpaceToStorageClass(GV->getAddressSpace())); + addressSpaceToStorageClass(GV->getAddressSpace(), STI)); std::string GlobalIdent; if (!GV->hasName()) { @@ -1618,7 +1642,7 @@ bool SPIRVInstructionSelector::selectGlobalValue( unsigned AddrSpace = GV->getAddressSpace(); SPIRV::StorageClass::StorageClass Storage = - addressSpaceToStorageClass(AddrSpace); + addressSpaceToStorageClass(AddrSpace, STI); bool HasLnkTy = GV->getLinkage() != GlobalValue::InternalLinkage && Storage != SPIRV::StorageClass::Function; SPIRV::LinkageType::LinkageType LnkType = diff --git a/llvm/lib/Target/SPIRV/SPIRVLegalizerInfo.cpp b/llvm/lib/Target/SPIRV/SPIRVLegalizerInfo.cpp index 011a550a7b3d9..4f2e7a240fc2c 100644 --- a/llvm/lib/Target/SPIRV/SPIRVLegalizerInfo.cpp +++ b/llvm/lib/Target/SPIRV/SPIRVLegalizerInfo.cpp @@ -102,14 +102,16 @@ SPIRVLegalizerInfo::SPIRVLegalizerInfo(const SPIRVSubtarget &ST) { const LLT p2 = LLT::pointer(2, PSize); // UniformConstant const LLT p3 = LLT::pointer(3, PSize); // Workgroup const LLT p4 = LLT::pointer(4, PSize); // Generic - const LLT p5 = LLT::pointer(5, PSize); // Input + const LLT p5 = + LLT::pointer(5, PSize); // Input, SPV_INTEL_usm_storage_classes (Device) + const LLT p6 = LLT::pointer(6, PSize); // SPV_INTEL_usm_storage_classes (Host) // TODO: remove copy-pasting here by using concatenation in some way. auto allPtrsScalarsAndVectors = { - p0, p1, p2, p3, p4, p5, s1, s8, s16, - s32, s64, v2s1, v2s8, v2s16, v2s32, v2s64, v3s1, v3s8, - v3s16, v3s32, v3s64, v4s1, v4s8, v4s16, v4s32, v4s64, v8s1, - v8s8, v8s16, v8s32, v8s64, v16s1, v16s8, v16s16, v16s32, v16s64}; + p0, p1, p2, p3, p4, p5, p6, s1, s8, s16, + s32, s64, v2s1, v2s8, v2s16, v2s32, v2s64, v3s1, v3s8, v3s16, + v3s32, v3s64, v4s1, v4s8, v4s16, v4s32, v4s64, v8s1, v8s8, v8s16, + v8s32, v8s64, v16s1, v16s8, v16s16, v16s32, v16s64}; auto allScalarsAndVectors = { s1, s8, s16, s32, s64, v2s1, v2s8, v2s16, v2s32, v2s64, @@ -133,8 +135,8 @@ SPIRVLegalizerInfo::SPIRVLegalizerInfo(const SPIRVSubtarget &ST) { auto allFloatAndIntScalars = allIntScalars; - auto allPtrs = {p0, p1, p2, p3, p4, p5}; - auto allWritablePtrs = {p0, p1, p3, p4}; + auto allPtrs = {p0, p1, p2, p3, p4, p5, p6}; + auto allWritablePtrs = {p0, p1, p3, p4, p5, p6}; for (auto Opc : TypeFoldingSupportingOpcs) getActionDefinitionsBuilder(Opc).custom(); diff --git a/llvm/lib/Target/SPIRV/SPIRVModuleAnalysis.cpp b/llvm/lib/Target/SPIRV/SPIRVModuleAnalysis.cpp index 9b9575b987994..3be28c97d9538 100644 --- a/llvm/lib/Target/SPIRV/SPIRVModuleAnalysis.cpp +++ b/llvm/lib/Target/SPIRV/SPIRVModuleAnalysis.cpp @@ -1063,6 +1063,13 @@ void addInstrRequirements(const MachineInstr &MI, Reqs.addCapability(SPIRV::Capability::ExpectAssumeKHR); } break; + case SPIRV::OpPtrCastToCrossWorkgroupINTEL: + case SPIRV::OpCrossWorkgroupCastToPtrINTEL: + if (ST.canUseExtension(SPIRV::Extension::SPV_INTEL_usm_storage_classes)) { + Reqs.addExtension(SPIRV::Extension::SPV_INTEL_usm_storage_classes); + Reqs.addCapability(SPIRV::Capability::USMStorageClassesINTEL); + } + break; case SPIRV::OpConstantFunctionPointerINTEL: if (ST.canUseExtension(SPIRV::Extension::SPV_INTEL_function_pointers)) { Reqs.addExtension(SPIRV::Extension::SPV_INTEL_function_pointers); diff --git a/llvm/lib/Target/SPIRV/SPIRVPreLegalizer.cpp b/llvm/lib/Target/SPIRV/SPIRVPreLegalizer.cpp index cbc16fa986614..144216896eb68 100644 --- a/llvm/lib/Target/SPIRV/SPIRVPreLegalizer.cpp +++ b/llvm/lib/Target/SPIRV/SPIRVPreLegalizer.cpp @@ -122,6 +122,9 @@ static void foldConstantsIntoIntrinsics(MachineFunction &MF) { static void insertBitcasts(MachineFunction &MF, SPIRVGlobalRegistry *GR, MachineIRBuilder MIB) { + // Get access to information about available extensions + const SPIRVSubtarget *ST = + static_cast(&MIB.getMF().getSubtarget()); SmallVector ToErase; for (MachineBasicBlock &MBB : MF) { for (MachineInstr &MI : MBB) { @@ -141,7 +144,7 @@ static void insertBitcasts(MachineFunction &MF, SPIRVGlobalRegistry *GR, getMDOperandAsType(MI.getOperand(3).getMetadata(), 0), MIB); SPIRVType *AssignedPtrType = GR->getOrCreateSPIRVPointerType( BaseTy, MI, *MF.getSubtarget().getInstrInfo(), - addressSpaceToStorageClass(MI.getOperand(4).getImm())); + addressSpaceToStorageClass(MI.getOperand(4).getImm(), *ST)); // If the bitcast would be redundant, replace all uses with the source // register. @@ -250,6 +253,10 @@ Register insertAssignInstr(Register Reg, Type *Ty, SPIRVType *SpirvTy, static void generateAssignInstrs(MachineFunction &MF, SPIRVGlobalRegistry *GR, MachineIRBuilder MIB) { + // Get access to information about available extensions + const SPIRVSubtarget *ST = + static_cast(&MIB.getMF().getSubtarget()); + MachineRegisterInfo &MRI = MF.getRegInfo(); SmallVector ToErase; @@ -269,7 +276,7 @@ static void generateAssignInstrs(MachineFunction &MF, SPIRVGlobalRegistry *GR, getMDOperandAsType(MI.getOperand(2).getMetadata(), 0), MIB); SPIRVType *AssignedPtrType = GR->getOrCreateSPIRVPointerType( BaseTy, MI, *MF.getSubtarget().getInstrInfo(), - addressSpaceToStorageClass(MI.getOperand(3).getImm())); + addressSpaceToStorageClass(MI.getOperand(3).getImm(), *ST)); MachineInstr *Def = MRI.getVRegDef(Reg); assert(Def && "Expecting an instruction that defines the register"); insertAssignInstr(Reg, nullptr, AssignedPtrType, GR, MIB, diff --git a/llvm/lib/Target/SPIRV/SPIRVSubtarget.cpp b/llvm/lib/Target/SPIRV/SPIRVSubtarget.cpp index 4694363614ef6..79f16146ccd94 100644 --- a/llvm/lib/Target/SPIRV/SPIRVSubtarget.cpp +++ b/llvm/lib/Target/SPIRV/SPIRVSubtarget.cpp @@ -49,6 +49,12 @@ cl::list Extensions( clEnumValN(SPIRV::Extension::SPV_INTEL_optnone, "SPV_INTEL_optnone", "Adds OptNoneINTEL value for Function Control mask that " "indicates a request to not optimize the function."), + clEnumValN(SPIRV::Extension::SPV_INTEL_usm_storage_classes, + "SPV_INTEL_usm_storage_classes", + "Introduces two new storage classes that are sub classes of " + "the CrossWorkgroup storage class " + "that provides additional information that can enable " + "optimization."), clEnumValN(SPIRV::Extension::SPV_INTEL_subgroups, "SPV_INTEL_subgroups", "Allows work items in a subgroup to share data without the " "use of local memory and work group barriers, and to " diff --git a/llvm/lib/Target/SPIRV/SPIRVSymbolicOperands.td b/llvm/lib/Target/SPIRV/SPIRVSymbolicOperands.td index 6c36087baa85e..b022b97408d7d 100644 --- a/llvm/lib/Target/SPIRV/SPIRVSymbolicOperands.td +++ b/llvm/lib/Target/SPIRV/SPIRVSymbolicOperands.td @@ -463,6 +463,7 @@ defm AtomicFloat16MinMaxEXT : CapabilityOperand<5616, 0, 0, [SPV_EXT_shader_atom defm AtomicFloat32MinMaxEXT : CapabilityOperand<5612, 0, 0, [SPV_EXT_shader_atomic_float_min_max], []>; defm AtomicFloat64MinMaxEXT : CapabilityOperand<5613, 0, 0, [SPV_EXT_shader_atomic_float_min_max], []>; defm GroupUniformArithmeticKHR : CapabilityOperand<6400, 0, 0, [SPV_KHR_uniform_group_instructions], []>; +defm USMStorageClassesINTEL : CapabilityOperand<5935, 0, 0, [SPV_INTEL_usm_storage_classes], [Kernel]>; //===----------------------------------------------------------------------===// // Multiclass used to define SourceLanguage enum values and at the same time @@ -700,6 +701,8 @@ defm IncomingRayPayloadNV : StorageClassOperand<5342, [RayTracingNV]>; defm ShaderRecordBufferNV : StorageClassOperand<5343, [RayTracingNV]>; defm PhysicalStorageBufferEXT : StorageClassOperand<5349, [PhysicalStorageBufferAddressesEXT]>; defm CodeSectionINTEL : StorageClassOperand<5605, [FunctionPointersINTEL]>; +defm DeviceOnlyINTEL : StorageClassOperand<5936, [USMStorageClassesINTEL]>; +defm HostOnlyINTEL : StorageClassOperand<5937, [USMStorageClassesINTEL]>; //===----------------------------------------------------------------------===// // Multiclass used to define Dim enum values and at the same time diff --git a/llvm/lib/Target/SPIRV/SPIRVUtils.cpp b/llvm/lib/Target/SPIRV/SPIRVUtils.cpp index 05f766d3ec548..169d7cc93897e 100644 --- a/llvm/lib/Target/SPIRV/SPIRVUtils.cpp +++ b/llvm/lib/Target/SPIRV/SPIRVUtils.cpp @@ -14,6 +14,7 @@ #include "MCTargetDesc/SPIRVBaseInfo.h" #include "SPIRV.h" #include "SPIRVInstrInfo.h" +#include "SPIRVSubtarget.h" #include "llvm/ADT/StringRef.h" #include "llvm/CodeGen/GlobalISel/GenericMachineInstrs.h" #include "llvm/CodeGen/GlobalISel/MachineIRBuilder.h" @@ -146,15 +147,19 @@ unsigned storageClassToAddressSpace(SPIRV::StorageClass::StorageClass SC) { return 3; case SPIRV::StorageClass::Generic: return 4; + case SPIRV::StorageClass::DeviceOnlyINTEL: + return 5; + case SPIRV::StorageClass::HostOnlyINTEL: + return 6; case SPIRV::StorageClass::Input: return 7; default: - llvm_unreachable("Unable to get address space id"); + report_fatal_error("Unable to get address space id"); } } SPIRV::StorageClass::StorageClass -addressSpaceToStorageClass(unsigned AddrSpace) { +addressSpaceToStorageClass(unsigned AddrSpace, const SPIRVSubtarget &STI) { switch (AddrSpace) { case 0: return SPIRV::StorageClass::Function; @@ -166,10 +171,18 @@ addressSpaceToStorageClass(unsigned AddrSpace) { return SPIRV::StorageClass::Workgroup; case 4: return SPIRV::StorageClass::Generic; + case 5: + return STI.canUseExtension(SPIRV::Extension::SPV_INTEL_usm_storage_classes) + ? SPIRV::StorageClass::DeviceOnlyINTEL + : SPIRV::StorageClass::CrossWorkgroup; + case 6: + return STI.canUseExtension(SPIRV::Extension::SPV_INTEL_usm_storage_classes) + ? SPIRV::StorageClass::HostOnlyINTEL + : SPIRV::StorageClass::CrossWorkgroup; case 7: return SPIRV::StorageClass::Input; default: - llvm_unreachable("Unknown address space"); + report_fatal_error("Unknown address space"); } } diff --git a/llvm/lib/Target/SPIRV/SPIRVUtils.h b/llvm/lib/Target/SPIRV/SPIRVUtils.h index a33dc02f854f5..1af53dcd0c4cd 100644 --- a/llvm/lib/Target/SPIRV/SPIRVUtils.h +++ b/llvm/lib/Target/SPIRV/SPIRVUtils.h @@ -27,6 +27,7 @@ class MachineRegisterInfo; class Register; class StringRef; class SPIRVInstrInfo; +class SPIRVSubtarget; // Add the given string as a series of integer operand, inserting null // terminators and padding to make sure the operands all have 32-bit @@ -62,7 +63,7 @@ unsigned storageClassToAddressSpace(SPIRV::StorageClass::StorageClass SC); // Convert an LLVM IR address space to a SPIR-V storage class. SPIRV::StorageClass::StorageClass -addressSpaceToStorageClass(unsigned AddrSpace); +addressSpaceToStorageClass(unsigned AddrSpace, const SPIRVSubtarget &STI); SPIRV::MemorySemantics::MemorySemantics getMemSemanticsForStorageClass(SPIRV::StorageClass::StorageClass SC); diff --git a/llvm/test/CodeGen/SPIRV/extensions/SPV_INTEL_usm_storage_classes/intel-usm-addrspaces.ll b/llvm/test/CodeGen/SPIRV/extensions/SPV_INTEL_usm_storage_classes/intel-usm-addrspaces.ll new file mode 100644 index 0000000000000..30c16350bf2b1 --- /dev/null +++ b/llvm/test/CodeGen/SPIRV/extensions/SPV_INTEL_usm_storage_classes/intel-usm-addrspaces.ll @@ -0,0 +1,84 @@ +; Modified from: https://github.com/KhronosGroup/SPIRV-LLVM-Translator/test/extensions/INTEL/SPV_INTEL_usm_storage_classes/intel_usm_addrspaces.ll + +; RUN: llc -O0 -mtriple=spirv32-unknown-unknown --spirv-extensions=SPV_INTEL_usm_storage_classes %s -o - | FileCheck %s --check-prefixes=CHECK-SPIRV,CHECK-SPIRV-EXT +; TODO: %if spirv-tools %{ llc -O0 -mtriple=spirv64-unknown-unknown --spirv-extensions=SPV_INTEL_usm_storage_classes %s -o - -filetype=obj | spirv-val %} +; RUN: llc -O0 -mtriple=spirv32-unknown-unknown %s -o - | FileCheck %s --check-prefixes=CHECK-SPIRV,CHECK-SPIRV-WITHOUT +; TODO: %if spirv-tools %{ llc -O0 -mtriple=spirv64-unknown-unknown %s -o - -filetype=obj | spirv-val %} + +; CHECK-: Capability USMStorageClassesINTEL +; CHECK-SPIRV-WITHOUT-NO: Capability USMStorageClassesINTEL +; CHECK-SPIRV-EXT-DAG: %[[DevTy:[0-9]+]] = OpTypePointer DeviceOnlyINTEL %[[#]] +; CHECK-SPIRV-EXT-DAG: %[[HostTy:[0-9]+]] = OpTypePointer HostOnlyINTEL %[[#]] +; CHECK-SPIRV-DAG: %[[CrsWrkTy:[0-9]+]] = OpTypePointer CrossWorkgroup %[[#]] + +target datalayout = "e-i64:64-v16:16-v24:32-v32:32-v48:64-v96:128-v192:256-v256:256-v512:512-v1024:1024-n8:16:32:64" +target triple = "spir64-unknown-unknown" + +define spir_kernel void @foo_kernel() { +entry: + ret void +} + +; CHECK-SPIRV: %[[Ptr1:[0-9]+]] = OpLoad %[[CrsWrkTy]] %[[#]] +; CHECK-SPIRV-EXT: %[[CastedPtr1:[0-9]+]] = OpCrossWorkgroupCastToPtrINTEL %[[DevTy]] %[[Ptr1]] +; CHECK-SPIRV-WITHOUT-NOT: OpCrossWorkgroupCastToPtrINTEL +; CHECK-SPIRV-EXT: OpStore %[[#]] %[[CastedPtr1]] +define spir_func void @test1(ptr addrspace(1) %arg_glob, ptr addrspace(5) %arg_dev) { +entry: + %arg_glob.addr = alloca ptr addrspace(1), align 4 + %arg_dev.addr = alloca ptr addrspace(5), align 4 + store ptr addrspace(1) %arg_glob, ptr %arg_glob.addr, align 4 + store ptr addrspace(5) %arg_dev, ptr %arg_dev.addr, align 4 + %loaded_glob = load ptr addrspace(1), ptr %arg_glob.addr, align 4 + %casted_ptr = addrspacecast ptr addrspace(1) %loaded_glob to ptr addrspace(5) + store ptr addrspace(5) %casted_ptr, ptr %arg_dev.addr, align 4 + ret void +} + +; CHECK-SPIRV: %[[Ptr2:[0-9]+]] = OpLoad %[[CrsWrkTy]] %[[#]] +; CHECK-SPIRV-EXT: %[[CastedPtr2:[0-9]+]] = OpCrossWorkgroupCastToPtrINTEL %[[HostTy]] %[[Ptr2]] +; CHECK-SPIRV-WITHOUT-NOT: OpCrossWorkgroupCastToPtrINTEL +; CHECK-SPIRV-EXT: OpStore %[[#]] %[[CastedPtr2]] +define spir_func void @test2(ptr addrspace(1) %arg_glob, ptr addrspace(6) %arg_host) { +entry: + %arg_glob.addr = alloca ptr addrspace(1), align 4 + %arg_host.addr = alloca ptr addrspace(6), align 4 + store ptr addrspace(1) %arg_glob, ptr %arg_glob.addr, align 4 + store ptr addrspace(6) %arg_host, ptr %arg_host.addr, align 4 + %loaded_glob = load ptr addrspace(1), ptr %arg_glob.addr, align 4 + %casted_ptr = addrspacecast ptr addrspace(1) %loaded_glob to ptr addrspace(6) + store ptr addrspace(6) %casted_ptr, ptr %arg_host.addr, align 4 + ret void +} + +; CHECK-SPIRV-EXT: %[[Ptr3:[0-9]+]] = OpLoad %[[DevTy]] %[[#]] +; CHECK-SPIRV-EXT: %[[CastedPtr3:[0-9]+]] = OpPtrCastToCrossWorkgroupINTEL %[[CrsWrkTy]] %[[Ptr3]] +; CHECK-SPIRV-WITHOUT-NOT: OpPtrCastToCrossWorkgroupINTEL +; CHECK-SPIRV-EXT: OpStore %[[#]] %[[CastedPtr3]] +define spir_func void @test3(ptr addrspace(1) %arg_glob, ptr addrspace(5) %arg_dev) { +entry: + %arg_glob.addr = alloca ptr addrspace(1), align 4 + %arg_dev.addr = alloca ptr addrspace(5), align 4 + store ptr addrspace(1) %arg_glob, ptr %arg_glob.addr, align 4 + store ptr addrspace(5) %arg_dev, ptr %arg_dev.addr, align 4 + %loaded_dev = load ptr addrspace(5), ptr %arg_dev.addr, align 4 + %casted_ptr = addrspacecast ptr addrspace(5) %loaded_dev to ptr addrspace(1) + store ptr addrspace(1) %casted_ptr, ptr %arg_glob.addr, align 4 + ret void +} + +; CHECK-SPIRV-EXT: %[[Ptr4:[0-9]+]] = OpLoad %[[HostTy]] %[[#]] +; CHECK-SPIRV-EXT: %[[CastedPtr4:[0-9]+]] = OpPtrCastToCrossWorkgroupINTEL %[[CrsWrkTy]] %[[Ptr4]] +; CHECK-SPIRV-WITHOUT-NOT: OpPtrCastToCrossWorkgroupINTEL +; CHECK-SPIRV-EXT: OpStore %[[#]] %[[CastedPtr4]] +define spir_func void @test4(ptr addrspace(1) %arg_glob, ptr addrspace(6) %arg_host) { +entry: + %arg_glob.addr = alloca ptr addrspace(1), align 4 + %arg_host.addr = alloca ptr addrspace(6), align 4 + store ptr addrspace(1) %arg_glob, ptr %arg_glob.addr, align 4 + store ptr addrspace(6) %arg_host, ptr %arg_host.addr, align 4 + %loaded_host = load ptr addrspace(6), ptr %arg_host.addr, align 4 + %casted_ptr = addrspacecast ptr addrspace(6) %loaded_host to ptr addrspace(1) + store ptr addrspace(1) %casted_ptr, ptr %arg_glob.addr, align 4 + ret void +}