From c65261785efc0a60dc9b457a0831557ca2f62d0f Mon Sep 17 00:00:00 2001 From: Andrzej Warzynski Date: Tue, 21 Nov 2023 13:13:03 +0000 Subject: [PATCH 1/6] [mlir][Vector] Add a rewrite pattern for gather over a strided memref This patch adds a rewrite pattern for `vector.gather` over a strided memref like the following: ```mlir %subview = memref.subview %arg0[0, 0] [100, 1] [1, 1] : memref<100x3xf32> to memref<100xf32, strided<[3]>> %gather = vector.gather %subview[%c0] [%idxs], %cst_0, %cst : memref<100xf32, strided<[3]>>, vector<4xindex>, vector<4xi1>, vector<4xf32> into vector<4xf32> ``` ```mlir %collapse_shape = memref.collapse_shape %arg0 [[0, 1]] : memref<100x3xf32> into memref<300xf32> %1 = arith.muli %arg3, %cst : vector<4xindex> %gather = vector.gather %collapse_shape[%c0] [%1], %cst_1, %cst_0 : memref<300xf32>, vector<4xindex>, vector<4xi1>, vector<4xf32> into vector<4xf32> ``` Fixes https://github.com/openxla/iree/issues/15364. --- .../Vector/Transforms/LowerVectorGather.cpp | 80 ++++++++++++++++++- .../Vector/vector-gather-lowering.mlir | 54 +++++++++++++ 2 files changed, 132 insertions(+), 2 deletions(-) diff --git a/mlir/lib/Dialect/Vector/Transforms/LowerVectorGather.cpp b/mlir/lib/Dialect/Vector/Transforms/LowerVectorGather.cpp index 152aefa65effc..54b350d7ac352 100644 --- a/mlir/lib/Dialect/Vector/Transforms/LowerVectorGather.cpp +++ b/mlir/lib/Dialect/Vector/Transforms/LowerVectorGather.cpp @@ -96,6 +96,82 @@ struct FlattenGather : OpRewritePattern { } }; +/// Rewrites a vector.gather of a strided MemRef as a gather of a non-strided +/// MemRef with updated indices that model the strided access. +/// +/// ```mlir +/// %subview = memref.subview %M (...) to memref<100xf32, strided<[3]>> +/// %gather = vector.gather %subview (...) : memref<100xf32, strided<[3]>> +/// ``` +/// ==> +/// ```mlir +/// %collapse_shape = memref.collapse_shape %M (...) into memref<300xf32> +/// %1 = arith.muli %idxs, %c3 : vector<4xindex> +/// %gather = vector.gather %collapse_shape (...) : memref<300xf32> (...) +/// ``` +/// +/// ATM this is effectively limited to reading a 1D Vector from a 2D MemRef, +/// but should be fairly straightforward to extend beyond that. +struct RemoveStrideFromGatherSource : OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(vector::GatherOp op, + PatternRewriter &rewriter) const override { + Value base = op.getBase(); + if (!base.getDefiningOp()) + return failure(); + + // TODO: Strided accesses might be coming from other ops as well + auto subview = dyn_cast(base.getDefiningOp()); + if (!subview) + return failure(); + + // TODO: Allows ranks > 2. + if (subview.getSource().getType().getRank() != 2) + return failure(); + + // Get strides + auto layout = subview.getResult().getType().getLayout(); + auto stridedLayoutAttr = llvm::dyn_cast(layout); + + // TODO: Allow the access to be strided in multiple dimensions. + if (stridedLayoutAttr.getStrides().size() != 1) + return failure(); + + int64_t srcTrailingDim = subview.getSource().getType().getShape().back(); + + // Assume that the stride matches the trailing dimension of the source + // memref. + // TODO: Relax this assumption. + if (stridedLayoutAttr.getStrides()[0] != srcTrailingDim) + return failure(); + + // 1. Collapse the input memref so that it's "flat". + SmallVector reassoc = {{0, 1}}; + Value collapsed = rewriter.create( + op.getLoc(), subview.getSource(), reassoc); + + // 2. Generate new gather indices that will model the + // strided access. + auto stride = rewriter.getIndexAttr(srcTrailingDim); + auto vType = op.getIndexVec().getType(); + Value mulCst = rewriter.create( + op.getLoc(), vType, DenseElementsAttr::get(vType, stride)); + + Value newIdxs = + rewriter.create(op.getLoc(), op.getIndexVec(), mulCst); + + // 3. Create an updated gather op with the collapsed input memref and the + // updated indices. + Value newGather = rewriter.create( + op.getLoc(), op.getResult().getType(), collapsed, op.getIndices(), + newIdxs, op.getMask(), op.getPassThru()); + rewriter.replaceOp(op, newGather); + + return success(); + } +}; + /// Turns 1-d `vector.gather` into a scalarized sequence of `vector.loads` or /// `tensor.extract`s. To avoid out-of-bounds memory accesses, these /// loads/extracts are made conditional using `scf.if` ops. @@ -168,6 +244,6 @@ struct Gather1DToConditionalLoads : OpRewritePattern { void mlir::vector::populateVectorGatherLoweringPatterns( RewritePatternSet &patterns, PatternBenefit benefit) { - patterns.add(patterns.getContext(), - benefit); + patterns.add(patterns.getContext(), benefit); } diff --git a/mlir/test/Dialect/Vector/vector-gather-lowering.mlir b/mlir/test/Dialect/Vector/vector-gather-lowering.mlir index 026bec8cd65d3..3de7f44e4fb3e 100644 --- a/mlir/test/Dialect/Vector/vector-gather-lowering.mlir +++ b/mlir/test/Dialect/Vector/vector-gather-lowering.mlir @@ -151,3 +151,57 @@ func.func @gather_tensor_1d_none_set(%base: tensor, %v: vector<2xindex>, %0 = vector.gather %base[%c0][%v], %mask, %pass_thru : tensor, vector<2xindex>, vector<2xi1>, vector<2xf32> into vector<2xf32> return %0 : vector<2xf32> } + +// Check that vector.gather of a strided memref is replaced with a +// vector.gather with indices encoding the original strides. Note that with the +// other patterns +#map = affine_map<()[s0] -> (s0 * 4096)> +#map1 = affine_map<()[s0] -> (s0 * -4096 + 518400, 4096)> +func.func @strided_gather(%M_in : memref<100x3xf32>, %M_out: memref<518400xf32>, %idxs : vector<4xindex>, %x : index, %y : index) { + %c0 = arith.constant 0 : index + %x_1 = affine.apply #map()[%x] + // Strided MemRef + %subview = memref.subview %M_in[0, 0] [100, 1] [1, 1] : memref<100x3xf32> to memref<100xf32, strided<[3]>> + %cst_0 = arith.constant dense : vector<4xi1> + %cst = arith.constant dense<0.000000e+00> : vector<4xf32> + // Gather of a strided MemRef + %7 = vector.gather %subview[%c0] [%idxs], %cst_0, %cst : memref<100xf32, strided<[3]>>, vector<4xindex>, vector<4xi1>, vector<4xf32> into vector<4xf32> + %subview_1 = memref.subview %M_out[%x_1] [%y] [1] : memref<518400xf32> to memref> + vector.store %7, %subview_1[%c0] : memref>, vector<4xf32> + return +} +// CHECK-LABEL: func.func @strided_gather( +// CHECK-SAME: %[[M_in:.*]]: memref<100x3xf32>, +// CHECK-SAME: %[[M_out:.*]]: memref<518400xf32>, +// CHECK-SAME: %[[IDXS:.*]]: vector<4xindex>, +// CHECK-SAME: %[[VAL_4:.*]]: index, +// CHECK-SAME: %[[VAL_5:.*]]: index) { +// CHECK: %[[CST_3:.*]] = arith.constant dense<3> : vector<4xindex> +// CHECK: %[[MASK:.*]] = arith.constant dense : vector<4xi1> + +// CHECK: %[[COLLAPSED:.*]] = memref.collapse_shape %[[M_in]] {{\[\[}}0, 1]] : memref<100x3xf32> into memref<300xf32> +// CHECK: %[[NEW_IDXS:.*]] = arith.muli %[[IDXS]], %[[CST_3]] : vector<4xindex> + +// CHECK: %[[MASK_0:.*]] = vector.extract %[[MASK]][0] : i1 from vector<4xi1> +// CHECK: %[[IDX_0:.*]] = vector.extract %[[NEW_IDXS]][0] : index from vector<4xindex> +// CHECK: scf.if %[[MASK_0]] -> (vector<4xf32>) +// CHECK: %[[M_0:.*]] = vector.load %[[COLLAPSED]]{{\[}}%[[IDX_0]]] : memref<300xf32>, vector<1xf32> +// CHECK: %[[V_0:.*]] = vector.extract %[[M_0]][0] : f32 from vector<1xf32> + +// CHECK: %[[MASK_1:.*]] = vector.extract %[[MASK]][1] : i1 from vector<4xi1> +// CHECK: %[[IDX_1:.*]] = vector.extract %[[NEW_IDXS]][1] : index from vector<4xindex> +// CHECK: scf.if %[[MASK_1]] -> (vector<4xf32>) +// CHECK: %[[M_1:.*]] = vector.load %[[COLLAPSED]]{{\[}}%[[IDX_1]]] : memref<300xf32>, vector<1xf32> +// CHECK: %[[V_1:.*]] = vector.extract %[[M_1]][0] : f32 from vector<1xf32> + +// CHECK: %[[MASK_2:.*]] = vector.extract %[[MASK]][2] : i1 from vector<4xi1> +// CHECK: %[[IDX_2:.*]] = vector.extract %[[NEW_IDXS]][2] : index from vector<4xindex> +// CHECK: scf.if %[[MASK_2]] -> (vector<4xf32>) +// CHECK: %[[M_2:.*]] = vector.load %[[COLLAPSED]][%[[IDX_2]]] : memref<300xf32>, vector<1xf32> +// CHECK: %[[V_2:.*]] = vector.extract %[[M_2]][0] : f32 from vector<1xf32> + +// CHECK: %[[MASK_3:.*]] = vector.extract %[[MASK]][3] : i1 from vector<4xi1> +// CHECK: %[[IDX_3:.*]] = vector.extract %[[NEW_IDXS]][3] : index from vector<4xindex> +// CHECK: scf.if %[[MASK_3]] -> (vector<4xf32>) +// CHECK: %[[M_3:.*]] = vector.load %[[COLLAPSED]]{{\[}}%[[IDX_3]]] : memref<300xf32>, vector<1xf32> +// CHECK: %[[V_3:.*]] = vector.extract %[[M_3]][0] : f32 from vector<1xf32> From bb71dbcfe1d6b0e0016712cdeb7c9941b23405d8 Mon Sep 17 00:00:00 2001 From: Andrzej Warzynski Date: Tue, 28 Nov 2023 11:42:37 +0000 Subject: [PATCH 2/6] fixup! [mlir][Vector] Add a rewrite pattern for gather over a strided memref Refine based on PR feedback --- .../Vector/Transforms/LowerVectorGather.cpp | 20 ++++++----- .../Vector/vector-gather-lowering.mlir | 36 ++++++++++--------- 2 files changed, 32 insertions(+), 24 deletions(-) diff --git a/mlir/lib/Dialect/Vector/Transforms/LowerVectorGather.cpp b/mlir/lib/Dialect/Vector/Transforms/LowerVectorGather.cpp index 54b350d7ac352..74487db5cdfc2 100644 --- a/mlir/lib/Dialect/Vector/Transforms/LowerVectorGather.cpp +++ b/mlir/lib/Dialect/Vector/Transforms/LowerVectorGather.cpp @@ -100,12 +100,12 @@ struct FlattenGather : OpRewritePattern { /// MemRef with updated indices that model the strided access. /// /// ```mlir -/// %subview = memref.subview %M (...) to memref<100xf32, strided<[3]>> +/// %subview = memref.subview %M (...) memref<100x3xf32> to memref<100xf32, strided<[3]>> /// %gather = vector.gather %subview (...) : memref<100xf32, strided<[3]>> /// ``` /// ==> /// ```mlir -/// %collapse_shape = memref.collapse_shape %M (...) into memref<300xf32> +/// %collapse_shape = memref.collapse_shape %M (...) memref<100x3xf32> into memref<300xf32> /// %1 = arith.muli %idxs, %c3 : vector<4xindex> /// %gather = vector.gather %collapse_shape (...) : memref<300xf32> (...) /// ``` @@ -122,23 +122,27 @@ struct RemoveStrideFromGatherSource : OpRewritePattern { return failure(); // TODO: Strided accesses might be coming from other ops as well - auto subview = dyn_cast(base.getDefiningOp()); + auto subview = base.getDefiningOp(); if (!subview) return failure(); - // TODO: Allows ranks > 2. - if (subview.getSource().getType().getRank() != 2) + auto sourceType = subview.getSource().getType(); + + // TODO: Allow ranks > 2. + if (sourceType.getRank() != 2) return failure(); // Get strides auto layout = subview.getResult().getType().getLayout(); auto stridedLayoutAttr = llvm::dyn_cast(layout); + if (!stridedLayoutAttr) + return failure(); // TODO: Allow the access to be strided in multiple dimensions. if (stridedLayoutAttr.getStrides().size() != 1) return failure(); - int64_t srcTrailingDim = subview.getSource().getType().getShape().back(); + int64_t srcTrailingDim = sourceType.getShape().back(); // Assume that the stride matches the trailing dimension of the source // memref. @@ -153,8 +157,8 @@ struct RemoveStrideFromGatherSource : OpRewritePattern { // 2. Generate new gather indices that will model the // strided access. - auto stride = rewriter.getIndexAttr(srcTrailingDim); - auto vType = op.getIndexVec().getType(); + IntegerAttr stride = rewriter.getIndexAttr(srcTrailingDim); + VectorType vType = op.getIndexVec().getType(); Value mulCst = rewriter.create( op.getLoc(), vType, DenseElementsAttr::get(vType, stride)); diff --git a/mlir/test/Dialect/Vector/vector-gather-lowering.mlir b/mlir/test/Dialect/Vector/vector-gather-lowering.mlir index 3de7f44e4fb3e..a7291c359d351 100644 --- a/mlir/test/Dialect/Vector/vector-gather-lowering.mlir +++ b/mlir/test/Dialect/Vector/vector-gather-lowering.mlir @@ -153,45 +153,49 @@ func.func @gather_tensor_1d_none_set(%base: tensor, %v: vector<2xindex>, } // Check that vector.gather of a strided memref is replaced with a -// vector.gather with indices encoding the original strides. Note that with the -// other patterns +// vector.gather with indices encoding the original strides. Note that multiple +// patterns are run for this example, e.g.: + // 1. "remove stride from gather source" + // 2. "flatten gather" +// However, the main goal is to the test Pattern 1 above. #map = affine_map<()[s0] -> (s0 * 4096)> #map1 = affine_map<()[s0] -> (s0 * -4096 + 518400, 4096)> -func.func @strided_gather(%M_in : memref<100x3xf32>, %M_out: memref<518400xf32>, %idxs : vector<4xindex>, %x : index, %y : index) { +func.func @strided_gather(%base : memref<100x3xf32>, + %M_out: memref<518400xf32>, + %idxs : vector<4xindex>, + %x : index, %y : index) -> vector<4xf32> { %c0 = arith.constant 0 : index %x_1 = affine.apply #map()[%x] // Strided MemRef - %subview = memref.subview %M_in[0, 0] [100, 1] [1, 1] : memref<100x3xf32> to memref<100xf32, strided<[3]>> - %cst_0 = arith.constant dense : vector<4xi1> - %cst = arith.constant dense<0.000000e+00> : vector<4xf32> + %subview = memref.subview %base[0, 0] [100, 1] [1, 1] : memref<100x3xf32> to memref<100xf32, strided<[3]>> + %mask = arith.constant dense : vector<4xi1> + %pass_thru = arith.constant dense<0.000000e+00> : vector<4xf32> // Gather of a strided MemRef - %7 = vector.gather %subview[%c0] [%idxs], %cst_0, %cst : memref<100xf32, strided<[3]>>, vector<4xindex>, vector<4xi1>, vector<4xf32> into vector<4xf32> - %subview_1 = memref.subview %M_out[%x_1] [%y] [1] : memref<518400xf32> to memref> - vector.store %7, %subview_1[%c0] : memref>, vector<4xf32> - return + %res = vector.gather %subview[%c0] [%idxs], %mask, %pass_thru : memref<100xf32, strided<[3]>>, vector<4xindex>, vector<4xi1>, vector<4xf32> into vector<4xf32> + return %res : vector<4xf32> } // CHECK-LABEL: func.func @strided_gather( -// CHECK-SAME: %[[M_in:.*]]: memref<100x3xf32>, +// CHECK-SAME: %[[base:.*]]: memref<100x3xf32>, // CHECK-SAME: %[[M_out:.*]]: memref<518400xf32>, // CHECK-SAME: %[[IDXS:.*]]: vector<4xindex>, // CHECK-SAME: %[[VAL_4:.*]]: index, -// CHECK-SAME: %[[VAL_5:.*]]: index) { +// CHECK-SAME: %[[VAL_5:.*]]: index) -> vector<4xf32> { // CHECK: %[[CST_3:.*]] = arith.constant dense<3> : vector<4xindex> // CHECK: %[[MASK:.*]] = arith.constant dense : vector<4xi1> -// CHECK: %[[COLLAPSED:.*]] = memref.collapse_shape %[[M_in]] {{\[\[}}0, 1]] : memref<100x3xf32> into memref<300xf32> +// CHECK: %[[COLLAPSED:.*]] = memref.collapse_shape %[[base]] {{\[\[}}0, 1]] : memref<100x3xf32> into memref<300xf32> // CHECK: %[[NEW_IDXS:.*]] = arith.muli %[[IDXS]], %[[CST_3]] : vector<4xindex> // CHECK: %[[MASK_0:.*]] = vector.extract %[[MASK]][0] : i1 from vector<4xi1> // CHECK: %[[IDX_0:.*]] = vector.extract %[[NEW_IDXS]][0] : index from vector<4xindex> // CHECK: scf.if %[[MASK_0]] -> (vector<4xf32>) -// CHECK: %[[M_0:.*]] = vector.load %[[COLLAPSED]]{{\[}}%[[IDX_0]]] : memref<300xf32>, vector<1xf32> +// CHECK: %[[M_0:.*]] = vector.load %[[COLLAPSED]][%[[IDX_0]]] : memref<300xf32>, vector<1xf32> // CHECK: %[[V_0:.*]] = vector.extract %[[M_0]][0] : f32 from vector<1xf32> // CHECK: %[[MASK_1:.*]] = vector.extract %[[MASK]][1] : i1 from vector<4xi1> // CHECK: %[[IDX_1:.*]] = vector.extract %[[NEW_IDXS]][1] : index from vector<4xindex> // CHECK: scf.if %[[MASK_1]] -> (vector<4xf32>) -// CHECK: %[[M_1:.*]] = vector.load %[[COLLAPSED]]{{\[}}%[[IDX_1]]] : memref<300xf32>, vector<1xf32> +// CHECK: %[[M_1:.*]] = vector.load %[[COLLAPSED]][%[[IDX_1]]] : memref<300xf32>, vector<1xf32> // CHECK: %[[V_1:.*]] = vector.extract %[[M_1]][0] : f32 from vector<1xf32> // CHECK: %[[MASK_2:.*]] = vector.extract %[[MASK]][2] : i1 from vector<4xi1> @@ -203,5 +207,5 @@ func.func @strided_gather(%M_in : memref<100x3xf32>, %M_out: memref<518400xf32>, // CHECK: %[[MASK_3:.*]] = vector.extract %[[MASK]][3] : i1 from vector<4xi1> // CHECK: %[[IDX_3:.*]] = vector.extract %[[NEW_IDXS]][3] : index from vector<4xindex> // CHECK: scf.if %[[MASK_3]] -> (vector<4xf32>) -// CHECK: %[[M_3:.*]] = vector.load %[[COLLAPSED]]{{\[}}%[[IDX_3]]] : memref<300xf32>, vector<1xf32> +// CHECK: %[[M_3:.*]] = vector.load %[[COLLAPSED]][%[[IDX_3]]] : memref<300xf32>, vector<1xf32> // CHECK: %[[V_3:.*]] = vector.extract %[[M_3]][0] : f32 from vector<1xf32> From 2fad34c305999e2fa333bba620795db3ff485cd0 Mon Sep 17 00:00:00 2001 From: Andrzej Warzynski Date: Tue, 28 Nov 2023 11:54:22 +0000 Subject: [PATCH 3/6] fixup! [mlir][Vector] Add a rewrite pattern for gather over a strided memref Fix formatting --- .../Dialect/Vector/Transforms/LowerVectorGather.cpp | 11 ++++++----- 1 file changed, 6 insertions(+), 5 deletions(-) diff --git a/mlir/lib/Dialect/Vector/Transforms/LowerVectorGather.cpp b/mlir/lib/Dialect/Vector/Transforms/LowerVectorGather.cpp index 74487db5cdfc2..3bbbf8167f52b 100644 --- a/mlir/lib/Dialect/Vector/Transforms/LowerVectorGather.cpp +++ b/mlir/lib/Dialect/Vector/Transforms/LowerVectorGather.cpp @@ -100,14 +100,15 @@ struct FlattenGather : OpRewritePattern { /// MemRef with updated indices that model the strided access. /// /// ```mlir -/// %subview = memref.subview %M (...) memref<100x3xf32> to memref<100xf32, strided<[3]>> -/// %gather = vector.gather %subview (...) : memref<100xf32, strided<[3]>> +/// %subview = memref.subview %M (...) memref<100x3xf32> to memref<100xf32, +/// strided<[3]>> %gather = vector.gather %subview (...) : memref<100xf32, +/// strided<[3]>> /// ``` /// ==> /// ```mlir -/// %collapse_shape = memref.collapse_shape %M (...) memref<100x3xf32> into memref<300xf32> -/// %1 = arith.muli %idxs, %c3 : vector<4xindex> -/// %gather = vector.gather %collapse_shape (...) : memref<300xf32> (...) +/// %collapse_shape = memref.collapse_shape %M (...) memref<100x3xf32> into +/// memref<300xf32> %1 = arith.muli %idxs, %c3 : vector<4xindex> %gather = +/// vector.gather %collapse_shape (...) : memref<300xf32> (...) /// ``` /// /// ATM this is effectively limited to reading a 1D Vector from a 2D MemRef, From 7d9f17c6b1a0cf77c06f89b9d4ae60bfee08ad4d Mon Sep 17 00:00:00 2001 From: Andrzej Warzynski Date: Tue, 28 Nov 2023 13:07:02 +0000 Subject: [PATCH 4/6] fixup! [mlir][Vector] Add a rewrite pattern for gather over a strided memref Restrict Gather1DToConditionalLoads --- .../Dialect/Vector/Transforms/LowerVectorGather.cpp | 11 +++++++++++ 1 file changed, 11 insertions(+) diff --git a/mlir/lib/Dialect/Vector/Transforms/LowerVectorGather.cpp b/mlir/lib/Dialect/Vector/Transforms/LowerVectorGather.cpp index 3bbbf8167f52b..8372ffa157162 100644 --- a/mlir/lib/Dialect/Vector/Transforms/LowerVectorGather.cpp +++ b/mlir/lib/Dialect/Vector/Transforms/LowerVectorGather.cpp @@ -196,6 +196,17 @@ struct Gather1DToConditionalLoads : OpRewritePattern { Value condMask = op.getMask(); Value base = op.getBase(); + + // vector.load requires the most minor memref dim to have unit stride + if (auto memType = dyn_cast(base.getType())) { + memType.getLayout(); + if (auto stridesAttr = + dyn_cast_if_present(memType.getLayout())) { + if (stridesAttr.getStrides().back() != 1) + return failure(); + } + } + Value indexVec = rewriter.createOrFold( loc, op.getIndexVectorType().clone(rewriter.getIndexType()), op.getIndexVec()); From 5646bce72691caae0226ecb5b052015fd8c9b02e Mon Sep 17 00:00:00 2001 From: Andrzej Warzynski Date: Thu, 30 Nov 2023 09:14:07 +0000 Subject: [PATCH 5/6] fixup! [mlir][Vector] Add a rewrite pattern for gather over a strided memref Update comments --- .../Vector/Transforms/LowerVectorGather.cpp | 14 ++++++++------ .../Dialect/Vector/vector-gather-lowering.mlir | 2 -- 2 files changed, 8 insertions(+), 8 deletions(-) diff --git a/mlir/lib/Dialect/Vector/Transforms/LowerVectorGather.cpp b/mlir/lib/Dialect/Vector/Transforms/LowerVectorGather.cpp index 8372ffa157162..e56bf6dc98084 100644 --- a/mlir/lib/Dialect/Vector/Transforms/LowerVectorGather.cpp +++ b/mlir/lib/Dialect/Vector/Transforms/LowerVectorGather.cpp @@ -100,15 +100,17 @@ struct FlattenGather : OpRewritePattern { /// MemRef with updated indices that model the strided access. /// /// ```mlir -/// %subview = memref.subview %M (...) memref<100x3xf32> to memref<100xf32, -/// strided<[3]>> %gather = vector.gather %subview (...) : memref<100xf32, -/// strided<[3]>> +/// %subview = memref.subview %M (...) +/// : memref<100x3xf32> to memref<100xf32, strided<[3]>> +/// %gather = vector.gather %subview[%idxs] (...) : memref<100xf32, strided<[3]>> /// ``` /// ==> /// ```mlir -/// %collapse_shape = memref.collapse_shape %M (...) memref<100x3xf32> into -/// memref<300xf32> %1 = arith.muli %idxs, %c3 : vector<4xindex> %gather = -/// vector.gather %collapse_shape (...) : memref<300xf32> (...) +/// %collapse_shape = memref.collapse_shape %M (...) +/// : memref<100x3xf32> into memref<300xf32> +/// %new_idxs = arith.muli %idxs, %c3 : vector<4xindex> +/// %gather = vector.gather %collapse_shape[%new_idxs] (...) +/// : memref<300xf32> (...) /// ``` /// /// ATM this is effectively limited to reading a 1D Vector from a 2D MemRef, diff --git a/mlir/test/Dialect/Vector/vector-gather-lowering.mlir b/mlir/test/Dialect/Vector/vector-gather-lowering.mlir index a7291c359d351..78f42baca7e31 100644 --- a/mlir/test/Dialect/Vector/vector-gather-lowering.mlir +++ b/mlir/test/Dialect/Vector/vector-gather-lowering.mlir @@ -161,7 +161,6 @@ func.func @gather_tensor_1d_none_set(%base: tensor, %v: vector<2xindex>, #map = affine_map<()[s0] -> (s0 * 4096)> #map1 = affine_map<()[s0] -> (s0 * -4096 + 518400, 4096)> func.func @strided_gather(%base : memref<100x3xf32>, - %M_out: memref<518400xf32>, %idxs : vector<4xindex>, %x : index, %y : index) -> vector<4xf32> { %c0 = arith.constant 0 : index @@ -176,7 +175,6 @@ func.func @strided_gather(%base : memref<100x3xf32>, } // CHECK-LABEL: func.func @strided_gather( // CHECK-SAME: %[[base:.*]]: memref<100x3xf32>, -// CHECK-SAME: %[[M_out:.*]]: memref<518400xf32>, // CHECK-SAME: %[[IDXS:.*]]: vector<4xindex>, // CHECK-SAME: %[[VAL_4:.*]]: index, // CHECK-SAME: %[[VAL_5:.*]]: index) -> vector<4xf32> { From 258cd7d1aea9a179ad1810e4cc1a5f06bd7ff729 Mon Sep 17 00:00:00 2001 From: Andrzej Warzynski Date: Thu, 30 Nov 2023 15:04:08 +0000 Subject: [PATCH 6/6] fixup! [mlir][Vector] Add a rewrite pattern for gather over a strided memref Remove unnecessary code --- mlir/lib/Dialect/Vector/Transforms/LowerVectorGather.cpp | 3 --- mlir/test/Dialect/Vector/vector-gather-lowering.mlir | 1 - 2 files changed, 4 deletions(-) diff --git a/mlir/lib/Dialect/Vector/Transforms/LowerVectorGather.cpp b/mlir/lib/Dialect/Vector/Transforms/LowerVectorGather.cpp index e56bf6dc98084..90128126d0fa1 100644 --- a/mlir/lib/Dialect/Vector/Transforms/LowerVectorGather.cpp +++ b/mlir/lib/Dialect/Vector/Transforms/LowerVectorGather.cpp @@ -121,8 +121,6 @@ struct RemoveStrideFromGatherSource : OpRewritePattern { LogicalResult matchAndRewrite(vector::GatherOp op, PatternRewriter &rewriter) const override { Value base = op.getBase(); - if (!base.getDefiningOp()) - return failure(); // TODO: Strided accesses might be coming from other ops as well auto subview = base.getDefiningOp(); @@ -201,7 +199,6 @@ struct Gather1DToConditionalLoads : OpRewritePattern { // vector.load requires the most minor memref dim to have unit stride if (auto memType = dyn_cast(base.getType())) { - memType.getLayout(); if (auto stridesAttr = dyn_cast_if_present(memType.getLayout())) { if (stridesAttr.getStrides().back() != 1) diff --git a/mlir/test/Dialect/Vector/vector-gather-lowering.mlir b/mlir/test/Dialect/Vector/vector-gather-lowering.mlir index 78f42baca7e31..d047ac629d87e 100644 --- a/mlir/test/Dialect/Vector/vector-gather-lowering.mlir +++ b/mlir/test/Dialect/Vector/vector-gather-lowering.mlir @@ -159,7 +159,6 @@ func.func @gather_tensor_1d_none_set(%base: tensor, %v: vector<2xindex>, // 2. "flatten gather" // However, the main goal is to the test Pattern 1 above. #map = affine_map<()[s0] -> (s0 * 4096)> -#map1 = affine_map<()[s0] -> (s0 * -4096 + 518400, 4096)> func.func @strided_gather(%base : memref<100x3xf32>, %idxs : vector<4xindex>, %x : index, %y : index) -> vector<4xf32> {