-
Notifications
You must be signed in to change notification settings - Fork 11.8k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[mlir][spirv][memref] Calculate alignment for PhysicalStorageBuffer
s
#80243
Conversation
The SPIR-V spec requires that memory accesses to PhysicalStorageBuffers are annotated with appropriate alignment attributes. Calculate these based on memref alignment attributes or scalar type sizes.
@llvm/pr-subscribers-mlir Author: Jakub Kuderski (kuhar) ChangesThe SPIR-V spec requires that memory accesses to [1] Otherwise spirv-val complains:
Full diff: https://github.com/llvm/llvm-project/pull/80243.diff 2 Files Affected:
diff --git a/mlir/lib/Conversion/MemRefToSPIRV/MemRefToSPIRV.cpp b/mlir/lib/Conversion/MemRefToSPIRV/MemRefToSPIRV.cpp
index acddb3c4da461..c0470ce189cd3 100644
--- a/mlir/lib/Conversion/MemRefToSPIRV/MemRefToSPIRV.cpp
+++ b/mlir/lib/Conversion/MemRefToSPIRV/MemRefToSPIRV.cpp
@@ -12,12 +12,18 @@
#include "mlir/Dialect/Arith/IR/Arith.h"
#include "mlir/Dialect/MemRef/IR/MemRef.h"
+#include "mlir/Dialect/SPIRV/IR/SPIRVAttributes.h"
#include "mlir/Dialect/SPIRV/IR/SPIRVEnums.h"
#include "mlir/Dialect/SPIRV/IR/SPIRVOps.h"
#include "mlir/Dialect/SPIRV/IR/SPIRVTypes.h"
#include "mlir/Dialect/SPIRV/Transforms/SPIRVConversion.h"
+#include "mlir/IR/BuiltinAttributes.h"
#include "mlir/IR/BuiltinTypes.h"
+#include "mlir/IR/MLIRContext.h"
+#include "mlir/IR/Visitors.h"
+#include "mlir/Support/LogicalResult.h"
#include "llvm/Support/Debug.h"
+#include <cassert>
#include <optional>
#define DEBUG_TYPE "memref-to-spirv-pattern"
@@ -439,6 +445,52 @@ DeallocOpPattern::matchAndRewrite(memref::DeallocOp operation,
// LoadOp
//===----------------------------------------------------------------------===//
+using AlignmentRequirements =
+ FailureOr<std::pair<spirv::MemoryAccessAttr, IntegerAttr>>;
+
+/// Given an accessed SPIR-V pointer, calculates its alignment requirements, if
+/// any.
+static AlignmentRequirements calculateRequiredAlignment(Value accessedPtr) {
+ auto ptrType = cast<spirv::PointerType>(accessedPtr.getType());
+ if (ptrType.getStorageClass() != spirv::StorageClass::PhysicalStorageBuffer)
+ return std::pair{spirv::MemoryAccessAttr{}, IntegerAttr{}};
+
+ // PhysicalStorageBuffers require the `Aligned` attribute.
+ auto pointeeType = dyn_cast<spirv::ScalarType>(ptrType.getPointeeType());
+ if (!pointeeType)
+ return failure();
+
+ // For scalar types, the alignment is determined by their size.
+ std::optional<int64_t> sizeInBytes = pointeeType.getSizeInBytes();
+ if (!sizeInBytes.has_value())
+ return failure();
+
+ MLIRContext *ctx = accessedPtr.getContext();
+ auto memAccessAttr =
+ spirv::MemoryAccessAttr::get(ctx, spirv::MemoryAccess::Aligned);
+ auto alignment = IntegerAttr::get(IntegerType::get(ctx, 32), *sizeInBytes);
+ return std::pair{memAccessAttr, alignment};
+}
+
+/// Given an accessed SPIR-V pointer and the original memref load/store
+/// `memAccess` op, calculates the alignment requirements, if any. Takes into
+/// account the alignment attributes applied to the load/store op.
+static AlignmentRequirements
+calculateRequiredAlignment(Value accessedPtr, Operation *memrefAccessOp) {
+ assert(memrefAccessOp);
+ assert((isa<memref::LoadOp, memref::StoreOp>(memrefAccessOp)) &&
+ "Bad op type");
+
+ auto memrefMemAccess = memrefAccessOp->getAttrOfType<spirv::MemoryAccessAttr>(
+ spirv::attributeName<spirv::MemoryAccess>());
+ auto memrefAlignment =
+ memrefAccessOp->getAttrOfType<IntegerAttr>("alignment");
+ if (memrefMemAccess && memrefAlignment)
+ return std::pair{memrefMemAccess, memrefAlignment};
+
+ return calculateRequiredAlignment(accessedPtr);
+}
+
LogicalResult
IntLoadOpPattern::matchAndRewrite(memref::LoadOp loadOp, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const {
@@ -486,7 +538,15 @@ IntLoadOpPattern::matchAndRewrite(memref::LoadOp loadOp, OpAdaptor adaptor,
// If the rewritten load op has the same bit width, use the loading value
// directly.
if (srcBits == dstBits) {
- Value loadVal = rewriter.create<spirv::LoadOp>(loc, accessChain);
+ AlignmentRequirements alignmentRequirements =
+ calculateRequiredAlignment(accessChain, loadOp);
+ if (failed(alignmentRequirements))
+ return rewriter.notifyMatchFailure(
+ loadOp, "failed to determine alignment requirements");
+
+ auto [memoryAccess, alignment] = *alignmentRequirements;
+ Value loadVal = rewriter.create<spirv::LoadOp>(loc, accessChain,
+ memoryAccess, alignment);
if (isBool)
loadVal = castIntNToBool(loc, loadVal, rewriter);
rewriter.replaceOp(loadOp, loadVal);
@@ -508,11 +568,15 @@ IntLoadOpPattern::matchAndRewrite(memref::LoadOp loadOp, OpAdaptor adaptor,
assert(accessChainOp.getIndices().size() == 2);
Value adjustedPtr = adjustAccessChainForBitwidth(typeConverter, accessChainOp,
srcBits, dstBits, rewriter);
- Value spvLoadOp = rewriter.create<spirv::LoadOp>(
- loc, dstType, adjustedPtr,
- loadOp->getAttrOfType<spirv::MemoryAccessAttr>(
- spirv::attributeName<spirv::MemoryAccess>()),
- loadOp->getAttrOfType<IntegerAttr>("alignment"));
+ AlignmentRequirements alignmentRequirements =
+ calculateRequiredAlignment(adjustedPtr, loadOp);
+ if (failed(alignmentRequirements))
+ return rewriter.notifyMatchFailure(
+ loadOp, "failed to determine alignment requirements");
+
+ auto [memoryAccess, alignment] = *alignmentRequirements;
+ Value spvLoadOp = rewriter.create<spirv::LoadOp>(loc, dstType, adjustedPtr,
+ memoryAccess, alignment);
// Shift the bits to the rightmost.
// ____XXXX________ -> ____________XXXX
@@ -552,14 +616,21 @@ LoadOpPattern::matchAndRewrite(memref::LoadOp loadOp, OpAdaptor adaptor,
auto memrefType = cast<MemRefType>(loadOp.getMemref().getType());
if (memrefType.getElementType().isSignlessInteger())
return failure();
- auto loadPtr = spirv::getElementPtr(
+ Value loadPtr = spirv::getElementPtr(
*getTypeConverter<SPIRVTypeConverter>(), memrefType, adaptor.getMemref(),
adaptor.getIndices(), loadOp.getLoc(), rewriter);
if (!loadPtr)
return failure();
- rewriter.replaceOpWithNewOp<spirv::LoadOp>(loadOp, loadPtr);
+ AlignmentRequirements requiredAlignment = calculateRequiredAlignment(loadPtr);
+ if (failed(requiredAlignment))
+ return rewriter.notifyMatchFailure(
+ loadOp, "failed to determine alignment requirements");
+
+ auto [memAccessAttr, alignment] = *requiredAlignment;
+ rewriter.replaceOpWithNewOp<spirv::LoadOp>(loadOp, loadPtr, memAccessAttr,
+ alignment);
return success();
}
@@ -618,10 +689,18 @@ IntStoreOpPattern::matchAndRewrite(memref::StoreOp storeOp, OpAdaptor adaptor,
assert(dstBits % srcBits == 0);
if (srcBits == dstBits) {
+ AlignmentRequirements requiredAlignment =
+ calculateRequiredAlignment(accessChain);
+ if (failed(requiredAlignment))
+ return rewriter.notifyMatchFailure(
+ storeOp, "failed to determine alignment requirements");
+
+ auto [memAccessAttr, alignment] = *requiredAlignment;
Value storeVal = adaptor.getValue();
if (isBool)
storeVal = castBoolToIntN(loc, storeVal, dstType, rewriter);
- rewriter.replaceOpWithNewOp<spirv::StoreOp>(storeOp, accessChain, storeVal);
+ rewriter.replaceOpWithNewOp<spirv::StoreOp>(storeOp, accessChain, storeVal,
+ memAccessAttr, alignment);
return success();
}
diff --git a/mlir/test/Conversion/MemRefToSPIRV/memref-to-spirv.mlir b/mlir/test/Conversion/MemRefToSPIRV/memref-to-spirv.mlir
index a8b550367d5fa..aa05fd9bc8ca8 100644
--- a/mlir/test/Conversion/MemRefToSPIRV/memref-to-spirv.mlir
+++ b/mlir/test/Conversion/MemRefToSPIRV/memref-to-spirv.mlir
@@ -1,17 +1,19 @@
-// RUN: mlir-opt -split-input-file -convert-memref-to-spirv="bool-num-bits=8" -cse %s -o - | FileCheck %s
+// RUN: mlir-opt --split-input-file --convert-memref-to-spirv="bool-num-bits=8" --cse %s | FileCheck %s
// Check that with proper compute and storage extensions, we don't need to
// perform special tricks.
module attributes {
spirv.target_env = #spirv.target_env<
- #spirv.vce<v1.0,
+ #spirv.vce<v1.5,
[
Shader, Int8, Int16, Int64, Float16, Float64,
StorageBuffer16BitAccess, StorageUniform16, StoragePushConstant16,
- StorageBuffer8BitAccess, UniformAndStorageBuffer8BitAccess, StoragePushConstant8
+ StorageBuffer8BitAccess, UniformAndStorageBuffer8BitAccess, StoragePushConstant8,
+ PhysicalStorageBufferAddresses
],
- [SPV_KHR_16bit_storage, SPV_KHR_8bit_storage, SPV_KHR_storage_buffer_storage_class]>, #spirv.resource_limits<>>
+ [SPV_KHR_16bit_storage, SPV_KHR_8bit_storage, SPV_KHR_storage_buffer_storage_class, SPV_KHR_physical_storage_buffer]>,
+ #spirv.resource_limits<>>
} {
// CHECK-LABEL: @load_store_zero_rank_float
@@ -119,6 +121,51 @@ func.func @store_i16(%arg0: memref<10xi16, #spirv.storage_class<StorageBuffer>>,
return
}
+// CHECK-LABEL: @load_store_i32_physical
+func.func @load_store_i32_physical(%arg0: memref<i32, #spirv.storage_class<PhysicalStorageBuffer>>) {
+ // CHECK: spirv.Load "PhysicalStorageBuffer" %{{.+}} ["Aligned", 4] : i32
+ // CHECK: spirv.Store "PhysicalStorageBuffer" %{{.+}}, %{{.+}} ["Aligned", 4] : i32
+ %0 = memref.load %arg0[] : memref<i32, #spirv.storage_class<PhysicalStorageBuffer>>
+ memref.store %0, %arg0[] : memref<i32, #spirv.storage_class<PhysicalStorageBuffer>>
+ return
+}
+
+// CHECK-LABEL: @load_store_i8_physical
+func.func @load_store_i8_physical(%arg0: memref<i8, #spirv.storage_class<PhysicalStorageBuffer>>) {
+ // CHECK: spirv.Load "PhysicalStorageBuffer" %{{.+}} ["Aligned", 1] : i8
+ // CHECK: spirv.Store "PhysicalStorageBuffer" %{{.+}}, %{{.+}} ["Aligned", 1] : i8
+ %0 = memref.load %arg0[] : memref<i8, #spirv.storage_class<PhysicalStorageBuffer>>
+ memref.store %0, %arg0[] : memref<i8, #spirv.storage_class<PhysicalStorageBuffer>>
+ return
+}
+
+// CHECK-LABEL: @load_store_i1_physical
+func.func @load_store_i1_physical(%arg0: memref<i1, #spirv.storage_class<PhysicalStorageBuffer>>) {
+ // CHECK: spirv.Load "PhysicalStorageBuffer" %{{.+}} ["Aligned", 1] : i8
+ // CHECK: spirv.Store "PhysicalStorageBuffer" %{{.+}}, %{{.+}} ["Aligned", 1] : i8
+ %0 = memref.load %arg0[] : memref<i1, #spirv.storage_class<PhysicalStorageBuffer>>
+ memref.store %0, %arg0[] : memref<i1, #spirv.storage_class<PhysicalStorageBuffer>>
+ return
+}
+
+// CHECK-LABEL: @load_store_f32_physical
+func.func @load_store_f32_physical(%arg0: memref<f32, #spirv.storage_class<PhysicalStorageBuffer>>) {
+ // CHECK: spirv.Load "PhysicalStorageBuffer" %{{.+}} ["Aligned", 4] : f32
+ // CHECK: spirv.Store "PhysicalStorageBuffer" %{{.+}}, %{{.+}} ["Aligned", 4] : f32
+ %0 = memref.load %arg0[] : memref<f32, #spirv.storage_class<PhysicalStorageBuffer>>
+ memref.store %0, %arg0[] : memref<f32, #spirv.storage_class<PhysicalStorageBuffer>>
+ return
+}
+
+// CHECK-LABEL: @load_store_f16_physical
+func.func @load_store_f16_physical(%arg0: memref<f16, #spirv.storage_class<PhysicalStorageBuffer>>) {
+ // CHECK: spirv.Load "PhysicalStorageBuffer" %{{.+}} ["Aligned", 2] : f16
+ // CHECK: spirv.Store "PhysicalStorageBuffer" %{{.+}}, %{{.+}} ["Aligned", 2] : f16
+ %0 = memref.load %arg0[] : memref<f16, #spirv.storage_class<PhysicalStorageBuffer>>
+ memref.store %0, %arg0[] : memref<f16, #spirv.storage_class<PhysicalStorageBuffer>>
+ return
+}
+
} // end module
// -----
|
@llvm/pr-subscribers-mlir-spirv Author: Jakub Kuderski (kuhar) ChangesThe SPIR-V spec requires that memory accesses to [1] Otherwise spirv-val complains:
Full diff: https://github.com/llvm/llvm-project/pull/80243.diff 2 Files Affected:
diff --git a/mlir/lib/Conversion/MemRefToSPIRV/MemRefToSPIRV.cpp b/mlir/lib/Conversion/MemRefToSPIRV/MemRefToSPIRV.cpp
index acddb3c4da461..c0470ce189cd3 100644
--- a/mlir/lib/Conversion/MemRefToSPIRV/MemRefToSPIRV.cpp
+++ b/mlir/lib/Conversion/MemRefToSPIRV/MemRefToSPIRV.cpp
@@ -12,12 +12,18 @@
#include "mlir/Dialect/Arith/IR/Arith.h"
#include "mlir/Dialect/MemRef/IR/MemRef.h"
+#include "mlir/Dialect/SPIRV/IR/SPIRVAttributes.h"
#include "mlir/Dialect/SPIRV/IR/SPIRVEnums.h"
#include "mlir/Dialect/SPIRV/IR/SPIRVOps.h"
#include "mlir/Dialect/SPIRV/IR/SPIRVTypes.h"
#include "mlir/Dialect/SPIRV/Transforms/SPIRVConversion.h"
+#include "mlir/IR/BuiltinAttributes.h"
#include "mlir/IR/BuiltinTypes.h"
+#include "mlir/IR/MLIRContext.h"
+#include "mlir/IR/Visitors.h"
+#include "mlir/Support/LogicalResult.h"
#include "llvm/Support/Debug.h"
+#include <cassert>
#include <optional>
#define DEBUG_TYPE "memref-to-spirv-pattern"
@@ -439,6 +445,52 @@ DeallocOpPattern::matchAndRewrite(memref::DeallocOp operation,
// LoadOp
//===----------------------------------------------------------------------===//
+using AlignmentRequirements =
+ FailureOr<std::pair<spirv::MemoryAccessAttr, IntegerAttr>>;
+
+/// Given an accessed SPIR-V pointer, calculates its alignment requirements, if
+/// any.
+static AlignmentRequirements calculateRequiredAlignment(Value accessedPtr) {
+ auto ptrType = cast<spirv::PointerType>(accessedPtr.getType());
+ if (ptrType.getStorageClass() != spirv::StorageClass::PhysicalStorageBuffer)
+ return std::pair{spirv::MemoryAccessAttr{}, IntegerAttr{}};
+
+ // PhysicalStorageBuffers require the `Aligned` attribute.
+ auto pointeeType = dyn_cast<spirv::ScalarType>(ptrType.getPointeeType());
+ if (!pointeeType)
+ return failure();
+
+ // For scalar types, the alignment is determined by their size.
+ std::optional<int64_t> sizeInBytes = pointeeType.getSizeInBytes();
+ if (!sizeInBytes.has_value())
+ return failure();
+
+ MLIRContext *ctx = accessedPtr.getContext();
+ auto memAccessAttr =
+ spirv::MemoryAccessAttr::get(ctx, spirv::MemoryAccess::Aligned);
+ auto alignment = IntegerAttr::get(IntegerType::get(ctx, 32), *sizeInBytes);
+ return std::pair{memAccessAttr, alignment};
+}
+
+/// Given an accessed SPIR-V pointer and the original memref load/store
+/// `memAccess` op, calculates the alignment requirements, if any. Takes into
+/// account the alignment attributes applied to the load/store op.
+static AlignmentRequirements
+calculateRequiredAlignment(Value accessedPtr, Operation *memrefAccessOp) {
+ assert(memrefAccessOp);
+ assert((isa<memref::LoadOp, memref::StoreOp>(memrefAccessOp)) &&
+ "Bad op type");
+
+ auto memrefMemAccess = memrefAccessOp->getAttrOfType<spirv::MemoryAccessAttr>(
+ spirv::attributeName<spirv::MemoryAccess>());
+ auto memrefAlignment =
+ memrefAccessOp->getAttrOfType<IntegerAttr>("alignment");
+ if (memrefMemAccess && memrefAlignment)
+ return std::pair{memrefMemAccess, memrefAlignment};
+
+ return calculateRequiredAlignment(accessedPtr);
+}
+
LogicalResult
IntLoadOpPattern::matchAndRewrite(memref::LoadOp loadOp, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const {
@@ -486,7 +538,15 @@ IntLoadOpPattern::matchAndRewrite(memref::LoadOp loadOp, OpAdaptor adaptor,
// If the rewritten load op has the same bit width, use the loading value
// directly.
if (srcBits == dstBits) {
- Value loadVal = rewriter.create<spirv::LoadOp>(loc, accessChain);
+ AlignmentRequirements alignmentRequirements =
+ calculateRequiredAlignment(accessChain, loadOp);
+ if (failed(alignmentRequirements))
+ return rewriter.notifyMatchFailure(
+ loadOp, "failed to determine alignment requirements");
+
+ auto [memoryAccess, alignment] = *alignmentRequirements;
+ Value loadVal = rewriter.create<spirv::LoadOp>(loc, accessChain,
+ memoryAccess, alignment);
if (isBool)
loadVal = castIntNToBool(loc, loadVal, rewriter);
rewriter.replaceOp(loadOp, loadVal);
@@ -508,11 +568,15 @@ IntLoadOpPattern::matchAndRewrite(memref::LoadOp loadOp, OpAdaptor adaptor,
assert(accessChainOp.getIndices().size() == 2);
Value adjustedPtr = adjustAccessChainForBitwidth(typeConverter, accessChainOp,
srcBits, dstBits, rewriter);
- Value spvLoadOp = rewriter.create<spirv::LoadOp>(
- loc, dstType, adjustedPtr,
- loadOp->getAttrOfType<spirv::MemoryAccessAttr>(
- spirv::attributeName<spirv::MemoryAccess>()),
- loadOp->getAttrOfType<IntegerAttr>("alignment"));
+ AlignmentRequirements alignmentRequirements =
+ calculateRequiredAlignment(adjustedPtr, loadOp);
+ if (failed(alignmentRequirements))
+ return rewriter.notifyMatchFailure(
+ loadOp, "failed to determine alignment requirements");
+
+ auto [memoryAccess, alignment] = *alignmentRequirements;
+ Value spvLoadOp = rewriter.create<spirv::LoadOp>(loc, dstType, adjustedPtr,
+ memoryAccess, alignment);
// Shift the bits to the rightmost.
// ____XXXX________ -> ____________XXXX
@@ -552,14 +616,21 @@ LoadOpPattern::matchAndRewrite(memref::LoadOp loadOp, OpAdaptor adaptor,
auto memrefType = cast<MemRefType>(loadOp.getMemref().getType());
if (memrefType.getElementType().isSignlessInteger())
return failure();
- auto loadPtr = spirv::getElementPtr(
+ Value loadPtr = spirv::getElementPtr(
*getTypeConverter<SPIRVTypeConverter>(), memrefType, adaptor.getMemref(),
adaptor.getIndices(), loadOp.getLoc(), rewriter);
if (!loadPtr)
return failure();
- rewriter.replaceOpWithNewOp<spirv::LoadOp>(loadOp, loadPtr);
+ AlignmentRequirements requiredAlignment = calculateRequiredAlignment(loadPtr);
+ if (failed(requiredAlignment))
+ return rewriter.notifyMatchFailure(
+ loadOp, "failed to determine alignment requirements");
+
+ auto [memAccessAttr, alignment] = *requiredAlignment;
+ rewriter.replaceOpWithNewOp<spirv::LoadOp>(loadOp, loadPtr, memAccessAttr,
+ alignment);
return success();
}
@@ -618,10 +689,18 @@ IntStoreOpPattern::matchAndRewrite(memref::StoreOp storeOp, OpAdaptor adaptor,
assert(dstBits % srcBits == 0);
if (srcBits == dstBits) {
+ AlignmentRequirements requiredAlignment =
+ calculateRequiredAlignment(accessChain);
+ if (failed(requiredAlignment))
+ return rewriter.notifyMatchFailure(
+ storeOp, "failed to determine alignment requirements");
+
+ auto [memAccessAttr, alignment] = *requiredAlignment;
Value storeVal = adaptor.getValue();
if (isBool)
storeVal = castBoolToIntN(loc, storeVal, dstType, rewriter);
- rewriter.replaceOpWithNewOp<spirv::StoreOp>(storeOp, accessChain, storeVal);
+ rewriter.replaceOpWithNewOp<spirv::StoreOp>(storeOp, accessChain, storeVal,
+ memAccessAttr, alignment);
return success();
}
diff --git a/mlir/test/Conversion/MemRefToSPIRV/memref-to-spirv.mlir b/mlir/test/Conversion/MemRefToSPIRV/memref-to-spirv.mlir
index a8b550367d5fa..aa05fd9bc8ca8 100644
--- a/mlir/test/Conversion/MemRefToSPIRV/memref-to-spirv.mlir
+++ b/mlir/test/Conversion/MemRefToSPIRV/memref-to-spirv.mlir
@@ -1,17 +1,19 @@
-// RUN: mlir-opt -split-input-file -convert-memref-to-spirv="bool-num-bits=8" -cse %s -o - | FileCheck %s
+// RUN: mlir-opt --split-input-file --convert-memref-to-spirv="bool-num-bits=8" --cse %s | FileCheck %s
// Check that with proper compute and storage extensions, we don't need to
// perform special tricks.
module attributes {
spirv.target_env = #spirv.target_env<
- #spirv.vce<v1.0,
+ #spirv.vce<v1.5,
[
Shader, Int8, Int16, Int64, Float16, Float64,
StorageBuffer16BitAccess, StorageUniform16, StoragePushConstant16,
- StorageBuffer8BitAccess, UniformAndStorageBuffer8BitAccess, StoragePushConstant8
+ StorageBuffer8BitAccess, UniformAndStorageBuffer8BitAccess, StoragePushConstant8,
+ PhysicalStorageBufferAddresses
],
- [SPV_KHR_16bit_storage, SPV_KHR_8bit_storage, SPV_KHR_storage_buffer_storage_class]>, #spirv.resource_limits<>>
+ [SPV_KHR_16bit_storage, SPV_KHR_8bit_storage, SPV_KHR_storage_buffer_storage_class, SPV_KHR_physical_storage_buffer]>,
+ #spirv.resource_limits<>>
} {
// CHECK-LABEL: @load_store_zero_rank_float
@@ -119,6 +121,51 @@ func.func @store_i16(%arg0: memref<10xi16, #spirv.storage_class<StorageBuffer>>,
return
}
+// CHECK-LABEL: @load_store_i32_physical
+func.func @load_store_i32_physical(%arg0: memref<i32, #spirv.storage_class<PhysicalStorageBuffer>>) {
+ // CHECK: spirv.Load "PhysicalStorageBuffer" %{{.+}} ["Aligned", 4] : i32
+ // CHECK: spirv.Store "PhysicalStorageBuffer" %{{.+}}, %{{.+}} ["Aligned", 4] : i32
+ %0 = memref.load %arg0[] : memref<i32, #spirv.storage_class<PhysicalStorageBuffer>>
+ memref.store %0, %arg0[] : memref<i32, #spirv.storage_class<PhysicalStorageBuffer>>
+ return
+}
+
+// CHECK-LABEL: @load_store_i8_physical
+func.func @load_store_i8_physical(%arg0: memref<i8, #spirv.storage_class<PhysicalStorageBuffer>>) {
+ // CHECK: spirv.Load "PhysicalStorageBuffer" %{{.+}} ["Aligned", 1] : i8
+ // CHECK: spirv.Store "PhysicalStorageBuffer" %{{.+}}, %{{.+}} ["Aligned", 1] : i8
+ %0 = memref.load %arg0[] : memref<i8, #spirv.storage_class<PhysicalStorageBuffer>>
+ memref.store %0, %arg0[] : memref<i8, #spirv.storage_class<PhysicalStorageBuffer>>
+ return
+}
+
+// CHECK-LABEL: @load_store_i1_physical
+func.func @load_store_i1_physical(%arg0: memref<i1, #spirv.storage_class<PhysicalStorageBuffer>>) {
+ // CHECK: spirv.Load "PhysicalStorageBuffer" %{{.+}} ["Aligned", 1] : i8
+ // CHECK: spirv.Store "PhysicalStorageBuffer" %{{.+}}, %{{.+}} ["Aligned", 1] : i8
+ %0 = memref.load %arg0[] : memref<i1, #spirv.storage_class<PhysicalStorageBuffer>>
+ memref.store %0, %arg0[] : memref<i1, #spirv.storage_class<PhysicalStorageBuffer>>
+ return
+}
+
+// CHECK-LABEL: @load_store_f32_physical
+func.func @load_store_f32_physical(%arg0: memref<f32, #spirv.storage_class<PhysicalStorageBuffer>>) {
+ // CHECK: spirv.Load "PhysicalStorageBuffer" %{{.+}} ["Aligned", 4] : f32
+ // CHECK: spirv.Store "PhysicalStorageBuffer" %{{.+}}, %{{.+}} ["Aligned", 4] : f32
+ %0 = memref.load %arg0[] : memref<f32, #spirv.storage_class<PhysicalStorageBuffer>>
+ memref.store %0, %arg0[] : memref<f32, #spirv.storage_class<PhysicalStorageBuffer>>
+ return
+}
+
+// CHECK-LABEL: @load_store_f16_physical
+func.func @load_store_f16_physical(%arg0: memref<f16, #spirv.storage_class<PhysicalStorageBuffer>>) {
+ // CHECK: spirv.Load "PhysicalStorageBuffer" %{{.+}} ["Aligned", 2] : f16
+ // CHECK: spirv.Store "PhysicalStorageBuffer" %{{.+}}, %{{.+}} ["Aligned", 2] : f16
+ %0 = memref.load %arg0[] : memref<f16, #spirv.storage_class<PhysicalStorageBuffer>>
+ memref.store %0, %arg0[] : memref<f16, #spirv.storage_class<PhysicalStorageBuffer>>
+ return
+}
+
} // end module
// -----
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Overall Looks good!
NIT: Maybe have a failing case on the LIT test. But for this specific case may be not super impactful.
llvm#80243) The SPIR-V spec requires that memory accesses to `PhysicalStorageBuffer`s are annotated with appropriate alignment attributes [1]. Calculate these based on memref alignment attributes or scalar type sizes. [1] Otherwise spirv-val complains: ``` [VULKAN] ! Validation Error: [ VUID-VkShaderModuleCreateInfo-pCode-01379 ] | MessageID = 0x2a1bf17f | SPIR-V module not valid: [VUID-StandaloneSpirv-PhysicalStorageBuffer64-04708] Memory accesses with PhysicalStorageBuffer must use Aligned. %48 = OpLoad %float %47 ```
The SPIR-V spec requires that memory accesses to
PhysicalStorageBuffer
s are annotated with appropriate alignment attributes [1]. Calculate these based on memref alignment attributes or scalar type sizes.[1] Otherwise spirv-val complains: