diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp index 8363e73857e5c..12aa11e9e33f5 100644 --- a/mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp +++ b/mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp @@ -1236,7 +1236,7 @@ class DropInnerMostUnitDimsTransferRead return failure(); auto srcType = dyn_cast(readOp.getSource().getType()); - if (!srcType) + if (!srcType || !srcType.hasStaticShape()) return failure(); if (!readOp.getPermutationMap().isMinorIdentity()) @@ -1260,21 +1260,19 @@ class DropInnerMostUnitDimsTransferRead targetType.getElementType()); auto loc = readOp.getLoc(); - SmallVector sizes = - memref::getMixedSizes(rewriter, loc, readOp.getSource()); - SmallVector offsets(srcType.getRank(), - rewriter.getIndexAttr(0)); - SmallVector strides(srcType.getRank(), - rewriter.getIndexAttr(1)); MemRefType resultMemrefType = getMemRefTypeWithDroppingInnerDims(rewriter, srcType, dimsToDrop); + SmallVector offsets(srcType.getRank(), 0); + SmallVector strides(srcType.getRank(), 1); + ArrayAttr inBoundsAttr = readOp.getInBounds() ? rewriter.getArrayAttr( readOp.getInBoundsAttr().getValue().drop_back(dimsToDrop)) : ArrayAttr(); Value rankedReducedView = rewriter.create( - loc, resultMemrefType, readOp.getSource(), offsets, sizes, strides); + loc, resultMemrefType, readOp.getSource(), offsets, srcType.getShape(), + strides); auto permMap = getTransferMinorIdentityMap( cast(rankedReducedView.getType()), resultTargetVecType); Value result = rewriter.create( @@ -1320,7 +1318,7 @@ class DropInnerMostUnitDimsTransferWrite return failure(); auto srcType = dyn_cast(writeOp.getSource().getType()); - if (!srcType) + if (!srcType || !srcType.hasStaticShape()) return failure(); if (!writeOp.getPermutationMap().isMinorIdentity()) @@ -1343,23 +1341,20 @@ class DropInnerMostUnitDimsTransferWrite VectorType::get(targetType.getShape().drop_back(dimsToDrop), targetType.getElementType()); - Location loc = writeOp.getLoc(); - SmallVector sizes = - memref::getMixedSizes(rewriter, loc, writeOp.getSource()); - SmallVector offsets(srcType.getRank(), - rewriter.getIndexAttr(0)); - SmallVector strides(srcType.getRank(), - rewriter.getIndexAttr(1)); MemRefType resultMemrefType = getMemRefTypeWithDroppingInnerDims(rewriter, srcType, dimsToDrop); + SmallVector offsets(srcType.getRank(), 0); + SmallVector strides(srcType.getRank(), 1); ArrayAttr inBoundsAttr = writeOp.getInBounds() ? rewriter.getArrayAttr( writeOp.getInBoundsAttr().getValue().drop_back(dimsToDrop)) : ArrayAttr(); + Location loc = writeOp.getLoc(); Value rankedReducedView = rewriter.create( - loc, resultMemrefType, writeOp.getSource(), offsets, sizes, strides); + loc, resultMemrefType, writeOp.getSource(), offsets, srcType.getShape(), + strides); auto permMap = getTransferMinorIdentityMap( cast(rankedReducedView.getType()), resultTargetVecType); diff --git a/mlir/test/Dialect/Vector/vector-transfer-collapse-inner-most-dims.mlir b/mlir/test/Dialect/Vector/vector-transfer-collapse-inner-most-dims.mlir index 3984f17f9e8cd..d6d69c8af8850 100644 --- a/mlir/test/Dialect/Vector/vector-transfer-collapse-inner-most-dims.mlir +++ b/mlir/test/Dialect/Vector/vector-transfer-collapse-inner-most-dims.mlir @@ -16,25 +16,6 @@ func.func @contiguous_inner_most_view(%in: memref<1x1x8x1xf32, strided<[3072, 8, // ----- -func.func @contiguous_outer_dyn_inner_most_view(%in: memref>) -> vector<1x8x1xf32>{ - %c0 = arith.constant 0 : index - %cst = arith.constant 0.0 : f32 - %0 = vector.transfer_read %in[%c0, %c0, %c0, %c0], %cst {in_bounds = [true, true, true]} : memref>, vector<1x8x1xf32> - return %0 : vector<1x8x1xf32> -} -// CHECK: func @contiguous_outer_dyn_inner_most_view( -// CHECK-SAME: %[[SRC:[a-zA-Z0-9]+]] -// CHECK-DAG: %[[C0:.+]] = arith.constant 0 : index -// CHECK-DAG: %[[D0:.+]] = memref.dim %[[SRC]], %[[C0]] -// CHECK: %[[SRC_0:.+]] = memref.subview %[[SRC]][0, 0, 0, 0] [%[[D0]], 1, 8, 1] [1, 1, 1, 1] -// CHECK-SAME: memref> to memref> -// CHECK: %[[VEC:.+]] = vector.transfer_read %[[SRC_0]] -// CHECK-SAME: memref>, vector<1x8xf32> -// CHECK: %[[RESULT:.+]] = vector.shape_cast %[[VEC]] -// CHECK: return %[[RESULT]] - -// ----- - func.func @contiguous_inner_most_dim(%A: memref<16x1xf32>, %i:index, %j:index) -> (vector<8x1xf32>) { %c0 = arith.constant 0 : index %f0 = arith.constant 0.0 : f32 @@ -138,27 +119,6 @@ func.func @drop_inner_most_dim_for_transfer_write(%arg0: memref<1x512x16x1xf32, // ----- -func.func @outer_dyn_drop_inner_most_dim_for_transfer_write(%arg0: memref>, %arg1: vector<1x16x16x1xf32>, %arg2: index) { - %c0 = arith.constant 0 : index - vector.transfer_write %arg1, %arg0[%arg2, %c0, %c0, %c0] - {in_bounds = [true, true, true, true]} - : vector<1x16x16x1xf32>, memref> - return -} -// CHECK: func.func @outer_dyn_drop_inner_most_dim_for_transfer_write -// CHECK-SAME: %[[DEST:[a-zA-Z0-9]+]] -// CHECK-SAME: %[[VEC:[a-zA-Z0-9]+]] -// CHECK-SAME: %[[IDX:[a-zA-Z0-9]+]] -// CHECK-DAG: %[[C0:.+]] = arith.constant 0 : index -// CHECK-DAG: %[[D0:.+]] = memref.dim %[[SRC]], %[[C0]] -// CHECK: %[[SUBVIEW:.+]] = memref.subview %[[DEST]][0, 0, 0, 0] [%[[D0]], 512, 16, 1] -// CHECK-SAME: memref> to memref> -// CHECK: %[[CAST:.+]] = vector.shape_cast %[[VEC]] : vector<1x16x16x1xf32> to vector<1x16x16xf32> -// CHECK: vector.transfer_write %[[CAST]], %[[SUBVIEW]] -// CHECK-SAME: [%[[IDX]], %[[C0]], %[[C0]]] - -// ----- - func.func @non_unit_strides(%arg0: memref<512x16x1xf32, strided<[8192, 16, 4], offset: ?>>, %arg1: vector<16x16x1xf32>, %arg2: index) { %c0 = arith.constant 0 : index vector.transfer_write %arg1, %arg0[%arg2, %c0, %c0]