Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[mlir][spirv][memref] Calculate alignment for PhysicalStorageBuffers #80243

Merged
merged 2 commits into from
Feb 1, 2024

Conversation

kuhar
Copy link
Member

@kuhar kuhar commented Feb 1, 2024

The SPIR-V spec requires that memory accesses to PhysicalStorageBuffers are annotated with appropriate alignment attributes [1]. Calculate these based on memref alignment attributes or scalar type sizes.

[1] Otherwise spirv-val complains:

[VULKAN] ! Validation Error: [ VUID-VkShaderModuleCreateInfo-pCode-01379 ] | MessageID = 0x2a1bf17f | SPIR-V module not valid: [VUID-StandaloneSpirv-PhysicalStorageBuffer64-04708] Memory accesses with PhysicalStorageBuffer must use Aligned.
  %48 = OpLoad %float %47

The SPIR-V spec requires that memory accesses to PhysicalStorageBuffers
are annotated with appropriate alignment attributes. Calculate these
based on memref alignment attributes or scalar type sizes.
@llvmbot
Copy link
Collaborator

llvmbot commented Feb 1, 2024

@llvm/pr-subscribers-mlir

Author: Jakub Kuderski (kuhar)

Changes

The SPIR-V spec requires that memory accesses to PhysicalStorageBuffers are annotated with appropriate alignment attributes [1]. Calculate these based on memref alignment attributes or scalar type sizes.

[1] Otherwise spirv-val complains:

[VULKAN] ! Validation Error: [ VUID-VkShaderModuleCreateInfo-pCode-01379 ] | MessageID = 0x2a1bf17f | SPIR-V module not valid: [VUID-StandaloneSpirv-PhysicalStorageBuffer64-04708] Memory accesses with PhysicalStorageBuffer must use Aligned.
  %48 = OpLoad %float %47

Full diff: https://github.com/llvm/llvm-project/pull/80243.diff

2 Files Affected:

  • (modified) mlir/lib/Conversion/MemRefToSPIRV/MemRefToSPIRV.cpp (+88-9)
  • (modified) mlir/test/Conversion/MemRefToSPIRV/memref-to-spirv.mlir (+51-4)
diff --git a/mlir/lib/Conversion/MemRefToSPIRV/MemRefToSPIRV.cpp b/mlir/lib/Conversion/MemRefToSPIRV/MemRefToSPIRV.cpp
index acddb3c4da461..c0470ce189cd3 100644
--- a/mlir/lib/Conversion/MemRefToSPIRV/MemRefToSPIRV.cpp
+++ b/mlir/lib/Conversion/MemRefToSPIRV/MemRefToSPIRV.cpp
@@ -12,12 +12,18 @@
 
 #include "mlir/Dialect/Arith/IR/Arith.h"
 #include "mlir/Dialect/MemRef/IR/MemRef.h"
+#include "mlir/Dialect/SPIRV/IR/SPIRVAttributes.h"
 #include "mlir/Dialect/SPIRV/IR/SPIRVEnums.h"
 #include "mlir/Dialect/SPIRV/IR/SPIRVOps.h"
 #include "mlir/Dialect/SPIRV/IR/SPIRVTypes.h"
 #include "mlir/Dialect/SPIRV/Transforms/SPIRVConversion.h"
+#include "mlir/IR/BuiltinAttributes.h"
 #include "mlir/IR/BuiltinTypes.h"
+#include "mlir/IR/MLIRContext.h"
+#include "mlir/IR/Visitors.h"
+#include "mlir/Support/LogicalResult.h"
 #include "llvm/Support/Debug.h"
+#include <cassert>
 #include <optional>
 
 #define DEBUG_TYPE "memref-to-spirv-pattern"
@@ -439,6 +445,52 @@ DeallocOpPattern::matchAndRewrite(memref::DeallocOp operation,
 // LoadOp
 //===----------------------------------------------------------------------===//
 
+using AlignmentRequirements =
+    FailureOr<std::pair<spirv::MemoryAccessAttr, IntegerAttr>>;
+
+/// Given an accessed SPIR-V pointer, calculates its alignment requirements, if
+/// any.
+static AlignmentRequirements calculateRequiredAlignment(Value accessedPtr) {
+  auto ptrType = cast<spirv::PointerType>(accessedPtr.getType());
+  if (ptrType.getStorageClass() != spirv::StorageClass::PhysicalStorageBuffer)
+    return std::pair{spirv::MemoryAccessAttr{}, IntegerAttr{}};
+
+  // PhysicalStorageBuffers require the `Aligned` attribute.
+  auto pointeeType = dyn_cast<spirv::ScalarType>(ptrType.getPointeeType());
+  if (!pointeeType)
+    return failure();
+
+  // For scalar types, the alignment is determined by their size.
+  std::optional<int64_t> sizeInBytes = pointeeType.getSizeInBytes();
+  if (!sizeInBytes.has_value())
+    return failure();
+
+  MLIRContext *ctx = accessedPtr.getContext();
+  auto memAccessAttr =
+      spirv::MemoryAccessAttr::get(ctx, spirv::MemoryAccess::Aligned);
+  auto alignment = IntegerAttr::get(IntegerType::get(ctx, 32), *sizeInBytes);
+  return std::pair{memAccessAttr, alignment};
+}
+
+/// Given an accessed SPIR-V pointer and the original memref load/store
+/// `memAccess` op, calculates the alignment requirements, if any. Takes into
+/// account the alignment attributes applied to the load/store op.
+static AlignmentRequirements
+calculateRequiredAlignment(Value accessedPtr, Operation *memrefAccessOp) {
+  assert(memrefAccessOp);
+  assert((isa<memref::LoadOp, memref::StoreOp>(memrefAccessOp)) &&
+         "Bad op type");
+
+  auto memrefMemAccess = memrefAccessOp->getAttrOfType<spirv::MemoryAccessAttr>(
+      spirv::attributeName<spirv::MemoryAccess>());
+  auto memrefAlignment =
+      memrefAccessOp->getAttrOfType<IntegerAttr>("alignment");
+  if (memrefMemAccess && memrefAlignment)
+    return std::pair{memrefMemAccess, memrefAlignment};
+
+  return calculateRequiredAlignment(accessedPtr);
+}
+
 LogicalResult
 IntLoadOpPattern::matchAndRewrite(memref::LoadOp loadOp, OpAdaptor adaptor,
                                   ConversionPatternRewriter &rewriter) const {
@@ -486,7 +538,15 @@ IntLoadOpPattern::matchAndRewrite(memref::LoadOp loadOp, OpAdaptor adaptor,
   // If the rewritten load op has the same bit width, use the loading value
   // directly.
   if (srcBits == dstBits) {
-    Value loadVal = rewriter.create<spirv::LoadOp>(loc, accessChain);
+    AlignmentRequirements alignmentRequirements =
+        calculateRequiredAlignment(accessChain, loadOp);
+    if (failed(alignmentRequirements))
+      return rewriter.notifyMatchFailure(
+          loadOp, "failed to determine alignment requirements");
+
+    auto [memoryAccess, alignment] = *alignmentRequirements;
+    Value loadVal = rewriter.create<spirv::LoadOp>(loc, accessChain,
+                                                   memoryAccess, alignment);
     if (isBool)
       loadVal = castIntNToBool(loc, loadVal, rewriter);
     rewriter.replaceOp(loadOp, loadVal);
@@ -508,11 +568,15 @@ IntLoadOpPattern::matchAndRewrite(memref::LoadOp loadOp, OpAdaptor adaptor,
   assert(accessChainOp.getIndices().size() == 2);
   Value adjustedPtr = adjustAccessChainForBitwidth(typeConverter, accessChainOp,
                                                    srcBits, dstBits, rewriter);
-  Value spvLoadOp = rewriter.create<spirv::LoadOp>(
-      loc, dstType, adjustedPtr,
-      loadOp->getAttrOfType<spirv::MemoryAccessAttr>(
-          spirv::attributeName<spirv::MemoryAccess>()),
-      loadOp->getAttrOfType<IntegerAttr>("alignment"));
+  AlignmentRequirements alignmentRequirements =
+      calculateRequiredAlignment(adjustedPtr, loadOp);
+  if (failed(alignmentRequirements))
+    return rewriter.notifyMatchFailure(
+        loadOp, "failed to determine alignment requirements");
+
+  auto [memoryAccess, alignment] = *alignmentRequirements;
+  Value spvLoadOp = rewriter.create<spirv::LoadOp>(loc, dstType, adjustedPtr,
+                                                   memoryAccess, alignment);
 
   // Shift the bits to the rightmost.
   // ____XXXX________ -> ____________XXXX
@@ -552,14 +616,21 @@ LoadOpPattern::matchAndRewrite(memref::LoadOp loadOp, OpAdaptor adaptor,
   auto memrefType = cast<MemRefType>(loadOp.getMemref().getType());
   if (memrefType.getElementType().isSignlessInteger())
     return failure();
-  auto loadPtr = spirv::getElementPtr(
+  Value loadPtr = spirv::getElementPtr(
       *getTypeConverter<SPIRVTypeConverter>(), memrefType, adaptor.getMemref(),
       adaptor.getIndices(), loadOp.getLoc(), rewriter);
 
   if (!loadPtr)
     return failure();
 
-  rewriter.replaceOpWithNewOp<spirv::LoadOp>(loadOp, loadPtr);
+  AlignmentRequirements requiredAlignment = calculateRequiredAlignment(loadPtr);
+  if (failed(requiredAlignment))
+    return rewriter.notifyMatchFailure(
+        loadOp, "failed to determine alignment requirements");
+
+  auto [memAccessAttr, alignment] = *requiredAlignment;
+  rewriter.replaceOpWithNewOp<spirv::LoadOp>(loadOp, loadPtr, memAccessAttr,
+                                             alignment);
   return success();
 }
 
@@ -618,10 +689,18 @@ IntStoreOpPattern::matchAndRewrite(memref::StoreOp storeOp, OpAdaptor adaptor,
   assert(dstBits % srcBits == 0);
 
   if (srcBits == dstBits) {
+    AlignmentRequirements requiredAlignment =
+        calculateRequiredAlignment(accessChain);
+    if (failed(requiredAlignment))
+      return rewriter.notifyMatchFailure(
+          storeOp, "failed to determine alignment requirements");
+
+    auto [memAccessAttr, alignment] = *requiredAlignment;
     Value storeVal = adaptor.getValue();
     if (isBool)
       storeVal = castBoolToIntN(loc, storeVal, dstType, rewriter);
-    rewriter.replaceOpWithNewOp<spirv::StoreOp>(storeOp, accessChain, storeVal);
+    rewriter.replaceOpWithNewOp<spirv::StoreOp>(storeOp, accessChain, storeVal,
+                                                memAccessAttr, alignment);
     return success();
   }
 
diff --git a/mlir/test/Conversion/MemRefToSPIRV/memref-to-spirv.mlir b/mlir/test/Conversion/MemRefToSPIRV/memref-to-spirv.mlir
index a8b550367d5fa..aa05fd9bc8ca8 100644
--- a/mlir/test/Conversion/MemRefToSPIRV/memref-to-spirv.mlir
+++ b/mlir/test/Conversion/MemRefToSPIRV/memref-to-spirv.mlir
@@ -1,17 +1,19 @@
-// RUN: mlir-opt -split-input-file -convert-memref-to-spirv="bool-num-bits=8" -cse %s -o - | FileCheck %s
+// RUN: mlir-opt --split-input-file --convert-memref-to-spirv="bool-num-bits=8" --cse %s | FileCheck %s
 
 // Check that with proper compute and storage extensions, we don't need to
 // perform special tricks.
 
 module attributes {
   spirv.target_env = #spirv.target_env<
-    #spirv.vce<v1.0,
+    #spirv.vce<v1.5,
       [
         Shader, Int8, Int16, Int64, Float16, Float64,
         StorageBuffer16BitAccess, StorageUniform16, StoragePushConstant16,
-        StorageBuffer8BitAccess, UniformAndStorageBuffer8BitAccess, StoragePushConstant8
+        StorageBuffer8BitAccess, UniformAndStorageBuffer8BitAccess, StoragePushConstant8,
+        PhysicalStorageBufferAddresses
       ],
-      [SPV_KHR_16bit_storage, SPV_KHR_8bit_storage, SPV_KHR_storage_buffer_storage_class]>, #spirv.resource_limits<>>
+      [SPV_KHR_16bit_storage, SPV_KHR_8bit_storage, SPV_KHR_storage_buffer_storage_class, SPV_KHR_physical_storage_buffer]>,
+      #spirv.resource_limits<>>
 } {
 
 // CHECK-LABEL: @load_store_zero_rank_float
@@ -119,6 +121,51 @@ func.func @store_i16(%arg0: memref<10xi16, #spirv.storage_class<StorageBuffer>>,
   return
 }
 
+// CHECK-LABEL: @load_store_i32_physical
+func.func @load_store_i32_physical(%arg0: memref<i32, #spirv.storage_class<PhysicalStorageBuffer>>) {
+  //     CHECK: spirv.Load "PhysicalStorageBuffer" %{{.+}} ["Aligned", 4] : i32
+  //     CHECK: spirv.Store "PhysicalStorageBuffer" %{{.+}}, %{{.+}} ["Aligned", 4] : i32
+  %0 = memref.load %arg0[] : memref<i32, #spirv.storage_class<PhysicalStorageBuffer>>
+  memref.store %0, %arg0[] : memref<i32, #spirv.storage_class<PhysicalStorageBuffer>>
+  return
+}
+
+// CHECK-LABEL: @load_store_i8_physical
+func.func @load_store_i8_physical(%arg0: memref<i8, #spirv.storage_class<PhysicalStorageBuffer>>) {
+  //     CHECK: spirv.Load "PhysicalStorageBuffer" %{{.+}} ["Aligned", 1] : i8
+  //     CHECK: spirv.Store "PhysicalStorageBuffer" %{{.+}}, %{{.+}} ["Aligned", 1] : i8
+  %0 = memref.load %arg0[] : memref<i8, #spirv.storage_class<PhysicalStorageBuffer>>
+  memref.store %0, %arg0[] : memref<i8, #spirv.storage_class<PhysicalStorageBuffer>>
+  return
+}
+
+// CHECK-LABEL: @load_store_i1_physical
+func.func @load_store_i1_physical(%arg0: memref<i1, #spirv.storage_class<PhysicalStorageBuffer>>) {
+  //     CHECK: spirv.Load "PhysicalStorageBuffer" %{{.+}} ["Aligned", 1] : i8
+  //     CHECK: spirv.Store "PhysicalStorageBuffer" %{{.+}}, %{{.+}} ["Aligned", 1] : i8
+  %0 = memref.load %arg0[] : memref<i1, #spirv.storage_class<PhysicalStorageBuffer>>
+  memref.store %0, %arg0[] : memref<i1, #spirv.storage_class<PhysicalStorageBuffer>>
+  return
+}
+
+// CHECK-LABEL: @load_store_f32_physical
+func.func @load_store_f32_physical(%arg0: memref<f32, #spirv.storage_class<PhysicalStorageBuffer>>) {
+  //     CHECK: spirv.Load "PhysicalStorageBuffer" %{{.+}} ["Aligned", 4] : f32
+  //     CHECK: spirv.Store "PhysicalStorageBuffer" %{{.+}}, %{{.+}} ["Aligned", 4] : f32
+  %0 = memref.load %arg0[] : memref<f32, #spirv.storage_class<PhysicalStorageBuffer>>
+  memref.store %0, %arg0[] : memref<f32, #spirv.storage_class<PhysicalStorageBuffer>>
+  return
+}
+
+// CHECK-LABEL: @load_store_f16_physical
+func.func @load_store_f16_physical(%arg0: memref<f16, #spirv.storage_class<PhysicalStorageBuffer>>) {
+  //     CHECK: spirv.Load "PhysicalStorageBuffer" %{{.+}} ["Aligned", 2] : f16
+  //     CHECK: spirv.Store "PhysicalStorageBuffer" %{{.+}}, %{{.+}} ["Aligned", 2] : f16
+  %0 = memref.load %arg0[] : memref<f16, #spirv.storage_class<PhysicalStorageBuffer>>
+  memref.store %0, %arg0[] : memref<f16, #spirv.storage_class<PhysicalStorageBuffer>>
+  return
+}
+
 } // end module
 
 // -----

@llvmbot
Copy link
Collaborator

llvmbot commented Feb 1, 2024

@llvm/pr-subscribers-mlir-spirv

Author: Jakub Kuderski (kuhar)

Changes

The SPIR-V spec requires that memory accesses to PhysicalStorageBuffers are annotated with appropriate alignment attributes [1]. Calculate these based on memref alignment attributes or scalar type sizes.

[1] Otherwise spirv-val complains:

[VULKAN] ! Validation Error: [ VUID-VkShaderModuleCreateInfo-pCode-01379 ] | MessageID = 0x2a1bf17f | SPIR-V module not valid: [VUID-StandaloneSpirv-PhysicalStorageBuffer64-04708] Memory accesses with PhysicalStorageBuffer must use Aligned.
  %48 = OpLoad %float %47

Full diff: https://github.com/llvm/llvm-project/pull/80243.diff

2 Files Affected:

  • (modified) mlir/lib/Conversion/MemRefToSPIRV/MemRefToSPIRV.cpp (+88-9)
  • (modified) mlir/test/Conversion/MemRefToSPIRV/memref-to-spirv.mlir (+51-4)
diff --git a/mlir/lib/Conversion/MemRefToSPIRV/MemRefToSPIRV.cpp b/mlir/lib/Conversion/MemRefToSPIRV/MemRefToSPIRV.cpp
index acddb3c4da461..c0470ce189cd3 100644
--- a/mlir/lib/Conversion/MemRefToSPIRV/MemRefToSPIRV.cpp
+++ b/mlir/lib/Conversion/MemRefToSPIRV/MemRefToSPIRV.cpp
@@ -12,12 +12,18 @@
 
 #include "mlir/Dialect/Arith/IR/Arith.h"
 #include "mlir/Dialect/MemRef/IR/MemRef.h"
+#include "mlir/Dialect/SPIRV/IR/SPIRVAttributes.h"
 #include "mlir/Dialect/SPIRV/IR/SPIRVEnums.h"
 #include "mlir/Dialect/SPIRV/IR/SPIRVOps.h"
 #include "mlir/Dialect/SPIRV/IR/SPIRVTypes.h"
 #include "mlir/Dialect/SPIRV/Transforms/SPIRVConversion.h"
+#include "mlir/IR/BuiltinAttributes.h"
 #include "mlir/IR/BuiltinTypes.h"
+#include "mlir/IR/MLIRContext.h"
+#include "mlir/IR/Visitors.h"
+#include "mlir/Support/LogicalResult.h"
 #include "llvm/Support/Debug.h"
+#include <cassert>
 #include <optional>
 
 #define DEBUG_TYPE "memref-to-spirv-pattern"
@@ -439,6 +445,52 @@ DeallocOpPattern::matchAndRewrite(memref::DeallocOp operation,
 // LoadOp
 //===----------------------------------------------------------------------===//
 
+using AlignmentRequirements =
+    FailureOr<std::pair<spirv::MemoryAccessAttr, IntegerAttr>>;
+
+/// Given an accessed SPIR-V pointer, calculates its alignment requirements, if
+/// any.
+static AlignmentRequirements calculateRequiredAlignment(Value accessedPtr) {
+  auto ptrType = cast<spirv::PointerType>(accessedPtr.getType());
+  if (ptrType.getStorageClass() != spirv::StorageClass::PhysicalStorageBuffer)
+    return std::pair{spirv::MemoryAccessAttr{}, IntegerAttr{}};
+
+  // PhysicalStorageBuffers require the `Aligned` attribute.
+  auto pointeeType = dyn_cast<spirv::ScalarType>(ptrType.getPointeeType());
+  if (!pointeeType)
+    return failure();
+
+  // For scalar types, the alignment is determined by their size.
+  std::optional<int64_t> sizeInBytes = pointeeType.getSizeInBytes();
+  if (!sizeInBytes.has_value())
+    return failure();
+
+  MLIRContext *ctx = accessedPtr.getContext();
+  auto memAccessAttr =
+      spirv::MemoryAccessAttr::get(ctx, spirv::MemoryAccess::Aligned);
+  auto alignment = IntegerAttr::get(IntegerType::get(ctx, 32), *sizeInBytes);
+  return std::pair{memAccessAttr, alignment};
+}
+
+/// Given an accessed SPIR-V pointer and the original memref load/store
+/// `memAccess` op, calculates the alignment requirements, if any. Takes into
+/// account the alignment attributes applied to the load/store op.
+static AlignmentRequirements
+calculateRequiredAlignment(Value accessedPtr, Operation *memrefAccessOp) {
+  assert(memrefAccessOp);
+  assert((isa<memref::LoadOp, memref::StoreOp>(memrefAccessOp)) &&
+         "Bad op type");
+
+  auto memrefMemAccess = memrefAccessOp->getAttrOfType<spirv::MemoryAccessAttr>(
+      spirv::attributeName<spirv::MemoryAccess>());
+  auto memrefAlignment =
+      memrefAccessOp->getAttrOfType<IntegerAttr>("alignment");
+  if (memrefMemAccess && memrefAlignment)
+    return std::pair{memrefMemAccess, memrefAlignment};
+
+  return calculateRequiredAlignment(accessedPtr);
+}
+
 LogicalResult
 IntLoadOpPattern::matchAndRewrite(memref::LoadOp loadOp, OpAdaptor adaptor,
                                   ConversionPatternRewriter &rewriter) const {
@@ -486,7 +538,15 @@ IntLoadOpPattern::matchAndRewrite(memref::LoadOp loadOp, OpAdaptor adaptor,
   // If the rewritten load op has the same bit width, use the loading value
   // directly.
   if (srcBits == dstBits) {
-    Value loadVal = rewriter.create<spirv::LoadOp>(loc, accessChain);
+    AlignmentRequirements alignmentRequirements =
+        calculateRequiredAlignment(accessChain, loadOp);
+    if (failed(alignmentRequirements))
+      return rewriter.notifyMatchFailure(
+          loadOp, "failed to determine alignment requirements");
+
+    auto [memoryAccess, alignment] = *alignmentRequirements;
+    Value loadVal = rewriter.create<spirv::LoadOp>(loc, accessChain,
+                                                   memoryAccess, alignment);
     if (isBool)
       loadVal = castIntNToBool(loc, loadVal, rewriter);
     rewriter.replaceOp(loadOp, loadVal);
@@ -508,11 +568,15 @@ IntLoadOpPattern::matchAndRewrite(memref::LoadOp loadOp, OpAdaptor adaptor,
   assert(accessChainOp.getIndices().size() == 2);
   Value adjustedPtr = adjustAccessChainForBitwidth(typeConverter, accessChainOp,
                                                    srcBits, dstBits, rewriter);
-  Value spvLoadOp = rewriter.create<spirv::LoadOp>(
-      loc, dstType, adjustedPtr,
-      loadOp->getAttrOfType<spirv::MemoryAccessAttr>(
-          spirv::attributeName<spirv::MemoryAccess>()),
-      loadOp->getAttrOfType<IntegerAttr>("alignment"));
+  AlignmentRequirements alignmentRequirements =
+      calculateRequiredAlignment(adjustedPtr, loadOp);
+  if (failed(alignmentRequirements))
+    return rewriter.notifyMatchFailure(
+        loadOp, "failed to determine alignment requirements");
+
+  auto [memoryAccess, alignment] = *alignmentRequirements;
+  Value spvLoadOp = rewriter.create<spirv::LoadOp>(loc, dstType, adjustedPtr,
+                                                   memoryAccess, alignment);
 
   // Shift the bits to the rightmost.
   // ____XXXX________ -> ____________XXXX
@@ -552,14 +616,21 @@ LoadOpPattern::matchAndRewrite(memref::LoadOp loadOp, OpAdaptor adaptor,
   auto memrefType = cast<MemRefType>(loadOp.getMemref().getType());
   if (memrefType.getElementType().isSignlessInteger())
     return failure();
-  auto loadPtr = spirv::getElementPtr(
+  Value loadPtr = spirv::getElementPtr(
       *getTypeConverter<SPIRVTypeConverter>(), memrefType, adaptor.getMemref(),
       adaptor.getIndices(), loadOp.getLoc(), rewriter);
 
   if (!loadPtr)
     return failure();
 
-  rewriter.replaceOpWithNewOp<spirv::LoadOp>(loadOp, loadPtr);
+  AlignmentRequirements requiredAlignment = calculateRequiredAlignment(loadPtr);
+  if (failed(requiredAlignment))
+    return rewriter.notifyMatchFailure(
+        loadOp, "failed to determine alignment requirements");
+
+  auto [memAccessAttr, alignment] = *requiredAlignment;
+  rewriter.replaceOpWithNewOp<spirv::LoadOp>(loadOp, loadPtr, memAccessAttr,
+                                             alignment);
   return success();
 }
 
@@ -618,10 +689,18 @@ IntStoreOpPattern::matchAndRewrite(memref::StoreOp storeOp, OpAdaptor adaptor,
   assert(dstBits % srcBits == 0);
 
   if (srcBits == dstBits) {
+    AlignmentRequirements requiredAlignment =
+        calculateRequiredAlignment(accessChain);
+    if (failed(requiredAlignment))
+      return rewriter.notifyMatchFailure(
+          storeOp, "failed to determine alignment requirements");
+
+    auto [memAccessAttr, alignment] = *requiredAlignment;
     Value storeVal = adaptor.getValue();
     if (isBool)
       storeVal = castBoolToIntN(loc, storeVal, dstType, rewriter);
-    rewriter.replaceOpWithNewOp<spirv::StoreOp>(storeOp, accessChain, storeVal);
+    rewriter.replaceOpWithNewOp<spirv::StoreOp>(storeOp, accessChain, storeVal,
+                                                memAccessAttr, alignment);
     return success();
   }
 
diff --git a/mlir/test/Conversion/MemRefToSPIRV/memref-to-spirv.mlir b/mlir/test/Conversion/MemRefToSPIRV/memref-to-spirv.mlir
index a8b550367d5fa..aa05fd9bc8ca8 100644
--- a/mlir/test/Conversion/MemRefToSPIRV/memref-to-spirv.mlir
+++ b/mlir/test/Conversion/MemRefToSPIRV/memref-to-spirv.mlir
@@ -1,17 +1,19 @@
-// RUN: mlir-opt -split-input-file -convert-memref-to-spirv="bool-num-bits=8" -cse %s -o - | FileCheck %s
+// RUN: mlir-opt --split-input-file --convert-memref-to-spirv="bool-num-bits=8" --cse %s | FileCheck %s
 
 // Check that with proper compute and storage extensions, we don't need to
 // perform special tricks.
 
 module attributes {
   spirv.target_env = #spirv.target_env<
-    #spirv.vce<v1.0,
+    #spirv.vce<v1.5,
       [
         Shader, Int8, Int16, Int64, Float16, Float64,
         StorageBuffer16BitAccess, StorageUniform16, StoragePushConstant16,
-        StorageBuffer8BitAccess, UniformAndStorageBuffer8BitAccess, StoragePushConstant8
+        StorageBuffer8BitAccess, UniformAndStorageBuffer8BitAccess, StoragePushConstant8,
+        PhysicalStorageBufferAddresses
       ],
-      [SPV_KHR_16bit_storage, SPV_KHR_8bit_storage, SPV_KHR_storage_buffer_storage_class]>, #spirv.resource_limits<>>
+      [SPV_KHR_16bit_storage, SPV_KHR_8bit_storage, SPV_KHR_storage_buffer_storage_class, SPV_KHR_physical_storage_buffer]>,
+      #spirv.resource_limits<>>
 } {
 
 // CHECK-LABEL: @load_store_zero_rank_float
@@ -119,6 +121,51 @@ func.func @store_i16(%arg0: memref<10xi16, #spirv.storage_class<StorageBuffer>>,
   return
 }
 
+// CHECK-LABEL: @load_store_i32_physical
+func.func @load_store_i32_physical(%arg0: memref<i32, #spirv.storage_class<PhysicalStorageBuffer>>) {
+  //     CHECK: spirv.Load "PhysicalStorageBuffer" %{{.+}} ["Aligned", 4] : i32
+  //     CHECK: spirv.Store "PhysicalStorageBuffer" %{{.+}}, %{{.+}} ["Aligned", 4] : i32
+  %0 = memref.load %arg0[] : memref<i32, #spirv.storage_class<PhysicalStorageBuffer>>
+  memref.store %0, %arg0[] : memref<i32, #spirv.storage_class<PhysicalStorageBuffer>>
+  return
+}
+
+// CHECK-LABEL: @load_store_i8_physical
+func.func @load_store_i8_physical(%arg0: memref<i8, #spirv.storage_class<PhysicalStorageBuffer>>) {
+  //     CHECK: spirv.Load "PhysicalStorageBuffer" %{{.+}} ["Aligned", 1] : i8
+  //     CHECK: spirv.Store "PhysicalStorageBuffer" %{{.+}}, %{{.+}} ["Aligned", 1] : i8
+  %0 = memref.load %arg0[] : memref<i8, #spirv.storage_class<PhysicalStorageBuffer>>
+  memref.store %0, %arg0[] : memref<i8, #spirv.storage_class<PhysicalStorageBuffer>>
+  return
+}
+
+// CHECK-LABEL: @load_store_i1_physical
+func.func @load_store_i1_physical(%arg0: memref<i1, #spirv.storage_class<PhysicalStorageBuffer>>) {
+  //     CHECK: spirv.Load "PhysicalStorageBuffer" %{{.+}} ["Aligned", 1] : i8
+  //     CHECK: spirv.Store "PhysicalStorageBuffer" %{{.+}}, %{{.+}} ["Aligned", 1] : i8
+  %0 = memref.load %arg0[] : memref<i1, #spirv.storage_class<PhysicalStorageBuffer>>
+  memref.store %0, %arg0[] : memref<i1, #spirv.storage_class<PhysicalStorageBuffer>>
+  return
+}
+
+// CHECK-LABEL: @load_store_f32_physical
+func.func @load_store_f32_physical(%arg0: memref<f32, #spirv.storage_class<PhysicalStorageBuffer>>) {
+  //     CHECK: spirv.Load "PhysicalStorageBuffer" %{{.+}} ["Aligned", 4] : f32
+  //     CHECK: spirv.Store "PhysicalStorageBuffer" %{{.+}}, %{{.+}} ["Aligned", 4] : f32
+  %0 = memref.load %arg0[] : memref<f32, #spirv.storage_class<PhysicalStorageBuffer>>
+  memref.store %0, %arg0[] : memref<f32, #spirv.storage_class<PhysicalStorageBuffer>>
+  return
+}
+
+// CHECK-LABEL: @load_store_f16_physical
+func.func @load_store_f16_physical(%arg0: memref<f16, #spirv.storage_class<PhysicalStorageBuffer>>) {
+  //     CHECK: spirv.Load "PhysicalStorageBuffer" %{{.+}} ["Aligned", 2] : f16
+  //     CHECK: spirv.Store "PhysicalStorageBuffer" %{{.+}}, %{{.+}} ["Aligned", 2] : f16
+  %0 = memref.load %arg0[] : memref<f16, #spirv.storage_class<PhysicalStorageBuffer>>
+  memref.store %0, %arg0[] : memref<f16, #spirv.storage_class<PhysicalStorageBuffer>>
+  return
+}
+
 } // end module
 
 // -----

Copy link
Member

@raikonenfnu raikonenfnu left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Overall Looks good!

NIT: Maybe have a failing case on the LIT test. But for this specific case may be not super impactful.

@kuhar kuhar merged commit 8fd0bce into llvm:main Feb 1, 2024
4 checks passed
agozillon pushed a commit to agozillon/llvm-project that referenced this pull request Feb 5, 2024
llvm#80243)

The SPIR-V spec requires that memory accesses to
`PhysicalStorageBuffer`s are annotated with appropriate alignment
attributes [1]. Calculate these based on memref alignment attributes or
scalar type sizes.

[1] Otherwise spirv-val complains:
```
[VULKAN] ! Validation Error: [ VUID-VkShaderModuleCreateInfo-pCode-01379 ] | MessageID = 0x2a1bf17f | SPIR-V module not valid: [VUID-StandaloneSpirv-PhysicalStorageBuffer64-04708] Memory accesses with PhysicalStorageBuffer must use Aligned.
  %48 = OpLoad %float %47
```
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants