Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add support for the SPV_INTEL_usm_storage_classes extension #82247

Conversation

VyacheslavLevytskyy
Copy link
Contributor

Copy link

github-actions bot commented Feb 19, 2024

✅ With the latest revision this PR passed the C/C++ code formatter.

@llvmbot
Copy link
Collaborator

llvmbot commented Feb 19, 2024

@llvm/pr-subscribers-backend-spir-v

Author: Vyacheslav Levytskyy (VyacheslavLevytskyy)

Changes

Add support for the SPV_INTEL_usm_storage_classes extension:


Patch is 23.97 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/82247.diff

12 Files Affected:

  • (modified) llvm/lib/Target/SPIRV/SPIRVCallLowering.cpp (+9-8)
  • (modified) llvm/lib/Target/SPIRV/SPIRVGlobalRegistry.cpp (+4-1)
  • (modified) llvm/lib/Target/SPIRV/SPIRVInstrInfo.td (+4)
  • (modified) llvm/lib/Target/SPIRV/SPIRVInstructionSelector.cpp (+30-6)
  • (modified) llvm/lib/Target/SPIRV/SPIRVLegalizerInfo.cpp (+9-7)
  • (modified) llvm/lib/Target/SPIRV/SPIRVModuleAnalysis.cpp (+7)
  • (modified) llvm/lib/Target/SPIRV/SPIRVPreLegalizer.cpp (+9-2)
  • (modified) llvm/lib/Target/SPIRV/SPIRVSubtarget.cpp (+6)
  • (modified) llvm/lib/Target/SPIRV/SPIRVSymbolicOperands.td (+3)
  • (modified) llvm/lib/Target/SPIRV/SPIRVUtils.cpp (+16-3)
  • (modified) llvm/lib/Target/SPIRV/SPIRVUtils.h (+2-1)
  • (added) llvm/test/CodeGen/SPIRV/extensions/SPV_INTEL_usm_storage_classes/intel-usm-addrspaces.ll (+84)
diff --git a/llvm/lib/Target/SPIRV/SPIRVCallLowering.cpp b/llvm/lib/Target/SPIRV/SPIRVCallLowering.cpp
index baeed2ad895a4b..abce2ea1c77820 100644
--- a/llvm/lib/Target/SPIRV/SPIRVCallLowering.cpp
+++ b/llvm/lib/Target/SPIRV/SPIRVCallLowering.cpp
@@ -150,7 +150,8 @@ getKernelArgTypeQual(const Function &F, unsigned ArgIdx) {
 
 static SPIRVType *getArgSPIRVType(const Function &F, unsigned ArgIdx,
                                   SPIRVGlobalRegistry *GR,
-                                  MachineIRBuilder &MIRBuilder) {
+                                  MachineIRBuilder &MIRBuilder,
+                                  const SPIRVSubtarget &ST) {
   // Read argument's access qualifier from metadata or default.
   SPIRV::AccessQualifier::AccessQualifier ArgAccessQual =
       getArgAccessQual(F, ArgIdx);
@@ -169,8 +170,8 @@ static SPIRVType *getArgSPIRVType(const Function &F, unsigned ArgIdx,
     if (MDTypeStr.ends_with("*"))
       ResArgType = GR->getOrCreateSPIRVTypeByName(
           MDTypeStr, MIRBuilder,
-          addressSpaceToStorageClass(
-              OriginalArgType->getPointerAddressSpace()));
+          addressSpaceToStorageClass(OriginalArgType->getPointerAddressSpace(),
+                                     ST));
     else if (MDTypeStr.ends_with("_t"))
       ResArgType = GR->getOrCreateSPIRVTypeByName(
           "opencl." + MDTypeStr.str(), MIRBuilder,
@@ -220,6 +221,10 @@ bool SPIRVCallLowering::lowerFormalArguments(MachineIRBuilder &MIRBuilder,
   assert(GR && "Must initialize the SPIRV type registry before lowering args.");
   GR->setCurrentFunc(MIRBuilder.getMF());
 
+  // Get access to information about available extensions
+  const SPIRVSubtarget *ST =
+      static_cast<const SPIRVSubtarget *>(&MIRBuilder.getMF().getSubtarget());
+
   // Assign types and names to all args, and store their types for later.
   FunctionType *FTy = getOriginalFunctionType(F);
   SmallVector<SPIRVType *, 4> ArgTypeVRegs;
@@ -230,7 +235,7 @@ bool SPIRVCallLowering::lowerFormalArguments(MachineIRBuilder &MIRBuilder,
       // TODO: handle the case of multiple registers.
       if (VRegs[i].size() > 1)
         return false;
-      auto *SpirvTy = getArgSPIRVType(F, i, GR, MIRBuilder);
+      auto *SpirvTy = getArgSPIRVType(F, i, GR, MIRBuilder, *ST);
       GR->assignSPIRVTypeToVReg(SpirvTy, VRegs[i][0], MIRBuilder.getMF());
       ArgTypeVRegs.push_back(SpirvTy);
 
@@ -332,10 +337,6 @@ bool SPIRVCallLowering::lowerFormalArguments(MachineIRBuilder &MIRBuilder,
   if (F.hasName())
     buildOpName(FuncVReg, F.getName(), MIRBuilder);
 
-  // Get access to information about available extensions
-  const auto *ST =
-      static_cast<const SPIRVSubtarget *>(&MIRBuilder.getMF().getSubtarget());
-
   // Handle entry points and function linkage.
   if (isEntryPoint(F)) {
     const auto &STI = MIRBuilder.getMF().getSubtarget<SPIRVSubtarget>();
diff --git a/llvm/lib/Target/SPIRV/SPIRVGlobalRegistry.cpp b/llvm/lib/Target/SPIRV/SPIRVGlobalRegistry.cpp
index 47fec745c3f18a..a1cb630f1aa477 100644
--- a/llvm/lib/Target/SPIRV/SPIRVGlobalRegistry.cpp
+++ b/llvm/lib/Target/SPIRV/SPIRVGlobalRegistry.cpp
@@ -709,7 +709,10 @@ SPIRVType *SPIRVGlobalRegistry::createSPIRVType(
     // TODO: change the implementation once opaque pointers are supported
     // in the SPIR-V specification.
     SpvElementType = getOrCreateSPIRVIntegerType(8, MIRBuilder);
-    auto SC = addressSpaceToStorageClass(PType->getAddressSpace());
+    // Get access to information about available extensions
+    const SPIRVSubtarget *ST =
+        static_cast<const SPIRVSubtarget *>(&MIRBuilder.getMF().getSubtarget());
+    auto SC = addressSpaceToStorageClass(PType->getAddressSpace(), *ST);
     // Null pointer means we have a loop in type definitions, make and
     // return corresponding OpTypeForwardPointer.
     if (SpvElementType == nullptr) {
diff --git a/llvm/lib/Target/SPIRV/SPIRVInstrInfo.td b/llvm/lib/Target/SPIRV/SPIRVInstrInfo.td
index 7965dd969e5cfa..b16e73466ebe59 100644
--- a/llvm/lib/Target/SPIRV/SPIRVInstrInfo.td
+++ b/llvm/lib/Target/SPIRV/SPIRVInstrInfo.td
@@ -430,6 +430,10 @@ def OpGenericCastToPtrExplicit : Op<123, (outs ID:$r), (ins TYPE:$t, ID:$p, Stor
                               "$r = OpGenericCastToPtrExplicit $t $p $s">;
 def OpBitcast : UnOp<"OpBitcast", 124>;
 
+// SPV_INTEL_usm_storage_classes
+def OpPtrCastToCrossWorkgroupINTEL : UnOp<"OpPtrCastToCrossWorkgroupINTEL", 5934>;
+def OpCrossWorkgroupCastToPtrINTEL : UnOp<"OpCrossWorkgroupCastToPtrINTEL", 5938>;
+
 // 3.42.12 Composite Instructions
 
 def OpVectorExtractDynamic: Op<77, (outs ID:$res), (ins TYPE:$type, vID:$vec, ID:$idx),
diff --git a/llvm/lib/Target/SPIRV/SPIRVInstructionSelector.cpp b/llvm/lib/Target/SPIRV/SPIRVInstructionSelector.cpp
index 53d19a1e31382d..7258d3b4d88ed3 100644
--- a/llvm/lib/Target/SPIRV/SPIRVInstructionSelector.cpp
+++ b/llvm/lib/Target/SPIRV/SPIRVInstructionSelector.cpp
@@ -828,8 +828,18 @@ static bool isGenericCastablePtr(SPIRV::StorageClass::StorageClass SC) {
   }
 }
 
+static bool isUSMStorageClass(SPIRV::StorageClass::StorageClass SC) {
+  switch (SC) {
+  case SPIRV::StorageClass::DeviceOnlyINTEL:
+  case SPIRV::StorageClass::HostOnlyINTEL:
+    return true;
+  default:
+    return false;
+  }
+}
+
 // In SPIR-V address space casting can only happen to and from the Generic
-// storage class. We can also only case Workgroup, CrossWorkgroup, or Function
+// storage class. We can also only cast Workgroup, CrossWorkgroup, or Function
 // pointers to and from Generic pointers. As such, we can convert e.g. from
 // Workgroup to Function by going via a Generic pointer as an intermediary. All
 // other combinations can only be done by a bitcast, and are probably not safe.
@@ -862,13 +872,17 @@ bool SPIRVInstructionSelector::selectAddrSpaceCast(Register ResVReg,
   SPIRV::StorageClass::StorageClass SrcSC = GR.getPointerStorageClass(SrcPtr);
   SPIRV::StorageClass::StorageClass DstSC = GR.getPointerStorageClass(ResVReg);
 
-  // Casting from an eligable pointer to Generic.
+  // don't generate a cast between identical storage classes
+  if (SrcSC == DstSC)
+    return true;
+
+  // Casting from an eligible pointer to Generic.
   if (DstSC == SPIRV::StorageClass::Generic && isGenericCastablePtr(SrcSC))
     return selectUnOp(ResVReg, ResType, I, SPIRV::OpPtrCastToGeneric);
-  // Casting from Generic to an eligable pointer.
+  // Casting from Generic to an eligible pointer.
   if (SrcSC == SPIRV::StorageClass::Generic && isGenericCastablePtr(DstSC))
     return selectUnOp(ResVReg, ResType, I, SPIRV::OpGenericCastToPtr);
-  // Casting between 2 eligable pointers using Generic as an intermediary.
+  // Casting between 2 eligible pointers using Generic as an intermediary.
   if (isGenericCastablePtr(SrcSC) && isGenericCastablePtr(DstSC)) {
     Register Tmp = MRI->createVirtualRegister(&SPIRV::IDRegClass);
     SPIRVType *GenericPtrTy = GR.getOrCreateSPIRVPointerType(
@@ -886,6 +900,16 @@ bool SPIRVInstructionSelector::selectAddrSpaceCast(Register ResVReg,
                           .addUse(Tmp)
                           .constrainAllUses(TII, TRI, RBI);
   }
+
+  // Check if instructions from the SPV_INTEL_usm_storage_classes extension may
+  // be applied
+  if (isUSMStorageClass(SrcSC) && DstSC == SPIRV::StorageClass::CrossWorkgroup)
+    return selectUnOp(ResVReg, ResType, I,
+                      SPIRV::OpPtrCastToCrossWorkgroupINTEL);
+  if (SrcSC == SPIRV::StorageClass::CrossWorkgroup && isUSMStorageClass(DstSC))
+    return selectUnOp(ResVReg, ResType, I,
+                      SPIRV::OpCrossWorkgroupCastToPtrINTEL);
+
   // TODO Should this case just be disallowed completely?
   // We're casting 2 other arbitrary address spaces, so have to bitcast.
   return selectUnOp(ResVReg, ResType, I, SPIRV::OpBitcast);
@@ -1545,7 +1569,7 @@ bool SPIRVInstructionSelector::selectGlobalValue(
   }
   SPIRVType *ResType = GR.getOrCreateSPIRVPointerType(
       PointerBaseType, I, TII,
-      addressSpaceToStorageClass(GV->getAddressSpace()));
+      addressSpaceToStorageClass(GV->getAddressSpace(), STI));
 
   std::string GlobalIdent;
   if (!GV->hasName()) {
@@ -1618,7 +1642,7 @@ bool SPIRVInstructionSelector::selectGlobalValue(
 
   unsigned AddrSpace = GV->getAddressSpace();
   SPIRV::StorageClass::StorageClass Storage =
-      addressSpaceToStorageClass(AddrSpace);
+      addressSpaceToStorageClass(AddrSpace, STI);
   bool HasLnkTy = GV->getLinkage() != GlobalValue::InternalLinkage &&
                   Storage != SPIRV::StorageClass::Function;
   SPIRV::LinkageType::LinkageType LnkType =
diff --git a/llvm/lib/Target/SPIRV/SPIRVLegalizerInfo.cpp b/llvm/lib/Target/SPIRV/SPIRVLegalizerInfo.cpp
index 011a550a7b3d9b..4f2e7a240fc2cc 100644
--- a/llvm/lib/Target/SPIRV/SPIRVLegalizerInfo.cpp
+++ b/llvm/lib/Target/SPIRV/SPIRVLegalizerInfo.cpp
@@ -102,14 +102,16 @@ SPIRVLegalizerInfo::SPIRVLegalizerInfo(const SPIRVSubtarget &ST) {
   const LLT p2 = LLT::pointer(2, PSize); // UniformConstant
   const LLT p3 = LLT::pointer(3, PSize); // Workgroup
   const LLT p4 = LLT::pointer(4, PSize); // Generic
-  const LLT p5 = LLT::pointer(5, PSize); // Input
+  const LLT p5 =
+      LLT::pointer(5, PSize); // Input, SPV_INTEL_usm_storage_classes (Device)
+  const LLT p6 = LLT::pointer(6, PSize); // SPV_INTEL_usm_storage_classes (Host)
 
   // TODO: remove copy-pasting here by using concatenation in some way.
   auto allPtrsScalarsAndVectors = {
-      p0,    p1,    p2,    p3,    p4,    p5,    s1,     s8,     s16,
-      s32,   s64,   v2s1,  v2s8,  v2s16, v2s32, v2s64,  v3s1,   v3s8,
-      v3s16, v3s32, v3s64, v4s1,  v4s8,  v4s16, v4s32,  v4s64,  v8s1,
-      v8s8,  v8s16, v8s32, v8s64, v16s1, v16s8, v16s16, v16s32, v16s64};
+      p0,    p1,    p2,    p3,    p4,     p5,     p6,    s1,   s8,   s16,
+      s32,   s64,   v2s1,  v2s8,  v2s16,  v2s32,  v2s64, v3s1, v3s8, v3s16,
+      v3s32, v3s64, v4s1,  v4s8,  v4s16,  v4s32,  v4s64, v8s1, v8s8, v8s16,
+      v8s32, v8s64, v16s1, v16s8, v16s16, v16s32, v16s64};
 
   auto allScalarsAndVectors = {
       s1,   s8,   s16,   s32,   s64,   v2s1,  v2s8,  v2s16,  v2s32,  v2s64,
@@ -133,8 +135,8 @@ SPIRVLegalizerInfo::SPIRVLegalizerInfo(const SPIRVSubtarget &ST) {
 
   auto allFloatAndIntScalars = allIntScalars;
 
-  auto allPtrs = {p0, p1, p2, p3, p4, p5};
-  auto allWritablePtrs = {p0, p1, p3, p4};
+  auto allPtrs = {p0, p1, p2, p3, p4, p5, p6};
+  auto allWritablePtrs = {p0, p1, p3, p4, p5, p6};
 
   for (auto Opc : TypeFoldingSupportingOpcs)
     getActionDefinitionsBuilder(Opc).custom();
diff --git a/llvm/lib/Target/SPIRV/SPIRVModuleAnalysis.cpp b/llvm/lib/Target/SPIRV/SPIRVModuleAnalysis.cpp
index b3244cfbec7014..83c407a2a89f31 100644
--- a/llvm/lib/Target/SPIRV/SPIRVModuleAnalysis.cpp
+++ b/llvm/lib/Target/SPIRV/SPIRVModuleAnalysis.cpp
@@ -1063,6 +1063,13 @@ void addInstrRequirements(const MachineInstr &MI,
       Reqs.addCapability(SPIRV::Capability::ExpectAssumeKHR);
     }
     break;
+  case SPIRV::OpPtrCastToCrossWorkgroupINTEL:
+  case SPIRV::OpCrossWorkgroupCastToPtrINTEL:
+    if (ST.canUseExtension(SPIRV::Extension::SPV_INTEL_usm_storage_classes)) {
+      Reqs.addExtension(SPIRV::Extension::SPV_INTEL_usm_storage_classes);
+      Reqs.addCapability(SPIRV::Capability::USMStorageClassesINTEL);
+    }
+    break;
   case SPIRV::OpConstantFunctionPointerINTEL:
     if (ST.canUseExtension(SPIRV::Extension::SPV_INTEL_function_pointers)) {
       Reqs.addExtension(SPIRV::Extension::SPV_INTEL_function_pointers);
diff --git a/llvm/lib/Target/SPIRV/SPIRVPreLegalizer.cpp b/llvm/lib/Target/SPIRV/SPIRVPreLegalizer.cpp
index cbc16fa986614e..144216896eb68c 100644
--- a/llvm/lib/Target/SPIRV/SPIRVPreLegalizer.cpp
+++ b/llvm/lib/Target/SPIRV/SPIRVPreLegalizer.cpp
@@ -122,6 +122,9 @@ static void foldConstantsIntoIntrinsics(MachineFunction &MF) {
 
 static void insertBitcasts(MachineFunction &MF, SPIRVGlobalRegistry *GR,
                            MachineIRBuilder MIB) {
+  // Get access to information about available extensions
+  const SPIRVSubtarget *ST =
+      static_cast<const SPIRVSubtarget *>(&MIB.getMF().getSubtarget());
   SmallVector<MachineInstr *, 10> ToErase;
   for (MachineBasicBlock &MBB : MF) {
     for (MachineInstr &MI : MBB) {
@@ -141,7 +144,7 @@ static void insertBitcasts(MachineFunction &MF, SPIRVGlobalRegistry *GR,
           getMDOperandAsType(MI.getOperand(3).getMetadata(), 0), MIB);
       SPIRVType *AssignedPtrType = GR->getOrCreateSPIRVPointerType(
           BaseTy, MI, *MF.getSubtarget<SPIRVSubtarget>().getInstrInfo(),
-          addressSpaceToStorageClass(MI.getOperand(4).getImm()));
+          addressSpaceToStorageClass(MI.getOperand(4).getImm(), *ST));
 
       // If the bitcast would be redundant, replace all uses with the source
       // register.
@@ -250,6 +253,10 @@ Register insertAssignInstr(Register Reg, Type *Ty, SPIRVType *SpirvTy,
 
 static void generateAssignInstrs(MachineFunction &MF, SPIRVGlobalRegistry *GR,
                                  MachineIRBuilder MIB) {
+  // Get access to information about available extensions
+  const SPIRVSubtarget *ST =
+      static_cast<const SPIRVSubtarget *>(&MIB.getMF().getSubtarget());
+
   MachineRegisterInfo &MRI = MF.getRegInfo();
   SmallVector<MachineInstr *, 10> ToErase;
 
@@ -269,7 +276,7 @@ static void generateAssignInstrs(MachineFunction &MF, SPIRVGlobalRegistry *GR,
             getMDOperandAsType(MI.getOperand(2).getMetadata(), 0), MIB);
         SPIRVType *AssignedPtrType = GR->getOrCreateSPIRVPointerType(
             BaseTy, MI, *MF.getSubtarget<SPIRVSubtarget>().getInstrInfo(),
-            addressSpaceToStorageClass(MI.getOperand(3).getImm()));
+            addressSpaceToStorageClass(MI.getOperand(3).getImm(), *ST));
         MachineInstr *Def = MRI.getVRegDef(Reg);
         assert(Def && "Expecting an instruction that defines the register");
         insertAssignInstr(Reg, nullptr, AssignedPtrType, GR, MIB,
diff --git a/llvm/lib/Target/SPIRV/SPIRVSubtarget.cpp b/llvm/lib/Target/SPIRV/SPIRVSubtarget.cpp
index e0493321041e38..f62c350c2eb7df 100644
--- a/llvm/lib/Target/SPIRV/SPIRVSubtarget.cpp
+++ b/llvm/lib/Target/SPIRV/SPIRVSubtarget.cpp
@@ -49,6 +49,12 @@ cl::list<SPIRV::Extension::Extension> Extensions(
         clEnumValN(SPIRV::Extension::SPV_INTEL_optnone, "SPV_INTEL_optnone",
                    "Adds OptNoneINTEL value for Function Control mask that "
                    "indicates a request to not optimize the function."),
+        clEnumValN(SPIRV::Extension::SPV_INTEL_usm_storage_classes,
+                   "SPV_INTEL_usm_storage_classes",
+                   "Introduces two new storage classes that are sub classes of "
+                   "the CrossWorkgroup storage class "
+                   "that provides additional information that can enable "
+                   "optimization."),
         clEnumValN(SPIRV::Extension::SPV_INTEL_subgroups, "SPV_INTEL_subgroups",
                    "Allows work items in a subgroup to share data without the "
                    "use of local memory and work group barriers, and to "
diff --git a/llvm/lib/Target/SPIRV/SPIRVSymbolicOperands.td b/llvm/lib/Target/SPIRV/SPIRVSymbolicOperands.td
index d133a2575284c0..4da3afd418a4f8 100644
--- a/llvm/lib/Target/SPIRV/SPIRVSymbolicOperands.td
+++ b/llvm/lib/Target/SPIRV/SPIRVSymbolicOperands.td
@@ -461,6 +461,7 @@ defm AtomicFloat16AddEXT : CapabilityOperand<6095, 0, 0, [SPV_EXT_shader_atomic_
 defm AtomicFloat16MinMaxEXT : CapabilityOperand<5616, 0, 0, [SPV_EXT_shader_atomic_float_min_max], []>;
 defm AtomicFloat32MinMaxEXT : CapabilityOperand<5612, 0, 0, [SPV_EXT_shader_atomic_float_min_max], []>;
 defm AtomicFloat64MinMaxEXT : CapabilityOperand<5613, 0, 0, [SPV_EXT_shader_atomic_float_min_max], []>;
+defm USMStorageClassesINTEL : CapabilityOperand<5935, 0, 0, [SPV_INTEL_usm_storage_classes], [Kernel]>;
 
 //===----------------------------------------------------------------------===//
 // Multiclass used to define SourceLanguage enum values and at the same time
@@ -698,6 +699,8 @@ defm IncomingRayPayloadNV : StorageClassOperand<5342, [RayTracingNV]>;
 defm ShaderRecordBufferNV : StorageClassOperand<5343, [RayTracingNV]>;
 defm PhysicalStorageBufferEXT : StorageClassOperand<5349, [PhysicalStorageBufferAddressesEXT]>;
 defm CodeSectionINTEL : StorageClassOperand<5605, [FunctionPointersINTEL]>;
+defm DeviceOnlyINTEL : StorageClassOperand<5936, [USMStorageClassesINTEL]>;
+defm HostOnlyINTEL : StorageClassOperand<5937, [USMStorageClassesINTEL]>;
 
 //===----------------------------------------------------------------------===//
 // Multiclass used to define Dim enum values and at the same time
diff --git a/llvm/lib/Target/SPIRV/SPIRVUtils.cpp b/llvm/lib/Target/SPIRV/SPIRVUtils.cpp
index 89daa19e666f63..f998b3490ca1c3 100644
--- a/llvm/lib/Target/SPIRV/SPIRVUtils.cpp
+++ b/llvm/lib/Target/SPIRV/SPIRVUtils.cpp
@@ -14,6 +14,7 @@
 #include "MCTargetDesc/SPIRVBaseInfo.h"
 #include "SPIRV.h"
 #include "SPIRVInstrInfo.h"
+#include "SPIRVSubtarget.h"
 #include "llvm/ADT/StringRef.h"
 #include "llvm/CodeGen/GlobalISel/GenericMachineInstrs.h"
 #include "llvm/CodeGen/GlobalISel/MachineIRBuilder.h"
@@ -146,15 +147,19 @@ unsigned storageClassToAddressSpace(SPIRV::StorageClass::StorageClass SC) {
     return 3;
   case SPIRV::StorageClass::Generic:
     return 4;
+  case SPIRV::StorageClass::DeviceOnlyINTEL:
+    return 5;
+  case SPIRV::StorageClass::HostOnlyINTEL:
+    return 6;
   case SPIRV::StorageClass::Input:
     return 7;
   default:
-    llvm_unreachable("Unable to get address space id");
+    report_fatal_error("Unable to get address space id");
   }
 }
 
 SPIRV::StorageClass::StorageClass
-addressSpaceToStorageClass(unsigned AddrSpace) {
+addressSpaceToStorageClass(unsigned AddrSpace, const SPIRVSubtarget &STI) {
   switch (AddrSpace) {
   case 0:
     return SPIRV::StorageClass::Function;
@@ -166,10 +171,18 @@ addressSpaceToStorageClass(unsigned AddrSpace) {
     return SPIRV::StorageClass::Workgroup;
   case 4:
     return SPIRV::StorageClass::Generic;
+  case 5:
+    return STI.canUseExtension(SPIRV::Extension::SPV_INTEL_usm_storage_classes)
+               ? SPIRV::StorageClass::DeviceOnlyINTEL
+               : SPIRV::StorageClass::CrossWorkgroup;
+  case 6:
+    return STI.canUseExtension(SPIRV::Extension::SPV_INTEL_usm_storage_classes)
+               ? SPIRV::StorageClass::HostOnlyINTEL
+               : SPIRV::StorageClass::CrossWorkgroup;
   case 7:
     return SPIRV::StorageClass::Input;
   default:
-    llvm_unreachable("Unknown address space");
+    report_fatal_error("Unknown address space");
   }
 }
 
diff --git a/llvm/lib/Target/SPIRV/SPIRVUtils.h b/llvm/lib/Target/SPIRV/SPIRVUtils.h
index 60742e2f272808..42d2d313615770 100644
--- a/llvm/lib/Target/SPIRV/SPIRVUtils.h
+++ b/llvm/lib/Target/SPIRV/SPIRVUtils.h
@@ -27,6 +27,7 @@ class MachineRegisterInfo;
 class Register;
 class StringRef;
 class SPIRVInstrInfo;
+class SPIRVSubtarget;
 
 // Add the given string as a series of integer operand, inserting null
 // terminators and padding to make sure the operands all have 32-bit
@@ -62,7 +63,7 @@ unsigned storageClassToAddressSpace(SPIRV::StorageClass::StorageClass SC);
 
 // Convert an LLVM IR address space to a SPIR-V storage class.
 SPIRV::StorageClass::StorageClass
-addressSpaceToStorageClass(unsigned AddrSpace);
+addressSpaceToStorageClass(unsigned AddrSpace, const SPIRVSubtarget &STI);
 
 SPIRV::MemorySemantics::MemorySemantics
 getMemSemanticsForStorageClass(SPIRV::StorageClass::StorageClass SC);
diff --git a/llvm/test/CodeGen/SPIRV/extensions/SPV_INTEL_usm_storage_classes/intel-usm-addrspaces.ll b/llvm/test/CodeGen/SPIRV/extensions/SPV_INTEL_usm_storage_classes/intel-usm-addrspaces.ll
new file mode 100644
index 00000000000000..30c16350bf2b1f
--- /dev/null
+++ b/llvm/test/CodeGen/SPIRV/extensions/SPV_INTEL_usm_storage_classes/intel-usm-addrspaces.ll
@@ -0,0 +1,84 @@
+; Modified from: https://github.com/KhronosGroup/SPIRV-LLVM-Translator/test/extensions/INTEL/SPV_INTEL_usm_storage_classes/intel_usm_addrspaces.ll
+
+; RUN: llc -O0 -mtriple=spirv32-unknown-unknown --spirv-extensions=SPV_INTEL_usm_storage_classes %s -o - | FileCheck %s --check-prefixes=CHECK-SPIRV,CHECK-SPIRV-EXT
+; TODO: %if spirv-tools %{ llc -O0 -mtriple=spirv64-unknown-unknown --spirv-extensions=SPV_INTEL_usm_storage_classes %s -o - -filetype=obj | spirv-val %}
+; RUN: llc -O0 -mtriple=spirv32-unknown-unknown %s -o - | FileCheck %s --check...
[truncated]

@VyacheslavLevytskyy VyacheslavLevytskyy merged commit 4a602d9 into llvm:main Feb 22, 2024
3 of 5 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

3 participants