diff --git a/mlir/lib/Conversion/MemRefToSPIRV/MemRefToSPIRV.cpp b/mlir/lib/Conversion/MemRefToSPIRV/MemRefToSPIRV.cpp index 81116cf6f13ad..565dee6f27589 100644 --- a/mlir/lib/Conversion/MemRefToSPIRV/MemRefToSPIRV.cpp +++ b/mlir/lib/Conversion/MemRefToSPIRV/MemRefToSPIRV.cpp @@ -163,6 +163,27 @@ static std::optional getAtomicOpScope(MemRefType type) { return {}; } +/// Extracts the element type from a SPIR-V pointer type pointing to storage. +/// +/// For Kernel capability, the pointer points directly to the element type +/// (possibly wrapped in an array). For Vulkan, the pointer points to a struct +/// containing an array or runtime array, and we need to unwrap to get the +/// element type. +static Type +getElementTypeForStoragePointer(Type pointeeType, + const SPIRVTypeConverter &typeConverter) { + if (typeConverter.allows(spirv::Capability::Kernel)) { + if (auto arrayType = dyn_cast(pointeeType)) + return arrayType.getElementType(); + return pointeeType; + } + // For Vulkan we need to extract element from wrapping struct and array. + Type structElemType = cast(pointeeType).getElementType(0); + if (auto arrayType = dyn_cast(structElemType)) + return arrayType.getElementType(); + return cast(structElemType).getElementType(); +} + /// Casts the given `srcInt` into a boolean value. static Value castIntNToBool(Location loc, Value srcInt, OpBuilder &builder) { if (srcInt.getType().isInteger(1)) @@ -437,22 +458,8 @@ AtomicRMWOpPattern::matchAndRewrite(memref::AtomicRMWOp atomicOp, "failed to convert memref type"); Type pointeeType = pointerType.getPointeeType(); - IntegerType dstType; - if (typeConverter.allows(spirv::Capability::Kernel)) { - if (auto arrayType = dyn_cast(pointeeType)) - dstType = dyn_cast(arrayType.getElementType()); - else - dstType = dyn_cast(pointeeType); - } else { - Type structElemType = - cast(pointeeType).getElementType(0); - if (auto arrayType = dyn_cast(structElemType)) - dstType = dyn_cast(arrayType.getElementType()); - else - dstType = dyn_cast( - cast(structElemType).getElementType()); - } - + auto dstType = dyn_cast( + getElementTypeForStoragePointer(pointeeType, typeConverter)); if (!dstType) return rewriter.notifyMatchFailure( atomicOp, "failed to determine destination element type"); @@ -691,21 +698,7 @@ IntLoadOpPattern::matchAndRewrite(memref::LoadOp loadOp, OpAdaptor adaptor, return rewriter.notifyMatchFailure(loadOp, "failed to convert memref type"); Type pointeeType = pointerType.getPointeeType(); - Type dstType; - if (typeConverter.allows(spirv::Capability::Kernel)) { - if (auto arrayType = dyn_cast(pointeeType)) - dstType = arrayType.getElementType(); - else - dstType = pointeeType; - } else { - // For Vulkan we need to extract element from wrapping struct and array. - Type structElemType = - cast(pointeeType).getElementType(0); - if (auto arrayType = dyn_cast(structElemType)) - dstType = arrayType.getElementType(); - else - dstType = cast(structElemType).getElementType(); - } + Type dstType = getElementTypeForStoragePointer(pointeeType, typeConverter); int dstBits = dstType.getIntOrFloatBitWidth(); assert(dstBits % srcBits == 0); @@ -963,23 +956,8 @@ IntStoreOpPattern::matchAndRewrite(memref::StoreOp storeOp, OpAdaptor adaptor, "failed to convert memref type"); Type pointeeType = pointerType.getPointeeType(); - IntegerType dstType; - if (typeConverter.allows(spirv::Capability::Kernel)) { - if (auto arrayType = dyn_cast(pointeeType)) - dstType = dyn_cast(arrayType.getElementType()); - else - dstType = dyn_cast(pointeeType); - } else { - // For Vulkan we need to extract element from wrapping struct and array. - Type structElemType = - cast(pointeeType).getElementType(0); - if (auto arrayType = dyn_cast(structElemType)) - dstType = dyn_cast(arrayType.getElementType()); - else - dstType = dyn_cast( - cast(structElemType).getElementType()); - } - + auto dstType = dyn_cast( + getElementTypeForStoragePointer(pointeeType, typeConverter)); if (!dstType) return rewriter.notifyMatchFailure( storeOp, "failed to determine destination element type");