diff --git a/mlir/lib/Conversion/MemRefToSPIRV/MemRefToSPIRV.cpp b/mlir/lib/Conversion/MemRefToSPIRV/MemRefToSPIRV.cpp index f44552c4556c2..a90dcc8cc3ef1 100644 --- a/mlir/lib/Conversion/MemRefToSPIRV/MemRefToSPIRV.cpp +++ b/mlir/lib/Conversion/MemRefToSPIRV/MemRefToSPIRV.cpp @@ -699,6 +699,35 @@ LoadOpPattern::matchAndRewrite(memref::LoadOp loadOp, OpAdaptor adaptor, return success(); } +template +static FailureOr> +extractLoadCoordsForComposite(memref::LoadOp loadOp, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) { + // At present we only support linear "tiling" as specified in Vulkan, this + // means that texels are assumed to be laid out in memory in a row-major + // order. This allows us to support any memref layout that is a permutation of + // the dimensions. Future work will pass an optional image layout to the + // rewrite pattern so that we can support optimized target specific tilings. + SmallVector indices = adaptor.getIndices(); + AffineMap map = loadOp.getMemRefType().getLayout().getAffineMap(); + if (!map.isPermutation()) + return rewriter.notifyMatchFailure( + loadOp, + "Cannot lower memrefs with memory layout which is not a permutation"); + + // The memrefs layout determines the dimension ordering so we need to follow + // the map to get the ordering of the dimensions/indices. + const unsigned dimCount = map.getNumDims(); + SmallVector coords(dimCount); + for (unsigned dim = 0; dim < dimCount; ++dim) + coords[map.getDimPosition(dim)] = indices[dim]; + + // We need to reverse the coordinates because the memref layout is slowest to + // fastest moving and the vector coordinates for the image op is fastest to + // slowest moving. + return llvm::to_vector(llvm::reverse(coords)); +} + LogicalResult ImageLoadOpPattern::matchAndRewrite(memref::LoadOp loadOp, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const { @@ -755,13 +784,17 @@ ImageLoadOpPattern::matchAndRewrite(memref::LoadOp loadOp, OpAdaptor adaptor, // Build a vector of coordinates or just a scalar index if we have a 1D image. Value coords; - if (memrefType.getRank() != 1) { + if (memrefType.getRank() == 1) { + coords = adaptor.getIndices()[0]; + } else { + FailureOr> maybeCoords = + extractLoadCoordsForComposite(loadOp, adaptor, rewriter); + if (failed(maybeCoords)) + return failure(); auto coordVectorType = VectorType::get({loadOp.getMemRefType().getRank()}, adaptor.getIndices().getType()[0]); coords = spirv::CompositeConstructOp::create(rewriter, loc, coordVectorType, - adaptor.getIndices()); - } else { - coords = adaptor.getIndices()[0]; + maybeCoords.value()); } // Fetch the value out of the image. diff --git a/mlir/test/Conversion/MemRefToSPIRV/memref-to-spirv.mlir b/mlir/test/Conversion/MemRefToSPIRV/memref-to-spirv.mlir index e6321e99693ac..ab3c8b7397e1a 100644 --- a/mlir/test/Conversion/MemRefToSPIRV/memref-to-spirv.mlir +++ b/mlir/test/Conversion/MemRefToSPIRV/memref-to-spirv.mlir @@ -515,6 +515,12 @@ module attributes { // Check Image Support. +// CHECK: #[[$COLMAJMAP:.*]] = affine_map<(d0, d1) -> (d1, d0)> +#col_major = affine_map<(d0, d1) -> (d1, d0)> +// CHECK: #[[$CUSTOMLAYOUTMAP:.*]] = affine_map<(d0, d1, d2) -> (d2, d1, d0)> +#custom = affine_map<(d0, d1, d2) -> (d2, d1, d0)> +// CHECK: #[[$NONPERMMAP:.*]] = affine_map<(d0, d1) -> (d0, d1 mod 2)> +#non_permutation = affine_map<(d0, d1) -> (d0, d1 mod 2)> module attributes { spirv.target_env = #spirv.target_env<#spirv.vce>, %[[ARG1:.*]]: memref<1xf32, #spirv.storage_class> func.func @load_from_image_1D(%arg0: memref<1xf32, #spirv.storage_class>, %arg1: memref<1xf32, #spirv.storage_class>) { -// CHECK-DAG: %[[SB:.*]] = builtin.unrealized_conversion_cast %arg1 : memref<1xf32, #spirv.storage_class> to !spirv.ptr [0])>, StorageBuffer> -// CHECK-DAG: %[[IMAGE_PTR:.*]] = builtin.unrealized_conversion_cast %arg0 : memref<1xf32, #spirv.storage_class> to !spirv.ptr>, UniformConstant> +// CHECK-DAG: %[[SB:.*]] = builtin.unrealized_conversion_cast %[[ARG1]] : memref<1xf32, #spirv.storage_class> to !spirv.ptr [0])>, StorageBuffer> +// CHECK-DAG: %[[IMAGE_PTR:.*]] = builtin.unrealized_conversion_cast %[[ARG0]] : memref<1xf32, #spirv.storage_class> to !spirv.ptr>, UniformConstant> %cst = arith.constant 0 : index // CHECK: %[[COORDS:.*]] = builtin.unrealized_conversion_cast %{{.*}} : index to i32 // CHECK: %[[SIMAGE:.*]] = spirv.Load "UniformConstant" %[[IMAGE_PTR]] : !spirv.sampled_image> @@ -550,121 +556,206 @@ module attributes { } // CHECK-LABEL: @load_from_image_2D( - // CHECK-SAME: %[[ARG0:.*]]: memref<1x1xf32, #spirv.storage_class>, %[[ARG1:.*]]: memref<1x1xf32, #spirv.storage_class> - func.func @load_from_image_2D(%arg0: memref<1x1xf32, #spirv.storage_class>, %arg1: memref<1x1xf32, #spirv.storage_class>) { -// CHECK-DAG: %[[SB:.*]] = builtin.unrealized_conversion_cast %arg1 : memref<1x1xf32, #spirv.storage_class> to !spirv.ptr [0])>, StorageBuffer> -// CHECK-DAG: %[[IMAGE_PTR:.*]] = builtin.unrealized_conversion_cast %arg0 : memref<1x1xf32, #spirv.storage_class> to !spirv.ptr>, UniformConstant> - %cst = arith.constant 0 : index + // CHECK-SAME: %[[ARG0:.*]]: memref<2x4xf32, #spirv.storage_class>, %[[ARG1:.*]]: memref<2x4xf32, #spirv.storage_class> + func.func @load_from_image_2D(%arg0: memref<2x4xf32, #spirv.storage_class>, %arg1: memref<2x4xf32, #spirv.storage_class>) { +// CHECK-DAG: %[[SB:.*]] = builtin.unrealized_conversion_cast %[[ARG1]] : memref<2x4xf32, #spirv.storage_class> to !spirv.ptr [0])>, StorageBuffer> +// CHECK-DAG: %[[IMAGE_PTR:.*]] = builtin.unrealized_conversion_cast %[[ARG0]] : memref<2x4xf32, #spirv.storage_class> to !spirv.ptr>, UniformConstant> + // CHECK: %[[X:.*]] = arith.constant 3 : index + // CHECK: %[[X32:.*]] = builtin.unrealized_conversion_cast %[[X]] : index to i32 + %x = arith.constant 3 : index + // CHECK: %[[Y:.*]] = arith.constant 1 : index + // CHECK: %[[Y32:.*]] = builtin.unrealized_conversion_cast %[[Y]] : index to i32 + %y = arith.constant 1 : index + // CHECK: %[[SIMAGE:.*]] = spirv.Load "UniformConstant" %[[IMAGE_PTR]] : !spirv.sampled_image> + // CHECK: %[[IMAGE:.*]] = spirv.Image %[[SIMAGE]] : !spirv.sampled_image> + // CHECK: %[[COORDS:.*]] = spirv.CompositeConstruct %[[X32]], %[[Y32]] : (i32, i32) -> vector<2xi32> + // CHECK: %[[RES_VEC:.*]] = spirv.ImageFetch %[[IMAGE]], %[[COORDS]] : !spirv.image, vector<2xi32> -> vector<4xf32> + // CHECK: %[[RESULT:.*]] = spirv.CompositeExtract %[[RES_VEC]][0 : i32] : vector<4xf32> + %0 = memref.load %arg0[%y, %x] : memref<2x4xf32, #spirv.storage_class> + // CHECK: spirv.Store "StorageBuffer" %{{.*}}, %[[RESULT]] : f32 + memref.store %0, %arg1[%y, %x] : memref<2x4xf32, #spirv.storage_class> + return + } + + // CHECK-LABEL: @load_from_col_major_image_2D( + // CHECK-SAME: %[[ARG0:.*]]: memref<2x4xf32, #[[$COLMAJMAP]], #spirv.storage_class>, %[[ARG1:.*]]: memref<2x4xf32, #spirv.storage_class> + func.func @load_from_col_major_image_2D(%arg0: memref<2x4xf32, #col_major, #spirv.storage_class>, %arg1: memref<2x4xf32, #spirv.storage_class>) { +// CHECK-DAG: %[[SB:.*]] = builtin.unrealized_conversion_cast %[[ARG1]] : memref<2x4xf32, #spirv.storage_class> to !spirv.ptr [0])>, StorageBuffer> +// CHECK-DAG: %[[IMAGE_PTR:.*]] = builtin.unrealized_conversion_cast %[[ARG0]] : memref<2x4xf32, #[[$COLMAJMAP]], #spirv.storage_class> to !spirv.ptr>, UniformConstant> + // CHECK: %[[X:.*]] = arith.constant 3 : index + // CHECK: %[[X32:.*]] = builtin.unrealized_conversion_cast %[[X]] : index to i32 + %x = arith.constant 3 : index + // CHECK: %[[Y:.*]] = arith.constant 1 : index + // CHECK: %[[Y32:.*]] = builtin.unrealized_conversion_cast %[[Y]] : index to i32 + %y = arith.constant 1 : index // CHECK: %[[SIMAGE:.*]] = spirv.Load "UniformConstant" %[[IMAGE_PTR]] : !spirv.sampled_image> // CHECK: %[[IMAGE:.*]] = spirv.Image %[[SIMAGE]] : !spirv.sampled_image> - // CHECK: %[[COORDS:.*]] = spirv.CompositeConstruct %{{.*}}, %{{.*}} : (i32, i32) -> vector<2xi32> + // CHECK: %[[COORDS:.*]] = spirv.CompositeConstruct %[[X32]], %[[Y32]] : (i32, i32) -> vector<2xi32> // CHECK: %[[RES_VEC:.*]] = spirv.ImageFetch %[[IMAGE]], %[[COORDS]] : !spirv.image, vector<2xi32> -> vector<4xf32> // CHECK: %[[RESULT:.*]] = spirv.CompositeExtract %[[RES_VEC]][0 : i32] : vector<4xf32> - %0 = memref.load %arg0[%cst, %cst] : memref<1x1xf32, #spirv.storage_class> + %0 = memref.load %arg0[%x, %y] : memref<2x4xf32, #col_major, #spirv.storage_class> // CHECK: spirv.Store "StorageBuffer" %{{.*}}, %[[RESULT]] : f32 - memref.store %0, %arg1[%cst, %cst] : memref<1x1xf32, #spirv.storage_class> + memref.store %0, %arg1[%y, %x] : memref<2x4xf32, #spirv.storage_class> return } // CHECK-LABEL: @load_from_image_3D( - // CHECK-SAME: %[[ARG0:.*]]: memref<1x1x1xf32, #spirv.storage_class>, %[[ARG1:.*]]: memref<1x1x1xf32, #spirv.storage_class> - func.func @load_from_image_3D(%arg0: memref<1x1x1xf32, #spirv.storage_class>, %arg1: memref<1x1x1xf32, #spirv.storage_class>) { -// CHECK-DAG: %[[SB:.*]] = builtin.unrealized_conversion_cast %arg1 : memref<1x1x1xf32, #spirv.storage_class> to !spirv.ptr [0])>, StorageBuffer> -// CHECK-DAG: %[[IMAGE_PTR:.*]] = builtin.unrealized_conversion_cast %arg0 : memref<1x1x1xf32, #spirv.storage_class> to !spirv.ptr>, UniformConstant> - %cst = arith.constant 0 : index + // CHECK-SAME: %[[ARG0:.*]]: memref<2x3x4xf32, #spirv.storage_class>, %[[ARG1:.*]]: memref<2x3x4xf32, #spirv.storage_class> + func.func @load_from_image_3D(%arg0: memref<2x3x4xf32, #spirv.storage_class>, %arg1: memref<2x3x4xf32, #spirv.storage_class>) { +// CHECK-DAG: %[[SB:.*]] = builtin.unrealized_conversion_cast %[[ARG1]] : memref<2x3x4xf32, #spirv.storage_class> to !spirv.ptr [0])>, StorageBuffer> +// CHECK-DAG: %[[IMAGE_PTR:.*]] = builtin.unrealized_conversion_cast %[[ARG0]] : memref<2x3x4xf32, #spirv.storage_class> to !spirv.ptr>, UniformConstant> + // CHECK: %[[X:.*]] = arith.constant 3 : index + // CHECK: %[[X32:.*]] = builtin.unrealized_conversion_cast %[[X]] : index to i32 + %x = arith.constant 3 : index + // CHECK: %[[Y:.*]] = arith.constant 2 : index + // CHECK: %[[Y32:.*]] = builtin.unrealized_conversion_cast %[[Y]] : index to i32 + %y = arith.constant 2 : index + // CHECK: %[[Z:.*]] = arith.constant 1 : index + // CHECK: %[[Z32:.*]] = builtin.unrealized_conversion_cast %[[Z]] : index to i32 + %z = arith.constant 1 : index + // CHECK: %[[SIMAGE:.*]] = spirv.Load "UniformConstant" %[[IMAGE_PTR]] : !spirv.sampled_image> + // CHECK: %[[IMAGE:.*]] = spirv.Image %[[SIMAGE]] : !spirv.sampled_image> + // CHECK: %[[COORDS:.*]] = spirv.CompositeConstruct %[[X32]], %[[Y32]], %[[Z32]] : (i32, i32, i32) -> vector<3xi32> + // CHECK: %[[RES_VEC:.*]] = spirv.ImageFetch %[[IMAGE]], %[[COORDS]] : !spirv.image, vector<3xi32> -> vector<4xf32> + // CHECK: %[[RESULT:.*]] = spirv.CompositeExtract %[[RES_VEC]][0 : i32] : vector<4xf32> + %0 = memref.load %arg0[%z, %y, %x] : memref<2x3x4xf32, #spirv.storage_class> + // CHECK: spirv.Store "StorageBuffer" %{{.*}}, %[[RESULT]] : f32 + memref.store %0, %arg1[%z, %y, %x] : memref<2x3x4xf32, #spirv.storage_class> + return + } + + // CHECK-LABEL: @load_from_custom_layout_image_3D( + // CHECK-SAME: %[[ARG0:.*]]: memref<2x3x4xf32, #[[$CUSTOMLAYOUTMAP]], #spirv.storage_class>, %[[ARG1:.*]]: memref<2x3x4xf32, #spirv.storage_class> + func.func @load_from_custom_layout_image_3D(%arg0: memref<2x3x4xf32, #custom, #spirv.storage_class>, %arg1: memref<2x3x4xf32, #spirv.storage_class>) { +// CHECK-DAG: %[[SB:.*]] = builtin.unrealized_conversion_cast %[[ARG1]] : memref<2x3x4xf32, #spirv.storage_class> to !spirv.ptr [0])>, StorageBuffer> +// CHECK-DAG: %[[IMAGE_PTR:.*]] = builtin.unrealized_conversion_cast %[[ARG0]] : memref<2x3x4xf32, #[[$CUSTOMLAYOUTMAP]], #spirv.storage_class> to !spirv.ptr>, UniformConstant> + // CHECK: %[[X:.*]] = arith.constant 3 : index + // CHECK: %[[X32:.*]] = builtin.unrealized_conversion_cast %[[X]] : index to i32 + %x = arith.constant 3 : index + // CHECK: %[[Y:.*]] = arith.constant 2 : index + // CHECK: %[[Y32:.*]] = builtin.unrealized_conversion_cast %[[Y]] : index to i32 + %y = arith.constant 2 : index + // CHECK: %[[Z:.*]] = arith.constant 1 : index + // CHECK: %[[Z32:.*]] = builtin.unrealized_conversion_cast %[[Z]] : index to i32 + %z = arith.constant 1 : index // CHECK: %[[SIMAGE:.*]] = spirv.Load "UniformConstant" %[[IMAGE_PTR]] : !spirv.sampled_image> // CHECK: %[[IMAGE:.*]] = spirv.Image %[[SIMAGE]] : !spirv.sampled_image> - // CHECK: %[[COORDS:.*]] = spirv.CompositeConstruct %{{.*}}, %{{.*}}, %{{.*}} : (i32, i32, i32) -> vector<3xi32> + // CHECK: %[[COORDS:.*]] = spirv.CompositeConstruct %[[X32]], %[[Y32]], %[[Z32]] : (i32, i32, i32) -> vector<3xi32> // CHECK: %[[RES_VEC:.*]] = spirv.ImageFetch %[[IMAGE]], %[[COORDS]] : !spirv.image, vector<3xi32> -> vector<4xf32> // CHECK: %[[RESULT:.*]] = spirv.CompositeExtract %[[RES_VEC]][0 : i32] : vector<4xf32> - %0 = memref.load %arg0[%cst, %cst, %cst] : memref<1x1x1xf32, #spirv.storage_class> + %0 = memref.load %arg0[%x, %y, %z] : memref<2x3x4xf32, #custom, #spirv.storage_class> // CHECK: spirv.Store "StorageBuffer" %{{.*}}, %[[RESULT]] : f32 - memref.store %0, %arg1[%cst, %cst, %cst] : memref<1x1x1xf32, #spirv.storage_class> + memref.store %0, %arg1[%z, %y, %x] : memref<2x3x4xf32, #spirv.storage_class> return } // CHECK-LABEL: @load_from_image_2D_f16( - // CHECK-SAME: %[[ARG0:.*]]: memref<1x1xf16, #spirv.storage_class>, %[[ARG1:.*]]: memref<1x1xf16, #spirv.storage_class> - func.func @load_from_image_2D_f16(%arg0: memref<1x1xf16, #spirv.storage_class>, %arg1: memref<1x1xf16, #spirv.storage_class>) { -// CHECK-DAG: %[[SB:.*]] = builtin.unrealized_conversion_cast %arg1 : memref<1x1xf16, #spirv.storage_class> to !spirv.ptr [0])>, StorageBuffer> -// CHECK-DAG: %[[IMAGE_PTR:.*]] = builtin.unrealized_conversion_cast %arg0 : memref<1x1xf16, #spirv.storage_class> to !spirv.ptr>, UniformConstant> - %cst = arith.constant 0 : index + // CHECK-SAME: %[[ARG0:.*]]: memref<2x3xf16, #spirv.storage_class>, %[[ARG1:.*]]: memref<2x3xf16, #spirv.storage_class> + func.func @load_from_image_2D_f16(%arg0: memref<2x3xf16, #spirv.storage_class>, %arg1: memref<2x3xf16, #spirv.storage_class>) { +// CHECK-DAG: %[[SB:.*]] = builtin.unrealized_conversion_cast %[[ARG1]] : memref<2x3xf16, #spirv.storage_class> to !spirv.ptr [0])>, StorageBuffer> +// CHECK-DAG: %[[IMAGE_PTR:.*]] = builtin.unrealized_conversion_cast %[[ARG0]] : memref<2x3xf16, #spirv.storage_class> to !spirv.ptr>, UniformConstant> + // CHECK: %[[X:.*]] = arith.constant 2 : index + // CHECK: %[[X32:.*]] = builtin.unrealized_conversion_cast %[[X]] : index to i32 + %x = arith.constant 2 : index + // CHECK: %[[Y:.*]] = arith.constant 1 : index + // CHECK: %[[Y32:.*]] = builtin.unrealized_conversion_cast %[[Y]] : index to i32 + %y = arith.constant 1 : index // CHECK: %[[SIMAGE:.*]] = spirv.Load "UniformConstant" %[[IMAGE_PTR]] : !spirv.sampled_image> // CHECK: %[[IMAGE:.*]] = spirv.Image %[[SIMAGE]] : !spirv.sampled_image> - // CHECK: %[[COORDS:.*]] = spirv.CompositeConstruct %{{.*}}, %{{.*}} : (i32, i32) -> vector<2xi32> + // CHECK: %[[COORDS:.*]] = spirv.CompositeConstruct %[[X32]], %[[Y32]] : (i32, i32) -> vector<2xi32> // CHECK: %[[RES_VEC:.*]] = spirv.ImageFetch %[[IMAGE]], %[[COORDS]] : !spirv.image, vector<2xi32> -> vector<4xf16> // CHECK: %[[RESULT:.*]] = spirv.CompositeExtract %[[RES_VEC]][0 : i32] : vector<4xf16> - %0 = memref.load %arg0[%cst, %cst] : memref<1x1xf16, #spirv.storage_class> + %0 = memref.load %arg0[%y, %x] : memref<2x3xf16, #spirv.storage_class> // CHECK: spirv.Store "StorageBuffer" %{{.*}}, %[[RESULT]] : f16 - memref.store %0, %arg1[%cst, %cst] : memref<1x1xf16, #spirv.storage_class> + memref.store %0, %arg1[%y, %x] : memref<2x3xf16, #spirv.storage_class> return } // CHECK-LABEL: @load_from_image_2D_i32( - // CHECK-SAME: %[[ARG0:.*]]: memref<1x1xi32, #spirv.storage_class>, %[[ARG1:.*]]: memref<1x1xi32, #spirv.storage_class> - func.func @load_from_image_2D_i32(%arg0: memref<1x1xi32, #spirv.storage_class>, %arg1: memref<1x1xi32, #spirv.storage_class>) { -// CHECK-DAG: %[[SB:.*]] = builtin.unrealized_conversion_cast %arg1 : memref<1x1xi32, #spirv.storage_class> to !spirv.ptr [0])>, StorageBuffer> -// CHECK-DAG: %[[IMAGE_PTR:.*]] = builtin.unrealized_conversion_cast %arg0 : memref<1x1xi32, #spirv.storage_class> to !spirv.ptr>, UniformConstant> - %cst = arith.constant 0 : index + // CHECK-SAME: %[[ARG0:.*]]: memref<2x3xi32, #spirv.storage_class>, %[[ARG1:.*]]: memref<2x3xi32, #spirv.storage_class> + func.func @load_from_image_2D_i32(%arg0: memref<2x3xi32, #spirv.storage_class>, %arg1: memref<2x3xi32, #spirv.storage_class>) { +// CHECK-DAG: %[[SB:.*]] = builtin.unrealized_conversion_cast %[[ARG1]] : memref<2x3xi32, #spirv.storage_class> to !spirv.ptr [0])>, StorageBuffer> +// CHECK-DAG: %[[IMAGE_PTR:.*]] = builtin.unrealized_conversion_cast %[[ARG0]] : memref<2x3xi32, #spirv.storage_class> to !spirv.ptr>, UniformConstant> + // CHECK: %[[X:.*]] = arith.constant 2 : index + // CHECK: %[[X32:.*]] = builtin.unrealized_conversion_cast %[[X]] : index to i32 + %x = arith.constant 2 : index + // CHECK: %[[Y:.*]] = arith.constant 1 : index + // CHECK: %[[Y32:.*]] = builtin.unrealized_conversion_cast %[[Y]] : index to i32 + %y = arith.constant 1 : index // CHECK: %[[SIMAGE:.*]] = spirv.Load "UniformConstant" %[[IMAGE_PTR]] : !spirv.sampled_image> // CHECK: %[[IMAGE:.*]] = spirv.Image %[[SIMAGE]] : !spirv.sampled_image> - // CHECK: %[[COORDS:.*]] = spirv.CompositeConstruct %{{.*}}, %{{.*}} : (i32, i32) -> vector<2xi32> + // CHECK: %[[COORDS:.*]] = spirv.CompositeConstruct %[[X32]], %[[Y32]] : (i32, i32) -> vector<2xi32> // CHECK: %[[RES_VEC:.*]] = spirv.ImageFetch %[[IMAGE]], %[[COORDS]] : !spirv.image, vector<2xi32> -> vector<4xi32> // CHECK: %[[RESULT:.*]] = spirv.CompositeExtract %[[RES_VEC]][0 : i32] : vector<4xi32> - %0 = memref.load %arg0[%cst, %cst] : memref<1x1xi32, #spirv.storage_class> + %0 = memref.load %arg0[%y, %x] : memref<2x3xi32, #spirv.storage_class> // CHECK: spirv.Store "StorageBuffer" %{{.*}}, %[[RESULT]] : i32 - memref.store %0, %arg1[%cst, %cst] : memref<1x1xi32, #spirv.storage_class> + memref.store %0, %arg1[%y, %x] : memref<2x3xi32, #spirv.storage_class> return } // CHECK-LABEL: @load_from_image_2D_ui32( - // CHECK-SAME: %[[ARG0:.*]]: memref<1x1xui32, #spirv.storage_class>, %[[ARG1:.*]]: memref<1x1xui32, #spirv.storage_class> - func.func @load_from_image_2D_ui32(%arg0: memref<1x1xui32, #spirv.storage_class>, %arg1: memref<1x1xui32, #spirv.storage_class>) { -// CHECK-DAG: %[[SB:.*]] = builtin.unrealized_conversion_cast %arg1 : memref<1x1xui32, #spirv.storage_class> to !spirv.ptr [0])>, StorageBuffer> -// CHECK-DAG: %[[IMAGE_PTR:.*]] = builtin.unrealized_conversion_cast %arg0 : memref<1x1xui32, #spirv.storage_class> to !spirv.ptr>, UniformConstant> - %cst = arith.constant 0 : index + // CHECK-SAME: %[[ARG0:.*]]: memref<2x3xui32, #spirv.storage_class>, %[[ARG1:.*]]: memref<2x3xui32, #spirv.storage_class> + func.func @load_from_image_2D_ui32(%arg0: memref<2x3xui32, #spirv.storage_class>, %arg1: memref<2x3xui32, #spirv.storage_class>) { +// CHECK-DAG: %[[SB:.*]] = builtin.unrealized_conversion_cast %[[ARG1]] : memref<2x3xui32, #spirv.storage_class> to !spirv.ptr [0])>, StorageBuffer> +// CHECK-DAG: %[[IMAGE_PTR:.*]] = builtin.unrealized_conversion_cast %[[ARG0]] : memref<2x3xui32, #spirv.storage_class> to !spirv.ptr>, UniformConstant> + // CHECK: %[[X:.*]] = arith.constant 2 : index + // CHECK: %[[X32:.*]] = builtin.unrealized_conversion_cast %[[X]] : index to i32 + %x = arith.constant 2 : index + // CHECK: %[[Y:.*]] = arith.constant 1 : index + // CHECK: %[[Y32:.*]] = builtin.unrealized_conversion_cast %[[Y]] : index to i32 + %y = arith.constant 1 : index // CHECK: %[[SIMAGE:.*]] = spirv.Load "UniformConstant" %[[IMAGE_PTR]] : !spirv.sampled_image> // CHECK: %[[IMAGE:.*]] = spirv.Image %[[SIMAGE]] : !spirv.sampled_image> - // CHECK: %[[COORDS:.*]] = spirv.CompositeConstruct %{{.*}}, %{{.*}} : (i32, i32) -> vector<2xi32> + // CHECK: %[[COORDS:.*]] = spirv.CompositeConstruct %[[X32]], %[[Y32]] : (i32, i32) -> vector<2xi32> // CHECK: %[[RES_VEC:.*]] = spirv.ImageFetch %[[IMAGE]], %[[COORDS]] : !spirv.image, vector<2xi32> -> vector<4xui32> // CHECK: %[[RESULT:.*]] = spirv.CompositeExtract %[[RES_VEC]][0 : i32] : vector<4xui32> - %0 = memref.load %arg0[%cst, %cst] : memref<1x1xui32, #spirv.storage_class> + %0 = memref.load %arg0[%y, %x] : memref<2x3xui32, #spirv.storage_class> // CHECK: spirv.Store "StorageBuffer" %{{.*}}, %[[RESULT]] : ui32 - memref.store %0, %arg1[%cst, %cst] : memref<1x1xui32, #spirv.storage_class> + memref.store %0, %arg1[%y, %x] : memref<2x3xui32, #spirv.storage_class> return } // CHECK-LABEL: @load_from_image_2D_i16( - // CHECK-SAME: %[[ARG0:.*]]: memref<1x1xi16, #spirv.storage_class>, %[[ARG1:.*]]: memref<1x1xi16, #spirv.storage_class> - func.func @load_from_image_2D_i16(%arg0: memref<1x1xi16, #spirv.storage_class>, %arg1: memref<1x1xi16, #spirv.storage_class>) { -// CHECK-DAG: %[[SB:.*]] = builtin.unrealized_conversion_cast %arg1 : memref<1x1xi16, #spirv.storage_class> to !spirv.ptr [0])>, StorageBuffer> -// CHECK-DAG: %[[IMAGE_PTR:.*]] = builtin.unrealized_conversion_cast %arg0 : memref<1x1xi16, #spirv.storage_class> to !spirv.ptr>, UniformConstant> - %cst = arith.constant 0 : index + // CHECK-SAME: %[[ARG0:.*]]: memref<2x3xi16, #spirv.storage_class>, %[[ARG1:.*]]: memref<2x3xi16, #spirv.storage_class> + func.func @load_from_image_2D_i16(%arg0: memref<2x3xi16, #spirv.storage_class>, %arg1: memref<2x3xi16, #spirv.storage_class>) { +// CHECK-DAG: %[[SB:.*]] = builtin.unrealized_conversion_cast %[[ARG1]] : memref<2x3xi16, #spirv.storage_class> to !spirv.ptr [0])>, StorageBuffer> +// CHECK-DAG: %[[IMAGE_PTR:.*]] = builtin.unrealized_conversion_cast %[[ARG0]] : memref<2x3xi16, #spirv.storage_class> to !spirv.ptr>, UniformConstant> + // CHECK: %[[X:.*]] = arith.constant 2 : index + // CHECK: %[[X32:.*]] = builtin.unrealized_conversion_cast %[[X]] : index to i32 + %x = arith.constant 2 : index + // CHECK: %[[Y:.*]] = arith.constant 1 : index + // CHECK: %[[Y32:.*]] = builtin.unrealized_conversion_cast %[[Y]] : index to i32 + %y = arith.constant 1 : index // CHECK: %[[SIMAGE:.*]] = spirv.Load "UniformConstant" %[[IMAGE_PTR]] : !spirv.sampled_image> // CHECK: %[[IMAGE:.*]] = spirv.Image %[[SIMAGE]] : !spirv.sampled_image> - // CHECK: %[[COORDS:.*]] = spirv.CompositeConstruct %{{.*}}, %{{.*}} : (i32, i32) -> vector<2xi32> + // CHECK: %[[COORDS:.*]] = spirv.CompositeConstruct %[[X32]], %[[Y32]] : (i32, i32) -> vector<2xi32> // CHECK: %[[RES_VEC:.*]] = spirv.ImageFetch %[[IMAGE]], %[[COORDS]] : !spirv.image, vector<2xi32> -> vector<4xi16> // CHECK: %[[RESULT:.*]] = spirv.CompositeExtract %[[RES_VEC]][0 : i32] : vector<4xi16> - %0 = memref.load %arg0[%cst, %cst] : memref<1x1xi16, #spirv.storage_class> + %0 = memref.load %arg0[%y, %x] : memref<2x3xi16, #spirv.storage_class> // CHECK: spirv.Store "StorageBuffer" %{{.*}}, %[[RESULT]] : i16 - memref.store %0, %arg1[%cst, %cst] : memref<1x1xi16, #spirv.storage_class> + memref.store %0, %arg1[%y, %x] : memref<2x3xi16, #spirv.storage_class> return } // CHECK-LABEL: @load_from_image_2D_ui16( - // CHECK-SAME: %[[ARG0:.*]]: memref<1x1xui16, #spirv.storage_class>, %[[ARG1:.*]]: memref<1x1xui16, #spirv.storage_class> - func.func @load_from_image_2D_ui16(%arg0: memref<1x1xui16, #spirv.storage_class>, %arg1: memref<1x1xui16, #spirv.storage_class>) { -// CHECK-DAG: %[[SB:.*]] = builtin.unrealized_conversion_cast %arg1 : memref<1x1xui16, #spirv.storage_class> to !spirv.ptr [0])>, StorageBuffer> -// CHECK-DAG: %[[IMAGE_PTR:.*]] = builtin.unrealized_conversion_cast %arg0 : memref<1x1xui16, #spirv.storage_class> to !spirv.ptr>, UniformConstant> - %cst = arith.constant 0 : index + // CHECK-SAME: %[[ARG0:.*]]: memref<2x3xui16, #spirv.storage_class>, %[[ARG1:.*]]: memref<2x3xui16, #spirv.storage_class> + func.func @load_from_image_2D_ui16(%arg0: memref<2x3xui16, #spirv.storage_class>, %arg1: memref<2x3xui16, #spirv.storage_class>) { +// CHECK-DAG: %[[SB:.*]] = builtin.unrealized_conversion_cast %[[ARG1]] : memref<2x3xui16, #spirv.storage_class> to !spirv.ptr [0])>, StorageBuffer> +// CHECK-DAG: %[[IMAGE_PTR:.*]] = builtin.unrealized_conversion_cast %[[ARG0]] : memref<2x3xui16, #spirv.storage_class> to !spirv.ptr>, UniformConstant> + // CHECK: %[[X:.*]] = arith.constant 2 : index + // CHECK: %[[X32:.*]] = builtin.unrealized_conversion_cast %[[X]] : index to i32 + %x = arith.constant 2 : index + // CHECK: %[[Y:.*]] = arith.constant 1 : index + // CHECK: %[[Y32:.*]] = builtin.unrealized_conversion_cast %[[Y]] : index to i32 + %y = arith.constant 1 : index // CHECK: %[[SIMAGE:.*]] = spirv.Load "UniformConstant" %[[IMAGE_PTR]] : !spirv.sampled_image> // CHECK: %[[IMAGE:.*]] = spirv.Image %[[SIMAGE]] : !spirv.sampled_image> - // CHECK: %[[COORDS:.*]] = spirv.CompositeConstruct %{{.*}}, %{{.*}} : (i32, i32) -> vector<2xi32> + // CHECK: %[[COORDS:.*]] = spirv.CompositeConstruct %[[X32]], %[[Y32]] : (i32, i32) -> vector<2xi32> // CHECK: %[[RES_VEC:.*]] = spirv.ImageFetch %[[IMAGE]], %[[COORDS]] : !spirv.image, vector<2xi32> -> vector<4xui16> // CHECK: %[[RESULT:.*]] = spirv.CompositeExtract %[[RES_VEC]][0 : i32] : vector<4xui16> - %0 = memref.load %arg0[%cst, %cst] : memref<1x1xui16, #spirv.storage_class> + %0 = memref.load %arg0[%y, %x] : memref<2x3xui16, #spirv.storage_class> // CHECK: spirv.Store "StorageBuffer" %{{.*}}, %[[RESULT]] : ui16 - memref.store %0, %arg1[%cst, %cst] : memref<1x1xui16, #spirv.storage_class> + memref.store %0, %arg1[%y, %x] : memref<2x3xui16, #spirv.storage_class> return } @@ -697,4 +788,15 @@ module attributes { memref.store %0, %arg1[%cst] : memref<1xvector<1xf32>, #spirv.storage_class> return } + + // CHECK-LABEL: @load_non_perm_layout( + func.func @load_non_perm_layout(%arg0: memref<2x4xf32, #non_permutation, #spirv.storage_class>, %arg1: memref<2x4xf32, #spirv.storage_class>) { + %x = arith.constant 3 : index + %y = arith.constant 1 : index + // CHECK-NOT: spirv.Image + // CHECK-NOT: spirv.ImageFetch + %0 = memref.load %arg0[%y, %x] : memref<2x4xf32, #non_permutation, #spirv.storage_class> + memref.store %0, %arg1[%y, %x] : memref<2x4xf32, #spirv.storage_class> + return + } }