186 changes: 93 additions & 93 deletions mlir/lib/Dialect/SPIRV/Transforms/SPIRVConversion.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -155,89 +155,86 @@ SPIRVTypeConverter::getStorageClassForMemorySpace(unsigned space) {

#undef STORAGE_SPACE_MAP_LIST

// TODO: This is a utility function that should probably be
// exposed by the SPIR-V dialect. Keeping it local till the use case arises.
static Optional<int64_t> getTypeNumBytes(Type t) {
if (t.isa<spirv::ScalarType>()) {
auto bitWidth = t.getIntOrFloatBitWidth();
// TODO: This is a utility function that should probably be exposed by the
// SPIR-V dialect. Keeping it local till the use case arises.
static Optional<int64_t>
getTypeNumBytes(const SPIRVTypeConverter::Options &options, Type type) {
if (type.isa<spirv::ScalarType>()) {
auto bitWidth = type.getIntOrFloatBitWidth();
// According to the SPIR-V spec:
// "There is no physical size or bit pattern defined for values with boolean
// type. If they are stored (in conjunction with OpVariable), they can only
// be used with logical addressing operations, not physical, and only with
// non-externally visible shader Storage Classes: Workgroup, CrossWorkgroup,
// Private, Function, Input, and Output."
if (bitWidth == 1) {
if (bitWidth == 1)
return llvm::None;
}
return bitWidth / 8;
}

if (auto vecType = t.dyn_cast<VectorType>()) {
auto elementSize = getTypeNumBytes(vecType.getElementType());
if (auto vecType = type.dyn_cast<VectorType>()) {
auto elementSize = getTypeNumBytes(options, vecType.getElementType());
if (!elementSize)
return llvm::None;
return vecType.getNumElements() * *elementSize;
return vecType.getNumElements() * elementSize.getValue();
}

if (auto memRefType = t.dyn_cast<MemRefType>()) {
if (auto memRefType = type.dyn_cast<MemRefType>()) {
// TODO: Layout should also be controlled by the ABI attributes. For now
// using the layout from MemRef.
int64_t offset;
SmallVector<int64_t, 4> strides;
if (!memRefType.hasStaticShape() ||
failed(getStridesAndOffset(memRefType, strides, offset))) {
failed(getStridesAndOffset(memRefType, strides, offset)))
return llvm::None;
}

// To get the size of the memref object in memory, the total size is the
// max(stride * dimension-size) computed for all dimensions times the size
// of the element.
auto elementSize = getTypeNumBytes(memRefType.getElementType());
if (!elementSize) {
auto elementSize = getTypeNumBytes(options, memRefType.getElementType());
if (!elementSize)
return llvm::None;
}
if (memRefType.getRank() == 0) {

if (memRefType.getRank() == 0)
return elementSize;
}

auto dims = memRefType.getShape();
if (llvm::is_contained(dims, ShapedType::kDynamicSize) ||
offset == MemRefType::getDynamicStrideOrOffset() ||
llvm::is_contained(strides, MemRefType::getDynamicStrideOrOffset())) {
llvm::is_contained(strides, MemRefType::getDynamicStrideOrOffset()))
return llvm::None;
}

int64_t memrefSize = -1;
for (auto shape : enumerate(dims)) {
for (auto shape : enumerate(dims))
memrefSize = std::max(memrefSize, shape.value() * strides[shape.index()]);
}

return (offset + memrefSize) * elementSize.getValue();
}

if (auto tensorType = t.dyn_cast<TensorType>()) {
if (!tensorType.hasStaticShape()) {
if (auto tensorType = type.dyn_cast<TensorType>()) {
if (!tensorType.hasStaticShape())
return llvm::None;
}
auto elementSize = getTypeNumBytes(tensorType.getElementType());
if (!elementSize) {

auto elementSize = getTypeNumBytes(options, tensorType.getElementType());
if (!elementSize)
return llvm::None;
}

int64_t size = elementSize.getValue();
for (auto shape : tensorType.getShape()) {
for (auto shape : tensorType.getShape())
size *= shape;
}

return size;
}

// TODO: Add size computation for other types.
return llvm::None;
}

Optional<int64_t> SPIRVTypeConverter::getConvertedTypeNumBytes(Type t) {
return getTypeNumBytes(t);
}

/// Converts a scalar `type` to a suitable type under the given `targetEnv`.
static Optional<Type>
convertScalarType(const spirv::TargetEnv &targetEnv, spirv::ScalarType type,
Optional<spirv::StorageClass> storageClass = {}) {
static Type convertScalarType(const spirv::TargetEnv &targetEnv,
const SPIRVTypeConverter::Options &options,
spirv::ScalarType type,
Optional<spirv::StorageClass> storageClass = {}) {
// Get extension and capability requirements for the given type.
SmallVector<ArrayRef<spirv::Extension>, 1> extensions;
SmallVector<ArrayRef<spirv::Capability>, 2> capabilities;
Expand All @@ -251,13 +248,9 @@ convertScalarType(const spirv::TargetEnv &targetEnv, spirv::ScalarType type,

// Otherwise we need to adjust the type, which really means adjusting the
// bitwidth given this is a scalar type.
// TODO: We are unconditionally converting the bitwidth here,
// this might be okay for non-interface types (i.e., types used in
// Private/Function storage classes), but not for interface types (i.e.,
// types used in StorageBuffer/Uniform/PushConstant/etc. storage classes).
// This is because the later actually affects the ABI contract with the
// runtime. So we may want to expose a control on SPIRVTypeConverter to fail
// conversion if we cannot change there.

if (!options.emulateNon32BitScalarTypes)
return nullptr;

if (auto floatType = type.dyn_cast<FloatType>()) {
LLVM_DEBUG(llvm::dbgs() << type << " converted to 32-bit for SPIR-V\n");
Expand All @@ -271,17 +264,18 @@ convertScalarType(const spirv::TargetEnv &targetEnv, spirv::ScalarType type,
}

/// Converts a vector `type` to a suitable type under the given `targetEnv`.
static Optional<Type>
convertVectorType(const spirv::TargetEnv &targetEnv, VectorType type,
Optional<spirv::StorageClass> storageClass = {}) {
static Type convertVectorType(const spirv::TargetEnv &targetEnv,
const SPIRVTypeConverter::Options &options,
VectorType type,
Optional<spirv::StorageClass> storageClass = {}) {
if (type.getRank() == 1 && type.getNumElements() == 1)
return type.getElementType();

if (!spirv::CompositeType::isValid(type)) {
// TODO: Vector types with more than four elements can be translated into
// array types.
LLVM_DEBUG(llvm::dbgs() << type << " illegal: > 4-element unimplemented\n");
return llvm::None;
return nullptr;
}

// Get extension and capability requirements for the given type.
Expand All @@ -296,115 +290,120 @@ convertVectorType(const spirv::TargetEnv &targetEnv, VectorType type,
return type;

auto elementType = convertScalarType(
targetEnv, type.getElementType().cast<spirv::ScalarType>(), storageClass);
targetEnv, options, type.getElementType().cast<spirv::ScalarType>(),
storageClass);
if (elementType)
return VectorType::get(type.getShape(), *elementType);
return llvm::None;
return VectorType::get(type.getShape(), elementType);
return nullptr;
}

/// Converts a tensor `type` to a suitable type under the given `targetEnv`.
///
/// Note that this is mainly for lowering constant tensors.In SPIR-V one can
/// Note that this is mainly for lowering constant tensors. In SPIR-V one can
/// create composite constants with OpConstantComposite to embed relative large
/// constant values and use OpCompositeExtract and OpCompositeInsert to
/// manipulate, like what we do for vectors.
static Optional<Type> convertTensorType(const spirv::TargetEnv &targetEnv,
TensorType type) {
static Type convertTensorType(const spirv::TargetEnv &targetEnv,
const SPIRVTypeConverter::Options &options,
TensorType type) {
// TODO: Handle dynamic shapes.
if (!type.hasStaticShape()) {
LLVM_DEBUG(llvm::dbgs()
<< type << " illegal: dynamic shape unimplemented\n");
return llvm::None;
return nullptr;
}

auto scalarType = type.getElementType().dyn_cast<spirv::ScalarType>();
if (!scalarType) {
LLVM_DEBUG(llvm::dbgs()
<< type << " illegal: cannot convert non-scalar element type\n");
return llvm::None;
return nullptr;
}

Optional<int64_t> scalarSize = getTypeNumBytes(scalarType);
Optional<int64_t> tensorSize = getTypeNumBytes(type);
Optional<int64_t> scalarSize = getTypeNumBytes(options, scalarType);
Optional<int64_t> tensorSize = getTypeNumBytes(options, type);
if (!scalarSize || !tensorSize) {
LLVM_DEBUG(llvm::dbgs()
<< type << " illegal: cannot deduce element count\n");
return llvm::None;
return nullptr;
}

auto arrayElemCount = *tensorSize / *scalarSize;
auto arrayElemType = convertScalarType(targetEnv, scalarType);
auto arrayElemType = convertScalarType(targetEnv, options, scalarType);
if (!arrayElemType)
return llvm::None;
Optional<int64_t> arrayElemSize = getTypeNumBytes(*arrayElemType);
return nullptr;
Optional<int64_t> arrayElemSize = getTypeNumBytes(options, arrayElemType);
if (!arrayElemSize) {
LLVM_DEBUG(llvm::dbgs()
<< type << " illegal: cannot deduce converted element size\n");
return llvm::None;
return nullptr;
}

return spirv::ArrayType::get(*arrayElemType, arrayElemCount, *arrayElemSize);
return spirv::ArrayType::get(arrayElemType, arrayElemCount, *arrayElemSize);
}

static Optional<Type> convertMemrefType(const spirv::TargetEnv &targetEnv,
MemRefType type) {
static Type convertMemrefType(const spirv::TargetEnv &targetEnv,
const SPIRVTypeConverter::Options &options,
MemRefType type) {
Optional<spirv::StorageClass> storageClass =
SPIRVTypeConverter::getStorageClassForMemorySpace(
type.getMemorySpaceAsInt());
if (!storageClass) {
LLVM_DEBUG(llvm::dbgs()
<< type << " illegal: cannot convert memory space\n");
return llvm::None;
return nullptr;
}

Optional<Type> arrayElemType;
Type arrayElemType;
Type elementType = type.getElementType();
if (auto vecType = elementType.dyn_cast<VectorType>()) {
arrayElemType = convertVectorType(targetEnv, vecType, storageClass);
arrayElemType =
convertVectorType(targetEnv, options, vecType, storageClass);
} else if (auto scalarType = elementType.dyn_cast<spirv::ScalarType>()) {
arrayElemType = convertScalarType(targetEnv, scalarType, storageClass);
arrayElemType =
convertScalarType(targetEnv, options, scalarType, storageClass);
} else {
LLVM_DEBUG(
llvm::dbgs()
<< type
<< " unhandled: can only convert scalar or vector element type\n");
return llvm::None;
return nullptr;
}
if (!arrayElemType)
return llvm::None;
return nullptr;

Optional<int64_t> elementSize = getTypeNumBytes(elementType);
Optional<int64_t> elementSize = getTypeNumBytes(options, elementType);
if (!elementSize) {
LLVM_DEBUG(llvm::dbgs()
<< type << " illegal: cannot deduce element size\n");
return llvm::None;
return nullptr;
}

if (!type.hasStaticShape()) {
auto arrayType = spirv::RuntimeArrayType::get(*arrayElemType, *elementSize);
auto arrayType = spirv::RuntimeArrayType::get(arrayElemType, *elementSize);
// Wrap in a struct to satisfy Vulkan interface requirements.
auto structType = spirv::StructType::get(arrayType, 0);
return spirv::PointerType::get(structType, *storageClass);
}

Optional<int64_t> memrefSize = getTypeNumBytes(type);
Optional<int64_t> memrefSize = getTypeNumBytes(options, type);
if (!memrefSize) {
LLVM_DEBUG(llvm::dbgs()
<< type << " illegal: cannot deduce element count\n");
return llvm::None;
return nullptr;
}

auto arrayElemCount = *memrefSize / *elementSize;

Optional<int64_t> arrayElemSize = getTypeNumBytes(*arrayElemType);
Optional<int64_t> arrayElemSize = getTypeNumBytes(options, arrayElemType);
if (!arrayElemSize) {
LLVM_DEBUG(llvm::dbgs()
<< type << " illegal: cannot deduce converted element size\n");
return llvm::None;
return nullptr;
}

auto arrayType =
spirv::ArrayType::get(*arrayElemType, arrayElemCount, *arrayElemSize);
spirv::ArrayType::get(arrayElemType, arrayElemCount, *arrayElemSize);

// Wrap in a struct to satisfy Vulkan interface requirements. Memrefs with
// workgroup storage class do not need the struct to be laid out explicitly.
Expand All @@ -414,13 +413,11 @@ static Optional<Type> convertMemrefType(const spirv::TargetEnv &targetEnv,
return spirv::PointerType::get(structType, *storageClass);
}

SPIRVTypeConverter::SPIRVTypeConverter(spirv::TargetEnvAttr targetAttr)
: targetEnv(targetAttr) {
SPIRVTypeConverter::SPIRVTypeConverter(spirv::TargetEnvAttr targetAttr,
Options options)
: targetEnv(targetAttr), options(options) {
// Add conversions. The order matters here: later ones will be tried earlier.

// All other cases failed. Then we cannot convert this type.
addConversion([](Type type) { return llvm::None; });

// Allow all SPIR-V dialect specific types. This assumes all builtin types
// adopted in the SPIR-V dialect (i.e., IntegerType, FloatType, VectorType)
// were tried before.
Expand All @@ -437,26 +434,26 @@ SPIRVTypeConverter::SPIRVTypeConverter(spirv::TargetEnvAttr targetAttr)

addConversion([this](IntegerType intType) -> Optional<Type> {
if (auto scalarType = intType.dyn_cast<spirv::ScalarType>())
return convertScalarType(targetEnv, scalarType);
return llvm::None;
return convertScalarType(this->targetEnv, this->options, scalarType);
return Type();
});

addConversion([this](FloatType floatType) -> Optional<Type> {
if (auto scalarType = floatType.dyn_cast<spirv::ScalarType>())
return convertScalarType(targetEnv, scalarType);
return llvm::None;
return convertScalarType(this->targetEnv, this->options, scalarType);
return Type();
});

addConversion([this](VectorType vectorType) {
return convertVectorType(targetEnv, vectorType);
return convertVectorType(this->targetEnv, this->options, vectorType);
});

addConversion([this](TensorType tensorType) {
return convertTensorType(targetEnv, tensorType);
return convertTensorType(this->targetEnv, this->options, tensorType);
});

addConversion([this](MemRefType memRefType) {
return convertMemrefType(targetEnv, memRefType);
return convertMemrefType(this->targetEnv, this->options, memRefType);
});
}

Expand Down Expand Up @@ -493,8 +490,11 @@ FuncOpConversion::matchAndRewrite(FuncOp funcOp, ArrayRef<Value> operands,
}

Type resultType;
if (fnType.getNumResults() == 1)
if (fnType.getNumResults() == 1) {
resultType = getTypeConverter()->convertType(fnType.getResult(0));
if (!resultType)
return failure();
}

// Create the converted spv.func op.
auto newFuncOp = rewriter.create<spirv::FuncOp>(
Expand Down
68 changes: 68 additions & 0 deletions mlir/test/Conversion/StandardToSPIRV/std-types-to-spirv.mlir
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
// RUN: mlir-opt -split-input-file -convert-std-to-spirv %s -o - | FileCheck %s
// RUN: mlir-opt -split-input-file -convert-std-to-spirv="emulate-non-32-bit-scalar-types=false" %s -o - | FileCheck %s --check-prefix=NOEMU

//===----------------------------------------------------------------------===//
// Integer types
Expand All @@ -14,18 +15,30 @@ module attributes {
// CHECK-SAME: i32
// CHECK-SAME: si32
// CHECK-SAME: ui32
// NOEMU-LABEL: func @integer8
// NOEMU-SAME: i8
// NOEMU-SAME: si8
// NOEMU-SAME: ui8
func @integer8(%arg0: i8, %arg1: si8, %arg2: ui8) { return }

// CHECK-LABEL: spv.func @integer16
// CHECK-SAME: i32
// CHECK-SAME: si32
// CHECK-SAME: ui32
// NOEMU-LABEL: func @integer16
// NOEMU-SAME: i16
// NOEMU-SAME: si16
// NOEMU-SAME: ui16
func @integer16(%arg0: i16, %arg1: si16, %arg2: ui16) { return }

// CHECK-LABEL: spv.func @integer64
// CHECK-SAME: i32
// CHECK-SAME: si32
// CHECK-SAME: ui32
// NOEMU-LABEL: func @integer64
// NOEMU-SAME: i64
// NOEMU-SAME: si64
// NOEMU-SAME: ui64
func @integer64(%arg0: i64, %arg1: si64, %arg2: ui64) { return }

} // end module
Expand All @@ -42,18 +55,30 @@ module attributes {
// CHECK-SAME: i8
// CHECK-SAME: si8
// CHECK-SAME: ui8
// NOEMU-LABEL: spv.func @integer8
// NOEMU-SAME: i8
// NOEMU-SAME: si8
// NOEMU-SAME: ui8
func @integer8(%arg0: i8, %arg1: si8, %arg2: ui8) { return }

// CHECK-LABEL: spv.func @integer16
// CHECK-SAME: i16
// CHECK-SAME: si16
// CHECK-SAME: ui16
// NOEMU-LABEL: spv.func @integer16
// NOEMU-SAME: i16
// NOEMU-SAME: si16
// NOEMU-SAME: ui16
func @integer16(%arg0: i16, %arg1: si16, %arg2: ui16) { return }

// CHECK-LABEL: spv.func @integer64
// CHECK-SAME: i64
// CHECK-SAME: si64
// CHECK-SAME: ui64
// NOEMU-LABEL: spv.func @integer64
// NOEMU-SAME: i64
// NOEMU-SAME: si64
// NOEMU-SAME: ui64
func @integer64(%arg0: i64, %arg1: si64, %arg2: ui64) { return }

} // end module
Expand Down Expand Up @@ -106,10 +131,14 @@ module attributes {

// CHECK-LABEL: spv.func @float16
// CHECK-SAME: f32
// NOEMU-LABEL: func @float16
// NOEMU-SAME: f16
func @float16(%arg0: f16) { return }

// CHECK-LABEL: spv.func @float64
// CHECK-SAME: f32
// NOEMU-LABEL: func @float64
// NOEMU-SAME: f64
func @float64(%arg0: f64) { return }

} // end module
Expand All @@ -124,10 +153,14 @@ module attributes {

// CHECK-LABEL: spv.func @float16
// CHECK-SAME: f16
// NOEMU-LABEL: spv.func @float16
// NOEMU-SAME: f16
func @float16(%arg0: f16) { return }

// CHECK-LABEL: spv.func @float64
// CHECK-SAME: f64
// NOEMU-LABEL: spv.func @float64
// NOEMU-SAME: f64
func @float64(%arg0: f64) { return }

} // end module
Expand Down Expand Up @@ -276,34 +309,50 @@ module attributes {

// CHECK-LABEL: spv.func @memref_8bit_StorageBuffer
// CHECK-SAME: !spv.ptr<!spv.struct<(!spv.array<16 x i32, stride=4> [0])>, StorageBuffer>
// NOEMU-LABEL: func @memref_8bit_StorageBuffer
// NOEMU-SAME: memref<16xi8>
func @memref_8bit_StorageBuffer(%arg0: memref<16xi8, 0>) { return }

// CHECK-LABEL: spv.func @memref_8bit_Uniform
// CHECK-SAME: !spv.ptr<!spv.struct<(!spv.array<16 x si32, stride=4> [0])>, Uniform>
// NOEMU-LABEL: func @memref_8bit_Uniform
// NOEMU-SAME: memref<16xsi8, 4>
func @memref_8bit_Uniform(%arg0: memref<16xsi8, 4>) { return }

// CHECK-LABEL: spv.func @memref_8bit_PushConstant
// CHECK-SAME: !spv.ptr<!spv.struct<(!spv.array<16 x ui32, stride=4> [0])>, PushConstant>
// NOEMU-LABEL: func @memref_8bit_PushConstant
// NOEMU-SAME: memref<16xui8, 7>
func @memref_8bit_PushConstant(%arg0: memref<16xui8, 7>) { return }

// CHECK-LABEL: spv.func @memref_16bit_StorageBuffer
// CHECK-SAME: !spv.ptr<!spv.struct<(!spv.array<16 x i32, stride=4> [0])>, StorageBuffer>
// NOEMU-LABEL: func @memref_16bit_StorageBuffer
// NOEMU-SAME: memref<16xi16>
func @memref_16bit_StorageBuffer(%arg0: memref<16xi16, 0>) { return }

// CHECK-LABEL: spv.func @memref_16bit_Uniform
// CHECK-SAME: !spv.ptr<!spv.struct<(!spv.array<16 x si32, stride=4> [0])>, Uniform>
// NOEMU-LABEL: func @memref_16bit_Uniform
// NOEMU-SAME: memref<16xsi16, 4>
func @memref_16bit_Uniform(%arg0: memref<16xsi16, 4>) { return }

// CHECK-LABEL: spv.func @memref_16bit_PushConstant
// CHECK-SAME: !spv.ptr<!spv.struct<(!spv.array<16 x ui32, stride=4> [0])>, PushConstant>
// NOEMU-LABEL: func @memref_16bit_PushConstant
// NOEMU-SAME: memref<16xui16, 7>
func @memref_16bit_PushConstant(%arg0: memref<16xui16, 7>) { return }

// CHECK-LABEL: spv.func @memref_16bit_Input
// CHECK-SAME: !spv.ptr<!spv.struct<(!spv.array<16 x f32, stride=4> [0])>, Input>
// NOEMU-LABEL: func @memref_16bit_Input
// NOEMU-SAME: memref<16xf16, 9>
func @memref_16bit_Input(%arg3: memref<16xf16, 9>) { return }

// CHECK-LABEL: spv.func @memref_16bit_Output
// CHECK-SAME: !spv.ptr<!spv.struct<(!spv.array<16 x f32, stride=4> [0])>, Output>
// NOEMU-LABEL: func @memref_16bit_Output
// NOEMU-SAME: memref<16xf16, 10>
func @memref_16bit_Output(%arg4: memref<16xf16, 10>) { return }

} // end module
Expand All @@ -321,11 +370,16 @@ module attributes {

// CHECK-LABEL: spv.func @memref_8bit_PushConstant
// CHECK-SAME: !spv.ptr<!spv.struct<(!spv.array<16 x i8, stride=1> [0])>, PushConstant>
// NOEMU-LABEL: spv.func @memref_8bit_PushConstant
// NOEMU-SAME: !spv.ptr<!spv.struct<(!spv.array<16 x i8, stride=1> [0])>, PushConstant>
func @memref_8bit_PushConstant(%arg0: memref<16xi8, 7>) { return }

// CHECK-LABEL: spv.func @memref_16bit_PushConstant
// CHECK-SAME: !spv.ptr<!spv.struct<(!spv.array<16 x i16, stride=2> [0])>, PushConstant>
// CHECK-SAME: !spv.ptr<!spv.struct<(!spv.array<16 x f16, stride=2> [0])>, PushConstant>
// NOEMU-LABEL: spv.func @memref_16bit_PushConstant
// NOEMU-SAME: !spv.ptr<!spv.struct<(!spv.array<16 x i16, stride=2> [0])>, PushConstant>
// NOEMU-SAME: !spv.ptr<!spv.struct<(!spv.array<16 x f16, stride=2> [0])>, PushConstant>
func @memref_16bit_PushConstant(
%arg0: memref<16xi16, 7>,
%arg1: memref<16xf16, 7>
Expand All @@ -346,11 +400,16 @@ module attributes {

// CHECK-LABEL: spv.func @memref_8bit_StorageBuffer
// CHECK-SAME: !spv.ptr<!spv.struct<(!spv.array<16 x i8, stride=1> [0])>, StorageBuffer>
// NOEMU-LABEL: spv.func @memref_8bit_StorageBuffer
// NOEMU-SAME: !spv.ptr<!spv.struct<(!spv.array<16 x i8, stride=1> [0])>, StorageBuffer>
func @memref_8bit_StorageBuffer(%arg0: memref<16xi8, 0>) { return }

// CHECK-LABEL: spv.func @memref_16bit_StorageBuffer
// CHECK-SAME: !spv.ptr<!spv.struct<(!spv.array<16 x i16, stride=2> [0])>, StorageBuffer>
// CHECK-SAME: !spv.ptr<!spv.struct<(!spv.array<16 x f16, stride=2> [0])>, StorageBuffer>
// NOEMU-LABEL: spv.func @memref_16bit_StorageBuffer
// NOEMU-SAME: !spv.ptr<!spv.struct<(!spv.array<16 x i16, stride=2> [0])>, StorageBuffer>
// NOEMU-SAME: !spv.ptr<!spv.struct<(!spv.array<16 x f16, stride=2> [0])>, StorageBuffer>
func @memref_16bit_StorageBuffer(
%arg0: memref<16xi16, 0>,
%arg1: memref<16xf16, 0>
Expand All @@ -371,11 +430,16 @@ module attributes {

// CHECK-LABEL: spv.func @memref_8bit_Uniform
// CHECK-SAME: !spv.ptr<!spv.struct<(!spv.array<16 x i8, stride=1> [0])>, Uniform>
// NOEMU-LABEL: spv.func @memref_8bit_Uniform
// NOEMU-SAME: !spv.ptr<!spv.struct<(!spv.array<16 x i8, stride=1> [0])>, Uniform>
func @memref_8bit_Uniform(%arg0: memref<16xi8, 4>) { return }

// CHECK-LABEL: spv.func @memref_16bit_Uniform
// CHECK-SAME: !spv.ptr<!spv.struct<(!spv.array<16 x i16, stride=2> [0])>, Uniform>
// CHECK-SAME: !spv.ptr<!spv.struct<(!spv.array<16 x f16, stride=2> [0])>, Uniform>
// NOEMU-LABEL: spv.func @memref_16bit_Uniform
// NOEMU-SAME: !spv.ptr<!spv.struct<(!spv.array<16 x i16, stride=2> [0])>, Uniform>
// NOEMU-SAME: !spv.ptr<!spv.struct<(!spv.array<16 x f16, stride=2> [0])>, Uniform>
func @memref_16bit_Uniform(
%arg0: memref<16xi16, 4>,
%arg1: memref<16xf16, 4>
Expand All @@ -395,10 +459,14 @@ module attributes {

// CHECK-LABEL: spv.func @memref_16bit_Input
// CHECK-SAME: !spv.ptr<!spv.struct<(!spv.array<16 x f16, stride=2> [0])>, Input>
// NOEMU-LABEL: spv.func @memref_16bit_Input
// NOEMU-SAME: !spv.ptr<!spv.struct<(!spv.array<16 x f16, stride=2> [0])>, Input>
func @memref_16bit_Input(%arg3: memref<16xf16, 9>) { return }

// CHECK-LABEL: spv.func @memref_16bit_Output
// CHECK-SAME: !spv.ptr<!spv.struct<(!spv.array<16 x i16, stride=2> [0])>, Output>
// NOEMU-LABEL: spv.func @memref_16bit_Output
// NOEMU-SAME: !spv.ptr<!spv.struct<(!spv.array<16 x i16, stride=2> [0])>, Output>
func @memref_16bit_Output(%arg4: memref<16xi16, 10>) { return }

} // end module
Expand Down