Expand Up
@@ -155,89 +155,86 @@ SPIRVTypeConverter::getStorageClassForMemorySpace(unsigned space) {
#undef STORAGE_SPACE_MAP_LIST
// TODO: This is a utility function that should probably be
// exposed by the SPIR-V dialect. Keeping it local till the use case arises.
static Optional<int64_t > getTypeNumBytes (Type t) {
if (t.isa <spirv::ScalarType>()) {
auto bitWidth = t.getIntOrFloatBitWidth ();
// TODO: This is a utility function that should probably be exposed by the
// SPIR-V dialect. Keeping it local till the use case arises.
static Optional<int64_t >
getTypeNumBytes (const SPIRVTypeConverter::Options &options, Type type) {
if (type.isa <spirv::ScalarType>()) {
auto bitWidth = type.getIntOrFloatBitWidth ();
// According to the SPIR-V spec:
// "There is no physical size or bit pattern defined for values with boolean
// type. If they are stored (in conjunction with OpVariable), they can only
// be used with logical addressing operations, not physical, and only with
// non-externally visible shader Storage Classes: Workgroup, CrossWorkgroup,
// Private, Function, Input, and Output."
if (bitWidth == 1 ) {
if (bitWidth == 1 )
return llvm::None;
}
return bitWidth / 8 ;
}
if (auto vecType = t .dyn_cast <VectorType>()) {
auto elementSize = getTypeNumBytes (vecType.getElementType ());
if (auto vecType = type .dyn_cast <VectorType>()) {
auto elementSize = getTypeNumBytes (options, vecType.getElementType ());
if (!elementSize)
return llvm::None;
return vecType.getNumElements () * * elementSize;
return vecType.getNumElements () * elementSize. getValue () ;
}
if (auto memRefType = t .dyn_cast <MemRefType>()) {
if (auto memRefType = type .dyn_cast <MemRefType>()) {
// TODO: Layout should also be controlled by the ABI attributes. For now
// using the layout from MemRef.
int64_t offset;
SmallVector<int64_t , 4 > strides;
if (!memRefType.hasStaticShape () ||
failed (getStridesAndOffset (memRefType, strides, offset))) {
failed (getStridesAndOffset (memRefType, strides, offset)))
return llvm::None;
}
// To get the size of the memref object in memory, the total size is the
// max(stride * dimension-size) computed for all dimensions times the size
// of the element.
auto elementSize = getTypeNumBytes (memRefType.getElementType ());
if (!elementSize) {
auto elementSize = getTypeNumBytes (options, memRefType.getElementType ());
if (!elementSize)
return llvm::None;
}
if (memRefType.getRank () == 0 ) {
if (memRefType.getRank () == 0 )
return elementSize;
}
auto dims = memRefType.getShape ();
if (llvm::is_contained (dims, ShapedType::kDynamicSize ) ||
offset == MemRefType::getDynamicStrideOrOffset () ||
llvm::is_contained (strides, MemRefType::getDynamicStrideOrOffset ())) {
llvm::is_contained (strides, MemRefType::getDynamicStrideOrOffset ()))
return llvm::None;
}
int64_t memrefSize = -1 ;
for (auto shape : enumerate(dims)) {
for (auto shape : enumerate(dims))
memrefSize = std::max (memrefSize, shape.value () * strides[shape.index ()]);
}
return (offset + memrefSize) * elementSize.getValue ();
}
if (auto tensorType = t .dyn_cast <TensorType>()) {
if (!tensorType.hasStaticShape ()) {
if (auto tensorType = type .dyn_cast <TensorType>()) {
if (!tensorType.hasStaticShape ())
return llvm::None;
}
auto elementSize = getTypeNumBytes (tensorType.getElementType ());
if (!elementSize) {
auto elementSize = getTypeNumBytes (options, tensorType.getElementType ());
if (!elementSize)
return llvm::None;
}
int64_t size = elementSize.getValue ();
for (auto shape : tensorType.getShape ()) {
for (auto shape : tensorType.getShape ())
size *= shape;
}
return size;
}
// TODO: Add size computation for other types.
return llvm::None;
}
Optional<int64_t > SPIRVTypeConverter::getConvertedTypeNumBytes (Type t) {
return getTypeNumBytes (t);
}
// / Converts a scalar `type` to a suitable type under the given `targetEnv`.
static Optional<Type>
convertScalarType (const spirv::TargetEnv &targetEnv, spirv::ScalarType type,
Optional<spirv::StorageClass> storageClass = {}) {
static Type convertScalarType (const spirv::TargetEnv &targetEnv,
const SPIRVTypeConverter::Options &options,
spirv::ScalarType type,
Optional<spirv::StorageClass> storageClass = {}) {
// Get extension and capability requirements for the given type.
SmallVector<ArrayRef<spirv::Extension>, 1 > extensions;
SmallVector<ArrayRef<spirv::Capability>, 2 > capabilities;
Expand All
@@ -251,13 +248,9 @@ convertScalarType(const spirv::TargetEnv &targetEnv, spirv::ScalarType type,
// Otherwise we need to adjust the type, which really means adjusting the
// bitwidth given this is a scalar type.
// TODO: We are unconditionally converting the bitwidth here,
// this might be okay for non-interface types (i.e., types used in
// Private/Function storage classes), but not for interface types (i.e.,
// types used in StorageBuffer/Uniform/PushConstant/etc. storage classes).
// This is because the later actually affects the ABI contract with the
// runtime. So we may want to expose a control on SPIRVTypeConverter to fail
// conversion if we cannot change there.
if (!options.emulateNon32BitScalarTypes )
return nullptr ;
if (auto floatType = type.dyn_cast <FloatType>()) {
LLVM_DEBUG (llvm::dbgs () << type << " converted to 32-bit for SPIR-V\n " );
Expand All
@@ -271,17 +264,18 @@ convertScalarType(const spirv::TargetEnv &targetEnv, spirv::ScalarType type,
}
// / Converts a vector `type` to a suitable type under the given `targetEnv`.
static Optional<Type>
convertVectorType (const spirv::TargetEnv &targetEnv, VectorType type,
Optional<spirv::StorageClass> storageClass = {}) {
static Type convertVectorType (const spirv::TargetEnv &targetEnv,
const SPIRVTypeConverter::Options &options,
VectorType type,
Optional<spirv::StorageClass> storageClass = {}) {
if (type.getRank () == 1 && type.getNumElements () == 1 )
return type.getElementType ();
if (!spirv::CompositeType::isValid (type)) {
// TODO: Vector types with more than four elements can be translated into
// array types.
LLVM_DEBUG (llvm::dbgs () << type << " illegal: > 4-element unimplemented\n " );
return llvm::None ;
return nullptr ;
}
// Get extension and capability requirements for the given type.
Expand All
@@ -296,115 +290,120 @@ convertVectorType(const spirv::TargetEnv &targetEnv, VectorType type,
return type;
auto elementType = convertScalarType (
targetEnv, type.getElementType ().cast <spirv::ScalarType>(), storageClass);
targetEnv, options, type.getElementType ().cast <spirv::ScalarType>(),
storageClass);
if (elementType)
return VectorType::get (type.getShape (), * elementType);
return llvm::None ;
return VectorType::get (type.getShape (), elementType);
return nullptr ;
}
// / Converts a tensor `type` to a suitable type under the given `targetEnv`.
// /
// / Note that this is mainly for lowering constant tensors.In SPIR-V one can
// / Note that this is mainly for lowering constant tensors. In SPIR-V one can
// / create composite constants with OpConstantComposite to embed relative large
// / constant values and use OpCompositeExtract and OpCompositeInsert to
// / manipulate, like what we do for vectors.
static Optional<Type> convertTensorType (const spirv::TargetEnv &targetEnv,
TensorType type) {
static Type convertTensorType (const spirv::TargetEnv &targetEnv,
const SPIRVTypeConverter::Options &options,
TensorType type) {
// TODO: Handle dynamic shapes.
if (!type.hasStaticShape ()) {
LLVM_DEBUG (llvm::dbgs ()
<< type << " illegal: dynamic shape unimplemented\n " );
return llvm::None ;
return nullptr ;
}
auto scalarType = type.getElementType ().dyn_cast <spirv::ScalarType>();
if (!scalarType) {
LLVM_DEBUG (llvm::dbgs ()
<< type << " illegal: cannot convert non-scalar element type\n " );
return llvm::None ;
return nullptr ;
}
Optional<int64_t > scalarSize = getTypeNumBytes (scalarType);
Optional<int64_t > tensorSize = getTypeNumBytes (type);
Optional<int64_t > scalarSize = getTypeNumBytes (options, scalarType);
Optional<int64_t > tensorSize = getTypeNumBytes (options, type);
if (!scalarSize || !tensorSize) {
LLVM_DEBUG (llvm::dbgs ()
<< type << " illegal: cannot deduce element count\n " );
return llvm::None ;
return nullptr ;
}
auto arrayElemCount = *tensorSize / *scalarSize;
auto arrayElemType = convertScalarType (targetEnv, scalarType);
auto arrayElemType = convertScalarType (targetEnv, options, scalarType);
if (!arrayElemType)
return llvm::None ;
Optional<int64_t > arrayElemSize = getTypeNumBytes (* arrayElemType);
return nullptr ;
Optional<int64_t > arrayElemSize = getTypeNumBytes (options, arrayElemType);
if (!arrayElemSize) {
LLVM_DEBUG (llvm::dbgs ()
<< type << " illegal: cannot deduce converted element size\n " );
return llvm::None ;
return nullptr ;
}
return spirv::ArrayType::get (* arrayElemType, arrayElemCount, *arrayElemSize);
return spirv::ArrayType::get (arrayElemType, arrayElemCount, *arrayElemSize);
}
static Optional<Type> convertMemrefType (const spirv::TargetEnv &targetEnv,
MemRefType type) {
static Type convertMemrefType (const spirv::TargetEnv &targetEnv,
const SPIRVTypeConverter::Options &options,
MemRefType type) {
Optional<spirv::StorageClass> storageClass =
SPIRVTypeConverter::getStorageClassForMemorySpace (
type.getMemorySpaceAsInt ());
if (!storageClass) {
LLVM_DEBUG (llvm::dbgs ()
<< type << " illegal: cannot convert memory space\n " );
return llvm::None ;
return nullptr ;
}
Optional< Type> arrayElemType;
Type arrayElemType;
Type elementType = type.getElementType ();
if (auto vecType = elementType.dyn_cast <VectorType>()) {
arrayElemType = convertVectorType (targetEnv, vecType, storageClass);
arrayElemType =
convertVectorType (targetEnv, options, vecType, storageClass);
} else if (auto scalarType = elementType.dyn_cast <spirv::ScalarType>()) {
arrayElemType = convertScalarType (targetEnv, scalarType, storageClass);
arrayElemType =
convertScalarType (targetEnv, options, scalarType, storageClass);
} else {
LLVM_DEBUG (
llvm::dbgs ()
<< type
<< " unhandled: can only convert scalar or vector element type\n " );
return llvm::None ;
return nullptr ;
}
if (!arrayElemType)
return llvm::None ;
return nullptr ;
Optional<int64_t > elementSize = getTypeNumBytes (elementType);
Optional<int64_t > elementSize = getTypeNumBytes (options, elementType);
if (!elementSize) {
LLVM_DEBUG (llvm::dbgs ()
<< type << " illegal: cannot deduce element size\n " );
return llvm::None ;
return nullptr ;
}
if (!type.hasStaticShape ()) {
auto arrayType = spirv::RuntimeArrayType::get (* arrayElemType, *elementSize);
auto arrayType = spirv::RuntimeArrayType::get (arrayElemType, *elementSize);
// Wrap in a struct to satisfy Vulkan interface requirements.
auto structType = spirv::StructType::get (arrayType, 0 );
return spirv::PointerType::get (structType, *storageClass);
}
Optional<int64_t > memrefSize = getTypeNumBytes (type);
Optional<int64_t > memrefSize = getTypeNumBytes (options, type);
if (!memrefSize) {
LLVM_DEBUG (llvm::dbgs ()
<< type << " illegal: cannot deduce element count\n " );
return llvm::None ;
return nullptr ;
}
auto arrayElemCount = *memrefSize / *elementSize;
Optional<int64_t > arrayElemSize = getTypeNumBytes (* arrayElemType);
Optional<int64_t > arrayElemSize = getTypeNumBytes (options, arrayElemType);
if (!arrayElemSize) {
LLVM_DEBUG (llvm::dbgs ()
<< type << " illegal: cannot deduce converted element size\n " );
return llvm::None ;
return nullptr ;
}
auto arrayType =
spirv::ArrayType::get (* arrayElemType, arrayElemCount, *arrayElemSize);
spirv::ArrayType::get (arrayElemType, arrayElemCount, *arrayElemSize);
// Wrap in a struct to satisfy Vulkan interface requirements. Memrefs with
// workgroup storage class do not need the struct to be laid out explicitly.
Expand All
@@ -414,13 +413,11 @@ static Optional<Type> convertMemrefType(const spirv::TargetEnv &targetEnv,
return spirv::PointerType::get (structType, *storageClass);
}
SPIRVTypeConverter::SPIRVTypeConverter (spirv::TargetEnvAttr targetAttr)
: targetEnv(targetAttr) {
SPIRVTypeConverter::SPIRVTypeConverter (spirv::TargetEnvAttr targetAttr,
Options options)
: targetEnv(targetAttr), options(options) {
// Add conversions. The order matters here: later ones will be tried earlier.
// All other cases failed. Then we cannot convert this type.
addConversion ([](Type type) { return llvm::None; });
// Allow all SPIR-V dialect specific types. This assumes all builtin types
// adopted in the SPIR-V dialect (i.e., IntegerType, FloatType, VectorType)
// were tried before.
Expand All
@@ -437,26 +434,26 @@ SPIRVTypeConverter::SPIRVTypeConverter(spirv::TargetEnvAttr targetAttr)
addConversion ([this ](IntegerType intType) -> Optional<Type> {
if (auto scalarType = intType.dyn_cast <spirv::ScalarType>())
return convertScalarType (targetEnv, scalarType);
return llvm::None ;
return convertScalarType (this -> targetEnv , this -> options , scalarType);
return Type () ;
});
addConversion ([this ](FloatType floatType) -> Optional<Type> {
if (auto scalarType = floatType.dyn_cast <spirv::ScalarType>())
return convertScalarType (targetEnv, scalarType);
return llvm::None ;
return convertScalarType (this -> targetEnv , this -> options , scalarType);
return Type () ;
});
addConversion ([this ](VectorType vectorType) {
return convertVectorType (targetEnv, vectorType);
return convertVectorType (this -> targetEnv , this -> options , vectorType);
});
addConversion ([this ](TensorType tensorType) {
return convertTensorType (targetEnv, tensorType);
return convertTensorType (this -> targetEnv , this -> options , tensorType);
});
addConversion ([this ](MemRefType memRefType) {
return convertMemrefType (targetEnv, memRefType);
return convertMemrefType (this -> targetEnv , this -> options , memRefType);
});
}
Expand Down
Expand Up
@@ -493,8 +490,11 @@ FuncOpConversion::matchAndRewrite(FuncOp funcOp, ArrayRef<Value> operands,
}
Type resultType;
if (fnType.getNumResults () == 1 )
if (fnType.getNumResults () == 1 ) {
resultType = getTypeConverter ()->convertType (fnType.getResult (0 ));
if (!resultType)
return failure ();
}
// Create the converted spv.func op.
auto newFuncOp = rewriter.create <spirv::FuncOp>(
Expand Down