diff --git a/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp b/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp index 9d81702581131..815909169c6b8 100644 --- a/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp +++ b/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp @@ -290,6 +290,7 @@ class VectorGatherOpConversion MemRefType memRefType = dyn_cast(gather.getBaseType()); assert(memRefType && "The base should be bufferized"); + // TODO: Add support for strided MemRef. if (failed(isMemRefTypeSupported(memRefType, *this->getTypeConverter()))) return rewriter.notifyMatchFailure(gather, "memref type not supported"); @@ -348,6 +349,7 @@ class VectorScatterOpConversion auto memRefType = dyn_cast(scatter.getBaseType()); assert(memRefType && "The base should be bufferized"); + // TODO: Add support for strided MemRef. if (failed(isMemRefTypeSupported(memRefType, *this->getTypeConverter()))) return rewriter.notifyMatchFailure(scatter, "memref type not supported"); diff --git a/mlir/test/Conversion/VectorToLLVM/vector-to-llvm-interface.mlir b/mlir/test/Conversion/VectorToLLVM/vector-to-llvm-interface.mlir index 49c55f5b54496..076209cbc7a4c 100644 --- a/mlir/test/Conversion/VectorToLLVM/vector-to-llvm-interface.mlir +++ b/mlir/test/Conversion/VectorToLLVM/vector-to-llvm-interface.mlir @@ -2066,6 +2066,20 @@ func.func @gather_with_alignment(%arg0: memref, %arg1: vector<3xi32>, %ar // ----- +// TODO: Implement this lowering. +func.func @negative_gather_on_strided_memref(%arg0: memref>, %arg1: vector<3xi32>, %arg2: vector<3xi1>, %arg3: vector<3xf32>) -> vector<3xf32> { + %0 = arith.constant 0: index + %1 = vector.gather %arg0[%0][%arg1], %arg2, %arg3 + : memref>, vector<3xi32>, vector<3xi1>, vector<3xf32> into vector<3xf32> + return %1 : vector<3xf32> +} + +// CHECK-LABEL: func @negative_gather_on_strided_memref +// CHECK-NOT: llvm.intr.masked.gather +// CHECK: vector.gather + +// ----- + //===----------------------------------------------------------------------===// // vector.scatter //===----------------------------------------------------------------------===// @@ -2152,6 +2166,19 @@ func.func @scatter_with_alignment(%arg0: memref, %arg1: vector<3xi32>, %a // CHECK-LABEL: func @scatter_with_alignment // CHECK: llvm.intr.masked.scatter %{{.*}}, %{{.*}}, %{{.*}} {alignment = 8 : i32} : vector<3xf32>, vector<3xi1> into vector<3x!llvm.ptr> +// ----- + +// TODO: Implement this lowering. +func.func @negative_scatter_on_strided_memref(%arg0: memref>, %arg1: vector<3xi32>, %arg2: vector<3xi1>, %arg3: vector<3xf32>) { + %0 = arith.constant 0: index + vector.scatter %arg0[%0][%arg1], %arg2, %arg3 + : memref>, vector<3xi32>, vector<3xi1>, vector<3xf32> + return +} + +// CHECK-LABEL: func @negative_scatter_on_strided_memref +// CHECK-NOT: llvm.intr.masked.scatter +// CHECK: vector.scatter // ----- diff --git a/mlir/test/Dialect/Vector/invalid.mlir b/mlir/test/Dialect/Vector/invalid.mlir index 3957455ccc76e..d8e08c8b2a850 100644 --- a/mlir/test/Dialect/Vector/invalid.mlir +++ b/mlir/test/Dialect/Vector/invalid.mlir @@ -2046,6 +2046,15 @@ func.func @load_non_pow_of_2_alignment(%memref: memref<4xi32>, %c0: index) { // ----- +func.func @load_non_unit_stride(%src : memref>) { + %c0 = arith.constant 0 : index + // expected-error @+1 {{'vector.load' op most minor memref dim must have unit stride}} + %0 = vector.load %src[%c0] : memref>, vector<16xi8> + return +} + +// ----- + //===----------------------------------------------------------------------===// // vector.store //===----------------------------------------------------------------------===// @@ -2073,6 +2082,13 @@ func.func @store_non_pow_of_2_alignment(%memref: memref<4xi32>, %val: vector<4xi return } +// ----- +func.func @store_non_unit_stride(%src : memref>,%val : vector<16xi8>, %c0: index) { + // expected-error @below {{'vector.store' op most minor memref dim must have unit stride}} + vector.store %val, %src[%c0] : memref>, vector<16xi8> + return +} + // ----- // Verify that vector.bitcast rejects vectors with i0 (zero-bitwidth) element type.