Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
17 changes: 9 additions & 8 deletions llvm/lib/Target/SPIRV/SPIRVCallLowering.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -150,7 +150,8 @@ getKernelArgTypeQual(const Function &F, unsigned ArgIdx) {

static SPIRVType *getArgSPIRVType(const Function &F, unsigned ArgIdx,
SPIRVGlobalRegistry *GR,
MachineIRBuilder &MIRBuilder) {
MachineIRBuilder &MIRBuilder,
const SPIRVSubtarget &ST) {
// Read argument's access qualifier from metadata or default.
SPIRV::AccessQualifier::AccessQualifier ArgAccessQual =
getArgAccessQual(F, ArgIdx);
Expand All @@ -169,8 +170,8 @@ static SPIRVType *getArgSPIRVType(const Function &F, unsigned ArgIdx,
if (MDTypeStr.ends_with("*"))
ResArgType = GR->getOrCreateSPIRVTypeByName(
MDTypeStr, MIRBuilder,
addressSpaceToStorageClass(
OriginalArgType->getPointerAddressSpace()));
addressSpaceToStorageClass(OriginalArgType->getPointerAddressSpace(),
ST));
else if (MDTypeStr.ends_with("_t"))
ResArgType = GR->getOrCreateSPIRVTypeByName(
"opencl." + MDTypeStr.str(), MIRBuilder,
Expand Down Expand Up @@ -206,6 +207,10 @@ bool SPIRVCallLowering::lowerFormalArguments(MachineIRBuilder &MIRBuilder,
assert(GR && "Must initialize the SPIRV type registry before lowering args.");
GR->setCurrentFunc(MIRBuilder.getMF());

// Get access to information about available extensions
const SPIRVSubtarget *ST =
static_cast<const SPIRVSubtarget *>(&MIRBuilder.getMF().getSubtarget());

// Assign types and names to all args, and store their types for later.
FunctionType *FTy = getOriginalFunctionType(F);
SmallVector<SPIRVType *, 4> ArgTypeVRegs;
Expand All @@ -216,7 +221,7 @@ bool SPIRVCallLowering::lowerFormalArguments(MachineIRBuilder &MIRBuilder,
// TODO: handle the case of multiple registers.
if (VRegs[i].size() > 1)
return false;
auto *SpirvTy = getArgSPIRVType(F, i, GR, MIRBuilder);
auto *SpirvTy = getArgSPIRVType(F, i, GR, MIRBuilder, *ST);
GR->assignSPIRVTypeToVReg(SpirvTy, VRegs[i][0], MIRBuilder.getMF());
ArgTypeVRegs.push_back(SpirvTy);

Expand Down Expand Up @@ -318,10 +323,6 @@ bool SPIRVCallLowering::lowerFormalArguments(MachineIRBuilder &MIRBuilder,
if (F.hasName())
buildOpName(FuncVReg, F.getName(), MIRBuilder);

// Get access to information about available extensions
const auto *ST =
static_cast<const SPIRVSubtarget *>(&MIRBuilder.getMF().getSubtarget());

// Handle entry points and function linkage.
if (isEntryPoint(F)) {
const auto &STI = MIRBuilder.getMF().getSubtarget<SPIRVSubtarget>();
Expand Down
5 changes: 4 additions & 1 deletion llvm/lib/Target/SPIRV/SPIRVGlobalRegistry.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -709,7 +709,10 @@ SPIRVType *SPIRVGlobalRegistry::createSPIRVType(
// TODO: change the implementation once opaque pointers are supported
// in the SPIR-V specification.
SpvElementType = getOrCreateSPIRVIntegerType(8, MIRBuilder);
auto SC = addressSpaceToStorageClass(PType->getAddressSpace());
// Get access to information about available extensions
const SPIRVSubtarget *ST =
static_cast<const SPIRVSubtarget *>(&MIRBuilder.getMF().getSubtarget());
auto SC = addressSpaceToStorageClass(PType->getAddressSpace(), *ST);
// Null pointer means we have a loop in type definitions, make and
// return corresponding OpTypeForwardPointer.
if (SpvElementType == nullptr) {
Expand Down
4 changes: 4 additions & 0 deletions llvm/lib/Target/SPIRV/SPIRVInstrInfo.td
Original file line number Diff line number Diff line change
Expand Up @@ -430,6 +430,10 @@ def OpGenericCastToPtrExplicit : Op<123, (outs ID:$r), (ins TYPE:$t, ID:$p, Stor
"$r = OpGenericCastToPtrExplicit $t $p $s">;
def OpBitcast : UnOp<"OpBitcast", 124>;

// SPV_INTEL_usm_storage_classes
def OpPtrCastToCrossWorkgroupINTEL : UnOp<"OpPtrCastToCrossWorkgroupINTEL", 5934>;
def OpCrossWorkgroupCastToPtrINTEL : UnOp<"OpCrossWorkgroupCastToPtrINTEL", 5938>;

// 3.42.12 Composite Instructions

def OpVectorExtractDynamic: Op<77, (outs ID:$res), (ins TYPE:$type, vID:$vec, ID:$idx),
Expand Down
36 changes: 30 additions & 6 deletions llvm/lib/Target/SPIRV/SPIRVInstructionSelector.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -828,8 +828,18 @@ static bool isGenericCastablePtr(SPIRV::StorageClass::StorageClass SC) {
}
}

static bool isUSMStorageClass(SPIRV::StorageClass::StorageClass SC) {
switch (SC) {
case SPIRV::StorageClass::DeviceOnlyINTEL:
case SPIRV::StorageClass::HostOnlyINTEL:
return true;
default:
return false;
}
}

// In SPIR-V address space casting can only happen to and from the Generic
// storage class. We can also only case Workgroup, CrossWorkgroup, or Function
// storage class. We can also only cast Workgroup, CrossWorkgroup, or Function
// pointers to and from Generic pointers. As such, we can convert e.g. from
// Workgroup to Function by going via a Generic pointer as an intermediary. All
// other combinations can only be done by a bitcast, and are probably not safe.
Expand Down Expand Up @@ -862,13 +872,17 @@ bool SPIRVInstructionSelector::selectAddrSpaceCast(Register ResVReg,
SPIRV::StorageClass::StorageClass SrcSC = GR.getPointerStorageClass(SrcPtr);
SPIRV::StorageClass::StorageClass DstSC = GR.getPointerStorageClass(ResVReg);

// Casting from an eligable pointer to Generic.
// don't generate a cast between identical storage classes
if (SrcSC == DstSC)
return true;

// Casting from an eligible pointer to Generic.
if (DstSC == SPIRV::StorageClass::Generic && isGenericCastablePtr(SrcSC))
return selectUnOp(ResVReg, ResType, I, SPIRV::OpPtrCastToGeneric);
// Casting from Generic to an eligable pointer.
// Casting from Generic to an eligible pointer.
if (SrcSC == SPIRV::StorageClass::Generic && isGenericCastablePtr(DstSC))
return selectUnOp(ResVReg, ResType, I, SPIRV::OpGenericCastToPtr);
// Casting between 2 eligable pointers using Generic as an intermediary.
// Casting between 2 eligible pointers using Generic as an intermediary.
if (isGenericCastablePtr(SrcSC) && isGenericCastablePtr(DstSC)) {
Register Tmp = MRI->createVirtualRegister(&SPIRV::IDRegClass);
SPIRVType *GenericPtrTy = GR.getOrCreateSPIRVPointerType(
Expand All @@ -886,6 +900,16 @@ bool SPIRVInstructionSelector::selectAddrSpaceCast(Register ResVReg,
.addUse(Tmp)
.constrainAllUses(TII, TRI, RBI);
}

// Check if instructions from the SPV_INTEL_usm_storage_classes extension may
// be applied
if (isUSMStorageClass(SrcSC) && DstSC == SPIRV::StorageClass::CrossWorkgroup)
return selectUnOp(ResVReg, ResType, I,
SPIRV::OpPtrCastToCrossWorkgroupINTEL);
if (SrcSC == SPIRV::StorageClass::CrossWorkgroup && isUSMStorageClass(DstSC))
return selectUnOp(ResVReg, ResType, I,
SPIRV::OpCrossWorkgroupCastToPtrINTEL);

// TODO Should this case just be disallowed completely?
// We're casting 2 other arbitrary address spaces, so have to bitcast.
return selectUnOp(ResVReg, ResType, I, SPIRV::OpBitcast);
Expand Down Expand Up @@ -1545,7 +1569,7 @@ bool SPIRVInstructionSelector::selectGlobalValue(
}
SPIRVType *ResType = GR.getOrCreateSPIRVPointerType(
PointerBaseType, I, TII,
addressSpaceToStorageClass(GV->getAddressSpace()));
addressSpaceToStorageClass(GV->getAddressSpace(), STI));

std::string GlobalIdent;
if (!GV->hasName()) {
Expand Down Expand Up @@ -1618,7 +1642,7 @@ bool SPIRVInstructionSelector::selectGlobalValue(

unsigned AddrSpace = GV->getAddressSpace();
SPIRV::StorageClass::StorageClass Storage =
addressSpaceToStorageClass(AddrSpace);
addressSpaceToStorageClass(AddrSpace, STI);
bool HasLnkTy = GV->getLinkage() != GlobalValue::InternalLinkage &&
Storage != SPIRV::StorageClass::Function;
SPIRV::LinkageType::LinkageType LnkType =
Expand Down
16 changes: 9 additions & 7 deletions llvm/lib/Target/SPIRV/SPIRVLegalizerInfo.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -102,14 +102,16 @@ SPIRVLegalizerInfo::SPIRVLegalizerInfo(const SPIRVSubtarget &ST) {
const LLT p2 = LLT::pointer(2, PSize); // UniformConstant
const LLT p3 = LLT::pointer(3, PSize); // Workgroup
const LLT p4 = LLT::pointer(4, PSize); // Generic
const LLT p5 = LLT::pointer(5, PSize); // Input
const LLT p5 =
LLT::pointer(5, PSize); // Input, SPV_INTEL_usm_storage_classes (Device)
const LLT p6 = LLT::pointer(6, PSize); // SPV_INTEL_usm_storage_classes (Host)

// TODO: remove copy-pasting here by using concatenation in some way.
auto allPtrsScalarsAndVectors = {
p0, p1, p2, p3, p4, p5, s1, s8, s16,
s32, s64, v2s1, v2s8, v2s16, v2s32, v2s64, v3s1, v3s8,
v3s16, v3s32, v3s64, v4s1, v4s8, v4s16, v4s32, v4s64, v8s1,
v8s8, v8s16, v8s32, v8s64, v16s1, v16s8, v16s16, v16s32, v16s64};
p0, p1, p2, p3, p4, p5, p6, s1, s8, s16,
s32, s64, v2s1, v2s8, v2s16, v2s32, v2s64, v3s1, v3s8, v3s16,
v3s32, v3s64, v4s1, v4s8, v4s16, v4s32, v4s64, v8s1, v8s8, v8s16,
v8s32, v8s64, v16s1, v16s8, v16s16, v16s32, v16s64};

auto allScalarsAndVectors = {
s1, s8, s16, s32, s64, v2s1, v2s8, v2s16, v2s32, v2s64,
Expand All @@ -133,8 +135,8 @@ SPIRVLegalizerInfo::SPIRVLegalizerInfo(const SPIRVSubtarget &ST) {

auto allFloatAndIntScalars = allIntScalars;

auto allPtrs = {p0, p1, p2, p3, p4, p5};
auto allWritablePtrs = {p0, p1, p3, p4};
auto allPtrs = {p0, p1, p2, p3, p4, p5, p6};
auto allWritablePtrs = {p0, p1, p3, p4, p5, p6};

for (auto Opc : TypeFoldingSupportingOpcs)
getActionDefinitionsBuilder(Opc).custom();
Expand Down
7 changes: 7 additions & 0 deletions llvm/lib/Target/SPIRV/SPIRVModuleAnalysis.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1063,6 +1063,13 @@ void addInstrRequirements(const MachineInstr &MI,
Reqs.addCapability(SPIRV::Capability::ExpectAssumeKHR);
}
break;
case SPIRV::OpPtrCastToCrossWorkgroupINTEL:
case SPIRV::OpCrossWorkgroupCastToPtrINTEL:
if (ST.canUseExtension(SPIRV::Extension::SPV_INTEL_usm_storage_classes)) {
Reqs.addExtension(SPIRV::Extension::SPV_INTEL_usm_storage_classes);
Reqs.addCapability(SPIRV::Capability::USMStorageClassesINTEL);
}
break;
case SPIRV::OpConstantFunctionPointerINTEL:
if (ST.canUseExtension(SPIRV::Extension::SPV_INTEL_function_pointers)) {
Reqs.addExtension(SPIRV::Extension::SPV_INTEL_function_pointers);
Expand Down
11 changes: 9 additions & 2 deletions llvm/lib/Target/SPIRV/SPIRVPreLegalizer.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -122,6 +122,9 @@ static void foldConstantsIntoIntrinsics(MachineFunction &MF) {

static void insertBitcasts(MachineFunction &MF, SPIRVGlobalRegistry *GR,
MachineIRBuilder MIB) {
// Get access to information about available extensions
const SPIRVSubtarget *ST =
static_cast<const SPIRVSubtarget *>(&MIB.getMF().getSubtarget());
SmallVector<MachineInstr *, 10> ToErase;
for (MachineBasicBlock &MBB : MF) {
for (MachineInstr &MI : MBB) {
Expand All @@ -141,7 +144,7 @@ static void insertBitcasts(MachineFunction &MF, SPIRVGlobalRegistry *GR,
getMDOperandAsType(MI.getOperand(3).getMetadata(), 0), MIB);
SPIRVType *AssignedPtrType = GR->getOrCreateSPIRVPointerType(
BaseTy, MI, *MF.getSubtarget<SPIRVSubtarget>().getInstrInfo(),
addressSpaceToStorageClass(MI.getOperand(4).getImm()));
addressSpaceToStorageClass(MI.getOperand(4).getImm(), *ST));

// If the bitcast would be redundant, replace all uses with the source
// register.
Expand Down Expand Up @@ -250,6 +253,10 @@ Register insertAssignInstr(Register Reg, Type *Ty, SPIRVType *SpirvTy,

static void generateAssignInstrs(MachineFunction &MF, SPIRVGlobalRegistry *GR,
MachineIRBuilder MIB) {
// Get access to information about available extensions
const SPIRVSubtarget *ST =
static_cast<const SPIRVSubtarget *>(&MIB.getMF().getSubtarget());

MachineRegisterInfo &MRI = MF.getRegInfo();
SmallVector<MachineInstr *, 10> ToErase;

Expand All @@ -269,7 +276,7 @@ static void generateAssignInstrs(MachineFunction &MF, SPIRVGlobalRegistry *GR,
getMDOperandAsType(MI.getOperand(2).getMetadata(), 0), MIB);
SPIRVType *AssignedPtrType = GR->getOrCreateSPIRVPointerType(
BaseTy, MI, *MF.getSubtarget<SPIRVSubtarget>().getInstrInfo(),
addressSpaceToStorageClass(MI.getOperand(3).getImm()));
addressSpaceToStorageClass(MI.getOperand(3).getImm(), *ST));
MachineInstr *Def = MRI.getVRegDef(Reg);
assert(Def && "Expecting an instruction that defines the register");
insertAssignInstr(Reg, nullptr, AssignedPtrType, GR, MIB,
Expand Down
6 changes: 6 additions & 0 deletions llvm/lib/Target/SPIRV/SPIRVSubtarget.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,12 @@ cl::list<SPIRV::Extension::Extension> Extensions(
clEnumValN(SPIRV::Extension::SPV_INTEL_optnone, "SPV_INTEL_optnone",
"Adds OptNoneINTEL value for Function Control mask that "
"indicates a request to not optimize the function."),
clEnumValN(SPIRV::Extension::SPV_INTEL_usm_storage_classes,
"SPV_INTEL_usm_storage_classes",
"Introduces two new storage classes that are sub classes of "
"the CrossWorkgroup storage class "
"that provides additional information that can enable "
"optimization."),
clEnumValN(SPIRV::Extension::SPV_INTEL_subgroups, "SPV_INTEL_subgroups",
"Allows work items in a subgroup to share data without the "
"use of local memory and work group barriers, and to "
Expand Down
3 changes: 3 additions & 0 deletions llvm/lib/Target/SPIRV/SPIRVSymbolicOperands.td
Original file line number Diff line number Diff line change
Expand Up @@ -463,6 +463,7 @@ defm AtomicFloat16MinMaxEXT : CapabilityOperand<5616, 0, 0, [SPV_EXT_shader_atom
defm AtomicFloat32MinMaxEXT : CapabilityOperand<5612, 0, 0, [SPV_EXT_shader_atomic_float_min_max], []>;
defm AtomicFloat64MinMaxEXT : CapabilityOperand<5613, 0, 0, [SPV_EXT_shader_atomic_float_min_max], []>;
defm GroupUniformArithmeticKHR : CapabilityOperand<6400, 0, 0, [SPV_KHR_uniform_group_instructions], []>;
defm USMStorageClassesINTEL : CapabilityOperand<5935, 0, 0, [SPV_INTEL_usm_storage_classes], [Kernel]>;

//===----------------------------------------------------------------------===//
// Multiclass used to define SourceLanguage enum values and at the same time
Expand Down Expand Up @@ -700,6 +701,8 @@ defm IncomingRayPayloadNV : StorageClassOperand<5342, [RayTracingNV]>;
defm ShaderRecordBufferNV : StorageClassOperand<5343, [RayTracingNV]>;
defm PhysicalStorageBufferEXT : StorageClassOperand<5349, [PhysicalStorageBufferAddressesEXT]>;
defm CodeSectionINTEL : StorageClassOperand<5605, [FunctionPointersINTEL]>;
defm DeviceOnlyINTEL : StorageClassOperand<5936, [USMStorageClassesINTEL]>;
defm HostOnlyINTEL : StorageClassOperand<5937, [USMStorageClassesINTEL]>;

//===----------------------------------------------------------------------===//
// Multiclass used to define Dim enum values and at the same time
Expand Down
19 changes: 16 additions & 3 deletions llvm/lib/Target/SPIRV/SPIRVUtils.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
#include "MCTargetDesc/SPIRVBaseInfo.h"
#include "SPIRV.h"
#include "SPIRVInstrInfo.h"
#include "SPIRVSubtarget.h"
#include "llvm/ADT/StringRef.h"
#include "llvm/CodeGen/GlobalISel/GenericMachineInstrs.h"
#include "llvm/CodeGen/GlobalISel/MachineIRBuilder.h"
Expand Down Expand Up @@ -146,15 +147,19 @@ unsigned storageClassToAddressSpace(SPIRV::StorageClass::StorageClass SC) {
return 3;
case SPIRV::StorageClass::Generic:
return 4;
case SPIRV::StorageClass::DeviceOnlyINTEL:
return 5;
case SPIRV::StorageClass::HostOnlyINTEL:
return 6;
case SPIRV::StorageClass::Input:
return 7;
default:
llvm_unreachable("Unable to get address space id");
report_fatal_error("Unable to get address space id");
}
}

SPIRV::StorageClass::StorageClass
addressSpaceToStorageClass(unsigned AddrSpace) {
addressSpaceToStorageClass(unsigned AddrSpace, const SPIRVSubtarget &STI) {
switch (AddrSpace) {
case 0:
return SPIRV::StorageClass::Function;
Expand All @@ -166,10 +171,18 @@ addressSpaceToStorageClass(unsigned AddrSpace) {
return SPIRV::StorageClass::Workgroup;
case 4:
return SPIRV::StorageClass::Generic;
case 5:
return STI.canUseExtension(SPIRV::Extension::SPV_INTEL_usm_storage_classes)
? SPIRV::StorageClass::DeviceOnlyINTEL
: SPIRV::StorageClass::CrossWorkgroup;
case 6:
return STI.canUseExtension(SPIRV::Extension::SPV_INTEL_usm_storage_classes)
? SPIRV::StorageClass::HostOnlyINTEL
: SPIRV::StorageClass::CrossWorkgroup;
case 7:
return SPIRV::StorageClass::Input;
default:
llvm_unreachable("Unknown address space");
report_fatal_error("Unknown address space");
}
}

Expand Down
3 changes: 2 additions & 1 deletion llvm/lib/Target/SPIRV/SPIRVUtils.h
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@ class MachineRegisterInfo;
class Register;
class StringRef;
class SPIRVInstrInfo;
class SPIRVSubtarget;

// Add the given string as a series of integer operand, inserting null
// terminators and padding to make sure the operands all have 32-bit
Expand Down Expand Up @@ -62,7 +63,7 @@ unsigned storageClassToAddressSpace(SPIRV::StorageClass::StorageClass SC);

// Convert an LLVM IR address space to a SPIR-V storage class.
SPIRV::StorageClass::StorageClass
addressSpaceToStorageClass(unsigned AddrSpace);
addressSpaceToStorageClass(unsigned AddrSpace, const SPIRVSubtarget &STI);

SPIRV::MemorySemantics::MemorySemantics
getMemSemanticsForStorageClass(SPIRV::StorageClass::StorageClass SC);
Expand Down
Loading