From 32b43c867a2d9aceddf102c18eafd44b7757fe78 Mon Sep 17 00:00:00 2001 From: Jack Frankland Date: Wed, 24 Sep 2025 11:13:59 +0100 Subject: [PATCH 1/7] [mlir][memref-to-spirv]: Reverse Image Load Coordinates When converting a `memref.load` from the image address space to a `spirv.ImageFetch` ensure that we reverse the load coordinates when they are extracted from the load. This is required because the coordinate operand to the fetch operation is a vector with the coordinates in increasing rank whereas the load operation has coordinates of decreaseing rank. For example, if the memref.load operation loaded from the coordinates `[%z, %y, %x]` then `%x` would be the fastest moving dimension followed by `%y` then `%z` so the spirv.ImageFetch operation would expect a vector `<%x, %y, %z>`. Signed-off-by: Jack Frankland --- .../MemRefToSPIRV/MemRefToSPIRV.cpp | 4 +- .../MemRefToSPIRV/memref-to-spirv.mlir | 150 +++++++++++------- 2 files changed, 97 insertions(+), 57 deletions(-) diff --git a/mlir/lib/Conversion/MemRefToSPIRV/MemRefToSPIRV.cpp b/mlir/lib/Conversion/MemRefToSPIRV/MemRefToSPIRV.cpp index f44552c4556c2..da1c2a18414cc 100644 --- a/mlir/lib/Conversion/MemRefToSPIRV/MemRefToSPIRV.cpp +++ b/mlir/lib/Conversion/MemRefToSPIRV/MemRefToSPIRV.cpp @@ -758,8 +758,10 @@ ImageLoadOpPattern::matchAndRewrite(memref::LoadOp loadOp, OpAdaptor adaptor, if (memrefType.getRank() != 1) { auto coordVectorType = VectorType::get({loadOp.getMemRefType().getRank()}, adaptor.getIndices().getType()[0]); + auto indices = llvm::to_vector(adaptor.getIndices()); + auto indicesReversed = llvm::to_vector(llvm::reverse(indices)); coords = spirv::CompositeConstructOp::create(rewriter, loc, coordVectorType, - adaptor.getIndices()); + indicesReversed); } else { coords = adaptor.getIndices()[0]; } diff --git a/mlir/test/Conversion/MemRefToSPIRV/memref-to-spirv.mlir b/mlir/test/Conversion/MemRefToSPIRV/memref-to-spirv.mlir index e6321e99693ac..56bf4939a0e63 100644 --- a/mlir/test/Conversion/MemRefToSPIRV/memref-to-spirv.mlir +++ b/mlir/test/Conversion/MemRefToSPIRV/memref-to-spirv.mlir @@ -550,121 +550,159 @@ module attributes { } // CHECK-LABEL: @load_from_image_2D( - // CHECK-SAME: %[[ARG0:.*]]: memref<1x1xf32, #spirv.storage_class>, %[[ARG1:.*]]: memref<1x1xf32, #spirv.storage_class> - func.func @load_from_image_2D(%arg0: memref<1x1xf32, #spirv.storage_class>, %arg1: memref<1x1xf32, #spirv.storage_class>) { -// CHECK-DAG: %[[SB:.*]] = builtin.unrealized_conversion_cast %arg1 : memref<1x1xf32, #spirv.storage_class> to !spirv.ptr [0])>, StorageBuffer> -// CHECK-DAG: %[[IMAGE_PTR:.*]] = builtin.unrealized_conversion_cast %arg0 : memref<1x1xf32, #spirv.storage_class> to !spirv.ptr>, UniformConstant> - %cst = arith.constant 0 : index + // CHECK-SAME: %[[ARG0:.*]]: memref<2x2xf32, #spirv.storage_class>, %[[ARG1:.*]]: memref<2x2xf32, #spirv.storage_class> + func.func @load_from_image_2D(%arg0: memref<2x2xf32, #spirv.storage_class>, %arg1: memref<2x2xf32, #spirv.storage_class>) { +// CHECK-DAG: %[[SB:.*]] = builtin.unrealized_conversion_cast %arg1 : memref<2x2xf32, #spirv.storage_class> to !spirv.ptr [0])>, StorageBuffer> +// CHECK-DAG: %[[IMAGE_PTR:.*]] = builtin.unrealized_conversion_cast %arg0 : memref<2x2xf32, #spirv.storage_class> to !spirv.ptr>, UniformConstant> + // CHECK: %[[X:.*]] = arith.constant 0 : index + // CHECK: %[[X32:.*]] = builtin.unrealized_conversion_cast %[[X]] : index to i32 + %x = arith.constant 0 : index + // CHECK: %[[Y:.*]] = arith.constant 1 : index + // CHECK: %[[Y32:.*]] = builtin.unrealized_conversion_cast %[[Y]] : index to i32 + %y = arith.constant 1 : index // CHECK: %[[SIMAGE:.*]] = spirv.Load "UniformConstant" %[[IMAGE_PTR]] : !spirv.sampled_image> // CHECK: %[[IMAGE:.*]] = spirv.Image %[[SIMAGE]] : !spirv.sampled_image> - // CHECK: %[[COORDS:.*]] = spirv.CompositeConstruct %{{.*}}, %{{.*}} : (i32, i32) -> vector<2xi32> + // CHECK: %[[COORDS:.*]] = spirv.CompositeConstruct %[[X32]], %[[Y32]] : (i32, i32) -> vector<2xi32> // CHECK: %[[RES_VEC:.*]] = spirv.ImageFetch %[[IMAGE]], %[[COORDS]] : !spirv.image, vector<2xi32> -> vector<4xf32> // CHECK: %[[RESULT:.*]] = spirv.CompositeExtract %[[RES_VEC]][0 : i32] : vector<4xf32> - %0 = memref.load %arg0[%cst, %cst] : memref<1x1xf32, #spirv.storage_class> + %0 = memref.load %arg0[%y, %x] : memref<2x2xf32, #spirv.storage_class> // CHECK: spirv.Store "StorageBuffer" %{{.*}}, %[[RESULT]] : f32 - memref.store %0, %arg1[%cst, %cst] : memref<1x1xf32, #spirv.storage_class> + memref.store %0, %arg1[%y, %x] : memref<2x2xf32, #spirv.storage_class> return } // CHECK-LABEL: @load_from_image_3D( - // CHECK-SAME: %[[ARG0:.*]]: memref<1x1x1xf32, #spirv.storage_class>, %[[ARG1:.*]]: memref<1x1x1xf32, #spirv.storage_class> - func.func @load_from_image_3D(%arg0: memref<1x1x1xf32, #spirv.storage_class>, %arg1: memref<1x1x1xf32, #spirv.storage_class>) { -// CHECK-DAG: %[[SB:.*]] = builtin.unrealized_conversion_cast %arg1 : memref<1x1x1xf32, #spirv.storage_class> to !spirv.ptr [0])>, StorageBuffer> -// CHECK-DAG: %[[IMAGE_PTR:.*]] = builtin.unrealized_conversion_cast %arg0 : memref<1x1x1xf32, #spirv.storage_class> to !spirv.ptr>, UniformConstant> - %cst = arith.constant 0 : index + // CHECK-SAME: %[[ARG0:.*]]: memref<3x3x3xf32, #spirv.storage_class>, %[[ARG1:.*]]: memref<3x3x3xf32, #spirv.storage_class> + func.func @load_from_image_3D(%arg0: memref<3x3x3xf32, #spirv.storage_class>, %arg1: memref<3x3x3xf32, #spirv.storage_class>) { +// CHECK-DAG: %[[SB:.*]] = builtin.unrealized_conversion_cast %arg1 : memref<3x3x3xf32, #spirv.storage_class> to !spirv.ptr [0])>, StorageBuffer> +// CHECK-DAG: %[[IMAGE_PTR:.*]] = builtin.unrealized_conversion_cast %arg0 : memref<3x3x3xf32, #spirv.storage_class> to !spirv.ptr>, UniformConstant> + // CHECK: %[[X:.*]] = arith.constant 0 : index + // CHECK: %[[X32:.*]] = builtin.unrealized_conversion_cast %[[X]] : index to i32 + %x = arith.constant 0 : index + // CHECK: %[[Y:.*]] = arith.constant 1 : index + // CHECK: %[[Y32:.*]] = builtin.unrealized_conversion_cast %[[Y]] : index to i32 + %y = arith.constant 1 : index + // CHECK: %[[Z:.*]] = arith.constant 2 : index + // CHECK: %[[Z32:.*]] = builtin.unrealized_conversion_cast %[[Z]] : index to i32 + %z = arith.constant 2 : index // CHECK: %[[SIMAGE:.*]] = spirv.Load "UniformConstant" %[[IMAGE_PTR]] : !spirv.sampled_image> // CHECK: %[[IMAGE:.*]] = spirv.Image %[[SIMAGE]] : !spirv.sampled_image> - // CHECK: %[[COORDS:.*]] = spirv.CompositeConstruct %{{.*}}, %{{.*}}, %{{.*}} : (i32, i32, i32) -> vector<3xi32> + // CHECK: %[[COORDS:.*]] = spirv.CompositeConstruct %[[X32]], %[[Y32]], %[[Z32]] : (i32, i32, i32) -> vector<3xi32> // CHECK: %[[RES_VEC:.*]] = spirv.ImageFetch %[[IMAGE]], %[[COORDS]] : !spirv.image, vector<3xi32> -> vector<4xf32> // CHECK: %[[RESULT:.*]] = spirv.CompositeExtract %[[RES_VEC]][0 : i32] : vector<4xf32> - %0 = memref.load %arg0[%cst, %cst, %cst] : memref<1x1x1xf32, #spirv.storage_class> + %0 = memref.load %arg0[%z, %y, %x] : memref<3x3x3xf32, #spirv.storage_class> // CHECK: spirv.Store "StorageBuffer" %{{.*}}, %[[RESULT]] : f32 - memref.store %0, %arg1[%cst, %cst, %cst] : memref<1x1x1xf32, #spirv.storage_class> + memref.store %0, %arg1[%z, %y, %x] : memref<3x3x3xf32, #spirv.storage_class> return } // CHECK-LABEL: @load_from_image_2D_f16( - // CHECK-SAME: %[[ARG0:.*]]: memref<1x1xf16, #spirv.storage_class>, %[[ARG1:.*]]: memref<1x1xf16, #spirv.storage_class> - func.func @load_from_image_2D_f16(%arg0: memref<1x1xf16, #spirv.storage_class>, %arg1: memref<1x1xf16, #spirv.storage_class>) { -// CHECK-DAG: %[[SB:.*]] = builtin.unrealized_conversion_cast %arg1 : memref<1x1xf16, #spirv.storage_class> to !spirv.ptr [0])>, StorageBuffer> -// CHECK-DAG: %[[IMAGE_PTR:.*]] = builtin.unrealized_conversion_cast %arg0 : memref<1x1xf16, #spirv.storage_class> to !spirv.ptr>, UniformConstant> - %cst = arith.constant 0 : index + // CHECK-SAME: %[[ARG0:.*]]: memref<2x2xf16, #spirv.storage_class>, %[[ARG1:.*]]: memref<2x2xf16, #spirv.storage_class> + func.func @load_from_image_2D_f16(%arg0: memref<2x2xf16, #spirv.storage_class>, %arg1: memref<2x2xf16, #spirv.storage_class>) { +// CHECK-DAG: %[[SB:.*]] = builtin.unrealized_conversion_cast %arg1 : memref<2x2xf16, #spirv.storage_class> to !spirv.ptr [0])>, StorageBuffer> +// CHECK-DAG: %[[IMAGE_PTR:.*]] = builtin.unrealized_conversion_cast %arg0 : memref<2x2xf16, #spirv.storage_class> to !spirv.ptr>, UniformConstant> + // CHECK: %[[X:.*]] = arith.constant 0 : index + // CHECK: %[[X32:.*]] = builtin.unrealized_conversion_cast %[[X]] : index to i32 + %x = arith.constant 0 : index + // CHECK: %[[Y:.*]] = arith.constant 1 : index + // CHECK: %[[Y32:.*]] = builtin.unrealized_conversion_cast %[[Y]] : index to i32 + %y = arith.constant 1 : index // CHECK: %[[SIMAGE:.*]] = spirv.Load "UniformConstant" %[[IMAGE_PTR]] : !spirv.sampled_image> // CHECK: %[[IMAGE:.*]] = spirv.Image %[[SIMAGE]] : !spirv.sampled_image> - // CHECK: %[[COORDS:.*]] = spirv.CompositeConstruct %{{.*}}, %{{.*}} : (i32, i32) -> vector<2xi32> + // CHECK: %[[COORDS:.*]] = spirv.CompositeConstruct %[[X32]], %[[Y32]] : (i32, i32) -> vector<2xi32> // CHECK: %[[RES_VEC:.*]] = spirv.ImageFetch %[[IMAGE]], %[[COORDS]] : !spirv.image, vector<2xi32> -> vector<4xf16> // CHECK: %[[RESULT:.*]] = spirv.CompositeExtract %[[RES_VEC]][0 : i32] : vector<4xf16> - %0 = memref.load %arg0[%cst, %cst] : memref<1x1xf16, #spirv.storage_class> + %0 = memref.load %arg0[%y, %x] : memref<2x2xf16, #spirv.storage_class> // CHECK: spirv.Store "StorageBuffer" %{{.*}}, %[[RESULT]] : f16 - memref.store %0, %arg1[%cst, %cst] : memref<1x1xf16, #spirv.storage_class> + memref.store %0, %arg1[%y, %x] : memref<2x2xf16, #spirv.storage_class> return } // CHECK-LABEL: @load_from_image_2D_i32( - // CHECK-SAME: %[[ARG0:.*]]: memref<1x1xi32, #spirv.storage_class>, %[[ARG1:.*]]: memref<1x1xi32, #spirv.storage_class> - func.func @load_from_image_2D_i32(%arg0: memref<1x1xi32, #spirv.storage_class>, %arg1: memref<1x1xi32, #spirv.storage_class>) { -// CHECK-DAG: %[[SB:.*]] = builtin.unrealized_conversion_cast %arg1 : memref<1x1xi32, #spirv.storage_class> to !spirv.ptr [0])>, StorageBuffer> -// CHECK-DAG: %[[IMAGE_PTR:.*]] = builtin.unrealized_conversion_cast %arg0 : memref<1x1xi32, #spirv.storage_class> to !spirv.ptr>, UniformConstant> - %cst = arith.constant 0 : index + // CHECK-SAME: %[[ARG0:.*]]: memref<2x2xi32, #spirv.storage_class>, %[[ARG1:.*]]: memref<2x2xi32, #spirv.storage_class> + func.func @load_from_image_2D_i32(%arg0: memref<2x2xi32, #spirv.storage_class>, %arg1: memref<2x2xi32, #spirv.storage_class>) { +// CHECK-DAG: %[[SB:.*]] = builtin.unrealized_conversion_cast %arg1 : memref<2x2xi32, #spirv.storage_class> to !spirv.ptr [0])>, StorageBuffer> +// CHECK-DAG: %[[IMAGE_PTR:.*]] = builtin.unrealized_conversion_cast %arg0 : memref<2x2xi32, #spirv.storage_class> to !spirv.ptr>, UniformConstant> + // CHECK: %[[X:.*]] = arith.constant 0 : index + // CHECK: %[[X32:.*]] = builtin.unrealized_conversion_cast %[[X]] : index to i32 + %x = arith.constant 0 : index + // CHECK: %[[Y:.*]] = arith.constant 1 : index + // CHECK: %[[Y32:.*]] = builtin.unrealized_conversion_cast %[[Y]] : index to i32 + %y = arith.constant 1 : index // CHECK: %[[SIMAGE:.*]] = spirv.Load "UniformConstant" %[[IMAGE_PTR]] : !spirv.sampled_image> // CHECK: %[[IMAGE:.*]] = spirv.Image %[[SIMAGE]] : !spirv.sampled_image> - // CHECK: %[[COORDS:.*]] = spirv.CompositeConstruct %{{.*}}, %{{.*}} : (i32, i32) -> vector<2xi32> + // CHECK: %[[COORDS:.*]] = spirv.CompositeConstruct %[[X32]], %[[Y32]] : (i32, i32) -> vector<2xi32> // CHECK: %[[RES_VEC:.*]] = spirv.ImageFetch %[[IMAGE]], %[[COORDS]] : !spirv.image, vector<2xi32> -> vector<4xi32> // CHECK: %[[RESULT:.*]] = spirv.CompositeExtract %[[RES_VEC]][0 : i32] : vector<4xi32> - %0 = memref.load %arg0[%cst, %cst] : memref<1x1xi32, #spirv.storage_class> + %0 = memref.load %arg0[%y, %x] : memref<2x2xi32, #spirv.storage_class> // CHECK: spirv.Store "StorageBuffer" %{{.*}}, %[[RESULT]] : i32 - memref.store %0, %arg1[%cst, %cst] : memref<1x1xi32, #spirv.storage_class> + memref.store %0, %arg1[%y, %x] : memref<2x2xi32, #spirv.storage_class> return } // CHECK-LABEL: @load_from_image_2D_ui32( - // CHECK-SAME: %[[ARG0:.*]]: memref<1x1xui32, #spirv.storage_class>, %[[ARG1:.*]]: memref<1x1xui32, #spirv.storage_class> - func.func @load_from_image_2D_ui32(%arg0: memref<1x1xui32, #spirv.storage_class>, %arg1: memref<1x1xui32, #spirv.storage_class>) { -// CHECK-DAG: %[[SB:.*]] = builtin.unrealized_conversion_cast %arg1 : memref<1x1xui32, #spirv.storage_class> to !spirv.ptr [0])>, StorageBuffer> -// CHECK-DAG: %[[IMAGE_PTR:.*]] = builtin.unrealized_conversion_cast %arg0 : memref<1x1xui32, #spirv.storage_class> to !spirv.ptr>, UniformConstant> - %cst = arith.constant 0 : index + // CHECK-SAME: %[[ARG0:.*]]: memref<2x2xui32, #spirv.storage_class>, %[[ARG1:.*]]: memref<2x2xui32, #spirv.storage_class> + func.func @load_from_image_2D_ui32(%arg0: memref<2x2xui32, #spirv.storage_class>, %arg1: memref<2x2xui32, #spirv.storage_class>) { +// CHECK-DAG: %[[SB:.*]] = builtin.unrealized_conversion_cast %arg1 : memref<2x2xui32, #spirv.storage_class> to !spirv.ptr [0])>, StorageBuffer> +// CHECK-DAG: %[[IMAGE_PTR:.*]] = builtin.unrealized_conversion_cast %arg0 : memref<2x2xui32, #spirv.storage_class> to !spirv.ptr>, UniformConstant> + // CHECK: %[[X:.*]] = arith.constant 0 : index + // CHECK: %[[X32:.*]] = builtin.unrealized_conversion_cast %[[X]] : index to i32 + %x = arith.constant 0 : index + // CHECK: %[[Y:.*]] = arith.constant 1 : index + // CHECK: %[[Y32:.*]] = builtin.unrealized_conversion_cast %[[Y]] : index to i32 + %y = arith.constant 1 : index // CHECK: %[[SIMAGE:.*]] = spirv.Load "UniformConstant" %[[IMAGE_PTR]] : !spirv.sampled_image> // CHECK: %[[IMAGE:.*]] = spirv.Image %[[SIMAGE]] : !spirv.sampled_image> - // CHECK: %[[COORDS:.*]] = spirv.CompositeConstruct %{{.*}}, %{{.*}} : (i32, i32) -> vector<2xi32> + // CHECK: %[[COORDS:.*]] = spirv.CompositeConstruct %[[X32]], %[[Y32]] : (i32, i32) -> vector<2xi32> // CHECK: %[[RES_VEC:.*]] = spirv.ImageFetch %[[IMAGE]], %[[COORDS]] : !spirv.image, vector<2xi32> -> vector<4xui32> // CHECK: %[[RESULT:.*]] = spirv.CompositeExtract %[[RES_VEC]][0 : i32] : vector<4xui32> - %0 = memref.load %arg0[%cst, %cst] : memref<1x1xui32, #spirv.storage_class> + %0 = memref.load %arg0[%y, %x] : memref<2x2xui32, #spirv.storage_class> // CHECK: spirv.Store "StorageBuffer" %{{.*}}, %[[RESULT]] : ui32 - memref.store %0, %arg1[%cst, %cst] : memref<1x1xui32, #spirv.storage_class> + memref.store %0, %arg1[%y, %x] : memref<2x2xui32, #spirv.storage_class> return } // CHECK-LABEL: @load_from_image_2D_i16( - // CHECK-SAME: %[[ARG0:.*]]: memref<1x1xi16, #spirv.storage_class>, %[[ARG1:.*]]: memref<1x1xi16, #spirv.storage_class> - func.func @load_from_image_2D_i16(%arg0: memref<1x1xi16, #spirv.storage_class>, %arg1: memref<1x1xi16, #spirv.storage_class>) { -// CHECK-DAG: %[[SB:.*]] = builtin.unrealized_conversion_cast %arg1 : memref<1x1xi16, #spirv.storage_class> to !spirv.ptr [0])>, StorageBuffer> -// CHECK-DAG: %[[IMAGE_PTR:.*]] = builtin.unrealized_conversion_cast %arg0 : memref<1x1xi16, #spirv.storage_class> to !spirv.ptr>, UniformConstant> - %cst = arith.constant 0 : index + // CHECK-SAME: %[[ARG0:.*]]: memref<2x2xi16, #spirv.storage_class>, %[[ARG1:.*]]: memref<2x2xi16, #spirv.storage_class> + func.func @load_from_image_2D_i16(%arg0: memref<2x2xi16, #spirv.storage_class>, %arg1: memref<2x2xi16, #spirv.storage_class>) { +// CHECK-DAG: %[[SB:.*]] = builtin.unrealized_conversion_cast %arg1 : memref<2x2xi16, #spirv.storage_class> to !spirv.ptr [0])>, StorageBuffer> +// CHECK-DAG: %[[IMAGE_PTR:.*]] = builtin.unrealized_conversion_cast %arg0 : memref<2x2xi16, #spirv.storage_class> to !spirv.ptr>, UniformConstant> + // CHECK: %[[X:.*]] = arith.constant 0 : index + // CHECK: %[[X32:.*]] = builtin.unrealized_conversion_cast %[[X]] : index to i32 + %x = arith.constant 0 : index + // CHECK: %[[Y:.*]] = arith.constant 1 : index + // CHECK: %[[Y32:.*]] = builtin.unrealized_conversion_cast %[[Y]] : index to i32 + %y = arith.constant 1 : index // CHECK: %[[SIMAGE:.*]] = spirv.Load "UniformConstant" %[[IMAGE_PTR]] : !spirv.sampled_image> // CHECK: %[[IMAGE:.*]] = spirv.Image %[[SIMAGE]] : !spirv.sampled_image> - // CHECK: %[[COORDS:.*]] = spirv.CompositeConstruct %{{.*}}, %{{.*}} : (i32, i32) -> vector<2xi32> + // CHECK: %[[COORDS:.*]] = spirv.CompositeConstruct %[[X32]], %[[Y32]] : (i32, i32) -> vector<2xi32> // CHECK: %[[RES_VEC:.*]] = spirv.ImageFetch %[[IMAGE]], %[[COORDS]] : !spirv.image, vector<2xi32> -> vector<4xi16> // CHECK: %[[RESULT:.*]] = spirv.CompositeExtract %[[RES_VEC]][0 : i32] : vector<4xi16> - %0 = memref.load %arg0[%cst, %cst] : memref<1x1xi16, #spirv.storage_class> + %0 = memref.load %arg0[%y, %x] : memref<2x2xi16, #spirv.storage_class> // CHECK: spirv.Store "StorageBuffer" %{{.*}}, %[[RESULT]] : i16 - memref.store %0, %arg1[%cst, %cst] : memref<1x1xi16, #spirv.storage_class> + memref.store %0, %arg1[%y, %x] : memref<2x2xi16, #spirv.storage_class> return } // CHECK-LABEL: @load_from_image_2D_ui16( - // CHECK-SAME: %[[ARG0:.*]]: memref<1x1xui16, #spirv.storage_class>, %[[ARG1:.*]]: memref<1x1xui16, #spirv.storage_class> - func.func @load_from_image_2D_ui16(%arg0: memref<1x1xui16, #spirv.storage_class>, %arg1: memref<1x1xui16, #spirv.storage_class>) { -// CHECK-DAG: %[[SB:.*]] = builtin.unrealized_conversion_cast %arg1 : memref<1x1xui16, #spirv.storage_class> to !spirv.ptr [0])>, StorageBuffer> -// CHECK-DAG: %[[IMAGE_PTR:.*]] = builtin.unrealized_conversion_cast %arg0 : memref<1x1xui16, #spirv.storage_class> to !spirv.ptr>, UniformConstant> - %cst = arith.constant 0 : index + // CHECK-SAME: %[[ARG0:.*]]: memref<2x2xui16, #spirv.storage_class>, %[[ARG1:.*]]: memref<2x2xui16, #spirv.storage_class> + func.func @load_from_image_2D_ui16(%arg0: memref<2x2xui16, #spirv.storage_class>, %arg1: memref<2x2xui16, #spirv.storage_class>) { +// CHECK-DAG: %[[SB:.*]] = builtin.unrealized_conversion_cast %arg1 : memref<2x2xui16, #spirv.storage_class> to !spirv.ptr [0])>, StorageBuffer> +// CHECK-DAG: %[[IMAGE_PTR:.*]] = builtin.unrealized_conversion_cast %arg0 : memref<2x2xui16, #spirv.storage_class> to !spirv.ptr>, UniformConstant> + // CHECK: %[[X:.*]] = arith.constant 0 : index + // CHECK: %[[X32:.*]] = builtin.unrealized_conversion_cast %[[X]] : index to i32 + %x = arith.constant 0 : index + // CHECK: %[[Y:.*]] = arith.constant 1 : index + // CHECK: %[[Y32:.*]] = builtin.unrealized_conversion_cast %[[Y]] : index to i32 + %y = arith.constant 1 : index // CHECK: %[[SIMAGE:.*]] = spirv.Load "UniformConstant" %[[IMAGE_PTR]] : !spirv.sampled_image> // CHECK: %[[IMAGE:.*]] = spirv.Image %[[SIMAGE]] : !spirv.sampled_image> - // CHECK: %[[COORDS:.*]] = spirv.CompositeConstruct %{{.*}}, %{{.*}} : (i32, i32) -> vector<2xi32> + // CHECK: %[[COORDS:.*]] = spirv.CompositeConstruct %[[X32]], %[[Y32]] : (i32, i32) -> vector<2xi32> // CHECK: %[[RES_VEC:.*]] = spirv.ImageFetch %[[IMAGE]], %[[COORDS]] : !spirv.image, vector<2xi32> -> vector<4xui16> // CHECK: %[[RESULT:.*]] = spirv.CompositeExtract %[[RES_VEC]][0 : i32] : vector<4xui16> - %0 = memref.load %arg0[%cst, %cst] : memref<1x1xui16, #spirv.storage_class> + %0 = memref.load %arg0[%y, %x] : memref<2x2xui16, #spirv.storage_class> // CHECK: spirv.Store "StorageBuffer" %{{.*}}, %[[RESULT]] : ui16 - memref.store %0, %arg1[%cst, %cst] : memref<1x1xui16, #spirv.storage_class> + memref.store %0, %arg1[%y, %x] : memref<2x2xui16, #spirv.storage_class> return } From dcc20d1545cff05617ece999848249c7d9600f29 Mon Sep 17 00:00:00 2001 From: Jack Frankland Date: Wed, 24 Sep 2025 14:52:47 +0100 Subject: [PATCH 2/7] Address PR Feedback Reverse indices in place. Signed-off-by: Jack Frankland --- mlir/lib/Conversion/MemRefToSPIRV/MemRefToSPIRV.cpp | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/mlir/lib/Conversion/MemRefToSPIRV/MemRefToSPIRV.cpp b/mlir/lib/Conversion/MemRefToSPIRV/MemRefToSPIRV.cpp index da1c2a18414cc..1b66dab9ee985 100644 --- a/mlir/lib/Conversion/MemRefToSPIRV/MemRefToSPIRV.cpp +++ b/mlir/lib/Conversion/MemRefToSPIRV/MemRefToSPIRV.cpp @@ -758,8 +758,7 @@ ImageLoadOpPattern::matchAndRewrite(memref::LoadOp loadOp, OpAdaptor adaptor, if (memrefType.getRank() != 1) { auto coordVectorType = VectorType::get({loadOp.getMemRefType().getRank()}, adaptor.getIndices().getType()[0]); - auto indices = llvm::to_vector(adaptor.getIndices()); - auto indicesReversed = llvm::to_vector(llvm::reverse(indices)); + auto indicesReversed = llvm::to_vector(llvm::reverse(adaptor.getIndices())); coords = spirv::CompositeConstructOp::create(rewriter, loc, coordVectorType, indicesReversed); } else { From 7a396077a416f54a13fcc616361ea13584ebe9ad Mon Sep 17 00:00:00 2001 From: Jack Frankland Date: Thu, 25 Sep 2025 16:06:45 +0100 Subject: [PATCH 3/7] Address Feedback: * Generalize coordinate mapping to support arbitrary permutations. * Add lit tests for permutations. * Make memrefs in lit tests non square. Signed-off-by: Jack Frankland --- .../MemRefToSPIRV/MemRefToSPIRV.cpp | 43 ++++- .../MemRefToSPIRV/memref-to-spirv.mlir | 169 ++++++++++++------ 2 files changed, 148 insertions(+), 64 deletions(-) diff --git a/mlir/lib/Conversion/MemRefToSPIRV/MemRefToSPIRV.cpp b/mlir/lib/Conversion/MemRefToSPIRV/MemRefToSPIRV.cpp index 1b66dab9ee985..0dc7f693a6cc8 100644 --- a/mlir/lib/Conversion/MemRefToSPIRV/MemRefToSPIRV.cpp +++ b/mlir/lib/Conversion/MemRefToSPIRV/MemRefToSPIRV.cpp @@ -699,6 +699,37 @@ LoadOpPattern::matchAndRewrite(memref::LoadOp loadOp, OpAdaptor adaptor, return success(); } +template +static FailureOr> +extractLoadCoordsForComposite(memref::LoadOp loadOp, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) { + // Texel coordinates are ordered from inner most to outer most dimension + // i.e. u, v, w, a where: + // u: Coordinate in the first dimension of an image. + // v: Coordinate in the second dimension of an image. + // w: Coordinate in the third dimension of an image. + // a: Coordinate for array layer. + // + // The memrefs layout determines the dimension ordering so we need to invert + // the map to get the ordering. + SmallVector indices = adaptor.getIndices(); + auto map = loadOp.getMemRefType().getLayout().getAffineMap(); + if (!map.isPermutation()) + return rewriter.notifyMatchFailure( + loadOp, + "Cannot lower memrefs with memory layout which is not a permutation"); + + const unsigned dimCount = map.getNumDims(); + SmallVector coords(dimCount); + for (unsigned dim = 0; dim < dimCount; ++dim) + coords[map.getDimPosition(dim)] = indices[dim]; + + // We need to do a final reversal since the image fetch op expects the first + // dimension in the 0th element position, 2nd dimension in the 1st element + // position etc. which is the opposite to the ordering in the map. + return llvm::to_vector(llvm::reverse(coords)); +} + LogicalResult ImageLoadOpPattern::matchAndRewrite(memref::LoadOp loadOp, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const { @@ -755,14 +786,16 @@ ImageLoadOpPattern::matchAndRewrite(memref::LoadOp loadOp, OpAdaptor adaptor, // Build a vector of coordinates or just a scalar index if we have a 1D image. Value coords; - if (memrefType.getRank() != 1) { + if (memrefType.getRank() == 1) { + coords = adaptor.getIndices()[0]; + } else { + auto maybeCoords = extractLoadCoordsForComposite(loadOp, adaptor, rewriter); + if (failed(maybeCoords)) + return failure(); auto coordVectorType = VectorType::get({loadOp.getMemRefType().getRank()}, adaptor.getIndices().getType()[0]); - auto indicesReversed = llvm::to_vector(llvm::reverse(adaptor.getIndices())); coords = spirv::CompositeConstructOp::create(rewriter, loc, coordVectorType, - indicesReversed); - } else { - coords = adaptor.getIndices()[0]; + maybeCoords.value()); } // Fetch the value out of the image. diff --git a/mlir/test/Conversion/MemRefToSPIRV/memref-to-spirv.mlir b/mlir/test/Conversion/MemRefToSPIRV/memref-to-spirv.mlir index 56bf4939a0e63..0f7542cd9c469 100644 --- a/mlir/test/Conversion/MemRefToSPIRV/memref-to-spirv.mlir +++ b/mlir/test/Conversion/MemRefToSPIRV/memref-to-spirv.mlir @@ -515,6 +515,10 @@ module attributes { // Check Image Support. +// CHECK: #[[COLMAJMAP:[a-z_]+]] = affine_map<(d0, d1) -> (d1, d0)> +#col_major = affine_map<(d0, d1) -> (d1, d0)> +// CHECK: #[[CUSTOMLAYOUTMAP:[a-z0-9_]+]] = affine_map<(d0, d1, d2) -> (d2, d1, d0)> +#custom = affine_map<(d0, d1, d2) -> (d2, d1, d0)> module attributes { spirv.target_env = #spirv.target_env<#spirv.vce>, %[[ARG1:.*]]: memref<2x2xf32, #spirv.storage_class> - func.func @load_from_image_2D(%arg0: memref<2x2xf32, #spirv.storage_class>, %arg1: memref<2x2xf32, #spirv.storage_class>) { -// CHECK-DAG: %[[SB:.*]] = builtin.unrealized_conversion_cast %arg1 : memref<2x2xf32, #spirv.storage_class> to !spirv.ptr [0])>, StorageBuffer> -// CHECK-DAG: %[[IMAGE_PTR:.*]] = builtin.unrealized_conversion_cast %arg0 : memref<2x2xf32, #spirv.storage_class> to !spirv.ptr>, UniformConstant> - // CHECK: %[[X:.*]] = arith.constant 0 : index + // CHECK-SAME: %[[ARG0:.*]]: memref<2x4xf32, #spirv.storage_class>, %[[ARG1:.*]]: memref<2x4xf32, #spirv.storage_class> + func.func @load_from_image_2D(%arg0: memref<2x4xf32, #spirv.storage_class>, %arg1: memref<2x4xf32, #spirv.storage_class>) { +// CHECK-DAG: %[[SB:.*]] = builtin.unrealized_conversion_cast %arg1 : memref<2x4xf32, #spirv.storage_class> to !spirv.ptr [0])>, StorageBuffer> +// CHECK-DAG: %[[IMAGE_PTR:.*]] = builtin.unrealized_conversion_cast %arg0 : memref<2x4xf32, #spirv.storage_class> to !spirv.ptr>, UniformConstant> + // CHECK: %[[X:.*]] = arith.constant 3 : index // CHECK: %[[X32:.*]] = builtin.unrealized_conversion_cast %[[X]] : index to i32 - %x = arith.constant 0 : index + %x = arith.constant 3 : index // CHECK: %[[Y:.*]] = arith.constant 1 : index // CHECK: %[[Y32:.*]] = builtin.unrealized_conversion_cast %[[Y]] : index to i32 %y = arith.constant 1 : index @@ -565,45 +569,92 @@ module attributes { // CHECK: %[[COORDS:.*]] = spirv.CompositeConstruct %[[X32]], %[[Y32]] : (i32, i32) -> vector<2xi32> // CHECK: %[[RES_VEC:.*]] = spirv.ImageFetch %[[IMAGE]], %[[COORDS]] : !spirv.image, vector<2xi32> -> vector<4xf32> // CHECK: %[[RESULT:.*]] = spirv.CompositeExtract %[[RES_VEC]][0 : i32] : vector<4xf32> - %0 = memref.load %arg0[%y, %x] : memref<2x2xf32, #spirv.storage_class> + %0 = memref.load %arg0[%y, %x] : memref<2x4xf32, #spirv.storage_class> // CHECK: spirv.Store "StorageBuffer" %{{.*}}, %[[RESULT]] : f32 - memref.store %0, %arg1[%y, %x] : memref<2x2xf32, #spirv.storage_class> + memref.store %0, %arg1[%y, %x] : memref<2x4xf32, #spirv.storage_class> return } - // CHECK-LABEL: @load_from_image_3D( - // CHECK-SAME: %[[ARG0:.*]]: memref<3x3x3xf32, #spirv.storage_class>, %[[ARG1:.*]]: memref<3x3x3xf32, #spirv.storage_class> - func.func @load_from_image_3D(%arg0: memref<3x3x3xf32, #spirv.storage_class>, %arg1: memref<3x3x3xf32, #spirv.storage_class>) { -// CHECK-DAG: %[[SB:.*]] = builtin.unrealized_conversion_cast %arg1 : memref<3x3x3xf32, #spirv.storage_class> to !spirv.ptr [0])>, StorageBuffer> -// CHECK-DAG: %[[IMAGE_PTR:.*]] = builtin.unrealized_conversion_cast %arg0 : memref<3x3x3xf32, #spirv.storage_class> to !spirv.ptr>, UniformConstant> - // CHECK: %[[X:.*]] = arith.constant 0 : index + // CHECK-LABEL: @load_from_col_major_image_2D( + // CHECK-SAME: %[[ARG0:.*]]: memref<2x4xf32, #[[COLMAJMAP]], #spirv.storage_class>, %[[ARG1:.*]]: memref<2x4xf32, #spirv.storage_class> + func.func @load_from_col_major_image_2D(%arg0: memref<2x4xf32, #col_major, #spirv.storage_class>, %arg1: memref<2x4xf32, #spirv.storage_class>) { +// CHECK-DAG: %[[SB:.*]] = builtin.unrealized_conversion_cast %arg1 : memref<2x4xf32, #spirv.storage_class> to !spirv.ptr [0])>, StorageBuffer> +// CHECK-DAG: %[[IMAGE_PTR:.*]] = builtin.unrealized_conversion_cast %arg0 : memref<2x4xf32, #[[COLMAJMAP]], #spirv.storage_class> to !spirv.ptr>, UniformConstant> + // CHECK: %[[X:.*]] = arith.constant 3 : index // CHECK: %[[X32:.*]] = builtin.unrealized_conversion_cast %[[X]] : index to i32 - %x = arith.constant 0 : index + %x = arith.constant 3 : index // CHECK: %[[Y:.*]] = arith.constant 1 : index // CHECK: %[[Y32:.*]] = builtin.unrealized_conversion_cast %[[Y]] : index to i32 %y = arith.constant 1 : index - // CHECK: %[[Z:.*]] = arith.constant 2 : index + // CHECK: %[[SIMAGE:.*]] = spirv.Load "UniformConstant" %[[IMAGE_PTR]] : !spirv.sampled_image> + // CHECK: %[[IMAGE:.*]] = spirv.Image %[[SIMAGE]] : !spirv.sampled_image> + // CHECK: %[[COORDS:.*]] = spirv.CompositeConstruct %[[X32]], %[[Y32]] : (i32, i32) -> vector<2xi32> + // CHECK: %[[RES_VEC:.*]] = spirv.ImageFetch %[[IMAGE]], %[[COORDS]] : !spirv.image, vector<2xi32> -> vector<4xf32> + // CHECK: %[[RESULT:.*]] = spirv.CompositeExtract %[[RES_VEC]][0 : i32] : vector<4xf32> + %0 = memref.load %arg0[%x, %y] : memref<2x4xf32, #col_major, #spirv.storage_class> + // CHECK: spirv.Store "StorageBuffer" %{{.*}}, %[[RESULT]] : f32 + memref.store %0, %arg1[%y, %x] : memref<2x4xf32, #spirv.storage_class> + return + } + + // CHECK-LABEL: @load_from_image_3D( + // CHECK-SAME: %[[ARG0:.*]]: memref<2x3x4xf32, #spirv.storage_class>, %[[ARG1:.*]]: memref<2x3x4xf32, #spirv.storage_class> + func.func @load_from_image_3D(%arg0: memref<2x3x4xf32, #spirv.storage_class>, %arg1: memref<2x3x4xf32, #spirv.storage_class>) { +// CHECK-DAG: %[[SB:.*]] = builtin.unrealized_conversion_cast %arg1 : memref<2x3x4xf32, #spirv.storage_class> to !spirv.ptr [0])>, StorageBuffer> +// CHECK-DAG: %[[IMAGE_PTR:.*]] = builtin.unrealized_conversion_cast %arg0 : memref<2x3x4xf32, #spirv.storage_class> to !spirv.ptr>, UniformConstant> + // CHECK: %[[X:.*]] = arith.constant 3 : index + // CHECK: %[[X32:.*]] = builtin.unrealized_conversion_cast %[[X]] : index to i32 + %x = arith.constant 3 : index + // CHECK: %[[Y:.*]] = arith.constant 2 : index + // CHECK: %[[Y32:.*]] = builtin.unrealized_conversion_cast %[[Y]] : index to i32 + %y = arith.constant 2 : index + // CHECK: %[[Z:.*]] = arith.constant 1 : index + // CHECK: %[[Z32:.*]] = builtin.unrealized_conversion_cast %[[Z]] : index to i32 + %z = arith.constant 1 : index + // CHECK: %[[SIMAGE:.*]] = spirv.Load "UniformConstant" %[[IMAGE_PTR]] : !spirv.sampled_image> + // CHECK: %[[IMAGE:.*]] = spirv.Image %[[SIMAGE]] : !spirv.sampled_image> + // CHECK: %[[COORDS:.*]] = spirv.CompositeConstruct %[[X32]], %[[Y32]], %[[Z32]] : (i32, i32, i32) -> vector<3xi32> + // CHECK: %[[RES_VEC:.*]] = spirv.ImageFetch %[[IMAGE]], %[[COORDS]] : !spirv.image, vector<3xi32> -> vector<4xf32> + // CHECK: %[[RESULT:.*]] = spirv.CompositeExtract %[[RES_VEC]][0 : i32] : vector<4xf32> + %0 = memref.load %arg0[%z, %y, %x] : memref<2x3x4xf32, #spirv.storage_class> + // CHECK: spirv.Store "StorageBuffer" %{{.*}}, %[[RESULT]] : f32 + memref.store %0, %arg1[%z, %y, %x] : memref<2x3x4xf32, #spirv.storage_class> + return + } + + // CHECK-LABEL: @load_from_custom_layout_image_3D( + // CHECK-SAME: %[[ARG0:.*]]: memref<2x3x4xf32, #[[CUSTOMLAYOUTMAP]], #spirv.storage_class>, %[[ARG1:.*]]: memref<2x3x4xf32, #spirv.storage_class> + func.func @load_from_custom_layout_image_3D(%arg0: memref<2x3x4xf32, #custom, #spirv.storage_class>, %arg1: memref<2x3x4xf32, #spirv.storage_class>) { +// CHECK-DAG: %[[SB:.*]] = builtin.unrealized_conversion_cast %arg1 : memref<2x3x4xf32, #spirv.storage_class> to !spirv.ptr [0])>, StorageBuffer> +// CHECK-DAG: %[[IMAGE_PTR:.*]] = builtin.unrealized_conversion_cast %arg0 : memref<2x3x4xf32, #[[CUSTOMLAYOUTMAP]], #spirv.storage_class> to !spirv.ptr>, UniformConstant> + // CHECK: %[[X:.*]] = arith.constant 3 : index + // CHECK: %[[X32:.*]] = builtin.unrealized_conversion_cast %[[X]] : index to i32 + %x = arith.constant 3 : index + // CHECK: %[[Y:.*]] = arith.constant 2 : index + // CHECK: %[[Y32:.*]] = builtin.unrealized_conversion_cast %[[Y]] : index to i32 + %y = arith.constant 2 : index + // CHECK: %[[Z:.*]] = arith.constant 1 : index // CHECK: %[[Z32:.*]] = builtin.unrealized_conversion_cast %[[Z]] : index to i32 - %z = arith.constant 2 : index + %z = arith.constant 1 : index // CHECK: %[[SIMAGE:.*]] = spirv.Load "UniformConstant" %[[IMAGE_PTR]] : !spirv.sampled_image> // CHECK: %[[IMAGE:.*]] = spirv.Image %[[SIMAGE]] : !spirv.sampled_image> // CHECK: %[[COORDS:.*]] = spirv.CompositeConstruct %[[X32]], %[[Y32]], %[[Z32]] : (i32, i32, i32) -> vector<3xi32> // CHECK: %[[RES_VEC:.*]] = spirv.ImageFetch %[[IMAGE]], %[[COORDS]] : !spirv.image, vector<3xi32> -> vector<4xf32> // CHECK: %[[RESULT:.*]] = spirv.CompositeExtract %[[RES_VEC]][0 : i32] : vector<4xf32> - %0 = memref.load %arg0[%z, %y, %x] : memref<3x3x3xf32, #spirv.storage_class> + %0 = memref.load %arg0[%x, %y, %z] : memref<2x3x4xf32, #custom, #spirv.storage_class> // CHECK: spirv.Store "StorageBuffer" %{{.*}}, %[[RESULT]] : f32 - memref.store %0, %arg1[%z, %y, %x] : memref<3x3x3xf32, #spirv.storage_class> + memref.store %0, %arg1[%z, %y, %x] : memref<2x3x4xf32, #spirv.storage_class> return } // CHECK-LABEL: @load_from_image_2D_f16( - // CHECK-SAME: %[[ARG0:.*]]: memref<2x2xf16, #spirv.storage_class>, %[[ARG1:.*]]: memref<2x2xf16, #spirv.storage_class> - func.func @load_from_image_2D_f16(%arg0: memref<2x2xf16, #spirv.storage_class>, %arg1: memref<2x2xf16, #spirv.storage_class>) { -// CHECK-DAG: %[[SB:.*]] = builtin.unrealized_conversion_cast %arg1 : memref<2x2xf16, #spirv.storage_class> to !spirv.ptr [0])>, StorageBuffer> -// CHECK-DAG: %[[IMAGE_PTR:.*]] = builtin.unrealized_conversion_cast %arg0 : memref<2x2xf16, #spirv.storage_class> to !spirv.ptr>, UniformConstant> - // CHECK: %[[X:.*]] = arith.constant 0 : index + // CHECK-SAME: %[[ARG0:.*]]: memref<2x3xf16, #spirv.storage_class>, %[[ARG1:.*]]: memref<2x3xf16, #spirv.storage_class> + func.func @load_from_image_2D_f16(%arg0: memref<2x3xf16, #spirv.storage_class>, %arg1: memref<2x3xf16, #spirv.storage_class>) { +// CHECK-DAG: %[[SB:.*]] = builtin.unrealized_conversion_cast %arg1 : memref<2x3xf16, #spirv.storage_class> to !spirv.ptr [0])>, StorageBuffer> +// CHECK-DAG: %[[IMAGE_PTR:.*]] = builtin.unrealized_conversion_cast %arg0 : memref<2x3xf16, #spirv.storage_class> to !spirv.ptr>, UniformConstant> + // CHECK: %[[X:.*]] = arith.constant 2 : index // CHECK: %[[X32:.*]] = builtin.unrealized_conversion_cast %[[X]] : index to i32 - %x = arith.constant 0 : index + %x = arith.constant 2 : index // CHECK: %[[Y:.*]] = arith.constant 1 : index // CHECK: %[[Y32:.*]] = builtin.unrealized_conversion_cast %[[Y]] : index to i32 %y = arith.constant 1 : index @@ -612,20 +663,20 @@ module attributes { // CHECK: %[[COORDS:.*]] = spirv.CompositeConstruct %[[X32]], %[[Y32]] : (i32, i32) -> vector<2xi32> // CHECK: %[[RES_VEC:.*]] = spirv.ImageFetch %[[IMAGE]], %[[COORDS]] : !spirv.image, vector<2xi32> -> vector<4xf16> // CHECK: %[[RESULT:.*]] = spirv.CompositeExtract %[[RES_VEC]][0 : i32] : vector<4xf16> - %0 = memref.load %arg0[%y, %x] : memref<2x2xf16, #spirv.storage_class> + %0 = memref.load %arg0[%y, %x] : memref<2x3xf16, #spirv.storage_class> // CHECK: spirv.Store "StorageBuffer" %{{.*}}, %[[RESULT]] : f16 - memref.store %0, %arg1[%y, %x] : memref<2x2xf16, #spirv.storage_class> + memref.store %0, %arg1[%y, %x] : memref<2x3xf16, #spirv.storage_class> return } // CHECK-LABEL: @load_from_image_2D_i32( - // CHECK-SAME: %[[ARG0:.*]]: memref<2x2xi32, #spirv.storage_class>, %[[ARG1:.*]]: memref<2x2xi32, #spirv.storage_class> - func.func @load_from_image_2D_i32(%arg0: memref<2x2xi32, #spirv.storage_class>, %arg1: memref<2x2xi32, #spirv.storage_class>) { -// CHECK-DAG: %[[SB:.*]] = builtin.unrealized_conversion_cast %arg1 : memref<2x2xi32, #spirv.storage_class> to !spirv.ptr [0])>, StorageBuffer> -// CHECK-DAG: %[[IMAGE_PTR:.*]] = builtin.unrealized_conversion_cast %arg0 : memref<2x2xi32, #spirv.storage_class> to !spirv.ptr>, UniformConstant> - // CHECK: %[[X:.*]] = arith.constant 0 : index + // CHECK-SAME: %[[ARG0:.*]]: memref<2x3xi32, #spirv.storage_class>, %[[ARG1:.*]]: memref<2x3xi32, #spirv.storage_class> + func.func @load_from_image_2D_i32(%arg0: memref<2x3xi32, #spirv.storage_class>, %arg1: memref<2x3xi32, #spirv.storage_class>) { +// CHECK-DAG: %[[SB:.*]] = builtin.unrealized_conversion_cast %arg1 : memref<2x3xi32, #spirv.storage_class> to !spirv.ptr [0])>, StorageBuffer> +// CHECK-DAG: %[[IMAGE_PTR:.*]] = builtin.unrealized_conversion_cast %arg0 : memref<2x3xi32, #spirv.storage_class> to !spirv.ptr>, UniformConstant> + // CHECK: %[[X:.*]] = arith.constant 2 : index // CHECK: %[[X32:.*]] = builtin.unrealized_conversion_cast %[[X]] : index to i32 - %x = arith.constant 0 : index + %x = arith.constant 2 : index // CHECK: %[[Y:.*]] = arith.constant 1 : index // CHECK: %[[Y32:.*]] = builtin.unrealized_conversion_cast %[[Y]] : index to i32 %y = arith.constant 1 : index @@ -634,20 +685,20 @@ module attributes { // CHECK: %[[COORDS:.*]] = spirv.CompositeConstruct %[[X32]], %[[Y32]] : (i32, i32) -> vector<2xi32> // CHECK: %[[RES_VEC:.*]] = spirv.ImageFetch %[[IMAGE]], %[[COORDS]] : !spirv.image, vector<2xi32> -> vector<4xi32> // CHECK: %[[RESULT:.*]] = spirv.CompositeExtract %[[RES_VEC]][0 : i32] : vector<4xi32> - %0 = memref.load %arg0[%y, %x] : memref<2x2xi32, #spirv.storage_class> + %0 = memref.load %arg0[%y, %x] : memref<2x3xi32, #spirv.storage_class> // CHECK: spirv.Store "StorageBuffer" %{{.*}}, %[[RESULT]] : i32 - memref.store %0, %arg1[%y, %x] : memref<2x2xi32, #spirv.storage_class> + memref.store %0, %arg1[%y, %x] : memref<2x3xi32, #spirv.storage_class> return } // CHECK-LABEL: @load_from_image_2D_ui32( - // CHECK-SAME: %[[ARG0:.*]]: memref<2x2xui32, #spirv.storage_class>, %[[ARG1:.*]]: memref<2x2xui32, #spirv.storage_class> - func.func @load_from_image_2D_ui32(%arg0: memref<2x2xui32, #spirv.storage_class>, %arg1: memref<2x2xui32, #spirv.storage_class>) { -// CHECK-DAG: %[[SB:.*]] = builtin.unrealized_conversion_cast %arg1 : memref<2x2xui32, #spirv.storage_class> to !spirv.ptr [0])>, StorageBuffer> -// CHECK-DAG: %[[IMAGE_PTR:.*]] = builtin.unrealized_conversion_cast %arg0 : memref<2x2xui32, #spirv.storage_class> to !spirv.ptr>, UniformConstant> - // CHECK: %[[X:.*]] = arith.constant 0 : index + // CHECK-SAME: %[[ARG0:.*]]: memref<2x3xui32, #spirv.storage_class>, %[[ARG1:.*]]: memref<2x3xui32, #spirv.storage_class> + func.func @load_from_image_2D_ui32(%arg0: memref<2x3xui32, #spirv.storage_class>, %arg1: memref<2x3xui32, #spirv.storage_class>) { +// CHECK-DAG: %[[SB:.*]] = builtin.unrealized_conversion_cast %arg1 : memref<2x3xui32, #spirv.storage_class> to !spirv.ptr [0])>, StorageBuffer> +// CHECK-DAG: %[[IMAGE_PTR:.*]] = builtin.unrealized_conversion_cast %arg0 : memref<2x3xui32, #spirv.storage_class> to !spirv.ptr>, UniformConstant> + // CHECK: %[[X:.*]] = arith.constant 2 : index // CHECK: %[[X32:.*]] = builtin.unrealized_conversion_cast %[[X]] : index to i32 - %x = arith.constant 0 : index + %x = arith.constant 2 : index // CHECK: %[[Y:.*]] = arith.constant 1 : index // CHECK: %[[Y32:.*]] = builtin.unrealized_conversion_cast %[[Y]] : index to i32 %y = arith.constant 1 : index @@ -656,20 +707,20 @@ module attributes { // CHECK: %[[COORDS:.*]] = spirv.CompositeConstruct %[[X32]], %[[Y32]] : (i32, i32) -> vector<2xi32> // CHECK: %[[RES_VEC:.*]] = spirv.ImageFetch %[[IMAGE]], %[[COORDS]] : !spirv.image, vector<2xi32> -> vector<4xui32> // CHECK: %[[RESULT:.*]] = spirv.CompositeExtract %[[RES_VEC]][0 : i32] : vector<4xui32> - %0 = memref.load %arg0[%y, %x] : memref<2x2xui32, #spirv.storage_class> + %0 = memref.load %arg0[%y, %x] : memref<2x3xui32, #spirv.storage_class> // CHECK: spirv.Store "StorageBuffer" %{{.*}}, %[[RESULT]] : ui32 - memref.store %0, %arg1[%y, %x] : memref<2x2xui32, #spirv.storage_class> + memref.store %0, %arg1[%y, %x] : memref<2x3xui32, #spirv.storage_class> return } // CHECK-LABEL: @load_from_image_2D_i16( - // CHECK-SAME: %[[ARG0:.*]]: memref<2x2xi16, #spirv.storage_class>, %[[ARG1:.*]]: memref<2x2xi16, #spirv.storage_class> - func.func @load_from_image_2D_i16(%arg0: memref<2x2xi16, #spirv.storage_class>, %arg1: memref<2x2xi16, #spirv.storage_class>) { -// CHECK-DAG: %[[SB:.*]] = builtin.unrealized_conversion_cast %arg1 : memref<2x2xi16, #spirv.storage_class> to !spirv.ptr [0])>, StorageBuffer> -// CHECK-DAG: %[[IMAGE_PTR:.*]] = builtin.unrealized_conversion_cast %arg0 : memref<2x2xi16, #spirv.storage_class> to !spirv.ptr>, UniformConstant> - // CHECK: %[[X:.*]] = arith.constant 0 : index + // CHECK-SAME: %[[ARG0:.*]]: memref<2x3xi16, #spirv.storage_class>, %[[ARG1:.*]]: memref<2x3xi16, #spirv.storage_class> + func.func @load_from_image_2D_i16(%arg0: memref<2x3xi16, #spirv.storage_class>, %arg1: memref<2x3xi16, #spirv.storage_class>) { +// CHECK-DAG: %[[SB:.*]] = builtin.unrealized_conversion_cast %arg1 : memref<2x3xi16, #spirv.storage_class> to !spirv.ptr [0])>, StorageBuffer> +// CHECK-DAG: %[[IMAGE_PTR:.*]] = builtin.unrealized_conversion_cast %arg0 : memref<2x3xi16, #spirv.storage_class> to !spirv.ptr>, UniformConstant> + // CHECK: %[[X:.*]] = arith.constant 2 : index // CHECK: %[[X32:.*]] = builtin.unrealized_conversion_cast %[[X]] : index to i32 - %x = arith.constant 0 : index + %x = arith.constant 2 : index // CHECK: %[[Y:.*]] = arith.constant 1 : index // CHECK: %[[Y32:.*]] = builtin.unrealized_conversion_cast %[[Y]] : index to i32 %y = arith.constant 1 : index @@ -678,20 +729,20 @@ module attributes { // CHECK: %[[COORDS:.*]] = spirv.CompositeConstruct %[[X32]], %[[Y32]] : (i32, i32) -> vector<2xi32> // CHECK: %[[RES_VEC:.*]] = spirv.ImageFetch %[[IMAGE]], %[[COORDS]] : !spirv.image, vector<2xi32> -> vector<4xi16> // CHECK: %[[RESULT:.*]] = spirv.CompositeExtract %[[RES_VEC]][0 : i32] : vector<4xi16> - %0 = memref.load %arg0[%y, %x] : memref<2x2xi16, #spirv.storage_class> + %0 = memref.load %arg0[%y, %x] : memref<2x3xi16, #spirv.storage_class> // CHECK: spirv.Store "StorageBuffer" %{{.*}}, %[[RESULT]] : i16 - memref.store %0, %arg1[%y, %x] : memref<2x2xi16, #spirv.storage_class> + memref.store %0, %arg1[%y, %x] : memref<2x3xi16, #spirv.storage_class> return } // CHECK-LABEL: @load_from_image_2D_ui16( - // CHECK-SAME: %[[ARG0:.*]]: memref<2x2xui16, #spirv.storage_class>, %[[ARG1:.*]]: memref<2x2xui16, #spirv.storage_class> - func.func @load_from_image_2D_ui16(%arg0: memref<2x2xui16, #spirv.storage_class>, %arg1: memref<2x2xui16, #spirv.storage_class>) { -// CHECK-DAG: %[[SB:.*]] = builtin.unrealized_conversion_cast %arg1 : memref<2x2xui16, #spirv.storage_class> to !spirv.ptr [0])>, StorageBuffer> -// CHECK-DAG: %[[IMAGE_PTR:.*]] = builtin.unrealized_conversion_cast %arg0 : memref<2x2xui16, #spirv.storage_class> to !spirv.ptr>, UniformConstant> - // CHECK: %[[X:.*]] = arith.constant 0 : index + // CHECK-SAME: %[[ARG0:.*]]: memref<2x3xui16, #spirv.storage_class>, %[[ARG1:.*]]: memref<2x3xui16, #spirv.storage_class> + func.func @load_from_image_2D_ui16(%arg0: memref<2x3xui16, #spirv.storage_class>, %arg1: memref<2x3xui16, #spirv.storage_class>) { +// CHECK-DAG: %[[SB:.*]] = builtin.unrealized_conversion_cast %arg1 : memref<2x3xui16, #spirv.storage_class> to !spirv.ptr [0])>, StorageBuffer> +// CHECK-DAG: %[[IMAGE_PTR:.*]] = builtin.unrealized_conversion_cast %arg0 : memref<2x3xui16, #spirv.storage_class> to !spirv.ptr>, UniformConstant> + // CHECK: %[[X:.*]] = arith.constant 2 : index // CHECK: %[[X32:.*]] = builtin.unrealized_conversion_cast %[[X]] : index to i32 - %x = arith.constant 0 : index + %x = arith.constant 2 : index // CHECK: %[[Y:.*]] = arith.constant 1 : index // CHECK: %[[Y32:.*]] = builtin.unrealized_conversion_cast %[[Y]] : index to i32 %y = arith.constant 1 : index @@ -700,9 +751,9 @@ module attributes { // CHECK: %[[COORDS:.*]] = spirv.CompositeConstruct %[[X32]], %[[Y32]] : (i32, i32) -> vector<2xi32> // CHECK: %[[RES_VEC:.*]] = spirv.ImageFetch %[[IMAGE]], %[[COORDS]] : !spirv.image, vector<2xi32> -> vector<4xui16> // CHECK: %[[RESULT:.*]] = spirv.CompositeExtract %[[RES_VEC]][0 : i32] : vector<4xui16> - %0 = memref.load %arg0[%y, %x] : memref<2x2xui16, #spirv.storage_class> + %0 = memref.load %arg0[%y, %x] : memref<2x3xui16, #spirv.storage_class> // CHECK: spirv.Store "StorageBuffer" %{{.*}}, %[[RESULT]] : ui16 - memref.store %0, %arg1[%y, %x] : memref<2x2xui16, #spirv.storage_class> + memref.store %0, %arg1[%y, %x] : memref<2x3xui16, #spirv.storage_class> return } From ae1c1303370b3c988347cb50321733f3d40807a6 Mon Sep 17 00:00:00 2001 From: Jack Frankland Date: Mon, 29 Sep 2025 10:44:32 +0100 Subject: [PATCH 4/7] Address Feedback Update comments to reflect support for linearly tiled images only. Signed-off-by: Jack Frankland --- mlir/lib/Conversion/MemRefToSPIRV/MemRefToSPIRV.cpp | 11 +++++------ 1 file changed, 5 insertions(+), 6 deletions(-) diff --git a/mlir/lib/Conversion/MemRefToSPIRV/MemRefToSPIRV.cpp b/mlir/lib/Conversion/MemRefToSPIRV/MemRefToSPIRV.cpp index 0dc7f693a6cc8..0f56c11748da9 100644 --- a/mlir/lib/Conversion/MemRefToSPIRV/MemRefToSPIRV.cpp +++ b/mlir/lib/Conversion/MemRefToSPIRV/MemRefToSPIRV.cpp @@ -703,12 +703,11 @@ template static FailureOr> extractLoadCoordsForComposite(memref::LoadOp loadOp, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) { - // Texel coordinates are ordered from inner most to outer most dimension - // i.e. u, v, w, a where: - // u: Coordinate in the first dimension of an image. - // v: Coordinate in the second dimension of an image. - // w: Coordinate in the third dimension of an image. - // a: Coordinate for array layer. + // At present we only support linear "tiling" as specified in Vulkan, this + // means that texels are assumed to be laid out in memory in a row-major + // order. This allows us to support any memref layout that is a permutation of + // the dimensions. Future work will pass an optional image layout to the + // rewrite pattern so that we can support optimized target specific tilings. // // The memrefs layout determines the dimension ordering so we need to invert // the map to get the ordering. From a6e360ef79c0d8bb877a1f3ecd8f14a036eadf1c Mon Sep 17 00:00:00 2001 From: Jack Frankland Date: Mon, 29 Sep 2025 11:47:02 +0100 Subject: [PATCH 5/7] Address Feedback Make map variables globally scoped in lit tests. Signed-off-by: Jack Frankland --- .../MemRefToSPIRV/memref-to-spirv.mlir | 48 +++++++++---------- 1 file changed, 24 insertions(+), 24 deletions(-) diff --git a/mlir/test/Conversion/MemRefToSPIRV/memref-to-spirv.mlir b/mlir/test/Conversion/MemRefToSPIRV/memref-to-spirv.mlir index 0f7542cd9c469..1c95ec2389ed1 100644 --- a/mlir/test/Conversion/MemRefToSPIRV/memref-to-spirv.mlir +++ b/mlir/test/Conversion/MemRefToSPIRV/memref-to-spirv.mlir @@ -515,9 +515,9 @@ module attributes { // Check Image Support. -// CHECK: #[[COLMAJMAP:[a-z_]+]] = affine_map<(d0, d1) -> (d1, d0)> +// CHECK: #[[$COLMAJMAP:.*]] = affine_map<(d0, d1) -> (d1, d0)> #col_major = affine_map<(d0, d1) -> (d1, d0)> -// CHECK: #[[CUSTOMLAYOUTMAP:[a-z0-9_]+]] = affine_map<(d0, d1, d2) -> (d2, d1, d0)> +// CHECK: #[[$CUSTOMLAYOUTMAP:.*]] = affine_map<(d0, d1, d2) -> (d2, d1, d0)> #custom = affine_map<(d0, d1, d2) -> (d2, d1, d0)> module attributes { spirv.target_env = #spirv.target_env<#spirv.vce>, %[[ARG1:.*]]: memref<1xf32, #spirv.storage_class> func.func @load_from_image_1D(%arg0: memref<1xf32, #spirv.storage_class>, %arg1: memref<1xf32, #spirv.storage_class>) { -// CHECK-DAG: %[[SB:.*]] = builtin.unrealized_conversion_cast %arg1 : memref<1xf32, #spirv.storage_class> to !spirv.ptr [0])>, StorageBuffer> -// CHECK-DAG: %[[IMAGE_PTR:.*]] = builtin.unrealized_conversion_cast %arg0 : memref<1xf32, #spirv.storage_class> to !spirv.ptr>, UniformConstant> +// CHECK-DAG: %[[SB:.*]] = builtin.unrealized_conversion_cast %[[ARG1]] : memref<1xf32, #spirv.storage_class> to !spirv.ptr [0])>, StorageBuffer> +// CHECK-DAG: %[[IMAGE_PTR:.*]] = builtin.unrealized_conversion_cast %[[ARG0]] : memref<1xf32, #spirv.storage_class> to !spirv.ptr>, UniformConstant> %cst = arith.constant 0 : index // CHECK: %[[COORDS:.*]] = builtin.unrealized_conversion_cast %{{.*}} : index to i32 // CHECK: %[[SIMAGE:.*]] = spirv.Load "UniformConstant" %[[IMAGE_PTR]] : !spirv.sampled_image> @@ -556,8 +556,8 @@ module attributes { // CHECK-LABEL: @load_from_image_2D( // CHECK-SAME: %[[ARG0:.*]]: memref<2x4xf32, #spirv.storage_class>, %[[ARG1:.*]]: memref<2x4xf32, #spirv.storage_class> func.func @load_from_image_2D(%arg0: memref<2x4xf32, #spirv.storage_class>, %arg1: memref<2x4xf32, #spirv.storage_class>) { -// CHECK-DAG: %[[SB:.*]] = builtin.unrealized_conversion_cast %arg1 : memref<2x4xf32, #spirv.storage_class> to !spirv.ptr [0])>, StorageBuffer> -// CHECK-DAG: %[[IMAGE_PTR:.*]] = builtin.unrealized_conversion_cast %arg0 : memref<2x4xf32, #spirv.storage_class> to !spirv.ptr>, UniformConstant> +// CHECK-DAG: %[[SB:.*]] = builtin.unrealized_conversion_cast %[[ARG1]] : memref<2x4xf32, #spirv.storage_class> to !spirv.ptr [0])>, StorageBuffer> +// CHECK-DAG: %[[IMAGE_PTR:.*]] = builtin.unrealized_conversion_cast %[[ARG0]] : memref<2x4xf32, #spirv.storage_class> to !spirv.ptr>, UniformConstant> // CHECK: %[[X:.*]] = arith.constant 3 : index // CHECK: %[[X32:.*]] = builtin.unrealized_conversion_cast %[[X]] : index to i32 %x = arith.constant 3 : index @@ -576,10 +576,10 @@ module attributes { } // CHECK-LABEL: @load_from_col_major_image_2D( - // CHECK-SAME: %[[ARG0:.*]]: memref<2x4xf32, #[[COLMAJMAP]], #spirv.storage_class>, %[[ARG1:.*]]: memref<2x4xf32, #spirv.storage_class> + // CHECK-SAME: %[[ARG0:.*]]: memref<2x4xf32, #[[$COLMAJMAP]], #spirv.storage_class>, %[[ARG1:.*]]: memref<2x4xf32, #spirv.storage_class> func.func @load_from_col_major_image_2D(%arg0: memref<2x4xf32, #col_major, #spirv.storage_class>, %arg1: memref<2x4xf32, #spirv.storage_class>) { -// CHECK-DAG: %[[SB:.*]] = builtin.unrealized_conversion_cast %arg1 : memref<2x4xf32, #spirv.storage_class> to !spirv.ptr [0])>, StorageBuffer> -// CHECK-DAG: %[[IMAGE_PTR:.*]] = builtin.unrealized_conversion_cast %arg0 : memref<2x4xf32, #[[COLMAJMAP]], #spirv.storage_class> to !spirv.ptr>, UniformConstant> +// CHECK-DAG: %[[SB:.*]] = builtin.unrealized_conversion_cast %[[ARG1]] : memref<2x4xf32, #spirv.storage_class> to !spirv.ptr [0])>, StorageBuffer> +// CHECK-DAG: %[[IMAGE_PTR:.*]] = builtin.unrealized_conversion_cast %[[ARG0]] : memref<2x4xf32, #[[$COLMAJMAP]], #spirv.storage_class> to !spirv.ptr>, UniformConstant> // CHECK: %[[X:.*]] = arith.constant 3 : index // CHECK: %[[X32:.*]] = builtin.unrealized_conversion_cast %[[X]] : index to i32 %x = arith.constant 3 : index @@ -600,8 +600,8 @@ module attributes { // CHECK-LABEL: @load_from_image_3D( // CHECK-SAME: %[[ARG0:.*]]: memref<2x3x4xf32, #spirv.storage_class>, %[[ARG1:.*]]: memref<2x3x4xf32, #spirv.storage_class> func.func @load_from_image_3D(%arg0: memref<2x3x4xf32, #spirv.storage_class>, %arg1: memref<2x3x4xf32, #spirv.storage_class>) { -// CHECK-DAG: %[[SB:.*]] = builtin.unrealized_conversion_cast %arg1 : memref<2x3x4xf32, #spirv.storage_class> to !spirv.ptr [0])>, StorageBuffer> -// CHECK-DAG: %[[IMAGE_PTR:.*]] = builtin.unrealized_conversion_cast %arg0 : memref<2x3x4xf32, #spirv.storage_class> to !spirv.ptr>, UniformConstant> +// CHECK-DAG: %[[SB:.*]] = builtin.unrealized_conversion_cast %[[ARG1]] : memref<2x3x4xf32, #spirv.storage_class> to !spirv.ptr [0])>, StorageBuffer> +// CHECK-DAG: %[[IMAGE_PTR:.*]] = builtin.unrealized_conversion_cast %[[ARG0]] : memref<2x3x4xf32, #spirv.storage_class> to !spirv.ptr>, UniformConstant> // CHECK: %[[X:.*]] = arith.constant 3 : index // CHECK: %[[X32:.*]] = builtin.unrealized_conversion_cast %[[X]] : index to i32 %x = arith.constant 3 : index @@ -623,10 +623,10 @@ module attributes { } // CHECK-LABEL: @load_from_custom_layout_image_3D( - // CHECK-SAME: %[[ARG0:.*]]: memref<2x3x4xf32, #[[CUSTOMLAYOUTMAP]], #spirv.storage_class>, %[[ARG1:.*]]: memref<2x3x4xf32, #spirv.storage_class> + // CHECK-SAME: %[[ARG0:.*]]: memref<2x3x4xf32, #[[$CUSTOMLAYOUTMAP]], #spirv.storage_class>, %[[ARG1:.*]]: memref<2x3x4xf32, #spirv.storage_class> func.func @load_from_custom_layout_image_3D(%arg0: memref<2x3x4xf32, #custom, #spirv.storage_class>, %arg1: memref<2x3x4xf32, #spirv.storage_class>) { -// CHECK-DAG: %[[SB:.*]] = builtin.unrealized_conversion_cast %arg1 : memref<2x3x4xf32, #spirv.storage_class> to !spirv.ptr [0])>, StorageBuffer> -// CHECK-DAG: %[[IMAGE_PTR:.*]] = builtin.unrealized_conversion_cast %arg0 : memref<2x3x4xf32, #[[CUSTOMLAYOUTMAP]], #spirv.storage_class> to !spirv.ptr>, UniformConstant> +// CHECK-DAG: %[[SB:.*]] = builtin.unrealized_conversion_cast %[[ARG1]] : memref<2x3x4xf32, #spirv.storage_class> to !spirv.ptr [0])>, StorageBuffer> +// CHECK-DAG: %[[IMAGE_PTR:.*]] = builtin.unrealized_conversion_cast %[[ARG0]] : memref<2x3x4xf32, #[[$CUSTOMLAYOUTMAP]], #spirv.storage_class> to !spirv.ptr>, UniformConstant> // CHECK: %[[X:.*]] = arith.constant 3 : index // CHECK: %[[X32:.*]] = builtin.unrealized_conversion_cast %[[X]] : index to i32 %x = arith.constant 3 : index @@ -650,8 +650,8 @@ module attributes { // CHECK-LABEL: @load_from_image_2D_f16( // CHECK-SAME: %[[ARG0:.*]]: memref<2x3xf16, #spirv.storage_class>, %[[ARG1:.*]]: memref<2x3xf16, #spirv.storage_class> func.func @load_from_image_2D_f16(%arg0: memref<2x3xf16, #spirv.storage_class>, %arg1: memref<2x3xf16, #spirv.storage_class>) { -// CHECK-DAG: %[[SB:.*]] = builtin.unrealized_conversion_cast %arg1 : memref<2x3xf16, #spirv.storage_class> to !spirv.ptr [0])>, StorageBuffer> -// CHECK-DAG: %[[IMAGE_PTR:.*]] = builtin.unrealized_conversion_cast %arg0 : memref<2x3xf16, #spirv.storage_class> to !spirv.ptr>, UniformConstant> +// CHECK-DAG: %[[SB:.*]] = builtin.unrealized_conversion_cast %[[ARG1]] : memref<2x3xf16, #spirv.storage_class> to !spirv.ptr [0])>, StorageBuffer> +// CHECK-DAG: %[[IMAGE_PTR:.*]] = builtin.unrealized_conversion_cast %[[ARG0]] : memref<2x3xf16, #spirv.storage_class> to !spirv.ptr>, UniformConstant> // CHECK: %[[X:.*]] = arith.constant 2 : index // CHECK: %[[X32:.*]] = builtin.unrealized_conversion_cast %[[X]] : index to i32 %x = arith.constant 2 : index @@ -672,8 +672,8 @@ module attributes { // CHECK-LABEL: @load_from_image_2D_i32( // CHECK-SAME: %[[ARG0:.*]]: memref<2x3xi32, #spirv.storage_class>, %[[ARG1:.*]]: memref<2x3xi32, #spirv.storage_class> func.func @load_from_image_2D_i32(%arg0: memref<2x3xi32, #spirv.storage_class>, %arg1: memref<2x3xi32, #spirv.storage_class>) { -// CHECK-DAG: %[[SB:.*]] = builtin.unrealized_conversion_cast %arg1 : memref<2x3xi32, #spirv.storage_class> to !spirv.ptr [0])>, StorageBuffer> -// CHECK-DAG: %[[IMAGE_PTR:.*]] = builtin.unrealized_conversion_cast %arg0 : memref<2x3xi32, #spirv.storage_class> to !spirv.ptr>, UniformConstant> +// CHECK-DAG: %[[SB:.*]] = builtin.unrealized_conversion_cast %[[ARG1]] : memref<2x3xi32, #spirv.storage_class> to !spirv.ptr [0])>, StorageBuffer> +// CHECK-DAG: %[[IMAGE_PTR:.*]] = builtin.unrealized_conversion_cast %[[ARG0]] : memref<2x3xi32, #spirv.storage_class> to !spirv.ptr>, UniformConstant> // CHECK: %[[X:.*]] = arith.constant 2 : index // CHECK: %[[X32:.*]] = builtin.unrealized_conversion_cast %[[X]] : index to i32 %x = arith.constant 2 : index @@ -694,8 +694,8 @@ module attributes { // CHECK-LABEL: @load_from_image_2D_ui32( // CHECK-SAME: %[[ARG0:.*]]: memref<2x3xui32, #spirv.storage_class>, %[[ARG1:.*]]: memref<2x3xui32, #spirv.storage_class> func.func @load_from_image_2D_ui32(%arg0: memref<2x3xui32, #spirv.storage_class>, %arg1: memref<2x3xui32, #spirv.storage_class>) { -// CHECK-DAG: %[[SB:.*]] = builtin.unrealized_conversion_cast %arg1 : memref<2x3xui32, #spirv.storage_class> to !spirv.ptr [0])>, StorageBuffer> -// CHECK-DAG: %[[IMAGE_PTR:.*]] = builtin.unrealized_conversion_cast %arg0 : memref<2x3xui32, #spirv.storage_class> to !spirv.ptr>, UniformConstant> +// CHECK-DAG: %[[SB:.*]] = builtin.unrealized_conversion_cast %[[ARG1]] : memref<2x3xui32, #spirv.storage_class> to !spirv.ptr [0])>, StorageBuffer> +// CHECK-DAG: %[[IMAGE_PTR:.*]] = builtin.unrealized_conversion_cast %[[ARG0]] : memref<2x3xui32, #spirv.storage_class> to !spirv.ptr>, UniformConstant> // CHECK: %[[X:.*]] = arith.constant 2 : index // CHECK: %[[X32:.*]] = builtin.unrealized_conversion_cast %[[X]] : index to i32 %x = arith.constant 2 : index @@ -716,8 +716,8 @@ module attributes { // CHECK-LABEL: @load_from_image_2D_i16( // CHECK-SAME: %[[ARG0:.*]]: memref<2x3xi16, #spirv.storage_class>, %[[ARG1:.*]]: memref<2x3xi16, #spirv.storage_class> func.func @load_from_image_2D_i16(%arg0: memref<2x3xi16, #spirv.storage_class>, %arg1: memref<2x3xi16, #spirv.storage_class>) { -// CHECK-DAG: %[[SB:.*]] = builtin.unrealized_conversion_cast %arg1 : memref<2x3xi16, #spirv.storage_class> to !spirv.ptr [0])>, StorageBuffer> -// CHECK-DAG: %[[IMAGE_PTR:.*]] = builtin.unrealized_conversion_cast %arg0 : memref<2x3xi16, #spirv.storage_class> to !spirv.ptr>, UniformConstant> +// CHECK-DAG: %[[SB:.*]] = builtin.unrealized_conversion_cast %[[ARG1]] : memref<2x3xi16, #spirv.storage_class> to !spirv.ptr [0])>, StorageBuffer> +// CHECK-DAG: %[[IMAGE_PTR:.*]] = builtin.unrealized_conversion_cast %[[ARG0]] : memref<2x3xi16, #spirv.storage_class> to !spirv.ptr>, UniformConstant> // CHECK: %[[X:.*]] = arith.constant 2 : index // CHECK: %[[X32:.*]] = builtin.unrealized_conversion_cast %[[X]] : index to i32 %x = arith.constant 2 : index @@ -738,8 +738,8 @@ module attributes { // CHECK-LABEL: @load_from_image_2D_ui16( // CHECK-SAME: %[[ARG0:.*]]: memref<2x3xui16, #spirv.storage_class>, %[[ARG1:.*]]: memref<2x3xui16, #spirv.storage_class> func.func @load_from_image_2D_ui16(%arg0: memref<2x3xui16, #spirv.storage_class>, %arg1: memref<2x3xui16, #spirv.storage_class>) { -// CHECK-DAG: %[[SB:.*]] = builtin.unrealized_conversion_cast %arg1 : memref<2x3xui16, #spirv.storage_class> to !spirv.ptr [0])>, StorageBuffer> -// CHECK-DAG: %[[IMAGE_PTR:.*]] = builtin.unrealized_conversion_cast %arg0 : memref<2x3xui16, #spirv.storage_class> to !spirv.ptr>, UniformConstant> +// CHECK-DAG: %[[SB:.*]] = builtin.unrealized_conversion_cast %[[ARG1]] : memref<2x3xui16, #spirv.storage_class> to !spirv.ptr [0])>, StorageBuffer> +// CHECK-DAG: %[[IMAGE_PTR:.*]] = builtin.unrealized_conversion_cast %[[ARG0]] : memref<2x3xui16, #spirv.storage_class> to !spirv.ptr>, UniformConstant> // CHECK: %[[X:.*]] = arith.constant 2 : index // CHECK: %[[X32:.*]] = builtin.unrealized_conversion_cast %[[X]] : index to i32 %x = arith.constant 2 : index From 94ef93ace623e3a790614aa4c54aa9468533b571 Mon Sep 17 00:00:00 2001 From: Jack Frankland Date: Mon, 29 Sep 2025 17:22:24 +0100 Subject: [PATCH 6/7] Address Feedback * Remove `auto` usage * Expand comment Signed-off-by: Jack Frankland --- .../Conversion/MemRefToSPIRV/MemRefToSPIRV.cpp | 16 ++++++++-------- 1 file changed, 8 insertions(+), 8 deletions(-) diff --git a/mlir/lib/Conversion/MemRefToSPIRV/MemRefToSPIRV.cpp b/mlir/lib/Conversion/MemRefToSPIRV/MemRefToSPIRV.cpp index 0f56c11748da9..a90dcc8cc3ef1 100644 --- a/mlir/lib/Conversion/MemRefToSPIRV/MemRefToSPIRV.cpp +++ b/mlir/lib/Conversion/MemRefToSPIRV/MemRefToSPIRV.cpp @@ -708,24 +708,23 @@ extractLoadCoordsForComposite(memref::LoadOp loadOp, OpAdaptor adaptor, // order. This allows us to support any memref layout that is a permutation of // the dimensions. Future work will pass an optional image layout to the // rewrite pattern so that we can support optimized target specific tilings. - // - // The memrefs layout determines the dimension ordering so we need to invert - // the map to get the ordering. SmallVector indices = adaptor.getIndices(); - auto map = loadOp.getMemRefType().getLayout().getAffineMap(); + AffineMap map = loadOp.getMemRefType().getLayout().getAffineMap(); if (!map.isPermutation()) return rewriter.notifyMatchFailure( loadOp, "Cannot lower memrefs with memory layout which is not a permutation"); + // The memrefs layout determines the dimension ordering so we need to follow + // the map to get the ordering of the dimensions/indices. const unsigned dimCount = map.getNumDims(); SmallVector coords(dimCount); for (unsigned dim = 0; dim < dimCount; ++dim) coords[map.getDimPosition(dim)] = indices[dim]; - // We need to do a final reversal since the image fetch op expects the first - // dimension in the 0th element position, 2nd dimension in the 1st element - // position etc. which is the opposite to the ordering in the map. + // We need to reverse the coordinates because the memref layout is slowest to + // fastest moving and the vector coordinates for the image op is fastest to + // slowest moving. return llvm::to_vector(llvm::reverse(coords)); } @@ -788,7 +787,8 @@ ImageLoadOpPattern::matchAndRewrite(memref::LoadOp loadOp, OpAdaptor adaptor, if (memrefType.getRank() == 1) { coords = adaptor.getIndices()[0]; } else { - auto maybeCoords = extractLoadCoordsForComposite(loadOp, adaptor, rewriter); + FailureOr> maybeCoords = + extractLoadCoordsForComposite(loadOp, adaptor, rewriter); if (failed(maybeCoords)) return failure(); auto coordVectorType = VectorType::get({loadOp.getMemRefType().getRank()}, From 83a6276ee1d3c3cdde9a2530d78e7ce0d95a3ed1 Mon Sep 17 00:00:00 2001 From: Jack Frankland Date: Mon, 29 Sep 2025 17:52:09 +0100 Subject: [PATCH 7/7] Address Feedback * Add negative test Signed-off-by: Jack Frankland --- .../Conversion/MemRefToSPIRV/memref-to-spirv.mlir | 13 +++++++++++++ 1 file changed, 13 insertions(+) diff --git a/mlir/test/Conversion/MemRefToSPIRV/memref-to-spirv.mlir b/mlir/test/Conversion/MemRefToSPIRV/memref-to-spirv.mlir index 1c95ec2389ed1..ab3c8b7397e1a 100644 --- a/mlir/test/Conversion/MemRefToSPIRV/memref-to-spirv.mlir +++ b/mlir/test/Conversion/MemRefToSPIRV/memref-to-spirv.mlir @@ -519,6 +519,8 @@ module attributes { #col_major = affine_map<(d0, d1) -> (d1, d0)> // CHECK: #[[$CUSTOMLAYOUTMAP:.*]] = affine_map<(d0, d1, d2) -> (d2, d1, d0)> #custom = affine_map<(d0, d1, d2) -> (d2, d1, d0)> +// CHECK: #[[$NONPERMMAP:.*]] = affine_map<(d0, d1) -> (d0, d1 mod 2)> +#non_permutation = affine_map<(d0, d1) -> (d0, d1 mod 2)> module attributes { spirv.target_env = #spirv.target_env<#spirv.vce, #spirv.storage_class> return } + + // CHECK-LABEL: @load_non_perm_layout( + func.func @load_non_perm_layout(%arg0: memref<2x4xf32, #non_permutation, #spirv.storage_class>, %arg1: memref<2x4xf32, #spirv.storage_class>) { + %x = arith.constant 3 : index + %y = arith.constant 1 : index + // CHECK-NOT: spirv.Image + // CHECK-NOT: spirv.ImageFetch + %0 = memref.load %arg0[%y, %x] : memref<2x4xf32, #non_permutation, #spirv.storage_class> + memref.store %0, %arg1[%y, %x] : memref<2x4xf32, #spirv.storage_class> + return + } }