diff --git a/mlir/lib/Dialect/Vector/Transforms/LowerVectorGather.cpp b/mlir/lib/Dialect/Vector/Transforms/LowerVectorGather.cpp index 7194d41d60df7..fd95ea0c39a54 100644 --- a/mlir/lib/Dialect/Vector/Transforms/LowerVectorGather.cpp +++ b/mlir/lib/Dialect/Vector/Transforms/LowerVectorGather.cpp @@ -195,16 +195,6 @@ struct Gather1DToConditionalLoads : OpRewritePattern { // correct N-D load indices from the 1-D gather index. bool useDelinearization = false; if (auto memType = dyn_cast(base.getType())) { - // vector.load requires the most minor memref dim to have unit stride - // (unless reading exactly 1 element). - if (auto stridesAttr = - dyn_cast_if_present(memType.getLayout())) { - if (stridesAttr.getStrides().back() != 1 && - resultTy.getNumElements() != 1) - return rewriter.notifyMatchFailure( - op, "most minor memref dim must have unit stride"); - } - if (memType.getRank() > 1) useDelinearization = true; } diff --git a/mlir/test/Dialect/Vector/vector-gather-lowering.mlir b/mlir/test/Dialect/Vector/vector-gather-lowering.mlir index 59b13e300e5e5..1a5407e7f4752 100644 --- a/mlir/test/Dialect/Vector/vector-gather-lowering.mlir +++ b/mlir/test/Dialect/Vector/vector-gather-lowering.mlir @@ -360,3 +360,31 @@ func.func @gather_memref_2d_delinearize_nonzero_offsets( vector<2xi1>, vector<2xf32> into vector<2xf32> return %0 : vector<2xf32> } + +// ----- + +// Verify that gather on a rank-1 strided memref (from rank-reducing subview) +// is correctly scalarized to per-element vector.load ops. + +// CHECK-LABEL: @gather_rank1_strided_memref +// CHECK-SAME: (%[[BASE:.+]]: memref<4xf32, strided<[6]>> +// CHECK-NOT: vector.gather +// CHECK: vector.extract {{.*}}[0] +// CHECK: scf.if {{.*}} -> (vector<3xf32>) +// CHECK: vector.load %[[BASE]][%{{.*}}] : memref<4xf32, strided<[6]>>, vector<1xf32> +// CHECK: vector.extract {{.*}}[1] +// CHECK: scf.if {{.*}} -> (vector<3xf32>) +// CHECK: vector.load %[[BASE]][%{{.*}}] : memref<4xf32, strided<[6]>>, vector<1xf32> +// CHECK: vector.extract {{.*}}[2] +// CHECK: scf.if {{.*}} -> (vector<3xf32>) +// CHECK: vector.load %[[BASE]][%{{.*}}] : memref<4xf32, strided<[6]>>, vector<1xf32> +func.func @gather_rank1_strided_memref( + %base: memref<4xf32, strided<[6]>>, + %v: vector<3xindex>, %mask: vector<3xi1>, + %pass_thru: vector<3xf32>) -> vector<3xf32> { + %c0 = arith.constant 0 : index + %0 = vector.gather %base[%c0][%v], %mask, %pass_thru + : memref<4xf32, strided<[6]>>, vector<3xindex>, + vector<3xi1>, vector<3xf32> into vector<3xf32> + return %0 : vector<3xf32> +}