Skip to content

Commit

Permalink
[SPIR-V] Emit proper pointer type for OpenCL kernel arguments (#67726)
Browse files Browse the repository at this point in the history
  • Loading branch information
michalpaszkowski committed Oct 19, 2023
1 parent b858309 commit 8175190
Show file tree
Hide file tree
Showing 10 changed files with 315 additions and 86 deletions.
107 changes: 52 additions & 55 deletions llvm/lib/Target/SPIRV/SPIRVBuiltins.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2010,60 +2010,6 @@ static Type *parseTypeString(const StringRef Name, LLVMContext &Context) {
llvm_unreachable("Unable to recognize type!");
}

static const TargetExtType *parseToTargetExtType(const Type *OpaqueType,
MachineIRBuilder &MIRBuilder) {
assert(isSpecialOpaqueType(OpaqueType) &&
"Not a SPIR-V/OpenCL special opaque type!");
assert(!OpaqueType->isTargetExtTy() &&
"This already is SPIR-V/OpenCL TargetExtType!");

StringRef NameWithParameters = OpaqueType->getStructName();

// Pointers-to-opaque-structs representing OpenCL types are first translated
// to equivalent SPIR-V types. OpenCL builtin type names should have the
// following format: e.g. %opencl.event_t
if (NameWithParameters.startswith("opencl.")) {
const SPIRV::OpenCLType *OCLTypeRecord =
SPIRV::lookupOpenCLType(NameWithParameters);
if (!OCLTypeRecord)
report_fatal_error("Missing TableGen record for OpenCL type: " +
NameWithParameters);
NameWithParameters = OCLTypeRecord->SpirvTypeLiteral;
// Continue with the SPIR-V builtin type...
}

// Names of the opaque structs representing a SPIR-V builtins without
// parameters should have the following format: e.g. %spirv.Event
assert(NameWithParameters.startswith("spirv.") &&
"Unknown builtin opaque type!");

// Parameterized SPIR-V builtins names follow this format:
// e.g. %spirv.Image._void_1_0_0_0_0_0_0, %spirv.Pipe._0
if (NameWithParameters.find('_') == std::string::npos)
return TargetExtType::get(OpaqueType->getContext(), NameWithParameters);

SmallVector<StringRef> Parameters;
unsigned BaseNameLength = NameWithParameters.find('_') - 1;
SplitString(NameWithParameters.substr(BaseNameLength + 1), Parameters, "_");

SmallVector<Type *, 1> TypeParameters;
bool HasTypeParameter = !isDigit(Parameters[0][0]);
if (HasTypeParameter)
TypeParameters.push_back(parseTypeString(
Parameters[0], MIRBuilder.getMF().getFunction().getContext()));
SmallVector<unsigned> IntParameters;
for (unsigned i = HasTypeParameter ? 1 : 0; i < Parameters.size(); i++) {
unsigned IntParameter = 0;
bool ValidLiteral = !Parameters[i].getAsInteger(10, IntParameter);
assert(ValidLiteral &&
"Invalid format of SPIR-V builtin parameter literal!");
IntParameters.push_back(IntParameter);
}
return TargetExtType::get(OpaqueType->getContext(),
NameWithParameters.substr(0, BaseNameLength),
TypeParameters, IntParameters);
}

//===----------------------------------------------------------------------===//
// Implementation functions for builtin types.
//===----------------------------------------------------------------------===//
Expand Down Expand Up @@ -2127,6 +2073,56 @@ static SPIRVType *getSampledImageType(const TargetExtType *OpaqueType,
}

namespace SPIRV {
const TargetExtType *
parseBuiltinTypeNameToTargetExtType(std::string TypeName,
MachineIRBuilder &MIRBuilder) {
StringRef NameWithParameters = TypeName;

// Pointers-to-opaque-structs representing OpenCL types are first translated
// to equivalent SPIR-V types. OpenCL builtin type names should have the
// following format: e.g. %opencl.event_t
if (NameWithParameters.startswith("opencl.")) {
const SPIRV::OpenCLType *OCLTypeRecord =
SPIRV::lookupOpenCLType(NameWithParameters);
if (!OCLTypeRecord)
report_fatal_error("Missing TableGen record for OpenCL type: " +
NameWithParameters);
NameWithParameters = OCLTypeRecord->SpirvTypeLiteral;
// Continue with the SPIR-V builtin type...
}

// Names of the opaque structs representing a SPIR-V builtins without
// parameters should have the following format: e.g. %spirv.Event
assert(NameWithParameters.startswith("spirv.") &&
"Unknown builtin opaque type!");

// Parameterized SPIR-V builtins names follow this format:
// e.g. %spirv.Image._void_1_0_0_0_0_0_0, %spirv.Pipe._0
if (NameWithParameters.find('_') == std::string::npos)
return TargetExtType::get(MIRBuilder.getContext(), NameWithParameters);

SmallVector<StringRef> Parameters;
unsigned BaseNameLength = NameWithParameters.find('_') - 1;
SplitString(NameWithParameters.substr(BaseNameLength + 1), Parameters, "_");

SmallVector<Type *, 1> TypeParameters;
bool HasTypeParameter = !isDigit(Parameters[0][0]);
if (HasTypeParameter)
TypeParameters.push_back(parseTypeString(
Parameters[0], MIRBuilder.getMF().getFunction().getContext()));
SmallVector<unsigned> IntParameters;
for (unsigned i = HasTypeParameter ? 1 : 0; i < Parameters.size(); i++) {
unsigned IntParameter = 0;
bool ValidLiteral = !Parameters[i].getAsInteger(10, IntParameter);
assert(ValidLiteral &&
"Invalid format of SPIR-V builtin parameter literal!");
IntParameters.push_back(IntParameter);
}
return TargetExtType::get(MIRBuilder.getContext(),
NameWithParameters.substr(0, BaseNameLength),
TypeParameters, IntParameters);
}

SPIRVType *lowerBuiltinType(const Type *OpaqueType,
SPIRV::AccessQualifier::AccessQualifier AccessQual,
MachineIRBuilder &MIRBuilder,
Expand All @@ -2141,7 +2137,8 @@ SPIRVType *lowerBuiltinType(const Type *OpaqueType,
// will be removed in the future release of LLVM.
const TargetExtType *BuiltinType = dyn_cast<TargetExtType>(OpaqueType);
if (!BuiltinType)
BuiltinType = parseToTargetExtType(OpaqueType, MIRBuilder);
BuiltinType = parseBuiltinTypeNameToTargetExtType(
OpaqueType->getStructName().str(), MIRBuilder);

unsigned NumStartingVRegs = MIRBuilder.getMRI()->getNumVirtRegs();

Expand Down
12 changes: 12 additions & 0 deletions llvm/lib/Target/SPIRV/SPIRVBuiltins.h
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,18 @@ std::optional<bool> lowerBuiltin(const StringRef DemangledCall,
const Register OrigRet, const Type *OrigRetTy,
const SmallVectorImpl<Register> &Args,
SPIRVGlobalRegistry *GR);

/// Translates a string representing a SPIR-V or OpenCL builtin type to a
/// TargetExtType that can be further lowered with lowerBuiltinType().
///
/// \return A TargetExtType representing the builtin SPIR-V type.
///
/// \p TypeName is the full string representation of the SPIR-V or OpenCL
/// builtin type.
const TargetExtType *
parseBuiltinTypeNameToTargetExtType(std::string TypeName,
MachineIRBuilder &MIRBuilder);

/// Handles the translation of the provided special opaque/builtin type \p Type
/// to SPIR-V type. Generates the corresponding machine instructions for the
/// target type or gets the already existing OpType<...> register from the
Expand Down
43 changes: 28 additions & 15 deletions llvm/lib/Target/SPIRV/SPIRVCallLowering.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -194,23 +194,38 @@ getKernelArgTypeQual(const Function &KernelFunction, unsigned ArgIdx) {
return {};
}

static Type *getArgType(const Function &F, unsigned ArgIdx) {
static SPIRVType *getArgSPIRVType(const Function &F, unsigned ArgIdx,
SPIRVGlobalRegistry *GR,
MachineIRBuilder &MIRBuilder) {
// Read argument's access qualifier from metadata or default.
SPIRV::AccessQualifier::AccessQualifier ArgAccessQual =
getArgAccessQual(F, ArgIdx);

Type *OriginalArgType = getOriginalFunctionType(F)->getParamType(ArgIdx);

// In case of non-kernel SPIR-V function or already TargetExtType, use the
// original IR type.
if (F.getCallingConv() != CallingConv::SPIR_KERNEL ||
isSpecialOpaqueType(OriginalArgType))
return OriginalArgType;
return GR->getOrCreateSPIRVType(OriginalArgType, MIRBuilder, ArgAccessQual);

MDString *MDKernelArgType =
getKernelArgAttribute(F, ArgIdx, "kernel_arg_type");
if (!MDKernelArgType || !MDKernelArgType->getString().endswith("_t"))
return OriginalArgType;

std::string KernelArgTypeStr = "opencl." + MDKernelArgType->getString().str();
Type *ExistingOpaqueType =
StructType::getTypeByName(F.getContext(), KernelArgTypeStr);
return ExistingOpaqueType
? ExistingOpaqueType
: StructType::create(F.getContext(), KernelArgTypeStr);
if (!MDKernelArgType || (MDKernelArgType->getString().ends_with("*") &&
MDKernelArgType->getString().ends_with("_t")))
return GR->getOrCreateSPIRVType(OriginalArgType, MIRBuilder, ArgAccessQual);

if (MDKernelArgType->getString().ends_with("*"))
return GR->getOrCreateSPIRVTypeByName(
MDKernelArgType->getString(), MIRBuilder,
addressSpaceToStorageClass(OriginalArgType->getPointerAddressSpace()));

if (MDKernelArgType->getString().ends_with("_t"))
return GR->getOrCreateSPIRVTypeByName(
"opencl." + MDKernelArgType->getString().str(), MIRBuilder,
SPIRV::StorageClass::Function, ArgAccessQual);

llvm_unreachable("Unable to recognize argument type name.");
}

static bool isEntryPoint(const Function &F) {
Expand Down Expand Up @@ -262,10 +277,8 @@ bool SPIRVCallLowering::lowerFormalArguments(MachineIRBuilder &MIRBuilder,
// TODO: handle the case of multiple registers.
if (VRegs[i].size() > 1)
return false;
SPIRV::AccessQualifier::AccessQualifier ArgAccessQual =
getArgAccessQual(F, i);
auto *SpirvTy = GR->assignTypeToVReg(getArgType(F, i), VRegs[i][0],
MIRBuilder, ArgAccessQual);
auto *SpirvTy = getArgSPIRVType(F, i, GR, MIRBuilder);
GR->assignSPIRVTypeToVReg(SpirvTy, VRegs[i][0], MIRBuilder.getMF());
ArgTypeVRegs.push_back(SpirvTy);

if (Arg.hasName())
Expand Down
68 changes: 55 additions & 13 deletions llvm/lib/Target/SPIRV/SPIRVGlobalRegistry.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -956,40 +956,82 @@ SPIRVGlobalRegistry::checkSpecialInstr(const SPIRV::SpecialTypeDescriptor &TD,
}

// TODO: maybe use tablegen to implement this.
SPIRVType *
SPIRVGlobalRegistry::getOrCreateSPIRVTypeByName(StringRef TypeStr,
MachineIRBuilder &MIRBuilder) {
SPIRVType *SPIRVGlobalRegistry::getOrCreateSPIRVTypeByName(
StringRef TypeStr, MachineIRBuilder &MIRBuilder,
SPIRV::StorageClass::StorageClass SC,
SPIRV::AccessQualifier::AccessQualifier AQ) {
unsigned VecElts = 0;
auto &Ctx = MIRBuilder.getMF().getFunction().getContext();

// Parse strings representing either a SPIR-V or OpenCL builtin type.
if (hasBuiltinTypePrefix(TypeStr))
return getOrCreateSPIRVType(
SPIRV::parseBuiltinTypeNameToTargetExtType(TypeStr.str(), MIRBuilder),
MIRBuilder, AQ);

// Parse type name in either "typeN" or "type vector[N]" format, where
// N is the number of elements of the vector.
Type *Type;
Type *Ty;

if (TypeStr.starts_with("atomic_"))
TypeStr = TypeStr.substr(strlen("atomic_"));

if (TypeStr.startswith("void")) {
Type = Type::getVoidTy(Ctx);
Ty = Type::getVoidTy(Ctx);
TypeStr = TypeStr.substr(strlen("void"));
} else if (TypeStr.startswith("bool")) {
Ty = Type::getIntNTy(Ctx, 1);
TypeStr = TypeStr.substr(strlen("bool"));
} else if (TypeStr.startswith("char") || TypeStr.startswith("uchar")) {
Ty = Type::getInt8Ty(Ctx);
TypeStr = TypeStr.startswith("char") ? TypeStr.substr(strlen("char"))
: TypeStr.substr(strlen("uchar"));
} else if (TypeStr.startswith("short") || TypeStr.startswith("ushort")) {
Ty = Type::getInt16Ty(Ctx);
TypeStr = TypeStr.startswith("short") ? TypeStr.substr(strlen("short"))
: TypeStr.substr(strlen("ushort"));
} else if (TypeStr.startswith("int") || TypeStr.startswith("uint")) {
Type = Type::getInt32Ty(Ctx);
Ty = Type::getInt32Ty(Ctx);
TypeStr = TypeStr.startswith("int") ? TypeStr.substr(strlen("int"))
: TypeStr.substr(strlen("uint"));
} else if (TypeStr.startswith("float")) {
Type = Type::getFloatTy(Ctx);
TypeStr = TypeStr.substr(strlen("float"));
} else if (TypeStr.starts_with("long") || TypeStr.starts_with("ulong")) {
Ty = Type::getInt64Ty(Ctx);
TypeStr = TypeStr.startswith("long") ? TypeStr.substr(strlen("long"))
: TypeStr.substr(strlen("ulong"));
} else if (TypeStr.startswith("half")) {
Type = Type::getHalfTy(Ctx);
Ty = Type::getHalfTy(Ctx);
TypeStr = TypeStr.substr(strlen("half"));
} else if (TypeStr.startswith("opencl.sampler_t")) {
Type = StructType::create(Ctx, "opencl.sampler_t");
} else if (TypeStr.startswith("float")) {
Ty = Type::getFloatTy(Ctx);
TypeStr = TypeStr.substr(strlen("float"));
} else if (TypeStr.startswith("double")) {
Ty = Type::getDoubleTy(Ctx);
TypeStr = TypeStr.substr(strlen("double"));
} else
llvm_unreachable("Unable to recognize SPIRV type name.");

auto SpirvTy = getOrCreateSPIRVType(Ty, MIRBuilder, AQ);

// Handle "type*" or "type* vector[N]".
if (TypeStr.starts_with("*")) {
SpirvTy = getOrCreateSPIRVPointerType(SpirvTy, MIRBuilder, SC);
TypeStr = TypeStr.substr(strlen("*"));
}

// Handle "typeN*" or "type vector[N]*".
bool IsPtrToVec = TypeStr.consume_back("*");

if (TypeStr.startswith(" vector[")) {
TypeStr = TypeStr.substr(strlen(" vector["));
TypeStr = TypeStr.substr(0, TypeStr.find(']'));
}
TypeStr.getAsInteger(10, VecElts);
auto SpirvTy = getOrCreateSPIRVType(Type, MIRBuilder);
if (VecElts > 0)
SpirvTy = getOrCreateSPIRVVectorType(SpirvTy, VecElts, MIRBuilder);

if (IsPtrToVec)
SpirvTy = getOrCreateSPIRVPointerType(SpirvTy, MIRBuilder, SC);

return SpirvTy;
}

Expand Down
7 changes: 5 additions & 2 deletions llvm/lib/Target/SPIRV/SPIRVGlobalRegistry.h
Original file line number Diff line number Diff line change
Expand Up @@ -138,8 +138,11 @@ class SPIRVGlobalRegistry {

// Either generate a new OpTypeXXX instruction or return an existing one
// corresponding to the given string containing the name of the builtin type.
SPIRVType *getOrCreateSPIRVTypeByName(StringRef TypeStr,
MachineIRBuilder &MIRBuilder);
SPIRVType *getOrCreateSPIRVTypeByName(
StringRef TypeStr, MachineIRBuilder &MIRBuilder,
SPIRV::StorageClass::StorageClass SC = SPIRV::StorageClass::Function,
SPIRV::AccessQualifier::AccessQualifier AQ =
SPIRV::AccessQualifier::ReadWrite);

// Return the SPIR-V type instruction corresponding to the given VReg, or
// nullptr if no such type instruction exists.
Expand Down
2 changes: 1 addition & 1 deletion llvm/lib/Target/SPIRV/SPIRVUtils.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -332,7 +332,7 @@ const Type *getTypedPtrEltType(const Type *Ty) {
return PType->getNonOpaquePointerElementType();
}

static bool hasBuiltinTypePrefix(StringRef Name) {
bool hasBuiltinTypePrefix(StringRef Name) {
if (Name.starts_with("opencl.") || Name.starts_with("spirv."))
return true;
return false;
Expand Down
3 changes: 3 additions & 0 deletions llvm/lib/Target/SPIRV/SPIRVUtils.h
Original file line number Diff line number Diff line change
Expand Up @@ -92,6 +92,9 @@ std::string getOclOrSpirvBuiltinDemangledName(StringRef Name);
// element type, otherwise return Type.
const Type *getTypedPtrEltType(const Type *Type);

// Check if a string contains a builtin prefix.
bool hasBuiltinTypePrefix(StringRef Name);

// Check if given LLVM type is a special opaque builtin type.
bool isSpecialOpaqueType(const Type *Ty);
} // namespace llvm
Expand Down
18 changes: 18 additions & 0 deletions llvm/test/CodeGen/SPIRV/pointers/getelementptr-base-type.ll
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
; RUN: llc -O0 -mtriple=spirv64-unknown-unknown %s -o - | FileCheck %s

; CHECK: %[[#FLOAT32:]] = OpTypeFloat 32
; CHECK: %[[#PTR:]] = OpTypePointer CrossWorkgroup %[[#FLOAT32]]
; CHECK: %[[#ARG:]] = OpFunctionParameter %[[#PTR]]
; CHECK: %[[#GEP:]] = OpInBoundsPtrAccessChain %[[#PTR]] %[[#ARG]] %[[#]]
; CHECK: %[[#]] = OpLoad %[[#FLOAT32]] %[[#GEP]] Aligned 4

define spir_kernel void @test1(ptr addrspace(1) %arg1) !kernel_arg_addr_space !1 !kernel_arg_access_qual !2 !kernel_arg_type !3 !kernel_arg_type_qual !4 {
%a = getelementptr inbounds float, ptr addrspace(1) %arg1, i64 1
%b = load float, ptr addrspace(1) %a, align 4
ret void
}

!1 = !{i32 1}
!2 = !{!"none"}
!3 = !{!"float*"}
!4 = !{!""}

0 comments on commit 8175190

Please sign in to comment.