Skip to content

Conversation

FranklandJack
Copy link
Contributor

When converting a memref.load from the image address space to a spirv.ImageFetch ensure that we reverse the load coordinates when they are extracted from the load.

This is required because the coordinate operand to the fetch operation is a vector with the coordinates in increasing rank whereas the load operation has coordinates of decreaseing rank. For example, if the memref.load operation loaded from the coordinates [%z, %y, %x] then %x would be the fastest moving dimension followed by %y then %z so the spirv.ImageFetch operation would expect a vector <%x, %y, %z>.

@llvmbot
Copy link
Member

llvmbot commented Sep 24, 2025

@llvm/pr-subscribers-mlir-spirv

@llvm/pr-subscribers-mlir

Author: Jack Frankland (FranklandJack)

Changes

When converting a memref.load from the image address space to a spirv.ImageFetch ensure that we reverse the load coordinates when they are extracted from the load.

This is required because the coordinate operand to the fetch operation is a vector with the coordinates in increasing rank whereas the load operation has coordinates of decreaseing rank. For example, if the memref.load operation loaded from the coordinates [%z, %y, %x] then %x would be the fastest moving dimension followed by %y then %z so the spirv.ImageFetch operation would expect a vector &lt;%x, %y, %z&gt;.


Patch is 23.82 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/160495.diff

2 Files Affected:

  • (modified) mlir/lib/Conversion/MemRefToSPIRV/MemRefToSPIRV.cpp (+3-1)
  • (modified) mlir/test/Conversion/MemRefToSPIRV/memref-to-spirv.mlir (+94-56)
diff --git a/mlir/lib/Conversion/MemRefToSPIRV/MemRefToSPIRV.cpp b/mlir/lib/Conversion/MemRefToSPIRV/MemRefToSPIRV.cpp
index 49d06497dbeea..aba56e5ca3f32 100644
--- a/mlir/lib/Conversion/MemRefToSPIRV/MemRefToSPIRV.cpp
+++ b/mlir/lib/Conversion/MemRefToSPIRV/MemRefToSPIRV.cpp
@@ -758,8 +758,10 @@ ImageLoadOpPattern::matchAndRewrite(memref::LoadOp loadOp, OpAdaptor adaptor,
   if (memrefType.getRank() != 1) {
     auto coordVectorType = VectorType::get({loadOp.getMemRefType().getRank()},
                                            adaptor.getIndices().getType()[0]);
+    auto indices = llvm::to_vector(adaptor.getIndices());
+    auto indicesReversed = llvm::to_vector(llvm::reverse(indices));
     coords = spirv::CompositeConstructOp::create(rewriter, loc, coordVectorType,
-                                                 adaptor.getIndices());
+                                                 indicesReversed);
   } else {
     coords = adaptor.getIndices()[0];
   }
diff --git a/mlir/test/Conversion/MemRefToSPIRV/memref-to-spirv.mlir b/mlir/test/Conversion/MemRefToSPIRV/memref-to-spirv.mlir
index e6321e99693ac..56bf4939a0e63 100644
--- a/mlir/test/Conversion/MemRefToSPIRV/memref-to-spirv.mlir
+++ b/mlir/test/Conversion/MemRefToSPIRV/memref-to-spirv.mlir
@@ -550,121 +550,159 @@ module attributes {
   }
 
   // CHECK-LABEL: @load_from_image_2D(
-  // CHECK-SAME: %[[ARG0:.*]]: memref<1x1xf32, #spirv.storage_class<Image>>, %[[ARG1:.*]]: memref<1x1xf32, #spirv.storage_class<StorageBuffer>>
-  func.func @load_from_image_2D(%arg0: memref<1x1xf32, #spirv.storage_class<Image>>, %arg1: memref<1x1xf32, #spirv.storage_class<StorageBuffer>>) {
-// CHECK-DAG: %[[SB:.*]] = builtin.unrealized_conversion_cast %arg1 : memref<1x1xf32, #spirv.storage_class<StorageBuffer>> to !spirv.ptr<!spirv.struct<(!spirv.array<1 x f32, stride=4> [0])>, StorageBuffer>
-// CHECK-DAG: %[[IMAGE_PTR:.*]] = builtin.unrealized_conversion_cast %arg0 : memref<1x1xf32, #spirv.storage_class<Image>> to !spirv.ptr<!spirv.sampled_image<!spirv.image<f32, Dim2D, DepthUnknown, NonArrayed, SingleSampled, NeedSampler, R32f>>, UniformConstant>
-    %cst = arith.constant 0 : index
+  // CHECK-SAME: %[[ARG0:.*]]: memref<2x2xf32, #spirv.storage_class<Image>>, %[[ARG1:.*]]: memref<2x2xf32, #spirv.storage_class<StorageBuffer>>
+  func.func @load_from_image_2D(%arg0: memref<2x2xf32, #spirv.storage_class<Image>>, %arg1: memref<2x2xf32, #spirv.storage_class<StorageBuffer>>) {
+// CHECK-DAG: %[[SB:.*]] = builtin.unrealized_conversion_cast %arg1 : memref<2x2xf32, #spirv.storage_class<StorageBuffer>> to !spirv.ptr<!spirv.struct<(!spirv.array<4 x f32, stride=4> [0])>, StorageBuffer>
+// CHECK-DAG: %[[IMAGE_PTR:.*]] = builtin.unrealized_conversion_cast %arg0 : memref<2x2xf32, #spirv.storage_class<Image>> to !spirv.ptr<!spirv.sampled_image<!spirv.image<f32, Dim2D, DepthUnknown, NonArrayed, SingleSampled, NeedSampler, R32f>>, UniformConstant>
+    // CHECK: %[[X:.*]] = arith.constant 0 : index
+    // CHECK: %[[X32:.*]] = builtin.unrealized_conversion_cast %[[X]] : index to i32
+    %x = arith.constant 0 : index
+    // CHECK: %[[Y:.*]] = arith.constant 1 : index
+    // CHECK: %[[Y32:.*]] = builtin.unrealized_conversion_cast %[[Y]] : index to i32
+    %y = arith.constant 1 : index
     // CHECK: %[[SIMAGE:.*]] = spirv.Load "UniformConstant" %[[IMAGE_PTR]] : !spirv.sampled_image<!spirv.image<f32, Dim2D, DepthUnknown, NonArrayed, SingleSampled, NeedSampler, R32f>>
     // CHECK: %[[IMAGE:.*]] = spirv.Image %[[SIMAGE]] : !spirv.sampled_image<!spirv.image<f32, Dim2D, DepthUnknown, NonArrayed, SingleSampled, NeedSampler, R32f>>
-    // CHECK: %[[COORDS:.*]] = spirv.CompositeConstruct %{{.*}}, %{{.*}} : (i32, i32) -> vector<2xi32>
+    // CHECK: %[[COORDS:.*]] = spirv.CompositeConstruct %[[X32]], %[[Y32]] : (i32, i32) -> vector<2xi32>
     // CHECK: %[[RES_VEC:.*]] =  spirv.ImageFetch %[[IMAGE]], %[[COORDS]]  : !spirv.image<f32, Dim2D, DepthUnknown, NonArrayed, SingleSampled, NeedSampler, R32f>, vector<2xi32> -> vector<4xf32>
     // CHECK: %[[RESULT:.*]] = spirv.CompositeExtract %[[RES_VEC]][0 : i32] : vector<4xf32>
-    %0 = memref.load %arg0[%cst, %cst] : memref<1x1xf32, #spirv.storage_class<Image>>
+    %0 = memref.load %arg0[%y, %x] : memref<2x2xf32, #spirv.storage_class<Image>>
     // CHECK: spirv.Store "StorageBuffer" %{{.*}}, %[[RESULT]] : f32
-    memref.store %0, %arg1[%cst, %cst] : memref<1x1xf32, #spirv.storage_class<StorageBuffer>>
+    memref.store %0, %arg1[%y, %x] : memref<2x2xf32, #spirv.storage_class<StorageBuffer>>
     return
   }
 
   // CHECK-LABEL: @load_from_image_3D(
-  // CHECK-SAME: %[[ARG0:.*]]: memref<1x1x1xf32, #spirv.storage_class<Image>>, %[[ARG1:.*]]: memref<1x1x1xf32, #spirv.storage_class<StorageBuffer>>
-  func.func @load_from_image_3D(%arg0: memref<1x1x1xf32, #spirv.storage_class<Image>>, %arg1: memref<1x1x1xf32, #spirv.storage_class<StorageBuffer>>) {
-// CHECK-DAG: %[[SB:.*]] = builtin.unrealized_conversion_cast %arg1 : memref<1x1x1xf32, #spirv.storage_class<StorageBuffer>> to !spirv.ptr<!spirv.struct<(!spirv.array<1 x f32, stride=4> [0])>, StorageBuffer>
-// CHECK-DAG: %[[IMAGE_PTR:.*]] = builtin.unrealized_conversion_cast %arg0 : memref<1x1x1xf32, #spirv.storage_class<Image>> to !spirv.ptr<!spirv.sampled_image<!spirv.image<f32, Dim3D, DepthUnknown, NonArrayed, SingleSampled, NeedSampler, R32f>>, UniformConstant>
-    %cst = arith.constant 0 : index
+  // CHECK-SAME: %[[ARG0:.*]]: memref<3x3x3xf32, #spirv.storage_class<Image>>, %[[ARG1:.*]]: memref<3x3x3xf32, #spirv.storage_class<StorageBuffer>>
+  func.func @load_from_image_3D(%arg0: memref<3x3x3xf32, #spirv.storage_class<Image>>, %arg1: memref<3x3x3xf32, #spirv.storage_class<StorageBuffer>>) {
+// CHECK-DAG: %[[SB:.*]] = builtin.unrealized_conversion_cast %arg1 : memref<3x3x3xf32, #spirv.storage_class<StorageBuffer>> to !spirv.ptr<!spirv.struct<(!spirv.array<27 x f32, stride=4> [0])>, StorageBuffer>
+// CHECK-DAG: %[[IMAGE_PTR:.*]] = builtin.unrealized_conversion_cast %arg0 : memref<3x3x3xf32, #spirv.storage_class<Image>> to !spirv.ptr<!spirv.sampled_image<!spirv.image<f32, Dim3D, DepthUnknown, NonArrayed, SingleSampled, NeedSampler, R32f>>, UniformConstant>
+    // CHECK: %[[X:.*]] = arith.constant 0 : index
+    // CHECK: %[[X32:.*]] = builtin.unrealized_conversion_cast %[[X]] : index to i32
+    %x = arith.constant 0 : index
+    // CHECK: %[[Y:.*]] = arith.constant 1 : index
+    // CHECK: %[[Y32:.*]] = builtin.unrealized_conversion_cast %[[Y]] : index to i32
+    %y = arith.constant 1 : index
+    // CHECK: %[[Z:.*]] = arith.constant 2 : index
+    // CHECK: %[[Z32:.*]] = builtin.unrealized_conversion_cast %[[Z]] : index to i32
+    %z = arith.constant 2 : index
     // CHECK: %[[SIMAGE:.*]] = spirv.Load "UniformConstant" %[[IMAGE_PTR]] : !spirv.sampled_image<!spirv.image<f32, Dim3D, DepthUnknown, NonArrayed, SingleSampled, NeedSampler, R32f>>
     // CHECK: %[[IMAGE:.*]] = spirv.Image %[[SIMAGE]] : !spirv.sampled_image<!spirv.image<f32, Dim3D, DepthUnknown, NonArrayed, SingleSampled, NeedSampler, R32f>>
-    // CHECK: %[[COORDS:.*]] = spirv.CompositeConstruct %{{.*}}, %{{.*}}, %{{.*}} : (i32, i32, i32) -> vector<3xi32>
+    // CHECK: %[[COORDS:.*]] = spirv.CompositeConstruct %[[X32]], %[[Y32]], %[[Z32]] : (i32, i32, i32) -> vector<3xi32>
     // CHECK: %[[RES_VEC:.*]] =  spirv.ImageFetch %[[IMAGE]], %[[COORDS]]  : !spirv.image<f32, Dim3D, DepthUnknown, NonArrayed, SingleSampled, NeedSampler, R32f>, vector<3xi32> -> vector<4xf32>
     // CHECK: %[[RESULT:.*]] = spirv.CompositeExtract %[[RES_VEC]][0 : i32] : vector<4xf32>
-    %0 = memref.load %arg0[%cst, %cst, %cst] : memref<1x1x1xf32, #spirv.storage_class<Image>>
+    %0 = memref.load %arg0[%z, %y, %x] : memref<3x3x3xf32, #spirv.storage_class<Image>>
     // CHECK: spirv.Store "StorageBuffer" %{{.*}}, %[[RESULT]] : f32
-    memref.store %0, %arg1[%cst, %cst, %cst] : memref<1x1x1xf32, #spirv.storage_class<StorageBuffer>>
+    memref.store %0, %arg1[%z, %y, %x] : memref<3x3x3xf32, #spirv.storage_class<StorageBuffer>>
     return
   }
 
   // CHECK-LABEL: @load_from_image_2D_f16(
-  // CHECK-SAME: %[[ARG0:.*]]: memref<1x1xf16, #spirv.storage_class<Image>>, %[[ARG1:.*]]: memref<1x1xf16, #spirv.storage_class<StorageBuffer>>
-  func.func @load_from_image_2D_f16(%arg0: memref<1x1xf16, #spirv.storage_class<Image>>, %arg1: memref<1x1xf16, #spirv.storage_class<StorageBuffer>>) {
-// CHECK-DAG: %[[SB:.*]] = builtin.unrealized_conversion_cast %arg1 : memref<1x1xf16, #spirv.storage_class<StorageBuffer>> to !spirv.ptr<!spirv.struct<(!spirv.array<1 x f16, stride=2> [0])>, StorageBuffer>
-// CHECK-DAG: %[[IMAGE_PTR:.*]] = builtin.unrealized_conversion_cast %arg0 : memref<1x1xf16, #spirv.storage_class<Image>> to !spirv.ptr<!spirv.sampled_image<!spirv.image<f16, Dim2D, DepthUnknown, NonArrayed, SingleSampled, NeedSampler, R16f>>, UniformConstant>
-    %cst = arith.constant 0 : index
+  // CHECK-SAME: %[[ARG0:.*]]: memref<2x2xf16, #spirv.storage_class<Image>>, %[[ARG1:.*]]: memref<2x2xf16, #spirv.storage_class<StorageBuffer>>
+  func.func @load_from_image_2D_f16(%arg0: memref<2x2xf16, #spirv.storage_class<Image>>, %arg1: memref<2x2xf16, #spirv.storage_class<StorageBuffer>>) {
+// CHECK-DAG: %[[SB:.*]] = builtin.unrealized_conversion_cast %arg1 : memref<2x2xf16, #spirv.storage_class<StorageBuffer>> to !spirv.ptr<!spirv.struct<(!spirv.array<4 x f16, stride=2> [0])>, StorageBuffer>
+// CHECK-DAG: %[[IMAGE_PTR:.*]] = builtin.unrealized_conversion_cast %arg0 : memref<2x2xf16, #spirv.storage_class<Image>> to !spirv.ptr<!spirv.sampled_image<!spirv.image<f16, Dim2D, DepthUnknown, NonArrayed, SingleSampled, NeedSampler, R16f>>, UniformConstant>
+    // CHECK: %[[X:.*]] = arith.constant 0 : index
+    // CHECK: %[[X32:.*]] = builtin.unrealized_conversion_cast %[[X]] : index to i32
+    %x = arith.constant 0 : index
+    // CHECK: %[[Y:.*]] = arith.constant 1 : index
+    // CHECK: %[[Y32:.*]] = builtin.unrealized_conversion_cast %[[Y]] : index to i32
+    %y = arith.constant 1 : index
     // CHECK: %[[SIMAGE:.*]] = spirv.Load "UniformConstant" %[[IMAGE_PTR]] : !spirv.sampled_image<!spirv.image<f16, Dim2D, DepthUnknown, NonArrayed, SingleSampled, NeedSampler, R16f>>
     // CHECK: %[[IMAGE:.*]] = spirv.Image %[[SIMAGE]] : !spirv.sampled_image<!spirv.image<f16, Dim2D, DepthUnknown, NonArrayed, SingleSampled, NeedSampler, R16f>>
-    // CHECK: %[[COORDS:.*]] = spirv.CompositeConstruct %{{.*}}, %{{.*}} : (i32, i32) -> vector<2xi32>
+    // CHECK: %[[COORDS:.*]] = spirv.CompositeConstruct %[[X32]], %[[Y32]] : (i32, i32) -> vector<2xi32>
     // CHECK: %[[RES_VEC:.*]] =  spirv.ImageFetch %[[IMAGE]], %[[COORDS]]  : !spirv.image<f16, Dim2D, DepthUnknown, NonArrayed, SingleSampled, NeedSampler, R16f>, vector<2xi32> -> vector<4xf16>
     // CHECK: %[[RESULT:.*]] = spirv.CompositeExtract %[[RES_VEC]][0 : i32] : vector<4xf16>
-    %0 = memref.load %arg0[%cst, %cst] : memref<1x1xf16, #spirv.storage_class<Image>>
+    %0 = memref.load %arg0[%y, %x] : memref<2x2xf16, #spirv.storage_class<Image>>
     // CHECK: spirv.Store "StorageBuffer" %{{.*}}, %[[RESULT]] : f16
-    memref.store %0, %arg1[%cst, %cst] : memref<1x1xf16, #spirv.storage_class<StorageBuffer>>
+    memref.store %0, %arg1[%y, %x] : memref<2x2xf16, #spirv.storage_class<StorageBuffer>>
     return
   }
 
   // CHECK-LABEL: @load_from_image_2D_i32(
-  // CHECK-SAME: %[[ARG0:.*]]: memref<1x1xi32, #spirv.storage_class<Image>>, %[[ARG1:.*]]: memref<1x1xi32, #spirv.storage_class<StorageBuffer>>
-  func.func @load_from_image_2D_i32(%arg0: memref<1x1xi32, #spirv.storage_class<Image>>, %arg1: memref<1x1xi32, #spirv.storage_class<StorageBuffer>>) {
-// CHECK-DAG: %[[SB:.*]] = builtin.unrealized_conversion_cast %arg1 : memref<1x1xi32, #spirv.storage_class<StorageBuffer>> to !spirv.ptr<!spirv.struct<(!spirv.array<1 x i32, stride=4> [0])>, StorageBuffer>
-// CHECK-DAG: %[[IMAGE_PTR:.*]] = builtin.unrealized_conversion_cast %arg0 : memref<1x1xi32, #spirv.storage_class<Image>> to !spirv.ptr<!spirv.sampled_image<!spirv.image<i32, Dim2D, DepthUnknown, NonArrayed, SingleSampled, NeedSampler, R32i>>, UniformConstant>
-    %cst = arith.constant 0 : index
+  // CHECK-SAME: %[[ARG0:.*]]: memref<2x2xi32, #spirv.storage_class<Image>>, %[[ARG1:.*]]: memref<2x2xi32, #spirv.storage_class<StorageBuffer>>
+  func.func @load_from_image_2D_i32(%arg0: memref<2x2xi32, #spirv.storage_class<Image>>, %arg1: memref<2x2xi32, #spirv.storage_class<StorageBuffer>>) {
+// CHECK-DAG: %[[SB:.*]] = builtin.unrealized_conversion_cast %arg1 : memref<2x2xi32, #spirv.storage_class<StorageBuffer>> to !spirv.ptr<!spirv.struct<(!spirv.array<4 x i32, stride=4> [0])>, StorageBuffer>
+// CHECK-DAG: %[[IMAGE_PTR:.*]] = builtin.unrealized_conversion_cast %arg0 : memref<2x2xi32, #spirv.storage_class<Image>> to !spirv.ptr<!spirv.sampled_image<!spirv.image<i32, Dim2D, DepthUnknown, NonArrayed, SingleSampled, NeedSampler, R32i>>, UniformConstant>
+    // CHECK: %[[X:.*]] = arith.constant 0 : index
+    // CHECK: %[[X32:.*]] = builtin.unrealized_conversion_cast %[[X]] : index to i32
+    %x = arith.constant 0 : index
+    // CHECK: %[[Y:.*]] = arith.constant 1 : index
+    // CHECK: %[[Y32:.*]] = builtin.unrealized_conversion_cast %[[Y]] : index to i32
+    %y = arith.constant 1 : index
     // CHECK: %[[SIMAGE:.*]] = spirv.Load "UniformConstant" %[[IMAGE_PTR]] : !spirv.sampled_image<!spirv.image<i32, Dim2D, DepthUnknown, NonArrayed, SingleSampled, NeedSampler, R32i>>
     // CHECK: %[[IMAGE:.*]] = spirv.Image %[[SIMAGE]] : !spirv.sampled_image<!spirv.image<i32, Dim2D, DepthUnknown, NonArrayed, SingleSampled, NeedSampler, R32i>>
-    // CHECK: %[[COORDS:.*]] = spirv.CompositeConstruct %{{.*}}, %{{.*}} : (i32, i32) -> vector<2xi32>
+    // CHECK: %[[COORDS:.*]] = spirv.CompositeConstruct %[[X32]], %[[Y32]] : (i32, i32) -> vector<2xi32>
     // CHECK: %[[RES_VEC:.*]] =  spirv.ImageFetch %[[IMAGE]], %[[COORDS]]  : !spirv.image<i32, Dim2D, DepthUnknown, NonArrayed, SingleSampled, NeedSampler, R32i>, vector<2xi32> -> vector<4xi32>
     // CHECK: %[[RESULT:.*]] = spirv.CompositeExtract %[[RES_VEC]][0 : i32] : vector<4xi32>
-    %0 = memref.load %arg0[%cst, %cst] : memref<1x1xi32, #spirv.storage_class<Image>>
+    %0 = memref.load %arg0[%y, %x] : memref<2x2xi32, #spirv.storage_class<Image>>
     // CHECK: spirv.Store "StorageBuffer" %{{.*}}, %[[RESULT]] : i32
-    memref.store %0, %arg1[%cst, %cst] : memref<1x1xi32, #spirv.storage_class<StorageBuffer>>
+    memref.store %0, %arg1[%y, %x] : memref<2x2xi32, #spirv.storage_class<StorageBuffer>>
     return
   }
 
   // CHECK-LABEL: @load_from_image_2D_ui32(
-  // CHECK-SAME: %[[ARG0:.*]]: memref<1x1xui32, #spirv.storage_class<Image>>, %[[ARG1:.*]]: memref<1x1xui32, #spirv.storage_class<StorageBuffer>>
-  func.func @load_from_image_2D_ui32(%arg0: memref<1x1xui32, #spirv.storage_class<Image>>, %arg1: memref<1x1xui32, #spirv.storage_class<StorageBuffer>>) {
-// CHECK-DAG: %[[SB:.*]] = builtin.unrealized_conversion_cast %arg1 : memref<1x1xui32, #spirv.storage_class<StorageBuffer>> to !spirv.ptr<!spirv.struct<(!spirv.array<1 x ui32, stride=4> [0])>, StorageBuffer>
-// CHECK-DAG: %[[IMAGE_PTR:.*]] = builtin.unrealized_conversion_cast %arg0 : memref<1x1xui32, #spirv.storage_class<Image>> to !spirv.ptr<!spirv.sampled_image<!spirv.image<ui32, Dim2D, DepthUnknown, NonArrayed, SingleSampled, NeedSampler, R32ui>>, UniformConstant>
-    %cst = arith.constant 0 : index
+  // CHECK-SAME: %[[ARG0:.*]]: memref<2x2xui32, #spirv.storage_class<Image>>, %[[ARG1:.*]]: memref<2x2xui32, #spirv.storage_class<StorageBuffer>>
+  func.func @load_from_image_2D_ui32(%arg0: memref<2x2xui32, #spirv.storage_class<Image>>, %arg1: memref<2x2xui32, #spirv.storage_class<StorageBuffer>>) {
+// CHECK-DAG: %[[SB:.*]] = builtin.unrealized_conversion_cast %arg1 : memref<2x2xui32, #spirv.storage_class<StorageBuffer>> to !spirv.ptr<!spirv.struct<(!spirv.array<4 x ui32, stride=4> [0])>, StorageBuffer>
+// CHECK-DAG: %[[IMAGE_PTR:.*]] = builtin.unrealized_conversion_cast %arg0 : memref<2x2xui32, #spirv.storage_class<Image>> to !spirv.ptr<!spirv.sampled_image<!spirv.image<ui32, Dim2D, DepthUnknown, NonArrayed, SingleSampled, NeedSampler, R32ui>>, UniformConstant>
+    // CHECK: %[[X:.*]] = arith.constant 0 : index
+    // CHECK: %[[X32:.*]] = builtin.unrealized_conversion_cast %[[X]] : index to i32
+    %x = arith.constant 0 : index
+    // CHECK: %[[Y:.*]] = arith.constant 1 : index
+    // CHECK: %[[Y32:.*]] = builtin.unrealized_conversion_cast %[[Y]] : index to i32
+    %y = arith.constant 1 : index
     // CHECK: %[[SIMAGE:.*]] = spirv.Load "UniformConstant" %[[IMAGE_PTR]] : !spirv.sampled_image<!spirv.image<ui32, Dim2D, DepthUnknown, NonArrayed, SingleSampled, NeedSampler, R32ui>>
     // CHECK: %[[IMAGE:.*]] = spirv.Image %[[SIMAGE]] : !spirv.sampled_image<!spirv.image<ui32, Dim2D, DepthUnknown, NonArrayed, SingleSampled, NeedSampler, R32ui>>
-    // CHECK: %[[COORDS:.*]] = spirv.CompositeConstruct %{{.*}}, %{{.*}} : (i32, i32) -> vector<2xi32>
+    // CHECK: %[[COORDS:.*]] = spirv.CompositeConstruct %[[X32]], %[[Y32]] : (i32, i32) -> vector<2xi32>
     // CHECK: %[[RES_VEC:.*]] =  spirv.ImageFetch %[[IMAGE]], %[[COORDS]]  : !spirv.image<ui32, Dim2D, DepthUnknown, NonArrayed, SingleSampled, NeedSampler, R32ui>, vector<2xi32> -> vector<4xui32>
     // CHECK: %[[RESULT:.*]] = spirv.CompositeExtract %[[RES_VEC]][0 : i32] : vector<4xui32>
-    %0 = memref.load %arg0[%cst, %cst] : memref<1x1xui32, #spirv.storage_class<Image>>
+    %0 = memref.load %arg0[%y, %x] : memref<2x2xui32, #spirv.storage_class<Image>>
     // CHECK: spirv.Store "StorageBuffer" %{{.*}}, %[[RESULT]] : ui32
-    memref.store %0, %arg1[%cst, %cst] : memref<1x1xui32, #spirv.storage_class<StorageBuffer>>
+    memref.store %0, %arg1[%y, %x] : memref<2x2xui32, #spirv.storage_class<StorageBuffer>>
     return
   }
 
   // CHECK-LABEL: @load_from_image_2D_i16(
-  // CHECK-SAME: %[[ARG0:.*]]: memref<1x1xi16, #spirv.storage_class<Image>>, %[[ARG1:.*]]: memref<1x1xi16, #spirv.storage_class<StorageBuffer>>
-  func.func @load_from_image_2D_i16(%arg0: memref<1x1xi16, #spirv.storage_class<Image>>, %arg1: memref<1x1xi16, #spirv.storage_class<StorageBuffer>>) {
-// CHECK-DAG: %[[SB:.*]] = builtin.unrealized_conversion_cast %arg1 : memref<1x1xi16, #spirv.storage_class<StorageBuffer>> to !spirv.ptr<!spirv.struct<(!spirv.array<1 x i16, stride=2> [0])>, StorageBuffer>
-// CHECK-DAG: %[[IMAGE_PTR:.*]] = builtin.unrealized_conversion_cast %arg0 : memref<1x1xi16, #spirv.storage_class<Image>> to !spirv.ptr<!spirv.sampled_image<!spirv.image<i16, Dim2D, DepthUnknown, NonArrayed, SingleSampled, NeedSampler, R16i>>, UniformConstant>
-    %cst = arith.constant 0 : index
+  // CHECK-SAME: %[[ARG0:.*]]: memref<2x2xi16, #spirv.storage_class<Image>>, %[[ARG1:.*]]: memref<2x2xi16, #spirv.storage_class<StorageBuffer>>
+  func.func @load_from_image_2D_i16(%arg0: memref<2x2xi16, #spirv.storage_class<Image>>, %arg1: memref<2x2xi16, #spirv.storage_class<StorageBuffer>>) {
+// CHECK-DAG: %[[SB:.*]] = builtin.unrealized_conversion_cast %arg1 : memref<2x2xi16, #spirv.storage_class<StorageBuffer>> to !spirv.ptr<!spirv.struct<(!spirv.array<4 x i16, stride=2> [0])>, StorageBuffer>
+// CHECK-DAG: %[[IMAGE_PTR:.*]] = builtin.unrealized_conversion_cast %arg0 : memref<2x2xi16, #spirv.storage_class<Image>> to !spirv.ptr<!spirv.sampled_image<!spirv.image<i16, Dim2D, DepthUnknown, NonArrayed, SingleSampled, NeedSampler, R16i>>, UniformConstant>
+    // CHECK: %[[X:.*]] = arith.constant 0 : index
+    // CHECK: %[[X32:.*]] = builtin.unrealized_conversion_cast %[[X]] : index to i32
+    %x = arith.constant 0 : index
+    // CHECK: %[[Y:.*]] = arith.constant 1 : index
+    // CHECK: %[[Y32:.*]] = builtin.unrealized_conversion_cast %[[Y]] : index to i32
+    %y = arith.constant 1 : index
     // CHECK: %[[SIMAGE:.*]] = spirv.Load "UniformConstant" %[[IMAGE_PTR]] : !spirv.sampled_image<!spirv.image<i16, Dim2D, DepthUnknown, NonArrayed, SingleSampled, NeedSampler, R16i>>
     // CHECK: %[[IMAGE:.*]] = spirv.Image %[[SIMAGE]] : !spirv.sampled_image<!spirv.image<i16, Dim2D, DepthUnknown, NonArrayed, SingleSampled, NeedSampler, R16i>>
-    // CHECK: %[[COORDS:.*]] = spirv.CompositeConstruct %{{.*}}, %{{.*}} : (i32, i32) -> vector<2xi32>
+    // CHECK: %[[COORDS:.*]] = spirv.CompositeConstruct %[[X32]], %[[Y32]] : (i32, i32) -> vector<2xi32>
     // CHECK: %[[RES_VEC:.*]] =  spirv.ImageFetch %[[IMAGE]], %[[COORDS]]  : !spirv.image<i16, Dim2D, Dept...
[truncated]

Copy link
Member

@kuhar kuhar left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@IgWod-IMG could you review this one?

@kuhar kuhar requested a review from IgWod-IMG September 24, 2025 12:26
@IgWod-IMG
Copy link
Contributor

I am a bit confused, but maybe a picture will help. Let's look at the Figure 1 in here. Is what you are trying to say that in memref the first dimension would be height, and the second dimension would be width, but in Vulkan (as in the figure) the first dimension is width, whereas the second dimension is height? And since in both cases you want to access data row by row, we need to reverse indices? So basically, without this patch the conversion silently changes iteration from row-major to column-major?

Am I understanding it right?

@FranklandJack
Copy link
Contributor Author

I am a bit confused, but maybe a picture will help. Let's look at the Figure 1 in here. Is what you are trying to say that in memref the first dimension would be height, and the second dimension would be width, but in Vulkan (as in the figure) the first dimension is width, whereas the second dimension is height? And since in both cases you want to access data row by row, we need to reverse indices? So basically, without this patch the conversion silently changes iteration from row-major to column-major?

Am I understanding it right?

Yes, I think so.

The definition of OpImageFetch has the following wording for the coordinates argument:

Coordinate must be a scalar or vector of [integer type](https://registry.khronos.org/SPIR-V/specs/unified1/SPIRV.html#Integer). It contains (u[, v] …​ [, array layer]) as needed by the definition of Sampled Image.

I'm not entirely sure what u[, v] …​ [, array layer] refers to here but the Vulkan spec has for normalized coordinates:

    u: Coordinate in the first dimension of an image.
    v: Coordinate in the second dimension of an image.
    w: Coordinate in the third dimension of an image.
    a: Coordinate for array layer.

For a memref the innermost dimension is the first dimension (although maybe this isn't a safe assumption if the layout isn't row-major 🤷 ), so when you get the indices via a call to getIndices() it would seem you need to reverse them to get the coordinate ordering required by the SPIR-V op?

I may have misunderstood the SPIR-V/Vulkan specs here, so apologies if this isn't the case.

@IgWod-IMG
Copy link
Contributor

I may have misunderstood the SPIR-V/Vulkan specs here, so apologies if this isn't the case.

I am not a SPIR-V/Vulkan expert either, but I agree with your assessment. I guess it's a bit tricky since Image does not really have a layout, just coordinates.

(although maybe this isn't a safe assumption if the layout isn't row-major 🤷 )

I think this is an important point, we only need to reverse indices for row-major memrefs, and column-major should just work as it is. And I guess any other layout should not be supported - unless conversion can transform indices to match the logical Image coordinates.

So, my recommendation would be:

  1. Only do the reversal for row-major memrefs.
  2. Update the description with what we discussed so the reason behind the decision is clear; also, I would use i/u, j/v, k/w, instead of x, y, z.
  3. minor: Update tests to use non-square/cube images; having all dimension the same in tests does not help the mental model.

@FranklandJack
Copy link
Contributor Author

I may have misunderstood the SPIR-V/Vulkan specs here, so apologies if this isn't the case.

I am not a SPIR-V/Vulkan expert either, but I agree with your assessment. I guess it's a bit tricky since Image does not really have a layout, just coordinates.

(although maybe this isn't a safe assumption if the layout isn't row-major 🤷 )

I think this is an important point, we only need to reverse indices for row-major memrefs, and column-major should just work as it is. And I guess any other layout should not be supported - unless conversion can transform indices to match the logical Image coordinates.

So, my recommendation would be:

1. Only do the reversal for row-major `memref`s.

2. Update the description with what we discussed so the reason behind the decision is clear; also, I would use `i/u`, `j/v`, `k/w`, instead of `x`, `y`, `z`.

3. minor: Update tests to use non-square/cube images; having all dimension the same in tests does not help the mental model.

Okay so, I started implementing these changes but thinking about this a bit more, does SPIR-V provide any mapping from texel coordinates to hardware memory, or are images and texels completely opaque?

In the case of MLIR's memrefs we have a mapping at the IR level, namely the layout given by the memref's type which is assumed to be the identity if none is present (I guess a backend is free to remap the layout however it likes, but that is up to the target). In this conversion we are providing a generic lowering that is assuming for images width is the contiguous dimension in memory, followed by height and then depth. Whilst this might be a valid assumption for some hardware, is this true for all targets? Does this lowering as an upstream conversion even make sense?

@GeorgeARM GeorgeARM self-requested a review September 25, 2025 13:04
@IgWod-IMG
Copy link
Contributor

Okay so, I started implementing these changes but thinking about this a bit more, does SPIR-V provide any mapping from texel coordinates to hardware memory, or are images and texels completely opaque?

My understanding is that in SPIR-V images are just handles and are opaque.

In this conversion we are providing a generic lowering that is assuming for images width is the contiguous dimension in memory, followed by height and then depth. Whilst this might be a valid assumption for some hardware, is this true for all targets?

I investigated image stuff a bit more and it seems that it is okay to assume you always have width, height and depth in that order. SPIR-V is mostly hardware agnostic, so it is up to the vendor to translate OpImageFetch coordinates into something that fits into the memory layout of the actual hardware, as such I don’t think we need to concern ourselves with the actual hardware layouts at this level.

Does this lowering as an upstream conversion even make sense?

I think it does, but we need to be careful with how we map memref that can represent many different layouts to something that is rather opaque. At the same time, it seems to me that the alternative solution would be to leave the conversion as it is, and then it would be up to the driver/host to make sure image is in the correct order so that indexing makes sense.

I think it would be useful to get an input on this patch from someone who has more experience with SPIR-V.

@FranklandJack
Copy link
Contributor Author

I think it would be useful to get an input on this patch from someone who has more experience with SPIR-V.

Agreed, is there anyone from the MLIR community we could tag here to get input?

@GeorgeARM
Copy link
Contributor

Hello, not a SPIR-V expert but here are my 2cents.

In general the SPIR-V image coordinate system is well-defined and independent of the memory layout itself. So the indexing reversal does make sense to reconcile the memref indexing order with the spir-v one from a logical perspective.

Note though there is caveat on the return type on texel fetch as this will always be 4 elements and depending on the format we might need to extract one or expect the memref to be encoded in such a way to capture this.
For formats like R32F this is straight-forward (we always extract the element we want in the conversion) but for the likes of RGBA8 which are multi-channel then we probably need to expect memref with elements of vector<4xi8>.

Finally, you are right the image layout is opaque in SPIR-V, and for good reason, but it shouldn't be conflated with memref layouts. We shouldn't try to retrofit the memref layout onto images; this could be quite tricky; hence to keep thinks safe and clean we can initially require that memref are in the image address space (as you probably already do) and expect identity layout. It then up to the client/user to ensure that image is created and populated correctly though the API.

Think this is a great start to get use of basic image support mainly in compute and approach seems quite reasonable.

When converting a `memref.load` from the image address space to a
`spirv.ImageFetch` ensure that we reverse the load coordinates when they
are extracted from the load.

This is required because the coordinate operand to the fetch operation
is a vector with the coordinates in increasing rank whereas the load
operation has coordinates of decreaseing rank. For example, if the
memref.load operation loaded from the coordinates `[%z, %y, %x]` then
`%x` would be the fastest moving dimension followed by `%y` then `%z` so
the spirv.ImageFetch operation would expect a vector `<%x, %y, %z>`.

Signed-off-by: Jack Frankland <jack.frankland@arm.com>
Reverse indices in place.

Signed-off-by: Jack Frankland <jack.frankland@arm.com>
* Generalize coordinate mapping to support arbitrary permutations.
* Add lit tests for permutations.
* Make memrefs in lit tests non square.

Signed-off-by: Jack Frankland <jack.frankland@arm.com>
Update comments to reflect support for linearly tiled images only.

Signed-off-by: Jack Frankland <jack.frankland@arm.com>
Make map variables globally scoped in lit tests.

Signed-off-by: Jack Frankland <jack.frankland@arm.com>
@FranklandJack
Copy link
Contributor Author

@IgWod-IMG Would you be comfortable approving this PR? I think I've addressed your comments and I'll update the commit message when I squash the commits on merging :)

Copy link
Member

@kuhar kuhar left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The code looks OK % nits, but I haven't reviewed the logic

Copy link
Contributor

@IgWod-IMG IgWod-IMG left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks okay in general, just two more questions and I'm happy to approve.

* Remove `auto` usage
* Expand comment

Signed-off-by: Jack Frankland <jack.frankland@arm.com>
Copy link
Contributor

@IgWod-IMG IgWod-IMG left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The rationale makes sense so LGTM. Let's just wait for the clarification on the negative testing.

* Add negative test

Signed-off-by: Jack Frankland <jack.frankland@arm.com>
@FranklandJack FranklandJack merged commit 048922e into llvm:main Sep 30, 2025
9 checks passed
mahesh-attarde pushed a commit to mahesh-attarde/llvm-project that referenced this pull request Oct 3, 2025
When converting a `memref.load` from the image address space to a
`spirv.ImageFetch` ensure that we correctly map the load indices to
width, height and depth.

The lowering currently assumes a linear image tiling, that is row-major
memory layout. This allows us to support any memref layout that is a
permutation of the dimensions, more complex layouts are not currently
supported. Because the ordering of the dimensions in the vector passed
to image fetch is the opposite to that in the memref directions a final
reversal of the mapped dimensions is always required.

---------

Signed-off-by: Jack Frankland <jack.frankland@arm.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Projects

None yet

Development

Successfully merging this pull request may close these issues.

5 participants